[PYTORCH]Repeat, Reciprocal & Reshape Op support (#5280)
authorSamuel <siju.samuel@huawei.com>
Fri, 10 Apr 2020 15:08:56 +0000 (20:38 +0530)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 15:08:56 +0000 (00:08 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 46068a4..b8b32e7 100644 (file)
@@ -154,6 +154,34 @@ def _select():
         return _op.transform.take(data, index, axis=dim)
     return _impl
 
+def _reciprocal():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        return _expr.const(1.0) / data
+    return _impl
+
+def _repeat():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        reps = _get_dims(inputs[1])
+        return _op.transform.tile(data, reps=reps)
+    return _impl
+
+def _repeat_interleave():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        if isinstance(inputs[1], int):
+            repeats = inputs[1]
+            axis = inputs[2]
+        else:
+            msg = "Only repeat with one value as repeat is currently supported."
+            raise AssertionError(msg)
+        if axis is None: # Flatten the data if no axis is given from torch
+            data = _op.transform.reshape(data, [-1])
+            axis = 0
+        return _op.transform.repeat(data, repeats=repeats, axis=axis)
+    return _impl
+
 def _ones():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -675,6 +703,16 @@ def _view():
         return _op.transform.reshape(data, new_shape)
     return _impl
 
+def _reshape():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        if isinstance(inputs[1], list):
+            new_shape = inputs[1]
+        else:
+            new_shape = _infer_shape(inputs[1])
+        return _op.transform.reshape(data, new_shape)
+    return _impl
+
 def _clone():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1082,6 +1120,9 @@ _convert_map = {
     "aten::div_"                            : _elemwise("divide"),
     "aten::ones"                            : _ones(),
     "aten::zeros"                           : _zeros(),
+    "aten::reciprocal"                      : _reciprocal(),
+    "aten::repeat"                          : _repeat(),
+    "aten::repeat_interleave"               : _repeat_interleave(),
     "aten::to"                              : _to(),
     "aten::squeeze"                         : _squeeze(),
     "aten::unsqueeze"                       : _unsqueeze(),
@@ -1122,6 +1163,7 @@ _convert_map = {
     "aten::addmm"                           : _dense(),
     "aten::size"                            : _size(),
     "aten::view"                            : _view(),
+    "aten::reshape"                         : _reshape(),
     "aten::clone"                           : _clone(),
     "aten::log_softmax"                     : _log_softmax(),
     "aten::sigmoid"                         : _sigmoid(),
index 05bf7e4..4226463 100644 (file)
@@ -293,6 +293,61 @@ def test_forward_multiply():
     verify_model(Multiply3().float().eval(), input_data=input_data)
     verify_model(Multiply4().float().eval(), input_data=input_data)
 
+def test_forward_reciprocal():
+    torch.set_grad_enabled(False)
+    input_shape = [2, 1, 10, 1, 10]
+    class Reciprocal1(Module):
+        def forward(self, *args):
+            return args[0].reciprocal()
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Reciprocal1().float().eval(), input_data=input_data)
+
+def test_forward_repeat():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3]
+    class Repeat1(Module):
+        def forward(self, *args):
+            return args[0].repeat(1, 1)
+
+    class Repeat2(Module):
+        def forward(self, *args):
+            return args[0].repeat(4, 2)
+
+    class Repeat3(Module):
+        def forward(self, *args):
+            return args[0].repeat(4, 2, 1)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Repeat1().float().eval(), input_data=input_data)
+    verify_model(Repeat2().float().eval(), input_data=input_data)
+    verify_model(Repeat3().float().eval(), input_data=input_data)
+
+def test_forward_repeat_interleave():
+    torch.set_grad_enabled(False)
+    input_shape = [2, 2, 3]
+    class RepeatInterleave1(Module):
+        def forward(self, *args):
+            return args[0].repeat_interleave(2)
+
+    class RepeatInterleave2(Module):
+        def forward(self, *args):
+            return args[0].repeat_interleave(3, dim=0)
+
+    class RepeatInterleave3(Module):
+        def forward(self, *args):
+            return args[0].repeat_interleave(2, dim=1)
+
+    class RepeatInterleave4(Module):
+        def forward(self, *args):
+            return args[0].repeat_interleave(4, dim=2)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(RepeatInterleave1().float().eval(), input_data=input_data)
+    verify_model(RepeatInterleave2().float().eval(), input_data=input_data)
+    verify_model(RepeatInterleave3().float().eval(), input_data=input_data)
+    verify_model(RepeatInterleave4().float().eval(), input_data=input_data)
+
 def test_forward_unsqueeze():
     torch.set_grad_enabled(False)
     input_shape = [10, 10]
@@ -600,6 +655,22 @@ def test_forward_layernorm():
         init_weight(ln.eval())
         verify_model(ln.eval(), input_data=inp)
 
+def test_forward_reshape():
+    torch.set_grad_enabled(False)
+    input_shape = [2, 1, 10, 1, 10]
+    new_shape = [2, 1, 10, 10]
+    class Reshape1(Module):
+        def forward(self, *args):
+            return args[0].reshape(new_shape)
+
+    class Reshape2(Module):
+        def forward(self, *args):
+            return args[0].reshape([-1])
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Reshape1().float().eval(), input_data=input_data)
+    verify_model(Reshape2().float().eval(), input_data=input_data)
+
 def test_forward_transpose():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
@@ -1151,6 +1222,10 @@ if __name__ == "__main__":
     test_forward_add()
     test_forward_subtract()
     test_forward_multiply()
+    test_forward_reshape()
+    test_forward_reciprocal()
+    test_forward_repeat()
+    test_forward_repeat_interleave()
     test_forward_squeeze()
     test_forward_unsqueeze()
     test_forward_concatenate()