From 3dba1285ab1c78b035f79493d25379ef36f53512 Mon Sep 17 00:00:00 2001 From: Lara Haidar Date: Wed, 6 Mar 2019 22:35:12 -0800 Subject: [PATCH] ONNX Export Narrow op Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17550 Differential Revision: D14350401 Pulled By: houseroad fbshipit-source-id: 4d88079bb7a8bbd270b0272009826eb3b202cc33 --- test/onnx/expect/TestOperators.test_narrow.expect | 61 +++++++++++++++++++++++ test/onnx/test_operators.py | 4 ++ test/onnx/test_pytorch_onnx_caffe2.py | 8 +++ torch/onnx/symbolic.py | 5 ++ 4 files changed, 78 insertions(+) create mode 100644 test/onnx/expect/TestOperators.test_narrow.expect diff --git a/test/onnx/expect/TestOperators.test_narrow.expect b/test/onnx/expect/TestOperators.test_narrow.expect new file mode 100644 index 0000000..ff04330 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_narrow.expect @@ -0,0 +1,61 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.1" +graph { + node { + input: "0" + output: "1" + op_type: "Slice" + attribute { + name: "axes" + ints: 0 + type: INTS + } + attribute { + name: "ends" + ints: 2 + type: INTS + } + attribute { + name: "starts" + ints: 0 + type: INTS + } + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + output { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index c0a2823..d09a2a1 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -410,6 +410,10 @@ class TestOperators(TestCase): x = torch.rand(3, 4, requires_grad=True) self.assertONNX(lambda x: x[:, 1:2], x) + def test_narrow(self): + x = torch.randn(3, 3, requires_grad=True) + self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x) + def test_atan(self): x = torch.randn(3, 4, requires_grad=True) self.assertONNX(lambda x: x.atan(), x) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index c2532d6..9d8cb78 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -1107,6 +1107,14 @@ class TestCaffe2Backend(unittest.TestCase): x = torch.randn(2, 3, requires_grad=True) self.run_model_test(ReshapeAsModel(), train=False, input=x, batch_size=BATCH_SIZE) + def test_narrow(self): + class NarrowModel(torch.nn.Module): + def forward(self, input): + return torch.narrow(input, 0, 0, 2) + + x = torch.randn(3, 3, requires_grad=True) + self.run_model_test(NarrowModel(), train=False, input=x, batch_size=BATCH_SIZE) + # a bit of metaprogramming to set up all the rnn tests diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index a3c145f..b5955cf 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1626,6 +1626,11 @@ def nonzero(g, input): return g.op('NonZero', input) +@parse_args('v', 'i', 'i', 'i') +def narrow(g, input, dim, start, length): + return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length]) + + @parse_args('v', 'i', 'i') def _argmax(g, input, dim, keepdim): return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim) -- 2.7.4