raise ImportError("The tflite package must be installed")
input_tensors = self.get_input_tensors(op)
- assert input_tensors, "input tensors should not be empty"
+ assert len(input_tensors) in (1, 2), "input tensors should not be empty"
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "There should be only 1 output tensor"
+
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
- assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions
- op_options = op.BuiltinOptions()
- reshape_options = ReshapeOptions()
- reshape_options.Init(op_options.Bytes, op_options.Pos)
- target_shape = reshape_options.NewShapeAsNumpy()
+ if len(input_tensors) == 2:
+ shape_tensor = input_tensors[1]
+ if self.has_expr(shape_tensor.tensor_idx):
+ target_shape = self.get_expr(shape_tensor.tensor_idx)
+ else:
+ target_shape = self.get_tensor_value(shape_tensor)
+ # convert to flattened list
+ from itertools import chain
+ try:
+ target_shape = list(chain(*target_shape))
+ except TypeError:
+ target_shape = list(chain(target_shape))
+
+ else:
+ assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions
+ op_options = op.BuiltinOptions()
+ reshape_options = ReshapeOptions()
+ reshape_options.Init(op_options.Bytes, op_options.Pos)
+ target_shape = tuple(reshape_options.NewShapeAsNumpy())
in_expr = self.get_expr(input_tensor_idx)
# If the tensors are quantized, ensure that input/output qnn params are same.
if input_tensor.qnn_params:
- output_tensors = self.get_output_tensors(op)
- assert len(output_tensors) == 1, "There should be only 1 output tensor"
output_tensor = output_tensors[0]
assert self.has_same_qnn_params(input_tensor, output_tensor), \
"TFLite reshape requires input and output scale and zero points to be equal"
- out = _op.reshape(in_expr, newshape=tuple(target_shape))
+
+ out = _op.reshape(in_expr, newshape=target_shape)
return out
def _convert_resize(self, method, op):
# Reshape
# -------
-def _test_reshape(data, out_shape):
+def _test_reshape(data, out_shape, wrap_shape):
""" One iteration of reshape operation with given data and out shape """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
- out = array_ops.reshape(in_data, out_shape)
- compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+ out_shape = out_shape if not wrap_shape\
+ else np.array(out_shape, dtype=np.int32)
+
+ in_shape = out_shape if not wrap_shape\
+ else array_ops.placeholder(shape=out_shape.shape,\
+ dtype=out_shape.dtype,\
+ name="Newshape")
+
+ out = array_ops.reshape(in_data, in_shape)
+
+ compare_tflite_with_tvm(
+ [data, out_shape] if wrap_shape else [data],\
+ ['Placeholder:0', 'Newshape:0'] if wrap_shape else ['Placeholder:0'],\
+ [in_data, in_shape] if wrap_shape else [in_data],\
+ [out],
+ mode='vm')
def test_forward_reshape():
- _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3])
- _test_reshape(np.arange(6), [-1, 2])
- _test_reshape(np.arange(6), [3, -1])
- _test_reshape(np.arange(6), [-1])
+ for wrap in [True, False]:
+ _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3], wrap)
+ _test_reshape(np.arange(6), [-1, 2], wrap)
+ _test_reshape(np.arange(6), [3, -1], wrap)
+ _test_reshape(np.arange(6), [-1], wrap)
#######################################################################