'RELU':self.convert_relu,
'SPLIT': self.convert_split,
'TRANSPOSE': self.convert_transpose,
+ 'CAST': self.convert_cast,
'TILE': self.convert_tile,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd
if tensor_wrapper.tensor.Type() == TensorType.INT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
+ if tensor_wrapper.tensor.Type() == TensorType.INT64:
+ return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
+ tensor_wrapper.tensor.ShapeAsNumpy())
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_wrapper.tensor.Type())))
return "float32"
if tensor_type == TensorType.INT32:
return "int32"
+ if tensor_type == TensorType.INT64:
+ return "int64"
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_type)))
return out
+ def convert_cast(self, op):
+ """Convert TFLite CAST"""
+ try:
+ from tflite.Operator import Operator
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.CastOptions import CastOptions
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions
+ op_options = op.BuiltinOptions()
+ cast_options = CastOptions()
+ cast_options.Init(op_options.Bytes, op_options.Pos)
+ cast_dtype = cast_options.OutDataType()
+
+ out = _op.cast(in_expr, self.get_tensor_type_str(cast_dtype))
+
+ return out
+
def convert_tile(self, op):
"""tile implementation."""
try:
_test_forward_transpose((2, 3, 4, 5), ())
#######################################################################
+# Cast
+# --------
+
+def _test_cast(data, cast_dtype):
+ """ One iteration of CAST """
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ out = math_ops.cast(in_data, cast_dtype)
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+
+def test_forward_cast():
+ """ CAST """
+ _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32)
+ _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8)
+ _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)
+
+#######################################################################
# tile
# ---------
# Transpose
test_forward_transpose()
+ # Cast
+ test_forward_cast()
+
# Tile
test_forward_tile()