[PYTORCH]AvgPool3d, MaxPool3d and Squeeze op support (#5220)
authorSamuel <siju.samuel@huawei.com>
Thu, 2 Apr 2020 22:20:41 +0000 (03:50 +0530)
committerGitHub <noreply@github.com>
Thu, 2 Apr 2020 22:20:41 +0000 (07:20 +0900)
* [PYTORCH]AvgPool3d, MaxPool3d and Squeeze op support

* Testcases added

* review comments

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 9a08af9..977a899 100644 (file)
@@ -57,6 +57,17 @@ def _elemwise(name):
         return get_relay_op(name)(data0, data1)
     return _impl
 
+def _squeeze():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        if len(inputs) == 1:
+            axis = None
+        else:
+            axis = [int(inputs[1])]
+
+        return _op.transform.squeeze(data, axis)
+    return _impl
+
 def _unsqueeze():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -297,6 +308,26 @@ def _maxpool_1d():
         return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode)
     return _impl
 
+def _maxpool_3d():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+
+        pool_size = _infer_shape(inputs[1])
+        strides = _infer_shape(inputs[2])
+        padding = _infer_shape(inputs[3])
+        dilation = _infer_shape(inputs[4])
+        ceil_mode = int(inputs[5])
+        if dilation != (1, 1, 1):
+            msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation), )
+            raise NotImplementedError(msg)
+
+        return _op.nn.max_pool3d(data,
+                                 pool_size=pool_size,
+                                 strides=strides,
+                                 padding=padding,
+                                 ceil_mode=ceil_mode)
+    return _impl
+
 def _hardtanh():
     def _impl(inputs, input_types):
         a = inputs[0]
@@ -631,6 +662,27 @@ def _avg_pool2d():
 
     return _impl
 
+def _avg_pool3d():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+
+        pool_size = _infer_shape(inputs[1])
+        if inputs[2]:
+            strides = _infer_shape(inputs[2])
+        else:
+            strides = pool_size
+        padding = _infer_shape(inputs[3])
+
+        ceil_mode = int(inputs[4])
+        count_include_pad = int(inputs[5])
+
+        return _op.nn.avg_pool3d(data,
+                                 pool_size=pool_size,
+                                 strides=strides,
+                                 padding=padding,
+                                 ceil_mode=ceil_mode,
+                                 count_include_pad=count_include_pad)
+    return _impl
 
 def _dropout():
     def _impl(inputs, input_types):
@@ -970,6 +1022,7 @@ _convert_map = {
     "aten::ones"                            : _ones(),
     "aten::zeros"                           : _zeros(),
     "aten::to"                              : _to(),
+    "aten::squeeze"                         : _squeeze(),
     "aten::unsqueeze"                       : _unsqueeze(),
     "aten::cat"                             : _concatenate(),
     "aten::slice"                           : _slice(),
@@ -987,6 +1040,7 @@ _convert_map = {
     "aten::max_pool2d"                      : _maxpool_2d(),
     "aten::max_pool2d_with_indices"         : _maxpool_2d(),
     "aten::max_pool1d"                      : _maxpool_1d(),
+    "aten::max_pool3d"                      : _maxpool_3d(),
     "aten::hardtanh"                        : _hardtanh(),
     "aten::hardtanh_"                       : _hardtanh(),
     "aten::_convolution"                    : _convolution(),
@@ -1007,6 +1061,7 @@ _convert_map = {
     "aten::log_softmax"                     : _log_softmax(),
     "aten::sigmoid"                         : _sigmoid(),
     "aten::avg_pool2d"                      : _avg_pool2d(),
+    "aten::avg_pool3d"                      : _avg_pool3d(),
     "aten::dropout"                         : _dropout(),
     "aten::dropout_"                        : _dropout(),
     "aten::feature_dropout"                 : _dropout(),
index c75ae6e..e7c2e08 100644 (file)
@@ -304,6 +304,22 @@ def test_forward_unsqueeze():
     input_data = torch.rand(input_shape).float()
     verify_model(Unsqueeze1().float().eval(), input_data=input_data)
 
+def test_forward_squeeze():
+    torch.set_grad_enabled(False)
+    input_shape = [2, 1, 10, 1, 10]
+
+    class Squeeze1(Module):
+        def forward(self, *args):
+            return args[0].squeeze()
+
+    class Squeeze2(Module):
+        def forward(self, *args):
+            return args[0].squeeze(1)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Squeeze1().float().eval(), input_data=input_data)
+    verify_model(Squeeze2().float().eval(), input_data=input_data)
+
 def test_forward_concatenate():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
@@ -388,6 +404,20 @@ def test_forward_maxpool1d():
                                     stride=2).eval(),
                 input_data)
 
+def test_forward_maxpool3d():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10, 10]
+    input_data = torch.rand(input_shape).float()
+
+    verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(),
+                input_data)
+    verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(),
+                input_data)
+    verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4],
+                                    padding=2,
+                                    stride=2).eval(),
+                input_data)
+
 def test_forward_split():
     torch.set_grad_enabled(False)
     input_shape = [4, 10]
@@ -423,6 +453,18 @@ def test_forward_avgpool():
     verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data)
     verify_model(AvgPool2D2().float().eval(), input_data=input_data)
 
+def test_forward_avgpool3d():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10, 10]
+
+    class AvgPool3D1(Module):
+        def forward(self, *args):
+            return torch.nn.functional.avg_pool3d(args[0], kernel_size=[10, 10, 10])
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data)
+    verify_model(AvgPool3D1().float().eval(), input_data=input_data)
+
 def test_forward_hardtanh():
     torch.set_grad_enabled(False)
     input_shape = [10]
@@ -1071,6 +1113,7 @@ if __name__ == "__main__":
     test_forward_add()
     test_forward_subtract()
     test_forward_multiply()
+    test_forward_squeeze()
     test_forward_unsqueeze()
     test_forward_concatenate()
     test_forward_relu()
@@ -1081,6 +1124,7 @@ if __name__ == "__main__":
     test_forward_adaptiveavgpool()
     test_forward_maxpool2d()
     test_forward_maxpool1d()
+    test_forward_maxpool3d()
     test_forward_hardtanh()
     test_forward_conv()
     test_forward_conv_transpose()
@@ -1097,6 +1141,7 @@ if __name__ == "__main__":
     test_forward_sigmoid()
     test_forward_dense()
     test_forward_avgpool()
+    test_forward_avgpool3d()
     test_forward_dropout()
     test_forward_slice()
     test_forward_mean()