reshape with non constant shape argument (#6411)
authormasahi <masahi129@gmail.com>
Mon, 7 Sep 2020 15:02:30 +0000 (00:02 +0900)
committerGitHub <noreply@github.com>
Mon, 7 Sep 2020 15:02:30 +0000 (00:02 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index fe5a45c..99d2dae 100644 (file)
@@ -1131,10 +1131,13 @@ def _view():
 def _reshape():
     def _impl(inputs, input_types):
         data = inputs[0]
-        if isinstance(inputs[1], list):
+        if _is_int_seq(inputs[1]):
             new_shape = inputs[1]
         else:
-            new_shape = _infer_shape(inputs[1])
+            assert isinstance(inputs[1], list)
+            infer_res = [_infer_value(_wrap_const(size), {}) for size in inputs[1]]
+            new_shape = [np.asscalar(res.asnumpy().astype(np.int))
+                         for res in infer_res]
         return _op.transform.reshape(data, new_shape)
     return _impl
 
index cfe9507..b35c7d6 100644 (file)
@@ -946,6 +946,7 @@ def test_forward_reshape():
     torch.set_grad_enabled(False)
     input_shape = [2, 1, 10, 1, 10]
     new_shape = [2, 1, 10, 10]
+
     class Reshape1(Module):
         def forward(self, *args):
             return args[0].reshape(new_shape)
@@ -954,9 +955,15 @@ def test_forward_reshape():
         def forward(self, *args):
             return args[0].reshape([-1])
 
+    class Reshape3(torch.nn.Module):
+        def forward(self, x):
+            x_shape = x.shape
+            return x.reshape((x_shape[0] * x_shape[1], x_shape[2]))
+
     input_data = torch.rand(input_shape).float()
-    verify_model(Reshape1().float().eval(), input_data=input_data)
-    verify_model(Reshape2().float().eval(), input_data=input_data)
+    verify_model(Reshape1(), input_data=input_data)
+    verify_model(Reshape2(), input_data=input_data)
+    verify_model(Reshape3(), input_data=torch.randn(2, 3, 4))
 
 
 @tvm.testing.uses_gpu