{"Dropout", &Caffe2Backend::CreateDropout},
{"LRN", &Caffe2Backend::CreateLRN},
{"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
- {"RandomNormal", &Caffe2Backend::CreateRandomNormal},
- {"Where", &Caffe2Backend::CreateWhereOp}};
+ {"RandomNormal", &Caffe2Backend::CreateRandomNormal}};
return kSpecialOperators;
}
return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
}
-Caffe2Ops Caffe2Backend::CreateWhereOp(
- OnnxNode* onnx_node,
- const ConversionContext& ctx) {
- // The native Caffe2 op doesn't support broadcasting, so we defer the handling
- // of this op to the ATen library that does.
- onnx::NodeProto converted;
- converted.CopyFrom(onnx_node->node);
- converted.set_op_type("ATen");
- onnx::AttributeProto* attr = converted.add_attribute();
- attr->set_name("operator");
- attr->set_s("where");
- OnnxNode new_node(converted);
- return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
-}
-
Caffe2Ops Caffe2Backend::CreateReciprocal(
OnnxNode* onnx_node,
const ConversionContext& ctx) {
OnnxNode* onnx_node,
const ConversionContext& ctx);
- Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);
-
Caffe2Ops CreateBatchNormalization(
OnnxNode* onnx_node,
const ConversionContext& ctx);
'|test_isnan.*' # Needs implementation
'|test_scatter.*' # Should be similar to ScatterAssign
'|test_constantofshape_int.*' # Needs implementation
+ '|test_where.*' # Needs implementation
'|test_shrink.*' # Needs implementation
'|test_strnorm.*' # Needs implementation
'|test_nonzero.*' # Needs implementation
return g.op("Relu", input)
-def ceil(g, input):
- return g.op("Ceil", input)
-
-
-def floor(g, input):
- return g.op("Floor", input)
-
-
@parse_args('v', 't', 't')
def threshold(g, self, threshold, value):
# See Note [Export inplace]
def where(g, condition, self, other):
- return g.op("Where", condition, self, other)
+ return g.op("ATen", condition, self, other, operator_s="where")
@parse_args('v', 'i', 'i')