[REDO AFTER GH BUG] Add support for quantized models via QNN (#5016)
authormasahi <masahi129@gmail.com>
Tue, 10 Mar 2020 01:36:50 +0000 (10:36 +0900)
committerGitHub <noreply@github.com>
Tue, 10 Mar 2020 01:36:50 +0000 (10:36 +0900)
This reverts commit f346c60287b50950275e20db9e6d84b3fc568a00.

python/tvm/relay/frontend/pytorch.py
python/tvm/relay/frontend/qnn_torch.py [new file with mode: 0644]
tests/python/frontend/pytorch/qnn_test.py [new file with mode: 0644]
tests/python/frontend/pytorch/test_forward.py

index e284e48..ff37f82 100644 (file)
@@ -19,6 +19,7 @@
 # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
 """PT: PyTorch frontend."""
 import itertools
+import logging
 
 import numpy as np
 
@@ -32,6 +33,8 @@ from .common import get_relay_op
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
 
+from . import qnn_torch
+
 __all__ = ["from_pytorch"]
 
 # operator implementation
@@ -146,6 +149,10 @@ def _zeros():
 def _relu():
     def _impl(inputs, input_types):
         data = inputs[0]
+        if input_types[0] == "quint8":
+            assert len(inputs) == 3, "Input quant param not found in op inputs"
+            input_zero_point = _expr.const(inputs[2], dtype="int32")
+            return qnn_torch.quantized_relu(data, input_zero_point)
         return _op.nn.relu(data)
     return _impl
 
@@ -154,9 +161,14 @@ def _adaptive_avg_2d():
         data = inputs[0]
         output_size = _infer_shape(inputs[1])
 
-        return _op.nn.adaptive_avg_pool2d(
-            data,
-            output_size=output_size)
+        def func(x):
+            return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
+
+        if input_types[0] == "quint8":
+            return qnn_torch.quantized_adaptive_avg_2d(data, func)
+
+        return func(data)
+
     return _impl
 
 def _adaptive_max_2d():
@@ -506,7 +518,18 @@ def _mean():
         else:
             exclude = False
 
-        return _op.mean(data, axis, keepdims, exclude)
+        def func(x):
+            return _op.mean(x, axis, keepdims, exclude)
+
+        if input_types[0] == "quint8":
+            assert len(inputs) == 6, "Input quant param not found in op inputs"
+            input_scale = _expr.const(inputs[4])
+            input_zero_point = _expr.const(inputs[5])
+            return qnn_torch.quantized_mean(data, input_scale,
+                                            input_zero_point, func)
+
+        return func(data)
+
     return _impl
 
 def _chunk():
@@ -668,10 +691,40 @@ def _upsample(method):
         else:
             coord_trans = "half_pixel"
 
-        return _op.image.resize(data, out_size, "NCHW", method, coord_trans)
+        def func(x):
+            return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
+
+        if input_types[0] == "quint8":
+            import torch
+            from packaging import version
+
+            # Torch version > 1.4 changed upsampling API
+            if version.parse(torch.__version__) > version.parse("1.4.0"):
+                num_inputs = 7
+            else:
+                num_inputs = 5
+
+            assert len(inputs) == num_inputs, "Input quant param not found in op inputs"
+
+            input_scale = _expr.const(inputs[-2])
+            input_zero_point = _expr.const(inputs[-1])
+            return qnn_torch.quantized_upsample(data, input_scale,
+                                                input_zero_point, func)
+        return func(data)
 
     return _impl
 
+
+def _expand_as():
+    def _impl(inputs, input_types):
+        # TODO: maybe fix this
+        # This assumes expand_as can be removed because TVM has broadcast op
+        msg = "aten::expand_as(...) found, assume it is part of broadcast op"
+        logging.warning(msg)
+        return inputs[0]
+    return _impl
+
+
 # Helper functions for operator implementation
 
 def _convert_data_type(input_type):
@@ -792,6 +845,7 @@ _convert_map = {
     "aten::detach"                          : _identity(),
     "aten::upsample_bilinear2d"             : _upsample("bilinear"),
     "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
+    "aten::expand_as"                       : _expand_as()
 }
 
 
@@ -842,6 +896,7 @@ def _report_missing_conversion(op_names):
                  "prim::ListConstruct", "prim::ListUnpack",
                  "prim::TupleConstruct", "prim::TupleUnpack"]
     known_ops += list(_convert_map.keys())
+    known_ops += list(qnn_torch.convert_map.keys())
 
     missing = [op_name for op_name in op_names
                if op_name not in known_ops]
@@ -1008,6 +1063,7 @@ def parse_params(graph, state_dict):
     getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
     params = {}
     param_tensors = {}
+    packed_param_map = {}
     seen = set()
 
     for node in getattr_nodes:
@@ -1020,14 +1076,18 @@ def parse_params(graph, state_dict):
             full_attr = _getattr_full_name(getattrs)
             full_attr_node_name = _get_output_name(getattrs[-1])
 
-            if full_attr in state_dict:
+            if full_attr.endswith("_packed_params"):  # for quantized models
+                err_msg = "parameter %s not found in state dict" % full_attr
+                assert full_attr in state_dict, err_msg
+                packed_param_map[full_attr_node_name] = full_attr
+            elif full_attr in state_dict:
                 torch_tensor = state_dict[full_attr]
                 tensor, var = _get_tensor_and_var(torch_tensor,
                                                   full_attr_node_name)
                 param_tensors[full_attr_node_name] = tensor
                 params[full_attr_node_name] = var
 
-    return params, param_tensors
+    return params, param_tensors, packed_param_map
 
 
 def parse_operators(operators, outputs, output_index_map, ret_name):
@@ -1108,16 +1168,26 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
 
     params = script_module.state_dict()
     input_vars = parse_inputs(graph.inputs(), input_shapes)
-    param_vars, tensors = parse_params(graph, params)
+    param_vars, tensors, packed_param_map = parse_params(graph, params)
+    tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
 
     input_vars.update(param_vars)
     outputs = list(input_vars.values())
     output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
     ret_name = _get_input_names(graph.return_node())[0]
 
+    # For quantized models
+    if "aten::quantize_per_tensor" in op_names:
+        weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
+        qnn_torch.add_input_quant_params_to_op_inputs(graph)
+        qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
+                                              packed_param_map,
+                                              weight_quant_params)
+        qnn_torch.add_quant_params(tvm_params, weight_quant_params)
+        _convert_map.update(qnn_torch.convert_map)
+
     body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
                            output_index_map, ret_name)
     func = tvm.relay.Function(_analysis.free_vars(body), body)
-    tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
 
     return _module.IRModule.from_expr(func), tvm_params
diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py
new file mode 100644 (file)
index 0000000..0704e34
--- /dev/null
@@ -0,0 +1,692 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, import-outside-toplevel
+""" Functions to convert quantized torch models to QNN """
+
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.relay import expr as _expr
+from tvm.relay import op as _op
+from tvm.relay.frontend.common import infer_shape
+
+
+class QNNParam:
+    """ A placeholder for weight quantization parameters """
+
+    def __init__(self, weight, bias, scale, zero_point, param_key):
+        param_prefix = param_key[:-len("._packed_params")]
+        self.weight_var = _expr.var(param_prefix + "_weight",
+                                    shape=weight.shape)
+        self.weight = weight
+
+        if bias is not None:
+            self.bias_var = _expr.var(param_prefix + "_bias",
+                                      shape=bias.shape)
+            self.bias = bias.detach().numpy()
+        else:
+            self.bias_var = None
+            self.bias = None
+
+        self.scale = _expr.const(scale)
+        self.zero_point = _expr.const(zero_point, dtype="int32")
+
+
+def _unpack_quant_params(param_name, packed_params, unpack_func):
+    # Torch stores quantized params in a custom packed format,
+    # need to unpack and retrieve them as numpy arrays
+    qweight, bias = unpack_func(packed_params)
+    weight_np = qweight.dequantize().numpy()
+
+    import torch
+    if qweight.qscheme() == torch.per_tensor_affine:
+        param = QNNParam(weight_np, bias, qweight.q_scale(),
+                         int(qweight.q_zero_point()), param_name)
+    else:
+        scales = qweight.q_per_channel_scales().numpy()
+        zero_points = qweight.q_per_channel_zero_points().numpy()
+        # This is an assumption posed by QNN
+        msg = "The values of zero points should be all zero for per channel"
+        assert np.all(zero_points == 0), msg
+        param = QNNParam(weight_np, bias, scales, 0, param_name)
+
+    return param
+
+
+def get_weight_quant_params(script_module):
+    """ Retrive and unpack weight parameters from quantized modules """
+    conv_packed_params = []
+    linear_packed_params = []
+
+    import torch
+    # conv and linear requires different unpacking function
+    # extract all conv and linear parameters separately to distinguish them
+    for name, m in script_module.named_modules():
+        if isinstance(m, torch.jit.RecursiveScriptModule):
+            if "Conv" in m.original_name:
+                conv_packed_params.append((name, m.state_dict()))
+            elif m.original_name == "LinearPackedParams":
+                linear_packed_params.append((name, m.state_dict()))
+
+    pairs = [(torch.ops.quantized.conv2d_unpack, conv_packed_params),
+             (torch.ops.quantized.linear_unpack, linear_packed_params)]
+
+    quant_params = {}
+    param_name = "_packed_params"
+    for unpack_func, params in pairs:
+        for name, state_dict in params:
+            assert len(state_dict) == 1
+            assert param_name in state_dict
+            key = name + "." + param_name
+            packed_param = state_dict[param_name]
+            quant_params[key] = _unpack_quant_params(key, packed_param,
+                                                     unpack_func)
+
+    return quant_params
+
+
+def add_quant_params_to_outputs(outputs, output_index_map,
+                                packed_param_map, quant_params):
+    """
+    Add quant params to outputs so that they can be referenced by other
+    ops later. Weights are quantized here.
+    """
+    for node_name, packed_param_name in packed_param_map.items():
+        qparam = quant_params[packed_param_name]
+        output_index_map[node_name] = len(outputs)
+        qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
+                                        qparam.zero_point, out_dtype="int8",
+                                        axis=0)
+        param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
+        outputs.append(param_tup)
+
+
+def _get_quant_param_for_input(input_value):
+    """
+    We want to know the input scale and zp of this input_value, since
+    input quant params are not explicitly passed around in torch (they
+    are embeded in a QTensor data structure, not visible statically).
+    We know that it is quantized using output scale and zp
+    of some previous quantized op. The purpose of this function
+    is to find that pair of parameters.
+    """
+    # Indices for output scale and zp
+    # For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7),
+    # 6th and 7th arg are output scale and zp respectively.
+    output_quant_param_indices = {
+        "aten::quantize_per_tensor": (1, 2),
+        "quantized::conv2d": (6, 7),
+        "quantized::conv2d_relu": (6, 7),
+        "quantized::linear": (2, 3),
+        "quantized::linear_relu": (2, 3),
+        "quantized::add_relu": (2, 3),
+        "quantized::add": (2, 3),
+        "quantized::mul_relu": (2, 3),
+        "quantized::mul": (2, 3),
+        "quantized::cat": (2, 3),
+        "quantized::mul_scalar": (2, 3),
+        "quantized::add_scalar": (2, 3)
+    }
+
+    def dfs(current_node):
+        # trace back to find the producer of this input value
+        current_op = current_node.kind()
+        if current_op in output_quant_param_indices:
+            indices = output_quant_param_indices[current_op]
+            scale = current_node.inputsAt(indices[0])
+            zp = current_node.inputsAt(indices[1])
+            return scale, zp
+
+        # Trace back eariler nodes, dfs order
+        # Assume quantized tensor comes earlier in the args
+        for arg in current_node.inputs():
+            return dfs(arg.node())
+
+        # shouldn't happen
+        assert False, "No producer for %s" % (str(current_node))
+
+    return dfs(input_value.node())
+
+
+def _get_add_scalar_output_quant_param(input_scale, input_zero_point,
+                                       scalar):
+    """
+    Determine the output scale and zp of quantized::add_scalar op
+    This is used for mobilenet v3
+    Refer to aten/src/ATen/native/quantized/cpu/qadd.cpp
+    The names of variables are the same as torch impl
+    """
+    q_min = 0
+    q_max = 255
+    s = input_scale
+    z = input_zero_point
+    c = scalar
+    c_q = round(c / s)
+
+    if q_min > z - c_q:
+        s_prime = (float(q_max) - (z - c_q)) / (float(q_max) - q_min) * s
+        z_prime = q_min
+    elif q_max < z - c_q:
+        s_prime = (float(z - c_q) - q_min) / (float(q_max) - q_min) * s
+        z_prime = q_max
+    else:
+        s_prime = s
+        z_prime = z - c_q
+
+    return s_prime, z_prime
+
+
+def _get_mul_scalar_output_quant_param(input_scale, input_zero_point,
+                                       scalar):
+    """
+    Determine the output scale and zp of quantized::mul_scalar op
+    This is used for mobilenet v3
+    Refer to aten/src/ATen/native/quantized/cpu/qmul.cpp
+    The names of variables are the same as torch impl
+    """
+    q_min = 0
+    q_max = 255
+    self_scale = input_scale
+    self_zero_point = input_zero_point
+    other_val = scalar
+
+    if other_val > 0.0:
+        s_prime = other_val * self_scale
+        z_prime = self_zero_point
+    elif other_val == 0.0:
+        s_prime = 1.0
+        z_prime = 0
+    else:
+        s_prime = abs(other_val) * self_scale
+        z_prime = q_max - (self_zero_point - q_min)
+
+    return s_prime, z_prime
+
+
+def _add_output_quant_params_to_scalar_op(node, graph,
+                                          input_scale, input_zero_point,
+                                          scalar):
+    """
+    The output scale and zp of {add,mul}_scalar op are not explicit in the IR
+    They are required for _get_quant_param_for_input above to work correctly
+    So calculate these params using the same way torch does, and make new
+    constant nodes in the input IR. Also add these params to the inputs of
+    scalar op.
+
+    For example,
+       %6 : float = prim::Constant[value=3.]()
+       %input : QUInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6)
+    becomes
+       %6 : float = prim::Constant[value=3.]()
+       %7 : float = prim::Constant[value=0.015686161816120148]()
+       %8 : int = prim::Constant[value=0]()
+       %input : UInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6, %7, %8)
+
+    %7 and %8 are newly created output scale and zp constant nodes
+    """
+    import torch
+    operator = node.kind()
+
+    if operator == "quantized::mul_scalar":
+        out_scale, out_zero_point = \
+          _get_mul_scalar_output_quant_param(input_scale, input_zero_point,
+                                             scalar)
+    elif operator == "quantized::add_scalar":
+        out_scale, out_zero_point = \
+          _get_add_scalar_output_quant_param(input_scale, input_zero_point,
+                                             scalar)
+    else:
+        raise NotImplementedError("unsupported scalar op: %s" % operator)
+
+    # create new constant nodes and add them to graph
+    out_scale_node = graph.create("prim::Constant")
+    out_zero_point_node = graph.create("prim::Constant")
+    out_scale_node.insertBefore(node)
+    out_zero_point_node.insertBefore(node)
+    out_scale_node.f_("value", out_scale)
+    out_zero_point_node.i_("value", out_zero_point)
+    out_scale_node.output().setType(torch._C.FloatType.get())
+    out_zero_point_node.output().setType(torch._C.IntType.get())
+    node.addInput(out_scale_node.output())
+    node.addInput(out_zero_point_node.output())
+
+
+def add_input_quant_params_to_op_inputs(graph):
+    """
+    In Torch, input quant params are not explicitly passed around
+    Instead, they are stored in QTensor data structure, and retrieved
+    at runtime by each quantized ops.
+    However, they need to be known statically for QNN translation.
+    To workaround and simplify the translation of inputs, we manually add
+    input quant params to inputs of Torch quantized operators listed below.
+    See _quantized_conv2d() below for example of why this is helpful.
+
+    For example,
+      %input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435)
+    becomes
+      %395 : float = prim::Constant[value=0.036212071776390076]()
+      %396 : int = prim::Constant[value=0]()
+      %430 : float = prim::Constant[value=0.16080744564533234]()
+      %431 : int = prim::Constant[value=42]()
+      %input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435,
+                                                     %430, %431, %395, %396)
+
+    %434, %435 are output scale and zp of quantized::add op
+    %430, %431, %395, %396 are two pairs of input (scale, zp) for two tensors
+    added by this function
+    """
+    # How many quantized tensors each op takes as inputs?
+    # A pair of (scale, zp) for each input quantized tensor will be added
+    # to the input nodes
+    num_quantized_inputs = {"quantized::conv2d": 1,
+                            "quantized::conv2d_relu": 1,
+                            "quantized::linear": 1,
+                            "quantized::linear_relu": 1,
+                            "quantized::add_relu": 2,
+                            "quantized::add": 2,
+                            "quantized::mul_relu": 2,
+                            "quantized::mul": 2,
+                            "aten::dequantize": 1,
+                            "aten::mean": 1,
+                            "aten::upsample_bilinear2d": 1,
+                            "aten::relu_": 1,
+                            "aten::relu": 1,
+                            "quantized::add_scalar": 1,
+                            "quantized::mul_scalar": 1,
+                            'quantized::relu6': 1}
+
+    need_input_quant_param = set(num_quantized_inputs.keys())
+    need_input_quant_param.add("quantized::cat")
+
+    for node in graph.nodes():
+        operator = node.kind()
+        if operator not in need_input_quant_param:
+            continue
+
+        input_scales = []
+        input_zero_points = []
+
+        if operator == "quantized::cat":
+            # the number of inputs to concat is not constant
+            # so handle it separately
+            inputs = node.inputsAt(0).node().inputs()
+            for inp in inputs:
+                scale, zp = _get_quant_param_for_input(inp)
+                input_scales.append(scale)
+                input_zero_points.append(zp)
+        else:
+            for i in range(num_quantized_inputs[operator]):
+                scale, zp = _get_quant_param_for_input(node.inputsAt(i))
+                input_scales.append(scale)
+                input_zero_points.append(zp)
+
+        if operator in ["quantized::add_scalar", "quantized::mul_scalar"]:
+            scalar = node.inputsAt(1).node().f("value")
+            inp_scale = input_scales[0].node().f("value")
+            inp_zero_point = input_zero_points[0].node().i("value")
+
+            # see the comments in this function above
+            _add_output_quant_params_to_scalar_op(node, graph,
+                                                  inp_scale, inp_zero_point,
+                                                  scalar)
+
+        for scale, zp in zip(input_scales, input_zero_points):
+            node.addInput(scale)
+            node.addInput(zp)
+
+
+def add_quant_params(params, quant_params):
+    """ Add quant parameters to TVM param map """
+    for qparam in quant_params.values():
+        params[qparam.weight_var.name_hint] = tvm.nd.array(qparam.weight)
+        if qparam.bias is not None:
+            params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)
+
+
+def quantized_adaptive_avg_2d(data, func_fp32):
+    # this follows tflite impl
+    inp = _op.cast(data, dtype="int32")
+    out = func_fp32(inp)
+    return _op.cast(out, "uint8")
+
+
+def quantized_mean(data, input_scale, input_zero_point, func_fp32):
+    # refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp
+    dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
+    out = func_fp32(dequantized)
+    return relay.qnn.op.quantize(out, input_scale, input_zero_point,
+                                 out_dtype="uint8", axis=1)
+
+
+def quantized_upsample(data, input_scale, input_zero_point, func_fp32):
+    # currently piggy backs to fp32, it gets identical output as torch
+    data = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
+    out = func_fp32(data)
+    return relay.qnn.op.quantize(out, input_scale, input_zero_point,
+                                 out_dtype="uint8", axis=1)
+
+
+def quantized_relu(data, input_zero_point):
+    # refer to aten/src/ATen/native/quantized/cpu/qrelu.cpp
+    zp = _op.cast(input_zero_point, dtype="uint8")
+    return _op.tensor.maximum(data, zp)
+
+
+def _quantize_per_tensor():
+    def _impl(inputs, _):
+        return relay.qnn.op.quantize(inputs[0], _expr.const(inputs[1]),
+                                     _expr.const(inputs[2]), out_dtype="uint8",
+                                     axis=1)
+    return _impl
+
+
+def _dequantize():
+    def _impl(inputs, _):
+        assert len(inputs) == 3, "Input quant params not found in op inputs"
+        inp_scale = _expr.const(inputs[1])
+        inp_zero_point = _expr.const(inputs[2])
+        return relay.qnn.op.dequantize(inputs[0], inp_scale, inp_zero_point)
+    return _impl
+
+
+def _get_numpy(relay_const_scalar):
+    return relay_const_scalar.data.asnumpy()
+
+
+def _get_scalar(relay_const_scalar):
+    return np.asscalar(_get_numpy(relay_const_scalar))
+
+
+def _do_bias_and_requantize(output, bias, input_scale, weight_scale,
+                            output_scale, output_zero_point,
+                            with_relu):
+    """ Output processing for conv and linear """
+    # this is a vector for per channel case
+    requant_input_scale = _expr.const(_get_numpy(input_scale) *
+                                      _get_numpy(weight_scale))
+    # Torch does bias add and requanize scale in fp32
+    # refer to third_party/fbgemm/include/fbgemm/OutputProcessing-inl.h
+    # Instead, we do bias add in int32 and use qnn requantize, which needs
+    # integer input.
+    # We observed no loss in accuracy in doing this way, and it is better
+    # for tvm because bias quantization can be done at compile time
+    # Instead, the torch way requires rounding of activation at runtime
+
+    if bias is not None:
+        qbias = relay.qnn.op.quantize(bias, requant_input_scale,
+                                      _expr.const(0, "int32"),
+                                      out_dtype="int32", axis=0)
+        requantize_input = _op.nn.bias_add(output, qbias)
+    else:
+        requantize_input = output
+
+    requantized = relay.qnn.op.requantize(requantize_input,
+                                          requant_input_scale,
+                                          relay.const(0, 'int32'),
+                                          output_scale, output_zero_point,
+                                          out_dtype="int32", axis=1)
+    clip_min = 0
+    if with_relu:
+        clip_min = _get_scalar(output_zero_point)
+
+    clip = _op.tensor.clip(requantized, clip_min, 255.)
+    return _op.cast(clip, dtype="uint8")
+
+
+def _quantized_conv2d(with_relu=False):
+    def _impl(inputs, _):
+        # refer to src/ATen/native/quantized/cpu/qconv.cpp
+        # inputs[0]: input tensor
+        # inputs[1]: (weight, scale, zero_point, bias)
+        # inputs[2-5]: stride, padding, dilation, groups
+        # inputs[6]: output_scale
+        # inputs[7]: output_zero_point
+        # inputs[8]: input_scale (added manually by frontend)
+        # inputs[9]: input_zero_point (added manually by frontend)
+        weight = inputs[1][0]
+        weight_scale = inputs[1][1]
+        weight_zero_point = inputs[1][2]
+
+        output_scale = _expr.const(inputs[6])
+        output_zero_point = _expr.const(inputs[7])
+
+        assert len(inputs) == 10, "Input quant params not found in op inputs"
+        # These are manually added by add_input_quant_params_to_op_inputs above
+        # In torch, they are retrieved from QTensor data structure at runtime
+        input_scale = _expr.const(inputs[8])
+        input_zero_point = _expr.const(inputs[9])
+
+        strides, padding, dilation = inputs[2], inputs[3], inputs[4]
+        strides = infer_shape(inputs[2])
+        padding = infer_shape(inputs[3])
+        dilation = infer_shape(inputs[4])
+        groups = inputs[5]
+
+        weight_shape = infer_shape(weight)
+        kernel_size = (weight_shape[2], weight_shape[3])
+        out_channels = weight_shape[0]
+
+        if padding[0] != 0 or padding[1] != 0:
+            pad_val = _get_scalar(input_zero_point)
+            inp = _op.nn.pad(inputs[0], pad_width=((0, 0),
+                                                   (0, 0),
+                                                   (padding[0], padding[0]),
+                                                   (padding[1], padding[1])),
+                             pad_value=float(pad_val))
+        else:
+            inp = inputs[0]
+
+        # padding is (0, 0) because we did explicit pad op with
+        # pad value being zero point above
+        conv_out = relay.qnn.op.conv2d(inp, weight,
+                                       input_zero_point, weight_zero_point,
+                                       input_scale, weight_scale,
+                                       kernel_size=kernel_size,
+                                       dilation=dilation, strides=strides,
+                                       padding=(0, 0), groups=groups,
+                                       channels=out_channels)
+        bias_var = inputs[1][3]
+
+        return _do_bias_and_requantize(conv_out, bias_var, input_scale,
+                                       weight_scale, output_scale,
+                                       output_zero_point, with_relu)
+
+    return _impl
+
+
+def _linear(with_relu=False):
+    # similar to conv
+    def _impl(inputs, _):
+        weight = inputs[1][0]
+        weight_scale = inputs[1][1]
+        weight_zero_point = inputs[1][2]
+        output_scale = _expr.const(inputs[2])
+        output_zero_point = _expr.const(inputs[3])
+        assert len(inputs) == 6, "Input quant params not found in op inputs"
+        # Manually added by add_input_quant_params_to_op_inputs above
+        input_scale = _expr.const(inputs[4])
+        input_zero_point = _expr.const(inputs[5])
+
+        weight_shape = infer_shape(weight)
+        dense = relay.qnn.op.dense(inputs[0], weight,
+                                   input_zero_point, weight_zero_point,
+                                   input_scale, weight_scale,
+                                   units=weight_shape[0])
+        bias_var = inputs[1][3]
+
+        return _do_bias_and_requantize(dense, bias_var, input_scale,
+                                       weight_scale, output_scale,
+                                       output_zero_point, with_relu)
+
+    return _impl
+
+
+def _binop(relay_op, with_relu=False):
+    # refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp
+    # they piggy backs to fp32 math by dequantize -> fp32 math -> quantize
+    def _impl(inputs, _):
+        output_scale = _expr.const(inputs[2])
+        output_zero_point = _expr.const(inputs[3])
+        assert len(inputs) == 8, "Input quant params not found in op inputs"
+        # Manually added by add_input_quant_params_to_op_inputs above
+        input_scale_lhs = _expr.const(inputs[4])
+        input_zero_point_lhs = _expr.const(inputs[5])
+        input_scale_rhs = _expr.const(inputs[6])
+        input_zero_point_rhs = _expr.const(inputs[7])
+        lhs = inputs[0]
+        rhs = inputs[1]
+
+        if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize':
+            lhs = lhs.args[0]
+        else:
+            lhs = relay.qnn.op.dequantize(lhs,
+                                          input_scale_lhs,
+                                          input_zero_point_lhs)
+
+        if isinstance(rhs, _expr.Call) and rhs.op.name == 'qnn.quantize':
+            rhs = rhs.args[0]
+        else:
+            rhs = relay.qnn.op.dequantize(rhs,
+                                          input_scale_rhs,
+                                          input_zero_point_rhs)
+        fp32_out = relay_op(lhs, rhs)
+
+        if with_relu:
+            fp32_out = _op.nn.relu(fp32_out)
+
+        return relay.qnn.op.quantize(fp32_out,
+                                     output_scale,
+                                     output_zero_point,
+                                     axis=-1,
+                                     out_dtype="uint8")
+    return _impl
+
+
+def _cat():
+    # refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp
+    # for concat they also piggy backs to fp32(!)
+    # dequantize -> fp32 math -> quantize
+    # we can also use QNN concat op. we observed no change in accuracy
+    def _impl(inputs, _):
+        axis = inputs[1]
+        output_scale = _expr.const(inputs[2])
+        output_zero_point = _expr.const(inputs[3])
+        num_inputs = (len(inputs) - 4) // 2
+        dequantized = []
+
+        for i in range(0, num_inputs):
+            inp_scale = _expr.const(inputs[4+i*2])
+            inp_zp = _expr.const(inputs[4+i*2+1])
+            dequantized.append(relay.qnn.op.dequantize(inputs[0][i],
+                                                       inp_scale, inp_zp))
+
+        concat = _op.tensor.concatenate(dequantized, axis=axis)
+        return relay.qnn.op.quantize(concat, output_scale, output_zero_point,
+                                     axis=1, out_dtype="uint8")
+
+    return _impl
+
+
+def _add_scalar():
+    # this is used for mobilenet v3
+    def _impl(inputs, _):
+        # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp
+        assert len(inputs) == 6, "Input quant params not found in op inputs"
+        s = inputs[4]
+        z = inputs[5]
+        c = inputs[1]
+        c_q = round(c / s)
+        q_min = 0
+        q_max = 255
+
+        # math for calculating output scale and zp are already done
+        # during _add_output_quant_params_to_scalar_op above
+        out_scale = _expr.const(inputs[2])
+        out_zp = _expr.const(inputs[3])
+
+        if q_min > z - c_q or q_max < z - c_q:
+            dequant = relay.qnn.op.dequantize(inputs[0],
+                                              _expr.const(s), _expr.const(z))
+            dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s))
+            return relay.qnn.op.quantize(dequantized_add, out_scale, out_zp,
+                                         axis=1, out_dtype="uint8")
+        # only scale change
+        return inputs[0]
+
+    return _impl
+
+
+def quantize_scalar(data, scale, zero_point):
+    # used to quantize 6., in mobilenet v3
+    transformed = zero_point + data / scale
+    return max(0, min(round(transformed), 255))
+
+
+def _relu6():
+    # refer to src/ATen/native/quantized/cpu/qrelu.cpp
+    def _impl(inputs, _):
+        assert len(inputs) == 4, "Input quant params not found in op inputs"
+        input_scale = inputs[2]
+        input_zero_point = inputs[3]
+        six = quantize_scalar(6., input_scale, input_zero_point)
+        return _op.tensor.clip(inputs[0], input_zero_point, six)
+    return _impl
+
+
+def _mul_scalar():
+    # this is used for mobilenet v3
+    def _impl(inputs, _):
+        # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp
+        # math for calculating output scale and zp are already done
+        # during _add_output_quant_params_to_scalar_op above
+        assert len(inputs) == 6, "Input quant params not found in op inputs"
+        other_val = inputs[1]  # scalar
+
+        if other_val > 0.0:
+            # only scale change
+            return inputs[0]
+        if other_val == 0.0:
+            shape = infer_shape(inputs[0])
+            return _op.full(_expr.const(0), shape, dtype="uint8")
+
+        # negative scale case
+        q_min = 0
+        q_max = 255
+        bias = _expr.const(q_max + q_min, dtype="int8")
+        int8 = bias - _op.cast(inputs[0], "int8")
+        return _op.cast(int8, "uint8")
+
+    return _impl
+
+
+convert_map = {
+    'aten::quantize_per_tensor': _quantize_per_tensor(),
+    'quantized::conv2d_relu': _quantized_conv2d(True),
+    'aten::dequantize': _dequantize(),
+    'quantized::conv2d': _quantized_conv2d(),
+    'quantized::add_relu': _binop(relay.add, True),
+    'quantized::add': _binop(relay.add),
+    'quantized::mul_relu': _binop(relay.multiply, True),
+    'quantized::mul': _binop(relay.multiply),
+    'quantized::linear': _linear(),
+    'quantized::linear_relu': _linear(True),
+    'quantized::cat': _cat(),
+    'quantized::add_scalar': _add_scalar(),
+    'quantized::mul_scalar': _mul_scalar(),
+    'quantized::relu6': _relu6()
+}
diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py
new file mode 100644 (file)
index 0000000..e3a876c
--- /dev/null
@@ -0,0 +1,455 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Tests on quantized torch model conversion """
+import os
+
+from PIL import Image
+
+import numpy as np
+
+import torch
+from torch import nn
+from torch.quantization import QuantStub, DeQuantStub
+from torch.quantization import fuse_modules, QuantWrapper
+
+import tvm
+from tvm import relay
+from tvm.relay.frontend.pytorch import get_graph_input_names
+from tvm.contrib.download import download_testdata
+
+
+def torch_version_check():
+    from packaging import version
+    return version.parse(torch.__version__) > version.parse("1.4.0")
+
+
+def get_tvm_runtime(script_module, input_name, ishape):
+
+    input_shapes = {input_name: ishape}
+    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+
+    with relay.build_config(opt_level=3):
+        # test on only cpu for now, torch cannot run quant models on cuda
+        # also not to make CI too slow
+        json, lib, params = relay.build(mod, target="llvm", params=params)
+
+    runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0))
+    runtime.set_input(**params)
+    return runtime
+
+
+def get_qconfig(per_channel):
+    from torch.quantization.observer import MovingAverageMinMaxObserver
+    from torch.quantization.observer import default_weight_observer
+
+    if per_channel:
+        return torch.quantization.get_default_qconfig('fbgemm')
+    else:
+        act = MovingAverageMinMaxObserver.with_args(reduce_range=False)
+        return torch.quantization.QConfig(activation=act,
+                                          weight=default_weight_observer)
+
+
+def quantize_model(model, inp, per_channel=False, dummy=True):
+    model.fuse_model()
+    model.qconfig = get_qconfig(per_channel)
+    torch.quantization.prepare(model, inplace=True)
+    model(inp)
+    torch.quantization.convert(model, inplace=True)
+
+
+class ConvBn(nn.Module):
+    def __init__(self, with_relu=False):
+        super().__init__()
+        layers = [nn.Conv2d(3, 32, 3, bias=True),
+                  nn.BatchNorm2d(32)]
+        if with_relu:
+            layers.append(nn.ReLU())
+        self.conv = nn.Sequential(*layers)
+        self.quant_wrap = QuantWrapper(self.conv)
+        self.with_relu = with_relu
+
+    def forward(self, x):
+        return self.quant_wrap(x)
+
+    def fuse_model(self):
+        indices = ["0", "1"]
+        if self.with_relu:
+            indices.append("2")
+        fuse_modules(self.conv, indices, inplace=True)
+
+
+class Linear(nn.Module):
+    def __init__(self, with_relu=False):
+        super().__init__()
+        layers = [nn.Linear(16, 32)]
+        if with_relu:
+            layers.append(nn.ReLU())
+        self.fc = nn.Sequential(*layers)
+        self.quant_wrap = QuantWrapper(self.fc)
+        self.with_relu = with_relu
+
+    def forward(self, x):
+        return self.quant_wrap(x)
+
+    def fuse_model(self):
+        if self.with_relu:
+            fuse_modules(self.fc, ["0", "1"], inplace=True)
+
+
+class ReLU(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.relu = QuantWrapper(nn.ReLU())
+
+    def forward(self, x):
+        return self.relu(x)
+
+    def fuse_model(self):
+        pass
+
+
+# Mobilenet V3 related modules
+class Hsigmoid(nn.Module):
+    def __init__(self, inplace=True, add_stub=False):
+        super().__init__()
+        self.float_op = nn.quantized.FloatFunctional()
+        self.relu6 = nn.ReLU6(inplace=inplace)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.add_stub = add_stub
+
+    def forward(self, x):
+        if self.add_stub:
+            x = self.quant(x)
+        relu6 = self.relu6(self.float_op.add_scalar(x, 3.))
+        mul = self.float_op.mul_scalar(relu6, 1/6.)
+        if self.add_stub:
+            mul = self.dequant(mul)
+        return mul
+
+    def fuse_model(self):
+        pass
+
+
+class Hswish(nn.Module):
+    def __init__(self, inplace=True, add_stub=False):
+        super(Hswish, self).__init__()
+        self.float_op = nn.quantized.FloatFunctional()
+        self.hsigmoid = Hsigmoid(inplace, add_stub=False)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.add_stub = add_stub
+
+    def forward(self, x):
+        if self.add_stub:
+            x = self.quant(x)
+        mul = self.float_op.mul(x, self.hsigmoid(x))
+        if self.add_stub:
+            mul = self.dequant(mul)
+        return mul
+
+    def fuse_model(self):
+        pass
+
+
+class SqueezeExcite(nn.Module):
+    def __init__(self, channel, reduction=4, add_stub=False):
+        super(SqueezeExcite, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction, bias=False),
+            nn.ReLU(inplace=True),
+            nn.Linear(channel // reduction, channel, bias=False),
+            Hsigmoid(add_stub=False)
+        )
+        self.fmul = nn.quantized.FloatFunctional()
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.add_stub = add_stub
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        if self.add_stub:
+            x = self.quant(x)
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        out = self.fmul.mul(x, y.expand_as(x))
+        if self.add_stub:
+            return self.dequant(out)
+        else:
+            return out
+
+    def fuse_model(self):
+        fuse_modules(self.fc, ["0", "1"], inplace=True)
+
+
+# test on quantized::mul_scalar with negative scale
+class MulScalarNegative(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+        self.float_op = nn.quantized.FloatFunctional()
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        mul = self.float_op.mul_scalar(x, -0.3)
+        return self.dequant(mul)
+
+    def fuse_model(self):
+        pass
+
+
+class UpsamplingBilinear(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.relu = QuantWrapper(nn.ReLU())
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        upsample = nn.functional.interpolate(x, scale_factor=2,
+                                             mode='bilinear',
+                                             align_corners=True)
+        return self.dequant(upsample)
+
+    def fuse_model(self):
+        pass
+
+
+def test_quantized_modules():
+    imagenet_ishape = (1, 3, 224, 224)
+
+    qmodules = [
+       ("relu", imagenet_ishape, ReLU(), False),
+       ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
+    ]
+
+    for per_channel in [False, True]:
+        if per_channel:
+            postfix = ", per_channel"
+        else:
+            postfix = ""
+
+        qmodules += [
+           ("conv_bn" + postfix, imagenet_ishape, ConvBn(), per_channel),
+           ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
+           ("linear" + postfix, (16, 16), Linear(), per_channel),
+           ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel)
+        ]
+
+    if torch_version_check():
+        qmodules += [
+           ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
+           ("hswish", imagenet_ishape, Hswish(add_stub=True), False),
+           ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
+           ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True),
+           ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False)
+        ]
+    else:
+        print("Skipping tests that require torch > 1.4")
+
+    for (module_name, ishape, raw_module, per_channel) in qmodules:
+        raw_module.eval()
+        inp = torch.rand(ishape)
+
+        quantize_model(raw_module, inp, per_channel=per_channel, dummy=True)
+        script_module = torch.jit.trace(raw_module, inp).eval()
+
+        with torch.no_grad():
+            pt_result = script_module(inp.clone()).numpy()
+
+        input_name = get_graph_input_names(script_module)[0]
+
+        runtime = get_tvm_runtime(script_module, input_name, ishape)
+        runtime.set_input(input_name, inp.numpy().copy())
+        runtime.run()
+        tvm_result = runtime.get_output(0).asnumpy()
+
+        max_abs_diff = np.max(np.abs(tvm_result - pt_result))
+        mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
+        num_identical = np.sum(tvm_result == pt_result)
+        match_ratio = num_identical / float(np.prod(tvm_result.shape))
+
+        print(module_name, max_abs_diff, mean_abs_diff, match_ratio)
+
+        # sample outputs
+        """
+        relu 0.0039215684 2.6052087e-08 0.9999933567176871
+        upsample bilinear 0.0 0.0 1.0
+        conv_bn 0.22062653 0.011478779 0.6909348115006899
+        conv_bn_relu 0.3700896 0.010921672 0.7489366477964451
+        linear 0.15987062 0.009231662 0.794921875
+        linear_relu 0.14180502 0.0053220326 0.8828125
+        conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019
+        conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732
+        linear, per_channel 0.0 0.0 1.0
+        linear_relu, per_channel 0.0 0.0 1.0
+        hsigmoid 0.002614379 0.00020525524 0.9214896896258503
+        hswish 0.0052286386 0.00063522335 0.7587359162414966
+        semodule, per_channel 0.0039885044 0.0008620687 0.7838592529296875
+        mul_scalar negative 0.0011764616 7.815566e-09 0.9999933567176871
+        """
+
+        # we cannot make any guarantee on how close the raw output is to torch
+        # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1)
+
+
+def test_quantized_imagenet():
+    def get_transform():
+        import torchvision.transforms as transforms
+        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                                         std=[0.229, 0.224, 0.225])
+        return transforms.Compose([
+                transforms.Resize(256),
+                transforms.CenterCrop(224),
+                transforms.ToTensor(),
+                normalize,
+            ])
+
+    def get_real_image(im_height, im_width):
+        repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
+        img_name = 'elephant-299.jpg'
+        image_url = os.path.join(repo_base, img_name)
+        img_path = download_testdata(image_url, img_name, module='data')
+        return Image.open(img_path).resize((im_height, im_width))
+
+    def get_imagenet_input():
+        im = get_real_image(224, 224)
+        preprocess = get_transform()
+        pt_tensor = preprocess(im)
+        return np.expand_dims(pt_tensor.numpy(), 0)
+
+    from torchvision.models.quantization import resnet as qresnet
+    from torchvision.models.quantization import mobilenet as qmobilenet
+    from torchvision.models.quantization import inception as qinception
+    from torchvision.models.quantization import googlenet as qgooglenet
+
+    qmodels = []
+
+    for per_channel in [False, True]:
+        qmodels += [
+            ("resnet18", qresnet.resnet18(pretrained=True), per_channel),
+            ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
+            ("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
+            ("googlenet", qgooglenet(pretrained=True), per_channel),
+        ]
+
+    results = []
+
+    for (model_name, raw_model, per_channel) in qmodels:
+        raw_model.eval()
+
+        if per_channel:
+            model_name += ", per channel quantization"
+        else:
+            model_name += ", per tensor quantization"
+
+        inp = get_imagenet_input()
+        pt_inp = torch.from_numpy(inp)
+
+        quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False)
+        script_module = torch.jit.trace(raw_model, pt_inp).eval()
+
+        with torch.no_grad():
+            pt_result = script_module(pt_inp).numpy()
+
+        input_name = get_graph_input_names(script_module)[0]
+        runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
+        runtime.set_input(input_name, inp)
+        runtime.run()
+
+        tvm_result = runtime.get_output(0).asnumpy()
+
+        results.append((model_name, pt_result[0], tvm_result[0]))
+
+    for (model_name, pt_result, tvm_result) in results:
+        max_abs_diff = np.max(np.abs(tvm_result - pt_result))
+        mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
+        num_identical = np.sum(tvm_result == pt_result)
+        pt_top3_labels = np.argsort(pt_result)[::-1][:3]
+        tvm_top3_labels = np.argsort(pt_result)[::-1][:3]
+
+        print("\nModel name: %s" % model_name)
+        print("PyTorch top3 label:", pt_top3_labels)
+        print("TVM top3 label:", tvm_top3_labels)
+        print("max abs diff:", max_abs_diff)
+        print("mean abs_diff:", mean_abs_diff)
+        print("%d in 1000 raw outputs identical." % num_identical)
+
+        assert set(pt_top3_labels) == set(tvm_top3_labels)
+
+        # sample outputs
+        """
+        Model name: resnet18, per tensor quantization
+        PyTorch top3 label: [386 101 385]
+        TVM top3 label: [386 101 385]
+        max abs diff: 0.65681696
+        mean abs_diff: 0.14055882
+        236 in 1000 raw outputs identical.
+
+        Model name: mobilenet_v2, per tensor quantization
+        PyTorch top3 label: [101 386 385]
+        TVM top3 label: [101 386 385]
+        max abs diff: 2.1262953
+        mean abs_diff: 0.41025686
+        101 in 1000 raw outputs identical.
+
+        Model name: inception_v3, per tensor quantization
+        PyTorch top3 label: [101 386 385]
+        TVM top3 label: [101 386 385]
+        max abs diff: 0.9994669
+        mean abs_diff: 0.098697364
+        272 in 1000 raw outputs identical.
+
+        Model name: googlenet, per tensor quantization
+        PyTorch top3 label: [101 386 385]
+        TVM top3 label: [101 386 385]
+        max abs diff: 0.28248847
+        mean abs_diff: 0.0634469
+        274 in 1000 raw outputs identical.
+
+        Model name: resnet18, per channel quantization
+        PyTorch top3 label: [101 386 385]
+        TVM top3 label: [101 386 385]
+        max abs diff: 0.65908074
+        mean abs_diff: 0.1274223
+        469 in 1000 raw outputs identical.
+
+        Model name: mobilenet_v2, per channel quantization
+        PyTorch top3 label: [101 386 385]
+        TVM top3 label: [101 386 385]
+        max abs diff: 0.71120834
+        mean abs_diff: 0.15883648
+        423 in 1000 raw outputs identical.
+
+        Model name: inception_v3, per channel quantization
+        PyTorch top3 label: [386 101 385]
+        TVM top3 label: [386 101 385]
+        max abs diff: 1.3372154
+        mean abs_diff: 0.1225224
+        401 in 1000 raw outputs identical.
+
+        Model name: googlenet, per channel quantization
+        PyTorch top3 label: [101 386 385]
+        TVM top3 label: [101 386 385]
+        max abs diff: 0.34015465
+        mean abs_diff: 0.054197952
+        558 in 1000 raw outputs identical.
+        """
index e60c1fd..eed47ea 100644 (file)
@@ -854,3 +854,9 @@ if __name__ == "__main__":
     test_custom_conversion_map()
 
     test_segmentaton_models()
+
+    # Quantization test
+    from qnn_test import test_quantized_imagenet, test_quantized_modules
+
+    test_quantized_modules()
+    test_quantized_imagenet()