[QNN] Add support for per channel weight scale in dense op (#4880)
authormasahi <masahi129@gmail.com>
Sat, 15 Feb 2020 01:12:06 +0000 (10:12 +0900)
committerGitHub <noreply@github.com>
Sat, 15 Feb 2020 01:12:06 +0000 (10:12 +0900)
* add test case for per channel dense

* add unit arg in tflite frontend

* update qnn legalize test

* fix output dim index

python/tvm/relay/frontend/tflite.py
python/tvm/relay/qnn/op/qnn.py
src/relay/qnn/op/dense.cc
tests/python/relay/test_op_qnn_dense.py
tests/python/relay/test_pass_qnn_legalize.py

index d889631..e92e4ce 100644 (file)
@@ -982,6 +982,7 @@ class OperatorConverter(object):
 
         weight_value = self.get_tensor_value(weight_tensor)
         weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
+        weight_shape = _infer_shape(weight_expr)
 
         if input_tensor.qnn_params:
             out = _qnn.op.dense(in_expr, weight_expr,
@@ -989,6 +990,7 @@ class OperatorConverter(object):
                                 kernel_zero_point=weight_tensor.qnn_params['zero_point'],
                                 input_scale=input_tensor.qnn_params['scale'],
                                 kernel_scale=weight_tensor.qnn_params['scale'],
+                                units=weight_shape[0],
                                 out_dtype='int32')
         else:
             out = _op.nn.dense(in_expr, weight_expr)
index eaca625..a7529f6 100644 (file)
@@ -345,7 +345,7 @@ def dense(data,
           kernel_zero_point,
           input_scale,
           kernel_scale,
-          units=None,
+          units,
           out_dtype="int32"):
     """Qnn Dense operator.
     Applies a quantized linear transformation
@@ -371,7 +371,7 @@ def dense(data,
         stored for access to this during relay. This information is not
         needed in the pass pipeline after qnn.conv2d is lowered to the
         sequence of steps as in nn.conv2d. See also input_scale in Requantize.
-    units : int, optional
+    units : int
         Number of hidden units of the dense transformation.
     out_dtype : str, optional
         Specifies the output data type for mixed precision dense can be int32 or int16.
index b7a12e1..de3c4db 100644 (file)
@@ -55,7 +55,7 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   CHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
   CHECK(IsScalarType(types[3], DataType::Int(32)));    // kernel_zero_point
   CHECK(IsScalarType(types[4], DataType::Float(32)));  // input_scale
-  CHECK(IsScalarType(types[5], DataType::Float(32)));  // kernel_scale
+  AssignType(types[5], DataType::Float(32), param->units, reporter);
 
   CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
 
index 0e7c284..43600cb 100644 (file)
@@ -75,52 +75,8 @@ def make_configuration(quantized_data,
     return config
 
 
-def make_uint_configuration(use_bias=False, requantize_output=False):
-    input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3)
-    input_zero_point, kernel_zero_point = 127, 127
-    input_scale = 0.5
-    kernel_scale = 0.5
-    output_scale = 1.0
-    in_dtype = 'uint8'
-    out_dtype = 'int32' if not requantize_output else 'uint8'
-    units = 3
-    quantized_data_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 109, 107,
-                                  129, 131, 133, 135, 137, 139, 141, 111, 145, 107]) \
-        .astype(in_dtype) \
-        .reshape(input_shape)
-    quantized_kernel_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 145, 147,
-                                    129, 131, 133, 135, 137, 139, 141, 143, 145, 147,
-                                    129, 131, 133, 135, 137, 139, 141, 143, 145, 147]) \
-        .astype(in_dtype) \
-        .reshape(kernel_shape)
-    bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None
-    requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, 127, 'uint8') if requantize_output else None
-
-    if requantize_output:
-        assert use_bias
-        output = np.array([151, 152, 153, 185, 186, 187])
-    elif use_bias:
-        output = np.array([96, 100, 104, 232, 236, 240 ])
-    else:
-        output = np.array([92, 92, 92, 228, 228, 228 ])
-    output = output.astype(out_dtype).reshape(output_shape)
-    return make_configuration(quantized_data=quantized_data_np,
-                              quantized_kernel=quantized_kernel_np,
-                              dtype=in_dtype,
-                              input_shape=input_shape,
-                              kernel_shape=kernel_shape,
-                              input_zero_point=input_zero_point,
-                              kernel_zero_point=kernel_zero_point,
-                              input_scale=input_scale,
-                              kernel_scale= kernel_scale,
-                              units=units,
-                              output=output,
-                              bias=bias,
-                              requantize=requant_params)
-
-
-def make_int_configuration(use_bias=False, requantize_output=False):
-    input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3)
+def make_int_configuration(use_bias=False, requantize_output=False, per_channel=False):
+    input_shape, kernel_shape, output_shape = (2, 10), (3, 10), (2, 3)
     input_zero_point, kernel_zero_point = -1, -1
     in_dtype = 'int8'
     out_dtype = 'int32' if not requantize_output else 'int8'
@@ -138,15 +94,22 @@ def make_int_configuration(use_bias=False, requantize_output=False):
     kernel_scale = 0.5
     output_scale = 1.0
     bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None
-    requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, -1, 'int8') if requantize_output else None
 
-    if requantize_output:
+    if per_channel:
+        assert use_bias and requantize_output
+        kernel_scale = np.array([0.5, 0.3, 0.4], dtype=np.float32)
+        output = np.array([23, 14, 20, 57, 34, 47])
+    elif requantize_output:
         assert use_bias
         output = np.array([23, 24, 25, 57, 58, 59])
     elif use_bias:
-        output = np.array([96, 100, 104, 232, 236, 240 ])
+        output = np.array([96, 100, 104, 232, 236, 240])
     else:
-        output = np.array([92, 92, 92, 228, 228, 228 ])
+        output = np.array([92, 92, 92, 228, 228, 228])
+
+    requant_params = make_requantize_params(input_scale * kernel_scale,
+                                            output_scale, -1, 'int8') if requantize_output else None
+
     output = output.astype(out_dtype).reshape(output_shape)
     return make_configuration(quantized_data=quantized_data_np,
                               quantized_kernel=quantized_kernel_np,
@@ -206,8 +169,8 @@ def qnn_dense_driver(test_configuration):
     with relay.build_config(opt_level=2):
         graph, lib, params = relay.build(mod, "llvm", params=None)
         mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
-        mod.set_input(quantized_data_name,test_configuration[quantized_data_name])
-        mod.set_input(quantized_kernel_name,test_configuration[quantized_kernel_name])
+        mod.set_input(quantized_data_name, test_configuration[quantized_data_name])
+        mod.set_input(quantized_kernel_name, test_configuration[quantized_kernel_name])
         if test_configuration[bias_name] is not None:
             mod.set_input(bias_name, test_configuration[bias_name])
         mod.set_input(**params)
@@ -241,7 +204,15 @@ def test_qnn_dense_with_requantized_output():
         qnn_dense_driver(int8_requantized_output_with_bias_params)
 
 
+def test_per_channel_weight_scale():
+    with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
+        config = make_int_configuration(use_bias=True, requantize_output=True,
+                                        per_channel=True)
+        qnn_dense_driver(config)
+
+
 if __name__ == "__main__":
     test_qnn_dense_without_bias()
     test_qnn_dense_with_bias()
     test_qnn_dense_with_requantized_output()
+    test_per_channel_weight_scale()
index e5893c9..dee19f7 100644 (file)
@@ -191,6 +191,7 @@ def test_qnn_legalize_qnn_dense():
                 kernel_zero_point=relay.const(1, 'int32'),
                 input_scale=relay.const(1, 'float32'),
                 kernel_scale=relay.const(1, 'float32'),
+                units=kernel_shape[0],
                 out_dtype='int32')
 
         mod = relay.Function(relay.analysis.free_vars(func), func)