[PYTORCH]ReflectionPad2d op (#5624)
authorSamuel <siju.samuel@huawei.com>
Wed, 20 May 2020 01:10:23 +0000 (06:40 +0530)
committerGitHub <noreply@github.com>
Wed, 20 May 2020 01:10:23 +0000 (10:10 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index efb95f9..46b5cec 100644 (file)
@@ -1577,6 +1577,22 @@ def _one_hot():
     return _impl
 
 
+def _reflection_pad2d():
+    def _impl(inputs, input_types):
+        if isinstance(inputs[1], list):
+            pad_list = inputs[1]
+        else:
+            pad_list = list(_infer_shape(inputs[1]))
+        padding_left = pad_list[0]
+        padding_right = pad_list[1]
+        padding_top = pad_list[2]
+        padding_bottom = pad_list[3]
+        paddings = [[0, 0], [0, 0], [padding_top, padding_bottom], [padding_left, padding_right]]
+
+        return _op.nn.mirror_pad(inputs[0], paddings, mode='REFLECT')
+    return _impl
+
+
 # Helper functions for operator implementation
 def _convert_dtype_value(val):
     convert_torch_dtype_map = {7:"torch.float64",
@@ -1695,6 +1711,7 @@ def _get_convert_map(prelude):
         "aten::prelu"                           : _prelu(),
         "aten::leaky_relu"                      : _leaky_relu(),
         "aten::elu"                             : _elu(),
+        "aten::elu_"                            : _elu(),
         "aten::celu"                            : _celu(),
         "aten::gelu"                            : _gelu(),
         "aten::selu"                            : _selu(),
@@ -1798,6 +1815,7 @@ def _get_convert_map(prelude):
         "aten::embedding"                       : _embedding(),
         "aten::one_hot"                         : _one_hot(),
         "aten::mm"                              : _matmul(prelude),
+        "aten::reflection_pad2d"                : _reflection_pad2d(),
         "relay::tensor_array_stack"             : _tensor_array_stack(prelude),
         "aten::add"                             : _add(prelude),
         "aten::add_"                            : _add(prelude),
index 30036db..50c3ede 100644 (file)
@@ -1020,6 +1020,15 @@ def test_adaptive_pool3d():
         verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)
 
 
+def test_forward_reflection_pad2d():
+    inp = torch.rand((1, 1, 3, 3))
+    verify_model(torch.nn.ReflectionPad2d(2).eval(), inp)
+    verify_model(torch.nn.ReflectionPad2d((1, 1, 2, 0)).eval(), inp)
+
+    inp = torch.rand((2, 4, 5, 6))
+    verify_model(torch.nn.ReflectionPad2d((1, 3, 2, 4)).eval(), inp)
+
+
 def test_conv3d():
     for ishape in [(1, 32, 16, 16, 16),
                    (1, 32, 9, 15, 15),
@@ -2183,6 +2192,7 @@ if __name__ == "__main__":
     test_forward_split()
     test_upsample()
     test_to()
+    test_forward_reflection_pad2d()
     test_adaptive_pool3d()
     test_conv3d()