From 03a29da76433654afbe6b3fbfe0dd4564788fa1a Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Tue, 12 Nov 2019 12:36:28 -0800 Subject: [PATCH] [Relay][Op][TF] Complete tensor array unstack with all ranks support (#4309) --- python/tvm/relay/frontend/tensorflow.py | 11 ++- python/tvm/relay/prelude.py | 120 +++++++++++++++++++++++ tests/python/frontend/tensorflow/test_forward.py | 20 ++++ 3 files changed, 148 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 5a17d5f..2a65678 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -40,6 +40,7 @@ from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape from .common import infer_channels as _infer_channels from .common import infer_value as _infer_value +from .common import infer_value_simulated as _infer_value_simulated __all__ = ['from_tensorflow'] @@ -1079,9 +1080,13 @@ def _rank(): def _range(): def _impl(inputs, attr, params): start = _get_param(params, inputs[0])[0] - limit = _get_param(params, inputs[1])[0] \ - if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \ - else params.pop('Rank').asnumpy()[0] + if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant): + limit = _get_param(params, inputs[1])[0] + else: + if any(['Rank' in param for param in params]): + limit = params.pop('Rank').asnumpy()[0] + else: + limit = _infer_value_simulated(inputs[1], params).asnumpy()[0] delta = _get_param(params, inputs[2])[0] dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype) return AttrCvt( diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 1625e19..ddb9302 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -336,6 +336,122 @@ class TensorArrayOps(object): Function([tensor2], helper_var(const(0), ndim, tensor2), self.prelude.l(self.get_var('tensor_t')()), []) + def define_tensor_array_unstack_tensor3(self): + """Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array. + + tensor_array_unstack_tensor3(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor3_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any(), Any(), Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + + helper_body = If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(self.get_var('tensor2')(op.take(tensor, i, axis=0)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), []) + + tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3") + tensor_array_unstack_tensor3_var = GlobalVar(tensor_array_unstack_tensor3_name) + setattr(self.prelude, tensor_array_unstack_tensor3_name, tensor_array_unstack_tensor3_var) + tensor3 = Var("tensor", TensorType([Any(), Any(), Any()], self.dtype)) + shape = op.shape_of(tensor3) + ndim = op.take(shape, const(0)) + self.prelude.mod[tensor_array_unstack_tensor3_var] =\ + Function([tensor3], helper_var(const(0), ndim, tensor3), + self.prelude.l(self.get_var('tensor_t')()), []) + + def define_tensor_array_unstack_tensor4(self): + """Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array. + + tensor_array_unstack_tensor4(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor4_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any(), Any(), Any(), Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + + helper_body = If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(self.get_var('tensor3')(op.take(tensor, i, axis=0)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), []) + + tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4") + tensor_array_unstack_tensor4_var = GlobalVar(tensor_array_unstack_tensor4_name) + setattr(self.prelude, tensor_array_unstack_tensor4_name, tensor_array_unstack_tensor4_var) + tensor4 = Var("tensor", TensorType([Any(), Any(), Any(), Any()], self.dtype)) + shape = op.shape_of(tensor4) + ndim = op.take(shape, const(0)) + self.prelude.mod[tensor_array_unstack_tensor4_var] =\ + Function([tensor4], helper_var(const(0), ndim, tensor4), + self.prelude.l(self.get_var('tensor_t')()), []) + + def define_tensor_array_unstack_tensor5(self): + """Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array. + + tensor_array_unstack_tensor5(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor5_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + + helper_body = If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(self.get_var('tensor4')(op.take(tensor, i, axis=0)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), []) + + tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5") + tensor_array_unstack_tensor5_var = GlobalVar(tensor_array_unstack_tensor5_name) + setattr(self.prelude, tensor_array_unstack_tensor5_name, tensor_array_unstack_tensor5_var) + tensor5 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)) + shape = op.shape_of(tensor5) + ndim = op.take(shape, const(0)) + self.prelude.mod[tensor_array_unstack_tensor5_var] =\ + Function([tensor5], helper_var(const(0), ndim, tensor5), + self.prelude.l(self.get_var('tensor_t')()), []) + + def define_tensor_array_unstack_tensor6(self): + """Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array. + + tensor_array_unstack_tensor6(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor6_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + + helper_body = If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(self.get_var('tensor5')(op.take(tensor, i, axis=0)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), []) + + tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6") + tensor_array_unstack_tensor6_var = GlobalVar(tensor_array_unstack_tensor6_name) + setattr(self.prelude, tensor_array_unstack_tensor6_name, tensor_array_unstack_tensor6_var) + tensor6 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)) + shape = op.shape_of(tensor6) + ndim = op.take(shape, const(0)) + self.prelude.mod[tensor_array_unstack_tensor6_var] =\ + Function([tensor6], helper_var(const(0), ndim, tensor6), + self.prelude.l(self.get_var('tensor_t')()), []) + def define_tensor_array_scatter(self): """Defines a function to scatter the values of a tensor_t in indices of a tensor array. tensor_array_scatter(ta, indices, value) : @@ -516,6 +632,10 @@ class TensorArrayOps(object): self.define_tensor_array_write() self.define_tensor_array_unstack_tensor1() self.define_tensor_array_unstack_tensor2() + self.define_tensor_array_unstack_tensor3() + self.define_tensor_array_unstack_tensor4() + self.define_tensor_array_unstack_tensor5() + self.define_tensor_array_unstack_tensor6() self.define_tensor_array_scatter() self.define_tensor_array_split() self.define_tensor_array_concat() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 30b6dfe..17db2f5 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -763,6 +763,26 @@ def test_tensor_array_size(): for dtype in tf_dtypes.keys(): run(dtype) +def test_tensor_array_unstack(): + def run(dtype_str, input_shape): + with tf.Graph().as_default(): + dtype = tf_dtypes[dtype_str] + t = tf.constant(np.random.choice([0, 1, 2, 3], + size=input_shape).astype(dtype.name)) + ta1 = tf.TensorArray(dtype=dtype, infer_shape=False, size=input_shape[0]) + ta2 = ta1.unstack(t) + out0 = ta2.size() + out1 = ta2.read(0) + compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') + compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') + for dtype in tf_dtypes.keys(): + run(dtype, (5,)) + run(dtype, (5, 5)) + run(dtype, (5, 5, 5)) + run(dtype, (5, 5, 5, 5)) + run(dtype, (5, 5, 5, 5, 5)) + run(dtype, (5, 5, 5, 5, 5, 5)) + ####################################################################### # ConcatV2 # -------- -- 2.7.4