ONNX Export Narrow op
authorLara Haidar <haidar.lara@gmail.com>
Thu, 7 Mar 2019 06:35:12 +0000 (22:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Mar 2019 06:37:58 +0000 (22:37 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17550

Differential Revision: D14350401

Pulled By: houseroad

fbshipit-source-id: 4d88079bb7a8bbd270b0272009826eb3b202cc33

test/onnx/expect/TestOperators.test_narrow.expect [new file with mode: 0644]
test/onnx/test_operators.py
test/onnx/test_pytorch_onnx_caffe2.py
torch/onnx/symbolic.py

diff --git a/test/onnx/expect/TestOperators.test_narrow.expect b/test/onnx/expect/TestOperators.test_narrow.expect
new file mode 100644 (file)
index 0000000..ff04330
--- /dev/null
@@ -0,0 +1,61 @@
+ir_version: 4
+producer_name: "pytorch"
+producer_version: "1.1"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    op_type: "Slice"
+    attribute {
+      name: "axes"
+      ints: 0
+      type: INTS
+    }
+    attribute {
+      name: "ends"
+      ints: 2
+      type: INTS
+    }
+    attribute {
+      name: "starts"
+      ints: 0
+      type: INTS
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 3
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 10
+}
index c0a2823..d09a2a1 100644 (file)
@@ -410,6 +410,10 @@ class TestOperators(TestCase):
         x = torch.rand(3, 4, requires_grad=True)
         self.assertONNX(lambda x: x[:, 1:2], x)
 
+    def test_narrow(self):
+        x = torch.randn(3, 3, requires_grad=True)
+        self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
+
     def test_atan(self):
         x = torch.randn(3, 4, requires_grad=True)
         self.assertONNX(lambda x: x.atan(), x)
index c2532d6..9d8cb78 100644 (file)
@@ -1107,6 +1107,14 @@ class TestCaffe2Backend(unittest.TestCase):
         x = torch.randn(2, 3, requires_grad=True)
         self.run_model_test(ReshapeAsModel(), train=False, input=x, batch_size=BATCH_SIZE)
 
+    def test_narrow(self):
+        class NarrowModel(torch.nn.Module):
+            def forward(self, input):
+                return torch.narrow(input, 0, 0, 2)
+
+        x = torch.randn(3, 3, requires_grad=True)
+        self.run_model_test(NarrowModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
 # a bit of metaprogramming to set up all the rnn tests
 
 
index a3c145f..b5955cf 100644 (file)
@@ -1626,6 +1626,11 @@ def nonzero(g, input):
     return g.op('NonZero', input)
 
 
+@parse_args('v', 'i', 'i', 'i')
+def narrow(g, input, dim, start, length):
+    return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length])
+
+
 @parse_args('v', 'i', 'i')
 def _argmax(g, input, dim, keepdim):
     return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)