Retrigger build. (#6304)
authorRishabh Jain <56974688+jainris@users.noreply.github.com>
Fri, 21 Aug 2020 17:32:42 +0000 (23:02 +0530)
committerGitHub <noreply@github.com>
Fri, 21 Aug 2020 17:32:42 +0000 (10:32 -0700)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index b8e961f..200352c 100644 (file)
@@ -136,6 +136,7 @@ class OperatorConverter(object):
             'ROUND': self.convert_round,
             'RSQRT': self.convert_rsqrt,
             'REVERSE_SEQUENCE': self.convert_reverse_sequence,
+            'REVERSE_V2': self.convert_reverse_v2,
             'SELECT': self.convert_select,
             'SHAPE': self.convert_shape,
             'SIN': self.convert_sin,
@@ -2972,6 +2973,22 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_reverse_v2(self, op):
+        """Convert TFLite REVERSE_V2"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensor's length should be 2"
+
+        input_expr = self.get_expr(input_tensors[0].tensor_idx)
+
+        # Getting axis value
+        axis = self.get_tensor_value(input_tensors[1])
+        if isinstance(axis, np.ndarray):
+            assert len(axis) == 1, "TFLite does not support multi-axis yet"
+            axis = int(axis)
+
+        out = _op.reverse(input_expr, axis)
+        return out
+
 
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
index 099e719..4d7fc06 100644 (file)
@@ -2627,6 +2627,32 @@ def test_forward_fully_connected():
 
 
 #######################################################################
+# REVERSE_V2
+# ----------
+
+def _test_reverse_v2(input_shape, axis, dtype):
+    """ One iteration of REVERSE_V2 """
+    with tf.Graph().as_default():
+        input = np.random.randint(0, 100, size=input_shape).astype(dtype)
+        in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")
+        in_axis = ops.convert_to_tensor(axis, dtype=axis.dtype)
+
+        out = array_ops.reverse(in_input, in_axis)
+
+        compare_tflite_with_tvm(
+            [input],
+            ["input"],
+            [in_input],
+            [out])
+
+def test_forward_reverse_v2():
+    """ REVERSE_V2 """
+    for dtype in ['float32', 'int32']:
+        _test_reverse_v2((5), np.array([0], dtype='int32'), dtype)
+        _test_reverse_v2((5, 6, 4, 2), np.array([2], dtype='int32'), dtype)
+
+
+#######################################################################
 # Custom Operators
 # ----------------
 
@@ -3104,6 +3130,7 @@ if __name__ == '__main__':
     test_forward_quantize_dequantize()
     test_forward_arg_min_max()
     test_forward_expand_dims()
+    test_forward_reverse_v2()
 
     # NN
     test_forward_convolution()