from . import op
+def _get_name_static(canonical, dtype, shape):
+ """Get name for static shape tensor array op corresponding
+ to the canonical name"""
+ shape_str = '_'.join([str(dim) for dim in shape])
+ if len(shape_str) == 0:
+ shape_str = "scalar"
+ if canonical == 'tensor_t':
+ return 'static_tensor_{}_{}_t'.format(dtype, shape_str)
+ return "{}_{}_{}".format(canonical, dtype, shape_str)
+
+class StaticTensorArrayOps(object):
+ """Contains tensor array related ops for fixed rank tensor array"""
+
+ def __init__(self, prelude, dtype, shape):
+ """Create tensor array ops registry"""
+ self.prelude = prelude
+ self.dtype = dtype
+ self.shape = shape
+
+ def get_name(self, canonical):
+ """Get name corresponding to the canonical name"""
+ return _get_name_static(canonical, self.dtype, self.shape)
+
+ def get_var(self, canonical):
+ """Get var corresponding to the canonical name"""
+ name = self.get_name(canonical)
+ return getattr(self.prelude, name)
+
+ def define_tensor_adt(self):
+ """Defines the static tensor ADT, which is the container for tensors
+ with fixed shapes."""
+ tensor_type_name = self.get_name('tensor_t')
+ # Skip register if tensor type is already registered.
+ global_type_names = set()
+ for g_ty_var in self.prelude.mod.get_global_type_vars():
+ global_type_names.add(g_ty_var.name_hint)
+ if tensor_type_name in global_type_names:
+ return
+
+ tensor_type_var = GlobalTypeVar(tensor_type_name)
+ setattr(self.prelude, tensor_type_name, tensor_type_var)
+ tensor_type = TensorType(self.shape, self.dtype)
+ tensor_constructor_name = self.get_name('tensor_constructor')
+
+ tensor_nil_name = self.get_name('tensor_nil')
+ tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
+ tensor_case = Constructor(tensor_constructor_name, [tensor_type], tensor_type_var)
+
+ setattr(self.prelude, tensor_nil_name, tensor_nil_case)
+ setattr(self.prelude, tensor_constructor_name, tensor_case)
+ self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var,
+ [],
+ [tensor_nil_case, tensor_case])
+
+ def define_tensor_array(self):
+ """Defines a function to create a tensor array with size n.
+ tensor_array(n) : Tensor[(), int32] -> list[tensor_t]
+ """
+ tensor_array_constructor_name = self.get_name("tensor_array")
+ tensor_array_constructor_var = self._create_global_var(tensor_array_constructor_name)
+ setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
+ tensor_nil_var = self.get_var('tensor_nil')
+ tensor_type_var = self.get_var('tensor_t')
+ n = Var("x", scalar_type('int32'))
+ body = If(equal(n, const(0)),
+ self.prelude.nil(),
+ self.prelude.cons(tensor_nil_var(),
+ tensor_array_constructor_var(subtract(n, const(1)))))
+ self.prelude.mod[tensor_array_constructor_var] = \
+ Function([n], body, self.prelude.l(tensor_type_var()), [])
+
+ def define_tensor_take(self):
+ """Defines a function to return a range of tensor_t on axis 0.
+ tensor_take(t, lower, upper) :
+ tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
+ """
+ # We don't register take for scalar tensor.
+ ndim = len(self.shape)
+ if ndim == 0:
+ return
+
+ take_name = self.get_name("tensor_take")
+ take_var = self._create_global_var(take_name)
+ setattr(self.prelude, take_name, take_var)
+ origin_tensor_constructor = self.get_var('tensor_constructor')
+
+ output_shape = [Any(),] + list(self.shape[1:])
+ tensor_type_var, tensor_constructor = \
+ self._get_adt_by_shape(output_shape)
+
+ t = Var('tensor', self.get_var('tensor_t')())
+ lower = Var('lower', scalar_type('int32'))
+ upper = Var('upper', scalar_type('int32'))
+ tvar = Var('t')
+ case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]),
+ tensor_constructor(op.take(tvar,
+ op.arange(lower, upper, dtype='int32'),
+ axis=0)))
+ self.prelude.mod[take_var] = \
+ Function([t, lower, upper],
+ Match(t, [case], False), tensor_type_var(), [])
+
+ def define_tensor_concatenate(self):
+ """Defines a function to concatenate two tensor_t on axis 0.
+ tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
+ """
+ # We don't register concatenate for scalar tensor.
+ ndim = len(self.shape)
+ if ndim == 0:
+ return
+
+ concat_name = self.get_name("tensor_concatenate")
+ concat_var = self._create_global_var(concat_name)
+ setattr(self.prelude, concat_name, concat_var)
+ output_shape = [Any(),] + list(self.shape[1:])
+ tensor_type_var, tensor_constructor = \
+ self._get_adt_by_shape(output_shape)
+
+ origin_tensor_constructor = self.get_var('tensor_constructor')
+ origin_tensor_type_var = self.get_var('tensor_t')
+ x = Var("x", origin_tensor_type_var())
+ y = Var("y", origin_tensor_type_var())
+ t1 = Var("t1")
+ t2 = Var("t2")
+
+ case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]),
+ Match(y,
+ [Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]),
+ tensor_constructor(op.concatenate([t1, t2], axis=0)))],
+ False))
+
+ self.prelude.mod[concat_var] = \
+ Function([x, y], Match(x, [case], False), tensor_type_var(), [])
+
+
+ def define_tensor_expand_dims(self):
+ """Defines a function to grow a tensor_t's rank by adding one dimension in front
+ of the original tensor_t.
+ tensor_expand_dims(t) : tensor_t -> tensor_t
+ """
+ expand_dims_name = self.get_name("tensor_expand_dims")
+ expand_dims_var = self._create_global_var(expand_dims_name)
+ setattr(self.prelude, expand_dims_name, expand_dims_var)
+ origin_tensor_type_var = self.get_var('tensor_t')
+ origin_tensor_constructor = self.get_var('tensor_constructor')
+ x = Var("x", origin_tensor_type_var())
+
+ # Note: we set the added axis to be Any() instead of 1 due to
+ # in stack op, we need to recursively concatenate.
+ tensor_type_var, tensor_constructor = \
+ self._get_adt_by_shape([Any(),] + list(self.shape))
+ t = Var("t")
+ case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t)]),
+ tensor_constructor(op.expand_dims(t, 0, 1)))
+
+ self.prelude.mod[expand_dims_var] = \
+ Function([x], Match(x, [case], False), tensor_type_var(), [])
+
+ def define_tensor_array_read(self):
+ """Defines a function to get the nth element of a list. Assume the list has at least one
+ element.
+ tensor_array_read(ta, n) : list[static_tensor_t] -> Tensor[(), int32] ->
+ Tensor[self.shape, self.dtype]
+ """
+ read_name = self.get_name("tensor_array_read")
+ read_var = self._create_global_var(read_name)
+ setattr(self.prelude, read_name, read_var)
+ tensor_type_var = self.get_var('tensor_t')
+
+ tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+ n = Var("x", scalar_type('int32'))
+ self.prelude.mod[read_var] = \
+ Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [])
+
+ def define_tensor_array_write(self):
+ """Defines a function to update a tensor array at index n with value v.
+ tensor_array_write(ta, n, v) :
+ list[static_tensor_t] -> Tensor[(), int32] -> Tensor[self.shape, self.dtype] ->
+ list[static_tensor_t]
+ """
+ write_name = self.get_name("tensor_array_write")
+ write_var = self._create_global_var(write_name)
+ setattr(self.prelude, write_name, write_var)
+ tensor_type_var = self.get_var('tensor_t')
+ tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+ n = Var("x", scalar_type('int32'))
+ v = Var("v", tensor_type_var())
+ self.prelude.mod[write_var] = \
+ Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v),
+ self.prelude.l(tensor_type_var()), [])
+
+ def define_tensor_array_unstack(self):
+ """Defines a function to unstack the values of a tensor_t in a tensor array.
+ tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t]
+ """
+ ndim = len(self.shape)
+ # We don't register unstack for scalar tensor array
+ if ndim == 0:
+ return
+
+ helper_name = self.get_name("tensor_array_unstack_helper")
+ helper_var = self._create_global_var(helper_name)
+ setattr(self.prelude, helper_name, helper_var)
+ tensor = Var("t", TensorType(self.shape, self.dtype))
+ up = Var("up", scalar_type('int32'))
+ i = Var("i", scalar_type('int32'))
+ tensor_var = Var("tensor", TensorType(self.shape, self.dtype))
+
+ reduced_tensor_type_var, tensor_constructor = \
+ self._get_adt_by_shape(self.shape[1:])
+ helper_body = \
+ If(equal(i, up),
+ self.prelude.nil(),
+ self.prelude.cons(tensor_constructor(op.take(tensor, i, axis=0)),
+ helper_var(add(i, const(1)), up, tensor)))
+ self.prelude.mod[helper_var] = \
+ Function([i, up, tensor], helper_body, self.prelude.l(reduced_tensor_type_var()), [])
+
+ unstack_name = self.get_name("tensor_array_unstack")
+ unstack_var = self._create_global_var(unstack_name)
+ setattr(self.prelude, unstack_name, unstack_var)
+ shape = op.shape_of(tensor_var)
+ unstack_length = op.take(shape, const(0))
+ self.prelude.mod[unstack_var] = \
+ Function([tensor_var], helper_var(const(0), unstack_length, tensor_var),
+ self.prelude.l(reduced_tensor_type_var()), [])
+
+ def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
+ """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
+ tensor_array_scatter(ta, indices, value) :
+ list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t]
+
+ Set static indices shape by specifying indices_shape.
+ Set force_update to get static indices shape operator.
+ """
+ # When this operator has already been registered, only update
+ # when force_update is set. This should be used only when we need to
+ # redefine this op for static indices shape.
+ tensor_array_scatter_name = self.get_name("tensor_array_scatter")
+ if hasattr(self.prelude, tensor_array_scatter_name) and not force_update:
+ return
+
+ tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
+ tensor_array_scatter_helper_var = \
+ self._create_global_var(tensor_array_scatter_helper_name)
+ tensor_type_var = self.get_var('tensor_t')
+ ta = Var("ta", self.prelude.l(tensor_type_var()))
+ current = Var("current", scalar_type('int32'))
+ limit = Var("limit", scalar_type('int32'))
+ indices_ = Var('indices_', TensorType(indices_shape or [Any()], 'int32'))
+ values_ = Var('values_', self.prelude.l(tensor_type_var()))
+ write_var = self.get_var('tensor_array_write')
+ read_var = self.get_var('tensor_array_read')
+ helper_body = If(equal(current, limit),
+ ta,
+ tensor_array_scatter_helper_var(
+ write_var(ta, op.take(indices_, current),
+ read_var(values_, current)),
+ add(current, const(1)),
+ limit, indices_, values_))
+ self.prelude.mod[tensor_array_scatter_helper_var] = \
+ Function([ta, current, limit, indices_, values_],
+ helper_body, self.prelude.l(tensor_type_var()), [])
+
+ tensor_array_scatter_var = self._create_global_var(tensor_array_scatter_name)
+ setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
+ tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+
+ indices = Var('indices', TensorType(indices_shape or [Any()], 'int32'))
+ values = Var('values', self.prelude.l(tensor_type_var()))
+ if indices_shape is None:
+ indices_shape = op.shape_of(indices)
+ limit = op.take(indices_shape, const(0))
+ else:
+ limit = const(indices_shape[0])
+
+ body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
+ self.prelude.mod[tensor_array_scatter_var] = \
+ Function([tensor_array, indices, values], body,
+ self.prelude.l(tensor_type_var()), [])
+
+ def define_tensor_array_split(self,
+ value_shape=None,
+ lengths_shape=None,
+ force_update=False):
+ """Defines a function to split the values of a tensor_t into a tensor array.
+ tensor_array_split(ta, value, lengths) :
+ list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
+
+ Set static value and lengths shapes by specifying value_shape and lengths_shape.
+ Set force_update to get static value and lengths shape operator.
+ """
+ # Skip scalar case
+ ndim = len(self.shape)
+ if ndim == 0:
+ return
+
+ # When this operator has already been registered, only update
+ # when force_update is set. This should be used only when we need to
+ # redefine this op for static value/indices shape.
+ split_name = self.get_name("tensor_array_split")
+ if hasattr(self.prelude, split_name) and not force_update:
+ return
+
+ tensor_type_var = self.get_var('tensor_t')
+ tensor_array_split_helper_name = self.get_name("ta_split_helper")
+ tensor_array_split_helper_var = \
+ self._create_global_var(tensor_array_split_helper_name)
+ setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
+ output_shape = [Any(),] + list(self.shape[1:])
+ output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
+
+ if value_shape is None:
+ value_type_var = tensor_type_var
+ take_var = self.get_var('tensor_take')
+ else:
+ value_type_var, _ = self._get_adt_by_shape(value_shape)
+ # Also get static shape take operator
+ origin_shape = list(self.shape)
+ self.shape = value_shape
+ self.define_tensor_take()
+ take_var = self.get_var('tensor_take')
+ self.shape = origin_shape
+
+
+ ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var()))
+ value1 = Var('value1', value_type_var())
+ offset1 = Var('offset1', scalar_type('int32'))
+ current1 = Var('current1', scalar_type('int32'))
+ limit1 = Var('limit1', scalar_type('int32'))
+ lengths1 = Var('lengths', TensorType(lengths_shape or [Any()], 'int32'))
+
+ # Register write for output shape
+ origin_shape = list(self.shape)
+ self.shape = output_shape
+ self.define_tensor_array_write()
+ write_var = self.get_var('tensor_array_write')
+ self.shape = origin_shape
+ helper1_body = If(equal(current1, limit1),
+ ta1,
+ write_var(
+ tensor_array_split_helper_var(
+ ta1,
+ value1,
+ add(offset1, op.take(lengths1, current1)),
+ add(current1, const(1)),
+ limit1,
+ lengths1
+ ),
+ current1,
+ take_var(value1,
+ offset1,
+ add(op.take(lengths1, current1), offset1))))
+ self.prelude.mod[tensor_array_split_helper_var] = \
+ Function([ta1, value1, offset1, current1, limit1, lengths1],
+ helper1_body, self.prelude.l(output_tensor_type_var()), [])
+ split_var = self._create_global_var(split_name)
+ setattr(self.prelude, split_name, split_var)
+ tensor_array = Var("tensor_array", self.prelude.l(output_tensor_type_var()))
+
+ value = Var('value', value_type_var())
+ lengths = Var('lengths', TensorType(lengths_shape or [Any()], 'int32'))
+ if lengths_shape is None:
+ lengths_shape = op.shape_of(lengths)
+ lengths_limit = op.take(lengths_shape, const(0))
+ else:
+ lengths_limit = const(lengths_shape[0])
+ body = tensor_array_split_helper_var(
+ tensor_array,
+ value,
+ const(0),
+ const(0),
+ lengths_limit,
+ lengths)
+
+ self.prelude.mod[split_var] = \
+ Function([tensor_array, value, lengths], body,
+ self.prelude.l(output_tensor_type_var()), [])
+
+ def define_tensor_array_concat(self):
+ """Defines a function to return the values in the tensor array as concatenated tensor_t.
+ tensor_array_concat(ta) : list[tensor_t] -> tensor_t
+ """
+ # We don't register concat for scalar tensor array.
+ ndim = len(self.shape)
+ if ndim == 0:
+ return
+
+ concat_name = self.get_name("tensor_array_concat")
+ concat_var = self._create_global_var(concat_name)
+ setattr(self.prelude, concat_name, concat_var)
+
+ output_shape = [Any(),] + list(self.shape[1:])
+ tensor_type_var, _ = self._get_adt_by_shape(output_shape)
+
+ # Register tensor concatenate and get tensor_nil var for output shape
+ origin_shape = self.shape
+ self.shape = output_shape
+ self.define_tensor_concatenate()
+ tensor_concat_var = self.get_var('tensor_concatenate')
+ tensor_nil_var = self.get_var('tensor_nil')
+ self.shape = origin_shape
+
+ tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+ hd = Var("hd")
+ tl = Var("tl")
+ nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
+ cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
+ Match(tl, [
+ Clause(PatternConstructor(self.prelude.nil), hd),
+ Clause(PatternWildcard(),
+ tensor_concat_var(hd, concat_var(tl)))
+ ], False))
+ self.prelude.mod[concat_var] = \
+ Function([tensor_array],
+ Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), [])
+
+ def define_tensor_array_stack(self):
+ """Defines a function to get the values in the tensor array as a stack tensor_t.
+ tensor_array_stack(l) : list[tensor_t] -> tensor_t
+ """
+ stack_name = self.get_name("tensor_array_stack")
+ stack_var = self._create_global_var(stack_name)
+ setattr(self.prelude, stack_name, stack_var)
+ tensor_type_var = self.get_var('tensor_t')
+ tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+ expand_dims_var = self.get_var('tensor_expand_dims')
+
+ # Register tensor_concatenate for output_shape
+ origin_shape = self.shape
+ output_shape = [Any(),] + list(self.shape)
+ self.shape = output_shape
+ self.define_tensor_concatenate()
+ concat_var = self.get_var('tensor_concatenate')
+ self.shape = origin_shape
+
+ tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
+ tensors = self.prelude.foldl(concat_var,
+ self.prelude.hd(tensor_array_expand_dims),
+ self.prelude.tl(tensor_array_expand_dims))
+ output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
+ self.prelude.mod[stack_var] = Function([tensor_array], tensors,
+ output_tensor_type_var(), [])
+
+ def define_tensor_array_gather(self):
+ """Defines a function to return the selected values in a tensor array as tensor_t.
+ tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t
+ """
+ helper_name = self.get_name("tensor_array_gather_helper")
+ helper_var = self._create_global_var(helper_name)
+ setattr(self.prelude, helper_name, helper_var)
+ tensor_type_var = self.get_var('tensor_t')
+ output_shape = [Any(),] + list(self.shape)
+ output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
+ stack_var = self.get_var('tensor_array_stack')
+ read_var = self.get_var('tensor_array_read')
+ ta = Var("ta", self.prelude.l(tensor_type_var()))
+ accu = Var("accu", self.prelude.l(tensor_type_var()))
+ current = Var("current", scalar_type('int32'))
+ limit = Var("limit", scalar_type('int32'))
+ indices_ = Var('indices_', TensorType([Any()], 'int32'))
+ helper_body = \
+ If(equal(current, const(0)),
+ stack_var(accu),
+ helper_var(
+ ta,
+ self.prelude.cons(
+ read_var(
+ ta, op.take(indices_, subtract(current, const(1)))), accu),
+ subtract(current, const(1)),
+ limit, indices_))
+ self.prelude.mod[helper_var] = \
+ Function([ta, accu, current, limit, indices_],
+ helper_body, output_tensor_type_var(), [])
+ gather_name = self.get_name("tensor_array_gather")
+ gather_var = self._create_global_var(gather_name)
+ setattr(self.prelude, gather_name, gather_var)
+ tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+ indices = Var('indices', TensorType([Any()], 'int32'))
+ indices_shape = op.shape_of(indices)
+ limit = op.take(indices_shape, const(0))
+ body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
+ self.prelude.mod[gather_var] = \
+ Function([tensor_array, indices], body, output_tensor_type_var(), [])
+
+ def define_tensor_get_data(self, data_shape):
+ """Defines a function to get a Tensor from tensor_t with given shape.
+ """
+ tensor_get_data_name = self.get_name("tensor_get_data")
+ tensor_get_data_var = self._create_global_var(tensor_get_data_name)
+ setattr(self.prelude, tensor_get_data_name, tensor_get_data_var)
+
+ tensor_type_var, tensor_constructor = self._get_adt_by_shape(data_shape)
+ t = Var('tensor', tensor_type_var())
+ tvar = Var('t')
+ case =\
+ Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar)
+ self.prelude.mod[tensor_get_data_var] = \
+ Function([t], Match(t, [case], False),
+ TensorType(data_shape, self.dtype), [])
+
+ def register(self):
+ """Register all tensor array ops in Prelude"""
+ self.define_tensor_adt()
+ self.define_tensor_take()
+ self.define_tensor_concatenate()
+ self.define_tensor_expand_dims()
+ self.define_tensor_array()
+ self.define_tensor_array_read()
+ self.define_tensor_array_write()
+ self.define_tensor_array_unstack()
+ self.define_tensor_array_scatter()
+ self.define_tensor_array_split()
+ self.define_tensor_array_concat()
+ self.define_tensor_array_stack()
+ self.define_tensor_array_gather()
+
+ def _get_adt_by_shape(self, shape):
+ """Get ADT type and constructor with given shape."""
+ origin_shape = self.shape
+ self.shape = shape
+ self.define_tensor_adt()
+ tensor_type_var = self.get_var("tensor_t")
+ tensor_constructor = self.get_var("tensor_constructor")
+ self.shape = origin_shape
+ return tensor_type_var, tensor_constructor
+
+ def _create_global_var(self, name):
+ """Create a GlobalVar if doesn't exist in prelude."""
+ global_var_name_set = set()
+ for g_var_name in self.prelude.mod.get_global_vars():
+ global_var_name_set.add(g_var_name.name_hint)
+ if name not in global_var_name_set:
+ gvar = GlobalVar(name)
+ else:
+ gvar = self.prelude.mod.get_global_var(name)
+
+ return gvar
+
class TensorArrayOps(object):
"""Contains tensor array related ops"""
name = self.get_name(canonical, dtype)
return getattr(self, name)
+ def get_name_static(self, canonical, dtype, shape):
+ """Get name corresponding to the canonical name"""
+ return _get_name_static(canonical, dtype, shape)
+
+ def get_var_static(self, canonical, dtype, shape):
+ """Get var corresponding to the canonical name"""
+ name = self.get_name_static(canonical, dtype, shape)
+ return getattr(self, name)
+
def load_prelude(self):
"""Parses the Prelude from Relay's text format into a module."""
# TODO(@jroesch): we should remove this helper when we port over prelude
from tvm import relay
from tvm.relay.backend.interpreter import ConstructorValue
from tvm.relay import create_executor
-from tvm.relay.prelude import Prelude
+from tvm.relay.prelude import Prelude, StaticTensorArrayOps
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
import numpy as np
run('float32')
run('int32')
+def test_static_tensor_take():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+
+ take = p.get_var_static('tensor_take', dtype, shape)
+ tensor_constructor = p.get_var_static('tensor_constructor', dtype, shape)
+ v = relay.var('v')
+ lower = relay.var('lower')
+ upper = relay.var('upper')
+ mod["main"] = relay.Function([v, lower, upper], take(tensor_constructor(v), lower, upper))
+ v_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ expected = [np.take(v_data, range(2, 5), axis=0)]
+ check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype)
+ expected = [np.take(v_data, range(0, 9), axis=0)]
+ check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype)
+ run('float32', [10, 10])
+ run('int32', [15, 11])
+
+
+def test_static_tensor_concatenate():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+
+ concat = p.get_var_static('tensor_concatenate', dtype, shape)
+ tensor = p.get_var_static('tensor_constructor', dtype, shape)
+ v1 = relay.var('v1')
+ v2 = relay.var('v2')
+ mod["main"] = relay.Function([v1, v2], concat(tensor(v1),
+ tensor(v2)))
+ v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ expected = [np.concatenate((v1_data, v2_data))]
+ check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
+ run('float32', [5,])
+ run('int32', [2, 3])
+
+
+def test_static_tensor_expand_dims():
+ def run(dtype, shape):
+ x = relay.var('x')
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+
+ expand_dims_func = p.get_var_static('tensor_expand_dims', dtype, shape)
+ tensor = p.get_var_static('tensor_constructor', dtype, shape)
+ mod["main"] = relay.Function([x], expand_dims_func(tensor(x)))
+ x_np = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ expected = [np.expand_dims(x_np, axis=0)]
+ check_tensor_array(mod, expected, x_np)
+ run('float32', [])
+ run('int32', [2,])
+
+
+def test_static_tensor_array_constructor():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+ tensor_constructor = p.get_name_static('tensor_constructor', dtype, shape)
+ assert tensor_constructor != None
+ run('float32', [1, 1])
+
+
+def test_static_tensor_array_read():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+
+ np_data_list = []
+ ta_length = 3
+ for _ in range(ta_length):
+ np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype))
+
+ v0 = relay.var('v0')
+ v1 = relay.var('v1')
+ v2 = relay.var('v2')
+ n = relay.var('n')
+ tensor = p.get_var_static('tensor_constructor', dtype, shape)
+ tensor_array = p.get_var_static('tensor_array', dtype, shape)
+ init_tensor_array = tensor_array(relay.const(ta_length))
+ read_func = p.get_var_static('tensor_array_read', dtype, shape)
+ write_func = p.get_var_static('tensor_array_write', dtype, shape)
+ tensor_array0 = write_func(init_tensor_array, relay.const(0),
+ tensor(v0))
+ tensor_array1 = write_func(tensor_array0, relay.const(1),
+ tensor(v1))
+ tensor_array2 = write_func(tensor_array1, relay.const(2),
+ tensor(v2))
+
+ mod["main"] = relay.Function([v0, v1, v2, n], read_func(tensor_array2, n))
+ expected = [np_data_list[0]]
+ check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype)
+ expected = [np_data_list[1]]
+ check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype)
+ expected = [np_data_list[2]]
+ check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype)
+ run('float32', [])
+ run('int32', [2, 3])
+
+
+def test_static_tensor_array_write():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+
+ ta_length = 2
+ np_data_list = [np.random.uniform(0, 10, size=shape).astype(dtype) for _ in range(ta_length)]
+
+ v0 = relay.var('v0')
+ v1 = relay.var('v1')
+ tensor_array = p.get_var_static('tensor_array', dtype, shape)
+ init_tensor_array = tensor_array(relay.const(ta_length))
+ write_func = p.get_var_static('tensor_array_write', dtype, shape)
+ tensor = p.get_var_static('tensor_constructor', dtype, shape)
+ tensor_array0 = write_func(init_tensor_array, relay.const(0),
+ tensor(v0))
+ tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1))
+ mod["main"] = relay.Function([v0, v1], tensor_array1)
+ expected = np_data_list
+ check_tensor_array(mod, expected, *np_data_list, dtype=dtype)
+ run('float32', [])
+ run('int32', [2, 3])
+
+
+def test_static_tensor_array_unstack():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+
+ unstack_tensor = p.get_var_static('tensor_array_unstack', dtype, shape)
+ v = relay.var('v')
+ mod["main"] = relay.Function([v], unstack_tensor(v))
+ t = np.random.uniform(low=0, high=10, size=shape).astype(dtype)
+ *expected, = t
+ check_tensor_array(mod, expected, t, dtype=dtype)
+ run('float32', [4])
+ run('int32', [2, 3])
+
+
+def test_static_tensor_array_scatter():
+ def run(dtype, shape, indices_shape=None):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+ if indices_shape is not None:
+ static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True)
+
+ # tensor array
+ v1 = relay.var('v1')
+ v2 = relay.var('v2')
+ v3 = relay.var('v2')
+ tensor_array = p.get_var_static('tensor_array', dtype, shape)
+ tensor_array0 = tensor_array(relay.const(3))
+ write_func = p.get_var_static('tensor_array_write', dtype, shape)
+ scatter_func = p.get_var_static('tensor_array_scatter', dtype, shape)
+ tensor = p.get_var_static('tensor_constructor', dtype, shape)
+ tensor_array1 = write_func(tensor_array0, relay.const(0), tensor(v1))
+ tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2))
+ tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3))
+
+ # indices array
+ index = relay.var('index')
+
+ # values array
+ value_0 = relay.var('value_0')
+ value_1 = relay.var('value_1')
+ values_array = tensor_array(relay.const(2))
+ values_array = write_func(values_array, relay.const(0),
+ tensor(value_0))
+ values_array = write_func(values_array, relay.const(1),
+ tensor(value_1))
+
+ # create the scatter function
+ tensor_array_scatter = scatter_func(tensor_array1, index, values_array)
+ mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1],
+ tensor_array_scatter)
+
+ # initialize and check
+ v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ v3_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ index_data = np.array([0, 1], dtype="int32")
+ val1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ val2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ expected = [val1_data, val2_data, v3_data]
+ check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data,
+ index_data, val1_data,
+ val2_data), dtype=dtype)
+ run('float32', [2, 3])
+ run('int32', [2, 3])
+ run('float32', [2, 3], [2,])
+
+
+def test_static_tensor_array_split():
+ def run(dtype, shape, value_shape=None, lengths_shape=None):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+ if value_shape is not None or lengths_shape is not None:
+ static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, True)
+
+ # tensor array
+ v1 = relay.var('v1')
+ v2 = relay.var('v2')
+ v3 = relay.var('v2')
+
+ adt_shape = [relay.Any(),] + shape[1:]
+ origin_shape = static_tensor_array_ops.shape
+ static_tensor_array_ops.shape = adt_shape
+ static_tensor_array_ops.define_tensor_array()
+ tensor_array = p.get_var_static('tensor_array', dtype, adt_shape)
+ static_tensor_array_ops.shape = origin_shape
+ tensor_array1 = tensor_array(relay.const(3))
+ write_func = p.get_var_static('tensor_array_write', dtype, adt_shape)
+ split_func = p.get_var_static('tensor_array_split', dtype, shape)
+ tensor = p.get_var_static('tensor_constructor', dtype, adt_shape)
+ tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1))
+ tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2))
+ tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3))
+
+ # value tensor
+ value = relay.var('value')
+
+ # lengths tensor
+ ta_len = relay.var('length')
+
+ # create the split function
+ if value_shape is None:
+ tensor1 = p.get_var_static('tensor_constructor', dtype, shape)
+ else:
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, value_shape)
+ static_tensor_array_ops.register()
+ tensor1 = p.get_var_static('tensor_constructor', dtype, value_shape)
+ tensor_array_split = split_func(tensor_array1, tensor1(value), ta_len)
+ mod["main"] = relay.Function([v1, v2, v3, value, ta_len],
+ tensor_array_split)
+
+ # initialize and check
+ v1_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype)
+ v2_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype)
+ v3_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype)
+ value_data = np.random.uniform(low=0.0, high=8.0,
+ size=value_shape or shape).astype(dtype)
+ length_data = np.array([2, 2], dtype="int32")
+ expected = np.concatenate([value_data, v3_data])
+ expected = np.split(expected, indices_or_sections=[2, 4])
+ check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data,
+ value_data, length_data),
+ dtype=dtype)
+
+ run('float32', [4, 3])
+ run('int32', [4, 3])
+ run('int32', [relay.Any(), 3], [4, 3], [2,])
+
+
+def test_static_tensor_array_concat():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+
+ v1 = relay.var('v1')
+ v2 = relay.var('v2')
+ tensor_array = p.get_var_static('tensor_array', dtype, shape)
+ tensor_array1 = tensor_array(relay.const(2))
+ write_func = p.get_var_static('tensor_array_write', dtype, shape)
+ concat_func = p.get_var_static('tensor_array_concat', dtype, shape)
+ tensor = p.get_var_static('tensor_constructor', dtype, shape)
+ tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1))
+ tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2))
+ tensor_array_concat = concat_func(tensor_array1)
+ mod["main"] = relay.Function([v1, v2], tensor_array_concat)
+ v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype)
+ v2_data = np.random.uniform(low=0.0, high=8.0, size=(1, 3)).astype(dtype)
+ expected = [np.concatenate((v1_data, v2_data), axis=0)]
+ check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
+ run('float32', [relay.Any(), 3])
+ run('int32', [relay.Any(), 3])
+
+
+def test_static_tensor_array_gather():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+
+ tensor_array = p.get_var_static('tensor_array', dtype, shape)
+ tensor = p.get_var_static('tensor_constructor', dtype, shape)
+ write = p.get_var_static('tensor_array_write', dtype, shape)
+ gather = p.get_var_static('tensor_array_gather', dtype, shape)
+ v = relay.var('v')
+ indice = relay.var('indice')
+ init_tensor_array = tensor_array(relay.const(3))
+ tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v))
+ tensor_array2 = write(tensor_array1, relay.const(1), tensor(v))
+ tensor_array3 = write(tensor_array2, relay.const(2), tensor(v))
+ out = gather(tensor_array3, indice)
+ mod["main"] = relay.Function([v, indice], out)
+ t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ indice_data = np.array([0, 2], dtype="int32")
+ expected = [np.stack([t, t])]
+ check_tensor_array(mod, expected, *(t, indice_data), dtype=dtype)
+ run('float32', [])
+ run('int32', [2, 3])
+
+
+def test_static_tensor_array_stack():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+
+ tensor_array = p.get_var_static('tensor_array', dtype, shape)
+ tensor = p.get_var_static('tensor_constructor', dtype, shape)
+ write = p.get_var_static('tensor_array_write', dtype, shape)
+ stack = p.get_var_static('tensor_array_stack', dtype, shape)
+ v = relay.var('v')
+ init_tensor_array = tensor_array(relay.const(3))
+ tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v))
+ tensor_array2 = write(tensor_array1, relay.const(1), tensor(v))
+ tensor_array3 = write(tensor_array2, relay.const(2), tensor(v))
+ tensor_array4 = stack(tensor_array3)
+ mod["main"] = relay.Function([v], tensor_array4)
+ t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+ expected = [np.stack([t, t, t])]
+ check_tensor_array(mod, expected, t, dtype=dtype)
+ run('float32', [])
+ run('int32', [2, 3])
+
+
+def test_static_tensor_get_data():
+ def run(dtype, shape):
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+ static_tensor_array_ops.register()
+ static_tensor_array_ops.define_tensor_get_data(shape)
+
+ np_data_list = []
+ ta_length = 3
+ for _ in range(ta_length):
+ np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype))
+
+ v0 = relay.var('v0')
+ v1 = relay.var('v1')
+ v2 = relay.var('v2')
+ n = relay.var('n')
+ tensor = p.get_var_static('tensor_constructor', dtype, shape)
+ tensor_array = p.get_var_static('tensor_array', dtype, shape)
+ init_tensor_array = tensor_array(relay.const(ta_length))
+ read_func = p.get_var_static('tensor_array_read', dtype, shape)
+ write_func = p.get_var_static('tensor_array_write', dtype, shape)
+ get_data_func = p.get_var_static('tensor_get_data', dtype, shape)
+ tensor_array0 = write_func(init_tensor_array, relay.const(0),
+ tensor(v0))
+ tensor_array1 = write_func(tensor_array0, relay.const(1),
+ tensor(v1))
+ tensor_array2 = write_func(tensor_array1, relay.const(2),
+ tensor(v2))
+
+ mod["main"] = relay.Function([v0, v1, v2, n], get_data_func(read_func(tensor_array2, n)))
+ expected = [np_data_list[0]]
+ check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype)
+ expected = [np_data_list[1]]
+ check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype)
+ expected = [np_data_list[2]]
+ check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype)
+ run('float32', [])
+ run('int32', [2, 3])
if __name__ == "__main__":
test_nat_constructor()
test_tensor_array_concat()
test_tensor_array_scatter()
test_tensor_array_split()
+
+ test_static_tensor_take()
+ test_static_tensor_concatenate()
+ test_static_tensor_expand_dims()
+ test_static_tensor_array_constructor()
+ test_static_tensor_array_read()
+ test_static_tensor_array_write()
+ test_static_tensor_array_unstack()
+ test_static_tensor_array_scatter()
+ test_static_tensor_array_split()
+ test_static_tensor_array_concat()
+ test_static_tensor_array_stack()
+ test_static_tensor_array_gather()
+ test_static_tensor_get_data()