[Relay][ADT]Static Tensor Array (#5103)
authorYao Wang <kevinthesunwy@gmail.com>
Sun, 5 Apr 2020 20:42:28 +0000 (13:42 -0700)
committerGitHub <noreply@github.com>
Sun, 5 Apr 2020 20:42:28 +0000 (13:42 -0700)
* Add other static tensor array ops

* Add tensor array get data

* Minor refactor

* Fix pylint

* Update docstring

* Make get data more generic

* Improve test

* Improve split test

* Improve get data

* Minor fix

* Further improvement for static shape

* Improve shape parsing

* Unify get_static_name

python/tvm/relay/prelude.py
tests/python/relay/test_adt.py

index 0e64a2f..47c3ba7 100644 (file)
@@ -27,6 +27,545 @@ from .adt import PatternConstructor, PatternVar, PatternWildcard
 from . import op
 
 
+def _get_name_static(canonical, dtype, shape):
+    """Get name for static shape tensor array op corresponding
+    to the canonical name"""
+    shape_str = '_'.join([str(dim) for dim in shape])
+    if len(shape_str) == 0:
+        shape_str = "scalar"
+    if canonical == 'tensor_t':
+        return 'static_tensor_{}_{}_t'.format(dtype, shape_str)
+    return "{}_{}_{}".format(canonical, dtype, shape_str)
+
+class StaticTensorArrayOps(object):
+    """Contains tensor array related ops for fixed rank tensor array"""
+
+    def __init__(self, prelude, dtype, shape):
+        """Create tensor array ops registry"""
+        self.prelude = prelude
+        self.dtype = dtype
+        self.shape = shape
+
+    def get_name(self, canonical):
+        """Get name corresponding to the canonical name"""
+        return _get_name_static(canonical, self.dtype, self.shape)
+
+    def get_var(self, canonical):
+        """Get var corresponding to the canonical name"""
+        name = self.get_name(canonical)
+        return getattr(self.prelude, name)
+
+    def define_tensor_adt(self):
+        """Defines the static tensor ADT, which is the container for tensors
+        with fixed shapes."""
+        tensor_type_name = self.get_name('tensor_t')
+        # Skip register if tensor type is already registered.
+        global_type_names = set()
+        for g_ty_var in self.prelude.mod.get_global_type_vars():
+            global_type_names.add(g_ty_var.name_hint)
+        if tensor_type_name in global_type_names:
+            return
+
+        tensor_type_var = GlobalTypeVar(tensor_type_name)
+        setattr(self.prelude, tensor_type_name, tensor_type_var)
+        tensor_type = TensorType(self.shape, self.dtype)
+        tensor_constructor_name = self.get_name('tensor_constructor')
+
+        tensor_nil_name = self.get_name('tensor_nil')
+        tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
+        tensor_case = Constructor(tensor_constructor_name, [tensor_type], tensor_type_var)
+
+        setattr(self.prelude, tensor_nil_name, tensor_nil_case)
+        setattr(self.prelude, tensor_constructor_name, tensor_case)
+        self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var,
+                                                     [],
+                                                     [tensor_nil_case, tensor_case])
+
+    def define_tensor_array(self):
+        """Defines a function to create a tensor array with size n.
+        tensor_array(n) : Tensor[(), int32] -> list[tensor_t]
+        """
+        tensor_array_constructor_name = self.get_name("tensor_array")
+        tensor_array_constructor_var = self._create_global_var(tensor_array_constructor_name)
+        setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
+        tensor_nil_var = self.get_var('tensor_nil')
+        tensor_type_var = self.get_var('tensor_t')
+        n = Var("x", scalar_type('int32'))
+        body = If(equal(n, const(0)),
+                  self.prelude.nil(),
+                  self.prelude.cons(tensor_nil_var(),
+                                    tensor_array_constructor_var(subtract(n, const(1)))))
+        self.prelude.mod[tensor_array_constructor_var] = \
+            Function([n], body, self.prelude.l(tensor_type_var()), [])
+
+    def define_tensor_take(self):
+        """Defines a function to return a range of tensor_t on axis 0.
+            tensor_take(t, lower, upper) :
+            tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
+        """
+        # We don't register take for scalar tensor.
+        ndim = len(self.shape)
+        if ndim == 0:
+            return
+
+        take_name = self.get_name("tensor_take")
+        take_var = self._create_global_var(take_name)
+        setattr(self.prelude, take_name, take_var)
+        origin_tensor_constructor = self.get_var('tensor_constructor')
+
+        output_shape = [Any(),] + list(self.shape[1:])
+        tensor_type_var, tensor_constructor = \
+            self._get_adt_by_shape(output_shape)
+
+        t = Var('tensor', self.get_var('tensor_t')())
+        lower = Var('lower', scalar_type('int32'))
+        upper = Var('upper', scalar_type('int32'))
+        tvar = Var('t')
+        case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]),
+                      tensor_constructor(op.take(tvar,
+                                                 op.arange(lower, upper, dtype='int32'),
+                                                 axis=0)))
+        self.prelude.mod[take_var] = \
+            Function([t, lower, upper],
+                     Match(t, [case], False), tensor_type_var(), [])
+
+    def define_tensor_concatenate(self):
+        """Defines a function to concatenate two tensor_t on axis 0.
+        tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
+        """
+         # We don't register concatenate for scalar tensor.
+        ndim = len(self.shape)
+        if ndim == 0:
+            return
+
+        concat_name = self.get_name("tensor_concatenate")
+        concat_var = self._create_global_var(concat_name)
+        setattr(self.prelude, concat_name, concat_var)
+        output_shape = [Any(),] + list(self.shape[1:])
+        tensor_type_var, tensor_constructor = \
+            self._get_adt_by_shape(output_shape)
+
+        origin_tensor_constructor = self.get_var('tensor_constructor')
+        origin_tensor_type_var = self.get_var('tensor_t')
+        x = Var("x", origin_tensor_type_var())
+        y = Var("y", origin_tensor_type_var())
+        t1 = Var("t1")
+        t2 = Var("t2")
+
+        case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]),
+                      Match(y,
+                            [Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]),
+                                    tensor_constructor(op.concatenate([t1, t2], axis=0)))],
+                            False))
+
+        self.prelude.mod[concat_var] = \
+            Function([x, y], Match(x, [case], False), tensor_type_var(), [])
+
+
+    def define_tensor_expand_dims(self):
+        """Defines a function to grow a tensor_t's rank by adding one dimension in front
+        of the original tensor_t.
+        tensor_expand_dims(t) : tensor_t -> tensor_t
+        """
+        expand_dims_name = self.get_name("tensor_expand_dims")
+        expand_dims_var = self._create_global_var(expand_dims_name)
+        setattr(self.prelude, expand_dims_name, expand_dims_var)
+        origin_tensor_type_var = self.get_var('tensor_t')
+        origin_tensor_constructor = self.get_var('tensor_constructor')
+        x = Var("x", origin_tensor_type_var())
+
+        # Note: we set the added axis to be Any() instead of 1 due to
+        # in stack op, we need to recursively concatenate.
+        tensor_type_var, tensor_constructor = \
+            self._get_adt_by_shape([Any(),] + list(self.shape))
+        t = Var("t")
+        case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t)]),
+                      tensor_constructor(op.expand_dims(t, 0, 1)))
+
+        self.prelude.mod[expand_dims_var] = \
+            Function([x], Match(x, [case], False), tensor_type_var(), [])
+
+    def define_tensor_array_read(self):
+        """Defines a function to get the nth element of a list. Assume the list has at least one
+        element.
+        tensor_array_read(ta, n) : list[static_tensor_t] -> Tensor[(), int32] ->
+        Tensor[self.shape, self.dtype]
+        """
+        read_name = self.get_name("tensor_array_read")
+        read_var = self._create_global_var(read_name)
+        setattr(self.prelude, read_name, read_var)
+        tensor_type_var = self.get_var('tensor_t')
+
+        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+        n = Var("x", scalar_type('int32'))
+        self.prelude.mod[read_var] = \
+            Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [])
+
+    def define_tensor_array_write(self):
+        """Defines a function to update a tensor array at index n with value v.
+        tensor_array_write(ta, n, v) :
+            list[static_tensor_t] -> Tensor[(), int32] -> Tensor[self.shape, self.dtype] ->
+            list[static_tensor_t]
+        """
+        write_name = self.get_name("tensor_array_write")
+        write_var = self._create_global_var(write_name)
+        setattr(self.prelude, write_name, write_var)
+        tensor_type_var = self.get_var('tensor_t')
+        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+        n = Var("x", scalar_type('int32'))
+        v = Var("v", tensor_type_var())
+        self.prelude.mod[write_var] = \
+            Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v),
+                     self.prelude.l(tensor_type_var()), [])
+
+    def define_tensor_array_unstack(self):
+        """Defines a function to unstack the values of a tensor_t in a tensor array.
+        tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t]
+        """
+        ndim = len(self.shape)
+        # We don't register unstack for scalar tensor array
+        if ndim == 0:
+            return
+
+        helper_name = self.get_name("tensor_array_unstack_helper")
+        helper_var = self._create_global_var(helper_name)
+        setattr(self.prelude, helper_name, helper_var)
+        tensor = Var("t", TensorType(self.shape, self.dtype))
+        up = Var("up", scalar_type('int32'))
+        i = Var("i", scalar_type('int32'))
+        tensor_var = Var("tensor", TensorType(self.shape, self.dtype))
+
+        reduced_tensor_type_var, tensor_constructor = \
+            self._get_adt_by_shape(self.shape[1:])
+        helper_body = \
+            If(equal(i, up),
+               self.prelude.nil(),
+               self.prelude.cons(tensor_constructor(op.take(tensor, i, axis=0)),
+                                 helper_var(add(i, const(1)), up, tensor)))
+        self.prelude.mod[helper_var] = \
+            Function([i, up, tensor], helper_body, self.prelude.l(reduced_tensor_type_var()), [])
+
+        unstack_name = self.get_name("tensor_array_unstack")
+        unstack_var = self._create_global_var(unstack_name)
+        setattr(self.prelude, unstack_name, unstack_var)
+        shape = op.shape_of(tensor_var)
+        unstack_length = op.take(shape, const(0))
+        self.prelude.mod[unstack_var] = \
+            Function([tensor_var], helper_var(const(0), unstack_length, tensor_var),
+                     self.prelude.l(reduced_tensor_type_var()), [])
+
+    def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
+        """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
+        tensor_array_scatter(ta, indices, value) :
+            list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t]
+
+        Set static indices shape by specifying indices_shape.
+        Set force_update to get static indices shape operator.
+        """
+        # When this operator has already been registered, only update
+        # when force_update is set. This should be used only when we need to
+        # redefine this op for static indices shape.
+        tensor_array_scatter_name = self.get_name("tensor_array_scatter")
+        if hasattr(self.prelude, tensor_array_scatter_name) and not force_update:
+            return
+
+        tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
+        tensor_array_scatter_helper_var = \
+            self._create_global_var(tensor_array_scatter_helper_name)
+        tensor_type_var = self.get_var('tensor_t')
+        ta = Var("ta", self.prelude.l(tensor_type_var()))
+        current = Var("current", scalar_type('int32'))
+        limit = Var("limit", scalar_type('int32'))
+        indices_ = Var('indices_', TensorType(indices_shape or [Any()], 'int32'))
+        values_ = Var('values_', self.prelude.l(tensor_type_var()))
+        write_var = self.get_var('tensor_array_write')
+        read_var = self.get_var('tensor_array_read')
+        helper_body = If(equal(current, limit),
+                         ta,
+                         tensor_array_scatter_helper_var(
+                             write_var(ta, op.take(indices_, current),
+                                       read_var(values_, current)),
+                             add(current, const(1)),
+                             limit, indices_, values_))
+        self.prelude.mod[tensor_array_scatter_helper_var] = \
+            Function([ta, current, limit, indices_, values_],
+                     helper_body, self.prelude.l(tensor_type_var()), [])
+
+        tensor_array_scatter_var = self._create_global_var(tensor_array_scatter_name)
+        setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
+        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+
+        indices = Var('indices', TensorType(indices_shape or [Any()], 'int32'))
+        values = Var('values', self.prelude.l(tensor_type_var()))
+        if indices_shape is None:
+            indices_shape = op.shape_of(indices)
+            limit = op.take(indices_shape, const(0))
+        else:
+            limit = const(indices_shape[0])
+
+        body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
+        self.prelude.mod[tensor_array_scatter_var] = \
+            Function([tensor_array, indices, values], body,
+                     self.prelude.l(tensor_type_var()), [])
+
+    def define_tensor_array_split(self,
+                                  value_shape=None,
+                                  lengths_shape=None,
+                                  force_update=False):
+        """Defines a function to split the values of a tensor_t into a tensor array.
+        tensor_array_split(ta, value, lengths) :
+            list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
+
+        Set static value and lengths shapes by specifying value_shape and lengths_shape.
+        Set force_update to get static value and lengths shape operator.
+        """
+        # Skip scalar case
+        ndim = len(self.shape)
+        if ndim == 0:
+            return
+
+        # When this operator has already been registered, only update
+        # when force_update is set. This should be used only when we need to
+        # redefine this op for static value/indices shape.
+        split_name = self.get_name("tensor_array_split")
+        if hasattr(self.prelude, split_name) and not force_update:
+            return
+
+        tensor_type_var = self.get_var('tensor_t')
+        tensor_array_split_helper_name = self.get_name("ta_split_helper")
+        tensor_array_split_helper_var = \
+            self._create_global_var(tensor_array_split_helper_name)
+        setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
+        output_shape = [Any(),] + list(self.shape[1:])
+        output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
+
+        if value_shape is None:
+            value_type_var = tensor_type_var
+            take_var = self.get_var('tensor_take')
+        else:
+            value_type_var, _ = self._get_adt_by_shape(value_shape)
+            # Also get static shape take operator
+            origin_shape = list(self.shape)
+            self.shape = value_shape
+            self.define_tensor_take()
+            take_var = self.get_var('tensor_take')
+            self.shape = origin_shape
+
+
+        ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var()))
+        value1 = Var('value1', value_type_var())
+        offset1 = Var('offset1', scalar_type('int32'))
+        current1 = Var('current1', scalar_type('int32'))
+        limit1 = Var('limit1', scalar_type('int32'))
+        lengths1 = Var('lengths', TensorType(lengths_shape or [Any()], 'int32'))
+
+        # Register write for output shape
+        origin_shape = list(self.shape)
+        self.shape = output_shape
+        self.define_tensor_array_write()
+        write_var = self.get_var('tensor_array_write')
+        self.shape = origin_shape
+        helper1_body = If(equal(current1, limit1),
+                          ta1,
+                          write_var(
+                              tensor_array_split_helper_var(
+                                  ta1,
+                                  value1,
+                                  add(offset1, op.take(lengths1, current1)),
+                                  add(current1, const(1)),
+                                  limit1,
+                                  lengths1
+                              ),
+                              current1,
+                              take_var(value1,
+                                       offset1,
+                                       add(op.take(lengths1, current1), offset1))))
+        self.prelude.mod[tensor_array_split_helper_var] = \
+            Function([ta1, value1, offset1, current1, limit1, lengths1],
+                     helper1_body, self.prelude.l(output_tensor_type_var()), [])
+        split_var = self._create_global_var(split_name)
+        setattr(self.prelude, split_name, split_var)
+        tensor_array = Var("tensor_array", self.prelude.l(output_tensor_type_var()))
+
+        value = Var('value', value_type_var())
+        lengths = Var('lengths', TensorType(lengths_shape or [Any()], 'int32'))
+        if lengths_shape is None:
+            lengths_shape = op.shape_of(lengths)
+            lengths_limit = op.take(lengths_shape, const(0))
+        else:
+            lengths_limit = const(lengths_shape[0])
+        body = tensor_array_split_helper_var(
+            tensor_array,
+            value,
+            const(0),
+            const(0),
+            lengths_limit,
+            lengths)
+
+        self.prelude.mod[split_var] = \
+            Function([tensor_array, value, lengths], body,
+                     self.prelude.l(output_tensor_type_var()), [])
+
+    def define_tensor_array_concat(self):
+        """Defines a function to return the values in the tensor array as concatenated tensor_t.
+        tensor_array_concat(ta) : list[tensor_t] -> tensor_t
+        """
+        # We don't register concat for scalar tensor array.
+        ndim = len(self.shape)
+        if ndim == 0:
+            return
+
+        concat_name = self.get_name("tensor_array_concat")
+        concat_var = self._create_global_var(concat_name)
+        setattr(self.prelude, concat_name, concat_var)
+
+        output_shape = [Any(),] + list(self.shape[1:])
+        tensor_type_var, _ = self._get_adt_by_shape(output_shape)
+
+        # Register tensor concatenate and get tensor_nil var for output shape
+        origin_shape = self.shape
+        self.shape = output_shape
+        self.define_tensor_concatenate()
+        tensor_concat_var = self.get_var('tensor_concatenate')
+        tensor_nil_var = self.get_var('tensor_nil')
+        self.shape = origin_shape
+
+        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+        hd = Var("hd")
+        tl = Var("tl")
+        nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
+        cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
+                           Match(tl, [
+                               Clause(PatternConstructor(self.prelude.nil), hd),
+                               Clause(PatternWildcard(),
+                                      tensor_concat_var(hd, concat_var(tl)))
+                           ], False))
+        self.prelude.mod[concat_var] = \
+            Function([tensor_array],
+                     Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), [])
+
+    def define_tensor_array_stack(self):
+        """Defines a function to get the values in the tensor array as a stack tensor_t.
+        tensor_array_stack(l) : list[tensor_t] -> tensor_t
+        """
+        stack_name = self.get_name("tensor_array_stack")
+        stack_var = self._create_global_var(stack_name)
+        setattr(self.prelude, stack_name, stack_var)
+        tensor_type_var = self.get_var('tensor_t')
+        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+        expand_dims_var = self.get_var('tensor_expand_dims')
+
+        # Register tensor_concatenate for output_shape
+        origin_shape = self.shape
+        output_shape = [Any(),] + list(self.shape)
+        self.shape = output_shape
+        self.define_tensor_concatenate()
+        concat_var = self.get_var('tensor_concatenate')
+        self.shape = origin_shape
+
+        tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
+        tensors = self.prelude.foldl(concat_var,
+                                     self.prelude.hd(tensor_array_expand_dims),
+                                     self.prelude.tl(tensor_array_expand_dims))
+        output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
+        self.prelude.mod[stack_var] = Function([tensor_array], tensors,
+                                               output_tensor_type_var(), [])
+
+    def define_tensor_array_gather(self):
+        """Defines a function to return the selected values in a tensor array as tensor_t.
+        tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t
+        """
+        helper_name = self.get_name("tensor_array_gather_helper")
+        helper_var = self._create_global_var(helper_name)
+        setattr(self.prelude, helper_name, helper_var)
+        tensor_type_var = self.get_var('tensor_t')
+        output_shape = [Any(),] + list(self.shape)
+        output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
+        stack_var = self.get_var('tensor_array_stack')
+        read_var = self.get_var('tensor_array_read')
+        ta = Var("ta", self.prelude.l(tensor_type_var()))
+        accu = Var("accu", self.prelude.l(tensor_type_var()))
+        current = Var("current", scalar_type('int32'))
+        limit = Var("limit", scalar_type('int32'))
+        indices_ = Var('indices_', TensorType([Any()], 'int32'))
+        helper_body = \
+            If(equal(current, const(0)),
+               stack_var(accu),
+               helper_var(
+                   ta,
+                   self.prelude.cons(
+                       read_var(
+                           ta, op.take(indices_, subtract(current, const(1)))), accu),
+                   subtract(current, const(1)),
+                   limit, indices_))
+        self.prelude.mod[helper_var] = \
+            Function([ta, accu, current, limit, indices_],
+                     helper_body, output_tensor_type_var(), [])
+        gather_name = self.get_name("tensor_array_gather")
+        gather_var = self._create_global_var(gather_name)
+        setattr(self.prelude, gather_name, gather_var)
+        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
+        indices = Var('indices', TensorType([Any()], 'int32'))
+        indices_shape = op.shape_of(indices)
+        limit = op.take(indices_shape, const(0))
+        body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
+        self.prelude.mod[gather_var] = \
+            Function([tensor_array, indices], body, output_tensor_type_var(), [])
+
+    def define_tensor_get_data(self, data_shape):
+        """Defines a function to get a Tensor from tensor_t with given shape.
+        """
+        tensor_get_data_name = self.get_name("tensor_get_data")
+        tensor_get_data_var = self._create_global_var(tensor_get_data_name)
+        setattr(self.prelude, tensor_get_data_name, tensor_get_data_var)
+
+        tensor_type_var, tensor_constructor = self._get_adt_by_shape(data_shape)
+        t = Var('tensor', tensor_type_var())
+        tvar = Var('t')
+        case =\
+            Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar)
+        self.prelude.mod[tensor_get_data_var] = \
+            Function([t], Match(t, [case], False),
+                     TensorType(data_shape, self.dtype), [])
+
+    def register(self):
+        """Register all tensor array ops in Prelude"""
+        self.define_tensor_adt()
+        self.define_tensor_take()
+        self.define_tensor_concatenate()
+        self.define_tensor_expand_dims()
+        self.define_tensor_array()
+        self.define_tensor_array_read()
+        self.define_tensor_array_write()
+        self.define_tensor_array_unstack()
+        self.define_tensor_array_scatter()
+        self.define_tensor_array_split()
+        self.define_tensor_array_concat()
+        self.define_tensor_array_stack()
+        self.define_tensor_array_gather()
+
+    def _get_adt_by_shape(self, shape):
+        """Get ADT type and constructor with given shape."""
+        origin_shape = self.shape
+        self.shape = shape
+        self.define_tensor_adt()
+        tensor_type_var = self.get_var("tensor_t")
+        tensor_constructor = self.get_var("tensor_constructor")
+        self.shape = origin_shape
+        return tensor_type_var, tensor_constructor
+
+    def _create_global_var(self, name):
+        """Create a GlobalVar if doesn't exist in prelude."""
+        global_var_name_set = set()
+        for g_var_name in self.prelude.mod.get_global_vars():
+            global_var_name_set.add(g_var_name.name_hint)
+        if name not in global_var_name_set:
+            gvar = GlobalVar(name)
+        else:
+            gvar = self.prelude.mod.get_global_var(name)
+
+        return gvar
+
 class TensorArrayOps(object):
     """Contains tensor array related ops"""
 
@@ -666,6 +1205,15 @@ class Prelude:
         name = self.get_name(canonical, dtype)
         return getattr(self, name)
 
+    def get_name_static(self, canonical, dtype, shape):
+        """Get name corresponding to the canonical name"""
+        return _get_name_static(canonical, dtype, shape)
+
+    def get_var_static(self, canonical, dtype, shape):
+        """Get var corresponding to the canonical name"""
+        name = self.get_name_static(canonical, dtype, shape)
+        return getattr(self, name)
+
     def load_prelude(self):
         """Parses the Prelude from Relay's text format into a module."""
         # TODO(@jroesch): we should remove this helper when we port over prelude
index deeb733..c9b13d2 100644 (file)
@@ -19,7 +19,7 @@ from tvm import te
 from tvm import relay
 from tvm.relay.backend.interpreter import ConstructorValue
 from tvm.relay import create_executor
-from tvm.relay.prelude import Prelude
+from tvm.relay.prelude import Prelude, StaticTensorArrayOps
 from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
 
 import numpy as np
@@ -980,6 +980,395 @@ def test_tensor_array_split():
     run('float32')
     run('int32')
 
+def test_static_tensor_take():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+
+        take = p.get_var_static('tensor_take', dtype, shape)
+        tensor_constructor = p.get_var_static('tensor_constructor', dtype, shape)
+        v = relay.var('v')
+        lower = relay.var('lower')
+        upper = relay.var('upper')
+        mod["main"] = relay.Function([v, lower, upper], take(tensor_constructor(v), lower, upper))
+        v_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        expected = [np.take(v_data, range(2, 5), axis=0)]
+        check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype)
+        expected = [np.take(v_data, range(0, 9), axis=0)]
+        check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype)
+    run('float32', [10, 10])
+    run('int32', [15, 11])
+
+
+def test_static_tensor_concatenate():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+
+        concat = p.get_var_static('tensor_concatenate', dtype, shape)
+        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        v1 = relay.var('v1')
+        v2 = relay.var('v2')
+        mod["main"] = relay.Function([v1, v2], concat(tensor(v1),
+                                                      tensor(v2)))
+        v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        expected = [np.concatenate((v1_data, v2_data))]
+        check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
+    run('float32', [5,])
+    run('int32', [2, 3])
+
+
+def test_static_tensor_expand_dims():
+    def run(dtype, shape):
+        x = relay.var('x')
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+
+        expand_dims_func = p.get_var_static('tensor_expand_dims', dtype, shape)
+        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        mod["main"] = relay.Function([x], expand_dims_func(tensor(x)))
+        x_np = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        expected = [np.expand_dims(x_np, axis=0)]
+        check_tensor_array(mod, expected, x_np)
+    run('float32', [])
+    run('int32', [2,])
+
+
+def test_static_tensor_array_constructor():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+        tensor_constructor = p.get_name_static('tensor_constructor', dtype, shape)
+        assert tensor_constructor != None
+    run('float32', [1, 1])
+
+
+def test_static_tensor_array_read():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+
+        np_data_list = []
+        ta_length = 3
+        for _ in range(ta_length):
+            np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype))
+
+        v0 = relay.var('v0')
+        v1 = relay.var('v1')
+        v2 = relay.var('v2')
+        n = relay.var('n')
+        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        init_tensor_array = tensor_array(relay.const(ta_length))
+        read_func = p.get_var_static('tensor_array_read', dtype, shape)
+        write_func = p.get_var_static('tensor_array_write', dtype, shape)
+        tensor_array0 = write_func(init_tensor_array, relay.const(0),
+                                   tensor(v0))
+        tensor_array1 = write_func(tensor_array0, relay.const(1),
+                                   tensor(v1))
+        tensor_array2 = write_func(tensor_array1, relay.const(2),
+                                   tensor(v2))
+
+        mod["main"] = relay.Function([v0, v1, v2, n], read_func(tensor_array2, n))
+        expected = [np_data_list[0]]
+        check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype)
+        expected = [np_data_list[1]]
+        check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype)
+        expected = [np_data_list[2]]
+        check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype)
+    run('float32', [])
+    run('int32', [2, 3])
+
+
+def test_static_tensor_array_write():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+
+        ta_length = 2
+        np_data_list = [np.random.uniform(0, 10, size=shape).astype(dtype) for _ in range(ta_length)]
+
+        v0 = relay.var('v0')
+        v1 = relay.var('v1')
+        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        init_tensor_array = tensor_array(relay.const(ta_length))
+        write_func = p.get_var_static('tensor_array_write', dtype, shape)
+        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        tensor_array0 = write_func(init_tensor_array, relay.const(0),
+                                   tensor(v0))
+        tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1))
+        mod["main"] = relay.Function([v0, v1], tensor_array1)
+        expected = np_data_list
+        check_tensor_array(mod, expected, *np_data_list, dtype=dtype)
+    run('float32', [])
+    run('int32', [2, 3])
+
+
+def test_static_tensor_array_unstack():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+
+        unstack_tensor = p.get_var_static('tensor_array_unstack', dtype, shape)
+        v = relay.var('v')
+        mod["main"] = relay.Function([v], unstack_tensor(v))
+        t = np.random.uniform(low=0, high=10, size=shape).astype(dtype)
+        *expected, = t
+        check_tensor_array(mod, expected, t, dtype=dtype)
+    run('float32', [4])
+    run('int32', [2, 3])
+
+
+def test_static_tensor_array_scatter():
+    def run(dtype, shape, indices_shape=None):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+        if indices_shape is not None:
+            static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True)
+
+        # tensor array
+        v1 = relay.var('v1')
+        v2 = relay.var('v2')
+        v3 = relay.var('v2')
+        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        tensor_array0 = tensor_array(relay.const(3))
+        write_func = p.get_var_static('tensor_array_write', dtype, shape)
+        scatter_func = p.get_var_static('tensor_array_scatter', dtype, shape)
+        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        tensor_array1 = write_func(tensor_array0, relay.const(0), tensor(v1))
+        tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2))
+        tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3))
+
+        # indices array
+        index = relay.var('index')
+
+        # values array
+        value_0 = relay.var('value_0')
+        value_1 = relay.var('value_1')
+        values_array = tensor_array(relay.const(2))
+        values_array = write_func(values_array, relay.const(0),
+                                  tensor(value_0))
+        values_array = write_func(values_array, relay.const(1),
+                                  tensor(value_1))
+
+        # create the scatter function
+        tensor_array_scatter = scatter_func(tensor_array1, index, values_array)
+        mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1],
+                                     tensor_array_scatter)
+
+        # initialize and check
+        v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        v3_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        index_data = np.array([0, 1], dtype="int32")
+        val1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        val2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        expected = [val1_data, val2_data, v3_data]
+        check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data,
+                                            index_data, val1_data,
+                                            val2_data), dtype=dtype)
+    run('float32', [2, 3])
+    run('int32', [2, 3])
+    run('float32', [2, 3], [2,])
+
+
+def test_static_tensor_array_split():
+    def run(dtype, shape, value_shape=None, lengths_shape=None):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+        if value_shape is not None or lengths_shape is not None:
+            static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, True)
+
+        # tensor array
+        v1 = relay.var('v1')
+        v2 = relay.var('v2')
+        v3 = relay.var('v2')
+
+        adt_shape = [relay.Any(),] + shape[1:]
+        origin_shape = static_tensor_array_ops.shape
+        static_tensor_array_ops.shape = adt_shape
+        static_tensor_array_ops.define_tensor_array()
+        tensor_array = p.get_var_static('tensor_array', dtype, adt_shape)
+        static_tensor_array_ops.shape = origin_shape
+        tensor_array1 = tensor_array(relay.const(3))
+        write_func = p.get_var_static('tensor_array_write', dtype, adt_shape)
+        split_func = p.get_var_static('tensor_array_split', dtype, shape)
+        tensor = p.get_var_static('tensor_constructor', dtype, adt_shape)
+        tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1))
+        tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2))
+        tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3))
+
+        # value tensor
+        value = relay.var('value')
+
+        # lengths tensor
+        ta_len = relay.var('length')
+
+        # create the split function
+        if value_shape is None:
+            tensor1 = p.get_var_static('tensor_constructor', dtype, shape)
+        else:
+            static_tensor_array_ops = StaticTensorArrayOps(p, dtype, value_shape)
+            static_tensor_array_ops.register()
+            tensor1 = p.get_var_static('tensor_constructor', dtype, value_shape)
+        tensor_array_split = split_func(tensor_array1, tensor1(value), ta_len)
+        mod["main"] = relay.Function([v1, v2, v3, value, ta_len],
+                                     tensor_array_split)
+
+        # initialize and check
+        v1_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype)
+        v2_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype)
+        v3_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype)
+        value_data = np.random.uniform(low=0.0, high=8.0,
+                                       size=value_shape or shape).astype(dtype)
+        length_data = np.array([2, 2], dtype="int32")
+        expected = np.concatenate([value_data, v3_data])
+        expected = np.split(expected, indices_or_sections=[2, 4])
+        check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data,
+                                            value_data, length_data),
+                           dtype=dtype)
+
+    run('float32', [4, 3])
+    run('int32', [4, 3])
+    run('int32', [relay.Any(), 3], [4, 3], [2,])
+
+
+def test_static_tensor_array_concat():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+
+        v1 = relay.var('v1')
+        v2 = relay.var('v2')
+        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        tensor_array1 = tensor_array(relay.const(2))
+        write_func = p.get_var_static('tensor_array_write', dtype, shape)
+        concat_func = p.get_var_static('tensor_array_concat', dtype, shape)
+        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1))
+        tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2))
+        tensor_array_concat = concat_func(tensor_array1)
+        mod["main"] = relay.Function([v1, v2], tensor_array_concat)
+        v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype)
+        v2_data = np.random.uniform(low=0.0, high=8.0, size=(1, 3)).astype(dtype)
+        expected = [np.concatenate((v1_data, v2_data), axis=0)]
+        check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
+    run('float32', [relay.Any(), 3])
+    run('int32', [relay.Any(), 3])
+
+
+def test_static_tensor_array_gather():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+
+        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        write = p.get_var_static('tensor_array_write', dtype, shape)
+        gather = p.get_var_static('tensor_array_gather', dtype, shape)
+        v = relay.var('v')
+        indice = relay.var('indice')
+        init_tensor_array = tensor_array(relay.const(3))
+        tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v))
+        tensor_array2 = write(tensor_array1, relay.const(1), tensor(v))
+        tensor_array3 = write(tensor_array2, relay.const(2), tensor(v))
+        out = gather(tensor_array3, indice)
+        mod["main"] = relay.Function([v, indice], out)
+        t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        indice_data = np.array([0, 2], dtype="int32")
+        expected = [np.stack([t, t])]
+        check_tensor_array(mod, expected, *(t, indice_data), dtype=dtype)
+    run('float32', [])
+    run('int32', [2, 3])
+
+
+def test_static_tensor_array_stack():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+
+        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        write = p.get_var_static('tensor_array_write', dtype, shape)
+        stack = p.get_var_static('tensor_array_stack', dtype, shape)
+        v = relay.var('v')
+        init_tensor_array = tensor_array(relay.const(3))
+        tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v))
+        tensor_array2 = write(tensor_array1, relay.const(1), tensor(v))
+        tensor_array3 = write(tensor_array2, relay.const(2), tensor(v))
+        tensor_array4 = stack(tensor_array3)
+        mod["main"] = relay.Function([v], tensor_array4)
+        t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
+        expected = [np.stack([t, t, t])]
+        check_tensor_array(mod, expected, t, dtype=dtype)
+    run('float32', [])
+    run('int32', [2, 3])
+
+
+def test_static_tensor_get_data():
+    def run(dtype, shape):
+        mod = tvm.IRModule()
+        p = Prelude(mod)
+        static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
+        static_tensor_array_ops.register()
+        static_tensor_array_ops.define_tensor_get_data(shape)
+
+        np_data_list = []
+        ta_length = 3
+        for _ in range(ta_length):
+            np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype))
+
+        v0 = relay.var('v0')
+        v1 = relay.var('v1')
+        v2 = relay.var('v2')
+        n = relay.var('n')
+        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        init_tensor_array = tensor_array(relay.const(ta_length))
+        read_func = p.get_var_static('tensor_array_read', dtype, shape)
+        write_func = p.get_var_static('tensor_array_write', dtype, shape)
+        get_data_func = p.get_var_static('tensor_get_data', dtype, shape)
+        tensor_array0 = write_func(init_tensor_array, relay.const(0),
+                                   tensor(v0))
+        tensor_array1 = write_func(tensor_array0, relay.const(1),
+                                   tensor(v1))
+        tensor_array2 = write_func(tensor_array1, relay.const(2),
+                                   tensor(v2))
+
+        mod["main"] = relay.Function([v0, v1, v2, n], get_data_func(read_func(tensor_array2, n)))
+        expected = [np_data_list[0]]
+        check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype)
+        expected = [np_data_list[1]]
+        check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype)
+        expected = [np_data_list[2]]
+        check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype)
+    run('float32', [])
+    run('int32', [2, 3])
 
 if __name__ == "__main__":
     test_nat_constructor()
@@ -1016,3 +1405,17 @@ if __name__ == "__main__":
     test_tensor_array_concat()
     test_tensor_array_scatter()
     test_tensor_array_split()
+
+    test_static_tensor_take()
+    test_static_tensor_concatenate()
+    test_static_tensor_expand_dims()
+    test_static_tensor_array_constructor()
+    test_static_tensor_array_read()
+    test_static_tensor_array_write()
+    test_static_tensor_array_unstack()
+    test_static_tensor_array_scatter()
+    test_static_tensor_array_split()
+    test_static_tensor_array_concat()
+    test_static_tensor_array_stack()
+    test_static_tensor_array_gather()
+    test_static_tensor_get_data()