[QNN] More doc fix on quantize and convolution (#4874)
authormasahi <masahi129@gmail.com>
Fri, 14 Feb 2020 04:28:07 +0000 (13:28 +0900)
committerGitHub <noreply@github.com>
Fri, 14 Feb 2020 04:28:07 +0000 (20:28 -0800)
* [QNN] Doc fix on quantize and convolution

* update test

python/tvm/relay/qnn/op/qnn.py
tests/python/relay/test_op_qnn_conv2d.py
tests/python/relay/test_pass_qnn_legalize.py

index f76d7b3..eaca625 100644 (file)
@@ -104,7 +104,7 @@ def quantize(data,
     axis : int
         The channel axis for quantization. Default value is -1 which corresponds to the last axis.
     out_dtype : str, optional
-        The data type of the input tensor. Can be [int8, uint8]
+        The data type of the input tensor. Can be [int8, uint8, int32]
     Returns
     -------
     result : tvm.relay.Expr
@@ -202,11 +202,11 @@ def conv2d(data,
            input_scale,
            kernel_scale,
            kernel_size,
+           channels,
            strides=(1, 1),
            padding=(0, 0),
            dilation=(1, 1),
            groups=1,
-           channels=None,
            data_layout="NCHW",
            kernel_layout="OIHW",
            out_layout="",
@@ -247,6 +247,9 @@ def conv2d(data,
     kernel_size : tuple of int
         The spatial width and height of the convolution kernel.
 
+    channels : int
+        Number of output channels of this convolution.
+
     strides : tuple of int, optional
         The strides of convolution.
 
@@ -259,9 +262,6 @@ def conv2d(data,
     groups : int, optional
         Number of groups for grouped convolution.
 
-    channels : int, optional
-        Number of output channels of this convolution.
-
     data_layout : str, optional
         Layout of the input.
 
index 264475c..67a7ef6 100644 (file)
@@ -79,8 +79,8 @@ def get_qnn_func(data,
                  data_layout,
                  kernel_layout,
                  out_dtype,
-                 groups,
-                 channels=None):
+                 channels,
+                 groups):
     func = relay.qnn.op.conv2d(
             data, kernel,
             input_zero_point=relay.const(input_zero_point, 'int32'),
@@ -116,12 +116,23 @@ def get_funcs(data_shape,
               data_layout,
               kernel_layout,
               out_dtype,
-              groups=1,
-              channels=None):
+              groups=1):
     data = relay.var("data", shape=data_shape,
             dtype=data_dtype)
     kernel = relay.var("kernel", shape=kernel_shape,
             dtype=kernel_dtype)
+
+    if groups > 1:
+        channels = groups
+    elif kernel_layout == "OIHW":
+        channels = kernel_shape[0]
+    elif kernel_layout == "HWIO":
+        channels = kernel_shape[3]
+    elif kernel_layout == "HWOI":
+        channels = kernel_shape[2]
+    else:
+        raise NotImplementedError
+
     ref_func = get_ref_func(data,
                             kernel,
                             input_zero_point,
@@ -152,8 +163,9 @@ def get_funcs(data_shape,
                             data_layout,
                             kernel_layout,
                             out_dtype,
-                            groups,
-                            channels)
+                            channels,
+                            groups)
+
     return (ref_func, qnn_func)
 
 def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
@@ -418,7 +430,7 @@ def test_layout():
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
 
-        # NHWC and HWIO layout. Used in depthwise conv.
+        # NHWC and HWOI layout. Used in depthwise conv.
         data_shape = (2, 2, 4, 1) # NHWC
         data_dtype = 'uint8'
         kernel_shape = (2, 2, 1, 1) # HWOI
@@ -568,6 +580,7 @@ def test_const_folding():
                                 data_layout="NCHW",
                                 kernel_layout="OIHW",
                                 out_dtype="int32",
+                                channels=kernel_shape[0],
                                 groups=1)
         folded_mod = transform.FoldConstant()(qnn_func)
         folded_func = folded_mod["main"]
@@ -787,8 +800,8 @@ def test_depthwise_depth_multiplier():
                                        data_layout="NCHW",
                                        kernel_layout="OIHW",
                                        out_dtype="int32",
-                                       groups=4,
-                                       channels=4)
+                                       groups=4)
+
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
 
@@ -813,8 +826,7 @@ def test_depthwise_depth_multiplier():
                                        data_layout="NCHW",
                                        kernel_layout="OIHW",
                                        out_dtype="int32",
-                                       groups=8,
-                                       channels=8)
+                                       groups=8)
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
 
@@ -839,8 +851,7 @@ def test_depthwise_depth_multiplier():
                                        data_layout="NHWC",
                                        kernel_layout="HWOI",
                                        out_dtype="int32",
-                                       groups=4,
-                                       channels=4)
+                                       groups=4)
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
 
@@ -864,8 +875,7 @@ def test_depthwise_depth_multiplier():
                                        data_layout="NHWC",
                                        kernel_layout="HWOI",
                                        out_dtype="int32",
-                                       groups=8,
-                                       channels=8)
+                                       groups=8)
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
 
@@ -888,6 +898,7 @@ def test_per_channel_kernel_scale():
                 input_scale=relay.const(2.0, 'float32'),
                 kernel_scale=kernel_scales,
                 kernel_size=(2, 2),
+                channels=kernel_shape[0],
                 padding=(0, 0),
                 strides=(1, 1),
                 dilation=(1, 1),
index 38fdb7d..e5893c9 100644 (file)
@@ -107,6 +107,7 @@ def test_qnn_legalize_qnn_conv2d():
                 input_scale=relay.const(1.0, 'float32'),
                 kernel_scale=relay.const(1.0, 'float32'),
                 kernel_size=(3, 3),
+                channels=kernel_shape[0],
                 strides=(1, 1),
                 dilation=(1, 1),
                 out_dtype='int32',