[Relay][Op][TF] Complete tensor array unstack with all ranks support (#4309)
authorWei Chen <ipondering.weic@gmail.com>
Tue, 12 Nov 2019 20:36:28 +0000 (12:36 -0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Tue, 12 Nov 2019 20:36:28 +0000 (12:36 -0800)
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/prelude.py
tests/python/frontend/tensorflow/test_forward.py

index 5a17d5f..2a65678 100644 (file)
@@ -40,6 +40,7 @@ from .common import infer_type as _infer_type
 from .common import infer_shape as _infer_shape
 from .common import infer_channels as _infer_channels
 from .common import infer_value as _infer_value
+from .common import infer_value_simulated as _infer_value_simulated
 
 __all__ = ['from_tensorflow']
 
@@ -1079,9 +1080,13 @@ def _rank():
 def _range():
     def _impl(inputs, attr, params):
         start = _get_param(params, inputs[0])[0]
-        limit = _get_param(params, inputs[1])[0] \
-            if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
-            else params.pop('Rank').asnumpy()[0]
+        if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant):
+            limit = _get_param(params, inputs[1])[0]
+        else:
+            if any(['Rank' in param for param in params]):
+                limit = params.pop('Rank').asnumpy()[0]
+            else:
+                limit = _infer_value_simulated(inputs[1], params).asnumpy()[0]
         delta = _get_param(params, inputs[2])[0]
         dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype)
         return AttrCvt(
index 1625e19..ddb9302 100644 (file)
@@ -336,6 +336,122 @@ class TensorArrayOps(object):
             Function([tensor2], helper_var(const(0), ndim, tensor2),
                      self.prelude.l(self.get_var('tensor_t')()), [])
 
+    def define_tensor_array_unstack_tensor3(self):
+        """Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array.
+
+        tensor_array_unstack_tensor3(t) : tensor_t -> list[tensor_t]
+        """
+        helper_name = self.get_name("tensor_array_unstack_tensor3_helper")
+        helper_var = GlobalVar(helper_name)
+        setattr(self.prelude, helper_name, helper_var)
+        tensor = Var("t", TensorType([Any(), Any(), Any()], self.dtype))
+        up = Var("up", scalar_type('int32'))
+        i = Var("i", scalar_type('int32'))
+
+        helper_body = If(equal(i, up),
+                         self.prelude.nil(),
+                         self.prelude.cons(self.get_var('tensor2')(op.take(tensor, i, axis=0)),
+                                           helper_var(add(i, const(1)), up, tensor)))
+        self.prelude.mod[helper_var] =\
+            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
+
+        tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3")
+        tensor_array_unstack_tensor3_var = GlobalVar(tensor_array_unstack_tensor3_name)
+        setattr(self.prelude, tensor_array_unstack_tensor3_name, tensor_array_unstack_tensor3_var)
+        tensor3 = Var("tensor", TensorType([Any(), Any(), Any()], self.dtype))
+        shape = op.shape_of(tensor3)
+        ndim = op.take(shape, const(0))
+        self.prelude.mod[tensor_array_unstack_tensor3_var] =\
+            Function([tensor3], helper_var(const(0), ndim, tensor3),
+                     self.prelude.l(self.get_var('tensor_t')()), [])
+
+    def define_tensor_array_unstack_tensor4(self):
+        """Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array.
+
+        tensor_array_unstack_tensor4(t) : tensor_t -> list[tensor_t]
+        """
+        helper_name = self.get_name("tensor_array_unstack_tensor4_helper")
+        helper_var = GlobalVar(helper_name)
+        setattr(self.prelude, helper_name, helper_var)
+        tensor = Var("t", TensorType([Any(), Any(), Any(), Any()], self.dtype))
+        up = Var("up", scalar_type('int32'))
+        i = Var("i", scalar_type('int32'))
+
+        helper_body = If(equal(i, up),
+                         self.prelude.nil(),
+                         self.prelude.cons(self.get_var('tensor3')(op.take(tensor, i, axis=0)),
+                                           helper_var(add(i, const(1)), up, tensor)))
+        self.prelude.mod[helper_var] =\
+            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
+
+        tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4")
+        tensor_array_unstack_tensor4_var = GlobalVar(tensor_array_unstack_tensor4_name)
+        setattr(self.prelude, tensor_array_unstack_tensor4_name, tensor_array_unstack_tensor4_var)
+        tensor4 = Var("tensor", TensorType([Any(), Any(), Any(), Any()], self.dtype))
+        shape = op.shape_of(tensor4)
+        ndim = op.take(shape, const(0))
+        self.prelude.mod[tensor_array_unstack_tensor4_var] =\
+            Function([tensor4], helper_var(const(0), ndim, tensor4),
+                     self.prelude.l(self.get_var('tensor_t')()), [])
+
+    def define_tensor_array_unstack_tensor5(self):
+        """Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array.
+
+        tensor_array_unstack_tensor5(t) : tensor_t -> list[tensor_t]
+        """
+        helper_name = self.get_name("tensor_array_unstack_tensor5_helper")
+        helper_var = GlobalVar(helper_name)
+        setattr(self.prelude, helper_name, helper_var)
+        tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
+        up = Var("up", scalar_type('int32'))
+        i = Var("i", scalar_type('int32'))
+
+        helper_body = If(equal(i, up),
+                         self.prelude.nil(),
+                         self.prelude.cons(self.get_var('tensor4')(op.take(tensor, i, axis=0)),
+                                           helper_var(add(i, const(1)), up, tensor)))
+        self.prelude.mod[helper_var] =\
+            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
+
+        tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5")
+        tensor_array_unstack_tensor5_var = GlobalVar(tensor_array_unstack_tensor5_name)
+        setattr(self.prelude, tensor_array_unstack_tensor5_name, tensor_array_unstack_tensor5_var)
+        tensor5 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
+        shape = op.shape_of(tensor5)
+        ndim = op.take(shape, const(0))
+        self.prelude.mod[tensor_array_unstack_tensor5_var] =\
+            Function([tensor5], helper_var(const(0), ndim, tensor5),
+                     self.prelude.l(self.get_var('tensor_t')()), [])
+
+    def define_tensor_array_unstack_tensor6(self):
+        """Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array.
+
+        tensor_array_unstack_tensor6(t) : tensor_t -> list[tensor_t]
+        """
+        helper_name = self.get_name("tensor_array_unstack_tensor6_helper")
+        helper_var = GlobalVar(helper_name)
+        setattr(self.prelude, helper_name, helper_var)
+        tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
+        up = Var("up", scalar_type('int32'))
+        i = Var("i", scalar_type('int32'))
+
+        helper_body = If(equal(i, up),
+                         self.prelude.nil(),
+                         self.prelude.cons(self.get_var('tensor5')(op.take(tensor, i, axis=0)),
+                                           helper_var(add(i, const(1)), up, tensor)))
+        self.prelude.mod[helper_var] =\
+            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
+
+        tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6")
+        tensor_array_unstack_tensor6_var = GlobalVar(tensor_array_unstack_tensor6_name)
+        setattr(self.prelude, tensor_array_unstack_tensor6_name, tensor_array_unstack_tensor6_var)
+        tensor6 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
+        shape = op.shape_of(tensor6)
+        ndim = op.take(shape, const(0))
+        self.prelude.mod[tensor_array_unstack_tensor6_var] =\
+            Function([tensor6], helper_var(const(0), ndim, tensor6),
+                     self.prelude.l(self.get_var('tensor_t')()), [])
+
     def define_tensor_array_scatter(self):
         """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
         tensor_array_scatter(ta, indices, value) :
@@ -516,6 +632,10 @@ class TensorArrayOps(object):
         self.define_tensor_array_write()
         self.define_tensor_array_unstack_tensor1()
         self.define_tensor_array_unstack_tensor2()
+        self.define_tensor_array_unstack_tensor3()
+        self.define_tensor_array_unstack_tensor4()
+        self.define_tensor_array_unstack_tensor5()
+        self.define_tensor_array_unstack_tensor6()
         self.define_tensor_array_scatter()
         self.define_tensor_array_split()
         self.define_tensor_array_concat()
index 30b6dfe..17db2f5 100644 (file)
@@ -763,6 +763,26 @@ def test_tensor_array_size():
     for dtype in tf_dtypes.keys():
         run(dtype)
 
+def test_tensor_array_unstack():
+    def run(dtype_str, input_shape):
+        with tf.Graph().as_default():
+            dtype = tf_dtypes[dtype_str]
+            t = tf.constant(np.random.choice([0, 1, 2, 3],
+                                             size=input_shape).astype(dtype.name))
+            ta1 = tf.TensorArray(dtype=dtype, infer_shape=False, size=input_shape[0])
+            ta2 = ta1.unstack(t)
+            out0 = ta2.size()
+            out1 = ta2.read(0)
+            compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
+            compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
+    for dtype in tf_dtypes.keys():
+        run(dtype, (5,))
+        run(dtype, (5, 5))
+        run(dtype, (5, 5, 5))
+        run(dtype, (5, 5, 5, 5))
+        run(dtype, (5, 5, 5, 5, 5))
+        run(dtype, (5, 5, 5, 5, 5, 5))
+
 #######################################################################
 # ConcatV2
 # --------