From aeef16d2857fc00ab83ba28476281e99a0a78e0c Mon Sep 17 00:00:00 2001 From: Rishabh Jain <56974688+jainris@users.noreply.github.com> Date: Thu, 10 Sep 2020 23:03:11 +0530 Subject: [PATCH] [QNN][Relay] Fixed bug in quantized conv2d. (#6420) * Fixed bug in quantized conv2d where when kernel size = (1,1) and strides != (1,1) it would raise size mismatch error. * Added test to check qnn.conv2d with kernel size = (1,1) and strides != (1,1). --- src/relay/qnn/op/convolution.cc | 23 +++++++++++++++++++++-- tests/python/relay/test_op_qnn_conv2d.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 5d2e360..847f81f 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -258,7 +258,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum. // Since, this is integer division (floor), we can first multiply the data by the pool_size and // then perform avg_pool2d. Reversing this causes inaccuracy due to floor division. If the - // pool_size is 1x1, we don't need avg_pool2d. + // pool_size and strides are 1x1, we don't need avg_pool2d. auto reduced_t2 = casted_t2; if (kernel_h * kernel_w != 1) { auto scaled_hw_t2 = @@ -268,6 +268,16 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, padding, param->data_layout, false, // ceil_mode false); // count_include_pad + } else { + int stride1 = get_const_int(param->strides[0]); + int stride2 = get_const_int(param->strides[1]); + if (stride1 * stride2 != 1) { + Array padding({0, 0}); + reduced_t2 = + AvgPool2D(reduced_t2, param->kernel_size, param->strides, padding, param->data_layout, + false, // ceil_mode + false); // count_include_pad + } } auto multiplied_t2 = reduced_t2; @@ -414,7 +424,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, // Keep dims true to retain 4D tensor auto reduced_c_t2 = Sum(casted_t2, axes_t2, true, false); - // If the pool_size is 1x1, we don't need avg_pool2d. + // If the pool_size and strides are 1x1, we don't need avg_pool2d. auto reduced_t2 = reduced_c_t2; if (kernel_h * kernel_w != 1) { reduced_c_t2 = @@ -423,6 +433,15 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, padding, param->data_layout, false, // ceil_mode false); // count_include_pad + } else { + int stride1 = get_const_int(param->strides[0]); + int stride2 = get_const_int(param->strides[1]); + if (stride1 * stride2 != 1) { + reduced_t2 = + AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, padding, param->data_layout, + false, // ceil_mode + false); // count_include_pad + } } auto multiplied_t2 = reduced_t2; diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index fcb335f..bb848e9 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -629,6 +629,33 @@ def test_kernel_size_1x1(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) +def test_kernel_size_1x1_strides_2(): + with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): + + # uint8 input + data_shape = (2, 4, 2, 4) + data_dtype = 'uint8' + kernel_shape = (3, 4, 1, 1) + kernel_dtype = 'uint8' + ref_func, qnn_func = get_funcs(data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(1, 1), + padding=(0, 0), + strides=(2, 2), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32") + assert 'avg_pool2d' not in qnn_func.astext() + verify(ref_func, qnn_func, data_shape, data_dtype, + kernel_shape, kernel_dtype) + def test_tflite_large_irregular(): with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): @@ -937,6 +964,7 @@ if __name__ == "__main__": test_dilation() test_const_folding() test_kernel_size_1x1() + test_kernel_size_1x1_strides_2() test_tflite_large_irregular() test_broadcast_layout() test_tflite_output_multiplier_greater_than_one() -- 2.7.4