[Frontend][TFLite] ADD_N operator (#5474)
authorMahesh Ambule <15611578+maheshambule@users.noreply.github.com>
Thu, 7 May 2020 20:09:18 +0000 (01:39 +0530)
committerGitHub <noreply@github.com>
Thu, 7 May 2020 20:09:18 +0000 (13:09 -0700)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index bb456f1..ab0eabc 100644 (file)
@@ -64,6 +64,7 @@ class OperatorConverter(object):
         self.convert_map = {
             'ABS': self.convert_abs,
             'ADD': self.convert_add,
+            'ADD_N': self.convert_add_n,
             'AVERAGE_POOL_2D': self.convert_average_pool2d,
             'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
             'CAST': self.convert_cast,
@@ -800,28 +801,9 @@ class OperatorConverter(object):
         assert len(input_tensors) == 2, "input tensors length should be 2"
 
         lhs_tensor = input_tensors[0]
-        if self.has_expr(lhs_tensor.tensor_idx):
-            # In most cases, we can assume that TOCO fuses elemwise operators
-            # with constants - it means both will be tensors.
-            lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
-        else:
-            # However, in some corner cases, the elemwise operator is not fused,
-            # we can receive as constant.
-            lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type())
-            lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor),
-                                              dtype=lhs_type_str)
-
         rhs_tensor = input_tensors[1]
-        if self.has_expr(rhs_tensor.tensor_idx):
-            # In most cases, we can assume that TOCO fuses elemwise operators
-            # with constants - it means both will be tensors.
-            rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
-        else:
-            # However, in some corner cases, the elemwise operator is not fused,
-            # we can receive as constant.
-            rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
-            rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
-                                              dtype=rhs_type_str)
+        lhs_expr = self.get_tensor_expr(lhs_tensor)
+        rhs_expr = self.get_tensor_expr(rhs_tensor)
 
         output_tensors = self.get_output_tensors(op)
         assert len(output_tensors) == 1, "output tensors length should be 1"
@@ -873,6 +855,20 @@ class OperatorConverter(object):
             return self._convert_elemwise(_qnn.op.add, op)
         return self._convert_elemwise(_op.add, op)
 
+    def convert_add_n(self, op):
+        """Convert TFLite ADD_N"""
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+
+        input_tensors = self.get_input_tensors(op)
+        assert not input_tensors[0].qnn_params, "TFLite does not support quantized ADD_N."
+        lhs_expr = self.get_tensor_expr(input_tensors[0])
+        for rhs_tensor in input_tensors[1:]:
+            assert not rhs_tensor.qnn_params, "TFLite does not support quantized ADD_N"
+            rhs_expr = self.get_tensor_expr(rhs_tensor)
+            lhs_expr = _op.add(lhs_expr, rhs_expr)
+        return lhs_expr
+
     def convert_sub(self, op):
         """Convert TFLite SUB"""
         # Check if the input tensor is quantized, call QNN op
index 20a077f..957f622 100644 (file)
@@ -1199,6 +1199,43 @@ def test_all_elemwise():
         _test_forward_elemwise(_test_floor_divide)
         _test_forward_elemwise(_test_floor_mod)
 
+
+#######################################################################
+# AddN
+# ----
+
+
+def _test_forward_add_n(inputs):
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+        temp = []
+        for each in inputs:
+            temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
+        output = tf.add_n(temp)
+        compare_tflite_with_tvm([each for each in inputs], [each.name for each in temp],
+                                [each for each in temp], [output])
+
+
+def test_forward_add_n():
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
+        y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
+        z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
+        m, n, o = x.astype(np.float32), y.astype(np.float32), z.astype(np.float32)
+        in0 = x
+        in1 = [x, y]
+        in2 = (x, y, z)
+        in3 = m
+        in4 = [m, n]
+        in5 = (m, n, o)
+        _test_forward_add_n(in0)
+        _test_forward_add_n(in1)
+        _test_forward_add_n(in2)
+        _test_forward_add_n(in3)
+        _test_forward_add_n(in4)
+        _test_forward_add_n(in5)
+
+
 #######################################################################
 # Logical operators
 # -----------------
@@ -2005,6 +2042,7 @@ def test_forward_mediapipe_hand_landmark():
         tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
                                     rtol=1e-5, atol=1e-5)
 
+
 #######################################################################
 # Main
 # ----
@@ -2062,6 +2100,7 @@ if __name__ == '__main__':
 
     # Elemwise
     test_all_elemwise()
+    test_forward_add_n()
 
     # Unary elemwise
     test_all_unary_elemwise()