From dee52466dbed6095ffbf46e29a09dcb172ce636d Mon Sep 17 00:00:00 2001 From: Neo Chien Date: Sun, 1 Sep 2019 09:56:39 +0800 Subject: [PATCH] Implementation of tile for TFLite (#3814) --- python/tvm/relay/frontend/tflite.py | 25 ++++++++++++++++++++++++- tests/python/frontend/tflite/test_forward.py | 23 +++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index f4c10f2..e0f9775 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -82,7 +82,8 @@ class OperatorConverter(object): 'PACK': self.convert_pack, 'LOGISTIC': self.convert_logistic, 'SPLIT': self.convert_split, - 'TRANSPOSE': self.convert_transpose + 'TRANSPOSE': self.convert_transpose, + 'TILE': self.convert_tile } def check_unsupported_ops(self): @@ -769,6 +770,28 @@ class OperatorConverter(object): return out + def convert_tile(self, op): + """tile implementation.""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + + in_expr = self.get_expr(input_tensor_idx) + + # reps (tuple of int) – The number of times repeating the tensor data. + reps = tuple(self.get_tensor_value(input_tensors[1])) + + out = _op.tile(in_expr, reps) + + return out + def convert_pool2d(self, op, pool_type): """pool2d implementation.""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index a78225c..7712260 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -229,6 +229,26 @@ def test_forward_transpose(): _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2)) _test_forward_transpose((2, 3, 4, 5), ()) +####################################################################### +# tile +# --------- + + +def _test_forward_tile(in_shape, reps, dtype): + data = np.random.uniform(-5, 5, size=in_shape).astype(dtype) + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + + out = array_ops.tile(in_data, reps) + + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + + +def test_forward_tile(): + _test_forward_tile((2, ), (3, ), "int32") + _test_forward_tile((2, 2), (2, 3), "float32") + ####################################################################### # Pooling @@ -856,6 +876,9 @@ if __name__ == '__main__': # Transpose test_forward_transpose() + # Tile + test_forward_tile() + # Transforms test_forward_concatenation() test_forward_pad() -- 2.7.4