op_names = map(inplace_add_to_add, set(use.user.kind() for use in uses))
- list_ops = set(["aten::add", "aten::__getitem__", "aten::stack"])
+ list_ops = set(["aten::add", "aten::__getitem__"])
intersect = list_ops.intersection(op_names)
if len(intersect) > 0 and intersect != set(["aten::add"]):
def _tensor_array_stack(prelude):
def _impl(inputs, input_types):
+ dim = inputs[1]
+ assert dim == 0, "stacking on a dynamic tensor list only supported on a first axis"
tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude)
stacked_shape = (Any(),) + shape
return _impl
+def _stack(prelude):
+ def _impl(inputs, input_types):
+ if isinstance(inputs[0], list):
+ # a static python list of tensors
+ dim = inputs[1]
+ return _op.stack(inputs[0], dim)
+ else:
+ # List ADT case
+ assert isinstance(inputs[0], _expr.Expr)
+ ty = _infer_type_with_prelude(inputs[0], prelude)
+ list_ty = prelude.mod.get_global_type_var("List")
+ msg = "The input list is expected to be List ADT"
+ assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg
+ return _tensor_array_stack(prelude)(inputs, input_types)
+ return _impl
+
+
def _rsub():
def _impl(inputs, input_types):
data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
"aten::embedding" : _embedding(),
"aten::one_hot" : _one_hot(),
"aten::mm" : _matmul(prelude),
- "relay::tensor_array_stack" : _tensor_array_stack(prelude),
"aten::add" : _add(prelude),
"aten::add_" : _add(prelude),
- "aten::stack" : _tensor_array_stack(prelude),
+ "aten::stack" : _stack(prelude),
"aten::__getitem__" : _list_getitem(prelude),
"aten::len" : _list_len(prelude),
"aten::type_as" : _type_as(),
// Sanity check: axis
int axis = param->axis;
- CHECK(-ndim <= axis && axis < ndim) << "stack only accepts `axis` in [-ndim, ndim)"
- << ", but got axis = " << axis << ", and ndim = " << ndim;
+ CHECK(-(ndim + 1) <= axis && axis < ndim + 1)
+ << "stack only accepts `axis` in [-(ndim+1), ndim+1)"
+ << ", but got axis = " << axis << ", and ndim = " << ndim;
axis = axis < 0 ? ndim + axis + 1 : axis;
// Sanity check: ndim and dtype.
verify_model(Logsumexp(1, keepdim=True), input_data=input_data.double())
+def test_stack():
+ class Stack(torch.nn.Module):
+ def __init__(self, axis=0):
+ super().__init__()
+ self.axis = axis
+
+ def forward(self, x):
+ return torch.stack((x, x), dim=self.axis)
+
+ inp = torch.randn(8, 8, 8)
+ verify_model(Stack(), input_data=inp)
+ verify_model(Stack(axis=-1), input_data=inp)
+ verify_model(Stack(axis=3), input_data=inp)
+ verify_model(Stack(axis=-4), input_data=inp)
+
+
+def test_stack_dynamic():
+ class Stack(torch.nn.Module):
+ def forward(self, x):
+ tensor_list = []
+ for i in range(x.size(0)):
+ # this is a workaround to avoid generating impure aten::append op
+ tensor_list += [x[i]]
+ # relay tensor array only supports stacking on the first axis
+ return torch.stack(tensor_list, dim=0)
+
+ verify_script_model(Stack(), [(8, 8, 8)], _get_default_vm_targets())
+
+
def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
test_forward_index()
test_min_max()
test_logsumexp()
+ test_stack()
+ test_stack_dynamic()
# Model tests
test_resnet18()
from torch import Tensor
import tvm
+import tvm.testing
from tvm import relay
from tvm.relay.frontend.pytorch import from_pytorch
from tvm.relay.prelude import Prelude
verify_stack([(2,), (2,), (2,)], 0)
verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
+ verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], 4)
@tvm.testing.uses_gpu
def test_isinf():
_verify_infiniteness_ops(relay.isinf, np.isinf)
-
+
@tvm.testing.uses_gpu
def test_unravel_index():
def verify_unravel_index(indices, shape, dtype):