Add onnxifi support to SparseLengthsWeightedSum (#14210)
authorYinghai Lu <yinghai@fb.com>
Wed, 21 Nov 2018 23:43:10 +0000 (15:43 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 21 Nov 2018 23:47:24 +0000 (15:47 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14210

We left `SparseLengthsWeightedSum` as benchmark is not testing it due to fp16 filler issue. It was flushed out by unit tests. Hence we add the support here.

Reviewed By: bddppq

Differential Revision: D13132320

fbshipit-source-id: b21c30c185c9e1fbf3980641bc3cdc39e85af2e1

caffe2/onnx/torch_ops/defs.cc
caffe2/onnx/torch_ops/operator_sets.h
caffe2/operators/lengths_reducer_ops.cc

index ff02d9b..24e6493 100644 (file)
@@ -56,6 +56,32 @@ ONNX_PYTORCH_OPERATOR_SET_SCHEMA(
             "Constrain index and length to integral tensors."));
 
 ONNX_PYTORCH_OPERATOR_SET_SCHEMA(
+    SparseLengthsWeightedSum,
+    1,
+    OpSchema()
+        .SetDoc("Mirror Caffe2 SparseLengthsWeightedSum operator")
+        .Input(0, "DATA", "data tensor", "T1")
+        .Input(1, "WEIGHTS", "data tensor", "T1")
+        .Input(2, "INDICES", "indices tensor", "T2")
+        .Input(3, "LENGTHS", "lengths tensor", "T2")
+        .Output(0, "output", "Output tensor", "T1")
+        .TypeConstraint(
+            "T1",
+            {"tensor(float16)", "tensor(float)", "tensor(double)"},
+            "Constrain input and output types to float tensors.")
+        .TypeConstraint(
+            "T2",
+            {"tensor(int8)",
+             "tensor(int16)",
+             "tensor(int32)",
+             "tensor(int64)",
+             "tensor(uint8)",
+             "tensor(uint16)",
+             "tensor(uint32)",
+             "tensor(uint64)"},
+            "Constrain index and length to integral tensors."));
+
+ONNX_PYTORCH_OPERATOR_SET_SCHEMA(
     BatchGather,
     1,
     OpSchema()
index cee3198..f7380af 100644 (file)
@@ -9,6 +9,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
     1,
     SparseLengthsSumFused8BitRowwise);
 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsSum);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsWeightedSum);
 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchGather);
 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, DotProduct);
 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, FCTransposed);
@@ -24,6 +25,8 @@ class OpSet_PyTorch_ver1 {
     fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
            PyTorch, 1, SparseLengthsSum)>());
     fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
+           PyTorch, 1, SparseLengthsWeightedSum)>());
+    fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
            PyTorch, 1, BatchGather)>());
     fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
            PyTorch, 1, DotProduct)>());
index d4c8df6..85513dc 100644 (file)
@@ -106,7 +106,8 @@ OPERATOR_SCHEMA(SparseLengthsWeightedSum)
         SparseLengthsWeightedSumOp::WEIGHT)
     .SetDoc(FormatDoc<SparseLengthsWeightedSumDef>())
     .Output(0, "OUTPUT", "Aggregated tensor")
-    .FillUsing(SparseLengthsWeightedSumDef::PopulateSchema);
+    .FillUsing(SparseLengthsWeightedSumDef::PopulateSchema)
+    .InheritOnnxSchema();
 REGISTER_CPU_OPERATOR(
     SparseLengthsWeightedSumGradient,
     SparseLengthsWeightedSumDef::BackwardOp);