From 24c53a343b0ecb76ed766d3f29e968ee0f8b0816 Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 14 Feb 2020 13:28:07 +0900 Subject: [PATCH] [QNN] More doc fix on quantize and convolution (#4874) * [QNN] Doc fix on quantize and convolution * update test --- python/tvm/relay/qnn/op/qnn.py | 10 +++---- tests/python/relay/test_op_qnn_conv2d.py | 41 ++++++++++++++++++---------- tests/python/relay/test_pass_qnn_legalize.py | 1 + 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index f76d7b3..eaca625 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -104,7 +104,7 @@ def quantize(data, axis : int The channel axis for quantization. Default value is -1 which corresponds to the last axis. out_dtype : str, optional - The data type of the input tensor. Can be [int8, uint8] + The data type of the input tensor. Can be [int8, uint8, int32] Returns ------- result : tvm.relay.Expr @@ -202,11 +202,11 @@ def conv2d(data, input_scale, kernel_scale, kernel_size, + channels, strides=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, - channels=None, data_layout="NCHW", kernel_layout="OIHW", out_layout="", @@ -247,6 +247,9 @@ def conv2d(data, kernel_size : tuple of int The spatial width and height of the convolution kernel. + channels : int + Number of output channels of this convolution. + strides : tuple of int, optional The strides of convolution. @@ -259,9 +262,6 @@ def conv2d(data, groups : int, optional Number of groups for grouped convolution. - channels : int, optional - Number of output channels of this convolution. - data_layout : str, optional Layout of the input. diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 264475c..67a7ef6 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -79,8 +79,8 @@ def get_qnn_func(data, data_layout, kernel_layout, out_dtype, - groups, - channels=None): + channels, + groups): func = relay.qnn.op.conv2d( data, kernel, input_zero_point=relay.const(input_zero_point, 'int32'), @@ -116,12 +116,23 @@ def get_funcs(data_shape, data_layout, kernel_layout, out_dtype, - groups=1, - channels=None): + groups=1): data = relay.var("data", shape=data_shape, dtype=data_dtype) kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype) + + if groups > 1: + channels = groups + elif kernel_layout == "OIHW": + channels = kernel_shape[0] + elif kernel_layout == "HWIO": + channels = kernel_shape[3] + elif kernel_layout == "HWOI": + channels = kernel_shape[2] + else: + raise NotImplementedError + ref_func = get_ref_func(data, kernel, input_zero_point, @@ -152,8 +163,9 @@ def get_funcs(data_shape, data_layout, kernel_layout, out_dtype, - groups, - channels) + channels, + groups) + return (ref_func, qnn_func) def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, @@ -418,7 +430,7 @@ def test_layout(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) - # NHWC and HWIO layout. Used in depthwise conv. + # NHWC and HWOI layout. Used in depthwise conv. data_shape = (2, 2, 4, 1) # NHWC data_dtype = 'uint8' kernel_shape = (2, 2, 1, 1) # HWOI @@ -568,6 +580,7 @@ def test_const_folding(): data_layout="NCHW", kernel_layout="OIHW", out_dtype="int32", + channels=kernel_shape[0], groups=1) folded_mod = transform.FoldConstant()(qnn_func) folded_func = folded_mod["main"] @@ -787,8 +800,8 @@ def test_depthwise_depth_multiplier(): data_layout="NCHW", kernel_layout="OIHW", out_dtype="int32", - groups=4, - channels=4) + groups=4) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) @@ -813,8 +826,7 @@ def test_depthwise_depth_multiplier(): data_layout="NCHW", kernel_layout="OIHW", out_dtype="int32", - groups=8, - channels=8) + groups=8) verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) @@ -839,8 +851,7 @@ def test_depthwise_depth_multiplier(): data_layout="NHWC", kernel_layout="HWOI", out_dtype="int32", - groups=4, - channels=4) + groups=4) verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) @@ -864,8 +875,7 @@ def test_depthwise_depth_multiplier(): data_layout="NHWC", kernel_layout="HWOI", out_dtype="int32", - groups=8, - channels=8) + groups=8) verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) @@ -888,6 +898,7 @@ def test_per_channel_kernel_scale(): input_scale=relay.const(2.0, 'float32'), kernel_scale=kernel_scales, kernel_size=(2, 2), + channels=kernel_shape[0], padding=(0, 0), strides=(1, 1), dilation=(1, 1), diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 38fdb7d..e5893c9 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -107,6 +107,7 @@ def test_qnn_legalize_qnn_conv2d(): input_scale=relay.const(1.0, 'float32'), kernel_scale=relay.const(1.0, 'float32'), kernel_size=(3, 3), + channels=kernel_shape[0], strides=(1, 1), dilation=(1, 1), out_dtype='int32', -- 2.7.4