{QNN] Making scale/zero_points as expr instead of attrs. (#4611)
authorAnimesh Jain <anijain@umich.edu>
Fri, 3 Jan 2020 13:39:56 +0000 (05:39 -0800)
committerWuwei Lin <wuwei@apache.org>
Fri, 3 Jan 2020 13:39:56 +0000 (22:39 +0900)
27 files changed:
include/tvm/relay/qnn/attrs.h
python/tvm/relay/frontend/mxnet_qnn_op_utils.py
python/tvm/relay/frontend/tflite.py
python/tvm/relay/qnn/op/__init__.py
python/tvm/relay/qnn/op/legalizations.py
python/tvm/relay/qnn/op/qnn.py
python/tvm/relay/util.py [moved from python/tvm/relay/qnn/op/op_attrs.py with 53% similarity]
src/relay/pass/pattern_util.h
src/relay/qnn/op/add.cc
src/relay/qnn/op/concatenate.cc
src/relay/qnn/op/convolution.cc
src/relay/qnn/op/dense.cc
src/relay/qnn/op/dequantize.cc
src/relay/qnn/op/mul.cc
src/relay/qnn/op/op_common.h
src/relay/qnn/op/quantize.cc
src/relay/qnn/op/requantize.cc
src/relay/qnn/util.h
tests/python/relay/test_op_qnn_add.py
tests/python/relay/test_op_qnn_concatenate.py
tests/python/relay/test_op_qnn_conv2d.py
tests/python/relay/test_op_qnn_dense.py
tests/python/relay/test_op_qnn_dequantize.py
tests/python/relay/test_op_qnn_mul.py
tests/python/relay/test_op_qnn_quantize.py
tests/python/relay/test_op_qnn_requantize.py
tests/python/relay/test_pass_qnn_legalize.py

index 66029c8..dbc8ded 100644 (file)
@@ -33,22 +33,10 @@ namespace qnn {
 
 /*! \brief Attribute for requantize operator */
 struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
-  double input_scale;
-  int32_t input_zero_point;
-  double output_scale;
-  int32_t output_zero_point;
   std::string rounding;
   DataType out_dtype;
 
   TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
-    TVM_ATTR_FIELD(input_scale)
-        .describe("The scale of the input tensor.");
-    TVM_ATTR_FIELD(input_zero_point)
-        .describe("The zero point of the input tensor.");
-    TVM_ATTR_FIELD(output_scale)
-        .describe("The scale of the output tensor.");
-    TVM_ATTR_FIELD(output_zero_point)
-        .describe("The zero point of the output tensor.");
     TVM_ATTR_FIELD(rounding).set_default("UPWARD")
         .describe("Defines the rounding direction when the value is midway between"
                   "two representable values. There are two supported modes - UPWARD"
@@ -67,175 +55,11 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
 
 /*! \brief Attribute for quantize operator */
 struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
-  int32_t output_zero_point;
-  double output_scale;
   DataType out_dtype;
 
   TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
     TVM_ATTR_FIELD(out_dtype)
       .describe("Output data type, can be one of [int8 or uint8].");
-    TVM_ATTR_FIELD(output_zero_point)
-      .describe("The zero_point for the activation of this op.");
-    TVM_ATTR_FIELD(output_scale)
-      .describe("The scale for the activation of this op.");
-  }
-};
-
-/*! \brief Attribute for dequantize operator */
-struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
-  int32_t input_zero_point;
-  double input_scale;
-
-  TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
-    TVM_ATTR_FIELD(input_zero_point)
-      .describe("The zero_point for the input tensor of this op.");
-    TVM_ATTR_FIELD(input_scale)
-      .describe("The scale for the input tensor of this op.");
-  }
-};
-
-/*! \brief Attributes used in QNN concatenate operator */
-struct QnnConcatenateAttrs : public tvm::AttrsNode<QnnConcatenateAttrs> {
-  Array<tvm::Expr> input_scales;
-  Array<tvm::Expr> input_zero_points;
-  double output_scale;
-  int32_t output_zero_point;
-  int axis;
-
-  TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") {
-    TVM_ATTR_FIELD(input_scales)
-        .describe("The list of scales of input quantized tensors.");
-    TVM_ATTR_FIELD(input_zero_points)
-        .describe("The list of zero points of input quantized tensors.");
-    TVM_ATTR_FIELD(output_zero_point)
-      .describe("The zero_point for the output tensor.");
-    TVM_ATTR_FIELD(output_scale)
-      .describe("The scale for the output tensor.");
-    TVM_ATTR_FIELD(axis)
-        .describe("The axis at which the input arrays are concatenated."
-                  "Should lie in range `[-ndim, ndim)`.")
-        .set_default(0);
-  }
-};  // struct QnnConcatenateAttrs
-
-/*! \brief Attribute for QNN Conv2d operator */
-struct QnnConv2DAttrs : public tvm::AttrsNode<QnnConv2DAttrs> {
-  // Traditional conv2d attributes.
-  Array<IndexExpr> strides;
-  Array<IndexExpr> padding;
-  Array<IndexExpr> dilation;
-  int groups;
-  IndexExpr channels;
-  Array<IndexExpr> kernel_size;
-  std::string data_layout;
-  std::string kernel_layout;
-  std::string out_layout;
-  DataType out_dtype;
-
-  // Quantization related attributes.
-  int32_t input_zero_point;
-  int32_t kernel_zero_point;
-  // The input tensor scale and kernel tensor scales are stored
-  // for easy access to this information.
-  double input_scale;
-  double kernel_scale;
-
-  TVM_DECLARE_ATTRS(QnnConv2DAttrs, "relay.attrs.QnnConv2DAttrs") {
-    TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
-        .describe("Specifies the strides of the convolution.");
-    TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
-        .describe("If padding is non-zero, then the input is implicitly zero-padded"
-                  "on both sides for padding number of points");
-    TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
-        .describe("Specifies the dilation rate to use for dilated convolution.");
-    TVM_ATTR_FIELD(groups).set_default(1)
-        .describe("Controls the connections between inputs and outputs."
-                  "At groups=1, all inputs are convolved to all outputs."
-                  "At groups=2, the operation becomes equivalent to having two convolution"
-                  "layers side by side, each seeing half the input channels, and producing"
-                  "half the output channels, and both subsequently concatenated.");
-    TVM_ATTR_FIELD(channels)
-        .describe("The number of output channels in the convolution."
-                  " If it is not set, inferred by shape of the weight.")
-        .set_default(NullValue<IndexExpr>());
-    TVM_ATTR_FIELD(kernel_size)
-        .describe("Specifies the dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
-    TVM_ATTR_FIELD(data_layout).set_default("NCHW")
-        .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
-                  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
-                  "dimensions respectively. Convolution is applied on the 'H' and"
-                  "'W' dimensions.");
-    TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
-        .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
-                  "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
-                  "dimensions respectively.");
-    TVM_ATTR_FIELD(out_layout).set_default("")
-        .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
-                  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
-                  "dimensions respectively. Default to be same as input layout.");
-    TVM_ATTR_FIELD(out_dtype)
-        .set_default(NullValue<DataType>())
-        .describe("Output data type, set to explicit type under mixed precision setting");
-    TVM_ATTR_FIELD(input_zero_point)
-        .describe("The zero point of the input tensor.");
-    TVM_ATTR_FIELD(kernel_zero_point)
-        .describe("The zero point of the kernel tensor.");
-    TVM_ATTR_FIELD(input_scale)
-      .describe("The quantization scale for the input tensor.");
-    TVM_ATTR_FIELD(kernel_scale)
-      .describe("The quantization scale for the weight tensor.");
-  }
-};
-
-/*! \brief Attribute for QNN binary operator */
-struct QnnBinaryOpAttrs : public tvm::AttrsNode<QnnBinaryOpAttrs> {
-  int32_t lhs_zero_point;
-  double lhs_scale;
-  int32_t rhs_zero_point;
-  double rhs_scale;
-  int32_t output_zero_point;
-  double output_scale;
-
-  TVM_DECLARE_ATTRS(QnnBinaryOpAttrs, "relay.attrs.QnnBinaryOpAttrs") {
-    TVM_ATTR_FIELD(lhs_zero_point)
-      .describe("The zero_point for the lhs input tensor of this op.");
-    TVM_ATTR_FIELD(lhs_scale)
-      .describe("The scale for the lhs input tensor of this op.");
-    TVM_ATTR_FIELD(rhs_zero_point)
-      .describe("The zero_point for the rhs input tensor of this op.");
-    TVM_ATTR_FIELD(rhs_scale)
-      .describe("The scale for the rhs input tensor of this op.");
-    TVM_ATTR_FIELD(output_zero_point)
-      .describe("The zero_point for the activation of this op.");
-    TVM_ATTR_FIELD(output_scale)
-      .describe("The scale for the activation of this op.");
-  }
-};
-
-/*! \brief Attributes for qnn dense operator */
-struct QnnDenseAttrs : public tvm::AttrsNode<QnnDenseAttrs> {
-  IndexExpr units;
-  DataType out_dtype;
-  // Quantization related attributes.
-  int32_t input_zero_point;
-  int32_t kernel_zero_point;
-  double input_scale;
-  double kernel_scale;
-
-  TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.QnnDenseAttrs") {
-    TVM_ATTR_FIELD(units)
-      .describe("Number of hidden units of the dense transformation.");
-    TVM_ATTR_FIELD(out_dtype)
-      .describe("Output data type, set to explicit type under mixed precision setting");
-    TVM_ATTR_FIELD(input_zero_point)
-      .describe("The zero point of the input tensor.");
-    TVM_ATTR_FIELD(kernel_zero_point)
-      .describe("The zero point of the kernel tensor.");
-    TVM_ATTR_FIELD(input_scale)
-      .describe("The input tensor scale.");
-    TVM_ATTR_FIELD(kernel_scale)
-      .describe("The kernel tensor scale.");
   }
 };
 
index e2aaa79..73d18a4 100644 (file)
@@ -20,6 +20,7 @@
 """
 
 import numpy as np
+from tvm import relay
 from tvm.relay.qnn.op.qnn import dequantize
 
 zero_centered_uint8_quantized_range = np.float32(255)
@@ -54,8 +55,8 @@ def _dequantize_zero_centered(data,
 
     real_range = np.max([np.abs(np.float32(data_min)),
                          np.abs(np.float32(data_max))])
-    scale = np.divide(real_range, quantized_range)
-    zero_point = 0
+    scale = relay.const(np.divide(real_range, quantized_range), 'float32')
+    zero_point = relay.const(0, 'int32')
     return dequantize(data, scale, zero_point)
 
 
@@ -186,9 +187,11 @@ def _dequantize_mxnet_min_max_uint8(data,
     max_limit = np.float64(iinfo.max)
     imin_range = np.float64(imin_range)
     imax_range = np.float64(imax_range)
-    scale = np.divide((imax_range - imin_range),
-                      (max_limit - min_limit))
-    zero_point = np.int(-1 * np.divide(imin_range, scale))
+    scale_val = np.divide((imax_range - imin_range),
+                          (max_limit - min_limit))
+    zero_point_val = np.int(-1 * np.divide(imin_range, scale_val))
+    scale = relay.const(scale_val, 'float32')
+    zero_point = relay.const(zero_point_val, 'int32')
     return dequantize(data, scale, zero_point)
 
 
index e2e01e5..284a8c8 100644 (file)
@@ -20,11 +20,13 @@ from __future__ import absolute_import as _abs
 import math
 import numpy as np
 import tvm
+from tvm import relay
 from .. import analysis
 from .. import expr as _expr
 from .. import module as _module
 from .. import op as _op
 from .. import qnn as _qnn
+from ..util import get_scalar_from_constant
 from ... import nd as _nd
 from .common import ExprTable
 from .common import infer_shape as _infer_shape
@@ -177,8 +179,8 @@ class OperatorConverter(object):
                 # Check that the scale and zero points are valid.
                 if scale != 0 or zero_point != 0:
                     qnn_params = dict()
-                    qnn_params['scale'] = scale
-                    qnn_params['zero_point'] = zero_point
+                    qnn_params['scale'] = relay.const(scale, 'float32')
+                    qnn_params['zero_point'] = relay.const(zero_point, 'int32')
             return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
         return return_list
 
@@ -225,8 +227,16 @@ class OperatorConverter(object):
                                   .format(str(tensor_type)))
 
     def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
-        return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \
-                lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point']
+        lhs_scale = lhs_tensor.qnn_params['scale']
+        rhs_scale = rhs_tensor.qnn_params['scale']
+        lhs_zero_point = lhs_tensor.qnn_params['zero_point']
+        rhs_zero_point = rhs_tensor.qnn_params['zero_point']
+        lhs_scale_value = get_scalar_from_constant(lhs_scale)
+        rhs_scale_value = get_scalar_from_constant(rhs_scale)
+        lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
+        rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
+        return lhs_scale_value == rhs_scale_value and \
+                lhs_zero_point_value == rhs_zero_point_value
 
     def is_quantized(self, op):
         """Check if an input tensor is quantized."""
@@ -750,13 +760,11 @@ class OperatorConverter(object):
         weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
 
         if input_tensor.qnn_params:
-            input_scale = input_tensor.qnn_params['scale']
-            kernel_scale = weight_tensor.qnn_params['scale']
             out = _qnn.op.dense(in_expr, weight_expr,
                                 input_zero_point=input_tensor.qnn_params['zero_point'],
                                 kernel_zero_point=weight_tensor.qnn_params['zero_point'],
-                                input_scale=input_scale,
-                                kernel_scale=kernel_scale,
+                                input_scale=input_tensor.qnn_params['scale'],
+                                kernel_scale=weight_tensor.qnn_params['scale'],
                                 out_dtype='int32')
         else:
             out = _op.nn.dense(in_expr, weight_expr)
@@ -783,11 +791,16 @@ class OperatorConverter(object):
 
         # Finally if the dense is quantized. Add a requantize at the end.
         if output_tensor.qnn_params:
-            input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale']
-            input_zero_point = 0
+            data_scale = input_tensor.qnn_params['scale']
+            weight_scale = weight_tensor.qnn_params['scale']
+            data_scale_val = get_scalar_from_constant(data_scale)
+            weight_scale_val = get_scalar_from_constant(weight_scale)
+            new_input_scale_val = data_scale_val * weight_scale_val
+            new_input_scale = relay.const(new_input_scale_val, 'float32')
+            new_input_zero_point = relay.const(0, 'int32')
             out = _qnn.op.requantize(out,
-                                     input_scale=input_scale,
-                                     input_zero_point=input_zero_point,
+                                     input_scale=new_input_scale,
+                                     input_zero_point=new_input_zero_point,
                                      output_scale=output_tensor.qnn_params['scale'],
                                      output_zero_point=output_tensor.qnn_params['zero_point'],
                                      out_dtype=output_tensor_type_str)
@@ -989,11 +1002,16 @@ class OperatorConverter(object):
 
         # Finally if the conv is quantized. Add a requantize at the end.
         if output_tensor.qnn_params:
-            input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale']
-            input_zero_point = 0
+            data_scale = input_tensor.qnn_params['scale']
+            weight_scale = weight_tensor.qnn_params['scale']
+            data_scale_val = get_scalar_from_constant(data_scale)
+            weight_scale_val = get_scalar_from_constant(weight_scale)
+            new_input_scale_val = data_scale_val * weight_scale_val
+            new_input_scale = relay.const(new_input_scale_val, 'float32')
+            new_input_zero_point = relay.const(0, 'int32')
             out = _qnn.op.requantize(out,
-                                     input_scale=input_scale,
-                                     input_zero_point=input_zero_point,
+                                     input_scale=new_input_scale,
+                                     input_zero_point=new_input_zero_point,
                                      output_scale=output_tensor.qnn_params['scale'],
                                      output_zero_point=output_tensor.qnn_params['zero_point'],
                                      out_dtype=output_tensor_type_str)
index 47badca..042dcb9 100644 (file)
@@ -20,4 +20,3 @@ from __future__ import absolute_import as _abs
 from .qnn import *
 from .op import register_qnn_legalize
 from . import legalizations
-from . import op_attrs
index 1bb28d8..f57fef2 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import absolute_import
 import tvm
 from tvm import relay
 from .. import op as reg
+from ...util import get_scalar_from_constant
 
 #################################################
 # Register the functions for different operators.
@@ -76,20 +77,13 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
     """
 
     # Collect the input exprs.
-    data, kernel = inputs
-
-    input_zp = attrs['input_zero_point']
-    kernel_zp = attrs['kernel_zero_point']
+    data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs
 
     shift_data = relay.subtract(relay.cast(data, dtype='int16'),
-                                relay.const(input_zp, 'int16'))
+                                relay.cast(input_zero_point, 'int16'))
     shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'),
-                                  relay.const(kernel_zp, 'int16'))
+                                  relay.cast(kernel_zero_point, 'int16'))
     new_attrs = {k : attrs[k] for k in attrs.keys()}
-    del new_attrs['kernel_zero_point']
-    del new_attrs['input_zero_point']
-    del new_attrs['input_scale']
-    del new_attrs['kernel_scale']
     return relay_op(shift_data, shift_kernel, **new_attrs)
 
 # Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
@@ -136,36 +130,36 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
         data_modified = relay.cast(data, 'int32')
         data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
         data_modified = relay.cast(data_modified, out_dtype)
-        return (data_modified, zero_point + shift)
+        zero_point_val = get_scalar_from_constant(zero_point)
+        zero_point_modified = relay.const(zero_point_val + shift, 'int32')
+        return (data_modified, zero_point_modified)
 
     # Collect the dtypes.
     data_dtype = types[0].dtype
     kernel_dtype = types[1].dtype
 
     # Collect the input exprs.
-    data, kernel = inputs
+    data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs
 
     # VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
     if data_dtype == 'uint8' and kernel_dtype == 'int8':
         return None
 
     # Shift input if necessary.
-    input_zp = attrs['input_zero_point']
     if data_dtype == 'int8':
         # Compute (QA + 128) and (zp_a + 128)
-        data, input_zp = _shift(data, input_zp, 'uint8')
+        data, input_zero_point = _shift(data, input_zero_point, 'uint8')
 
     # Shift kernel if necessary.
-    kernel_zp = attrs['kernel_zero_point']
     if kernel_dtype == 'uint8':
         # Compute (QA - 128) and (zp_a - 128)
-        kernel, kernel_zp = _shift(kernel, kernel_zp, 'int8')
+        kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, 'int8')
 
     # Call qnn.conv2d with modified inputs and zero points.
     new_attrs = {k : attrs[k] for k in attrs.keys()}
-    new_attrs['input_zero_point'] = input_zp
-    new_attrs['kernel_zero_point'] = kernel_zp
-    return relay_op(data, kernel, **new_attrs)
+    return relay_op(data, kernel,
+                    input_zero_point, kernel_zero_point,
+                    input_scale, kernel_scale, **new_attrs)
 
 # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
 def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
@@ -199,7 +193,9 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
         data_modified = relay.cast(data, 'int32')
         data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
         data_modified = relay.cast(data_modified, out_dtype)
-        return (data_modified, zero_point + shift)
+        zero_point_val = get_scalar_from_constant(zero_point)
+        zero_point_modified = relay.const(zero_point_val + shift, 'int32')
+        return (data_modified, zero_point_modified)
 
     # Collect the dtypes.
     data_dtype = types[0].dtype
@@ -209,18 +205,18 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
         return None
 
     # Collect the input exprs.
-    data, kernel = inputs
+    data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs
 
     assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
             "Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
 
     # Shift input if necessary.
-    input_zp = attrs['input_zero_point']
-    data, input_zp = _shift(data, input_zp, kernel_dtype)
+    data, input_zero_point = _shift(data, input_zero_point, kernel_dtype)
 
     new_attrs = {k : attrs[k] for k in attrs.keys()}
-    new_attrs['input_zero_point'] = input_zp
-    return relay_op(data, kernel, **new_attrs)
+    return relay_op(data, kernel,
+                    input_zero_point, kernel_zero_point,
+                    input_scale, kernel_scale, **new_attrs)
 
 def is_fast_int8_on_intel():
     """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
index c79c99c..2d950ba 100644 (file)
@@ -18,7 +18,6 @@
 """QNN dialect operators."""
 
 from __future__ import absolute_import as _abs
-from tvm.expr import FloatImm, IntImm
 from tvm.relay.expr import Tuple
 from . import _make
 
@@ -42,16 +41,16 @@ def requantize(data,
     data : tvm.relay.Expr
         The input data to the operator.
 
-    input_scale: float
+    input_scale: tvm.relay.Expr
         The quantization scale for the input tensor.
 
-    input_zero_point: int
+    input_zero_point: tvm.relay.Expr
         The zero point of the input tensor.
 
-    output_scale: float
+    output_scale: tvm.relay.Expr
         The quantization scale for the output tensor.
 
-    output_zero_point: int
+    output_zero_point: tvm.relay.Expr
         The zero point of the output tensor.
 
     rounding : string, optional
@@ -92,9 +91,9 @@ def quantize(data,
     ----------
     data : tvm.relay.Expr
         The input tensor to be quantized. Can be of type float32.
-    output_zero_point : int
+    output_zero_point : tvm.relay.Expr
         The output zero_point.
-    output_scale : float
+    output_scale : tvm.relay.Expr
         The output scale.
     out_dtype : str, optional
         The data type of the input tensor. Can be [int8, uint8]
@@ -122,9 +121,9 @@ def dequantize(data,
     ----------
     data : tvm.relay.Expr
         The input tensor to be dequantized. Can be of type [int8, uint8].
-    input_zero_point : int
+    input_zero_point : tvm.relay.Expr
         The output zero_point.
-    input_scale : float
+    input_scale : tvm.relay.Expr
         The output scale.
     Returns
     -------
@@ -150,16 +149,16 @@ def concatenate(data,
     data : Union(List[relay.Expr], Tuple[relay.Expr])
         The list of quantized tensors.
 
-    input_scales : List[float32]
+    input_scales : List[relay.Expr]
         The list of scales of input quantized tensors.
 
-    input_zero_points : List[int32]
+    input_zero_points : List[relay.Expr]
         The list of zero points of input quantized tensors.
 
-    output_scale : float32
+    output_scale : relay.Expr
         The scale of the output quantized tensor.
 
-    output_zero_point : int32
+    output_zero_point : relay.Expr
         The zero point of the output quantized tensor.
 
     axis : int
@@ -176,10 +175,12 @@ def concatenate(data,
         raise ValueError("relay.concatenate requires data to be non-empty.")
     if not isinstance(axis, int):
         raise ValueError("For now, we only support integer axis")
+    input_scales = list(input_scales)
+    input_zero_points = list(input_zero_points)
 
     return _make.concatenate(Tuple(data),
-                             [FloatImm("float64", x) for x in input_scales],
-                             [IntImm("int32", x) for x in input_zero_points],
+                             Tuple(input_scales),
+                             Tuple(input_zero_points),
                              output_scale,
                              output_zero_point,
                              axis)
@@ -218,22 +219,22 @@ def conv2d(data,
     kernel : tvm.relay.Expr
         The kernel expressions.
 
-    input_zero_point: int
+    input_zero_point: tvm.relay.Expr
            The zero point of the data distribution.
 
-    input_scale: float
+    kernel_zero_point: tvm.relay.Expr
+           The zero point of the quantized_kernel distribution.
+
+    input_scale: tvm.relay.Expr
            The scale for the input tensor. The scale for the input tensor is
            stored purely for convenience here. See more commentary below.
 
-    kernel_scale: float
+    kernel_scale: tvm.relay.Expr
            The scale for the weight tensor. The scale for the weight tensor is
            stored for access to this during relay. This information is not
            needed in the pass pipeline after qnn.conv2d is lowered to the
            sequence of steps as in nn.conv2d. See also input_scale in Requantize.
 
-    kernel_zero_point: int
-           The zero point of the quantized_kernel distribution.
-
     strides : tuple of int, optional
         The strides of convolution.
 
@@ -299,19 +300,22 @@ def add(lhs,
     lhs_scale: float
         The scale of the lhs quantized expr.
 
-    lhs_zero_point: int
+    lhs_scale: relay.Expr
+        The scale of the lhs quantized expr.
+
+    lhs_zero_point: relay.Expr
        The zero point of lhs quantized expr.
 
-    rhs_scale: float
+    rhs_scale: relay.Expr
         The scale of the rhs quantized expr.
 
-    rhs_zero_point: int
+    rhs_zero_point: relay.Expr
        The zero point of rhs quantized expr.
 
-    output_scale: float
+    output_scale: relay.Expr
         The scale of the output quantized expr.
 
-    output_zero_point: int
+    output_zero_point: relay.Expr
        The zero point of output quantized expr.
 
     Returns
@@ -347,13 +351,13 @@ def dense(data,
         The quantized input data to the operator.
     weight : tvm.relay.Expr
         The quantized weight expressions.
-    input_zero_point: int
+    input_zero_point: tvm.relay.Expr
         The input zero point.
-    kernel_zero_point: int
+    kernel_zero_point: tvm.relay.Expr
         The kernel zero point.
-    input_scale: float
+    input_scale: tvm.relay.Expr
         The scale for the input tensor.
-    kernel_scale: float
+    kernel_scale: tvm.relay.Expr
         The scale for the weight tensor. The scale for the weight tensor is
         stored for access to this during relay. This information is not
         needed in the pass pipeline after qnn.conv2d is lowered to the
@@ -391,22 +395,22 @@ def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point,
     rhs : relay.Expr
         The right hand side quantized input data.
 
-    lhs_scale: float
+    lhs_scale: relay.Expr
         The scale of the lhs quantized expr.
 
-    lhs_zero_point: int
+    lhs_zero_point: relay.Expr
        The zero point of lhs quantized expr.
 
-    rhs_scale: float
+    rhs_scale: relay.Expr
         The scale of the rhs quantized expr.
 
-    rhs_zero_point: int
+    rhs_zero_point: relay.Expr
        The zero point of rhs quantized expr.
 
-    output_scale: float
+    output_scale: relay.Expr
         The scale of the output quantized expr.
 
-    output_zero_point: int
+    output_zero_point: relay.Expr
        The zero point of output quantized expr.
 
     Returns
similarity index 53%
rename from python/tvm/relay/qnn/op/op_attrs.py
rename to python/tvm/relay/util.py
index 24ca3b4..b207182 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""The attributes node used for QNN operators"""
+# pylint: disable=wildcard-import, redefined-builtin, invalid-name
+""" Utility functions that are used across many directories. """
+from __future__ import absolute_import
+import numpy as np
+from . import expr as _expr
 
-from ....attrs import Attrs
-from ...base import register_relay_attr_node
-
-@register_relay_attr_node
-class QnnConv2DAttrs(Attrs):
-    """Attributes for qnn.conv2d"""
-
-@register_relay_attr_node
-class QnnDenseAttrs(Attrs):
-    """Attributes for qnn.dense"""
+def get_scalar_from_constant(expr):
+    """ Returns scalar value from Relay constant scalar. """
+    assert isinstance(expr, _expr.Constant) and not expr.data.shape, \
+            "Expr is not a constant scalar."
+    value = expr.data.asnumpy()
+    if value.dtype == np.dtype(np.int32):
+        return int(value)
+    if value.dtype == np.dtype(np.float32):
+        return float(value)
+    assert False, "Constant expr must be float32/int32"
+    return None  # To suppress pylint
index d3ec342..16356b6 100644 (file)
@@ -285,6 +285,7 @@ inline Expr Log(Expr e) {
 template <typename T>
 T GetScalarFromConstant(Expr expr) {
   const auto* n = expr.as<ConstantNode>();
+  CHECK(n) << "Expr must be a constant expr - " << AsText(expr, false);
   CHECK(n->is_scalar());
   return static_cast<T*>(n->data->data)[0];
 }
index f535567..e970e2b 100644 (file)
@@ -42,20 +42,18 @@ namespace qnn {
 Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
                         const Array<tvm::relay::Type>& arg_types) {
   // Get the attrs.
-  CHECK_EQ(new_args.size(), 2);
+  CHECK_EQ(new_args.size(), 8);
   auto& lhs = new_args[0];
   auto& rhs = new_args[1];
-  const auto* binary_op_attrs = attrs.as<QnnBinaryOpAttrs>();
-  CHECK(binary_op_attrs != nullptr);
-  auto lhs_scale = binary_op_attrs->lhs_scale;
-  auto lhs_zero_point = binary_op_attrs->lhs_zero_point;
-  auto rhs_scale = binary_op_attrs->rhs_scale;
-  auto rhs_zero_point = binary_op_attrs->rhs_zero_point;
-  auto output_scale = binary_op_attrs->output_scale;
-  auto output_zero_point = binary_op_attrs->output_zero_point;
+  auto& lhs_scale = new_args[2];
+  auto& lhs_zero_point = new_args[3];
+  auto& rhs_scale = new_args[4];
+  auto& rhs_zero_point = new_args[5];
+  auto& output_scale = new_args[6];
+  auto& output_zero_point = new_args[7];
 
   // Get the input dtype and shape.
-  CHECK_EQ(arg_types.size(), 3);
+  CHECK_EQ(arg_types.size(), 9);
   auto tensor_type = arg_types[0].as<TensorTypeNode>();
   auto input_dtype = tensor_type->dtype;
   auto input_shape = tensor_type->shape;
@@ -82,7 +80,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
 
   // Requantize LHS if necessary.
   auto requantized_lhs = lhs;
-  if (lhs_scale != output_scale || lhs_zero_point != output_zero_point) {
+  if (!IsEqualScalar(lhs_scale, output_scale) ||
+      !IsEqualScalar(lhs_zero_point, output_zero_point)) {
     requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point, output_scale,
                                  output_zero_point, DataType::Int(32));
   } else {
@@ -91,7 +90,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
 
   // Requantize RHS if necessary.
   auto requantized_rhs = rhs;
-  if (rhs_scale != output_scale || rhs_zero_point != output_zero_point) {
+  if (!IsEqualScalar(rhs_scale, output_scale) ||
+      !IsEqualScalar(rhs_zero_point, output_zero_point)) {
     requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point, output_scale,
                                  output_zero_point, DataType::Int(32));
   } else {
@@ -101,9 +101,9 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
   auto output = Add(requantized_lhs, requantized_rhs);
 
   // Subtract zero point.
-  if (output_zero_point != 0) {
-    auto output_zp = MakeConstantScalar(DataType::Int(32), output_zero_point);
-    output = Subtract(output, output_zp);
+  auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
+  if (!IsEqualScalar(output_zero_point, zero_scalar)) {
+    output = Subtract(output, output_zero_point);
   }
 
   // Go back to lower precision.
index 43d47e2..7dfa63f 100644 (file)
@@ -34,19 +34,48 @@ namespace tvm {
 namespace relay {
 namespace qnn {
 
-TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs);
-
-Expr MakeQnnConcatenate(Expr data, Array<tvm::Expr> input_scales,
-                        Array<tvm::Expr> input_zero_points, double output_scale,
-                        int32_t output_zero_point, int axis) {
-  auto attrs = make_object<QnnConcatenateAttrs>();
-  attrs->input_scales = std::move(input_scales);
-  attrs->input_zero_points = std::move(input_zero_points);
-  attrs->output_scale = output_scale;
-  attrs->output_zero_point = output_zero_point;
+bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                       const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 6);
+
+  // Check the scale and zero point types
+  const auto* input_scales_tuple = types[1].as<TupleTypeNode>();
+  if (input_scales_tuple == nullptr) {
+    throw relay::Error(
+        RELAY_ERROR("qnn concatenate requires a tuple of scales as the second argument, found "
+                    << PrettyPrint(types[1])));
+  }
+  for (const auto& input_scale : input_scales_tuple->fields) {
+    CHECK(IsScalarType(input_scale, DataType::Float(32)));  // input_scales[idx]
+  }
+
+  const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>();
+  if (input_zero_points_tuple == nullptr) {
+    throw relay::Error(
+        RELAY_ERROR("qnn concatenate requires a tuple of zero_points as the third argument, found "
+                    << PrettyPrint(types[2])));
+  }
+  for (const auto& input_zero_point : input_zero_points_tuple->fields) {
+    CHECK(IsScalarType(input_zero_point, DataType::Int(32)));  // input_zero_points[idx]
+  }
+
+  CHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
+  CHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
+
+  // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
+  // Concatenate infer type function.
+  Array<Type> tensor_types = {types[0], types[5]};
+  return ConcatenateRel<ConcatenateAttrs>(tensor_types, 2, attrs, reporter);
+}
+
+Expr MakeQnnConcatenate(Expr data, Expr input_scales, Expr input_zero_points, Expr output_scale,
+                        Expr output_zero_point, int axis) {
+  auto attrs = make_object<ConcatenateAttrs>();
   attrs->axis = axis;
   static const Op& op = Op::Get("qnn.concatenate");
-  return CallNode::make(op, {data}, Attrs(attrs), {});
+  return CallNode::make(op,
+                        {data, input_scales, input_zero_points, output_scale, output_zero_point},
+                        Attrs(attrs), {});
 }
 
 /*
@@ -59,14 +88,14 @@ Expr MakeQnnConcatenate(Expr data, Array<tvm::Expr> input_scales,
 Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
                                 const Array<tvm::relay::Type>& arg_types) {
   // Get the attrs.
-  CHECK_EQ(new_args.size(), 1);
+  CHECK_EQ(new_args.size(), 5);
   auto& data = new_args[0];
-  const auto* concatenate_attrs = attrs.as<QnnConcatenateAttrs>();
+  auto& input_scales = new_args[1];
+  auto& input_zero_points = new_args[2];
+  auto& output_scale = new_args[3];
+  auto& output_zero_point = new_args[4];
+  const auto* concatenate_attrs = attrs.as<ConcatenateAttrs>();
   CHECK(concatenate_attrs != nullptr);
-  auto input_scales = concatenate_attrs->input_scales;
-  auto input_zero_points = concatenate_attrs->input_zero_points;
-  auto output_scale = concatenate_attrs->output_scale;
-  auto output_zero_point = concatenate_attrs->output_zero_point;
 
   // Get the input dtype and shape.
   CHECK_GE(arg_types.size(), 1);
@@ -83,21 +112,24 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
   auto tuple_data = data.as<TupleNode>();
   CHECK(tuple_data != nullptr);
 
+  auto tuple_input_scales = input_scales.as<TupleNode>();
+  CHECK(tuple_input_scales != nullptr);
+
+  auto tuple_input_zero_points = input_zero_points.as<TupleNode>();
+  CHECK(tuple_input_zero_points != nullptr);
+
   int idx = 0;
   Array<Expr> requantized_exprs;
   for (auto quantized_expr : tuple_data->fields) {
     // Get the input scale for the idx quantized input tensor.
-    auto input_scale_expr = input_scales[idx].as<tvm::ir::FloatImm>();
-    CHECK(input_scale_expr != nullptr);
-    auto input_scale = input_scale_expr->value;
+    auto input_scale = tuple_input_scales->fields[idx];
 
     // Get the zero point for the idx quantized input tensor.
-    auto input_zero_point_expr = input_zero_points[idx].as<tvm::ir::IntImm>();
-    CHECK(input_zero_point_expr != nullptr);
-    auto input_zero_point = input_zero_point_expr->value;
+    auto input_zero_point = tuple_input_zero_points->fields[idx];
 
     // Check if output and input qnn params are same. If not, requantize.
-    if (input_scale != output_scale || input_zero_point != output_zero_point) {
+    if (!IsEqualScalar(input_scale, output_scale) ||
+        !IsEqualScalar(input_zero_point, output_zero_point)) {
       // Get the input shape and dtype.
       auto tensor_type = tuple_type->fields[idx].as<TensorTypeNode>();
       auto input_dtype = tensor_type->dtype;
@@ -118,11 +150,15 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
 RELAY_REGISTER_OP("qnn.concatenate")
 .describe(R"code(Concatenate the quantized input tensors along the given axis.
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<QnnConcatenateAttrs>()
-.set_num_inputs(1)
+.set_attrs_type<ConcatenateAttrs>()
+.set_num_inputs(5)
 .add_argument("data", "Tensor", "The tensor to concatenate.")
+.add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.")
+.add_argument("input_zero_points", "Tensor", "The quantization zero_points of the input tensors.")
+.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
+.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
 .set_support_level(11)
-.add_type_rel("QnnConcatenate", ConcatenateRel<QnnConcatenateAttrs>)
+.add_type_rel("QnnConcatenate", QnnConcatenateRel)
 .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize);
 
 TVM_REGISTER_API("relay.qnn.op._make.concatenate")
index 669b04f..839fcbd 100644 (file)
@@ -36,16 +36,15 @@ namespace relay {
 namespace qnn {
 
 // relay.op.qnn.conv2d
-TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs);
 
 bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                   const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 3);
+  CHECK_EQ(types.size(), 7);
   const auto* data = types[0].as<TensorTypeNode>();
   const auto* weight = types[1].as<TensorTypeNode>();
   if (data == nullptr || weight == nullptr) return false;
-  const auto* param = attrs.as<QnnConv2DAttrs>();
-  CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr.";
+  const auto* param = attrs.as<Conv2DAttrs>();
+  CHECK(param != nullptr) << "Conv2DAttrs cannot be nullptr.";
   CHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
       << "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
   CHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
@@ -53,10 +52,20 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   CHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32))
       << "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
   CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
-  return Conv2DRel<QnnConv2DAttrs>(types, num_inputs, attrs, reporter);
+
+  // Check the types of scale and zero points.
+  CHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+  CHECK(IsScalarType(types[3], DataType::Int(32)));    // kernel_zero_point
+  CHECK(IsScalarType(types[4], DataType::Float(32)));  // input_scale
+  CHECK(IsScalarType(types[5], DataType::Float(32)));  // kernel_scale
+
+  // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
+  // Conv2D infer type function.
+  Array<Type> tensor_types = {types[0], types[1], types[6]};
+  return Conv2DRel<Conv2DAttrs>(tensor_types, 3, attrs, reporter);
 }
 
-bool is_depthwise(const QnnConv2DAttrs* param) {
+bool is_depthwise(const Conv2DAttrs* param) {
   return param->channels.defined() && tvm::ir::Equal(param->channels, param->groups) &&
          param->groups != 1;
 }
@@ -70,7 +79,7 @@ using WorkloadType = std::tuple<int, int, int, int, int, int>;
  * \param param The qnn conv2d attributes.
  * \return A tuple of workload.
  */
-WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv2DAttrs* param) {
+WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const Conv2DAttrs* param) {
   // Get conv parameters.
   const auto in_shape = get_shape(arg_types[0]);
   int batch_size, in_channels;
@@ -121,6 +130,8 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
  * \brief Fallback to simpler lowering for dilation or grouped conv.
  * \param data The input expr.
  * \param weight The weight expr.
+ * \param input_zero_point The input zero point expr.
+ * \param kernel_zero_point The kernel zero point expr.
  * \param param The qnn conv2d attributes.
  * \return The fallback lowered sequence of Relay expr.
  * \note In case of dilation, normal lowering would require a dilated pool.
@@ -128,18 +139,20 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
  *       Relay operations. This will potentially lead to performance degradation
  *       as the convolution is called on int32 tensors instead of int8 tensors.
  */
-Expr Conv2DFallBack(const Expr& data, const Expr& weight, const QnnConv2DAttrs* param) {
+Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point,
+                    const Expr& kernel_zero_point, const Conv2DAttrs* param) {
   // Upcast the zero point to Int16.
-  auto zp_data = MakeConstantScalar(DataType::Int(16), param->input_zero_point);
-  auto zp_kernel = MakeConstantScalar(DataType::Int(16), param->kernel_zero_point);
+  auto zp_data = Cast(input_zero_point, DataType::Int(16));
+  auto zp_kernel = Cast(kernel_zero_point, DataType::Int(16));
 
   auto shifted_data = Cast(data, DataType::Int(16));
-  if (param->input_zero_point != 0) {
+  auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
+  if (!IsEqualScalar(input_zero_point, zero_scalar)) {
     shifted_data = Subtract(Cast(data, DataType::Int(16)), zp_data);
   }
 
   auto shifted_kernel = Cast(weight, DataType::Int(16));
-  if (param->kernel_zero_point != 0) {
+  if (!IsEqualScalar(kernel_zero_point, zero_scalar)) {
     shifted_kernel = Subtract(Cast(weight, DataType::Int(16)), zp_kernel);
   }
 
@@ -151,13 +164,14 @@ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const QnnConv2DAttrs*
 /*
  * \brief Pad the input data.
  * \param data The input expr.
+ * \param input_zero_point The input zero point expr.
  * \return The padded input expr.
  * \note For quantized convolution, the input has to be padded with zero point
  *       instead of zero. This might lead to performance degradation as pad
  *       cannot be fused with conv in Relay. In case we see performance
  *       degradation, we can change the conv2D API to accept a pad_const value.
  */
-Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
+Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2DAttrs* param) {
   // 1) Pad the input data
   auto padded_data = data;
   auto pad_h_value = get_const_int(param->padding[0]);
@@ -176,7 +190,8 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
     } else {
       LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
     }
-    padded_data = Pad(data, pad_width, param->input_zero_point, "constant");
+    auto pad_value = GetScalarFromConstant<int>(input_zero_point);
+    padded_data = Pad(data, pad_width, pad_value, "constant");
   }
   return padded_data;
 }
@@ -184,6 +199,7 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
 /*
  * \brief Calculates the second term in the qnn.conv2d depthwise lowering sequence.
  * \param padded_data The padded data expr.
+ * \param kernel_zero_point The kernel zero point expr.
  * \param param The qnn conv2d attributes.
  * \param kernel_h The height of kernel.
  * \param kernel_w The width of kernel.
@@ -197,11 +213,9 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
  *       However, deeper analysis shows that we can reduce r,s using avg_pool2d,
  *       followed by repeat on the C axis by cm times.
  */
-Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h,
-                               int kernel_w, int channel_multiplier) {
-  // Constant Expr for the kernel zero point.
-  auto zp_kernel = MakeConstantScalar(DataType::Int(32), param->kernel_zero_point);
-
+Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point,
+                               const Conv2DAttrs* param, int kernel_h, int kernel_w,
+                               int channel_multiplier) {
   auto casted_t2 = Cast(padded_data, DataType::Int(32));
 
   // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
@@ -210,8 +224,8 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* pa
   // pool_size is 1x1, we don't need avg_pool2d.
   auto reduced_t2 = casted_t2;
   if (kernel_h * kernel_w != 1) {
-    auto scaled_hw_t2 = Multiply(
-        casted_t2, MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w));
+    auto scaled_hw_t2 =
+        Multiply(casted_t2, MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w));
     Array<IndexExpr> padding({0, 0});
     reduced_t2 =
         AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, padding, param->data_layout,
@@ -220,8 +234,9 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* pa
   }
 
   auto multiplied_t2 = reduced_t2;
-  if (param->kernel_zero_point != 1) {
-    multiplied_t2 = Multiply(zp_kernel, reduced_t2);
+  auto one_scalar = MakeConstantScalar(DataType::Int(32), 1);
+  if (!IsEqualScalar(kernel_zero_point, one_scalar)) {
+    multiplied_t2 = Multiply(kernel_zero_point, reduced_t2);
   }
 
   // Reduce the C dimension. Find the dimension.
@@ -243,6 +258,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* pa
 /*
  * \brief Calculates the third term in the qnn.conv2d depthwise lowering sequence.
  * \param weight The weight expr.
+ * \param input_zero_point The input zero point expr.
  * \param param The qnn conv2d attributes.
  * \param out_channels The number of output channels.
  * \param channel_multiplier The channel/depth multiplier.
@@ -254,11 +270,8 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* pa
  *       This can be achieved by calling reduce on r and s axis. The tensor can be then reshaped to
  *       (1, oc, 1, 1) as (oc/m, oc%m) are just contiguous memory locations.
  */
-Expr DepthwiseConv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels,
-                              int channel_multiplier) {
-  // Constant expr for input zero point.
-  auto zp_data = MakeConstantScalar(DataType::Int(32), param->input_zero_point);
-
+Expr DepthwiseConv2DThirdTerm(const Expr& weight, const Expr& input_zero_point,
+                              const Conv2DAttrs* param, int out_channels, int channel_multiplier) {
   // Find which dimensions are R, S.
   Array<Integer> axes_t3;
   if (param->kernel_layout == "OIHW") {
@@ -284,15 +297,17 @@ Expr DepthwiseConv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, i
   }
   auto reshaped_t3 = Reshape(reduced_t3, newshape);
 
-  if (param->input_zero_point == 1) {
+  auto one_scalar = MakeConstantScalar(DataType::Int(32), 1);
+  if (IsEqualScalar(input_zero_point, one_scalar)) {
     return reshaped_t3;
   }
-  return Multiply(zp_data, reshaped_t3);
+  return Multiply(input_zero_point, reshaped_t3);
 }
 
 /*
  * \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence.
- * \param param The qnn conv2d attributes.
+ * \param input_zero_point_int The int value of input zero point.
+ * \param kernel_zero_point_int The int value of kernel zero point.
  * \param kernel_h The height of kernel.
  * \param kernel_w The width of kernel.
  * \return The sequence of Relay operators for term4.
@@ -300,8 +315,9 @@ Expr DepthwiseConv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, i
  *
  *       Sigma(r, s) zp_a * zp_w
  */
-Expr DepthwiseConv2DFourthTerm(const QnnConv2DAttrs* param, int kernel_h, int kernel_w) {
-  int scalar_term4 = param->input_zero_point * param->kernel_zero_point * kernel_h * kernel_w;
+Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int kernel_h,
+                               int kernel_w) {
+  int scalar_term4 = input_zero_point_int * kernel_zero_point_int * kernel_h * kernel_w;
   return MakeConstantScalar(DataType::Int(32), scalar_term4);
 }
 
@@ -315,7 +331,7 @@ Expr DepthwiseConv2DFourthTerm(const QnnConv2DAttrs* param, int kernel_h, int ke
  *       Sigma(c,r,s) QW(k, c, r, s) * QA(n, c, h + r, w + s)
  *       This is just conv2d on int tensors.
  */
-Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2DAttrs* param) {
+Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const Conv2DAttrs* param) {
   // Lowering for Term 1
   Array<IndexExpr> padding({0, 0});
   return Conv2D(padded_data, weight, param->strides, padding, param->dilation, param->groups,
@@ -326,6 +342,7 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
 /*
  * \brief Calculates the second term in the qnn.conv2d lowering sequence.
  * \param padded_data The padded data expr.
+ * \param kernel_zero_point The kernel zero point expr.
  * \param param The qnn conv2d attributes.
  * \param kernel_h The height of kernel.
  * \param kernel_w The width of kernel.
@@ -339,11 +356,8 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
  *       followed by a reduce on the C axis. Using avg_pool2d also gives an
  *       opportunity to reuse alter_op_layout infrastructure.
  */
-Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h,
-                      int kernel_w, int out_channels) {
-  // Constant Expr for the kernel zero point.
-  auto zp_kernel = MakeConstantScalar(DataType::Int(32), param->kernel_zero_point);
-
+Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point,
+                      const Conv2DAttrs* param, int kernel_h, int kernel_w, int out_channels) {
   auto casted_t2 = Cast(padded_data, DataType::Int(32));
 
   // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
@@ -366,20 +380,18 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
   // If the pool_size is 1x1, we don't need avg_pool2d.
   auto reduced_t2 = reduced_c_t2;
   if (kernel_h * kernel_w != 1) {
-    reduced_c_t2 = Multiply(
-        reduced_c_t2, MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w));
+    reduced_c_t2 =
+        Multiply(reduced_c_t2, MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w));
     reduced_t2 =
-        AvgPool2D(reduced_c_t2,
-                  param->kernel_size,
-                  param->strides,
-                  padding, param->data_layout,
+        AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, padding, param->data_layout,
                   false,   // ceil_mode
                   false);  // count_include_pad
   }
 
   auto multiplied_t2 = reduced_t2;
-  if (param->kernel_zero_point != 1) {
-    multiplied_t2 = Multiply(zp_kernel, reduced_t2);
+  auto one_scalar = MakeConstantScalar(DataType::Int(32), 1);
+  if (!IsEqualScalar(kernel_zero_point, one_scalar)) {
+    multiplied_t2 = Multiply(kernel_zero_point, reduced_t2);
   }
   return multiplied_t2;
 }
@@ -387,6 +399,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
 /*
  * \brief Calculates the third term in the qnn.conv2d lowering sequence.
  * \param weight The weight expr.
+ * \param input_zero_point The input zero point expr.
  * \param param The qnn conv2d attributes.
  * \param out_channels The number of output channels.
  * \return The sequence of Relay operatos for term3.
@@ -398,10 +411,8 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
  *       a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW
  *       format.
  */
-Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels) {
-  // Constant expr for input zero point.
-  auto zp_data = MakeConstantScalar(DataType::Int(32), param->input_zero_point);
-
+Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Conv2DAttrs* param,
+                     int out_channels) {
   // Find which dimensions are C, R, S.
   Array<Integer> axes_t3;
   if (param->kernel_layout == "OIHW") {
@@ -427,15 +438,17 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_ch
   }
   auto reshaped_t3 = Reshape(reduced_t3, newshape);
 
-  if (param->input_zero_point == 1) {
+  auto one_scalar = MakeConstantScalar(DataType::Int(32), 1);
+  if (IsEqualScalar(input_zero_point, one_scalar)) {
     return reshaped_t3;
   }
-  return Multiply(zp_data, reshaped_t3);
+  return Multiply(input_zero_point, reshaped_t3);
 }
 
 /*
  * \brief Calculates the fourth term in the qnn.conv2d lowering sequence.
- * \param param The qnn conv2d attributes.
+ * \param input_zero_point_int The int value of input zero point.
+ * \param kernel_zero_point_int The int value of kernel zero point.
  * \param in_channels The number of input channels.
  * \param kernel_h The height of kernel.
  * \param kernel_w The width of kernel.
@@ -445,9 +458,10 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_ch
  *       Sigma(c,r,s) zp_a * zp_w
  *
  */
-Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int in_channels, int kernel_h, int kernel_w) {
+Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int in_channels,
+                      int kernel_h, int kernel_w) {
   int scalar_term4 =
-      param->input_zero_point * param->kernel_zero_point * in_channels * kernel_h * kernel_w;
+      input_zero_point_int * kernel_zero_point_int * in_channels * kernel_h * kernel_w;
   return MakeConstantScalar(DataType::Int(32), scalar_term4);
 }
 
@@ -457,6 +471,8 @@ Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int in_channels, int kernel_h
  * \param term2 The term2 of qnn conv2d lowering.
  * \param term3 The term3 of qnn conv2d lowering.
  * \param term4 The term4 of qnn conv2d lowering.
+ * \param input_zero_point_int The int value of input zero point.
+ * \param kernel_zero_point_int The int value of kernel zero point.
  * \param param The qnn conv2d attributes.
  * \return The combined sequence of relay operations.
  * \note The combined operation looks like this
@@ -468,14 +484,14 @@ Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int in_channels, int kernel_h
  *
  */
 Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, const Expr& term4,
-                        const QnnConv2DAttrs* param) {
-  if (param->input_zero_point == 0 && param->kernel_zero_point == 0) {
+                        int input_zero_point_int, int kernel_zero_point_int) {
+  if (input_zero_point_int == 0 && kernel_zero_point_int == 0) {
     // term 2, 3 and 4 become zero.
     return term1;
-  } else if (param->input_zero_point == 0 && param->kernel_zero_point != 0) {
+  } else if (input_zero_point_int == 0 && kernel_zero_point_int != 0) {
     // term 3 and term 4 become zero.
     return Subtract(term1, term2);
-  } else if (param->input_zero_point != 0 && param->kernel_zero_point == 0) {
+  } else if (input_zero_point_int != 0 && kernel_zero_point_int == 0) {
     // term 2 and term 4 become zero.
     return Subtract(term1, term3);
   } else {
@@ -556,10 +572,12 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
  */
 Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
                            const Array<tvm::relay::Type>& arg_types) {
-  CHECK_EQ(new_args.size(), 2);
+  CHECK_EQ(new_args.size(), 6);
   Expr data = new_args[0];
   Expr weight = new_args[1];
-  const auto* param = attrs.as<QnnConv2DAttrs>();
+  Expr input_zero_point = new_args[2];
+  Expr kernel_zero_point = new_args[3];
+  const auto* param = attrs.as<Conv2DAttrs>();
   CHECK(param != nullptr);
   // Assertion checks for exisiing support.
   CHECK_EQ(param->padding.size(), 2) << "qnn.conv2d only supports 2D padding";
@@ -573,41 +591,50 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
   std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
       GetWorkload(arg_types, param);
 
+  // Extract the integer zero points.
+  auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
+  auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);
+
   // Fallback to int32 conv if there is dilation or grouped conv2d
 
   CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
   auto dilation_h = get_const_int(param->dilation[0]);
   auto dilation_w = get_const_int(param->dilation[1]);
   if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) {
-    return Conv2DFallBack(data, weight, param);
+    return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param);
   } else if (is_depthwise(param)) {
     CHECK_NE(channel_multiplier, -1);
-    auto padded_data = Conv2DPadInput(data, param);
+    auto padded_data = Conv2DPadInput(data, input_zero_point, param);
     auto term1 = Conv2DFirstTerm(padded_data, weight, param);
-    auto term2 =
-        DepthwiseConv2DSecondTerm(padded_data, param, kernel_h, kernel_w, channel_multiplier);
-    auto term3 = DepthwiseConv2DThirdTerm(weight, param, out_channels, channel_multiplier);
-    auto term4 = DepthwiseConv2DFourthTerm(param, kernel_h, kernel_w);
-    return Conv2DCombineTerms(term1, term2, term3, term4, param);
+    auto term2 = DepthwiseConv2DSecondTerm(padded_data, kernel_zero_point, param, kernel_h,
+                                           kernel_w, channel_multiplier);
+    auto term3 =
+        DepthwiseConv2DThirdTerm(weight, input_zero_point, param, out_channels, channel_multiplier);
+    auto term4 =
+        DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, kernel_w);
+    return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int,
+                              kernel_zero_point_int);
   }
 
-  auto padded_data = Conv2DPadInput(data, param);
+  auto padded_data = Conv2DPadInput(data, input_zero_point, param);
   auto term1 = Conv2DFirstTerm(padded_data, weight, param);
-  auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels);
-  auto term3 = Conv2DThirdTerm(weight, param, out_channels);
-  auto term4 = Conv2DFourthTerm(param, in_channels, kernel_h, kernel_w);
-  return Conv2DCombineTerms(term1, term2, term3, term4, param);
+  auto term2 =
+      Conv2DSecondTerm(padded_data, kernel_zero_point, param, kernel_h, kernel_w, out_channels);
+  auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels);
+  auto term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
+                                kernel_w);
+  return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int,
+                            kernel_zero_point_int);
 }
 
 // Positional relay function to create quantized conv2d operator
 // used by frontend FFI.
-Expr MakeQnnConv2D(Expr data, Expr weight, int32_t input_zero_point, int32_t kernel_zero_point,
-                   double input_scale, double kernel_scale, Array<IndexExpr> strides,
-                   Array<IndexExpr> padding, Array<IndexExpr> dilation,
-                   int groups, IndexExpr channels, Array<IndexExpr> kernel_size,
-                   std::string data_layout, std::string kernel_layout, std::string out_layout,
-                   DataType out_dtype) {
-  auto attrs = make_object<QnnConv2DAttrs>();
+Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point,
+                   Expr input_scale, Expr kernel_scale, Array<IndexExpr> strides,
+                   Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
+                   IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout,
+                   std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+  auto attrs = make_object<Conv2DAttrs>();
   attrs->strides = std::move(strides);
   attrs->padding = std::move(padding);
   attrs->dilation = std::move(dilation);
@@ -618,12 +645,10 @@ Expr MakeQnnConv2D(Expr data, Expr weight, int32_t input_zero_point, int32_t ker
   attrs->kernel_layout = std::move(kernel_layout);
   attrs->out_layout = std::move(out_layout);
   attrs->out_dtype = std::move(out_dtype);
-  attrs->input_zero_point = std::move(input_zero_point);
-  attrs->kernel_zero_point = std::move(kernel_zero_point);
-  attrs->input_scale = std::move(input_scale);
-  attrs->kernel_scale = std::move(kernel_scale);
   static const Op& op = Op::Get("qnn.conv2d");
-  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
+  return CallNode::make(
+      op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
+      Attrs(attrs), {});
 }
 
 RELAY_REGISTER_OP("qnn.conv2d")
@@ -639,10 +664,14 @@ operator to understand how to scale back the int32 output to (u)int8.
 - **out**:  This depends on the `layout` parameter. Output is 4D array of shape
             (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<QnnConv2DAttrs>()
-.set_num_inputs(2)
+.set_attrs_type<Conv2DAttrs>()
+.set_num_inputs(6)
 .add_argument("data", "Tensor", "The quantized input data tensor.")
 .add_argument("weight", "Tensor", "The quantized weight tensor.")
+.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
+.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
+.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.")
+.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.")
 .set_support_level(11)
 .add_type_rel("QnnConv2D", QnnConv2DRel)
 .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize);
index 2353e5a..c762331 100644 (file)
@@ -35,59 +35,67 @@ namespace relay {
 namespace qnn {
 
 // relay.op.qnn.dense
-TVM_REGISTER_NODE_TYPE(QnnDenseAttrs);
 
 bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                  const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 3);
+  CHECK_EQ(types.size(), 7);
   const auto* data = types[0].as<TensorTypeNode>();
   const auto* weight = types[1].as<TensorTypeNode>();
   if (data == nullptr || weight == nullptr) return false;
-  const auto* param = attrs.as<QnnDenseAttrs>();
-  CHECK(param != nullptr) << "QnnDenseAttrs cannot be nullptr.";
+  const auto* param = attrs.as<DenseAttrs>();
+  CHECK(param != nullptr) << "DenseAttrs cannot be nullptr.";
   CHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
       << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype;
   CHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
       << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype;
   CHECK(param->out_dtype == DataType::Int(32))
       << "Expected quantized dense type(int32) for output but was " << param->out_dtype;
+
+  // Check the types of scale and zero points.
+  CHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+  CHECK(IsScalarType(types[3], DataType::Int(32)));    // kernel_zero_point
+  CHECK(IsScalarType(types[4], DataType::Float(32)));  // input_scale
+  CHECK(IsScalarType(types[5], DataType::Float(32)));  // kernel_scale
+
   CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
-  return DenseRel<QnnDenseAttrs>(types, num_inputs, attrs, reporter);
+
+  // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
+  // Dense infer type function.
+  Array<Type> tensor_types = {types[0], types[1], types[6]};
+  return DenseRel<DenseAttrs>(tensor_types, 3, attrs, reporter);
 }
 
 // Positional relay function to create quantized dense operator used by frontend FFI.
-Expr MakeQuantizedDense(Expr data, Expr weight, int32_t input_zero_point,
-                        int32_t kernel_zero_point,  double input_scale,
-                        double kernel_scale, IndexExpr units,
-                        DataType out_dtype) {
-  auto attrs = make_object<QnnDenseAttrs>();
+Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point,
+                        Expr input_scale, Expr kernel_scale, IndexExpr units, DataType out_dtype) {
+  auto attrs = make_object<DenseAttrs>();
   attrs->units = std::move(units);
   attrs->out_dtype = out_dtype;
-  attrs->input_zero_point = input_zero_point;
-  attrs->kernel_zero_point = kernel_zero_point;
-  attrs->input_scale = input_scale;
-  attrs->kernel_scale = kernel_scale;
   static const Op& op = Op::Get("qnn.dense");
-  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
+  return CallNode::make(
+      op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
+      Attrs(attrs), {});
 }
 
 Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel,
-                    const QnnDenseAttrs* attrs) {
+                    const DenseAttrs* attrs) {
   return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype);
 }
 
-Expr DenseSecondTerm(const Expr& quantized_data, const Expr& zp_kernel) {
+Expr DenseSecondTerm(const Expr& quantized_data, const Expr& kernel_zero_point) {
   Array<Integer> axes = {1};
-  return Multiply(zp_kernel, Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false));
+  return Multiply(kernel_zero_point,
+                  Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false));
 }
 
-Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& zp_data) {
+Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& input_zero_point) {
   Array<Integer> axes = {1};
-  return Multiply(zp_data, Sum(Cast(quantized_kernel, DataType::Int(32)), axes, false, false));
+  return Multiply(input_zero_point,
+                  Sum(Cast(quantized_kernel, DataType::Int(32)), axes, false, false));
 }
 
-Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int reduction_dim_size) {
-  int32_t scalar_term = attrs->input_zero_point * attrs->kernel_zero_point * reduction_dim_size;
+Expr DenseFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int reduction_dim_size) {
+  int32_t scalar_term = input_zero_point_int * kernel_zero_point_int * reduction_dim_size;
   return MakeConstantScalar(DataType::Int(32), scalar_term);
 }
 
@@ -125,31 +133,35 @@ Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int reduction_dim_size) {
  */
 Expr QnnDenseCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
                           const Array<tvm::relay::Type>& arg_types) {
-  CHECK_EQ(new_args.size(), 2);
+  CHECK_EQ(new_args.size(), 6);
   Expr quantized_data = new_args[0];
   Expr quantized_kernel = new_args[1];
+  Expr input_zero_point = new_args[2];
+  Expr kernel_zero_point = new_args[3];
 
   const auto in_shape = get_shape(arg_types[0]);
   const int reduction_dim_size = get_const_int(in_shape[1]);
 
-  const auto* qnn_dense_attrs = attrs.as<QnnDenseAttrs>();
-  auto zp_kernel = MakeConstantScalar(DataType::Int(32), qnn_dense_attrs->kernel_zero_point);
-  auto zp_data = MakeConstantScalar(DataType::Int(32), qnn_dense_attrs->input_zero_point);
+  const auto* qnn_dense_attrs = attrs.as<DenseAttrs>();
+
+  // Extract the integer zero points.
+  auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
+  auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);
 
   // Get all the terms as described in the comments.
   auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs);
-  auto term2 = DenseSecondTerm(quantized_data, zp_kernel);
-  auto term3 = DenseThirdTerm(quantized_kernel, zp_data);
-  auto term4 = DenseFourthTerm(qnn_dense_attrs, reduction_dim_size);
+  auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point);
+  auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point);
+  auto term4 = DenseFourthTerm(input_zero_point_int, kernel_zero_point_int, reduction_dim_size);
 
   // Combine those 4 terms depending on the zero points to get the best lowering.
-  if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point == 0) {
+  if (input_zero_point_int == 0 && kernel_zero_point_int == 0) {
     // term 2, 3 and 4 become zero.
     return term1;
-  } else if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point != 0) {
+  } else if (input_zero_point_int == 0 && kernel_zero_point_int != 0) {
     // term 3 and term 4 become zero.
     return Subtract(term1, term2);
-  } else if (qnn_dense_attrs->input_zero_point != 0 && qnn_dense_attrs->kernel_zero_point == 0) {
+  } else if (input_zero_point_int != 0 && kernel_zero_point_int == 0) {
     // term 2 and term 4 become zero.
     return Subtract(term1, term3);
   } else {
@@ -166,12 +178,16 @@ RELAY_REGISTER_OP("qnn.dense")
 - **weight**: quantized(int8, unit8) `(units, input_dim)`
 - **out**: quantized(int32) `(x1, x2, ..., xn, units)`.
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<QnnDenseAttrs>()
-.set_num_inputs(2)
+.set_attrs_type<DenseAttrs>()
+.set_num_inputs(6)
 .add_argument("data", "quantized nD Tensor", "Input data.")
 .add_argument("weight", "quantized 2D Tensor", "Weight matrix.")
+.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
+.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
+.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.")
+.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.")
 .set_support_level(11)
-.add_type_rel("QDense", DenseRel<QnnDenseAttrs>)
+.add_type_rel("QDense", QnnDenseRel)
 .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize);
 
 TVM_REGISTER_API("relay.qnn.op._make.dense")
index a1e2380..94f2f89 100644 (file)
@@ -33,13 +33,11 @@ namespace tvm {
 namespace relay {
 namespace qnn {
 
-TVM_REGISTER_NODE_TYPE(DequantizeAttrs);
-
 bool DequantizeRel(const Array<Type>& types,
                    int num_inputs,
                    const Attrs& attrs,
                    const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
+  CHECK_EQ(types.size(), 4);
   const auto* data = types[0].as<TensorTypeNode>();
   const auto input_dtype = data->dtype;
   CHECK(input_dtype == DataType::Int(8) ||
@@ -47,42 +45,40 @@ bool DequantizeRel(const Array<Type>& types,
         input_dtype == DataType::Int(32))
     << "Input type should be one of the quantized types [unit8, int8, int32] but was "
     <<  input_dtype;
+
+  // Check the types of scale and zero points.
+  CHECK(IsScalarType(types[1], DataType::Float(32)));  // input_scale
+  CHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+
   const Array<tvm::Expr> oshape = data->shape;
   // assign output type, output will always be float 32.
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, DataType::Float(32)));
+  reporter->Assign(types[3], TensorTypeNode::make(oshape, DataType::Float(32)));
   return true;
 }
 
-Expr MakeDequantize(Expr data,
-                    double input_scale,
-                    int32_t input_zero_point) {
-  auto attrs = make_object<DequantizeAttrs>();
-  attrs->input_scale = input_scale;
-  attrs->input_zero_point = input_zero_point;
+Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point) {
   // real_value = scale * (quantized_value - zero_point)
-  // A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
+  // A more detailed explanation can be found here -
+  // https://github.com/google/gemmlowp/blob/master/doc/quantization.md
   static const Op& op = Op::Get("qnn.dequantize");
-  return CallNode::make(op, {data}, Attrs(attrs), {});
+  return CallNode::make(op, {data, input_scale, input_zero_point}, Attrs(), {});
 }
 
-Expr DequantizeLower(const Expr& input_tensor,
-                     const DequantizeAttrs* attrs) {
-  const auto input_zero_point = MakeConstantScalar(DataType::Int(32), attrs->input_zero_point);
-  const auto input_scale = MakeConstantScalar(DataType::Float(32), attrs->input_scale);
+Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
+                     const Expr& input_zero_point) {
   auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
   auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
   return scaled_output;
 }
 
-Expr DequantizeQnnCanonicalize(const Attrs& attrs,
-                               const Array<Expr>& new_args,
+Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
                                const Array<tvm::relay::Type>& types) {
-  CHECK_EQ(new_args.size(), 1);
+  CHECK_EQ(new_args.size(), 3);
   auto& data = new_args[0];
-  const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
-  CHECK(dequantize_attrs != nullptr);
-  CHECK_EQ(types.size(), 2);
-  return DequantizeLower(data, dequantize_attrs);
+  auto& input_scale = new_args[1];
+  auto& input_zero_point = new_args[2];
+  CHECK_EQ(types.size(), 4);
+  return DequantizeLower(data, input_scale, input_zero_point);
 }
 
 RELAY_REGISTER_OP("qnn.dequantize")
@@ -90,9 +86,10 @@ RELAY_REGISTER_OP("qnn.dequantize")
 The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point.
 - **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<DequantizeAttrs>()
-.set_num_inputs(1)
+.set_num_inputs(3)
 .add_argument("data", "Tensor", "The tensor to dequantize.")
+.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
+.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
 .set_support_level(11)
 .add_type_rel("Dequantize", DequantizeRel)
 .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);
index a34098f..c8ea3fc 100644 (file)
@@ -42,20 +42,18 @@ namespace qnn {
 Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
                         const Array<tvm::relay::Type>& arg_types) {
   // Get the attrs.
-  CHECK_EQ(new_args.size(), 2);
+  CHECK_EQ(new_args.size(), 8);
   auto& lhs = new_args[0];
   auto& rhs = new_args[1];
-  const auto* binary_op_attrs = attrs.as<QnnBinaryOpAttrs>();
-  CHECK(binary_op_attrs != nullptr);
-  auto lhs_scale = binary_op_attrs->lhs_scale;
-  auto lhs_zero_point = binary_op_attrs->lhs_zero_point;
-  auto rhs_scale = binary_op_attrs->rhs_scale;
-  auto rhs_zero_point = binary_op_attrs->rhs_zero_point;
-  auto output_scale = binary_op_attrs->output_scale;
-  auto output_zero_point = binary_op_attrs->output_zero_point;
+  auto& lhs_scale = new_args[2];
+  auto& lhs_zero_point = new_args[3];
+  auto& rhs_scale = new_args[4];
+  auto& rhs_zero_point = new_args[5];
+  auto& output_scale = new_args[6];
+  auto& output_zero_point = new_args[7];
 
   // Get the input dtype and shape.
-  CHECK_EQ(arg_types.size(), 3);
+  CHECK_EQ(arg_types.size(), 9);
   auto tensor_type = arg_types[0].as<TensorTypeNode>();
   auto input_dtype = tensor_type->dtype;
   auto input_shape = tensor_type->shape;
@@ -75,24 +73,28 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
   auto lhs_shifted = Cast(lhs, DataType::Int(32));
   auto rhs_shifted = Cast(rhs, DataType::Int(32));
 
-  if (lhs_zero_point != 0) {
-    auto lhs_zp = MakeConstantScalar(DataType::Int(32), lhs_zero_point);
-    lhs_shifted = Subtract(lhs_shifted, lhs_zp);
+  auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
+  if (!IsEqualScalar(lhs_zero_point, zero_scalar)) {
+    lhs_shifted = Subtract(lhs_shifted, lhs_zero_point);
   }
 
-  if (rhs_zero_point != 0) {
-    auto rhs_zp = MakeConstantScalar(DataType::Int(32), rhs_zero_point);
-    rhs_shifted = Subtract(rhs_shifted, rhs_zp);
+  if (!IsEqualScalar(rhs_zero_point, zero_scalar)) {
+    rhs_shifted = Subtract(rhs_shifted, rhs_zero_point);
   }
 
   // Create a new tensor Q'
   auto output = Multiply(lhs_shifted, rhs_shifted);
 
-  auto scale_new = rhs_scale * lhs_scale;
+  // Get the adjusted new scale and zero points.
+  float lhs_scale_float = GetScalarFromConstant<float>(lhs_scale);
+  float rhs_scale_float = GetScalarFromConstant<float>(rhs_scale);
+  float new_scale_float = lhs_scale_float * rhs_scale_float;
+  auto new_input_scale = MakeConstantScalar(DataType::Float(32), new_scale_float);
+  auto new_input_zero_point = zero_scalar;
 
   // Requantize to get Q_c
-  output = Requantize(output, input_shape, scale_new, 0, output_scale,
-    output_zero_point, input_dtype);
+  output = Requantize(output, input_shape, new_input_scale, new_input_zero_point, output_scale,
+                      output_zero_point, input_dtype);
 
   return output;
 }
index 2c116fe..41e8335 100644 (file)
@@ -35,6 +35,24 @@ namespace tvm {
 namespace relay {
 namespace qnn {
 
+static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                                   const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 9);
+
+  // Check the scale and zero point types
+  CHECK(IsScalarType(types[2], DataType::Float(32)));  // lhs_scale
+  CHECK(IsScalarType(types[3], DataType::Int(32)));    // lhs_zero_point
+  CHECK(IsScalarType(types[4], DataType::Float(32)));  // rhs_scale
+  CHECK(IsScalarType(types[5], DataType::Int(32)));    // rhs_zero_point
+  CHECK(IsScalarType(types[6], DataType::Float(32)));  // output_scale
+  CHECK(IsScalarType(types[7], DataType::Int(32)));    // output_zero_point
+
+  // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
+  // BroadcastRel infer type function.
+  Array<Type> tensor_types = {types[0], types[1], types[8]};
+  return BroadcastRel(tensor_types, 3, attrs, reporter);
+}
+
 /*! Quick helper macro
  * - Expose a positional make function to construct the node.
  * - Register op to the registry.
@@ -47,24 +65,26 @@ namespace qnn {
  */
 #define QNN_REGISTER_BINARY_OP(OpName)                                                     \
   TVM_REGISTER_API("relay.qnn.op._make." OpName)                                           \
-    .set_body_typed<Expr(Expr, Expr, double, int32_t, double, int32_t, double, int32_t)>(  \
-        [](Expr lhs, Expr rhs, double lhs_scale, int32_t lhs_zero_point, double rhs_scale, \
-           int32_t rhs_zero_point, double output_scale, int32_t output_zero_point) {       \
-          auto attrs = make_object<QnnBinaryOpAttrs>();                                      \
-          attrs->lhs_scale = lhs_scale;                                                    \
-          attrs->lhs_zero_point = lhs_zero_point;                                          \
-          attrs->rhs_scale = rhs_scale;                                                    \
-          attrs->rhs_zero_point = rhs_zero_point;                                          \
-          attrs->output_scale = output_scale;                                              \
-          attrs->output_zero_point = output_zero_point;                                    \
+    .set_body_typed<Expr(Expr, Expr, Expr, Expr, Expr, Expr, Expr, Expr)>(                 \
+        [](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale,        \
+           Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) {               \
           static const Op& op = Op::Get("qnn." OpName);                                    \
-          return CallNode::make(op, {lhs, rhs}, Attrs(attrs), {});                         \
+          return CallNode::make(op, {lhs, rhs,                                             \
+                                     lhs_scale, lhs_zero_point,                            \
+                                     rhs_scale, rhs_zero_point,                            \
+                                     output_scale, output_zero_point}, Attrs(), {});       \
         });                                                                                \
   RELAY_REGISTER_OP("qnn." OpName)                                                         \
-    .set_num_inputs(2)                                                                     \
+    .set_num_inputs(8)                                                                     \
     .add_argument("lhs", "Tensor", "The left hand side quantized tensor.")                 \
     .add_argument("rhs", "Tensor", "The right hand side quantized tensor.")                \
-    .add_type_rel("Broadcast", BroadcastRel)
+    .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.")                   \
+    .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.")         \
+    .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.")                   \
+    .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.")         \
+    .add_argument("output_scale", "Tensor", "The scale of the output tensor.")             \
+    .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.")   \
+    .add_type_rel("QnnBroadcast", QnnBroadcastRel)
 
 }  // namespace qnn
 }  // namespace relay
index 18dd9aa..9749fb8 100644 (file)
@@ -39,61 +39,61 @@ bool QuantizeRel(const Array<Type>& types,
                  int num_inputs,
                  const Attrs& attrs,
                  const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
+  CHECK_EQ(types.size(), 4);
   const auto* data = types[0].as<TensorTypeNode>();
   const auto input_dtype = data->dtype;
   CHECK(input_dtype == DataType::Float(32))
     << "Input type should be one of float32 but was " <<  input_dtype;
+
+  // Check the types of scale and zero points.
+  CHECK(IsScalarType(types[1], DataType::Float(32)));  // output_scale
+  CHECK(IsScalarType(types[2], DataType::Int(32)));    // output_zero_point
+
   const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
   const Array<tvm::Expr> oshape = data->shape;
   const DataType out_dtype = quantize_attrs->out_dtype;
-  CHECK(out_dtype == DataType::Int(8) ||
-        out_dtype == DataType::UInt(8) ||
+  CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
         out_dtype == DataType::Int(32))
-    << "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
+      << "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
   // assign output type
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype));
   return true;
 }
 
-Expr MakeQuantize(Expr data,
-                  double output_scale,
-                  int32_t output_zero_point,
-                  DataType out_dtype) {
+Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType out_dtype) {
   auto attrs = make_object<QuantizeAttrs>();
-  attrs->output_scale = output_scale;
-  attrs->output_zero_point = output_zero_point;
   attrs->out_dtype = std::move(out_dtype);
   // result_quantized_value = result_zero_point + result_real_value / result_scale.
-  // A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
+  // A more detailed explanation can be found here -
+  // https://github.com/google/gemmlowp/blob/master/doc/quantization.md
   static const Op& op = Op::Get("qnn.quantize");
-  return CallNode::make(op, {data}, Attrs(attrs), {});
+  return CallNode::make(op, {data, output_scale, output_zero_point}, Attrs(attrs), {});
 }
 
-Expr QuantizeLower(const Expr& input_tensor,
-                   const QuantizeAttrs* attrs) {
+Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
+                   const Expr& output_zero_point, const QuantizeAttrs* attrs) {
   const auto out_dtype = attrs->out_dtype;
-  const auto output_zero_point = MakeConstantScalar(DataType::Float(32), attrs->output_zero_point);
-  const auto scale = MakeConstantScalar(DataType::Float(32), attrs->output_scale);
   const int32_t min_val = GetQmin(out_dtype);
   const int32_t max_val = GetQmax(out_dtype);
-  auto scale_data = Divide(input_tensor, scale);
-  auto add_zero_point = Cast(Round(Add(scale_data, output_zero_point)), DataType::Int(32));
+  auto scale_data = Divide(input_tensor, output_scale);
+  auto add_zero_point =
+      Cast(Round(Add(scale_data, Cast(output_zero_point, DataType::Float(32)))), DataType::Int(32));
   auto clamped_output = Clip(add_zero_point, min_val, max_val);
   auto clamp_out_dtype = Cast(clamped_output, out_dtype);
   return clamp_out_dtype;
 }
 
-Expr QuantizeQnnCanonicalize(const Attrs& attrs,
-                             const Array<Expr>& new_args,
+Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
                              const Array<tvm::relay::Type>& types) {
-  CHECK_EQ(new_args.size(), 1);
+  CHECK_EQ(new_args.size(), 3);
   auto& data = new_args[0];
+  auto& output_scale = new_args[1];
+  auto& output_zero_point = new_args[2];
   const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
   CHECK(quantize_attrs != nullptr);
 
-  CHECK_EQ(types.size(), 2);
-  return QuantizeLower(data, quantize_attrs);
+  CHECK_EQ(types.size(), 4);
+  return QuantizeLower(data, output_scale, output_zero_point, quantize_attrs);
 }
 
 RELAY_REGISTER_OP("qnn.quantize")
@@ -108,8 +108,10 @@ scale and zero point.
           or quantized.
 )code" TVM_ADD_FILELINE)
 .set_attrs_type<QuantizeAttrs>()
-.set_num_inputs(1)
+.set_num_inputs(3)
 .add_argument("data", "Tensor", "The tensor to quantize.")
+.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
+.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
 .set_support_level(11)
 .add_type_rel("Quantize", QuantizeRel)
 .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize);
index 93284cb..68b0b08 100644 (file)
@@ -54,31 +54,35 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
  *       4) Add the output zero point.
  *       5) Cast to the out_dtype.
  */
-Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
+Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
+                     const Expr& input_zero_point, const Expr& output_scale,
+                     const Expr& output_zero_point, const RequantizeAttrs* param,
                      const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
-  double double_multiplier = param->input_scale / param->output_scale;
+  float input_scale_float = GetScalarFromConstant<float>(input_scale);
+  float output_scale_float = GetScalarFromConstant<float>(output_scale);
+  double double_multiplier =
+      static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
 
   DataType hp_dtype = DataType::Int(64);
 
   auto tensor = Cast(input_tensor, hp_dtype);
   // 1) Subtract the input_zero_point
-  if (param->input_zero_point != 0) {
-    auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point);
-    tensor = Subtract(tensor, input_zp);
+  auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
+  if (!IsEqualScalar(input_zero_point, zero_scalar)) {
+    tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype));
   }
 
   // 2) If the input and output scales are same, we can skip the fixed point multiplication.
   auto scaled_int64_t = tensor;
-  if (param->input_scale != param->output_scale) {
+  if (!IsEqualScalar(input_scale, output_scale)) {
     scaled_int64_t =
         FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
   }
 
   // 3) Add the output zero point.
   auto shifted_int64_t = scaled_int64_t;
-  if (param->output_zero_point != 0) {
-    auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
-    shifted_int64_t = Add(output_zp, scaled_int64_t);
+  if (!IsEqualScalar(output_zero_point, zero_scalar)) {
+    shifted_int64_t = Add(Cast(output_zero_point, hp_dtype), scaled_int64_t);
   }
 
   // 4) Clip to the out_dtype min/max.
@@ -103,13 +107,17 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
  */
 Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
                                const Array<tvm::relay::Type>& types) {
-  CHECK_EQ(new_args.size(), 1);
+  CHECK_EQ(new_args.size(), 5);
   auto& quantized_data = new_args[0];
+  auto& input_scale = new_args[1];
+  auto& input_zero_point = new_args[2];
+  auto& output_scale = new_args[3];
+  auto& output_zero_point = new_args[4];
   const auto* param = attrs.as<RequantizeAttrs>();
   CHECK(param != nullptr);
 
   // Find input shape.
-  CHECK_EQ(types.size(), 2);
+  CHECK_EQ(types.size(), 6);
   auto in_type = types[0];
   auto in_tensor_type = in_type.as<TensorTypeNode>();
   CHECK(in_tensor_type != nullptr) << "Type information missing."
@@ -117,7 +125,7 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
   Array<IndexExpr> input_shape = in_tensor_type->shape;
 
   // Find the output dtype.
-  auto out_type = types[1];
+  auto out_type = types[5];
   auto out_tensor_type = out_type.as<TensorTypeNode>();
   CHECK(out_tensor_type != nullptr) << "Type information missing."
                                     << " Please run infer_type pass.";
@@ -127,7 +135,8 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
   CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
       << "QNN requantize supports two rounding modes - UPWARD and "
       << "TONEAREST";
-  return RequantizeLower(quantized_data, param, input_shape, out_dtype);
+  return RequantizeLower(quantized_data, input_scale, input_zero_point, output_scale,
+                         output_zero_point, param, input_shape, out_dtype);
 }
 
 /*
@@ -140,7 +149,7 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
  */
 bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                    const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
+  CHECK_EQ(types.size(), 6);
   const auto* data = types[0].as<TensorTypeNode>();
   const auto in_dtype = data->dtype;
   CHECK(in_dtype == DataType::Int(8) ||
@@ -148,6 +157,12 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
         in_dtype == DataType::Int(32))
       << "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
 
+  // Check the types of scale and zero points.
+  CHECK(IsScalarType(types[1], DataType::Float(32)));  // input_scale
+  CHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+  CHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
+  CHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
+
   const Array<tvm::Expr> oshape = data->shape;
   // assign output type
   const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
@@ -156,23 +171,20 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
         out_dtype == DataType::UInt(8) ||
         out_dtype == DataType::Int(32))
       << "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[5], TensorTypeNode::make(oshape, out_dtype));
   return true;
 }
 
 // Positional relay function to create qnn requantize operator
 // used by frontend FFI.
-Expr MakeRequantize(Expr data, double input_scale, int32_t input_zero_point, double output_scale,
-                    int32_t output_zero_point, std::string rounding, DataType out_dtype) {
+Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
+                    Expr output_zero_point, std::string rounding, DataType out_dtype) {
   auto attrs = make_object<RequantizeAttrs>();
-  attrs->input_scale = std::move(input_scale);
-  attrs->input_zero_point = std::move(input_zero_point);
-  attrs->output_scale = std::move(output_scale);
-  attrs->output_zero_point = std::move(output_zero_point);
   attrs->rounding = std::move(rounding);
   attrs->out_dtype = std::move(out_dtype);
   static const Op& op = Op::Get("qnn.requantize");
-  return CallNode::make(op, {data}, Attrs(attrs), {});
+  return CallNode::make(op, {data, input_scale, input_zero_point, output_scale, output_zero_point},
+                        Attrs(attrs), {});
 }
 
 RELAY_REGISTER_OP("qnn.requantize")
@@ -185,8 +197,12 @@ Q_output = zp_output +  (scale_input)/(scale_output) * (Q_input - zp_input)
 
 )code" TVM_ADD_FILELINE)
 .set_attrs_type<RequantizeAttrs>()
-.set_num_inputs(1)
+.set_num_inputs(5)
 .add_argument("data", "Tensor", "The quantized input tensor.")
+.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
+.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
+.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
+.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
 .set_support_level(11)
 .add_type_rel("Requantize", RequantizeRel)
 .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize);
index e359296..67b934e 100644 (file)
@@ -77,21 +77,20 @@ static inline const int32_t GetQmax(const DataType& dtype) {
   }
 }
 
-Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
+Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
+                     const Expr& input_zero_point, const Expr& output_scale,
+                     const Expr& output_zero_point, const RequantizeAttrs* param,
                      const Array<IndexExpr>& input_shape, const DataType& out_dtype);
 
 static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_shape,
-                              double input_scale, int32_t input_zero_point, double output_scale,
-                              int32_t output_zero_point, const DataType& out_dtype,
-                              const std::string& rounding = "UPWARD") {
+                              const Expr& input_scale, const Expr& input_zero_point,
+                              const Expr& output_scale, const Expr& output_zero_point,
+                              const DataType& out_dtype, const std::string& rounding = "UPWARD") {
   auto attrs = make_object<RequantizeAttrs>();
-  attrs->input_scale = std::move(input_scale);
-  attrs->input_zero_point = std::move(input_zero_point);
-  attrs->output_scale = std::move(output_scale);
-  attrs->output_zero_point = std::move(output_zero_point);
   attrs->rounding = std::move(rounding);
   attrs->out_dtype = std::move(out_dtype);
-  return RequantizeLower(data, attrs.operator->(), input_shape, out_dtype);
+  return RequantizeLower(data, input_scale, input_zero_point, output_scale, output_zero_point,
+                         attrs.operator->(), input_shape, out_dtype);
 }
 
 static inline int64_t get_const_int(const tvm::Expr& x) {
@@ -122,10 +121,22 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
  *       2) Round the result.
  *       3) Right shift the result
  */
-Expr FixedPointMultiply(Expr tensor, double multiplier,
-                        const Array<IndexExpr>& input_shape,
+Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
                         const std::string& rounding);
 
+/*
+ * \brief Checks whether an expr type is scalar of a given data type.
+ * \param expr_type The type of expr to be checked.
+ * \param dtype The expected dtype.
+ * \return True if the type is a scalar of given dtype
+ */
+static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
+  const auto* scale = expr_type.as<TensorTypeNode>();
+  CHECK_EQ(scale->shape.size(), 0);
+  CHECK(scale->dtype == dtype) << "Expected " << dtype << " but got " << scale->dtype;
+  return true;
+}
+
 }  // namespace qnn
 }  // namespace relay
 }  // namespace tvm
index e919206..033a104 100644 (file)
@@ -27,12 +27,12 @@ def test_tflite_same_io_qnn_params():
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
     z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=0.00784314,
-                         lhs_zero_point=127,
-                         rhs_scale=0.00784314,
-                         rhs_zero_point=127,
-                         output_scale=0.00784314,
-                         output_zero_point=127)
+                         lhs_scale=relay.const(0.00784314, 'float32'),
+                         lhs_zero_point=relay.const(127, 'int32'),
+                         rhs_scale=relay.const(0.00784314, 'float32'),
+                         rhs_zero_point=relay.const(127, 'int32'),
+                         output_scale=relay.const(0.00784314, 'float32'),
+                         output_zero_point=relay.const(127, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
@@ -65,12 +65,12 @@ def test_tflite_different_io_qnn_params():
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
     z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=0.0156863,
-                         lhs_zero_point=127,
-                         rhs_scale=0.0117647,
-                         rhs_zero_point=85,
-                         output_scale=0.0235294,
-                         output_zero_point=128)
+                         lhs_scale=relay.const(0.0156863, 'float32'),
+                         lhs_zero_point=relay.const(127, 'int32'),
+                         rhs_scale=relay.const(0.0117647, 'float32'),
+                         rhs_zero_point=relay.const(85, 'int32'),
+                         output_scale=relay.const(0.0235294, 'float32'),
+                         output_zero_point=relay.const(128, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
@@ -103,12 +103,12 @@ def test_saturation():
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
     z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=0.125,
-                         lhs_zero_point=0,
-                         rhs_scale=0.125,
-                         rhs_zero_point=0,
-                         output_scale=0.125,
-                         output_zero_point=0)
+                         lhs_scale=relay.const(0.125, 'float32'),
+                         lhs_zero_point=relay.const(0, 'int32'),
+                         rhs_scale=relay.const(0.125, 'float32'),
+                         rhs_zero_point=relay.const(0, 'int32'),
+                         output_scale=relay.const(0.125, 'float32'),
+                         output_zero_point=relay.const(0, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
@@ -125,12 +125,12 @@ def test_saturation():
 
     # Same params, different scale
     z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=0.125,
-                         lhs_zero_point=0,
-                         rhs_scale=0.125,
-                         rhs_zero_point=0,
-                         output_scale=0.25,
-                         output_zero_point=0)
+                         lhs_scale=relay.const(0.125, 'float32'),
+                         lhs_zero_point=relay.const(0, 'int32'),
+                         rhs_scale=relay.const(0.125, 'float32'),
+                         rhs_zero_point=relay.const(0, 'int32'),
+                         output_scale=relay.const(0.25, 'float32'),
+                         output_zero_point=relay.const(0, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
@@ -147,12 +147,12 @@ def test_saturation():
 
     # Same io params, different output scale
     z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=0.125,
-                         lhs_zero_point=0,
-                         rhs_scale=0.125,
-                         rhs_zero_point=0,
-                         output_scale=0.25,
-                         output_zero_point=0)
+                         lhs_scale=relay.const(0.125, 'float32'),
+                         lhs_zero_point=relay.const(0, 'int32'),
+                         rhs_scale=relay.const(0.125, 'float32'),
+                         rhs_zero_point=relay.const(0, 'int32'),
+                         output_scale=relay.const(0.25, 'float32'),
+                         output_zero_point=relay.const(0, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
@@ -169,12 +169,12 @@ def test_saturation():
 
     # All params different
     z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=0.5,
-                         lhs_zero_point=0,
-                         rhs_scale=0.25,
-                         rhs_zero_point=0,
-                         output_scale=0.125,
-                         output_zero_point=0)
+                         lhs_scale=relay.const(0.5, 'float32'),
+                         lhs_zero_point=relay.const(0, 'int32'),
+                         rhs_scale=relay.const(0.25, 'float32'),
+                         rhs_zero_point=relay.const(0, 'int32'),
+                         output_scale=relay.const(0.125, 'float32'),
+                         output_zero_point=relay.const(0, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
index b24e1a0..ed49694 100644 (file)
@@ -26,16 +26,17 @@ def test_same_io_qnn_params():
     axis = 0
     x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
     y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
-    x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
-    y_scale = (62 + 64) / (np.power(2, 32) - 1.0)
+    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
+    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
+    zero = relay.const(0, 'int32')
 
     x = relay.var("x", shape=(1, 64), dtype=data_dtype)
     y = relay.var("y", shape=(1, 64), dtype=data_dtype)
     z = relay.qnn.op.concatenate((x, y),
-                                 input_scales=[x_scale, y_scale],
-                                 input_zero_points=[0, 0],
+                                 input_scales=(x_scale, y_scale),
+                                 input_zero_points=(zero, zero),
                                  output_scale=y_scale,
-                                 output_zero_point=0,
+                                 output_zero_point=zero,
                                  axis=axis)
 
     func = relay.Function([x, y], z)
@@ -54,16 +55,19 @@ def test_different_io_qnn_params():
     axis = 0
     x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
     y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
-    x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
-    y_scale = (62 + 64) / (np.power(2, 32) - 1.0)
+
+    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
+    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
+    x_zero_point = relay.const(3, 'int32')
+    y_zero_point = relay.const(4, 'int32')
 
     x = relay.var("x", shape=(1, 64), dtype=data_dtype)
     y = relay.var("y", shape=(1, 64), dtype=data_dtype)
     z = relay.qnn.op.concatenate((x, y),
-                                 input_scales=[x_scale, y_scale],
-                                 input_zero_points=[3, 4],
+                                 input_scales=(x_scale, y_scale),
+                                 input_zero_points=(x_zero_point, y_zero_point),
                                  output_scale=y_scale,
-                                 output_zero_point=1,
+                                 output_zero_point=relay.const(1, 'int32'),
                                  axis=axis)
 
     func = relay.Function([x, y], z)
@@ -82,16 +86,19 @@ def test_few_same_io_qnn_params():
     axis = 0
     x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
     y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
-    x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
-    y_scale = (62 + 64) / (np.power(2, 32) - 1.0)
+
+    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
+    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
+    x_zero_point = relay.const(0, 'int32')
+    y_zero_point = relay.const(1, 'int32')
 
     x = relay.var("x", shape=(1, 64), dtype=data_dtype)
     y = relay.var("y", shape=(1, 64), dtype=data_dtype)
     z = relay.qnn.op.concatenate((x, y),
-                                 input_scales=[x_scale, y_scale],
-                                 input_zero_points=[0, 1],
+                                 input_scales=(x_scale, y_scale),
+                                 input_zero_points=(x_zero_point, y_zero_point),
                                  output_scale=y_scale,
-                                 output_zero_point=1,
+                                 output_zero_point=relay.const(1, 'int32'),
                                  axis=axis)
 
     func = relay.Function([x, y], z)
@@ -110,16 +117,19 @@ def test_same_i_qnn_params():
     axis = 0
     x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
     y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
-    x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
-    y_scale = (62 + 64) / (np.power(2, 32) - 1.0)
+
+    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
+    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
+    x_zero_point = relay.const(0, 'int32')
+    y_zero_point = relay.const(0, 'int32')
 
     x = relay.var("x", shape=(1, 64), dtype=data_dtype)
     y = relay.var("y", shape=(1, 64), dtype=data_dtype)
     z = relay.qnn.op.concatenate((x, y),
-                                 input_scales=[x_scale, y_scale],
-                                 input_zero_points=[0, 0],
+                                 input_scales=(x_scale, y_scale),
+                                 input_zero_points=(x_zero_point, y_zero_point),
                                  output_scale=y_scale,
-                                 output_zero_point=1,
+                                 output_zero_point=relay.const(1, 'int32'),
                                  axis=axis)
 
     func = relay.Function([x, y], z)
index eda47e6..9effa6f 100644 (file)
@@ -52,16 +52,16 @@ def get_ref_func(data,
     shifted_kernel = relay.op.subtract(casted_kernel,
             relay.const(kernel_zero_point, "int32"))
     func = relay.op.nn.conv2d(shifted_data,
-                             shifted_kernel,
-                             padding=padding,
-                             strides=strides,
-                             dilation=dilation,
-                             groups=groups,
-                             channels=channels,
-                             kernel_size=kernel_size,
-                             out_dtype=out_dtype,
-                             data_layout=data_layout,
-                             kernel_layout=kernel_layout)
+                              shifted_kernel,
+                              padding=padding,
+                              strides=strides,
+                              dilation=dilation,
+                              groups=groups,
+                              channels=channels,
+                              kernel_size=kernel_size,
+                              out_dtype=out_dtype,
+                              data_layout=data_layout,
+                              kernel_layout=kernel_layout)
 
     func = relay.Function(relay.analysis.free_vars(func), func)
     return func
@@ -83,10 +83,10 @@ def get_qnn_func(data,
                  channels=None):
     func = relay.qnn.op.conv2d(
             data, kernel,
-            input_zero_point=input_zero_point,
-            kernel_zero_point=kernel_zero_point,
-            input_scale=input_scale,
-            kernel_scale=kernel_scale,
+            input_zero_point=relay.const(input_zero_point, 'int32'),
+            kernel_zero_point=relay.const(kernel_zero_point, 'int32'),
+            input_scale=relay.const(input_scale, 'float32'),
+            kernel_scale=relay.const(kernel_scale, 'float32'),
             kernel_size=kernel_size,
             strides=strides,
             dilation=dilation,
index e0df76d..11987a5 100644 (file)
@@ -179,10 +179,10 @@ def qnn_dense_driver(test_configuration):
     mod = relay.qnn.op.dense(
         quantized_data,
         quantized_kernel,
-        test_configuration['input_zero_point'],
-        test_configuration['kernel_zero_point'],
-        test_configuration['input_scale'],
-        test_configuration['kernel_scale'],
+        relay.const(test_configuration['input_zero_point'], 'int32'),
+        relay.const(test_configuration['kernel_zero_point'], 'int32'),
+        relay.const(test_configuration['input_scale'], 'float32'),
+        relay.const(test_configuration['kernel_scale'], 'float32'),
         test_configuration['units'])
     if test_configuration[bias_name] is not None:
         bias = relay.var(bias_name,
@@ -193,10 +193,10 @@ def qnn_dense_driver(test_configuration):
         requantize_config = test_configuration['requantize']
         mod = relay.qnn.op.requantize(
             mod,
-            input_scale=requantize_config['input_scale'],
-            input_zero_point=0,
-            output_scale=requantize_config['output_scale'],
-            output_zero_point=requantize_config['output_zero_point'],
+            input_scale=relay.const(requantize_config['input_scale'], 'float32'),
+            input_zero_point=relay.const(0, 'int32'),
+            output_scale=relay.const(requantize_config['output_scale'], 'float32'),
+            output_zero_point=relay.const(requantize_config['output_zero_point'], 'int32'),
             out_dtype=requantize_config['out_dtype'])
         expected_out_dtype = requantize_config['out_dtype']
 
index a99e78d..4510c57 100644 (file)
@@ -20,61 +20,56 @@ import numpy as np
 from tvm import relay
 from tvm.contrib import graph_runtime
 
-def test_dequantize_op():
+def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
+    shape = in_data.shape
+    input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
+    input_zero_point = relay.const(quant_args['in_zero_point'], 'int32')
+    input_scale = relay.const(quant_args['in_scale'], 'float32')
+    quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale,
+                                               input_zero_point=input_zero_point)
+    mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
+    mod = relay.Module.from_expr(mod)
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "llvm", params=None)
+        rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+        rt_mod.set_input(input_data=in_data)
+        rt_mod.set_input(**params)
+        rt_mod.run()
+        res = rt_mod.get_output(0).asnumpy()
+        np.testing.assert_equal(res, verify_output_data)
+        assert res.dtype == np.float32
 
-    def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
-        shape = in_data.shape
-        input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
-        input_zero_point = quant_args['in_zero_point']
-        input_scale = quant_args['in_scale']
-        quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale,
-                                                   input_zero_point=input_zero_point)
-        mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
-        mod = relay.Module.from_expr(mod)
-        with relay.build_config(opt_level=3):
-            graph, lib, params = relay.build(mod, "llvm", params=None)
-            rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
-            rt_mod.set_input(input_data=in_data)
-            rt_mod.set_input(**params)
-            rt_mod.run()
-            res = rt_mod.get_output(0).asnumpy()
-            np.testing.assert_equal(res, verify_output_data)
-            assert res.dtype == np.float32
+def test_uint8_to_float32():
+    data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
+        .astype('uint8') \
+        .reshape((2, 5))
+    output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
+        .astype('float32') \
+        .reshape((2, 5))
+    quant_args = {"in_zero_point":127, "in_scale":0.5}
+    quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
+                         verify_output_data=output)
 
-    def test_uint8_to_float32():
-        data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
-            .astype('uint8') \
-            .reshape((2, 5))
-        output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
-            .astype('float32') \
-            .reshape((2, 5))
-        quant_args = {"in_zero_point":127, "in_scale":0.5}
-        quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
-                             verify_output_data=output)
+def test_int8_to_float32():
+    data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
+        .astype('int8') \
+        .reshape((2, 5))
+    output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
+        .astype('float32') \
+        .reshape((2, 5))
+    quant_args = {"in_zero_point": -1, "in_scale": 0.5}
+    quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
+                         verify_output_data=output)
 
-    def test_int8_to_float32():
-        data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
-            .astype('int8') \
-            .reshape((2, 5))
-        output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
-            .astype('float32') \
-            .reshape((2, 5))
-        quant_args = {"in_zero_point": -1, "in_scale": 0.5}
-        quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
-                             verify_output_data=output)
+def test_int32_to_float32():
+    data = np.array([113, 29, -1052]).astype('int32')
+    output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32')
+    quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604}
+    quantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data,
+                         verify_output_data=output)
 
-    def test_int32_to_float32():
-        data = np.array([113, 29, -1052]).astype('int32')
-        output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32')
-        quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604}
-        quantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data,
-                             verify_output_data=output)
 
+if __name__ == "__main__":
     test_uint8_to_float32()
     test_int8_to_float32()
     test_int32_to_float32()
-
-
-if __name__ == "__main__":
-    test_dequantize_op()
-
index 8c08c1a..16f0be7 100644 (file)
@@ -44,12 +44,12 @@ def test_tflite_same_io_qnn_params():
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
     z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=lhs_scale,
-                         lhs_zero_point=lhs_zero_point,
-                         rhs_scale=rhs_scale,
-                         rhs_zero_point=rhs_zero_point,
-                         output_scale=output_scale,
-                         output_zero_point=output_zero_point)
+                         lhs_scale=relay.const(lhs_scale, 'float32'),
+                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
+                         rhs_scale=relay.const(rhs_scale, 'float32'),
+                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
+                         output_scale=relay.const(output_scale, 'float32'),
+                         output_zero_point=relay.const(output_zero_point, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
@@ -95,12 +95,12 @@ def test_tflite_different_io_qnn_params():
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
     z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=lhs_scale,
-                         lhs_zero_point=lhs_zero_point,
-                         rhs_scale=rhs_scale,
-                         rhs_zero_point=rhs_zero_point,
-                         output_scale=output_scale,
-                         output_zero_point=output_zero_point)
+                         lhs_scale=relay.const(lhs_scale, 'float32'),
+                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
+                         rhs_scale=relay.const(rhs_scale, 'float32'),
+                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
+                         output_scale=relay.const(output_scale, 'float32'),
+                         output_zero_point=relay.const(output_zero_point, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
@@ -141,12 +141,12 @@ def test_saturation():
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
     z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=lhs_scale,
-                         lhs_zero_point=lhs_zero_point,
-                         rhs_scale=rhs_scale,
-                         rhs_zero_point=rhs_zero_point,
-                         output_scale=output_scale,
-                         output_zero_point=output_zero_point)
+                         lhs_scale=relay.const(lhs_scale, 'float32'),
+                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
+                         rhs_scale=relay.const(rhs_scale, 'float32'),
+                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
+                         output_scale=relay.const(output_scale, 'float32'),
+                         output_zero_point=relay.const(output_zero_point, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
@@ -172,12 +172,12 @@ def test_saturation():
     output_scale = 0.25
 
     z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=lhs_scale,
-                         lhs_zero_point=lhs_zero_point,
-                         rhs_scale=rhs_scale,
-                         rhs_zero_point=rhs_zero_point,
-                         output_scale=output_scale,
-                         output_zero_point=output_zero_point)
+                         lhs_scale=relay.const(lhs_scale, 'float32'),
+                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
+                         rhs_scale=relay.const(rhs_scale, 'float32'),
+                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
+                         output_scale=relay.const(output_scale, 'float32'),
+                         output_zero_point=relay.const(output_zero_point, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
@@ -204,12 +204,12 @@ def test_saturation():
     output_scale = 0.125
 
     z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=lhs_scale,
-                         lhs_zero_point=lhs_zero_point,
-                         rhs_scale=rhs_scale,
-                         rhs_zero_point=rhs_zero_point,
-                         output_scale=output_scale,
-                         output_zero_point=output_zero_point)
+                         lhs_scale=relay.const(lhs_scale, 'float32'),
+                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
+                         rhs_scale=relay.const(rhs_scale, 'float32'),
+                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
+                         output_scale=relay.const(output_scale, 'float32'),
+                         output_zero_point=relay.const(output_zero_point, 'int32'))
 
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
index 9805db5..6871888 100644 (file)
@@ -20,51 +20,47 @@ import numpy as np
 from tvm import relay
 from tvm.contrib import graph_runtime
 
-def test_quantize_op():
+def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output_data):
+    shape = in_data.shape
+    input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
+    output_zero_point = relay.const(quant_args['out_zero_point'], 'int32')
+    output_scale = relay.const(quant_args['out_scale'], 'float32')
+    quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale,
+                                             output_zero_point=output_zero_point,out_dtype=out_dtype)
+    mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
+    mod = relay.Module.from_expr(mod)
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "llvm", params=None)
+        rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+        rt_mod.set_input(input_data=in_data)
+        rt_mod.set_input(**params)
+        rt_mod.run()
+        res = rt_mod.get_output(0).asnumpy()
+        np.testing.assert_equal(res, verify_output_data)
+        assert res.dtype == out_dtype
 
-    def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output_data):
-        shape = in_data.shape
-        input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
-        output_zero_point = quant_args['out_zero_point']
-        output_scale = quant_args['out_scale']
-        quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale,
-                                                 output_zero_point=output_zero_point,out_dtype=out_dtype)
-        mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
-        mod = relay.Module.from_expr(mod)
-        with relay.build_config(opt_level=3):
-            graph, lib, params = relay.build(mod, "llvm", params=None)
-            rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
-            rt_mod.set_input(input_data=in_data)
-            rt_mod.set_input(**params)
-            rt_mod.run()
-            res = rt_mod.get_output(0).asnumpy()
-            np.testing.assert_equal(res, verify_output_data)
-            assert res.dtype == out_dtype
+def test_float32_to_uint8():
+    data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
+        .astype('float32') \
+        .reshape((2,5))
+    output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
+        .astype('uint8') \
+        .reshape((2,5))
+    quant_args = {"out_zero_point":127, "out_scale":0.5}
+    quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='uint8', in_data=data,
+                         verify_output_data=output)
 
-    def test_float32_to_uint8():
-        data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
-            .astype('float32') \
-            .reshape((2,5))
-        output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
-            .astype('uint8') \
-            .reshape((2,5))
-        quant_args = {"out_zero_point":127, "out_scale":0.5}
-        quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='uint8', in_data=data,
-                             verify_output_data=output)
-
-    def test_float32_to_int8():
-        data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
-            .astype('float32') \
-            .reshape((2,5))
-        output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
-            .astype('int8') \
-            .reshape((2,5))
-        quant_args = {"out_zero_point":-1, "out_scale":0.5}
-        quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='int8', in_data=data,
-                             verify_output_data=output)
+def test_float32_to_int8():
+    data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
+        .astype('float32') \
+        .reshape((2,5))
+    output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
+        .astype('int8') \
+        .reshape((2,5))
+    quant_args = {"out_zero_point":-1, "out_scale":0.5}
+    quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='int8', in_data=data,
+                         verify_output_data=output)
 
+if __name__ == "__main__":
     test_float32_to_uint8()
     test_float32_to_int8()
-
-if __name__ == "__main__":
-    test_quantize_op()
index 3818135..6d6b9b4 100644 (file)
@@ -39,10 +39,10 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
             dtype=data_dtype)
     mod = relay.qnn.op.requantize(
             quantized_data,
-            input_scale=input_scale,
-            input_zero_point=input_zero_point,
-            output_scale=output_scale,
-            output_zero_point=output_zero_point,
+            input_scale=relay.const(input_scale, 'float32'),
+            input_zero_point=relay.const(input_zero_point, 'int32'),
+            output_scale=relay.const(output_scale, 'float32'),
+            output_zero_point=relay.const(output_zero_point, 'int32'),
             rounding=rounding,
             out_dtype=out_dtype)
 
index b57578e..6992f28 100644 (file)
@@ -46,10 +46,10 @@ def test_qnn_legalize():
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
         y = relay.qnn.op.requantize(x,
-                                    input_scale=1,
-                                    input_zero_point=0,
-                                    output_scale=1,
-                                    output_zero_point=0,
+                                    input_scale=relay.const(1, 'float32'),
+                                    input_zero_point=relay.const(0, 'int32'),
+                                    output_scale=relay.const(1, 'float32'),
+                                    output_zero_point=relay.const(0, 'int32'),
                                     out_dtype='int8')
         y = relay.Function([x], y)
         return y
@@ -58,10 +58,10 @@ def test_qnn_legalize():
         data = inputs[0]
         data = relay.add(relay.const(0, 'int8'), data)
         y = relay.qnn.op.requantize(data,
-                                    input_scale=1,
-                                    input_zero_point=0,
-                                    output_scale=1,
-                                    output_zero_point=0,
+                                    input_scale=relay.const(1, 'float32'),
+                                    input_zero_point=relay.const(0, 'int32'),
+                                    output_scale=relay.const(1, 'float32'),
+                                    output_zero_point=relay.const(0, 'int32'),
                                     out_dtype='int8')
         return y
 
@@ -69,10 +69,10 @@ def test_qnn_legalize():
         x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
         y = relay.add(relay.const(0, 'int8'), x)
         z = relay.qnn.op.requantize(y,
-                                    input_scale=1,
-                                    input_zero_point=0,
-                                    output_scale=1,
-                                    output_zero_point=0,
+                                    input_scale=relay.const(1, 'float32'),
+                                    input_zero_point=relay.const(0, 'int32'),
+                                    output_scale=relay.const(1, 'float32'),
+                                    output_zero_point=relay.const(0, 'int32'),
                                     out_dtype='int8')
         z = relay.Function([x], z)
         return z
@@ -102,10 +102,10 @@ def test_qnn_legalize_qnn_conv2d():
                 dtype=kernel_dtype)
         func = relay.qnn.op.conv2d(
                 data, kernel,
-                input_zero_point=1,
-                kernel_zero_point=1,
-                input_scale=1.0,
-                kernel_scale=1.0,
+                input_zero_point=relay.const(1, 'int32'),
+                kernel_zero_point=relay.const(1, 'int32'),
+                input_scale=relay.const(1.0, 'float32'),
+                kernel_scale=relay.const(1.0, 'float32'),
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 dilation=(1, 1),
@@ -186,10 +186,10 @@ def test_qnn_legalize_qnn_dense():
                 dtype=kernel_dtype)
         func = relay.qnn.op.dense(
                 data, kernel,
-                input_zero_point=1,
-                kernel_zero_point=1,
-                input_scale=1,
-                kernel_scale=1,
+                input_zero_point=relay.const(1, 'int32'),
+                kernel_zero_point=relay.const(1, 'int32'),
+                input_scale=relay.const(1, 'float32'),
+                kernel_scale=relay.const(1, 'float32'),
                 out_dtype='int32')
 
         mod = relay.Function(relay.analysis.free_vars(func), func)