[PYTORCH]expand bug fix (#5576)
authorSamuel <siju.samuel@huawei.com>
Wed, 13 May 2020 00:09:41 +0000 (05:39 +0530)
committerGitHub <noreply@github.com>
Wed, 13 May 2020 00:09:41 +0000 (09:09 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 3af1051..d95a912 100644 (file)
@@ -1245,27 +1245,36 @@ def _matmul():
         return _op.nn.dense(data0, data1_t)
     return _impl
 
+
 def _expand():
     def _impl(inputs, input_types):
         data_in = inputs[0]
         if isinstance(data_in, _expr.Expr):
-            shape = _infer_shape(data_in)
+            shape = list(_infer_shape(data_in))
 
         ndims = len(shape)
         sizes = _infer_shape(inputs[1])
         out = inputs[0]
 
+        out_dims = len(sizes)
+        if ndims < out_dims:
+            num_newaxis = out_dims - ndims
+            out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis)
+            shape = [1] * num_newaxis + shape
+
         for i in range(ndims):
-            if sizes[i] in {-1, shape[i]}:
+            if sizes[i] == -1 or sizes[i] == shape[i]:
                 continue
             data = list()
             for temp in range(sizes[i]):
                 data.append(out)
-            call = _op.tensor.concatenate(data, i)
 
-        return call
+            out = _op.tensor.concatenate(data, i)
+
+        return out
     return _impl
 
+
 def _int():
     def _impl(inputs, input_types):
         if isinstance(inputs[0], _expr.Expr):
index e1c276b..82a027f 100644 (file)
@@ -902,15 +902,24 @@ def test_forward_mean():
 
 def test_forward_expand():
     torch.set_grad_enabled(False)
-    input_shape = [1, 3, 10, 10]
 
     class Expand1(Module):
         def forward(self, *args):
             return args[0].expand((3, -1, -1, -1))
 
+    input_shape = [1, 3, 10, 10]
     input_data = torch.rand(input_shape).float()
     verify_model(Expand1().float().eval(), input_data=input_data)
 
+    class Expand2(Module):
+        def forward(self, *args):
+            return args[0].expand((3, 3, 3, 1))
+
+    input_shape = [3, 1]
+    input_data = torch.rand(input_shape).float()
+    verify_model(Expand2().float().eval(), input_data=input_data)
+
+
 def test_forward_pow():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]