[Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array (#5243)
authorYao Wang <kevinthesunwy@gmail.com>
Sat, 11 Apr 2020 01:43:23 +0000 (18:43 -0700)
committerGitHub <noreply@github.com>
Sat, 11 Apr 2020 01:43:23 +0000 (10:43 +0900)
* Support TF Frontend Static TensorArray

* Fix pylint

* Fix lint

* Move get_tensor_array_shape into prelude

* Fix lint

* Fix common

python/tvm/relay/frontend/common.py
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/prelude.py
tests/python/frontend/tensorflow/test_forward.py
topi/python/topi/util.py

index 5465e50..e86890f 100644 (file)
@@ -456,22 +456,20 @@ def get_name(node):
 
 def infer_type(node, mod=None):
     """A method to infer the type of an intermediate node in the relay graph."""
-    new_mod = IRModule.from_expr(node)
-    if mod is not None:
-        new_mod.update(mod)
-    new_mod = _transform.InferType()(new_mod)
-    entry = new_mod["main"]
-    return entry if isinstance(node, _function.Function) else entry.body
+    if isinstance(mod, IRModule):
+        mod["main"] = _function.Function([], node)
+        mod = _transform.InferType()(mod)
+        entry = mod["main"]
+        ret = entry.body
+    else:
+        new_mod = IRModule.from_expr(node)
+        if mod is not None:
+            new_mod.update(mod)
+            new_mod = _transform.InferType()(new_mod)
+        entry = new_mod["main"]
+        ret = entry if isinstance(node, _function.Function) else entry.body
 
-def infer_shape(inputs, mod=None):
-    """A method to get the output type of an intermediate node in the graph."""
-    out_type = infer_type(inputs, mod=mod)
-    checked_type = out_type.checked_type
-    if hasattr(checked_type, 'shape'):
-        # Regular operator that outputs tensors
-        return get_const_tuple(out_type.checked_type.shape)
-    # The return type is not a tensor, for example List
-    return checked_type
+    return ret
 
 def infer_channels(inputs, transpose=False):
     """A hack for getting 'channels' or 'units' since caffe2 does not provide
@@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False):
     return channels
 
 
+def infer_shape(inputs, mod=None):
+    """A method to get the output type of an intermediate node in the graph."""
+    out_type = infer_type(inputs, mod=mod)
+    checked_type = out_type.checked_type
+    if hasattr(checked_type, 'shape'):
+        # Regular operator that outputs tensors
+        return get_const_tuple(checked_type.shape)
+    # The return type is not a tensor, for example List
+    return checked_type
+
+
 def infer_value(input_val, params, mod=None):
     """A hack for getting the value of an expression by evaluating a
     portion of the relay graph. This is often needed for functions that
@@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None):
         return m.get_output(0)
     except Exception:
         if isinstance(mod, IRModule):
-            mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
+            mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
         else:
             mod = IRModule.from_expr(input_val)
         exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
index 8a72423..120631e 100644 (file)
@@ -26,13 +26,14 @@ import numpy as np
 import tvm
 
 from tvm.ir import IRModule
-from tvm.relay.prelude import Prelude
+from tvm.relay.prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape
 from tvm.ir import structural_hash as s_hash
 
 from .. import analysis
 from .. import expr as _expr
 from .. import function as _function
 from .. import op as _op
+from ..ty import Any
 from ..expr_functor import ExprMutator, ExprVisitor
 from .common import AttrCvt, get_relay_op
 from .common import infer_type as _infer_type
@@ -259,8 +260,6 @@ def _conv(opname):
         if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
             # transform to NCHW for TVM backend compatible and set 'flip_layout'
             # to have output flip back to NHWC
-            tmp_shape = _infer_shape(inputs[2], mod)
-            tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
             inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
             attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
                 attr['strides'][3], attr['strides'][1], attr['strides'][2]
@@ -789,25 +788,152 @@ def _pack():
 
 def _tensor_array():
     def _impl(inputs, attr, params, prelude):
+        try:
+            from tensorflow.python.framework import tensor_util
+        except ImportError as e:
+            raise ImportError(
+                "Unable to import tensorflow which is required {}".format(e))
+
         dtype_str = attr.get('dtype').name
-        tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
-        return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0)))
+        assert not attr["dynamic_size"], "Dynamic size tensor array is " \
+                                         "not supported in TVM yet."
+
+        raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape'])
+        elem_shape = []
+        for dim in raw_elem_shape:
+            if dim < 0:
+                elem_shape.append(Any())
+            else:
+                elem_shape.append(dim)
+
+        if elem_shape:
+            # Element shape is specified.
+            # Directly create static tensor array with given shape.
+            static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                           dtype_str,
+                                                           elem_shape)
+            static_tensor_array_ops.register()
+            tensor_array_constructor = prelude.get_var_static('tensor_array',
+                                                              dtype_str,
+                                                              elem_shape)
+            tensor_array = tensor_array_constructor(inputs[0])
+            _static_tensor_array_map[tensor_array] = tensor_array
+        elif attr['identical_element_shapes']:
+            # identical_element_shapes is set but element shape is not given.
+            # We create a static tensor array with dummy shape and record it in
+            # _static_tensor_array_map. Later when creating other tensor array ops
+            # which uses this tensor array, we reconstruct this tensor array with
+            # actual shape.
+            dummy_shape = ()
+            static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                           dtype_str,
+                                                           dummy_shape)
+            static_tensor_array_ops.register()
+            tensor_array_constructor = prelude.get_var_static('tensor_array',
+                                                              dtype_str,
+                                                              dummy_shape)
+            tensor_array = tensor_array_constructor(inputs[0])
+            _static_tensor_array_map[tensor_array] = None
+        else:
+            tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
+            tensor_array = tensor_array_constructor(inputs[0])
+        return tensor_array
     return _impl
 
 def _tensor_array_scatter():
     def _impl(inputs, attr, params, prelude):
         dtype_str = attr.get('T').name
-        values_rank = len(inputs[2].type_annotation.shape)
-        unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
-        unstack_function = prelude.get_var(unstack_name, dtype_str)
-        values = unstack_function(inputs[2])
-        tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
-        return tensor_array_scatter_func(inputs[0], inputs[1], values)
+        input_ta = inputs[0]
+        input_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
+        values_shape = _infer_shape(inputs[2], prelude.mod)
+        input_t_shape = values_shape[1:]
+        indices_shape = _infer_shape(inputs[1], prelude.mod)
+
+        if input_shape is None:
+            values_rank = len(values_shape)
+            unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
+            unstack_function = prelude.get_var(unstack_name, dtype_str)
+            values = unstack_function(inputs[2])
+            tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
+        else:
+            static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                           dtype_str,
+                                                           input_t_shape)
+            static_tensor_array_ops.register()
+            # For scatter operation, it is possible to write to a newly create
+            # tensor array. We need to check and recreate its input tensor array.
+            if input_ta in _static_tensor_array_map and \
+                    _static_tensor_array_map[input_ta] is None:
+                ta_constructor = prelude.get_var_static('tensor_array',
+                                                        dtype_str,
+                                                        input_t_shape)
+                new_ta = ta_constructor(input_ta.args[0])
+                _static_tensor_array_map[input_ta] = new_ta
+                input_ta = new_ta
+
+            # Register static indices shape
+            if isinstance(indices_shape[0], int):
+                static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True)
+            tensor_array_scatter_func = prelude.get_var_static('tensor_array_scatter',
+                                                               dtype_str,
+                                                               input_t_shape)
+
+            static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                           dtype_str,
+                                                           values_shape)
+            static_tensor_array_ops.register()
+            unstack_function = prelude.get_var_static('tensor_array_unstack',
+                                                      dtype_str,
+                                                      values_shape)
+            values = unstack_function(inputs[2])
+        ret = tensor_array_scatter_func(input_ta, inputs[1], values)
+        return ret
     return _impl
 
 def _tensor_array_gather():
     def _impl(inputs, attr, params, prelude):
-        return prelude.tensor_array_gather(inputs[2], inputs[1])
+        dtype_str = attr.get('dtype').name
+        input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude)
+        indices_shape = _infer_shape(inputs[1], prelude.mod)
+
+        if input_shape is None:
+            gather_func = prelude.get_var('tensor_array_gather', dtype_str)
+            out = gather_func(inputs[2], inputs[1])
+        else:
+            static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                           dtype_str,
+                                                           input_shape)
+            static_tensor_array_ops.register()
+            if not isinstance(indices_shape[0], int):
+                gather_function = prelude.get_var_static('tensor_array_gather',
+                                                         dtype_str,
+                                                         input_shape)
+                out_tensor_t = gather_function(inputs[2], inputs[1])
+
+                # Output shape is (indices_shape[0],) + input_shape
+                static_tensor_array_ops.define_tensor_get_data((indices_shape[0],) + input_shape)
+                get_data_func = prelude.get_var_static('tensor_get_data',
+                                                       dtype_str,
+                                                       input_shape)
+                out = get_data_func(out_tensor_t)
+            else:
+                # For fixed length indices, directly generate static shape output
+                read_func = prelude.get_var_static('tensor_array_read',
+                                                   dtype_str,
+                                                   input_shape)
+                static_tensor_array_ops.define_tensor_get_data(input_shape)
+                get_data_func = prelude.get_var_static('tensor_get_data',
+                                                       dtype_str,
+                                                       input_shape)
+                tensor_list = []
+                for i in range(indices_shape[0]):
+                    index = _op.take(inputs[1], tvm.relay.const(i))
+                    out_tensor = get_data_func(read_func(inputs[2], index))
+                    tensor_list.append(_op.expand_dims(out_tensor, axis=0))
+
+                out = _op.concatenate(tensor_list, axis=0)
+
+        return out
     return _impl
 
 def _tensor_array_size():
@@ -817,37 +943,163 @@ def _tensor_array_size():
 
 def _tensor_array_write():
     def _impl(inputs, attr, params, prelude):
-        input_rank = len(inputs[2].type_annotation.shape)
-        dtype = attr.get('T').name
+        dtype_str = attr.get('T').name
+        input_ta = inputs[3]
+        input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
+        input_t_shape = _infer_shape(inputs[2], prelude.mod)
+        input_rank = len(input_t_shape)
+
+        if input_ta_shape is None:
+            tensor_name = 'tensor{}'.format(input_rank)
+            tensor_func = prelude.get_var(tensor_name, dtype_str)
+            v = tensor_func(inputs[2])
+            write_func = prelude.get_var('tensor_array_write', dtype_str)
+        else:
+            # For write operation, it is possible to write to a newly create
+            # tensor array. We need to check and recreate its input tensor array.
+            if input_ta in _static_tensor_array_map and \
+                    _static_tensor_array_map[input_ta] is None:
+                static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                               dtype_str,
+                                                               input_t_shape)
+                static_tensor_array_ops.register()
+                ta_constructor = prelude.get_var_static('tensor_array',
+                                                        dtype_str,
+                                                        input_t_shape)
+                new_ta = ta_constructor(input_ta.args[0])
+                _static_tensor_array_map[input_ta] = new_ta
+                input_ta = new_ta
+                input_ta_shape = input_t_shape
+            else:
+                input_ta_rank = len(input_ta_shape)
+                assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \
+                    format(input_ta_rank, input_rank)
+                static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                               dtype_str,
+                                                               input_ta_shape)
+                static_tensor_array_ops.register()
 
-        tensor_name = 'tensor{}'.format(input_rank)
-        tensor_func = prelude.get_var(tensor_name, dtype)
-        v = tensor_func(inputs[2])
-        write_func = prelude.get_var('tensor_array_write', dtype)
+            tensor_func = prelude.get_var_static("tensor_constructor",
+                                                 dtype_str,
+                                                 input_ta_shape)
+            v = tensor_func(inputs[2])
+            write_func = prelude.get_var_static('tensor_array_write',
+                                                dtype_str,
+                                                input_ta_shape)
 
-        return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
+        return write_func(input_ta, _op.take(inputs[1], tvm.relay.const(0)), v)
     return _impl
 
 def _tensor_array_read():
     def _impl(inputs, attr, params, prelude):
-        read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name)
-        return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
+        dtype_str = attr['dtype'].name
+        input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude)
+
+        if input_shape is None:
+            read_func = prelude.get_var('tensor_array_read', dtype_str)
+            out = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
+        else:
+            static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                           dtype_str,
+                                                           input_shape)
+            static_tensor_array_ops.register()
+            static_tensor_array_ops.define_tensor_get_data(input_shape)
+            read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape)
+            out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
+            get_data_func = prelude.get_var_static('tensor_get_data',
+                                                   dtype_str,
+                                                   input_shape)
+            out = get_data_func(out_tensor)
+
+        return out
     return _impl
 
 def _tensor_array_split():
     def _impl(inputs, attr, params, prelude):
-        input_rank = len(inputs[1].type_annotation.shape)
         dtype_str = attr.get('T').name
-        v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
+        input_ta = inputs[0]
+        input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
+        input_t_shape = _infer_shape(inputs[1], prelude.mod)
+        input_rank = len(input_t_shape)
         lengths = _op.cast(inputs[2], 'int32')
-        split_var = prelude.get_var('tensor_array_split', dtype_str)
-        return split_var(inputs[0], v, lengths)
+        lengths_shape = _infer_shape(lengths, prelude.mod)
+        value_shape = _infer_shape(inputs[1], prelude.mod)
+
+        if input_ta_shape is None:
+            v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
+            split_func = prelude.get_var('tensor_array_split', dtype_str)
+        else:
+            # For split operation, it is possible to write to a newly create
+            # tensor array. We need to check and recreate its input tensor array.
+            if input_ta in _static_tensor_array_map and \
+                    _static_tensor_array_map[input_ta] is None:
+                input_ta_shape = (Any(),) + input_t_shape[1:]
+                static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                               dtype_str,
+                                                               input_ta_shape)
+                static_tensor_array_ops.register()
+                ta_constructor = prelude.get_var_static('tensor_array',
+                                                        dtype_str,
+                                                        input_ta_shape)
+                new_ta = ta_constructor(input_ta.args[0])
+                _static_tensor_array_map[input_ta] = new_ta
+                input_ta = new_ta
+            else:
+                input_ta_rank = len(input_ta_shape)
+                assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \
+                    format(input_ta_rank, input_rank)
+                static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                               dtype_str,
+                                                               input_ta_shape)
+                static_tensor_array_ops.register()
+
+            # Check static value/indices shape
+            if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int):
+                static_tensor_array_ops.define_tensor_array_split(value_shape,
+                                                                  lengths_shape,
+                                                                  True)
+
+            tensor_func_name = prelude.get_name_static("tensor_constructor",
+                                                       dtype_str,
+                                                       value_shape)
+            if not hasattr(prelude, tensor_func_name):
+                static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                               dtype_str,
+                                                               value_shape)
+                static_tensor_array_ops.register()
+            tensor_func = prelude.get_var_static("tensor_constructor",
+                                                 dtype_str,
+                                                 value_shape)
+            v = tensor_func(inputs[1])
+            split_func = prelude.get_var_static('tensor_array_split',
+                                                dtype_str,
+                                                input_ta_shape)
+
+        return split_func(input_ta, v, lengths)
     return _impl
 
 def _tensor_array_concat():
     def _impl(inputs, attr, params, prelude):
-        concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name)
-        return concat_func(inputs[1])
+        dtype_str = attr['dtype'].name
+        input_shape = get_tensor_array_shape(inputs[1], dtype_str, prelude)
+
+        if input_shape is None:
+            concat_func = prelude.get_var('tensor_array_concat', dtype_str)
+            out = concat_func(inputs[1])
+        else:
+            static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                           dtype_str,
+                                                           input_shape)
+            static_tensor_array_ops.register()
+            concat_func = prelude.get_var_static("tensor_array_concat", dtype_str, input_shape)
+            out_tensor = concat_func(inputs[1])
+            static_tensor_array_ops.define_tensor_get_data((Any(),) + input_shape[1:])
+            get_data_func = prelude.get_var_static('tensor_get_data',
+                                                   dtype_str,
+                                                   input_shape)
+            out = get_data_func(out_tensor)
+
+        return out
     return _impl
 
 def _tile():
@@ -1370,7 +1622,7 @@ def _range():
 
         return AttrCvt(
             op_name="arange",
-            ignores=['Tidx'],
+            ignores=['Tidx', '_class'],
             extras={'start': start,
                     'stop': limit,
                     'step': delta,
@@ -2084,6 +2336,9 @@ class RecurrentNetworks(object):
 # 1.x.
 _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
 
+# A map to record tensor array with fixed rank shape
+_static_tensor_array_map = {}
+
 class RewriteSubgraph(ExprMutator):
     """
     A helper class to rewrite expr in while loop function to variable
index 47c3ba7..243eace 100644 (file)
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
 """A prelude containing useful global functions and ADT definitions."""
-from tvm.ir import IRModule
+from tvm.ir import IRModule, TypeCall
 
 from .ty import GlobalTypeVar, TensorType, Any, scalar_type
 from .expr import Var, GlobalVar, If, const
@@ -24,8 +24,51 @@ from .function import Function
 from .op.tensor import add, subtract, equal
 from .adt import Constructor, TypeData, Clause, Match
 from .adt import PatternConstructor, PatternVar, PatternWildcard
-from . import op
-
+from . import op, transform
+
+
+def get_tensor_array_shape(expr, dtype, prelude):
+    """Get the static shape of a tensor array if it has fixed rank shape.
+
+    By design, static ADT tensor in TVM has type name in the format
+    of static_tensor_dim0_dim1_..._dimN_t.
+
+    Parameters
+    ----------
+    expr : Relay Expr
+        Input expression.
+
+    dtype : str
+        Data type.
+
+    prelude : Prelude
+        Tensor array prelude
+
+    Returns
+    -------
+    shape : tuple of (int, Any) or None
+        The output shape. None if input tensor array
+        has dynamic shape.
+    """
+    mod = prelude.mod
+    mod["main"] = Function([], expr)
+    mod = transform.InferType()(mod)
+    checked_type = mod["main"].body.checked_type
+    assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
+    ta_type_str = checked_type.args[0].func.name_hint
+    static_ta_ty_start = "static_tensor_{}".format(dtype)
+    if ta_type_str.startswith(static_ta_ty_start):
+        shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), '') \
+            .replace("_t", '')
+        shape = []
+        if "scalar" not in shape_str:
+            for dim_str in shape_str.split("_"):
+                if dim_str == "?":
+                    shape.append(Any())
+                else:
+                    shape.append(int(dim_str))
+        return tuple(shape)
+    return None
 
 def _get_name_static(canonical, dtype, shape):
     """Get name for static shape tensor array op corresponding
index fdb8912..bc884bb 100644 (file)
@@ -839,63 +839,75 @@ def test_forward_squeeze():
     _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
 
 
-def test_tensor_array_constructor():
-    def run(dtype_str):
+#######################################################################
+# TensorArray
+# -----------
+def test_tensor_array_write_read():
+    def run(dtype_str, infer_shape, element_shape):
         with tf.Graph().as_default():
             dtype = tf_dtypes[dtype_str]
-            t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
-            t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
-            ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
-            ta2 = ta1.write(0, t)
+            np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str)
+            in_data = [np_data, np_data]
+            t1 = tf.constant(np_data, dtype=dtype)
+            t2 = tf.constant(np_data, dtype=dtype)
+            ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape,
+                                 element_shape=element_shape)
+            ta2 = ta1.write(0, t1)
             ta3 = ta2.write(1, t2)
             out = ta3.read(0)
             g = tf.get_default_graph()
-            compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
-    for dtype in tf_dtypes.keys():
-        run(dtype)
+            compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='vm')
+
+    for dtype in ["float32", "int8"]:
+        run(dtype, False, None)
+        run(dtype, False, tf.TensorShape([None, 2]))
+        run(dtype, True, None)
 
 
 def test_tensor_array_scatter():
-    def run(dtype_str):
+    def run(dtype_str, infer_shape):
         with tf.Graph().as_default():
             dtype =  tf_dtypes[dtype_str]
             t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
             indices = tf.constant([2, 1, 0])
             ta1 = tf.TensorArray(dtype=dtype, size=3,
-                                 infer_shape=False, dynamic_size=False)
+                                 infer_shape=infer_shape)
             ta2 = ta1.scatter(indices, t)
             out0 = ta2.read(0)
             out1 = ta2.read(1)
             out2 = ta2.read(2)
             g = tf.get_default_graph()
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
-    for dtype in tf_dtypes.keys():
-        run(dtype)
-
-# TODO(wweic): Fix gather issue with PartialEvaluate
-# def test_tensor_array_gather():
-#     with tf.Graph().as_default():
-#         dtype = 'float32'
-#         t = tf.constant([[1.0], [2.0], [3.0]])
-#         scatter_indices = tf.constant([2, 1, 0])
-#         gather_indices = tf.constant([1, 2])
-#         ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False)
-#         ta2 = ta1.scatter(scatter_indices, t)
-#         t1 = ta2.gather(gather_indices)
-#         g = tf.get_default_graph()
-#         compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')
+            compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='vm')
+            compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='vm')
+            compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='vm')
+    for dtype in ["float32", "int8"]:
+        run(dtype, False)
+        run(dtype, True)
+
+
+def test_tensor_array_gather():
+    def run(dtype_str, infer_shape):
+        with tf.Graph().as_default():
+            dtype =  tf_dtypes[dtype_str]
+            t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
+            scatter_indices = tf.constant([2, 1, 0])
+            gather_indices = tf.constant([1, 2])
+            ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
+            ta2 = ta1.scatter(scatter_indices, t)
+            t1 = ta2.gather(gather_indices)
+            g = tf.get_default_graph()
+            compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='vm')
+    for dtype in ["float32", "int8"]:
+        run(dtype, True)
 
 
 def test_tensor_array_split():
-    def run(dtype_str):
+    def run(dtype_str, infer_shape):
         with tf.Graph().as_default():
             dtype =  tf_dtypes[dtype_str]
             t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
             split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
-            ta1 = tf.TensorArray(dtype=dtype, size=4,
-                                 infer_shape=False, dynamic_size=False)
+            ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=infer_shape)
             ta2 = ta1.split(t, split_length)
             out0 = ta2.read(0)
             out1 = ta2.read(1)
@@ -906,56 +918,76 @@ def test_tensor_array_split():
             compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
             compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
             compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
-    for dtype in tf_dtypes.keys():
-        run(dtype)
+    for dtype in ["float32", "int8"]:
+        run(dtype, False)
+        run(dtype, True)
 
 
 def test_tensor_array_concat():
-    def run(dtype_str):
+    def run(dtype_str, infer_shape):
         with tf.Graph().as_default():
             dtype = tf_dtypes[dtype_str]
             t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
             split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
             ta1 = tf.TensorArray(dtype=dtype, size=4,
-                                 infer_shape=False, dynamic_size=False)
+                                 infer_shape=infer_shape)
             ta2 = ta1.split(t, split_length)
             t = ta2.concat()
             out = tf.identity(t)
             compare_tf_with_tvm([], [], ['Identity:0'], mode='debug')
-    for dtype in tf_dtypes.keys():
-        run(dtype)
+    for dtype in ["float32", "int8"]:
+        run(dtype, False)
+        run(dtype, True)
 
 
 def test_tensor_array_size():
-    def run(dtype_str):
+    def run(dtype_str, infer_shape):
         with tf.Graph().as_default():
             dtype =  tf_dtypes[dtype_str]
-            ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
+            ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape)
             out = ta1.size()
             g = tf.get_default_graph()
             compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
-    for dtype in tf_dtypes.keys():
-        run(dtype)
+    for dtype in ["float32", "int8"]:
+        run(dtype, False)
+        run(dtype, True)
+
+
+def test_tensor_array_stack():
+    def run(dtype_str, infer_shape):
+        with tf.Graph().as_default():
+            dtype =  tf_dtypes[dtype_str]
+            t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
+            scatter_indices = tf.constant([2, 1, 0])
+            ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
+            ta2 = ta1.scatter(scatter_indices, t)
+            t1 = ta2.stack()
+            print(t1)
+            g = tf.get_default_graph()
+
+            compare_tf_with_tvm([], [], ['TensorArrayStack/TensorArrayGatherV3:0'], mode='vm')
+    for dtype in ["float32", "int8"]:
+        run(dtype, True)
+
 
 def test_tensor_array_unstack():
-    def run(dtype_str, input_shape):
+    def run(dtype_str, input_shape, infer_shape):
         with tf.Graph().as_default():
             dtype = tf_dtypes[dtype_str]
             t = tf.constant(np.random.choice([0, 1, 2, 3],
                                              size=input_shape).astype(dtype.name))
-            ta1 = tf.TensorArray(dtype=dtype, infer_shape=False, size=input_shape[0])
+            ta1 = tf.TensorArray(dtype=dtype, infer_shape=infer_shape, size=input_shape[0])
             ta2 = ta1.unstack(t)
             out0 = ta2.size()
             out1 = ta2.read(0)
             compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
             compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
-    for dtype in tf_dtypes.keys():
-        run(dtype, (5,))
-        run(dtype, (5, 5))
-        run(dtype, (5, 5, 5))
-        run(dtype, (5, 5, 5, 5))
-        run(dtype, (5, 5, 5, 5, 5))
-        run(dtype, (5, 5, 5, 5, 5, 5))
+    for dtype in ["float32", "int8"]:
+        run(dtype, (5,), False)
+        run(dtype, (5, 5), True)
+        run(dtype, (5, 5, 5), False)
+        run(dtype, (5, 5, 5, 5), True)
+
 
 #######################################################################
 # ConcatV2
@@ -3241,6 +3273,16 @@ if __name__ == '__main__':
     test_forward_reduce()
     test_forward_mean()
 
+    # TensorArray
+    test_tensor_array_write_read()
+    test_tensor_array_concat()
+    test_tensor_array_scatter()
+    test_tensor_array_gather()
+    test_tensor_array_size()
+    test_tensor_array_split()
+    test_tensor_array_stack()
+    test_tensor_array_unstack()
+
     # General
     test_forward_multi_input()
     test_forward_multi_output()
index 6815357..50a6a36 100644 (file)
@@ -166,12 +166,14 @@ def get_const_tuple(in_tuple):
     """
     ret = []
     for elem in in_tuple:
-        if isinstance(elem, tvm.tir.Var):
+        if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)):
             ret.append(elem)
         elif not isinstance(elem, (tvm.tir.IntImm, int)):
             elem = tvm.tir.ir_pass.Simplify(elem)
             if not isinstance(elem, tvm.tir.IntImm):
                 ret.append(elem)
+            else:
+                ret.append(get_const_int(elem))
         else:
             ret.append(get_const_int(elem))
     return tuple(ret)