[Relay] pytorch frontend support conv1d (#6203)
authorTianming Xu <tianmingxu.tmxu@gmail.com>
Wed, 5 Aug 2020 20:01:27 +0000 (04:01 +0800)
committerGitHub <noreply@github.com>
Wed, 5 Aug 2020 20:01:27 +0000 (05:01 +0900)
* [Relay] pytorch frontend support conv1d

* add tests for conv1d

Co-authored-by: xutianming.xtm <xutianming.xtm@bytedance.com>
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 57b64ac..3dfdb2f 100644 (file)
@@ -752,7 +752,7 @@ def _convolution():
         # If groups > 1 but weight_shape[1] != 1, this is group convolution
         if groups > 1 and weight_shape[1] == 1:
             channel_multiplier = channels // groups
-            new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
+            new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:])
             weight = _op.transform.reshape(weight, new_weight_shape)
 
         kernel_size = weight_shape[2:]
@@ -760,12 +760,18 @@ def _convolution():
 
         if isinstance(strides, _expr.Expr):
             strides = _infer_shape(strides)
+            if len(kernel_size) == 1:
+                strides = (1, ) + strides
 
         if isinstance(padding, _expr.Expr):
             padding = _infer_shape(padding)
+            if len(kernel_size) == 1:
+                padding = (0, ) + padding
 
         if isinstance(dilation, _expr.Expr):
             dilation = _infer_shape(dilation)
+            if len(kernel_size) == 1:
+                dilation = (1, ) + dilation
 
         if use_transpose:
             if len(kernel_size) == 3:
@@ -785,6 +791,9 @@ def _convolution():
             data_layout = "NCHW"
             kernel_layout = "OIHW"
 
+        if len(kernel_size) == 1:
+            data = _op.expand_dims(data, axis=2)
+            weight = _op.expand_dims(weight, axis=2)
 
         conv_out = conv_op(data,
                            weight,
@@ -793,15 +802,21 @@ def _convolution():
                            dilation=dilation,
                            groups=groups,
                            channels=channels,
-                           kernel_size=kernel_size,
+                           kernel_size=[1] + kernel_size \
+                                        if len(kernel_size) == 1 \
+                                        else kernel_size,
                            data_layout=data_layout,
                            kernel_layout=kernel_layout,
                            out_layout="",
                            out_dtype="")
         if use_bias:
-            return _op.nn.bias_add(conv_out, bias)
+            res = _op.nn.bias_add(conv_out, bias)
         else:
-            return conv_out
+            res = conv_out
+        if len(kernel_size) == 1:
+            res = _op.squeeze(res, axis=[2])
+        return res
+
     return _impl
 
 def _softmax():
index 6a572db..ab9cca1 100644 (file)
@@ -702,7 +702,8 @@ def test_forward_hardtanh():
 
 def test_forward_conv():
     torch.set_grad_enabled(False)
-    input_shape = [1, 3, 10, 10]
+    conv1d_input_shape = [1, 3, 10]
+    conv2d_input_shape = [1, 3, 10, 10]
 
     class Conv2D1(Module):
         def __init__(self):
@@ -731,23 +732,59 @@ def test_forward_conv():
         def forward(self, *args):
             return self.softmax(self.conv(args[0]))
 
-    input_data = torch.rand(input_shape).float()
-    verify_model(Conv2D1().float().eval(), input_data=input_data)
-    verify_model(Conv2D2().float().eval(), input_data=input_data)
+    class Conv1D1(Module):
+        def __init__(self):
+            super(Conv1D1, self).__init__()
+            self.conv = torch.nn.Conv1d(3, 6, 7)
+            self.softmax = torch.nn.Softmax()
+
+        def forward(self, *args):
+            return self.softmax(self.conv(args[0]))
+
+    class Conv1D2(Module):
+        def __init__(self):
+            super(Conv1D2, self).__init__()
+            self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
+            self.softmax = torch.nn.Softmax()
+
+        def forward(self, *args):
+            return self.softmax(self.conv(args[0]))
+
+    class Conv1D3(Module):
+        def __init__(self):
+            super(Conv1D3, self).__init__()
+            self.conv = torch.nn.Conv1d(3, 6, 7, groups=3, bias=False)
+            self.softmax = torch.nn.Softmax()
+
+        def forward(self, *args):
+            return self.softmax(self.conv(args[0]))
+
+    conv2d_input_data = torch.rand(conv2d_input_shape).float()
+    verify_model(Conv2D1().float().eval(), input_data=conv2d_input_data)
+    verify_model(Conv2D2().float().eval(), input_data=conv2d_input_data)
     # depth wise conv with channel mult 2
-    verify_model(Conv2D3().float().eval(), input_data=input_data)
+    verify_model(Conv2D3().float().eval(), input_data=conv2d_input_data)
     # group conv
     verify_model(torch.nn.Conv2d(8, 8, kernel_size=(3, 3),
                                  stride=(1, 1), groups=2).eval(),
                  input_data=torch.randn((1, 8, 16, 16)))
 
+    conv1d_input_data = torch.rand(conv1d_input_shape).float()
+    verify_model(Conv1D1().float().eval(), input_data=conv1d_input_data)
+    verify_model(Conv1D2().float().eval(), input_data=conv1d_input_data)
+    verify_model(Conv1D3().float().eval(), input_data=conv1d_input_data)
 
 def test_forward_conv_transpose():
     torch.set_grad_enabled(False)
-    input_shape = [1, 3, 10, 10]
-    input_data = torch.rand(input_shape).float()
-    verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data)
-    verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data)
+    conv2d_input_shape = [1, 3, 10, 10]
+    conv2d_input_data = torch.rand(conv2d_input_shape).float()
+    verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=conv2d_input_data)
+    verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=conv2d_input_data)
+
+    conv1d_input_shape = [1, 3, 10]
+    conv1d_input_data = torch.rand(conv1d_input_shape).float()
+    verify_model(torch.nn.ConvTranspose1d(3, 6, 7, bias=True), input_data=conv1d_input_data)
+    verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data)
 
 
 def test_forward_threshold():