[Relay, Torch] Fix stack op axis check, support torch::stack conversion for a static...
authormasahi <masahi129@gmail.com>
Thu, 10 Sep 2020 11:42:49 +0000 (20:42 +0900)
committerGitHub <noreply@github.com>
Thu, 10 Sep 2020 11:42:49 +0000 (20:42 +0900)
* fix torch::stack conversion, add dynamic stack test

* add test to relay stack

* add comment

* add more comment

* uncomment relay op tests

* check for List ADT properly

* improve assertion

Co-authored-by: masa <masa@pop-os.localdomain>
python/tvm/relay/frontend/pytorch.py
src/relay/op/tensor/transform.cc
tests/python/frontend/pytorch/test_forward.py
tests/python/frontend/pytorch/test_lstm.py
tests/python/relay/test_op_level3.py

index 51d90e1..7203150 100644 (file)
@@ -96,7 +96,7 @@ def _should_construct_dynamic_list(list_construct_node):
 
     op_names = map(inplace_add_to_add, set(use.user.kind() for use in uses))
 
-    list_ops = set(["aten::add", "aten::__getitem__", "aten::stack"])
+    list_ops = set(["aten::add", "aten::__getitem__"])
     intersect = list_ops.intersection(op_names)
 
     if len(intersect) > 0 and intersect != set(["aten::add"]):
@@ -1744,6 +1744,8 @@ def _add(prelude):
 
 def _tensor_array_stack(prelude):
     def _impl(inputs, input_types):
+        dim = inputs[1]
+        assert dim == 0, "stacking on a dynamic tensor list only supported on a first axis"
         tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude)
 
         stacked_shape = (Any(),) + shape
@@ -1757,6 +1759,23 @@ def _tensor_array_stack(prelude):
     return _impl
 
 
+def _stack(prelude):
+    def _impl(inputs, input_types):
+        if isinstance(inputs[0], list):
+            # a static python list of tensors
+            dim = inputs[1]
+            return _op.stack(inputs[0], dim)
+        else:
+            # List ADT case
+            assert isinstance(inputs[0], _expr.Expr)
+            ty = _infer_type_with_prelude(inputs[0], prelude)
+            list_ty = prelude.mod.get_global_type_var("List")
+            msg = "The input list is expected to be List ADT"
+            assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg
+            return _tensor_array_stack(prelude)(inputs, input_types)
+    return _impl
+
+
 def _rsub():
     def _impl(inputs, input_types):
         data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
@@ -2193,10 +2212,9 @@ def _get_convert_map(prelude, default_dtype):
         "aten::embedding"                       : _embedding(),
         "aten::one_hot"                         : _one_hot(),
         "aten::mm"                              : _matmul(prelude),
-        "relay::tensor_array_stack"             : _tensor_array_stack(prelude),
         "aten::add"                             : _add(prelude),
         "aten::add_"                            : _add(prelude),
-        "aten::stack"                           : _tensor_array_stack(prelude),
+        "aten::stack"                           : _stack(prelude),
         "aten::__getitem__"                     : _list_getitem(prelude),
         "aten::len"                             : _list_len(prelude),
         "aten::type_as"                         : _type_as(),
index 293875e..bf6ce4d 100644 (file)
@@ -295,8 +295,9 @@ bool StackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   // Sanity check: axis
   int axis = param->axis;
-  CHECK(-ndim <= axis && axis < ndim) << "stack only accepts `axis` in [-ndim, ndim)"
-                                      << ", but got axis = " << axis << ", and ndim = " << ndim;
+  CHECK(-(ndim + 1) <= axis && axis < ndim + 1)
+      << "stack only accepts `axis` in [-(ndim+1), ndim+1)"
+      << ", but got axis = " << axis << ", and ndim = " << ndim;
   axis = axis < 0 ? ndim + axis + 1 : axis;
 
   // Sanity check: ndim and dtype.
index fe14c91..2ce669f 100644 (file)
@@ -2925,6 +2925,35 @@ def test_logsumexp():
     verify_model(Logsumexp(1, keepdim=True), input_data=input_data.double())
 
 
+def test_stack():
+    class Stack(torch.nn.Module):
+        def __init__(self, axis=0):
+            super().__init__()
+            self.axis = axis
+
+        def forward(self, x):
+            return torch.stack((x, x), dim=self.axis)
+
+    inp = torch.randn(8, 8, 8)
+    verify_model(Stack(), input_data=inp)
+    verify_model(Stack(axis=-1), input_data=inp)
+    verify_model(Stack(axis=3), input_data=inp)
+    verify_model(Stack(axis=-4), input_data=inp)
+
+
+def test_stack_dynamic():
+    class Stack(torch.nn.Module):
+        def forward(self, x):
+            tensor_list = []
+            for i in range(x.size(0)):
+                # this is a workaround to avoid generating impure aten::append op
+                tensor_list += [x[i]]
+            # relay tensor array only supports stacking on the first axis
+            return torch.stack(tensor_list, dim=0)
+
+    verify_script_model(Stack(), [(8, 8, 8)], _get_default_vm_targets())
+
+
 def test_forward_pretrained_bert_base_uncased():
     ######################################################################
     # This is an example how to run BERT models using TVM
@@ -3169,6 +3198,8 @@ if __name__ == "__main__":
     test_forward_index()
     test_min_max()
     test_logsumexp()
+    test_stack()
+    test_stack_dynamic()
 
     # Model tests
     test_resnet18()
index 4524a72..dcaa5e1 100644 (file)
@@ -26,6 +26,7 @@ from typing import List, Tuple
 from torch import Tensor
 
 import tvm
+import tvm.testing
 from tvm import relay
 from tvm.relay.frontend.pytorch import from_pytorch
 from tvm.relay.prelude import Prelude
index 940bb70..f709aa2 100644 (file)
@@ -714,6 +714,7 @@ def test_stack():
     verify_stack([(2,), (2,), (2,)], 0)
     verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
     verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
+    verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], 4)
 
 
 @tvm.testing.uses_gpu
@@ -998,7 +999,7 @@ def test_isfinite():
 def test_isinf():
     _verify_infiniteness_ops(relay.isinf, np.isinf)
 
-    
+
 @tvm.testing.uses_gpu
 def test_unravel_index():
     def verify_unravel_index(indices, shape, dtype):