Revert D14635130: Improved onnx export for 3 onnx ops.
authorJunjie Bai <jbai@fb.com>
Thu, 28 Mar 2019 17:18:46 +0000 (10:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Mar 2019 17:26:34 +0000 (10:26 -0700)
Differential Revision:
D14635130

Original commit changeset: d54a2b6e2950

fbshipit-source-id: f624e2befdde245cb88435a95508b2a8e6b12e61

caffe2/onnx/backend.cc
caffe2/onnx/backend.h
caffe2/python/onnx/tests/onnx_backend_test.py
torch/onnx/symbolic.py

index 3564ebe..e7c512a 100644 (file)
@@ -362,8 +362,7 @@ Caffe2Backend::get_special_operators() const {
               {"Dropout", &Caffe2Backend::CreateDropout},
               {"LRN", &Caffe2Backend::CreateLRN},
               {"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
-              {"RandomNormal", &Caffe2Backend::CreateRandomNormal},
-              {"Where", &Caffe2Backend::CreateWhereOp}};
+              {"RandomNormal", &Caffe2Backend::CreateRandomNormal}};
   return kSpecialOperators;
 }
 
@@ -581,21 +580,6 @@ Caffe2Ops Caffe2Backend::CreateRandomNormal(
   return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
 }
 
-Caffe2Ops Caffe2Backend::CreateWhereOp(
-    OnnxNode* onnx_node,
-    const ConversionContext& ctx) {
-  // The native Caffe2 op doesn't support broadcasting, so we defer the handling
-  // of this op to the ATen library that does.
-  onnx::NodeProto converted;
-  converted.CopyFrom(onnx_node->node);
-  converted.set_op_type("ATen");
-  onnx::AttributeProto* attr = converted.add_attribute();
-  attr->set_name("operator");
-  attr->set_s("where");
-  OnnxNode new_node(converted);
-  return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
-}
-
 Caffe2Ops Caffe2Backend::CreateReciprocal(
     OnnxNode* onnx_node,
     const ConversionContext& ctx) {
index 8ee33ef..d61af29 100644 (file)
@@ -236,8 +236,6 @@ class CAFFE2_API Caffe2Backend {
       OnnxNode* onnx_node,
       const ConversionContext& ctx);
 
-  Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);
-
   Caffe2Ops CreateBatchNormalization(
       OnnxNode* onnx_node,
       const ConversionContext& ctx);
index f353e22..75d4b5a 100644 (file)
@@ -52,6 +52,7 @@ backend_test.exclude(r'(test_hardsigmoid'  # Does not support Hardsigmoid.
                      '|test_isnan.*'  # Needs implementation
                      '|test_scatter.*'  # Should be similar to ScatterAssign
                      '|test_constantofshape_int.*'  # Needs implementation
+                     '|test_where.*'  # Needs implementation
                      '|test_shrink.*'  # Needs implementation
                      '|test_strnorm.*'  # Needs implementation
                      '|test_nonzero.*'  # Needs implementation
index 9a1911f..fbb8d97 100644 (file)
@@ -548,14 +548,6 @@ def relu(g, input):
     return g.op("Relu", input)
 
 
-def ceil(g, input):
-    return g.op("Ceil", input)
-
-
-def floor(g, input):
-    return g.op("Floor", input)
-
-
 @parse_args('v', 't', 't')
 def threshold(g, self, threshold, value):
     # See Note [Export inplace]
@@ -930,7 +922,7 @@ def le(g, input, other):
 
 
 def where(g, condition, self, other):
-    return g.op("Where", condition, self, other)
+    return g.op("ATen", condition, self, other, operator_s="where")
 
 
 @parse_args('v', 'i', 'i')