From c48e1cc166bd1c71d4eb2fc45714a830780402c9 Mon Sep 17 00:00:00 2001 From: Ina Dobreva <55383260+inadob@users.noreply.github.com> Date: Tue, 24 Sep 2019 18:18:41 +0100 Subject: [PATCH] add parser support for TANH tflite operator (#3996) --- python/tvm/relay/frontend/tflite.py | 18 ++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 15 +++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 046bf69..7c27283 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -82,6 +82,7 @@ class OperatorConverter(object): 'PAD': self.convert_pad, 'PACK': self.convert_pack, 'LOGISTIC': self.convert_logistic, + 'TANH':self.convert_tanh, 'SPLIT': self.convert_split, 'TRANSPOSE': self.convert_transpose, 'TILE': self.convert_tile, @@ -326,6 +327,23 @@ class OperatorConverter(object): return out + def convert_tanh(self, op): + """Convert TFLite TANH""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + out = _op.tanh(in_expr) + + return out + def convert_concatenation(self, op): """Convert TFLite concatenation""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index c856908..e501758 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -813,6 +813,20 @@ def test_forward_softmax(): """ Softmax """ _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6))) +####################################################################### +# Tanh +# -------- + +def _test_tanh(data): + """ One iteration of TANH """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = math_ops.sigmoid(in_data) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + +def test_forward_tanh(): + """ TANH """ + _test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6))) ####################################################################### # Fully Connected @@ -976,6 +990,7 @@ if __name__ == '__main__': test_forward_logistic() test_forward_pooling() test_forward_softmax() + test_forward_tanh() test_forward_fully_connected() # Elemwise -- 2.7.4