[Relay][Quantize] Integrate data-aware calibration into quantization (#4295)
authorWuwei Lin <wuwei@apache.org>
Tue, 19 Nov 2019 22:54:57 +0000 (17:54 -0500)
committerGitHub <noreply@github.com>
Tue, 19 Nov 2019 22:54:57 +0000 (17:54 -0500)
* [Relay][Quantize] Integrate data-aware calibration into quantization

* Update _calibrate.py

* trigger ci

* Address comments

* address comments

python/tvm/relay/quantize/__init__.py
python/tvm/relay/quantize/_annotate.py
python/tvm/relay/quantize/_calibrate.py [new file with mode: 0644]
python/tvm/relay/quantize/kl_divergence.py
python/tvm/relay/quantize/quantize.py
src/relay/pass/quantize/calibrate.cc
src/relay/pass/quantize/quantize.cc
src/relay/pass/quantize/quantize.h

index 29b6895..09dfa8f 100644 (file)
@@ -21,4 +21,3 @@ from __future__ import absolute_import as _abs
 from .quantize import *
 from ._partition import register_partition_function
 from ._annotate import register_annotate_function
-from .kl_divergence import kl_divergence_scale
index 55f3597..9d679d2 100644 (file)
@@ -57,6 +57,7 @@ _reg.register_schedule("relay.op.annotation.simulated_quantize",
                        _reg.schedule_injective)
 _reg.register_pattern("relay.op.annotation.simulated_quantize",
                       _reg.OpPattern.ELEMWISE)
+_reg.register_schedule("annotation.cast_hint", _reg.schedule_injective)
 
 
 @register_relay_node
diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py
new file mode 100644 (file)
index 0000000..aae5051
--- /dev/null
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Find scales for quantization on the dataset."""
+from __future__ import absolute_import
+import logging
+import multiprocessing as mp
+import numpy as np
+import tvm
+
+from . import _quantize
+from . import quantize
+from .. import op as _op
+from .. import expr as _expr
+from .. import module as _module
+from .. import analysis as _analysis
+from .. import transform as _transform
+from .. import build_module as _build_module
+from ...contrib import graph_runtime
+from .kl_divergence import _find_scale_by_kl
+
+
+def collect_stats(mod, dataset):
+    """Given an annotated graph, create a profile graph to collect profile data from the
+    calibration dataset. This pass collects simulated_quantize op input into a tuple.
+    Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
+    graph.
+
+    Parameters
+    ----------
+    mod: Module
+        The simulation graph after annotation.
+
+    Returns
+    -------
+    ret: list of ndarray
+        List of output data of each layer
+    """
+
+    logging.info("collecting statistics for calibration...")
+    func = mod['main']
+    func = _quantize.CreateStatsCollector(func)
+    target = tvm.target.current_target() or 'llvm'
+    with _transform.build_config(opt_level=3):
+        graph, lib, params = _build_module.build(func, target=target)
+    outputs = []
+    runtime = graph_runtime.create(graph, lib, tvm.context(target))
+    runtime.set_input(**params)
+
+    num_outputs = runtime.get_num_outputs()
+    outputs = [[] for i in range(num_outputs)]
+
+    for batch in dataset:
+        runtime.set_input(**batch)
+        runtime.run()
+        for i in range(num_outputs):
+            output = runtime.get_output(i).asnumpy()
+            outputs[i].append(output)
+    for i in range(num_outputs):
+        outputs[i] = np.concatenate(outputs[i]).reshape(-1)
+    return outputs
+
+
+def _kl_scale(stats):
+    with mp.Pool() as pool:
+        logging.info("finding threshold with kl for calibration...")
+        scales = list(pool.map(_find_scale_by_kl, stats))
+
+    def func(sq_call):  # pylint: disable=unused-argument
+        scale = scales[func.scale_idx]
+        func.scale_idx += 1
+        return scale
+    func.scale_idx = 0
+
+    return func
+
+
+def _set_params(mod, input_scale_func, weight_scale_func):
+    quantize_op = _op.get("relay.op.annotation.simulated_quantize")
+    cfg = quantize.current_qconfig()
+    const_params = {}
+
+    def visit_func(expr):
+        '''visitor function for traverse'''
+        if isinstance(expr, _expr.Call) and expr.op == quantize_op:
+            _, ndom_scale, nclip_min, nclip_max = expr.args
+            attrs = expr.attrs
+            kind = attrs.kind
+            nbit = cfg.get_nbit_by_kind(kind)
+            valid_bit = nbit - attrs.sign
+
+            # set scale
+            if kind == quantize.QAnnotateKind.WEIGHT:
+                assert isinstance(expr.args[0], _expr.Constant)
+                scale = weight_scale_func(expr)
+            else:
+                scale = input_scale_func(expr)
+
+            def _make_const(val):
+                return _expr.const(val, 'float32')
+
+            valid_range = 2**valid_bit
+            const_params[ndom_scale] = _make_const(scale / valid_range)
+            const_params[nclip_min] = _make_const(- (valid_range - 1))
+            const_params[nclip_max] = _make_const((valid_range - 1))
+
+    func = mod['main']
+    _analysis.post_order_visit(func, visit_func)
+    func = _expr.bind(func, const_params)
+    return _module.Module.from_expr(func)
+
+
+# weight scale functions
+def _power2_scale(sq_call):  # pylint: disable=unused-argument
+    """calculate weight scale with nearest mode-2 scale"""
+    var = sq_call.args[0]
+    assert isinstance(var, _expr.Constant)
+    val = np.amax(np.abs(var.data.asnumpy()))
+    return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
+
+
+def _max_scale(sq_call):
+    """calculate weight scale with maximum absolute value"""
+    var = sq_call.args[0]
+    assert isinstance(var, _expr.Constant)
+    val = np.amax(np.abs(var.data.asnumpy()))
+    return val
+
+
+# input scale functions
+def _global_scale(sq_call): # pylint: disable=unused-argument
+    cfg = quantize.current_qconfig()
+    return cfg.global_scale
+
+
+def calibrate(dataset=None):
+    """The calibrate procedure will try to calculate the content of
+    dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
+    operator.
+
+    Parameters
+    ---------
+    dataset: Optional[Iterable[NDArray]]
+        The calibration dataset.
+
+    Returns
+    -------
+    ret: Function
+        The module pass function.
+    """
+    def wrapped_func(mod, ctx): # pylint: disable=unused-argument
+        """make transform.module pass happy"""
+        cfg = quantize.current_qconfig()
+
+        if cfg.calibrate_mode == 'kl_divergence':
+            stats = collect_stats(mod, dataset)
+            input_scale_func = _kl_scale(stats)
+        elif cfg.calibrate_mode == 'global_scale':
+            input_scale_func = _global_scale
+        else:
+            raise ValueError("Unknown calibrate mode {}".format(cfg.calibrate_mode))
+
+        if cfg.weight_scale == 'max':
+            weight_scale_func = _max_scale
+        elif cfg.weight_scale == 'power2':
+            weight_scale_func = _power2_scale
+        else:
+            raise ValueError("Unknown weight scale mode {}".format(cfg.weight_scale))
+
+        return _set_params(mod, input_scale_func, weight_scale_func)
+    return wrapped_func
index bce45dc..2feb514 100644 (file)
@@ -45,7 +45,7 @@ def _smooth_distribution(p, eps=0.0001):
 
 
 # pylint: disable=invalid-name
-def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
+def _find_scale_by_kl(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
     """Given a tensor, find the optimal threshold for quantizing it.
     The reference distribution is `q`, and the candidate distribution is `p`.
     `q` is a truncated version of the original distribution.
@@ -54,6 +54,8 @@ def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantize
     http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
     """
     assert isinstance(arr, np.ndarray)
+    assert stats is not None, "scipy needs to be installed for \
+    utilizing kl calibration during quantization"
 
     min_val = np.min(arr)
     max_val = np.max(arr)
index 7fa8a66..1d60145 100644 (file)
 #pylint: disable=unused-argument
 """Automatic quantization toolkit."""
 from __future__ import absolute_import
-import numpy as np
-
 from . import _quantize
+from ._calibrate import calibrate
 from .. import expr as _expr
-from .. import module as _module
-from .. import analysis as _analysis
 from .. import transform as _transform
-from .. import op as _op
 from ... import make as _make
 from ..base import NodeBase, register_relay_node
 
@@ -78,7 +74,9 @@ class QConfig(NodeBase):
         "dtype_input": "int8",
         "dtype_weight": "int8",
         "dtype_activation": "int32",
+        "calibrate_mode": "global_scale",
         "global_scale": 8.0,
+        "weight_scale": "power2",
         "skip_conv_layers": [0],
         "do_simulation": False,
         "round_for_shift": True,
@@ -143,9 +141,20 @@ def qconfig(**kwargs):
     nbit_dict: dict of QAnnotateKind -> int
         Number of bit for every kind of annotate field.
 
+    calibrate_mode: str
+        The calibration mode. 'global_scale' or 'kl_divergence'.
+        global_scale: use global scale
+        kl_divergence: find scales by kl divergence on the dataset.
+
     global_scale: float
         The global scale for calibration.
 
+    weight_scale: str
+        The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
+        power2: Find the maximum of the absolute value of the tensor, and then round up to power
+        of two.
+        max: Find the maximum of the absolute value of the tensor
+
     skip_conv_layers: list
         Specifying which layers to be skipped. Provide a list of indices
         that indicate which conv2d layers to leave untouched. Start from 0.
@@ -249,113 +258,6 @@ def annotate():
     return _quantize.QuantizeAnnotate()
 
 
-def collect_stats(graph):
-    """Given an annotated graph, create a profile graph to collect profile data from the
-    calibration dataset. This pass collects simulated_quantize op input into a tuple.
-    Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
-    graph.
-
-    Parameters
-    ----------
-    graph: Function
-        The simulation graph after annotation.
-
-    Returns
-    -------
-    ret: Function
-        The profile graph which outputs a tuple of profile data.
-    """
-    return _quantize.CollectStats(graph)
-
-
-def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
-    """The calibrate procedure will try to calculate the content of
-    dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
-    operator.
-
-    Parameters
-    ---------
-    graph: Function
-        The simulation graph after annotation.
-
-    mod: tvm.relay.Module
-        The module where calibration happens on.
-
-    ctx: tvm.relay.PassContext
-        The pass context used for calibration.
-
-    weight_scales: 'power2' or 'max'.
-        The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
-        power2: Find the maximum of the absolute value of the tensor, and then round up to power
-        of two.
-        max: Find the maximum of the absolute value of the tensor.
-
-    scales: List[float]
-        Pre-calculated scales for input and activations. Length and the order of elements of the
-        scales list should match the output tuple of the profile graph created by collect_stats.
-
-    Returns
-    -------
-    ret: Function
-        The graph after calibration
-    """
-    def power2_scale(arr):
-        """calculate weight scale with nearest mode-2 scale"""
-        val = np.amax(np.abs(arr.asnumpy()))
-        return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
-
-    def max_scale(arr):
-        """calculate weight scale with maximum absolute value"""
-        val = np.amax(np.abs(arr.asnumpy()))
-        return val
-
-    scale_idx = 0
-
-    cfg = current_qconfig()
-    const_params = {}
-    quantize_op = _op.get("relay.op.annotation.simulated_quantize")
-
-    def visit_func(expr):
-        """Internal visit function"""
-        nonlocal scale_idx
-        if isinstance(expr, _expr.Call) and expr.op == quantize_op:
-            _, ndom_scale, nclip_min, nclip_max = expr.args
-            attrs = expr.attrs
-            kind = attrs.kind
-            nbit = cfg.get_nbit_by_kind(kind)
-
-            valid_bit = nbit - attrs.sign
-            if kind in [QAnnotateKind.WEIGHT]:
-                if all([isinstance(arg, _expr.Constant)
-                        for arg in [ndom_scale, nclip_min, nclip_max]]):
-                    return
-                var = expr.args[0]
-                assert isinstance(var, _expr.Constant)
-                if weight_scales == 'max':
-                    scale = max_scale(var.data)
-                elif weight_scales == 'power2':
-                    scale = power2_scale(var.data)
-                else:
-                    raise ValueError('{} not supported'.format(weight_scales))
-            elif scales is not None:
-                scale = scales[scale_idx]
-                scale_idx += 1
-            else:
-                scale = cfg.global_scale
-
-            def _make_const(val):
-                return _expr.const(val, 'float32')
-
-            valid_range = 2**valid_bit
-            const_params[ndom_scale] = _make_const(scale / valid_range)
-            const_params[nclip_min] = _make_const(- (valid_range - 1))
-            const_params[nclip_max] = _make_const((valid_range - 1))
-
-    _analysis.post_order_visit(graph, visit_func)
-    ret = _expr.bind(graph, const_params)
-    return ret
-
-
 def realize():
     """The realize pass will transform the simulated quantized graph, which
     actually computes with float32, to a real low-bit integer graph. It will
@@ -391,7 +293,7 @@ def _bind_params(func, params):
     return _expr.bind(func, bind_dict)
 
 
-def prerequisite_optimize(graph, params=None):
+def prerequisite_optimize(mod, params=None):
     """ Prerequisite optimization passes for quantization. Perform
     "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
     "CanonicalizeOps" optimization before quantization. """
@@ -402,15 +304,13 @@ def prerequisite_optimize(graph, params=None):
                                       _transform.FoldConstant()])
 
     if params:
-        graph = _bind_params(graph, params)
+        mod['main'] = _bind_params(mod['main'], params)
 
-    mod = _module.Module.from_expr(graph)
-    with _transform.PassContext(opt_level=3):
-        mod = optimize(mod)
-    return mod["main"]
+    mod = optimize(mod)
+    return mod
 
 
-def quantize(graph, params=None, dataset=None):
+def quantize(mod, params=None, dataset=None):
     """ The quantization procedure. Before running the three main
     procedure of quantization, "annotate", "calibrate" and "realize"
     , we need to do "SimplifyInference", "FoldScaleAxis", "FoldConstant"
@@ -418,8 +318,8 @@ def quantize(graph, params=None, dataset=None):
 
     Parameters
     ---------
-    graph: Function
-        The original graph.
+    mod: Module
+        The original module.
 
     params : dict of str to NDArray
         Input parameters to the graph that do not change
@@ -433,11 +333,10 @@ def quantize(graph, params=None, dataset=None):
     ret: Function
         The graph after quantization
     """
-    graph = prerequisite_optimize(graph, params)
+    mod = prerequisite_optimize(mod, params)
 
-    mod = _module.Module.from_expr(graph)
-    calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
-                                              name="QuantizeCalibrate")
+    calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
+                                            name="QuantizeCalibrate")
     quant_passes = [partition(),
                     annotate(),
                     calibrate_pass]
@@ -452,4 +351,4 @@ def quantize(graph, params=None, dataset=None):
         with quantize_context():
             mod = quantize_seq(mod)
 
-    return mod["main"]
+    return mod
index 30b47ba..9757e58 100644 (file)
@@ -87,12 +87,12 @@ class StatsCollector : private ExprMutator {
  * \param expr The simulation graph after annotation.
  * \return The profile graph.
  */
-Expr CollectStats(const Expr& expr) {
+Expr CreateStatsCollector(const Expr& expr) {
   return StatsCollector().Collect(expr);
 }
 
-TVM_REGISTER_API("relay._quantize.CollectStats")
-.set_body_typed(CollectStats);
+TVM_REGISTER_API("relay._quantize.CreateStatsCollector")
+.set_body_typed(CreateStatsCollector);
 
 }  // namespace quantize
 }  // namespace relay
index 2793577..be24ad7 100644 (file)
@@ -123,7 +123,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   p->stream << "nbit_input=" << op->nbit_input << ", ";
   p->stream << "nbit_weight=" << op->nbit_weight << ", ";
   p->stream << "nbit_activation=" << op->nbit_activation << ", ";
+  p->stream << "calibrate_mode=" << op->calibrate_mode << ", ";
   p->stream << "global_scale=" << op->global_scale << ", ";
+  p->stream << "weight_scale=" << op->weight_scale << ", ";
   p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
   p->stream << "do_simulation==" << op->do_simulation << ", ";
   p->stream << "round_for_shift==" << op->round_for_shift << ", ";
index 8a0282a..3af13a9 100644 (file)
@@ -70,7 +70,9 @@ class QConfigNode : public Node {
   DataType dtype_input = Int(8);
   DataType dtype_weight = Int(8);
   DataType dtype_activation = Int(32);
+  std::string calibrate_mode = "global_scale";
   double global_scale = 8.0;
+  std::string weight_scale = "power2";
   Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
   bool do_simulation = false;
   bool round_for_shift = true;
@@ -84,7 +86,9 @@ class QConfigNode : public Node {
     v->Visit("dtype_input", &dtype_input);
     v->Visit("dtype_weight", &dtype_weight);
     v->Visit("dtype_activation", &dtype_activation);
+    v->Visit("calibrate_mode", &calibrate_mode);
     v->Visit("global_scale", &global_scale);
+    v->Visit("weight_scale", &weight_scale);
     v->Visit("skip_conv_layers", &skip_conv_layers);
     v->Visit("do_simulation", &do_simulation);
     v->Visit("round_for_shift", &round_for_shift);