From 2327bb9f6b7cd156a573a3ea50da075d1e07923a Mon Sep 17 00:00:00 2001 From: Ina Dobreva <55383260+inadob@users.noreply.github.com> Date: Fri, 10 Jan 2020 22:57:16 +0000 Subject: [PATCH] [Relay][Frontend][TFlite] Add parses support for SLICE (#4502) * [Relay][Frontend][TFlite] Add parses support for SLICE * TFlite 1.13: convertor gives nonsense output when size[i]==-1 * TF parser: SLICE need fixing for size[i]==-1 -> gives wrong output bcs of indices * Set end[i] = input_tensor_shape[i] as suggested in PR review * Add another test to cover size=-1 case --- python/tvm/relay/frontend/tflite.py | 30 ++++++++++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 21 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5737eae..cb6dbea 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -103,6 +103,7 @@ class OperatorConverter(object): 'TANH':self.convert_tanh, 'RELU':self.convert_relu, 'SPLIT': self.convert_split, + 'SLICE': self.convert_slice, 'TRANSPOSE': self.convert_transpose, 'CAST': self.convert_cast, 'TILE': self.convert_tile, @@ -1152,6 +1153,35 @@ class OperatorConverter(object): return out + def convert_slice(self, op): + """Convert TFLite SLICE""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be == 3" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + begin = list(self.get_tensor_value(input_tensors[1])) + size = list(self.get_tensor_value(input_tensors[2])) + # strided_slice(Relay) needs the slice's end indices, not the size + end = size + input_tensor_shape = input_tensor.tensor.ShapeAsNumpy() + input_tensor_rank = len(input_tensor_shape) + for i in range(input_tensor_rank): + if size[i] == -1: + end[i] = input_tensor_shape[i] + else: + end[i] += begin[i] + + out = _op.strided_slice(in_expr, begin, end) + + return out + def convert_transpose(self, op): """transpose implementation.""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index e6805a9..1478b25 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -225,6 +225,26 @@ def test_forward_split(): _test_split((1, 3, 5, 6), -1, 3, 'float32') ####################################################################### +# slice +# ----- + +def _test_slice(data, begin, size): + """ One iteration of SLICE """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = array_ops.slice(in_data, begin, size) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + +def test_forward_slice(): + """ SLICE """ + _test_slice(np.arange(4, dtype=np.float32).reshape((4, )), begin=[0], size=[2]) + _test_slice(np.arange(18, dtype=np.int32).reshape((3, 2, 3)), begin=[1, 0, 0], size=[1, 1, 3]) + # tflite 1.13 outputs nonsense values if size[i] == -1 + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1]) + _test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1]) + +####################################################################### # transpose # --------- @@ -1408,6 +1428,7 @@ if __name__ == '__main__': test_forward_reshape() test_all_resize() test_forward_squeeze() + test_forward_slice() # NN test_forward_convolution() -- 2.7.4