'PAD': self.convert_pad,
'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic,
- 'SPLIT': self.convert_split
+ 'SPLIT': self.convert_split,
+ 'TRANSPOSE': self.convert_transpose
}
def check_unsupported_ops(self):
return out
+ def convert_transpose(self, op):
+ """transpose implementation."""
+ try:
+ from tflite.Operator import Operator
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ input_tensor = input_tensors[0]
+ input_tensor_idx = input_tensor.tensor_idx
+
+ in_expr = self.get_expr(input_tensor_idx)
+
+ # axis
+ in_axis = tuple(self.get_tensor_value(input_tensors[1]))
+
+ if not in_axis:
+ out = _op.transpose(in_expr)
+ else:
+ out = _op.transpose(in_expr, in_axis)
+
+ return out
+
def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
_test_split((1, 3, 5, 6), -1, 3, 'float32')
#######################################################################
+# transpose
+# ---------
+
+
+def _test_forward_transpose(ishape, axes=()):
+ data = np.random.uniform(size=ishape).astype(np.float32)
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+
+ if not axes:
+ out = array_ops.transpose(in_data)
+ else:
+ out = array_ops.transpose(in_data, axes)
+
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+
+def test_forward_transpose():
+ _test_forward_transpose((2, 2))
+ _test_forward_transpose((2, 3, 4))
+ _test_forward_transpose((7, 8, 8, 10))
+ _test_forward_transpose((2, 3, 4), (1, 2, 0))
+ _test_forward_transpose((2, 3, 4), (0, 1, 2))
+ _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
+ _test_forward_transpose((2, 3, 4, 5), ())
+
+
+#######################################################################
# Pooling
# -------
def _test_pooling_iteration(input_shape, **kwargs):
if __name__ == '__main__':
# Split
test_forward_split()
+
+ # Transpose
+ test_forward_transpose()
+
# Transforms
test_forward_concatenation()
test_forward_pad()