[PYTORCH]Matmul fix for batch_matmul (#5604)
authorSamuel <siju.samuel@huawei.com>
Fri, 15 May 2020 21:43:23 +0000 (03:13 +0530)
committerGitHub <noreply@github.com>
Fri, 15 May 2020 21:43:23 +0000 (06:43 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 4ecac00..efb95f9 100644 (file)
@@ -1249,13 +1249,47 @@ def _chunk(prelude):
         return chunks
     return _impl
 
-def _matmul():
-    def _impl(inputs, input_types):
-        data0 = inputs[0]
-        data1 = inputs[1]
-        data1_t = _op.transpose(data1, axes=(1, 0))
+def _matmul(prelude):
+    def _impl(inputs, input_types):
+
+        inputs_0 = inputs[0]
+        inputs_1 = inputs[1]
+
+        # Need to check input shape as batch matmul must be supported.
+        a_shape = _infer_shape(inputs_0, prelude.mod)
+        b_shape = _infer_shape(inputs_1, prelude.mod)
+
+        # When performing a batch matmul, we need to properly handle N-dim shapes.
+        if len(a_shape) > 2 or len(b_shape) > 2:
+            # Convert a and b into 3 dimensional tensors.
+            a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+            b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]])
+            # Broadcast b to match batch size of a
+            new_b_shape = list(_infer_shape(b, prelude.mod))
+            new_a_shape = _infer_shape(a, prelude.mod)
+            if new_a_shape[0] > new_b_shape[0]:
+                new_b_shape[0] = new_a_shape[0]
+                b = _op.broadcast_to(b, new_b_shape)
+            # Transpose matrix dimensions of b.
+            b = _op.transpose(b, [0, 2, 1])
+            # Perform a batch matmul.
+            output = _op.nn.batch_matmul(a, b)
+            # Reshape output to original dimensions.
+            return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
+
+        # Otherwise a simple dense op will get the job done.
+        if len(b_shape) == 1:
+            input_1 = _op.expand_dims(inputs_1, 0, 1)
+        else:
+            input_1 = _op.transpose(inputs_1, axes=(1, 0))
+
+        out = _op.nn.dense(inputs_0, input_1)
+
+        if len(b_shape) == 1:
+            out = _op.squeeze(out, axis=[-1])
+
+        return out
 
-        return _op.nn.dense(data0, data1_t)
     return _impl
 
 
@@ -1702,7 +1736,7 @@ def _get_convert_map(prelude):
         "aten::alpha_dropout"                   : _dropout(),
         "aten::mean"                            : _mean(),
         "aten::chunk"                           : _chunk(prelude),
-        "aten::matmul"                          : _matmul(),
+        "aten::matmul"                          : _matmul(prelude),
         "aten::expand"                          : _expand(),
         "aten::Int"                             : _int(),
         "prim::NumToTensor"                     : _numtotensor(),
@@ -1763,7 +1797,7 @@ def _get_convert_map(prelude):
         "aten::rsub"                            : _rsub(),
         "aten::embedding"                       : _embedding(),
         "aten::one_hot"                         : _one_hot(),
-        "aten::mm"                              : _matmul(),
+        "aten::mm"                              : _matmul(prelude),
         "relay::tensor_array_stack"             : _tensor_array_stack(prelude),
         "aten::add"                             : _add(prelude),
         "aten::add_"                            : _add(prelude),
index 3d9d22b..30036db 100644 (file)
@@ -2064,11 +2064,45 @@ def test_forward_addcmul():
     verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])
 
 
+def test_forward_matmul():
+    torch.set_grad_enabled(False)
+
+    class MatMul1(Module):
+        def forward(self, *args):
+            return torch.matmul(args[0], args[1])
+
+    # matrix x vector
+    tensor1 = torch.randn(3, 4)
+    tensor2 = torch.randn(4)
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+    # matrix x matrix
+    tensor1 = torch.randn(10, 4)
+    tensor2 = torch.randn(4, 10)
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+    # batched matrix x batched matrix
+    tensor1 = torch.randn(10, 3, 4)
+    tensor2 = torch.randn(10, 4, 5)
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+    # batched matrix x broadcasted matrix
+    tensor1 = torch.randn(10, 3, 4)
+    tensor2 = torch.randn(4, 5)
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+    # batched matrix x batched matrix
+    tensor1 = torch.randn(1, 12, 14, 64)
+    tensor2 = torch.randn(1, 12, 64, 14)
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+
 if __name__ == "__main__":
     # Single operator tests
     test_forward_add()
     test_forward_subtract()
     test_forward_multiply()
+    test_forward_matmul()
     test_forward_rsub()
     test_forward_onehot()
     test_forward_embedding()