MXNet pre-quantized BERT (#6039)
authorAnimesh Jain <anijain@umich.edu>
Mon, 20 Jul 2020 22:59:55 +0000 (15:59 -0700)
committerGitHub <noreply@github.com>
Mon, 20 Jul 2020 22:59:55 +0000 (15:59 -0700)
* MXNet pre-quantized BERT

* Comments.

* Trigger.

* Retrigger CI

* Retrigger CI

* Retrigger CI

* Retrigger

include/tvm/relay/qnn/attrs.h
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/nnvm_common.py
python/tvm/relay/qnn/op/qnn.py
src/relay/qnn/op/dequantize.cc
tests/python/frontend/mxnet/test_forward.py
tests/python/relay/test_op_qnn_dequantize.py

index 4b5cd89..c5213fe 100644 (file)
@@ -75,6 +75,19 @@ struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
   }
 };
 
+/*! \brief Attribute for dequantize operator */
+struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
+  int axis;
+
+  TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
+    TVM_ATTR_FIELD(axis)
+        .describe(
+            "The channel axis for channel wise dequantization. Default value is -1,"
+            "which corresponds to the last axis.")
+        .set_default(-1);
+  }
+};
+
 }  // namespace qnn
 }  // namespace relay
 }  // namespace tvm
index 97b9d7a..327bcd4 100644 (file)
@@ -1944,18 +1944,27 @@ def _qnn_batch_norm(inputs, attrs):
 
 def _qnn_fully_connected(inputs, attrs, subgraphs, params):
 
-    def _get_input_scale_zp(_data, _inputs, _has_bias):
+    def _get_input_scale_zp(_data_dtype, _inputs, _has_bias):
         data_min_idx, data_max_idx = (3, 4) if _has_bias else (2, 3)
         data_min, data_max = _inputs[data_min_idx], _inputs[data_max_idx]
-        data_dtype = _infer_type(_data).checked_type.dtype
         _data_scale = get_mkldnn_uint8_scale(data_min, data_max) \
-            if data_dtype == 'uint8' \
+            if _data_dtype == 'uint8' \
             else get_mkldnn_int8_scale(data_min, data_max)
         _data_zp = 0
         return _data_scale, _data_zp
 
-    def _get_kernel_scale_zp(_kernel, _inputs, _has_bias):
+    def _get_kernel_scale_zp_tensor_quantized(_kernel, _inputs, _has_bias):
         kernel_dtype = _infer_type(_kernel).checked_type.dtype
+
+        if kernel_dtype != "int8":
+            raise tvm.error.OpNotImplemented(\
+                "Tensor wise quantized expects weights in int8 data type")
+
+        if isinstance(_kernel, tvm.relay.Call) and _kernel.op.name == "qnn.quantize":
+            _kernel_scale = _kernel.args[1].data.asnumpy()
+            _kernel_zp = _kernel.args[2].data.asnumpy()
+            return _kernel_scale, _kernel_zp
+
         kernel_min_idx, kernel_max_idx = (5, 6) if _has_bias else (4, 5)
         kernel_min_name = _get_name(_inputs[kernel_min_idx])
         kernel_min = params[kernel_min_name].asnumpy()[0]
@@ -1967,7 +1976,34 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
         _kernel_zp = 0
         return _kernel_scale, _kernel_zp
 
+    def _get_kernel_scale_zp_channel_quantized(_kernel, _bias, _data_scale):
+        kernel_dtype = _infer_type(_kernel).checked_type.dtype
+        if kernel_dtype != "float32":
+            raise tvm.error.OpNotImplemented(\
+                "Channel wise quantized expects weights in float32 data type")
+
+        # Get the FP32 values, calculate min/max and then channel quantize them
+        np_kernel = _infer_value(_kernel, params).asnumpy()
+        kernel_channel_min = np.amin(np_kernel, axis=(1, ))
+        kernel_channel_max = np.amax(np_kernel, axis=(1, ))
+
+        np_bias = None
+        if _bias is not None:
+            np_bias = _infer_value(_bias, params).asnumpy()
+        return quantize_conv_weights_bias_channel_mkldnn_from_var(_kernel,
+                                                                  np_bias,
+                                                                  kernel_channel_min,
+                                                                  kernel_channel_max,
+                                                                  _data_scale)
+
     def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
+        _bias = _inputs[2]
+        if isinstance(_bias, tvm.relay.Call) and _bias.op.name == "qnn.quantize":
+            _bias_scale = _bias.args[1].data.asnumpy()
+            _bias_requantize_scale = _bias_scale/(_data_scale * _kernel_scale)
+            _bias_requantize_scale = _expr.const(_bias_requantize_scale, dtype="float32")
+            return _bias_requantize_scale
+
         bias_min_name = _get_name(_inputs[7])
         bias_min = params[bias_min_name].asnumpy()[0]
         bias_max_name = _get_name(_inputs[8])
@@ -1987,16 +2023,48 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
         return res
     else:
         has_bias = not subgraph_dense_attrs.get_bool("no_bias", False)
-        # input
-        data = inputs[0]
-        data_scale, data_zp = _get_input_scale_zp(data, inputs, has_bias)
-        # kernel
-        kernel = inputs[1]
-        kernel_scale, kernel_zp = _get_kernel_scale_zp(kernel, inputs, has_bias)
         units = subgraph_dense_attrs.get_int("num_hidden")
+        is_flatten = subgraph_dense_attrs.get_bool("flatten", True)
+        enable_float_output = attrs.get_bool('enable_float_output', False)
+        is_channel_quantized = attrs.get_bool('channel_wise_quantize', False)
+
+        ########################
+        # Get data, kernel, bias
+        ########################
+        data, kernel = inputs[0], inputs[1]
+        bias = None
+        if has_bias:
+            bias = inputs[2]
+
+        ##############################
+        # Handle for shape of data > 2
+        ##############################
+        if is_flatten:
+            data = _op.nn.batch_flatten(data)
         data_shape = _infer_type(data).checked_type.shape
         if len(data_shape) > 2:
-            data = _op.nn.batch_flatten(data)
+            data = _op.reverse_reshape(data, [-1, 0])
+
+        ###############################
+        # Get data scale and zero point
+        ###############################
+        data_dtype = _infer_type(data).checked_type.dtype
+        data_scale, data_zp = _get_input_scale_zp(data_dtype, inputs, has_bias)
+
+        #################################
+        # Get weight scale and zero point
+        #################################
+        if is_channel_quantized:
+            kernel, kernel_scale, kernel_zp = _get_kernel_scale_zp_channel_quantized(kernel,
+                                                                                     bias,
+                                                                                     data_scale)
+        else:
+            kernel_scale, kernel_zp = _get_kernel_scale_zp_tensor_quantized(kernel, inputs,
+                                                                            has_bias)
+
+        ################
+        # Call QNN dense
+        ################
         res = relay.qnn.op.dense(data,
                                  kernel,
                                  input_zero_point=relay.const(data_zp, 'int32'),
@@ -2004,22 +2072,46 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
                                  input_scale=relay.const(data_scale, 'float32'),
                                  kernel_scale=relay.const(kernel_scale, 'float32'),
                                  units=units)
+
+        #################
+        # Handle bias add
+        #################
         if has_bias:
-            bias_data = inputs[2]
-            bias_requantize_scale = \
-                _get_bias_requantize_scale(inputs, data_scale, kernel_scale)
-            multiplied_bias = \
-                _op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
-            rounded_bias = _op.round(multiplied_bias)
-            clipped_bias = _op.clip(rounded_bias,
-                                    a_min=tvm.tir.op.min_value('int32').value,
-                                    a_max=tvm.tir.op.max_value('int32').value)
-            requantized_bias = _op.cast(clipped_bias, 'int32')
-            res = _op.nn.bias_add(res, requantized_bias, axis=-1)
-        enable_float_output = attrs.get_bool('enable_float_output', False)
-        out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8'
-        input_scale = np.float32(data_scale * kernel_scale)
-        if not enable_float_output:
+            if is_channel_quantized:
+                bias_scale = data_scale * kernel_scale
+                int32_bias = quantize_conv_bias_mkldnn_from_var(bias, bias_scale)
+                res = _op.nn.bias_add(res, int32_bias, axis=-1)
+            else:
+                bias_data = inputs[2]
+                bias_requantize_scale = \
+                    _get_bias_requantize_scale(inputs, data_scale, kernel_scale)
+                multiplied_bias = \
+                    _op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
+                rounded_bias = _op.round(multiplied_bias)
+                clipped_bias = _op.clip(rounded_bias,
+                                        a_min=tvm.tir.op.min_value('int32').value,
+                                        a_max=tvm.tir.op.max_value('int32').value)
+                requantized_bias = _op.cast(clipped_bias, 'int32')
+                res = _op.nn.bias_add(res, requantized_bias, axis=-1)
+
+        ##############################################
+        # Dequantize if float32 output else Requantize
+        ##############################################
+        if enable_float_output:
+            output_scale = np.float32(data_scale * kernel_scale)
+            res = relay.qnn.op.dequantize(res,
+                                          relay.const(output_scale),
+                                          input_zero_point=relay.const(0, 'int32'),
+                                          axis=1)
+            if with_relu:
+                res = _op.nn.relu(res)
+        else:
+
+            if is_channel_quantized:
+                raise tvm.error.OpNotImplemented(\
+                    "Channel wise quantized dense with non float output is not supported yet")
+            out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8'
+            input_scale = np.float32(data_scale * kernel_scale)
             min_output_range = attrs.get_float('min_calib_range')
             max_output_range = attrs.get_float('max_calib_range')
             output_scale = get_mkldnn_requantize_scale_outDtype(min_output_range,
@@ -2034,17 +2126,20 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
                 out_dtype=out_dtype)
             if with_relu:
                 res = _op.nn.relu(res)
-            return res, min_output_range, max_output_range
-        else:
-            output_scale = np.float32(data_scale * kernel_scale)
-            res = relay.qnn.op.dequantize(res,
-                                          relay.const(output_scale, 'float32'),
-                                          input_zero_point=relay.const(0, 'int32'))
-            if with_relu:
-                res = _op.nn.relu(res)
-            return res
 
 
+        ##############################
+        # Handle for shape of data > 2
+        ##############################
+        if len(data_shape) > 2:
+            new_shape = data_shape[:-1]
+            new_shape.append(units)
+            res = _op.reshape(res, new_shape)
+
+        if enable_float_output:
+            return res
+        return res, min_output_range, max_output_range
+
 def _mx_broadcast_to(inputs, attrs):
     data = inputs[0]
     tgt_shape = attrs.get_int_tuple("shape", [])
index a2eea94..7dd9c02 100644 (file)
 # pylint: disable=invalid-name, import-self, len-as-condition
 """Utility functions common to NNVM and MxNet conversion."""
 import warnings
+from ... import error
+from ...tir.op import min_value
 from .. import expr as _expr
 from .. import op as _op
 from .common import get_relay_op
 from .common import infer_type as _infer_type
+from .common import infer_shape as _infer_shape
 
 def _warn_not_used(attr, op='nnvm'):
     err = "{} is ignored in {}.".format(attr, op)
@@ -57,9 +60,54 @@ def _init_op(new_op):
 def _softmax_op(new_op):
     """softmax/log_softmax"""
     def _impl(inputs, attrs, _dtype='float32'):
-        # TODO(@icemelon9): currently ignore the 2nd input to softmax for mxnet 1.6
-        # assert len(inputs) == 1
         axis = attrs.get_int("axis", -1)
+        use_length = attrs.get_bool("use_length", False)
+        if use_length:
+            # The second arg is valid_length. We can use sequence mask to mask the input before
+            # computing softmax
+            assert len(inputs) == 2
+
+            data = inputs[0]
+            length = inputs[1]
+            data_shape = _infer_shape(data)
+            data_dtype = _infer_type(data).checked_type.dtype
+            length_shape = _infer_shape(length)
+
+            if axis < 0:
+                axis = len(data_shape) + axis
+
+            data_ndims = len(data_shape)
+            length_ndims = len(length_shape)
+
+            # Sequence_mask supports axis = 0 and 1 and requires data to be in specific format.
+            if axis == data_ndims - 1 and data_ndims > 2 and length_ndims == 2:
+                new_batch_size = 1
+                for dim in range(length_ndims):
+                    assert data_shape[dim] == length_shape[dim]
+                    new_batch_size *= data_shape[dim]
+
+                # Reshape the data and length to satisfy sequence mask
+                data = _op.reshape(data, newshape=(new_batch_size, -1))
+                length = _op.reshape(length, newshape=(new_batch_size))
+
+                # Input data is now 2D, we can set the axis = 1
+                axis = 1
+            elif data_ndims > 2:
+                raise error.OpNotImplemented(\
+                        "Operator softmax with use_length=True is supported only for axis -1")
+
+            res = _op.sequence_mask(data=data,
+                                    valid_length=length,
+                                    mask_value=float(min_value(data_dtype).value),
+                                    axis=axis)
+
+            # Apply softmax
+            res = new_op(res, axis=axis)
+
+            # Reshape back to input data shape
+            if len(data_shape) > 2:
+                return _op.reshape(res, newshape=data_shape)
+            return res
         return new_op(inputs[0], axis=axis)
     return _impl
 
index 5a3106d..14d74bf 100644 (file)
@@ -121,7 +121,8 @@ def quantize(data,
 
 def dequantize(data,
                input_scale,
-               input_zero_point):
+               input_zero_point,
+               axis=-1):
     r""" Dequantize op
     This operator takes quantized int8 and unit8 as input and produces
     dequantized float32 as output. The output shape is the same as input shape. The input
@@ -135,6 +136,8 @@ def dequantize(data,
         The input zero_point.
     input_scale : tvm.relay.Expr
         The input scale.
+    axis : int
+        The channel axis for quantization. Default value is -1 which corresponds to the last axis.
     Returns
     -------
     result : tvm.relay.Expr
@@ -143,7 +146,8 @@ def dequantize(data,
 
     return _make.dequantize(data,
                             input_scale,
-                            input_zero_point)
+                            input_zero_point,
+                            axis)
 
 
 def concatenate(data,
index 7c014d7..da804da 100644 (file)
@@ -34,6 +34,8 @@ namespace tvm {
 namespace relay {
 namespace qnn {
 
+TVM_REGISTER_NODE_TYPE(DequantizeAttrs);
+
 bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                    const TypeReporter& reporter) {
   CHECK_EQ(types.size(), 4);
@@ -45,9 +47,16 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
       << "Input type should be one of the quantized types [unit8, int8, int32] but was "
       << input_dtype;
 
-  // Check the types of scale and zero points.
-  CHECK(IsScalarType(types[1], DataType::Float(32)));  // input_scale
-  CHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+  const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
+  int axis = dequantize_attrs->axis;
+  axis = (axis == -1) ? data->shape.size() - 1 : axis;
+  CHECK_LT(axis, static_cast<int>(data->shape.size()))
+      << "axis " << dequantize_attrs->axis << " is out of range";
+  CHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range";
+
+  // Check and assign types for scale and zero points.
+  AssignType(types[1], DataType::Float(32), data->shape[axis], reporter);  // scale
+  AssignType(types[2], DataType::Int(32), data->shape[axis], reporter);    // zero point
 
   const Array<tvm::PrimExpr> oshape = data->shape;
   // assign output type, output will always be float 32.
@@ -55,16 +64,34 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   return true;
 }
 
-Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point) {
+Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis) {
   // real_value = scale * (quantized_value - zero_point)
   // A more detailed explanation can be found here -
   // https://github.com/google/gemmlowp/blob/master/doc/quantization.md
+  auto attrs = make_object<DequantizeAttrs>();
+  attrs->axis = axis;
   static const Op& op = Op::Get("qnn.dequantize");
-  return Call(op, {data, input_scale, input_zero_point}, Attrs(), {});
+  return Call(op, {data, input_scale, input_zero_point}, Attrs(attrs), {});
 }
 
 Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
-                     const Expr& input_zero_point) {
+                     const Expr& input_zero_point, const Array<IndexExpr>& input_shape,
+                     const DequantizeAttrs* attrs) {
+  const auto axis = attrs->axis;
+
+  size_t n_dim = input_shape.size();
+
+  // Expand scale and zero point if the input tensor is channel quantized
+  auto expanded_input_scale = input_scale;
+  if (!IsConstScalar(input_scale)) {
+    expanded_input_scale = ExpandBiasToMatchAxis(input_scale, n_dim, {axis});
+  }
+
+  auto expanded_input_zero_point = input_zero_point;
+  if (!IsConstScalar(input_zero_point)) {
+    expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis});
+  }
+
   auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
   auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
   return scaled_output;
@@ -77,7 +104,20 @@ Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
   auto& input_scale = new_args[1];
   auto& input_zero_point = new_args[2];
   CHECK_EQ(types.size(), 4);
-  return DequantizeLower(data, input_scale, input_zero_point);
+
+  // Get attrs.
+  const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
+  CHECK(dequantize_attrs != nullptr);
+
+  // Find input shape.
+  CHECK_EQ(types.size(), 4);
+  auto in_type = types[0];
+  auto in_tensor_type = in_type.as<TensorTypeNode>();
+  CHECK(in_tensor_type != nullptr) << "Type information missing."
+                                   << " Please run infer_type pass.";
+  Array<IndexExpr> input_shape = in_tensor_type->shape;
+
+  return DequantizeLower(data, input_scale, input_zero_point, input_shape, dequantize_attrs);
 }
 
 RELAY_REGISTER_OP("qnn.dequantize")
@@ -85,6 +125,7 @@ RELAY_REGISTER_OP("qnn.dequantize")
 The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point.
 - **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point
 )code" TVM_ADD_FILELINE)
+    .set_attrs_type<DequantizeAttrs>()
     .set_num_inputs(3)
     .add_argument("data", "Tensor", "The tensor to dequantize.")
     .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
index c8bbf88..48ad736 100644 (file)
@@ -1373,6 +1373,43 @@ def test_forward_box_decode():
     verify((1, 10, 4), (1, 10, 4), in_format="center")
 
 
+def test_forward_softmax():
+    def verify(data_shape, axis, use_length, length):
+        dtype = "float32"
+        x = np.random.uniform(low=-100, high=100, size=data_shape).astype(dtype)
+        if use_length:
+            ref_res = mx.nd.softmax(data=mx.nd.array(x),
+                                    length=mx.nd.array(length, dtype="int32"),
+                                    axis=axis, use_length=use_length)
+            mx_sym = mx.symbol.softmax(data=mx.sym.var("data"),
+                                       length=mx.sym.var("length"),
+                                       axis=axis, use_length=use_length)
+            shape_dict = {"data": data_shape, "length": (length.shape)}
+            dtype_dict = {"data": dtype, "length": "int32"}
+            mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
+        else:
+            ref_res = mx.nd.softmax(data=mx.nd.array(x), axis=axis)
+            mx_sym = mx.symbol.softmax(data=mx.sym.var("data"), axis=axis)
+            shape_dict = {"data": data_shape}
+            mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                if use_length:
+                    op_res = intrp.evaluate()(x, length)
+                else:
+                    op_res = intrp.evaluate()(x)
+
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+
+    verify((2, 3, 5), -1, False, None)
+    verify((2, 3, 5), 2, False, None)
+    verify((2, 3), -1, True, np.array([2, 1]).astype('int32'))
+    verify((2, 3, 4), -1, True, np.array([[3, 4, 2], [2, 1, 1]]).astype('int32'))
+    verify((2, 3, 4), 2, True, np.array([[3, 4, 2], [1, 2, 1]]).astype('int32'))
+
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -1449,3 +1486,4 @@ if __name__ == '__main__':
     test_forward_box_decode()
     test_forward_amp_multicast()
     test_forward_amp_cast()
+    test_forward_softmax()
index 3c82b7f..361d6f0 100644 (file)
@@ -21,13 +21,14 @@ import numpy as np
 from tvm import relay
 from tvm.contrib import graph_runtime
 
-def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
+def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, axis):
     shape = in_data.shape
     input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
     input_zero_point = relay.const(quant_args['in_zero_point'], 'int32')
     input_scale = relay.const(quant_args['in_scale'], 'float32')
     quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale,
-                                               input_zero_point=input_zero_point)
+                                               input_zero_point=input_zero_point,
+                                               axis=axis)
     mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
     mod = tvm.IRModule.from_expr(mod)
     with tvm.transform.PassContext(opt_level=3):
@@ -48,8 +49,8 @@ def test_uint8_to_float32():
         .astype('float32') \
         .reshape((2, 5))
     quant_args = {"in_zero_point":127, "in_scale":0.5}
-    quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
-                         verify_output_data=output)
+    dequantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
+                           verify_output_data=output, axis=-1)
 
 def test_int8_to_float32():
     data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
@@ -59,18 +60,31 @@ def test_int8_to_float32():
         .astype('float32') \
         .reshape((2, 5))
     quant_args = {"in_zero_point": -1, "in_scale": 0.5}
-    quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
-                         verify_output_data=output)
+    dequantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
+                           verify_output_data=output, axis=-1)
 
 def test_int32_to_float32():
     data = np.array([113, 29, -1052]).astype('int32')
     output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32')
     quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604}
-    quantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data,
-                         verify_output_data=output)
+    dequantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data,
+                           verify_output_data=output, axis=-1)
+
+
+def test_channelwise_axis_1():
+    data = np.transpose(np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \
+                        .astype('uint8').reshape((2,5)))
+    output = np.transpose(np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \
+                         .astype('float32').reshape((2,5)))
+    quant_args = {"in_zero_point" : np.array([127, 123]).astype('int32'),
+                  "in_scale"      : np.array([0.5, 0.25]).astype('float32')}
+
+    dequantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
+                           verify_output_data=output, axis=1)
 
 
 if __name__ == "__main__":
     test_uint8_to_float32()
     test_int8_to_float32()
     test_int32_to_float32()
+    test_channelwise_axis_1()