Extend support for exporting reshape to onnx. (#16971)
authorBowenBao <semisqg@gmail.com>
Fri, 15 Feb 2019 08:14:25 +0000 (00:14 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Feb 2019 08:17:05 +0000 (00:17 -0800)
Summary:
Resolve issue with reshape_as test case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16971

Differential Revision: D14098871

Pulled By: houseroad

fbshipit-source-id: ed6b966821462d374313256abbbe27f96ce11b2c

test/onnx/test_pytorch_onnx_caffe2.py
torch/onnx/symbolic.py

index c31d7c2..4c7830f 100644 (file)
@@ -1043,6 +1043,23 @@ class TestCaffe2Backend(unittest.TestCase):
         x = torch.randn(1, 2, 3, 4, requires_grad=True)
         self.run_model_test(FlattenModel(), train=False, input=x, batch_size=BATCH_SIZE)
 
+    def test_reshape(self):
+        class ReshapeModel(torch.nn.Module):
+            def forward(self, input):
+                return input.reshape(1, 1)
+
+        x = torch.randn(1, requires_grad=True)
+        self.run_model_test(ReshapeModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
+    def test_reshape_as(self):
+        class ReshapeAsModel(torch.nn.Module):
+            def forward(self, input):
+                y = torch.randn(3, 1, 2, 1, requires_grad=False)
+                return input.reshape_as(y)
+
+        x = torch.randn(2, 3, requires_grad=True)
+        self.run_model_test(ReshapeAsModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
 # a bit of metaprogramming to set up all the rnn tests
 
 
index 3f00a7b..8dc72c8 100644 (file)
@@ -241,6 +241,15 @@ def _reshape_from_tensor(g, input, shape):
     return g.op('Reshape', input, shape)
 
 
+def reshape(g, self, shape):
+    return view(g, self, shape)
+
+
+def reshape_as(g, self, other):
+    shape = g.op('Shape', other)
+    return reshape(g, self, shape)
+
+
 def add(g, self, other, alpha=None):
     # default alpha arg is to allow no-alpha add (aten add st overload no alpha)
     if alpha and _scalar(_maybe_get_scalar(alpha)) != 1: