[PYTORCH]ReplicationPad support added (#5708)
authorSamuel <siju.samuel@huawei.com>
Tue, 2 Jun 2020 00:15:41 +0000 (05:45 +0530)
committerGitHub <noreply@github.com>
Tue, 2 Jun 2020 00:15:41 +0000 (09:15 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index f2e24a1..da8c094 100644 (file)
@@ -1369,7 +1369,7 @@ def _none():
         return None
     return _impl
 
-def _pad():
+def _pad(mode):
     def _impl(inputs, input_types):
         data = inputs[0]
         if isinstance(inputs[1], list):
@@ -1394,9 +1394,11 @@ def _pad():
         # group into tuple of 2 ints
         paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)]
 
-        pad_value = inputs[2]
+        if mode == "constant":
+            return _op.nn.pad(data, paddings, pad_value=inputs[2], pad_mode=mode)
+        else:
+            return _op.nn.pad(data, paddings, pad_mode=mode)
 
-        return _op.nn.pad(data, paddings, pad_value)
     return _impl
 
 
@@ -1654,22 +1656,6 @@ def _one_hot():
     return _impl
 
 
-def _reflection_pad2d():
-    def _impl(inputs, input_types):
-        if isinstance(inputs[1], list):
-            pad_list = inputs[1]
-        else:
-            pad_list = list(_infer_shape(inputs[1]))
-        padding_left = pad_list[0]
-        padding_right = pad_list[1]
-        padding_top = pad_list[2]
-        padding_bottom = pad_list[3]
-        paddings = [[0, 0], [0, 0], [padding_top, padding_bottom], [padding_left, padding_right]]
-
-        return _op.nn.mirror_pad(inputs[0], paddings, mode='REFLECT')
-    return _impl
-
-
 # Helper functions for operator implementation
 def _convert_dtype_value(val):
     convert_torch_dtype_map = {7:"torch.float64",
@@ -1836,7 +1822,12 @@ def _get_convert_map(prelude):
         "aten::Int"                             : _int(),
         "prim::NumToTensor"                     : _numtotensor(),
         "prim::ImplicitTensorToNum"             : _tensortonum(),
-        "aten::constant_pad_nd"                 : _pad(),
+        "aten::constant_pad_nd"                 : _pad("constant"),
+        "aten::reflection_pad1d"                : _pad("reflect"),
+        "aten::reflection_pad2d"                : _pad("reflect"),
+        "aten::replication_pad1d"               : _pad("edge"),
+        "aten::replication_pad2d"               : _pad("edge"),
+        "aten::replication_pad3d"               : _pad("edge"),
         "aten::permute"                         : _transpose(prelude),
         "aten::sum"                             : _reduce("sum"),
         "aten::prod"                            : _reduce("prod"),
@@ -1895,7 +1886,6 @@ def _get_convert_map(prelude):
         "aten::embedding"                       : _embedding(),
         "aten::one_hot"                         : _one_hot(),
         "aten::mm"                              : _matmul(prelude),
-        "aten::reflection_pad2d"                : _reflection_pad2d(),
         "relay::tensor_array_stack"             : _tensor_array_stack(prelude),
         "aten::add"                             : _add(prelude),
         "aten::add_"                            : _add(prelude),
index f6edbf1..e41da7e 100644 (file)
@@ -1116,6 +1116,15 @@ def test_forward_constant_pad3d():
     verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp)
 
 
+def test_forward_reflection_pad1d():
+    inp = torch.rand((1, 2, 4))
+    verify_model(torch.nn.ReflectionPad1d(2).eval(), inp)
+    verify_model(torch.nn.ReflectionPad1d((3, 1)).eval(), inp)
+
+    inp = torch.rand((2, 4, 5))
+    verify_model(torch.nn.ReflectionPad1d((2, 3)).eval(), inp)
+
+
 def test_forward_reflection_pad2d():
     inp = torch.rand((1, 1, 3, 3))
     verify_model(torch.nn.ReflectionPad2d(2).eval(), inp)
@@ -1125,6 +1134,33 @@ def test_forward_reflection_pad2d():
     verify_model(torch.nn.ReflectionPad2d((1, 3, 2, 4)).eval(), inp)
 
 
+def test_forward_replication_pad1d():
+    inp = torch.rand((1, 2, 4))
+    verify_model(torch.nn.ReplicationPad1d(2).eval(), inp)
+    verify_model(torch.nn.ReplicationPad1d((3, 1)).eval(), inp)
+
+    inp = torch.rand((2, 4, 5))
+    verify_model(torch.nn.ReplicationPad1d((2, 3)).eval(), inp)
+
+
+def test_forward_replication_pad2d():
+    inp = torch.rand((1, 1, 3, 3))
+    verify_model(torch.nn.ReplicationPad2d(2).eval(), inp)
+    verify_model(torch.nn.ReplicationPad2d((1, 1, 2, 0)).eval(), inp)
+
+    inp = torch.rand((2, 4, 5, 6))
+    verify_model(torch.nn.ReplicationPad2d((1, 3, 2, 4)).eval(), inp)
+
+
+def test_forward_replication_pad3d():
+    inp = torch.rand((1, 1, 3, 3, 3))
+    verify_model(torch.nn.ReplicationPad3d(3).eval(), inp)
+    verify_model(torch.nn.ReplicationPad3d((1, 1, 2, 2, 1, 1)).eval(), inp)
+
+    inp = torch.rand((7, 5, 4, 5, 6))
+    verify_model(torch.nn.ReplicationPad3d((2, 3, 2, 5, 1, 4)).eval(), inp)
+
+
 def test_forward_upsample3d():
     inp = torch.arange(1, 9, dtype=torch.float32).view(1, 1, 2, 2, 2)
     verify_model(torch.nn.Upsample(scale_factor=2, mode='nearest').eval(), inp)
@@ -2429,7 +2465,11 @@ if __name__ == "__main__":
     test_forward_constant_pad1d()
     test_forward_constant_pad2d()
     test_forward_constant_pad3d()
+    test_forward_reflection_pad1d()
     test_forward_reflection_pad2d()
+    test_forward_replication_pad1d()
+    test_forward_replication_pad2d()
+    test_forward_replication_pad3d()
     test_adaptive_pool3d()
     test_conv3d()