"Constrain index and length to integral tensors."));
ONNX_PYTORCH_OPERATOR_SET_SCHEMA(
+ SparseLengthsWeightedSum,
+ 1,
+ OpSchema()
+ .SetDoc("Mirror Caffe2 SparseLengthsWeightedSum operator")
+ .Input(0, "DATA", "data tensor", "T1")
+ .Input(1, "WEIGHTS", "data tensor", "T1")
+ .Input(2, "INDICES", "indices tensor", "T2")
+ .Input(3, "LENGTHS", "lengths tensor", "T2")
+ .Output(0, "output", "Output tensor", "T1")
+ .TypeConstraint(
+ "T1",
+ {"tensor(float16)", "tensor(float)", "tensor(double)"},
+ "Constrain input and output types to float tensors.")
+ .TypeConstraint(
+ "T2",
+ {"tensor(int8)",
+ "tensor(int16)",
+ "tensor(int32)",
+ "tensor(int64)",
+ "tensor(uint8)",
+ "tensor(uint16)",
+ "tensor(uint32)",
+ "tensor(uint64)"},
+ "Constrain index and length to integral tensors."));
+
+ONNX_PYTORCH_OPERATOR_SET_SCHEMA(
BatchGather,
1,
OpSchema()
1,
SparseLengthsSumFused8BitRowwise);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsSum);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsWeightedSum);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchGather);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, DotProduct);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, FCTransposed);
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, SparseLengthsSum)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
+ PyTorch, 1, SparseLengthsWeightedSum)>());
+ fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, BatchGather)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
PyTorch, 1, DotProduct)>());
SparseLengthsWeightedSumOp::WEIGHT)
.SetDoc(FormatDoc<SparseLengthsWeightedSumDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
- .FillUsing(SparseLengthsWeightedSumDef::PopulateSchema);
+ .FillUsing(SparseLengthsWeightedSumDef::PopulateSchema)
+ .InheritOnnxSchema();
REGISTER_CPU_OPERATOR(
SparseLengthsWeightedSumGradient,
SparseLengthsWeightedSumDef::BackwardOp);