Reshape with dynamic shape arg (#6208)
authorDmitriy Smirnov <dmitriy.smirnov@arm.com>
Fri, 7 Aug 2020 03:08:35 +0000 (04:08 +0100)
committerGitHub <noreply@github.com>
Fri, 7 Aug 2020 03:08:35 +0000 (11:08 +0800)
Reshape operation updated to take shape from second operand.
In case if shape is provided using second operand it
can be a tensor now.

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 6e032b1..fe28741 100644 (file)
@@ -459,26 +459,43 @@ class OperatorConverter(object):
             raise ImportError("The tflite package must be installed")
 
         input_tensors = self.get_input_tensors(op)
-        assert input_tensors, "input tensors should not be empty"
+        assert len(input_tensors) in (1, 2), "input tensors should not be empty"
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "There should be only 1 output tensor"
+
         input_tensor = input_tensors[0]
         input_tensor_idx = input_tensor.tensor_idx
 
-        assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions
-        op_options = op.BuiltinOptions()
-        reshape_options = ReshapeOptions()
-        reshape_options.Init(op_options.Bytes, op_options.Pos)
-        target_shape = reshape_options.NewShapeAsNumpy()
+        if len(input_tensors) == 2:
+            shape_tensor = input_tensors[1]
+            if self.has_expr(shape_tensor.tensor_idx):
+                target_shape = self.get_expr(shape_tensor.tensor_idx)
+            else:
+                target_shape = self.get_tensor_value(shape_tensor)
+                # convert to flattened list
+                from itertools import chain
+                try:
+                    target_shape = list(chain(*target_shape))
+                except TypeError:
+                    target_shape = list(chain(target_shape))
+
+        else:
+            assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions
+            op_options = op.BuiltinOptions()
+            reshape_options = ReshapeOptions()
+            reshape_options.Init(op_options.Bytes, op_options.Pos)
+            target_shape = tuple(reshape_options.NewShapeAsNumpy())
 
         in_expr = self.get_expr(input_tensor_idx)
 
         # If the tensors are quantized, ensure that input/output qnn params are same.
         if input_tensor.qnn_params:
-            output_tensors = self.get_output_tensors(op)
-            assert len(output_tensors) == 1, "There should be only 1 output tensor"
             output_tensor = output_tensors[0]
             assert self.has_same_qnn_params(input_tensor, output_tensor), \
                     "TFLite reshape requires input and output scale and zero points to be equal"
-        out = _op.reshape(in_expr, newshape=tuple(target_shape))
+
+        out = _op.reshape(in_expr, newshape=target_shape)
         return out
 
     def _convert_resize(self, method, op):
index 603eb11..30a6631 100644 (file)
@@ -984,20 +984,35 @@ def test_forward_transpose_conv():
 # Reshape
 # -------
 
-def _test_reshape(data, out_shape):
+def _test_reshape(data, out_shape, wrap_shape):
     """ One iteration of reshape operation with given data and out shape """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
-        out = array_ops.reshape(in_data, out_shape)
 
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        out_shape = out_shape if not wrap_shape\
+            else np.array(out_shape, dtype=np.int32)
+
+        in_shape = out_shape if not wrap_shape\
+            else array_ops.placeholder(shape=out_shape.shape,\
+                                        dtype=out_shape.dtype,\
+                                        name="Newshape")
+
+        out = array_ops.reshape(in_data, in_shape)
+
+        compare_tflite_with_tvm(
+            [data, out_shape]               if wrap_shape else [data],\
+            ['Placeholder:0', 'Newshape:0'] if wrap_shape else ['Placeholder:0'],\
+            [in_data, in_shape]             if wrap_shape else [in_data],\
+            [out],
+            mode='vm')
 
 
 def test_forward_reshape():
-    _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3])
-    _test_reshape(np.arange(6), [-1, 2])
-    _test_reshape(np.arange(6), [3, -1])
-    _test_reshape(np.arange(6), [-1])
+    for wrap in [True, False]:
+        _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3], wrap)
+        _test_reshape(np.arange(6), [-1, 2], wrap)
+        _test_reshape(np.arange(6), [3, -1], wrap)
+        _test_reshape(np.arange(6), [-1], wrap)
 
 
 #######################################################################