[Relay][Frontend][TFLite] transpose implementation for tflite.py (#3705)
authorNeo Chien <cchung100m@cs.ccu.edu.tw>
Mon, 19 Aug 2019 17:20:38 +0000 (01:20 +0800)
committerHaichen Shen <shenhaichen@gmail.com>
Mon, 19 Aug 2019 17:20:38 +0000 (10:20 -0700)
* transpose implementation for tflite.py

* add TRANSPOSE to convert_map

* Fix Unexpected keyword argument 'axis' in function call

* add test for transpose oprator

* Add the parameter 'axes' handling

* add test for transpose oprator

* solve conflict within CONTRIBUTORS.md

* Improve the if condition for empty tuple

* Add one unit test to cover empty tuple

* solve conflict within CONTRIBUTORS.md

CONTRIBUTORS.md
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 6211151..03f5b1b 100644 (file)
@@ -113,3 +113,4 @@ We do encourage everyone to work anything they are interested in.
 - [Cody Hao Yu](https://github.com/comaniac)
 - [Chris Nuernberger](https://github.com/cnuernber)
 - [Shoubhik Bhattacharya](https://github.com/shoubhik)
+- [Neo Chien](https://github.com/cchung100m)
index eed3d81..162cc36 100644 (file)
@@ -81,7 +81,8 @@ class OperatorConverter(object):
             'PAD': self.convert_pad,
             'PACK': self.convert_pack,
             'LOGISTIC': self.convert_logistic,
-            'SPLIT': self.convert_split
+            'SPLIT': self.convert_split,
+            'TRANSPOSE': self.convert_transpose
         }
 
     def check_unsupported_ops(self):
@@ -743,6 +744,31 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_transpose(self, op):
+        """transpose implementation."""
+        try:
+            from tflite.Operator import Operator
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        input_tensor = input_tensors[0]
+        input_tensor_idx = input_tensor.tensor_idx
+
+        in_expr = self.get_expr(input_tensor_idx)
+
+        # axis
+        in_axis = tuple(self.get_tensor_value(input_tensors[1]))
+
+        if not in_axis:
+            out = _op.transpose(in_expr)
+        else:
+            out = _op.transpose(in_expr, in_axis)
+
+        return out
+
     def convert_pool2d(self, op, pool_type):
         """pool2d implementation."""
         try:
index 2c356d8..a78225c 100644 (file)
@@ -202,6 +202,35 @@ def test_forward_split():
     _test_split((1, 3, 5, 6), -1, 3, 'float32')
 
 #######################################################################
+# transpose
+# ---------
+
+
+def _test_forward_transpose(ishape, axes=()):
+    data = np.random.uniform(size=ishape).astype(np.float32)
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+
+        if not axes:
+            out = array_ops.transpose(in_data)
+        else:
+            out = array_ops.transpose(in_data, axes)
+
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+
+def test_forward_transpose():
+    _test_forward_transpose((2, 2))
+    _test_forward_transpose((2, 3, 4))
+    _test_forward_transpose((7, 8, 8, 10))
+    _test_forward_transpose((2, 3, 4), (1, 2, 0))
+    _test_forward_transpose((2, 3, 4), (0, 1, 2))
+    _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
+    _test_forward_transpose((2, 3, 4, 5), ())
+
+
+#######################################################################
 # Pooling
 # -------
 def _test_pooling_iteration(input_shape, **kwargs):
@@ -823,6 +852,10 @@ def test_forward_ssd_mobilenet_v1():
 if __name__ == '__main__':
     # Split
     test_forward_split()
+
+    # Transpose
+    test_forward_transpose()
+
     # Transforms
     test_forward_concatenation()
     test_forward_pad()