Implementation of tile for TFLite (#3814)
authorNeo Chien <cchung100m@cs.ccu.edu.tw>
Sun, 1 Sep 2019 01:56:39 +0000 (09:56 +0800)
committerJared Roesch <roeschinc@gmail.com>
Sun, 1 Sep 2019 01:56:39 +0000 (18:56 -0700)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index f4c10f2..e0f9775 100644 (file)
@@ -82,7 +82,8 @@ class OperatorConverter(object):
             'PACK': self.convert_pack,
             'LOGISTIC': self.convert_logistic,
             'SPLIT': self.convert_split,
-            'TRANSPOSE': self.convert_transpose
+            'TRANSPOSE': self.convert_transpose,
+            'TILE': self.convert_tile
         }
 
     def check_unsupported_ops(self):
@@ -769,6 +770,28 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_tile(self, op):
+        """tile implementation."""
+        try:
+            from tflite.Operator import Operator
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        input_tensor = input_tensors[0]
+        input_tensor_idx = input_tensor.tensor_idx
+
+        in_expr = self.get_expr(input_tensor_idx)
+
+        # reps (tuple of int) – The number of times repeating the tensor data.
+        reps = tuple(self.get_tensor_value(input_tensors[1]))
+
+        out = _op.tile(in_expr, reps)
+
+        return out
+
     def convert_pool2d(self, op, pool_type):
         """pool2d implementation."""
         try:
index a78225c..7712260 100644 (file)
@@ -229,6 +229,26 @@ def test_forward_transpose():
     _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
     _test_forward_transpose((2, 3, 4, 5), ())
 
+#######################################################################
+# tile
+# ---------
+
+
+def _test_forward_tile(in_shape, reps, dtype):
+    data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+
+        out = array_ops.tile(in_data, reps)
+
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+
+def test_forward_tile():
+    _test_forward_tile((2, ), (3, ), "int32")
+    _test_forward_tile((2, 2), (2, 3), "float32")
+
 
 #######################################################################
 # Pooling
@@ -856,6 +876,9 @@ if __name__ == '__main__':
     # Transpose
     test_forward_transpose()
 
+    # Tile
+    test_forward_tile()
+
     # Transforms
     test_forward_concatenation()
     test_forward_pad()