From 700271d0e92c33a4cce3fee34b39ad5baa179b00 Mon Sep 17 00:00:00 2001 From: Spandan Tiwari Date: Mon, 17 Dec 2018 13:45:21 -0800 Subject: [PATCH] Adding ONNX export for torch.expand and torch.ne (#15050) Summary: `torch.expand` and `torch.ne` are used often in models and this PR adds ONNX export support for them. ArmenAg has created issue https://github.com/pytorch/pytorch/issues/10882 for this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15050 Differential Revision: D13453036 Pulled By: houseroad fbshipit-source-id: 4724b4ffcebda6cd6b2acac51d6733cb27318daf --- test/onnx/expect/TestOperators.test_expand.expect | 63 ++++++++++++++++++ test/onnx/expect/TestOperators.test_ne.expect | 80 +++++++++++++++++++++++ test/onnx/test_operators.py | 9 +++ test/test_jit.py | 21 ------ torch/csrc/jit/export.cpp | 4 -- torch/onnx/symbolic.py | 10 ++- 6 files changed, 160 insertions(+), 27 deletions(-) create mode 100644 test/onnx/expect/TestOperators.test_expand.expect create mode 100644 test/onnx/expect/TestOperators.test_ne.expect diff --git a/test/onnx/expect/TestOperators.test_expand.expect b/test/onnx/expect/TestOperators.test_expand.expect new file mode 100644 index 0000000..13b0715 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_expand.expect @@ -0,0 +1,63 @@ +ir_version: 3 +producer_name: "pytorch" +producer_version: "0.4" +graph { + node { + output: "1" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 3 + data_type: 7 + raw_data: "\004\000\000\000\000\000\000\000\006\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "0" + input: "1" + output: "2" + op_type: "Expand" + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 6 + } + dim { + dim_value: 1 + } + } + } + } + } + output { + name: "2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 4 + } + dim { + dim_value: 6 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/expect/TestOperators.test_ne.expect b/test/onnx/expect/TestOperators.test_ne.expect new file mode 100644 index 0000000..078204f --- /dev/null +++ b/test/onnx/expect/TestOperators.test_ne.expect @@ -0,0 +1,80 @@ +ir_version: 3 +producer_name: "pytorch" +producer_version: "0.4" +graph { + node { + input: "0" + input: "1" + output: "2" + op_type: "Equal" + } + node { + input: "2" + output: "3" + op_type: "Not" + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "1" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "3" + type { + tensor_type { + elem_type: 2 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index d472e6a..4913585 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -487,6 +487,15 @@ class TestOperators(TestCase): x = torch.randn(6, 10, requires_grad=True) self.assertONNX(lambda x: torch.ones_like(x), x) + def test_expand(self): + x = torch.randn(6, 1, requires_grad=True) + self.assertONNX(lambda x: x.expand(4, 6, 2), x) + + def test_ne(self): + x = torch.randn(1, 2, 3, 1, requires_grad=False).int() + y = torch.randn(1, 4, requires_grad=False).int() + self.assertONNX(lambda x, y: torch.ne(x, y), (x, y)) + if __name__ == '__main__': no_onnx_dep_flag = '--no-onnx' _onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS diff --git a/test/test_jit.py b/test/test_jit.py index 77df56b..ac3a8c3 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1722,27 +1722,6 @@ class TestJit(JitTestCase): test_x = torch.rand(6, 3) self.assertEqual(foo(test_x), traced(test_x)) - def test_export_expand_aten_fallback(self): - class ExpandTest(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - y = x - for i in range(5): - y = x.expand([3, 4, i]) - return y - - mod = ExpandTest() - example_outs = mod(torch.rand(3, 4, 1)) - f = io.BytesIO() - with self.assertRaisesRegex(RuntimeError, 'Could not export a broadcasted operation'): - torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False, - example_outputs=example_outs) - - self.assertExpected( - torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False, - example_outputs=example_outs, - operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)) - def test_export_dropout(self): test = torch.nn.Dropout() test.eval() diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index b708214..f7207f7 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -69,10 +69,6 @@ void validateBlock(Block *b, onnx_torch::OperatorExportTypes operator_export_typ node->output(i)->replaceAllUsesWith(new_node->output(i)); } new_node->s_(Symbol::fromQualString("attr::operator"), "expand"); - } else { - FAIL_EXPORT( - "Could not export a broadcasted operation; ONNX likely does not support this form of broadcasting.\n\nBroadcast occurred at:\n" + - getNodeStackTraceString(node)); } } if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) { diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 303a183..e5e2536 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -346,9 +346,11 @@ def t(g, self): return g.op("Transpose", self, perm_i=(1, 0)) -# There is no translation for it, but we don't want to raise an error yet def expand(g, self, size, implicit): - return None + size = _maybe_get_const(size, 'is') + if not _is_value(size): + size = g.op("Constant", value_t=torch.LongTensor(size)) + return g.op("Expand", self, size) def expand_as(g, self, other): @@ -913,6 +915,10 @@ def eq(g, self, other): return g.op("Equal", self, other) +def ne(g, self, other): + return g.op("Not", eq(g, self, other)) + + def exp(g, self): return g.op("Exp", self) -- 2.7.4