TFLite: Add fused_activation_function for ADD, SUB, MUL, DIV (#3372)
authorAlexander Pivovarov <pivovaa@amazon.com>
Mon, 17 Jun 2019 19:36:31 +0000 (12:36 -0700)
committerYao Wang <kevinthesunwy@gmail.com>
Mon, 17 Jun 2019 19:36:31 +0000 (12:36 -0700)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 7b4394e..0ffd07e 100644 (file)
@@ -298,6 +298,12 @@ class OperatorConverter(object):
         """Generic method to Convert TFLite elemwise"""
         try:
             from tflite.Operator import Operator
+            from tflite.AddOptions import AddOptions
+            from tflite.SubOptions import SubOptions
+            from tflite.MulOptions import MulOptions
+            from tflite.DivOptions import DivOptions
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.ActivationFunctionType import ActivationFunctionType
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -320,6 +326,26 @@ class OperatorConverter(object):
             rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
                                               dtype=rhs_type_str)
         out = relay_op(lhs_expr, rhs_expr)
+
+        # Options (fused_activation_function)
+        options = None
+        if op.BuiltinOptionsType() == BuiltinOptions.AddOptions:
+            options = AddOptions()
+        elif op.BuiltinOptionsType() == BuiltinOptions.SubOptions:
+            options = SubOptions()
+        elif op.BuiltinOptionsType() == BuiltinOptions.MulOptions:
+            options = MulOptions()
+        elif op.BuiltinOptionsType() == BuiltinOptions.DivOptions:
+            options = DivOptions()
+
+        if options is not None:
+            op_options = op.BuiltinOptions()
+            options.Init(op_options.Bytes, op_options.Pos)
+            fused_activation_fn = options.FusedActivationFunction()
+            # if we have activation fn
+            if fused_activation_fn != ActivationFunctionType.NONE:
+                out = self.convert_fused_activation_function(out, fused_activation_fn)
+
         return out
 
     def convert_add(self, op):
index 3b76fad..795a089 100644 (file)
@@ -21,6 +21,7 @@ TFLite testcases
 This article is a test script to test TFLite operator with Relay.
 """
 from __future__ import print_function
+from functools import partial
 import numpy as np
 import tvm
 from tvm import relay
@@ -146,6 +147,20 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
                 tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
 
+def with_fused_activation_function(input_tensor, fn_name):
+    if fn_name is None or fn_name == "NONE":
+        return input_tensor
+    if fn_name == "RELU":
+        return nn_ops.relu(input_tensor)
+    if fn_name == "RELU6":
+        return nn_ops.relu6(input_tensor)
+    if fn_name == "RELU_N1_TO_1":
+        return math_ops.maximum(-1, math_ops.minimum(input_tensor, 1))
+    if fn_name == "TANH":
+        return math_ops.tanh(input_tensor)
+    raise AssertionError("Unknown fused_activation_function {}".format(fn_name))
+
+
 #######################################################################
 # Pooling
 # -------
@@ -313,7 +328,7 @@ def test_forward_concatenation():
 # Element-wise
 # ---
 
-def _test_elemwise(math_op, data):
+def _test_elemwise(math_op, data, fused_activation_function=None):
     """ One iteration of add """
 
     assert len(data) == 2
@@ -323,12 +338,14 @@ def _test_elemwise(math_op, data):
         in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
                    array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
         out = math_op(in_data[0], in_data[1])
+        out = with_fused_activation_function(out, fused_activation_function)
         compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
 
     # Test with tensor and constant
     with tf.Graph().as_default():
         in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
         out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
+        out = with_fused_activation_function(out, fused_activation_function)
         compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
 
 
@@ -336,31 +353,31 @@ def _test_elemwise(math_op, data):
 # Add
 # ---
 
-def _test_add(data):
+def _test_add(data, fused_activation_function=None):
     """ One iteration of add """
-    return _test_elemwise(math_ops.add, data)
+    return _test_elemwise(math_ops.add, data, fused_activation_function)
 
 #######################################################################
 # Subtract
 # --------
 
-def _test_sub(data):
+def _test_sub(data, fused_activation_function=None):
     """ One iteration of subtract """
-    return _test_elemwise(math_ops.subtract, data)
+    return _test_elemwise(math_ops.subtract, data, fused_activation_function)
 #######################################################################
 # Mul
 # ---
-def _test_mul(data):
+def _test_mul(data, fused_activation_function=None):
     """ One iteration of mul """
-    return _test_elemwise(math_ops.multiply, data)
+    return _test_elemwise(math_ops.multiply, data, fused_activation_function)
 
 #######################################################################
 # Divide
 # ------
 
-def _test_div(data):
+def _test_div(data, fused_activation_function=None):
     """ One iteration of divide """
-    return _test_elemwise(math_ops.divide, data)
+    return _test_elemwise(math_ops.divide, data, fused_activation_function)
 #######################################################################
 # Power
 # -----
@@ -386,17 +403,25 @@ def _test_minimum(data):
 def _test_forward_elemwise(testop):
     """ Elewise"""
     testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
-               np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))])
+               np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))])
     testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
-               np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))])
+               np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))])
     testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
-               np.arange(3.0, dtype=np.float32).reshape((1, 3))])
+               np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3))])
 
 def test_all_elemwise():
     _test_forward_elemwise(_test_add)
+    _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU"))
+    _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6"))
     _test_forward_elemwise(_test_sub)
+    _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU"))
+    _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6"))
     _test_forward_elemwise(_test_mul)
+    _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU"))
+    _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6"))
     _test_forward_elemwise(_test_div)
+    _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU"))
+    _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6"))
     _test_forward_elemwise(_test_pow)
     _test_forward_elemwise(_test_maximum)
     _test_forward_elemwise(_test_minimum)