[PYTORCH]Padding support (#5638)
authorSamuel <siju.samuel@huawei.com>
Thu, 21 May 2020 20:44:17 +0000 (02:14 +0530)
committerGitHub <noreply@github.com>
Thu, 21 May 2020 20:44:17 +0000 (05:44 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 28703da..cc7cd48 100644 (file)
@@ -1342,10 +1342,31 @@ def _none():
 def _pad():
     def _impl(inputs, input_types):
         data = inputs[0]
-        padding = inputs[1]
-        pad_width = list(zip(padding, padding))
+        if isinstance(inputs[1], list):
+            pad_list = inputs[1]
+        else:
+            pad_list = list(_infer_shape(inputs[1]))
+
+        # initialize paddings based on input len
+        pad_len = len(_infer_shape(data)) * 2
+        paddings = [0] * pad_len
+
+        if len(pad_list) >= 2:
+            paddings[-1] = pad_list[1]
+            paddings[-2] = pad_list[0]
+        if len(pad_list) >= 4:
+            paddings[-3] = pad_list[3]
+            paddings[-4] = pad_list[2]
+        if len(pad_list) >= 6:
+            paddings[-5] = pad_list[5]
+            paddings[-6] = pad_list[4]
+
+        # group into tuple of 2 ints
+        paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)]
+
         pad_value = inputs[2]
-        return _op.nn.pad(data, pad_width, pad_value)
+
+        return _op.nn.pad(data, paddings, pad_value)
     return _impl
 
 
index f1543f0..85928bf 100644 (file)
@@ -1020,6 +1020,50 @@ def test_adaptive_pool3d():
         verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)
 
 
+def test_forward_functional_pad():
+    torch.set_grad_enabled(False)
+    pad = (0, 0)
+    class Pad1(Module):
+        def forward(self, *args):
+            return torch.nn.functional.pad(args[0], pad, "constant", 0)
+
+    input_data = torch.rand((3, 3, 4, 2))
+    pad = (1, 1)
+    verify_model(Pad1().float().eval(), input_data=input_data)
+
+    pad = (1, 1, 2, 2)
+    verify_model(Pad1().float().eval(), input_data=input_data)
+
+    pad = (0, 1, 2, 1, 3, 3)
+    verify_model(Pad1().float().eval(), input_data=input_data)
+
+
+def test_forward_zero_pad2d():
+    inp = torch.rand((1, 1, 3, 3))
+    verify_model(torch.nn.ZeroPad2d(2).eval(), inp)
+    verify_model(torch.nn.ZeroPad2d((1, 1, 2, 0)).eval(), inp)
+
+
+def test_forward_constant_pad1d():
+    inp = torch.rand((1, 2, 4))
+    verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp)
+
+    inp = torch.rand((1, 2, 3))
+    verify_model(torch.nn.ConstantPad2d((3, 1), 3.5).eval(), inp)
+
+
+def test_forward_constant_pad2d():
+    inp = torch.rand((1, 2, 2, 2))
+    verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp)
+    verify_model(torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5).eval(), inp)
+
+
+def test_forward_constant_pad3d():
+    inp = torch.rand((1, 3, 2, 2, 2))
+    verify_model(torch.nn.ConstantPad3d(3, 3.5).eval(), inp)
+    verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp)
+
+
 def test_forward_reflection_pad2d():
     inp = torch.rand((1, 1, 3, 3))
     verify_model(torch.nn.ReflectionPad2d(2).eval(), inp)
@@ -2200,6 +2244,11 @@ if __name__ == "__main__":
     test_upsample()
     test_forward_upsample3d()
     test_to()
+    test_forward_functional_pad()
+    test_forward_zero_pad2d()
+    test_forward_constant_pad1d()
+    test_forward_constant_pad2d()
+    test_forward_constant_pad3d()
     test_forward_reflection_pad2d()
     test_adaptive_pool3d()
     test_conv3d()