[Frontend][TFLite] support for FILL and SPLIT_V operators (#5330)
authorMahesh Ambule <15611578+maheshambule@users.noreply.github.com>
Wed, 29 Apr 2020 10:51:57 +0000 (16:21 +0530)
committerGitHub <noreply@github.com>
Wed, 29 Apr 2020 10:51:57 +0000 (18:51 +0800)
* tflite spliv ops

* TFLITE fill and splitv ops

* TFLITE fill and splitv ops

* TFLITE fill and splitv ops

* remove unnecessary operator check

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 2065d60..b9a1657 100644 (file)
@@ -78,6 +78,7 @@ class OperatorConverter(object):
             'ELU': self.convert_elu,
             'EQUAL': self.convert_equal,
             'EXP': self.convert_exp,
+            'FILL': self.convert_fill,
             'FLOOR_DIV': self.convert_floor_div,
             'FLOOR_MOD': self.convert_floor_mod,
             'FLOOR': self.convert_floor,
@@ -123,6 +124,7 @@ class OperatorConverter(object):
             'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
             'SPACE_TO_DEPTH': self.convert_space_to_depth,
             'SPLIT': self.convert_split,
+            'SPLIT_V': self.convert_split_v,
             'SQRT': self.convert_sqrt,
             'SQUARE': self.convert_square,
             'SQUARED_DIFFERENCE': self.convert_squared_difference,
@@ -1212,6 +1214,21 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_fill(self, op):
+        """Convert TFLite FILL"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+
+        if self.has_expr(input_tensors[0].tensor_idx):
+            raise tvm.error.OpNotImplemented("For dims parameter of Fill operator,"
+                                             " only constant values are supported.")
+
+        in_dims = list(self.get_tensor_value(input_tensors[0]))
+        in_value_expr = self.get_expr(input_tensors[1].tensor_idx)
+        out = _op.full(in_value_expr, in_dims)
+
+        return out
+
     def _convert_reduce(self, relay_op, op):
         """Generic method to Convert TFLite MEAN operators"""
         try:
@@ -1617,6 +1634,35 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_split_v(self, op):
+        """SPLIT_V implementation."""
+        input_tensors = self.get_input_tensors(op)
+
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+
+        input_tensor = input_tensors[0]
+        input_tensor_idx = input_tensor.tensor_idx
+        in_expr = self.get_expr(input_tensor_idx)
+
+        if self.has_expr(input_tensors[1].tensor_idx):
+            raise tvm.error.OpNotImplemented("For size_splits parameter of SPLIT_V operator, "
+                                             "only constant values are supported.")
+        size_splits = list(self.get_tensor_value(input_tensors[1]))
+        size_splits = tuple(np.cumsum(size_splits)[:-1])
+
+        axis_tensor = input_tensors[2]
+        split_axis = self.get_tensor_value(axis_tensor)
+
+        out = _op.split(in_expr, size_splits, axis=int(split_axis))
+        # Relay does not like a TupleWrapper of 1 element, further this
+        # only shows up with tf1.13 if we use a split with num_splits==1.
+        # In tf 1.14 this doesn't appear as it is automatically a reshape
+        # operation.
+        if isinstance(out, _expr.TupleWrapper) and out.size == 1:
+            out = out[0]
+
+        return out
+
     def convert_slice(self, op):
         """Convert TFLite SLICE"""
         input_tensors = self.get_input_tensors(op)
index eb65d82..7ff4c31 100644 (file)
@@ -216,15 +216,19 @@ def with_fused_activation_function(input_tensor, fn_name):
         return math_ops.tanh(input_tensor)
     raise AssertionError("Unknown fused_activation_function {}".format(fn_name))
 
-def _test_split(in_shape, axis, num_Splits, dtype):
-    '''internal split tester taking as parameters in_shape, number of tensors to split into
-       and dtype (data type)'''
+
+def _test_split(in_shape, axis, num_splits, dtype):
+    """internal split tester taking as parameters in_shape, number of tensors to split into
+       and dtype (data type)"""
+
     np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=in_shape, dtype=dtype)
-        out = array_ops.split(in_data, num_Splits, axis=axis)
-        out_names = ['out_' + str(n) + ':0' for n in range(num_Splits)]
-        compare_tflite_with_tvm([np_data], ['Placeholder:0'],  [in_data], out,
+        in_data = array_ops.placeholder(shape=in_shape, dtype=dtype, name="in_data")
+        out = array_ops.split(in_data, num_splits, axis=axis)
+        num_splits = len(num_splits) if isinstance(num_splits, list) \
+            else num_splits
+        out_names = ['out_' + str(n) + ':0' for n in range(num_splits)]
+        compare_tflite_with_tvm([np_data], ['in_data'],  [in_data], out,
                                 out_names=out_names)
 
 def test_forward_split():
@@ -252,6 +256,9 @@ def test_forward_split():
     _test_split((1, 6, 3, 5), -3, 3, 'float32')
     _test_split((1, 3, 6, 5), -2, 3, 'float32')
     _test_split((1, 3, 5, 6), -1, 3, 'float32')
+    # size_splits split
+    _test_split((6,), 0, [1, 2, 3], 'float32')
+    _test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
 
 #######################################################################
 # slice
@@ -1210,6 +1217,39 @@ def test_forward_zeros_like():
     """ ZEROS LIKE """
     _test_zeros_like(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
 
+
+#######################################################################
+# Fill
+# ----
+
+def _test_fill(dims, value_data, value_dtype):
+    """ Use the fill op to create a tensor of value_data with constant dims."""
+
+    value_data = np.array(value_data, dtype=value_dtype)
+    # TF 1.13 TFLite convert method does not accept empty shapes
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        with tf.Graph().as_default():
+            value = array_ops.placeholder(dtype=value_dtype, name="value", shape=[])
+            out = tf.fill(dims,  value)
+            compare_tflite_with_tvm([value_data], ["value"], [value], [out])
+
+    with tf.Graph().as_default():
+        input1 = array_ops.placeholder(dtype=value_dtype, name="input1", shape=dims)
+        # Fill op gets converted to static tensor during conversion
+        out = tf.fill(dims,  value_data)
+        out1 = tf.add(out, input1)
+        input1_data = np.random.uniform(0, 5, size=dims).astype(value_dtype)
+        compare_tflite_with_tvm([input1_data], ["input1"], [input1], [out1])
+
+
+def test_forward_fill():
+    """ Test FILL op """
+
+    _test_fill((1, 2, 2, 4), 5, "int32")
+    _test_fill((1, 2, 2, 4), 5, "float32")
+    _test_fill((5, ), 5, "int32")
+
+
 #######################################################################
 # Reduce
 # ------
@@ -1980,6 +2020,9 @@ if __name__ == '__main__':
     # Zeros Like
     test_forward_zeros_like()
 
+    # Fill
+    test_forward_fill()
+
     # Reduce
     test_all_reduce()