[TFLite] Convert TFLite NCHW to NHWC (#3141)
authorZhao Wu <wuzhaozju@gmail.com>
Wed, 22 May 2019 02:44:47 +0000 (10:44 +0800)
committerSiva <sivar.b@huawei.com>
Wed, 22 May 2019 02:44:47 +0000 (08:14 +0530)
* Convert TFLite NCHW to NHWC

* Minor comment fix

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py
tutorials/frontend/from_tflite.py

index ff62d89..bfd63bb 100644 (file)
@@ -209,44 +209,10 @@ class OperatorConverter(object):
         reshape_options = ReshapeOptions()
         reshape_options.Init(op_options.Bytes, op_options.Pos)
         target_shape = reshape_options.NewShapeAsNumpy()
-        input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())
 
         in_expr = self.get_expr(input_tensor_idx)
-
-        if input_shape_length in (1, 2):
-            # The rule is channel first (after N but before H, W).
-            # length of 1 means N*H*W*C, do nothing.
-            # length of 2 means N*H*W, C, do nothing.
-            pass
-        elif input_shape_length == 3:
-            # convert N C H*W to N H*W C
-            in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
-        elif input_shape_length == 4:
-            # convert input to N H W C, then reshape to target shape,
-            # finally convert back if necessary
-            in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
-        else:
-            msg = 'Input shape length {} for operator Reshape is not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
-
         out = _op.reshape(in_expr, newshape=tuple(target_shape))
 
-        # The rule is channel first.
-        # 1: N*H*W*C
-        # 2: N*H*W, C
-        # 3: N H W C, reshape to N H*W C, transpose to N C H*W
-        # 4: N H W C, transpose to N C H W
-        # add more if we need target shapes in future
-        if len(target_shape) == 1 or len(target_shape) == 2:
-            pass
-        elif len(target_shape) == 3:
-            out = _op.transpose(out, axes=(0, 2, 1))
-        elif len(target_shape) == 4:
-            out = _op.transpose(out, axes=(0, 3, 1, 2))
-        else:
-            raise tvm.error.OpAttributeInvalid(
-                'Length of target shape must be between 1 and 5 for operator Reshape.')
-
         return out
 
     def convert_softmax(self, op):
@@ -269,7 +235,7 @@ class OperatorConverter(object):
         return out
 
     def convert_concatenation(self, op):
-        """ convert TFLite concatenation"""
+        """Convert TFLite concatenation"""
         try:
             from tflite.Operator import Operator
             from tflite.ConcatenationOptions import ConcatenationOptions
@@ -292,15 +258,6 @@ class OperatorConverter(object):
         concatenation_options.Init(op_options.Bytes, op_options.Pos)
         concatenation_axis = concatenation_options.Axis()
         fused_activation_fn = concatenation_options.FusedActivationFunction()
-        input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy())
-
-        # TFLite is N H W C, our layout is N C H W
-        if input_shape_length <= 4:
-            axis_convert_map = [0] + list(range(2, input_shape_length)) + [1]
-            concatenation_axis = axis_convert_map[concatenation_axis]
-        else:
-            raise NotImplementedError("Not support input shape length {} of concatenatio : "
-                                      .format(str(input_shape_length)))
 
         # with axis in N H W C
         out = _op.concatenate(in_exprs, axis=concatenation_axis)
@@ -336,20 +293,6 @@ class OperatorConverter(object):
             rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
                                               dtype=rhs_type_str)
 
-            # In this case, we have to be careful about formatting.
-            input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy())
-            if input_shape_length in (1, 2):
-                pass
-            elif input_shape_length == 3:
-                # N H*W C to N C H*W
-                rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1))
-            elif input_shape_length == 4:
-                # N H W C to N C H W
-                rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2))
-            else:
-                msg = 'Input shape length {} for operator ADD is not valid.'
-                raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
-
         out = _op.add(lhs_expr, rhs_expr)
         return out
 
@@ -440,46 +383,10 @@ class OperatorConverter(object):
         squeeze_options = SqueezeOptions()
         squeeze_options.Init(op_options.Bytes, op_options.Pos)
         squeeze_axis = squeeze_options.SqueezeDimsAsNumpy()
-        input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())
-        output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy())
 
         in_expr = self.get_expr(input_tensor_idx)
-
-        # TFLite is N H W C, our layout is N C H W
-        if input_shape_length in (1, 2):
-            # The rule is channel first (after N but before H, W).
-            # length of 1 means N*H*W*C, do nothing.
-            # length of 2 means N*H*W, C, do nothing.
-            pass
-        elif input_shape_length == 3:
-            # convert N C H*W to N H*W C
-            in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
-        elif input_shape_length == 4:
-            # convert input to N H W C, then reshape to target shape,
-            # finally convert back if necessary
-            in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
-        else:
-            msg = 'Input shape length {} for operator Squeeze is not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
-
         out = _op.squeeze(in_expr, axis=tuple(squeeze_axis))
 
-        # The rule is channel first.
-        # 1: N*H*W*C
-        # 2: N*H*W, C
-        # 3: N H W C, reshape to N H*W C, transpose to N C H*W
-        # 4: N H W C, transpose to N C H W
-        # add more if we need target shapes in future
-        if output_shape_length in (1, 2):
-            pass
-        elif output_shape_length == 3:
-            out = _op.transpose(out, axes=(0, 2, 1))
-        elif output_shape_length == 4:
-            out = _op.transpose(out, axes=(0, 3, 1, 2))
-        else:
-            msg = 'Output shape length {} for operator Squeeze is not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length))
-
         return out
 
     def convert_fused_activation_function(self, in_expr, fused_activation_fn):
@@ -562,13 +469,16 @@ class OperatorConverter(object):
         params = {'kernel_size': [kernel_h, kernel_w],
                   'strides': [stride_h, stride_w],
                   'dilation': [dilation_h, dilation_w],
-                  'padding': [0, 0]}
+                  'padding': [0, 0],
+                  'data_layout': 'NHWC'}
 
         if is_depthwise_conv:
             params['channels'] = int(in_channels * multiplier)
             params['groups'] = int(in_channels)
+            params['kernel_layout'] = 'HWOI'
         else:
             params['channels'] = int(output_channels)
+            params['kernel_layout'] = 'HWIO'
 
         # weight tensor type should be UINT8 (quantization) or FLOAT32
         weight_tensor_type = weight_tensor.tensor.Type()
@@ -578,12 +488,9 @@ class OperatorConverter(object):
         in_expr = self.get_expr(input_tensor_idx)
         weight_value = self.get_tensor_value(weight_tensor)
 
-        if is_depthwise_conv:
-            # TFLite is M KH KW IC, we require IC M KH KW
-            weight_value = weight_value.transpose((3, 0, 1, 2))
-        else:
-            # TFLite is OC KH KW IC, we require OC IC KH kW
-            weight_value = weight_value.transpose((0, 3, 1, 2))
+        # TFLite is OC/M KH KW IC, we require KH KW IC OC/M
+        # M means multiplier in depthwise convolution
+        weight_value = weight_value.transpose((1, 2, 3, 0))
 
         weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
 
@@ -592,9 +499,10 @@ class OperatorConverter(object):
         elif padding == Padding.SAME:
             pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
             pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
-            in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (0, 0),
+            in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
                                                           (pad_top, pad_bottom),
-                                                          (pad_left, pad_right)))
+                                                          (pad_left, pad_right),
+                                                          (0, 0)))
         else:
             raise tvm.error.OpAttributeUnimplemented(
                 'Padding format {} is not supported for operator Conv.'.format(padding))
@@ -610,7 +518,8 @@ class OperatorConverter(object):
             bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
             bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
                                                dtype=bias_tensor_type_str)
-            out = _op.nn.bias_add(out, bias_expr)
+            channel_axis = 3
+            out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)
 
         # If we have fused activations
         if fused_activation_fn != ActivationFunctionType.NONE:
@@ -648,7 +557,8 @@ class OperatorConverter(object):
 
         params = {'pool_size': (filter_h, filter_w),
                   'strides': (stride_h, stride_w),
-                  'padding': [0, 0]}
+                  'padding': [0, 0],
+                  'layout': 'NHWC'}
 
         in_expr = self.get_expr(input_tensor_idx)
 
index 63a345a..8fc2d55 100644 (file)
@@ -116,12 +116,10 @@ def run_tflite_graph(tflite_model_buf, input_data):
     return tflite_output
 
 
-def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
-                            output_tensors, output_need_transpose=False,
-                            init_global_variables=False):
+def compare_tflite_with_tvm(in_data, in_name, input_tensors,
+                            output_tensors, init_global_variables=False):
     """Generic function to generate and compare TFLite and TVM output"""
-    tflite_in_data = convert_to_list(tflite_in_data)
-    tvm_in_data = convert_to_list(tvm_in_data)
+    in_data = convert_to_list(in_data)
     in_name = convert_to_list(in_name)
     in_node = [0] * len(in_name)
     for i in range(len(in_name)):
@@ -134,7 +132,7 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
         converter = tf.contrib.lite.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
         tflite_model_buffer = converter.convert()
-        tflite_output = run_tflite_graph(tflite_model_buffer, tflite_in_data)
+        tflite_output = run_tflite_graph(tflite_model_buffer, in_data)
 
         for device in ["llvm"]:
             ctx = tvm.context(device, 0)
@@ -142,25 +140,9 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
                 print("Skip because %s is not enabled" % device)
                 continue
 
-            tvm_output = run_tvm_graph(tflite_model_buffer, tvm_in_data, in_node, target=device)
+            tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device)
             for i in range(len(tflite_output)):
-                if output_need_transpose:
-                    dim = len(tvm_output[i].shape)
-                    if dim == 3:
-                        # N C H*W to N H*W C
-                        axes = (0, 2, 1)
-                    elif dim == 4:
-                        # N C H W to N H W C
-                        axes = (0, 2, 3, 1)
-                    else:
-                        raise NotImplementedError("Not support input shape {} of transpose : ".
-                                                  format(str(dim)))
-                    tvm.testing.assert_allclose(tflite_output[i],
-                                                np.transpose(tvm_output[i], axes=axes),
-                                                atol=1e-5, rtol=1e-5)
-                else:
-                    tvm.testing.assert_allclose(tflite_output[i], tvm_output[i],
-                                                atol=1e-5, rtol=1e-5)
+                tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
         sess.close()
 
@@ -173,14 +155,12 @@ def _test_pooling_iteration(input_shape, **kwargs):
 
     x = -np.arange(
         np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
-    tvm_data = np.transpose(x, axes=(0, 3, 1, 2))
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=input_shape, dtype='float32')
         out = nn_ops.pool(in_data, **kwargs)
 
-        compare_tflite_with_tvm(x, tvm_data, 'Placeholder:0', [in_data], [out],
-                                output_need_transpose=True)
+        compare_tflite_with_tvm(x,'Placeholder:0', [in_data], [out])
 
 
 def _test_pooling(input_shape, **kwargs):
@@ -258,13 +238,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
                                 strides=strides,
                                 padding=padding,
                                 data_format=data_format)
-        # TFLite is NHWC, TVM is NCHW
-        tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
-        tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2))
-        # TFLite output is NHWC, TVM is NCHW, we need transpose
-        compare_tflite_with_tvm(tflite_data_array, tvm_data_array,
-                                'Placeholder:0', [in_data], [out],
-                                output_need_transpose=True)
+        data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
+        compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])
 
 
 def test_forward_convolution():
@@ -286,22 +261,11 @@ def test_forward_convolution():
 
 def _test_reshape(data, out_shape):
     """ One iteration of reshape operation with given data and out shape """
-    # see relay/frontend/tflite.py convert_reshape more detail of channel first rule
-    if len(data.shape) == 1 or len(data.shape) == 2:
-        tvm_data = data
-    elif len(data.shape) == 3:
-        tvm_data = np.transpose(data, axes=(0, 2, 1))
-    elif len(data.shape) == 4:
-        tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
-    else:
-        raise NotImplementedError("Not support input shape {} of reshape : ".
-                                  format(str(len(data))))
-
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = array_ops.reshape(in_data, out_shape)
 
-        compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
 
 
 def test_forward_reshape():
@@ -319,18 +283,6 @@ def _test_concatenation(data, axis):
     """ One iteration of concatenation """
 
     assert len(data) >= 1
-    need_transpose = False
-    if len(data[0].shape) == 1 or len(data[0].shape) == 2:
-        tvm_data = data
-    elif len(data[0].shape) == 3:
-        #need_transpose = True
-        tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data]
-    elif len(data[0].shape) == 4:
-        need_transpose = True
-        tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data]
-    else:
-        raise NotImplementedError("Not support input shape {} of reshape : ".
-                                  format(str(len(data))))
 
     with tf.Graph().as_default():
         in_data = [
@@ -339,7 +291,7 @@ def _test_concatenation(data, axis):
         out = array_ops.concat(in_data, axis=axis)
         name = ["in_{}:0".format(idx) for idx in range(len(data))]
 
-        compare_tflite_with_tvm(data, tvm_data, name, in_data, [out], need_transpose)
+        compare_tflite_with_tvm(data, name, in_data, [out])
 
 
 def test_forward_concatenation():
@@ -366,33 +318,19 @@ def _test_add(data):
     """ One iteration of add """
 
     assert len(data) == 2
-    need_transpose = False
-    if len(data[0].shape) == 1 or len(data[0].shape) == 2:
-        tvm_data = data
-    elif len(data[0].shape) == 3:
-        need_transpose = True
-        tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data]
-    elif len(data[0].shape) == 4:
-        need_transpose = True
-        tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data]
-    else:
-        raise NotImplementedError("Not support input shape {} of add : ".
-                                  format(str(len(data.shape))))
 
     # Test with two tensors
     with tf.Graph().as_default():
         in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
                    array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
         out = math_ops.add(in_data[0], in_data[1])
-        compare_tflite_with_tvm(data, tvm_data, ['in_0:0','in_1:0'],
-                                in_data, [out], need_transpose)
+        compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
 
     # Test with tensor and constant
     with tf.Graph().as_default():
         in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
         out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
-        compare_tflite_with_tvm([data[0]], [tvm_data[0]], ['in:0'],
-                                in_data, [out], need_transpose)
+        compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
 
 
 def test_forward_add():
@@ -415,19 +353,6 @@ def _test_squeeze(data, squeeze_dims=None):
     if squeeze_dims is None:
         squeeze_dims = []
 
-    # see relay/frontend/tflite.py convert_squeeze more detail of channel first rule
-    if len(data.shape) == 1 or len(data.shape) == 2:
-        tvm_data = data
-    elif len(data.shape) == 3:
-        tvm_data = np.transpose(data, axes=(0, 2, 1))
-    elif len(data.shape) == 4:
-        tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
-    else:
-        raise NotImplementedError("Not support input shape {} of reshape : ".
-                                  format(str(len(data.shape))))
-
-    tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
-
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
 
@@ -436,7 +361,7 @@ def _test_squeeze(data, squeeze_dims=None):
         else:
             out = array_ops.squeeze(in_data)
 
-        compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
 
 
 def test_forward_squeeze():
@@ -453,7 +378,7 @@ def _test_softmax(data):
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = nn_ops.softmax(in_data)
-        compare_tflite_with_tvm(data, data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
 
 def test_forward_softmax():
     """ Softmax """
@@ -496,10 +421,8 @@ def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None):
             in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype='float32')
             out = nn_ops.bias_add(out, in_bias)
 
-        tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
-        tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2))
-        compare_tflite_with_tvm(tflite_data_array, tvm_data_array,
-                                'Placeholder:0', [in_data], [out])
+        data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
+        compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])
 
 
 def test_forward_fully_connected():
@@ -523,9 +446,8 @@ def test_forward_mobilenet_v1():
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
     data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
-    tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
+    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
     tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                 rtol=1e-5, atol=1e-5)
 
@@ -538,9 +460,8 @@ def test_forward_mobilenet_v2():
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
     data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
-    tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
+    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
     tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                 rtol=1e-5, atol=1e-5)
 
@@ -557,9 +478,8 @@ def test_forward_inception_v3_net():
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
     data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
-    tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
+    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
     tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                 rtol=1e-5, atol=1e-5)
 
@@ -572,9 +492,8 @@ def test_forward_inception_v4_net():
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
     data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
-    tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
+    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
     tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                 rtol=1e-5, atol=1e-5)
 
index 67edeb8..f8686e9 100644 (file)
@@ -117,32 +117,23 @@ plt.imshow(resized_image)
 plt.show()
 image_data = np.asarray(resized_image).astype("float32")
 
-# convert HWC to CHW
-image_data = image_data.transpose((2, 0, 1))
-
-# after expand_dims, we have format NCHW
+# after expand_dims, we have format NHWC
 image_data = np.expand_dims(image_data, axis=0)
 
 # preprocess image as described here:
 # https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243
-image_data[:, 0, :, :] = 2.0 / 255.0 * image_data[:, 0, :, :] - 1
-image_data[:, 1, :, :] = 2.0 / 255.0 * image_data[:, 1, :, :] - 1
-image_data[:, 2, :, :] = 2.0 / 255.0 * image_data[:, 2, :, :] - 1
+image_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1
+image_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1
+image_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1
 print('input', image_data.shape)
 
-####################################################################
-#
-# .. note:: Input layout:
-#
-#   Currently, TVM TFLite frontend accepts ``NCHW`` as input layout.
-
 ######################################################################
 # Compile the model with relay
 # ---------------------------------------------
 
 # TFLite input tensor name, shape and type
 input_tensor = "input"
-input_shape = (1, 3, 224, 224)
+input_shape = (1, 224, 224, 3)
 input_dtype = "float32"
 
 # parse TFLite model and convert into Relay computation graph