{"Dropout", &Caffe2Backend::CreateDropout},
{"LRN", &Caffe2Backend::CreateLRN},
{"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
- {"RandomNormal", &Caffe2Backend::CreateRandomNormal}};
+ {"RandomNormal", &Caffe2Backend::CreateRandomNormal},
+ {"Where", &Caffe2Backend::CreateWhereOp}};
return kSpecialOperators;
}
return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
}
+Caffe2Ops Caffe2Backend::CreateWhereOp(
+ OnnxNode* onnx_node,
+ const ConversionContext& ctx) {
+ // The native Caffe2 op doesn't support broadcasting, so we defer the handling
+ // of this op to the ATen library that does.
+ onnx::NodeProto converted;
+ converted.CopyFrom(onnx_node->node);
+ converted.set_op_type("ATen");
+ onnx::AttributeProto* attr = converted.add_attribute();
+ attr->set_name("operator");
+ attr->set_s("where");
+ OnnxNode new_node(converted);
+ return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
+}
+
Caffe2Ops Caffe2Backend::CreateReciprocal(
OnnxNode* onnx_node,
const ConversionContext& ctx) {
OnnxNode* onnx_node,
const ConversionContext& ctx);
+ Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);
+
Caffe2Ops CreateBatchNormalization(
OnnxNode* onnx_node,
const ConversionContext& ctx);
'|test_isnan.*' # Needs implementation
'|test_scatter.*' # Should be similar to ScatterAssign
'|test_constantofshape_int.*' # Needs implementation
- '|test_where.*' # Needs implementation
'|test_shrink.*' # Needs implementation
'|test_strnorm.*' # Needs implementation
'|test_nonzero.*' # Needs implementation
return g.op("Relu", input)
+def ceil(g, input):
+ return g.op("Ceil", input)
+
+
+def floor(g, input):
+ return g.op("Floor", input)
+
+
@parse_args('v', 't', 't')
def threshold(g, self, threshold, value):
# See Note [Export inplace]
def where(g, condition, self, other):
- return g.op("ATen", condition, self, other, operator_s="where")
+ return g.op("Where", condition, self, other)
@parse_args('v', 'i', 'i')