Add parser support for ReLU tflite operator (#4022)
authorIna Dobreva <55383260+inadob@users.noreply.github.com>
Sat, 28 Sep 2019 00:30:11 +0000 (01:30 +0100)
committerYao Wang <kevinthesunwy@gmail.com>
Sat, 28 Sep 2019 00:30:11 +0000 (17:30 -0700)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index eba19a4..01f6c67 100644 (file)
@@ -84,6 +84,7 @@ class OperatorConverter(object):
             'PACK': self.convert_pack,
             'LOGISTIC': self.convert_logistic,
             'TANH':self.convert_tanh,
+            'RELU':self.convert_relu,
             'SPLIT': self.convert_split,
             'TRANSPOSE': self.convert_transpose,
             'TILE': self.convert_tile,
@@ -345,6 +346,23 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_relu(self, op):
+        """Convert TFLite ReLU"""
+        try:
+            from tflite.Operator import Operator
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+        out = _op.nn.relu(in_expr)
+
+        return out
+
     def convert_concatenation(self, op):
         """Convert TFLite concatenation"""
         try:
index 5d97ce8..06afa59 100644 (file)
@@ -837,6 +837,21 @@ def test_forward_tanh():
     _test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
 
 #######################################################################
+# ReLu
+# --------
+
+def _test_relu(data):
+    """ One iteration of ReLU """
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        out = nn_ops.relu(in_data)
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+def test_forward_relu():
+    """ ReLU """
+    _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
+
+#######################################################################
 # Fully Connected
 # -------
 
@@ -999,6 +1014,7 @@ if __name__ == '__main__':
     test_forward_pooling()
     test_forward_softmax()
     test_forward_tanh()
+    test_forward_relu()
     test_forward_fully_connected()
 
     # Elemwise