[TFLite] Support PRelu (#4298)
authorZhao Wu <wuzhaozju@gmail.com>
Sun, 10 Nov 2019 19:45:10 +0000 (03:45 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sun, 10 Nov 2019 19:45:10 +0000 (11:45 -0800)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index b042af9..8966aa6 100644 (file)
@@ -94,7 +94,8 @@ class OperatorConverter(object):
             'CAST': self.convert_cast,
             'TILE': self.convert_tile,
             'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
-            'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd
+            'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
+            'PRELU': self.convert_prelu,
         }
 
     def check_unsupported_ops(self):
@@ -1325,6 +1326,29 @@ class OperatorConverter(object):
 
         return reshaped_permuted_reshaped_padded
 
+    def convert_prelu(self, op):
+        """Convert TFLite PReLU"""
+        try:
+            from tflite.Operator import Operator
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+
+        input_tensor = input_tensors[0]
+        alpha_tensor = input_tensors[1]
+        alpha_tensor_type = alpha_tensor.tensor.Type()
+        alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
+        alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor),
+                                            dtype=alpha_tensor_type_str)
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+        out = _op.nn.prelu(in_expr, alpha_expr, axis=3)
+
+        return out
+
+
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
index de19fe3..c2c3cb5 100644 (file)
@@ -894,6 +894,19 @@ def test_forward_relu():
     """ ReLU """
     _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
 
+def _test_prelu(data):
+    """ One iteration of PReLU """
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        alpha = np.full((data.shape[-1],), 0.2, dtype=data.dtype)
+        # This specific pattern will be replaced into PRelu by tflite
+        out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data))
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+def test_forward_prelu():
+    """ PReLU """
+    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"))
+
 #######################################################################
 # Fully Connected
 # -------
@@ -1121,6 +1134,7 @@ if __name__ == '__main__':
     test_forward_softmax()
     test_forward_tanh()
     test_forward_relu()
+    test_forward_prelu()
     test_forward_fully_connected()
 
     # Elemwise