From b63267b92d942b9c64f814b73567b2fe908e67fb Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Wed, 22 May 2019 10:44:47 +0800 Subject: [PATCH] [TFLite] Convert TFLite NCHW to NHWC (#3141) * Convert TFLite NCHW to NHWC * Minor comment fix --- python/tvm/relay/frontend/tflite.py | 120 ++++---------------------- tests/python/frontend/tflite/test_forward.py | 123 +++++---------------------- tutorials/frontend/from_tflite.py | 19 ++--- 3 files changed, 41 insertions(+), 221 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ff62d89..bfd63bb 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -209,44 +209,10 @@ class OperatorConverter(object): reshape_options = ReshapeOptions() reshape_options.Init(op_options.Bytes, op_options.Pos) target_shape = reshape_options.NewShapeAsNumpy() - input_shape_length = len(input_tensor.tensor.ShapeAsNumpy()) in_expr = self.get_expr(input_tensor_idx) - - if input_shape_length in (1, 2): - # The rule is channel first (after N but before H, W). - # length of 1 means N*H*W*C, do nothing. - # length of 2 means N*H*W, C, do nothing. - pass - elif input_shape_length == 3: - # convert N C H*W to N H*W C - in_expr = _op.transpose(in_expr, axes=(0, 2, 1)) - elif input_shape_length == 4: - # convert input to N H W C, then reshape to target shape, - # finally convert back if necessary - in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) - else: - msg = 'Input shape length {} for operator Reshape is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) - out = _op.reshape(in_expr, newshape=tuple(target_shape)) - # The rule is channel first. - # 1: N*H*W*C - # 2: N*H*W, C - # 3: N H W C, reshape to N H*W C, transpose to N C H*W - # 4: N H W C, transpose to N C H W - # add more if we need target shapes in future - if len(target_shape) == 1 or len(target_shape) == 2: - pass - elif len(target_shape) == 3: - out = _op.transpose(out, axes=(0, 2, 1)) - elif len(target_shape) == 4: - out = _op.transpose(out, axes=(0, 3, 1, 2)) - else: - raise tvm.error.OpAttributeInvalid( - 'Length of target shape must be between 1 and 5 for operator Reshape.') - return out def convert_softmax(self, op): @@ -269,7 +235,7 @@ class OperatorConverter(object): return out def convert_concatenation(self, op): - """ convert TFLite concatenation""" + """Convert TFLite concatenation""" try: from tflite.Operator import Operator from tflite.ConcatenationOptions import ConcatenationOptions @@ -292,15 +258,6 @@ class OperatorConverter(object): concatenation_options.Init(op_options.Bytes, op_options.Pos) concatenation_axis = concatenation_options.Axis() fused_activation_fn = concatenation_options.FusedActivationFunction() - input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy()) - - # TFLite is N H W C, our layout is N C H W - if input_shape_length <= 4: - axis_convert_map = [0] + list(range(2, input_shape_length)) + [1] - concatenation_axis = axis_convert_map[concatenation_axis] - else: - raise NotImplementedError("Not support input shape length {} of concatenatio : " - .format(str(input_shape_length))) # with axis in N H W C out = _op.concatenate(in_exprs, axis=concatenation_axis) @@ -336,20 +293,6 @@ class OperatorConverter(object): rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), dtype=rhs_type_str) - # In this case, we have to be careful about formatting. - input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy()) - if input_shape_length in (1, 2): - pass - elif input_shape_length == 3: - # N H*W C to N C H*W - rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1)) - elif input_shape_length == 4: - # N H W C to N C H W - rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2)) - else: - msg = 'Input shape length {} for operator ADD is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) - out = _op.add(lhs_expr, rhs_expr) return out @@ -440,46 +383,10 @@ class OperatorConverter(object): squeeze_options = SqueezeOptions() squeeze_options.Init(op_options.Bytes, op_options.Pos) squeeze_axis = squeeze_options.SqueezeDimsAsNumpy() - input_shape_length = len(input_tensor.tensor.ShapeAsNumpy()) - output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy()) in_expr = self.get_expr(input_tensor_idx) - - # TFLite is N H W C, our layout is N C H W - if input_shape_length in (1, 2): - # The rule is channel first (after N but before H, W). - # length of 1 means N*H*W*C, do nothing. - # length of 2 means N*H*W, C, do nothing. - pass - elif input_shape_length == 3: - # convert N C H*W to N H*W C - in_expr = _op.transpose(in_expr, axes=(0, 2, 1)) - elif input_shape_length == 4: - # convert input to N H W C, then reshape to target shape, - # finally convert back if necessary - in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) - else: - msg = 'Input shape length {} for operator Squeeze is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) - out = _op.squeeze(in_expr, axis=tuple(squeeze_axis)) - # The rule is channel first. - # 1: N*H*W*C - # 2: N*H*W, C - # 3: N H W C, reshape to N H*W C, transpose to N C H*W - # 4: N H W C, transpose to N C H W - # add more if we need target shapes in future - if output_shape_length in (1, 2): - pass - elif output_shape_length == 3: - out = _op.transpose(out, axes=(0, 2, 1)) - elif output_shape_length == 4: - out = _op.transpose(out, axes=(0, 3, 1, 2)) - else: - msg = 'Output shape length {} for operator Squeeze is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length)) - return out def convert_fused_activation_function(self, in_expr, fused_activation_fn): @@ -562,13 +469,16 @@ class OperatorConverter(object): params = {'kernel_size': [kernel_h, kernel_w], 'strides': [stride_h, stride_w], 'dilation': [dilation_h, dilation_w], - 'padding': [0, 0]} + 'padding': [0, 0], + 'data_layout': 'NHWC'} if is_depthwise_conv: params['channels'] = int(in_channels * multiplier) params['groups'] = int(in_channels) + params['kernel_layout'] = 'HWOI' else: params['channels'] = int(output_channels) + params['kernel_layout'] = 'HWIO' # weight tensor type should be UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() @@ -578,12 +488,9 @@ class OperatorConverter(object): in_expr = self.get_expr(input_tensor_idx) weight_value = self.get_tensor_value(weight_tensor) - if is_depthwise_conv: - # TFLite is M KH KW IC, we require IC M KH KW - weight_value = weight_value.transpose((3, 0, 1, 2)) - else: - # TFLite is OC KH KW IC, we require OC IC KH kW - weight_value = weight_value.transpose((0, 3, 1, 2)) + # TFLite is OC/M KH KW IC, we require KH KW IC OC/M + # M means multiplier in depthwise convolution + weight_value = weight_value.transpose((1, 2, 3, 0)) weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) @@ -592,9 +499,10 @@ class OperatorConverter(object): elif padding == Padding.SAME: pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h) pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) - in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (0, 0), + in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (pad_top, pad_bottom), - (pad_left, pad_right))) + (pad_left, pad_right), + (0, 0))) else: raise tvm.error.OpAttributeUnimplemented( 'Padding format {} is not supported for operator Conv.'.format(padding)) @@ -610,7 +518,8 @@ class OperatorConverter(object): bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str) - out = _op.nn.bias_add(out, bias_expr) + channel_axis = 3 + out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) # If we have fused activations if fused_activation_fn != ActivationFunctionType.NONE: @@ -648,7 +557,8 @@ class OperatorConverter(object): params = {'pool_size': (filter_h, filter_w), 'strides': (stride_h, stride_w), - 'padding': [0, 0]} + 'padding': [0, 0], + 'layout': 'NHWC'} in_expr = self.get_expr(input_tensor_idx) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 63a345a..8fc2d55 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -116,12 +116,10 @@ def run_tflite_graph(tflite_model_buf, input_data): return tflite_output -def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, - output_tensors, output_need_transpose=False, - init_global_variables=False): +def compare_tflite_with_tvm(in_data, in_name, input_tensors, + output_tensors, init_global_variables=False): """Generic function to generate and compare TFLite and TVM output""" - tflite_in_data = convert_to_list(tflite_in_data) - tvm_in_data = convert_to_list(tvm_in_data) + in_data = convert_to_list(in_data) in_name = convert_to_list(in_name) in_node = [0] * len(in_name) for i in range(len(in_name)): @@ -134,7 +132,7 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, converter = tf.contrib.lite.TFLiteConverter.from_session( sess, input_tensors, output_tensors) tflite_model_buffer = converter.convert() - tflite_output = run_tflite_graph(tflite_model_buffer, tflite_in_data) + tflite_output = run_tflite_graph(tflite_model_buffer, in_data) for device in ["llvm"]: ctx = tvm.context(device, 0) @@ -142,25 +140,9 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, print("Skip because %s is not enabled" % device) continue - tvm_output = run_tvm_graph(tflite_model_buffer, tvm_in_data, in_node, target=device) + tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device) for i in range(len(tflite_output)): - if output_need_transpose: - dim = len(tvm_output[i].shape) - if dim == 3: - # N C H*W to N H*W C - axes = (0, 2, 1) - elif dim == 4: - # N C H W to N H W C - axes = (0, 2, 3, 1) - else: - raise NotImplementedError("Not support input shape {} of transpose : ". - format(str(dim))) - tvm.testing.assert_allclose(tflite_output[i], - np.transpose(tvm_output[i], axes=axes), - atol=1e-5, rtol=1e-5) - else: - tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], - atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) sess.close() @@ -173,14 +155,12 @@ def _test_pooling_iteration(input_shape, **kwargs): x = -np.arange( np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 - tvm_data = np.transpose(x, axes=(0, 3, 1, 2)) with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=input_shape, dtype='float32') out = nn_ops.pool(in_data, **kwargs) - compare_tflite_with_tvm(x, tvm_data, 'Placeholder:0', [in_data], [out], - output_need_transpose=True) + compare_tflite_with_tvm(x,'Placeholder:0', [in_data], [out]) def _test_pooling(input_shape, **kwargs): @@ -258,13 +238,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, strides=strides, padding=padding, data_format=data_format) - # TFLite is NHWC, TVM is NCHW - tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') - tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2)) - # TFLite output is NHWC, TVM is NCHW, we need transpose - compare_tflite_with_tvm(tflite_data_array, tvm_data_array, - 'Placeholder:0', [in_data], [out], - output_need_transpose=True) + data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') + compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out]) def test_forward_convolution(): @@ -286,22 +261,11 @@ def test_forward_convolution(): def _test_reshape(data, out_shape): """ One iteration of reshape operation with given data and out shape """ - # see relay/frontend/tflite.py convert_reshape more detail of channel first rule - if len(data.shape) == 1 or len(data.shape) == 2: - tvm_data = data - elif len(data.shape) == 3: - tvm_data = np.transpose(data, axes=(0, 2, 1)) - elif len(data.shape) == 4: - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) - else: - raise NotImplementedError("Not support input shape {} of reshape : ". - format(str(len(data)))) - with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) out = array_ops.reshape(in_data, out_shape) - compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out]) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) def test_forward_reshape(): @@ -319,18 +283,6 @@ def _test_concatenation(data, axis): """ One iteration of concatenation """ assert len(data) >= 1 - need_transpose = False - if len(data[0].shape) == 1 or len(data[0].shape) == 2: - tvm_data = data - elif len(data[0].shape) == 3: - #need_transpose = True - tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data] - elif len(data[0].shape) == 4: - need_transpose = True - tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data] - else: - raise NotImplementedError("Not support input shape {} of reshape : ". - format(str(len(data)))) with tf.Graph().as_default(): in_data = [ @@ -339,7 +291,7 @@ def _test_concatenation(data, axis): out = array_ops.concat(in_data, axis=axis) name = ["in_{}:0".format(idx) for idx in range(len(data))] - compare_tflite_with_tvm(data, tvm_data, name, in_data, [out], need_transpose) + compare_tflite_with_tvm(data, name, in_data, [out]) def test_forward_concatenation(): @@ -366,33 +318,19 @@ def _test_add(data): """ One iteration of add """ assert len(data) == 2 - need_transpose = False - if len(data[0].shape) == 1 or len(data[0].shape) == 2: - tvm_data = data - elif len(data[0].shape) == 3: - need_transpose = True - tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data] - elif len(data[0].shape) == 4: - need_transpose = True - tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data] - else: - raise NotImplementedError("Not support input shape {} of add : ". - format(str(len(data.shape)))) # Test with two tensors with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'), array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')] out = math_ops.add(in_data[0], in_data[1]) - compare_tflite_with_tvm(data, tvm_data, ['in_0:0','in_1:0'], - in_data, [out], need_transpose) + compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) # Test with tensor and constant with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) - compare_tflite_with_tvm([data[0]], [tvm_data[0]], ['in:0'], - in_data, [out], need_transpose) + compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) def test_forward_add(): @@ -415,19 +353,6 @@ def _test_squeeze(data, squeeze_dims=None): if squeeze_dims is None: squeeze_dims = [] - # see relay/frontend/tflite.py convert_squeeze more detail of channel first rule - if len(data.shape) == 1 or len(data.shape) == 2: - tvm_data = data - elif len(data.shape) == 3: - tvm_data = np.transpose(data, axes=(0, 2, 1)) - elif len(data.shape) == 4: - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) - else: - raise NotImplementedError("Not support input shape {} of reshape : ". - format(str(len(data.shape)))) - - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) - with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) @@ -436,7 +361,7 @@ def _test_squeeze(data, squeeze_dims=None): else: out = array_ops.squeeze(in_data) - compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out]) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) def test_forward_squeeze(): @@ -453,7 +378,7 @@ def _test_softmax(data): with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) out = nn_ops.softmax(in_data) - compare_tflite_with_tvm(data, data, 'Placeholder:0', [in_data], [out]) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) def test_forward_softmax(): """ Softmax """ @@ -496,10 +421,8 @@ def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None): in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype='float32') out = nn_ops.bias_add(out, in_bias) - tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') - tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2)) - compare_tflite_with_tvm(tflite_data_array, tvm_data_array, - 'Placeholder:0', [in_data], [out]) + data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') + compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out]) def test_forward_fully_connected(): @@ -523,9 +446,8 @@ def test_forward_mobilenet_v1(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) @@ -538,9 +460,8 @@ def test_forward_mobilenet_v2(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) @@ -557,9 +478,8 @@ def test_forward_inception_v3_net(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) @@ -572,9 +492,8 @@ def test_forward_inception_v4_net(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index 67edeb8..f8686e9 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -117,32 +117,23 @@ plt.imshow(resized_image) plt.show() image_data = np.asarray(resized_image).astype("float32") -# convert HWC to CHW -image_data = image_data.transpose((2, 0, 1)) - -# after expand_dims, we have format NCHW +# after expand_dims, we have format NHWC image_data = np.expand_dims(image_data, axis=0) # preprocess image as described here: # https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243 -image_data[:, 0, :, :] = 2.0 / 255.0 * image_data[:, 0, :, :] - 1 -image_data[:, 1, :, :] = 2.0 / 255.0 * image_data[:, 1, :, :] - 1 -image_data[:, 2, :, :] = 2.0 / 255.0 * image_data[:, 2, :, :] - 1 +image_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1 +image_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1 +image_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1 print('input', image_data.shape) -#################################################################### -# -# .. note:: Input layout: -# -# Currently, TVM TFLite frontend accepts ``NCHW`` as input layout. - ###################################################################### # Compile the model with relay # --------------------------------------------- # TFLite input tensor name, shape and type input_tensor = "input" -input_shape = (1, 3, 224, 224) +input_shape = (1, 224, 224, 3) input_dtype = "float32" # parse TFLite model and convert into Relay computation graph -- 2.7.4