Revert "[Torch, QNN] Add support for quantized models via QNN (#4977)" (#5013)
authorAnimesh Jain <anijain@umich.edu>
Mon, 9 Mar 2020 20:14:58 +0000 (13:14 -0700)
committerGitHub <noreply@github.com>
Mon, 9 Mar 2020 20:14:58 +0000 (13:14 -0700)
This reverts commit fc7f0783940c362bf48cd46817956381196201e2.

python/tvm/relay/frontend/pytorch.py
python/tvm/relay/frontend/qnn_torch.py [deleted file]
tests/python/frontend/pytorch/qnn_test.py [deleted file]
tests/python/frontend/pytorch/test_forward.py

index ff37f82..e284e48 100644 (file)
@@ -19,7 +19,6 @@
 # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
 """PT: PyTorch frontend."""
 import itertools
-import logging
 
 import numpy as np
 
@@ -33,8 +32,6 @@ from .common import get_relay_op
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
 
-from . import qnn_torch
-
 __all__ = ["from_pytorch"]
 
 # operator implementation
@@ -149,10 +146,6 @@ def _zeros():
 def _relu():
     def _impl(inputs, input_types):
         data = inputs[0]
-        if input_types[0] == "quint8":
-            assert len(inputs) == 3, "Input quant param not found in op inputs"
-            input_zero_point = _expr.const(inputs[2], dtype="int32")
-            return qnn_torch.quantized_relu(data, input_zero_point)
         return _op.nn.relu(data)
     return _impl
 
@@ -161,14 +154,9 @@ def _adaptive_avg_2d():
         data = inputs[0]
         output_size = _infer_shape(inputs[1])
 
-        def func(x):
-            return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
-
-        if input_types[0] == "quint8":
-            return qnn_torch.quantized_adaptive_avg_2d(data, func)
-
-        return func(data)
-
+        return _op.nn.adaptive_avg_pool2d(
+            data,
+            output_size=output_size)
     return _impl
 
 def _adaptive_max_2d():
@@ -518,18 +506,7 @@ def _mean():
         else:
             exclude = False
 
-        def func(x):
-            return _op.mean(x, axis, keepdims, exclude)
-
-        if input_types[0] == "quint8":
-            assert len(inputs) == 6, "Input quant param not found in op inputs"
-            input_scale = _expr.const(inputs[4])
-            input_zero_point = _expr.const(inputs[5])
-            return qnn_torch.quantized_mean(data, input_scale,
-                                            input_zero_point, func)
-
-        return func(data)
-
+        return _op.mean(data, axis, keepdims, exclude)
     return _impl
 
 def _chunk():
@@ -691,40 +668,10 @@ def _upsample(method):
         else:
             coord_trans = "half_pixel"
 
-        def func(x):
-            return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
-
-        if input_types[0] == "quint8":
-            import torch
-            from packaging import version
-
-            # Torch version > 1.4 changed upsampling API
-            if version.parse(torch.__version__) > version.parse("1.4.0"):
-                num_inputs = 7
-            else:
-                num_inputs = 5
-
-            assert len(inputs) == num_inputs, "Input quant param not found in op inputs"
-
-            input_scale = _expr.const(inputs[-2])
-            input_zero_point = _expr.const(inputs[-1])
-            return qnn_torch.quantized_upsample(data, input_scale,
-                                                input_zero_point, func)
-        return func(data)
+        return _op.image.resize(data, out_size, "NCHW", method, coord_trans)
 
     return _impl
 
-
-def _expand_as():
-    def _impl(inputs, input_types):
-        # TODO: maybe fix this
-        # This assumes expand_as can be removed because TVM has broadcast op
-        msg = "aten::expand_as(...) found, assume it is part of broadcast op"
-        logging.warning(msg)
-        return inputs[0]
-    return _impl
-
-
 # Helper functions for operator implementation
 
 def _convert_data_type(input_type):
@@ -845,7 +792,6 @@ _convert_map = {
     "aten::detach"                          : _identity(),
     "aten::upsample_bilinear2d"             : _upsample("bilinear"),
     "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
-    "aten::expand_as"                       : _expand_as()
 }
 
 
@@ -896,7 +842,6 @@ def _report_missing_conversion(op_names):
                  "prim::ListConstruct", "prim::ListUnpack",
                  "prim::TupleConstruct", "prim::TupleUnpack"]
     known_ops += list(_convert_map.keys())
-    known_ops += list(qnn_torch.convert_map.keys())
 
     missing = [op_name for op_name in op_names
                if op_name not in known_ops]
@@ -1063,7 +1008,6 @@ def parse_params(graph, state_dict):
     getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
     params = {}
     param_tensors = {}
-    packed_param_map = {}
     seen = set()
 
     for node in getattr_nodes:
@@ -1076,18 +1020,14 @@ def parse_params(graph, state_dict):
             full_attr = _getattr_full_name(getattrs)
             full_attr_node_name = _get_output_name(getattrs[-1])
 
-            if full_attr.endswith("_packed_params"):  # for quantized models
-                err_msg = "parameter %s not found in state dict" % full_attr
-                assert full_attr in state_dict, err_msg
-                packed_param_map[full_attr_node_name] = full_attr
-            elif full_attr in state_dict:
+            if full_attr in state_dict:
                 torch_tensor = state_dict[full_attr]
                 tensor, var = _get_tensor_and_var(torch_tensor,
                                                   full_attr_node_name)
                 param_tensors[full_attr_node_name] = tensor
                 params[full_attr_node_name] = var
 
-    return params, param_tensors, packed_param_map
+    return params, param_tensors
 
 
 def parse_operators(operators, outputs, output_index_map, ret_name):
@@ -1168,26 +1108,16 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
 
     params = script_module.state_dict()
     input_vars = parse_inputs(graph.inputs(), input_shapes)
-    param_vars, tensors, packed_param_map = parse_params(graph, params)
-    tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
+    param_vars, tensors = parse_params(graph, params)
 
     input_vars.update(param_vars)
     outputs = list(input_vars.values())
     output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
     ret_name = _get_input_names(graph.return_node())[0]
 
-    # For quantized models
-    if "aten::quantize_per_tensor" in op_names:
-        weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
-        qnn_torch.add_input_quant_params_to_op_inputs(graph)
-        qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
-                                              packed_param_map,
-                                              weight_quant_params)
-        qnn_torch.add_quant_params(tvm_params, weight_quant_params)
-        _convert_map.update(qnn_torch.convert_map)
-
     body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
                            output_index_map, ret_name)
     func = tvm.relay.Function(_analysis.free_vars(body), body)
+    tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
 
     return _module.IRModule.from_expr(func), tvm_params
diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py
deleted file mode 100644 (file)
index 0704e34..0000000
+++ /dev/null
@@ -1,692 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name, import-outside-toplevel
-""" Functions to convert quantized torch models to QNN """
-
-import numpy as np
-
-import tvm
-from tvm import relay
-from tvm.relay import expr as _expr
-from tvm.relay import op as _op
-from tvm.relay.frontend.common import infer_shape
-
-
-class QNNParam:
-    """ A placeholder for weight quantization parameters """
-
-    def __init__(self, weight, bias, scale, zero_point, param_key):
-        param_prefix = param_key[:-len("._packed_params")]
-        self.weight_var = _expr.var(param_prefix + "_weight",
-                                    shape=weight.shape)
-        self.weight = weight
-
-        if bias is not None:
-            self.bias_var = _expr.var(param_prefix + "_bias",
-                                      shape=bias.shape)
-            self.bias = bias.detach().numpy()
-        else:
-            self.bias_var = None
-            self.bias = None
-
-        self.scale = _expr.const(scale)
-        self.zero_point = _expr.const(zero_point, dtype="int32")
-
-
-def _unpack_quant_params(param_name, packed_params, unpack_func):
-    # Torch stores quantized params in a custom packed format,
-    # need to unpack and retrieve them as numpy arrays
-    qweight, bias = unpack_func(packed_params)
-    weight_np = qweight.dequantize().numpy()
-
-    import torch
-    if qweight.qscheme() == torch.per_tensor_affine:
-        param = QNNParam(weight_np, bias, qweight.q_scale(),
-                         int(qweight.q_zero_point()), param_name)
-    else:
-        scales = qweight.q_per_channel_scales().numpy()
-        zero_points = qweight.q_per_channel_zero_points().numpy()
-        # This is an assumption posed by QNN
-        msg = "The values of zero points should be all zero for per channel"
-        assert np.all(zero_points == 0), msg
-        param = QNNParam(weight_np, bias, scales, 0, param_name)
-
-    return param
-
-
-def get_weight_quant_params(script_module):
-    """ Retrive and unpack weight parameters from quantized modules """
-    conv_packed_params = []
-    linear_packed_params = []
-
-    import torch
-    # conv and linear requires different unpacking function
-    # extract all conv and linear parameters separately to distinguish them
-    for name, m in script_module.named_modules():
-        if isinstance(m, torch.jit.RecursiveScriptModule):
-            if "Conv" in m.original_name:
-                conv_packed_params.append((name, m.state_dict()))
-            elif m.original_name == "LinearPackedParams":
-                linear_packed_params.append((name, m.state_dict()))
-
-    pairs = [(torch.ops.quantized.conv2d_unpack, conv_packed_params),
-             (torch.ops.quantized.linear_unpack, linear_packed_params)]
-
-    quant_params = {}
-    param_name = "_packed_params"
-    for unpack_func, params in pairs:
-        for name, state_dict in params:
-            assert len(state_dict) == 1
-            assert param_name in state_dict
-            key = name + "." + param_name
-            packed_param = state_dict[param_name]
-            quant_params[key] = _unpack_quant_params(key, packed_param,
-                                                     unpack_func)
-
-    return quant_params
-
-
-def add_quant_params_to_outputs(outputs, output_index_map,
-                                packed_param_map, quant_params):
-    """
-    Add quant params to outputs so that they can be referenced by other
-    ops later. Weights are quantized here.
-    """
-    for node_name, packed_param_name in packed_param_map.items():
-        qparam = quant_params[packed_param_name]
-        output_index_map[node_name] = len(outputs)
-        qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
-                                        qparam.zero_point, out_dtype="int8",
-                                        axis=0)
-        param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
-        outputs.append(param_tup)
-
-
-def _get_quant_param_for_input(input_value):
-    """
-    We want to know the input scale and zp of this input_value, since
-    input quant params are not explicitly passed around in torch (they
-    are embeded in a QTensor data structure, not visible statically).
-    We know that it is quantized using output scale and zp
-    of some previous quantized op. The purpose of this function
-    is to find that pair of parameters.
-    """
-    # Indices for output scale and zp
-    # For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7),
-    # 6th and 7th arg are output scale and zp respectively.
-    output_quant_param_indices = {
-        "aten::quantize_per_tensor": (1, 2),
-        "quantized::conv2d": (6, 7),
-        "quantized::conv2d_relu": (6, 7),
-        "quantized::linear": (2, 3),
-        "quantized::linear_relu": (2, 3),
-        "quantized::add_relu": (2, 3),
-        "quantized::add": (2, 3),
-        "quantized::mul_relu": (2, 3),
-        "quantized::mul": (2, 3),
-        "quantized::cat": (2, 3),
-        "quantized::mul_scalar": (2, 3),
-        "quantized::add_scalar": (2, 3)
-    }
-
-    def dfs(current_node):
-        # trace back to find the producer of this input value
-        current_op = current_node.kind()
-        if current_op in output_quant_param_indices:
-            indices = output_quant_param_indices[current_op]
-            scale = current_node.inputsAt(indices[0])
-            zp = current_node.inputsAt(indices[1])
-            return scale, zp
-
-        # Trace back eariler nodes, dfs order
-        # Assume quantized tensor comes earlier in the args
-        for arg in current_node.inputs():
-            return dfs(arg.node())
-
-        # shouldn't happen
-        assert False, "No producer for %s" % (str(current_node))
-
-    return dfs(input_value.node())
-
-
-def _get_add_scalar_output_quant_param(input_scale, input_zero_point,
-                                       scalar):
-    """
-    Determine the output scale and zp of quantized::add_scalar op
-    This is used for mobilenet v3
-    Refer to aten/src/ATen/native/quantized/cpu/qadd.cpp
-    The names of variables are the same as torch impl
-    """
-    q_min = 0
-    q_max = 255
-    s = input_scale
-    z = input_zero_point
-    c = scalar
-    c_q = round(c / s)
-
-    if q_min > z - c_q:
-        s_prime = (float(q_max) - (z - c_q)) / (float(q_max) - q_min) * s
-        z_prime = q_min
-    elif q_max < z - c_q:
-        s_prime = (float(z - c_q) - q_min) / (float(q_max) - q_min) * s
-        z_prime = q_max
-    else:
-        s_prime = s
-        z_prime = z - c_q
-
-    return s_prime, z_prime
-
-
-def _get_mul_scalar_output_quant_param(input_scale, input_zero_point,
-                                       scalar):
-    """
-    Determine the output scale and zp of quantized::mul_scalar op
-    This is used for mobilenet v3
-    Refer to aten/src/ATen/native/quantized/cpu/qmul.cpp
-    The names of variables are the same as torch impl
-    """
-    q_min = 0
-    q_max = 255
-    self_scale = input_scale
-    self_zero_point = input_zero_point
-    other_val = scalar
-
-    if other_val > 0.0:
-        s_prime = other_val * self_scale
-        z_prime = self_zero_point
-    elif other_val == 0.0:
-        s_prime = 1.0
-        z_prime = 0
-    else:
-        s_prime = abs(other_val) * self_scale
-        z_prime = q_max - (self_zero_point - q_min)
-
-    return s_prime, z_prime
-
-
-def _add_output_quant_params_to_scalar_op(node, graph,
-                                          input_scale, input_zero_point,
-                                          scalar):
-    """
-    The output scale and zp of {add,mul}_scalar op are not explicit in the IR
-    They are required for _get_quant_param_for_input above to work correctly
-    So calculate these params using the same way torch does, and make new
-    constant nodes in the input IR. Also add these params to the inputs of
-    scalar op.
-
-    For example,
-       %6 : float = prim::Constant[value=3.]()
-       %input : QUInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6)
-    becomes
-       %6 : float = prim::Constant[value=3.]()
-       %7 : float = prim::Constant[value=0.015686161816120148]()
-       %8 : int = prim::Constant[value=0]()
-       %input : UInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6, %7, %8)
-
-    %7 and %8 are newly created output scale and zp constant nodes
-    """
-    import torch
-    operator = node.kind()
-
-    if operator == "quantized::mul_scalar":
-        out_scale, out_zero_point = \
-          _get_mul_scalar_output_quant_param(input_scale, input_zero_point,
-                                             scalar)
-    elif operator == "quantized::add_scalar":
-        out_scale, out_zero_point = \
-          _get_add_scalar_output_quant_param(input_scale, input_zero_point,
-                                             scalar)
-    else:
-        raise NotImplementedError("unsupported scalar op: %s" % operator)
-
-    # create new constant nodes and add them to graph
-    out_scale_node = graph.create("prim::Constant")
-    out_zero_point_node = graph.create("prim::Constant")
-    out_scale_node.insertBefore(node)
-    out_zero_point_node.insertBefore(node)
-    out_scale_node.f_("value", out_scale)
-    out_zero_point_node.i_("value", out_zero_point)
-    out_scale_node.output().setType(torch._C.FloatType.get())
-    out_zero_point_node.output().setType(torch._C.IntType.get())
-    node.addInput(out_scale_node.output())
-    node.addInput(out_zero_point_node.output())
-
-
-def add_input_quant_params_to_op_inputs(graph):
-    """
-    In Torch, input quant params are not explicitly passed around
-    Instead, they are stored in QTensor data structure, and retrieved
-    at runtime by each quantized ops.
-    However, they need to be known statically for QNN translation.
-    To workaround and simplify the translation of inputs, we manually add
-    input quant params to inputs of Torch quantized operators listed below.
-    See _quantized_conv2d() below for example of why this is helpful.
-
-    For example,
-      %input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435)
-    becomes
-      %395 : float = prim::Constant[value=0.036212071776390076]()
-      %396 : int = prim::Constant[value=0]()
-      %430 : float = prim::Constant[value=0.16080744564533234]()
-      %431 : int = prim::Constant[value=42]()
-      %input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435,
-                                                     %430, %431, %395, %396)
-
-    %434, %435 are output scale and zp of quantized::add op
-    %430, %431, %395, %396 are two pairs of input (scale, zp) for two tensors
-    added by this function
-    """
-    # How many quantized tensors each op takes as inputs?
-    # A pair of (scale, zp) for each input quantized tensor will be added
-    # to the input nodes
-    num_quantized_inputs = {"quantized::conv2d": 1,
-                            "quantized::conv2d_relu": 1,
-                            "quantized::linear": 1,
-                            "quantized::linear_relu": 1,
-                            "quantized::add_relu": 2,
-                            "quantized::add": 2,
-                            "quantized::mul_relu": 2,
-                            "quantized::mul": 2,
-                            "aten::dequantize": 1,
-                            "aten::mean": 1,
-                            "aten::upsample_bilinear2d": 1,
-                            "aten::relu_": 1,
-                            "aten::relu": 1,
-                            "quantized::add_scalar": 1,
-                            "quantized::mul_scalar": 1,
-                            'quantized::relu6': 1}
-
-    need_input_quant_param = set(num_quantized_inputs.keys())
-    need_input_quant_param.add("quantized::cat")
-
-    for node in graph.nodes():
-        operator = node.kind()
-        if operator not in need_input_quant_param:
-            continue
-
-        input_scales = []
-        input_zero_points = []
-
-        if operator == "quantized::cat":
-            # the number of inputs to concat is not constant
-            # so handle it separately
-            inputs = node.inputsAt(0).node().inputs()
-            for inp in inputs:
-                scale, zp = _get_quant_param_for_input(inp)
-                input_scales.append(scale)
-                input_zero_points.append(zp)
-        else:
-            for i in range(num_quantized_inputs[operator]):
-                scale, zp = _get_quant_param_for_input(node.inputsAt(i))
-                input_scales.append(scale)
-                input_zero_points.append(zp)
-
-        if operator in ["quantized::add_scalar", "quantized::mul_scalar"]:
-            scalar = node.inputsAt(1).node().f("value")
-            inp_scale = input_scales[0].node().f("value")
-            inp_zero_point = input_zero_points[0].node().i("value")
-
-            # see the comments in this function above
-            _add_output_quant_params_to_scalar_op(node, graph,
-                                                  inp_scale, inp_zero_point,
-                                                  scalar)
-
-        for scale, zp in zip(input_scales, input_zero_points):
-            node.addInput(scale)
-            node.addInput(zp)
-
-
-def add_quant_params(params, quant_params):
-    """ Add quant parameters to TVM param map """
-    for qparam in quant_params.values():
-        params[qparam.weight_var.name_hint] = tvm.nd.array(qparam.weight)
-        if qparam.bias is not None:
-            params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)
-
-
-def quantized_adaptive_avg_2d(data, func_fp32):
-    # this follows tflite impl
-    inp = _op.cast(data, dtype="int32")
-    out = func_fp32(inp)
-    return _op.cast(out, "uint8")
-
-
-def quantized_mean(data, input_scale, input_zero_point, func_fp32):
-    # refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp
-    dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
-    out = func_fp32(dequantized)
-    return relay.qnn.op.quantize(out, input_scale, input_zero_point,
-                                 out_dtype="uint8", axis=1)
-
-
-def quantized_upsample(data, input_scale, input_zero_point, func_fp32):
-    # currently piggy backs to fp32, it gets identical output as torch
-    data = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
-    out = func_fp32(data)
-    return relay.qnn.op.quantize(out, input_scale, input_zero_point,
-                                 out_dtype="uint8", axis=1)
-
-
-def quantized_relu(data, input_zero_point):
-    # refer to aten/src/ATen/native/quantized/cpu/qrelu.cpp
-    zp = _op.cast(input_zero_point, dtype="uint8")
-    return _op.tensor.maximum(data, zp)
-
-
-def _quantize_per_tensor():
-    def _impl(inputs, _):
-        return relay.qnn.op.quantize(inputs[0], _expr.const(inputs[1]),
-                                     _expr.const(inputs[2]), out_dtype="uint8",
-                                     axis=1)
-    return _impl
-
-
-def _dequantize():
-    def _impl(inputs, _):
-        assert len(inputs) == 3, "Input quant params not found in op inputs"
-        inp_scale = _expr.const(inputs[1])
-        inp_zero_point = _expr.const(inputs[2])
-        return relay.qnn.op.dequantize(inputs[0], inp_scale, inp_zero_point)
-    return _impl
-
-
-def _get_numpy(relay_const_scalar):
-    return relay_const_scalar.data.asnumpy()
-
-
-def _get_scalar(relay_const_scalar):
-    return np.asscalar(_get_numpy(relay_const_scalar))
-
-
-def _do_bias_and_requantize(output, bias, input_scale, weight_scale,
-                            output_scale, output_zero_point,
-                            with_relu):
-    """ Output processing for conv and linear """
-    # this is a vector for per channel case
-    requant_input_scale = _expr.const(_get_numpy(input_scale) *
-                                      _get_numpy(weight_scale))
-    # Torch does bias add and requanize scale in fp32
-    # refer to third_party/fbgemm/include/fbgemm/OutputProcessing-inl.h
-    # Instead, we do bias add in int32 and use qnn requantize, which needs
-    # integer input.
-    # We observed no loss in accuracy in doing this way, and it is better
-    # for tvm because bias quantization can be done at compile time
-    # Instead, the torch way requires rounding of activation at runtime
-
-    if bias is not None:
-        qbias = relay.qnn.op.quantize(bias, requant_input_scale,
-                                      _expr.const(0, "int32"),
-                                      out_dtype="int32", axis=0)
-        requantize_input = _op.nn.bias_add(output, qbias)
-    else:
-        requantize_input = output
-
-    requantized = relay.qnn.op.requantize(requantize_input,
-                                          requant_input_scale,
-                                          relay.const(0, 'int32'),
-                                          output_scale, output_zero_point,
-                                          out_dtype="int32", axis=1)
-    clip_min = 0
-    if with_relu:
-        clip_min = _get_scalar(output_zero_point)
-
-    clip = _op.tensor.clip(requantized, clip_min, 255.)
-    return _op.cast(clip, dtype="uint8")
-
-
-def _quantized_conv2d(with_relu=False):
-    def _impl(inputs, _):
-        # refer to src/ATen/native/quantized/cpu/qconv.cpp
-        # inputs[0]: input tensor
-        # inputs[1]: (weight, scale, zero_point, bias)
-        # inputs[2-5]: stride, padding, dilation, groups
-        # inputs[6]: output_scale
-        # inputs[7]: output_zero_point
-        # inputs[8]: input_scale (added manually by frontend)
-        # inputs[9]: input_zero_point (added manually by frontend)
-        weight = inputs[1][0]
-        weight_scale = inputs[1][1]
-        weight_zero_point = inputs[1][2]
-
-        output_scale = _expr.const(inputs[6])
-        output_zero_point = _expr.const(inputs[7])
-
-        assert len(inputs) == 10, "Input quant params not found in op inputs"
-        # These are manually added by add_input_quant_params_to_op_inputs above
-        # In torch, they are retrieved from QTensor data structure at runtime
-        input_scale = _expr.const(inputs[8])
-        input_zero_point = _expr.const(inputs[9])
-
-        strides, padding, dilation = inputs[2], inputs[3], inputs[4]
-        strides = infer_shape(inputs[2])
-        padding = infer_shape(inputs[3])
-        dilation = infer_shape(inputs[4])
-        groups = inputs[5]
-
-        weight_shape = infer_shape(weight)
-        kernel_size = (weight_shape[2], weight_shape[3])
-        out_channels = weight_shape[0]
-
-        if padding[0] != 0 or padding[1] != 0:
-            pad_val = _get_scalar(input_zero_point)
-            inp = _op.nn.pad(inputs[0], pad_width=((0, 0),
-                                                   (0, 0),
-                                                   (padding[0], padding[0]),
-                                                   (padding[1], padding[1])),
-                             pad_value=float(pad_val))
-        else:
-            inp = inputs[0]
-
-        # padding is (0, 0) because we did explicit pad op with
-        # pad value being zero point above
-        conv_out = relay.qnn.op.conv2d(inp, weight,
-                                       input_zero_point, weight_zero_point,
-                                       input_scale, weight_scale,
-                                       kernel_size=kernel_size,
-                                       dilation=dilation, strides=strides,
-                                       padding=(0, 0), groups=groups,
-                                       channels=out_channels)
-        bias_var = inputs[1][3]
-
-        return _do_bias_and_requantize(conv_out, bias_var, input_scale,
-                                       weight_scale, output_scale,
-                                       output_zero_point, with_relu)
-
-    return _impl
-
-
-def _linear(with_relu=False):
-    # similar to conv
-    def _impl(inputs, _):
-        weight = inputs[1][0]
-        weight_scale = inputs[1][1]
-        weight_zero_point = inputs[1][2]
-        output_scale = _expr.const(inputs[2])
-        output_zero_point = _expr.const(inputs[3])
-        assert len(inputs) == 6, "Input quant params not found in op inputs"
-        # Manually added by add_input_quant_params_to_op_inputs above
-        input_scale = _expr.const(inputs[4])
-        input_zero_point = _expr.const(inputs[5])
-
-        weight_shape = infer_shape(weight)
-        dense = relay.qnn.op.dense(inputs[0], weight,
-                                   input_zero_point, weight_zero_point,
-                                   input_scale, weight_scale,
-                                   units=weight_shape[0])
-        bias_var = inputs[1][3]
-
-        return _do_bias_and_requantize(dense, bias_var, input_scale,
-                                       weight_scale, output_scale,
-                                       output_zero_point, with_relu)
-
-    return _impl
-
-
-def _binop(relay_op, with_relu=False):
-    # refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp
-    # they piggy backs to fp32 math by dequantize -> fp32 math -> quantize
-    def _impl(inputs, _):
-        output_scale = _expr.const(inputs[2])
-        output_zero_point = _expr.const(inputs[3])
-        assert len(inputs) == 8, "Input quant params not found in op inputs"
-        # Manually added by add_input_quant_params_to_op_inputs above
-        input_scale_lhs = _expr.const(inputs[4])
-        input_zero_point_lhs = _expr.const(inputs[5])
-        input_scale_rhs = _expr.const(inputs[6])
-        input_zero_point_rhs = _expr.const(inputs[7])
-        lhs = inputs[0]
-        rhs = inputs[1]
-
-        if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize':
-            lhs = lhs.args[0]
-        else:
-            lhs = relay.qnn.op.dequantize(lhs,
-                                          input_scale_lhs,
-                                          input_zero_point_lhs)
-
-        if isinstance(rhs, _expr.Call) and rhs.op.name == 'qnn.quantize':
-            rhs = rhs.args[0]
-        else:
-            rhs = relay.qnn.op.dequantize(rhs,
-                                          input_scale_rhs,
-                                          input_zero_point_rhs)
-        fp32_out = relay_op(lhs, rhs)
-
-        if with_relu:
-            fp32_out = _op.nn.relu(fp32_out)
-
-        return relay.qnn.op.quantize(fp32_out,
-                                     output_scale,
-                                     output_zero_point,
-                                     axis=-1,
-                                     out_dtype="uint8")
-    return _impl
-
-
-def _cat():
-    # refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp
-    # for concat they also piggy backs to fp32(!)
-    # dequantize -> fp32 math -> quantize
-    # we can also use QNN concat op. we observed no change in accuracy
-    def _impl(inputs, _):
-        axis = inputs[1]
-        output_scale = _expr.const(inputs[2])
-        output_zero_point = _expr.const(inputs[3])
-        num_inputs = (len(inputs) - 4) // 2
-        dequantized = []
-
-        for i in range(0, num_inputs):
-            inp_scale = _expr.const(inputs[4+i*2])
-            inp_zp = _expr.const(inputs[4+i*2+1])
-            dequantized.append(relay.qnn.op.dequantize(inputs[0][i],
-                                                       inp_scale, inp_zp))
-
-        concat = _op.tensor.concatenate(dequantized, axis=axis)
-        return relay.qnn.op.quantize(concat, output_scale, output_zero_point,
-                                     axis=1, out_dtype="uint8")
-
-    return _impl
-
-
-def _add_scalar():
-    # this is used for mobilenet v3
-    def _impl(inputs, _):
-        # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp
-        assert len(inputs) == 6, "Input quant params not found in op inputs"
-        s = inputs[4]
-        z = inputs[5]
-        c = inputs[1]
-        c_q = round(c / s)
-        q_min = 0
-        q_max = 255
-
-        # math for calculating output scale and zp are already done
-        # during _add_output_quant_params_to_scalar_op above
-        out_scale = _expr.const(inputs[2])
-        out_zp = _expr.const(inputs[3])
-
-        if q_min > z - c_q or q_max < z - c_q:
-            dequant = relay.qnn.op.dequantize(inputs[0],
-                                              _expr.const(s), _expr.const(z))
-            dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s))
-            return relay.qnn.op.quantize(dequantized_add, out_scale, out_zp,
-                                         axis=1, out_dtype="uint8")
-        # only scale change
-        return inputs[0]
-
-    return _impl
-
-
-def quantize_scalar(data, scale, zero_point):
-    # used to quantize 6., in mobilenet v3
-    transformed = zero_point + data / scale
-    return max(0, min(round(transformed), 255))
-
-
-def _relu6():
-    # refer to src/ATen/native/quantized/cpu/qrelu.cpp
-    def _impl(inputs, _):
-        assert len(inputs) == 4, "Input quant params not found in op inputs"
-        input_scale = inputs[2]
-        input_zero_point = inputs[3]
-        six = quantize_scalar(6., input_scale, input_zero_point)
-        return _op.tensor.clip(inputs[0], input_zero_point, six)
-    return _impl
-
-
-def _mul_scalar():
-    # this is used for mobilenet v3
-    def _impl(inputs, _):
-        # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp
-        # math for calculating output scale and zp are already done
-        # during _add_output_quant_params_to_scalar_op above
-        assert len(inputs) == 6, "Input quant params not found in op inputs"
-        other_val = inputs[1]  # scalar
-
-        if other_val > 0.0:
-            # only scale change
-            return inputs[0]
-        if other_val == 0.0:
-            shape = infer_shape(inputs[0])
-            return _op.full(_expr.const(0), shape, dtype="uint8")
-
-        # negative scale case
-        q_min = 0
-        q_max = 255
-        bias = _expr.const(q_max + q_min, dtype="int8")
-        int8 = bias - _op.cast(inputs[0], "int8")
-        return _op.cast(int8, "uint8")
-
-    return _impl
-
-
-convert_map = {
-    'aten::quantize_per_tensor': _quantize_per_tensor(),
-    'quantized::conv2d_relu': _quantized_conv2d(True),
-    'aten::dequantize': _dequantize(),
-    'quantized::conv2d': _quantized_conv2d(),
-    'quantized::add_relu': _binop(relay.add, True),
-    'quantized::add': _binop(relay.add),
-    'quantized::mul_relu': _binop(relay.multiply, True),
-    'quantized::mul': _binop(relay.multiply),
-    'quantized::linear': _linear(),
-    'quantized::linear_relu': _linear(True),
-    'quantized::cat': _cat(),
-    'quantized::add_scalar': _add_scalar(),
-    'quantized::mul_scalar': _mul_scalar(),
-    'quantized::relu6': _relu6()
-}
diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py
deleted file mode 100644 (file)
index e3a876c..0000000
+++ /dev/null
@@ -1,455 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-""" Tests on quantized torch model conversion """
-import os
-
-from PIL import Image
-
-import numpy as np
-
-import torch
-from torch import nn
-from torch.quantization import QuantStub, DeQuantStub
-from torch.quantization import fuse_modules, QuantWrapper
-
-import tvm
-from tvm import relay
-from tvm.relay.frontend.pytorch import get_graph_input_names
-from tvm.contrib.download import download_testdata
-
-
-def torch_version_check():
-    from packaging import version
-    return version.parse(torch.__version__) > version.parse("1.4.0")
-
-
-def get_tvm_runtime(script_module, input_name, ishape):
-
-    input_shapes = {input_name: ishape}
-    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
-
-    with relay.build_config(opt_level=3):
-        # test on only cpu for now, torch cannot run quant models on cuda
-        # also not to make CI too slow
-        json, lib, params = relay.build(mod, target="llvm", params=params)
-
-    runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0))
-    runtime.set_input(**params)
-    return runtime
-
-
-def get_qconfig(per_channel):
-    from torch.quantization.observer import MovingAverageMinMaxObserver
-    from torch.quantization.observer import default_weight_observer
-
-    if per_channel:
-        return torch.quantization.get_default_qconfig('fbgemm')
-    else:
-        act = MovingAverageMinMaxObserver.with_args(reduce_range=False)
-        return torch.quantization.QConfig(activation=act,
-                                          weight=default_weight_observer)
-
-
-def quantize_model(model, inp, per_channel=False, dummy=True):
-    model.fuse_model()
-    model.qconfig = get_qconfig(per_channel)
-    torch.quantization.prepare(model, inplace=True)
-    model(inp)
-    torch.quantization.convert(model, inplace=True)
-
-
-class ConvBn(nn.Module):
-    def __init__(self, with_relu=False):
-        super().__init__()
-        layers = [nn.Conv2d(3, 32, 3, bias=True),
-                  nn.BatchNorm2d(32)]
-        if with_relu:
-            layers.append(nn.ReLU())
-        self.conv = nn.Sequential(*layers)
-        self.quant_wrap = QuantWrapper(self.conv)
-        self.with_relu = with_relu
-
-    def forward(self, x):
-        return self.quant_wrap(x)
-
-    def fuse_model(self):
-        indices = ["0", "1"]
-        if self.with_relu:
-            indices.append("2")
-        fuse_modules(self.conv, indices, inplace=True)
-
-
-class Linear(nn.Module):
-    def __init__(self, with_relu=False):
-        super().__init__()
-        layers = [nn.Linear(16, 32)]
-        if with_relu:
-            layers.append(nn.ReLU())
-        self.fc = nn.Sequential(*layers)
-        self.quant_wrap = QuantWrapper(self.fc)
-        self.with_relu = with_relu
-
-    def forward(self, x):
-        return self.quant_wrap(x)
-
-    def fuse_model(self):
-        if self.with_relu:
-            fuse_modules(self.fc, ["0", "1"], inplace=True)
-
-
-class ReLU(nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.relu = QuantWrapper(nn.ReLU())
-
-    def forward(self, x):
-        return self.relu(x)
-
-    def fuse_model(self):
-        pass
-
-
-# Mobilenet V3 related modules
-class Hsigmoid(nn.Module):
-    def __init__(self, inplace=True, add_stub=False):
-        super().__init__()
-        self.float_op = nn.quantized.FloatFunctional()
-        self.relu6 = nn.ReLU6(inplace=inplace)
-        self.quant = QuantStub()
-        self.dequant = DeQuantStub()
-        self.add_stub = add_stub
-
-    def forward(self, x):
-        if self.add_stub:
-            x = self.quant(x)
-        relu6 = self.relu6(self.float_op.add_scalar(x, 3.))
-        mul = self.float_op.mul_scalar(relu6, 1/6.)
-        if self.add_stub:
-            mul = self.dequant(mul)
-        return mul
-
-    def fuse_model(self):
-        pass
-
-
-class Hswish(nn.Module):
-    def __init__(self, inplace=True, add_stub=False):
-        super(Hswish, self).__init__()
-        self.float_op = nn.quantized.FloatFunctional()
-        self.hsigmoid = Hsigmoid(inplace, add_stub=False)
-        self.quant = QuantStub()
-        self.dequant = DeQuantStub()
-        self.add_stub = add_stub
-
-    def forward(self, x):
-        if self.add_stub:
-            x = self.quant(x)
-        mul = self.float_op.mul(x, self.hsigmoid(x))
-        if self.add_stub:
-            mul = self.dequant(mul)
-        return mul
-
-    def fuse_model(self):
-        pass
-
-
-class SqueezeExcite(nn.Module):
-    def __init__(self, channel, reduction=4, add_stub=False):
-        super(SqueezeExcite, self).__init__()
-        self.avg_pool = nn.AdaptiveAvgPool2d(1)
-        self.fc = nn.Sequential(
-            nn.Linear(channel, channel // reduction, bias=False),
-            nn.ReLU(inplace=True),
-            nn.Linear(channel // reduction, channel, bias=False),
-            Hsigmoid(add_stub=False)
-        )
-        self.fmul = nn.quantized.FloatFunctional()
-        self.quant = QuantStub()
-        self.dequant = DeQuantStub()
-        self.add_stub = add_stub
-
-    def forward(self, x):
-        b, c, _, _ = x.size()
-        if self.add_stub:
-            x = self.quant(x)
-        y = self.avg_pool(x).view(b, c)
-        y = self.fc(y).view(b, c, 1, 1)
-        out = self.fmul.mul(x, y.expand_as(x))
-        if self.add_stub:
-            return self.dequant(out)
-        else:
-            return out
-
-    def fuse_model(self):
-        fuse_modules(self.fc, ["0", "1"], inplace=True)
-
-
-# test on quantized::mul_scalar with negative scale
-class MulScalarNegative(nn.Module):
-    def __init__(self, ):
-        super().__init__()
-        self.float_op = nn.quantized.FloatFunctional()
-        self.quant = QuantStub()
-        self.dequant = DeQuantStub()
-
-    def forward(self, x):
-        x = self.quant(x)
-        mul = self.float_op.mul_scalar(x, -0.3)
-        return self.dequant(mul)
-
-    def fuse_model(self):
-        pass
-
-
-class UpsamplingBilinear(nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.relu = QuantWrapper(nn.ReLU())
-        self.quant = QuantStub()
-        self.dequant = DeQuantStub()
-
-    def forward(self, x):
-        x = self.quant(x)
-        upsample = nn.functional.interpolate(x, scale_factor=2,
-                                             mode='bilinear',
-                                             align_corners=True)
-        return self.dequant(upsample)
-
-    def fuse_model(self):
-        pass
-
-
-def test_quantized_modules():
-    imagenet_ishape = (1, 3, 224, 224)
-
-    qmodules = [
-       ("relu", imagenet_ishape, ReLU(), False),
-       ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
-    ]
-
-    for per_channel in [False, True]:
-        if per_channel:
-            postfix = ", per_channel"
-        else:
-            postfix = ""
-
-        qmodules += [
-           ("conv_bn" + postfix, imagenet_ishape, ConvBn(), per_channel),
-           ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
-           ("linear" + postfix, (16, 16), Linear(), per_channel),
-           ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel)
-        ]
-
-    if torch_version_check():
-        qmodules += [
-           ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
-           ("hswish", imagenet_ishape, Hswish(add_stub=True), False),
-           ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
-           ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True),
-           ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False)
-        ]
-    else:
-        print("Skipping tests that require torch > 1.4")
-
-    for (module_name, ishape, raw_module, per_channel) in qmodules:
-        raw_module.eval()
-        inp = torch.rand(ishape)
-
-        quantize_model(raw_module, inp, per_channel=per_channel, dummy=True)
-        script_module = torch.jit.trace(raw_module, inp).eval()
-
-        with torch.no_grad():
-            pt_result = script_module(inp.clone()).numpy()
-
-        input_name = get_graph_input_names(script_module)[0]
-
-        runtime = get_tvm_runtime(script_module, input_name, ishape)
-        runtime.set_input(input_name, inp.numpy().copy())
-        runtime.run()
-        tvm_result = runtime.get_output(0).asnumpy()
-
-        max_abs_diff = np.max(np.abs(tvm_result - pt_result))
-        mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
-        num_identical = np.sum(tvm_result == pt_result)
-        match_ratio = num_identical / float(np.prod(tvm_result.shape))
-
-        print(module_name, max_abs_diff, mean_abs_diff, match_ratio)
-
-        # sample outputs
-        """
-        relu 0.0039215684 2.6052087e-08 0.9999933567176871
-        upsample bilinear 0.0 0.0 1.0
-        conv_bn 0.22062653 0.011478779 0.6909348115006899
-        conv_bn_relu 0.3700896 0.010921672 0.7489366477964451
-        linear 0.15987062 0.009231662 0.794921875
-        linear_relu 0.14180502 0.0053220326 0.8828125
-        conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019
-        conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732
-        linear, per_channel 0.0 0.0 1.0
-        linear_relu, per_channel 0.0 0.0 1.0
-        hsigmoid 0.002614379 0.00020525524 0.9214896896258503
-        hswish 0.0052286386 0.00063522335 0.7587359162414966
-        semodule, per_channel 0.0039885044 0.0008620687 0.7838592529296875
-        mul_scalar negative 0.0011764616 7.815566e-09 0.9999933567176871
-        """
-
-        # we cannot make any guarantee on how close the raw output is to torch
-        # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1)
-
-
-def test_quantized_imagenet():
-    def get_transform():
-        import torchvision.transforms as transforms
-        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
-                                         std=[0.229, 0.224, 0.225])
-        return transforms.Compose([
-                transforms.Resize(256),
-                transforms.CenterCrop(224),
-                transforms.ToTensor(),
-                normalize,
-            ])
-
-    def get_real_image(im_height, im_width):
-        repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
-        img_name = 'elephant-299.jpg'
-        image_url = os.path.join(repo_base, img_name)
-        img_path = download_testdata(image_url, img_name, module='data')
-        return Image.open(img_path).resize((im_height, im_width))
-
-    def get_imagenet_input():
-        im = get_real_image(224, 224)
-        preprocess = get_transform()
-        pt_tensor = preprocess(im)
-        return np.expand_dims(pt_tensor.numpy(), 0)
-
-    from torchvision.models.quantization import resnet as qresnet
-    from torchvision.models.quantization import mobilenet as qmobilenet
-    from torchvision.models.quantization import inception as qinception
-    from torchvision.models.quantization import googlenet as qgooglenet
-
-    qmodels = []
-
-    for per_channel in [False, True]:
-        qmodels += [
-            ("resnet18", qresnet.resnet18(pretrained=True), per_channel),
-            ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
-            ("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
-            ("googlenet", qgooglenet(pretrained=True), per_channel),
-        ]
-
-    results = []
-
-    for (model_name, raw_model, per_channel) in qmodels:
-        raw_model.eval()
-
-        if per_channel:
-            model_name += ", per channel quantization"
-        else:
-            model_name += ", per tensor quantization"
-
-        inp = get_imagenet_input()
-        pt_inp = torch.from_numpy(inp)
-
-        quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False)
-        script_module = torch.jit.trace(raw_model, pt_inp).eval()
-
-        with torch.no_grad():
-            pt_result = script_module(pt_inp).numpy()
-
-        input_name = get_graph_input_names(script_module)[0]
-        runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
-        runtime.set_input(input_name, inp)
-        runtime.run()
-
-        tvm_result = runtime.get_output(0).asnumpy()
-
-        results.append((model_name, pt_result[0], tvm_result[0]))
-
-    for (model_name, pt_result, tvm_result) in results:
-        max_abs_diff = np.max(np.abs(tvm_result - pt_result))
-        mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
-        num_identical = np.sum(tvm_result == pt_result)
-        pt_top3_labels = np.argsort(pt_result)[::-1][:3]
-        tvm_top3_labels = np.argsort(pt_result)[::-1][:3]
-
-        print("\nModel name: %s" % model_name)
-        print("PyTorch top3 label:", pt_top3_labels)
-        print("TVM top3 label:", tvm_top3_labels)
-        print("max abs diff:", max_abs_diff)
-        print("mean abs_diff:", mean_abs_diff)
-        print("%d in 1000 raw outputs identical." % num_identical)
-
-        assert set(pt_top3_labels) == set(tvm_top3_labels)
-
-        # sample outputs
-        """
-        Model name: resnet18, per tensor quantization
-        PyTorch top3 label: [386 101 385]
-        TVM top3 label: [386 101 385]
-        max abs diff: 0.65681696
-        mean abs_diff: 0.14055882
-        236 in 1000 raw outputs identical.
-
-        Model name: mobilenet_v2, per tensor quantization
-        PyTorch top3 label: [101 386 385]
-        TVM top3 label: [101 386 385]
-        max abs diff: 2.1262953
-        mean abs_diff: 0.41025686
-        101 in 1000 raw outputs identical.
-
-        Model name: inception_v3, per tensor quantization
-        PyTorch top3 label: [101 386 385]
-        TVM top3 label: [101 386 385]
-        max abs diff: 0.9994669
-        mean abs_diff: 0.098697364
-        272 in 1000 raw outputs identical.
-
-        Model name: googlenet, per tensor quantization
-        PyTorch top3 label: [101 386 385]
-        TVM top3 label: [101 386 385]
-        max abs diff: 0.28248847
-        mean abs_diff: 0.0634469
-        274 in 1000 raw outputs identical.
-
-        Model name: resnet18, per channel quantization
-        PyTorch top3 label: [101 386 385]
-        TVM top3 label: [101 386 385]
-        max abs diff: 0.65908074
-        mean abs_diff: 0.1274223
-        469 in 1000 raw outputs identical.
-
-        Model name: mobilenet_v2, per channel quantization
-        PyTorch top3 label: [101 386 385]
-        TVM top3 label: [101 386 385]
-        max abs diff: 0.71120834
-        mean abs_diff: 0.15883648
-        423 in 1000 raw outputs identical.
-
-        Model name: inception_v3, per channel quantization
-        PyTorch top3 label: [386 101 385]
-        TVM top3 label: [386 101 385]
-        max abs diff: 1.3372154
-        mean abs_diff: 0.1225224
-        401 in 1000 raw outputs identical.
-
-        Model name: googlenet, per channel quantization
-        PyTorch top3 label: [101 386 385]
-        TVM top3 label: [101 386 385]
-        max abs diff: 0.34015465
-        mean abs_diff: 0.054197952
-        558 in 1000 raw outputs identical.
-        """
index eed47ea..e60c1fd 100644 (file)
@@ -854,9 +854,3 @@ if __name__ == "__main__":
     test_custom_conversion_map()
 
     test_segmentaton_models()
-
-    # Quantization test
-    from qnn_test import test_quantized_imagenet, test_quantized_modules
-
-    test_quantized_modules()
-    test_quantized_imagenet()