Add parser support for CAST tflite operator (#4096)
authorIna Dobreva <55383260+inadob@users.noreply.github.com>
Sun, 13 Oct 2019 05:06:50 +0000 (06:06 +0100)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sun, 13 Oct 2019 05:06:50 +0000 (22:06 -0700)
This implementation provides cast to limited number of dtypes
that tflite currently supports for placeholder op. Add INT64 in the
possible dtypes as it appears to be supported accrording to tlfite schema.

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index a519c6f..79ad917 100644 (file)
@@ -88,6 +88,7 @@ class OperatorConverter(object):
             'RELU':self.convert_relu,
             'SPLIT': self.convert_split,
             'TRANSPOSE': self.convert_transpose,
+            'CAST': self.convert_cast,
             'TILE': self.convert_tile,
             'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
             'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd
@@ -181,6 +182,9 @@ class OperatorConverter(object):
         if tensor_wrapper.tensor.Type() == TensorType.INT32:
             return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
                 tensor_wrapper.tensor.ShapeAsNumpy())
+        if tensor_wrapper.tensor.Type() == TensorType.INT64:
+            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
+                tensor_wrapper.tensor.ShapeAsNumpy())
         raise NotImplementedError("Tensor type {} is currently not supported"
                                   .format(str(tensor_wrapper.tensor.Type())))
 
@@ -197,6 +201,8 @@ class OperatorConverter(object):
             return "float32"
         if tensor_type == TensorType.INT32:
             return "int32"
+        if tensor_type == TensorType.INT64:
+            return "int64"
         raise NotImplementedError("Tensor type {} is currently not supported"
                                   .format(str(tensor_type)))
 
@@ -840,6 +846,31 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_cast(self, op):
+        """Convert TFLite CAST"""
+        try:
+            from tflite.Operator import Operator
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.CastOptions import CastOptions
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+
+        assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions
+        op_options = op.BuiltinOptions()
+        cast_options = CastOptions()
+        cast_options.Init(op_options.Bytes, op_options.Pos)
+        cast_dtype = cast_options.OutDataType()
+
+        out = _op.cast(in_expr, self.get_tensor_type_str(cast_dtype))
+
+        return out
+
     def convert_tile(self, op):
         """tile implementation."""
         try:
index 670e85b..e4013d3 100644 (file)
@@ -231,6 +231,24 @@ def test_forward_transpose():
     _test_forward_transpose((2, 3, 4, 5), ())
 
 #######################################################################
+# Cast
+# --------
+
+def _test_cast(data, cast_dtype):
+    """ One iteration of CAST """
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        out = math_ops.cast(in_data, cast_dtype)
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+
+def test_forward_cast():
+    """ CAST """
+    _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32)
+    _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8)
+    _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)
+
+#######################################################################
 # tile
 # ---------
 
@@ -1013,6 +1031,9 @@ if __name__ == '__main__':
     # Transpose
     test_forward_transpose()
 
+    # Cast
+    test_forward_cast()
+
     # Tile
     test_forward_tile()