From 77280b11e35c747c9ffc01c97509761aefcae37b Mon Sep 17 00:00:00 2001 From: Junjie Bai Date: Thu, 28 Mar 2019 10:18:46 -0700 Subject: [PATCH] Revert D14635130: Improved onnx export for 3 onnx ops. Differential Revision: D14635130 Original commit changeset: d54a2b6e2950 fbshipit-source-id: f624e2befdde245cb88435a95508b2a8e6b12e61 --- caffe2/onnx/backend.cc | 18 +----------------- caffe2/onnx/backend.h | 2 -- caffe2/python/onnx/tests/onnx_backend_test.py | 1 + torch/onnx/symbolic.py | 10 +--------- 4 files changed, 3 insertions(+), 28 deletions(-) diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc index 3564ebe..e7c512a 100644 --- a/caffe2/onnx/backend.cc +++ b/caffe2/onnx/backend.cc @@ -362,8 +362,7 @@ Caffe2Backend::get_special_operators() const { {"Dropout", &Caffe2Backend::CreateDropout}, {"LRN", &Caffe2Backend::CreateLRN}, {"DynamicSlice", &Caffe2Backend::CreateDynamicSlice}, - {"RandomNormal", &Caffe2Backend::CreateRandomNormal}, - {"Where", &Caffe2Backend::CreateWhereOp}}; + {"RandomNormal", &Caffe2Backend::CreateRandomNormal}}; return kSpecialOperators; } @@ -581,21 +580,6 @@ Caffe2Ops Caffe2Backend::CreateRandomNormal( return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); } -Caffe2Ops Caffe2Backend::CreateWhereOp( - OnnxNode* onnx_node, - const ConversionContext& ctx) { - // The native Caffe2 op doesn't support broadcasting, so we defer the handling - // of this op to the ATen library that does. - onnx::NodeProto converted; - converted.CopyFrom(onnx_node->node); - converted.set_op_type("ATen"); - onnx::AttributeProto* attr = converted.add_attribute(); - attr->set_name("operator"); - attr->set_s("where"); - OnnxNode new_node(converted); - return CommonOnnxNodeToCaffe2Ops(&new_node, ctx); -} - Caffe2Ops Caffe2Backend::CreateReciprocal( OnnxNode* onnx_node, const ConversionContext& ctx) { diff --git a/caffe2/onnx/backend.h b/caffe2/onnx/backend.h index 8ee33ef..d61af29 100644 --- a/caffe2/onnx/backend.h +++ b/caffe2/onnx/backend.h @@ -236,8 +236,6 @@ class CAFFE2_API Caffe2Backend { OnnxNode* onnx_node, const ConversionContext& ctx); - Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx); - Caffe2Ops CreateBatchNormalization( OnnxNode* onnx_node, const ConversionContext& ctx); diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py index f353e22..75d4b5a 100644 --- a/caffe2/python/onnx/tests/onnx_backend_test.py +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -52,6 +52,7 @@ backend_test.exclude(r'(test_hardsigmoid' # Does not support Hardsigmoid. '|test_isnan.*' # Needs implementation '|test_scatter.*' # Should be similar to ScatterAssign '|test_constantofshape_int.*' # Needs implementation + '|test_where.*' # Needs implementation '|test_shrink.*' # Needs implementation '|test_strnorm.*' # Needs implementation '|test_nonzero.*' # Needs implementation diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 9a1911f..fbb8d97 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -548,14 +548,6 @@ def relu(g, input): return g.op("Relu", input) -def ceil(g, input): - return g.op("Ceil", input) - - -def floor(g, input): - return g.op("Floor", input) - - @parse_args('v', 't', 't') def threshold(g, self, threshold, value): # See Note [Export inplace] @@ -930,7 +922,7 @@ def le(g, input, other): def where(g, condition, self, other): - return g.op("Where", condition, self, other) + return g.op("ATen", condition, self, other, operator_s="where") @parse_args('v', 'i', 'i') -- 2.7.4