from .quantize import *
from ._partition import register_partition_function
from ._annotate import register_annotate_function
-from .kl_divergence import kl_divergence_scale
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.ELEMWISE)
+_reg.register_schedule("annotation.cast_hint", _reg.schedule_injective)
@register_relay_node
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Find scales for quantization on the dataset."""
+from __future__ import absolute_import
+import logging
+import multiprocessing as mp
+import numpy as np
+import tvm
+
+from . import _quantize
+from . import quantize
+from .. import op as _op
+from .. import expr as _expr
+from .. import module as _module
+from .. import analysis as _analysis
+from .. import transform as _transform
+from .. import build_module as _build_module
+from ...contrib import graph_runtime
+from .kl_divergence import _find_scale_by_kl
+
+
+def collect_stats(mod, dataset):
+ """Given an annotated graph, create a profile graph to collect profile data from the
+ calibration dataset. This pass collects simulated_quantize op input into a tuple.
+ Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
+ graph.
+
+ Parameters
+ ----------
+ mod: Module
+ The simulation graph after annotation.
+
+ Returns
+ -------
+ ret: list of ndarray
+ List of output data of each layer
+ """
+
+ logging.info("collecting statistics for calibration...")
+ func = mod['main']
+ func = _quantize.CreateStatsCollector(func)
+ target = tvm.target.current_target() or 'llvm'
+ with _transform.build_config(opt_level=3):
+ graph, lib, params = _build_module.build(func, target=target)
+ outputs = []
+ runtime = graph_runtime.create(graph, lib, tvm.context(target))
+ runtime.set_input(**params)
+
+ num_outputs = runtime.get_num_outputs()
+ outputs = [[] for i in range(num_outputs)]
+
+ for batch in dataset:
+ runtime.set_input(**batch)
+ runtime.run()
+ for i in range(num_outputs):
+ output = runtime.get_output(i).asnumpy()
+ outputs[i].append(output)
+ for i in range(num_outputs):
+ outputs[i] = np.concatenate(outputs[i]).reshape(-1)
+ return outputs
+
+
+def _kl_scale(stats):
+ with mp.Pool() as pool:
+ logging.info("finding threshold with kl for calibration...")
+ scales = list(pool.map(_find_scale_by_kl, stats))
+
+ def func(sq_call): # pylint: disable=unused-argument
+ scale = scales[func.scale_idx]
+ func.scale_idx += 1
+ return scale
+ func.scale_idx = 0
+
+ return func
+
+
+def _set_params(mod, input_scale_func, weight_scale_func):
+ quantize_op = _op.get("relay.op.annotation.simulated_quantize")
+ cfg = quantize.current_qconfig()
+ const_params = {}
+
+ def visit_func(expr):
+ '''visitor function for traverse'''
+ if isinstance(expr, _expr.Call) and expr.op == quantize_op:
+ _, ndom_scale, nclip_min, nclip_max = expr.args
+ attrs = expr.attrs
+ kind = attrs.kind
+ nbit = cfg.get_nbit_by_kind(kind)
+ valid_bit = nbit - attrs.sign
+
+ # set scale
+ if kind == quantize.QAnnotateKind.WEIGHT:
+ assert isinstance(expr.args[0], _expr.Constant)
+ scale = weight_scale_func(expr)
+ else:
+ scale = input_scale_func(expr)
+
+ def _make_const(val):
+ return _expr.const(val, 'float32')
+
+ valid_range = 2**valid_bit
+ const_params[ndom_scale] = _make_const(scale / valid_range)
+ const_params[nclip_min] = _make_const(- (valid_range - 1))
+ const_params[nclip_max] = _make_const((valid_range - 1))
+
+ func = mod['main']
+ _analysis.post_order_visit(func, visit_func)
+ func = _expr.bind(func, const_params)
+ return _module.Module.from_expr(func)
+
+
+# weight scale functions
+def _power2_scale(sq_call): # pylint: disable=unused-argument
+ """calculate weight scale with nearest mode-2 scale"""
+ var = sq_call.args[0]
+ assert isinstance(var, _expr.Constant)
+ val = np.amax(np.abs(var.data.asnumpy()))
+ return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
+
+
+def _max_scale(sq_call):
+ """calculate weight scale with maximum absolute value"""
+ var = sq_call.args[0]
+ assert isinstance(var, _expr.Constant)
+ val = np.amax(np.abs(var.data.asnumpy()))
+ return val
+
+
+# input scale functions
+def _global_scale(sq_call): # pylint: disable=unused-argument
+ cfg = quantize.current_qconfig()
+ return cfg.global_scale
+
+
+def calibrate(dataset=None):
+ """The calibrate procedure will try to calculate the content of
+ dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
+ operator.
+
+ Parameters
+ ---------
+ dataset: Optional[Iterable[NDArray]]
+ The calibration dataset.
+
+ Returns
+ -------
+ ret: Function
+ The module pass function.
+ """
+ def wrapped_func(mod, ctx): # pylint: disable=unused-argument
+ """make transform.module pass happy"""
+ cfg = quantize.current_qconfig()
+
+ if cfg.calibrate_mode == 'kl_divergence':
+ stats = collect_stats(mod, dataset)
+ input_scale_func = _kl_scale(stats)
+ elif cfg.calibrate_mode == 'global_scale':
+ input_scale_func = _global_scale
+ else:
+ raise ValueError("Unknown calibrate mode {}".format(cfg.calibrate_mode))
+
+ if cfg.weight_scale == 'max':
+ weight_scale_func = _max_scale
+ elif cfg.weight_scale == 'power2':
+ weight_scale_func = _power2_scale
+ else:
+ raise ValueError("Unknown weight scale mode {}".format(cfg.weight_scale))
+
+ return _set_params(mod, input_scale_func, weight_scale_func)
+ return wrapped_func
# pylint: disable=invalid-name
-def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
+def _find_scale_by_kl(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
"""Given a tensor, find the optimal threshold for quantizing it.
The reference distribution is `q`, and the candidate distribution is `p`.
`q` is a truncated version of the original distribution.
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
"""
assert isinstance(arr, np.ndarray)
+ assert stats is not None, "scipy needs to be installed for \
+ utilizing kl calibration during quantization"
min_val = np.min(arr)
max_val = np.max(arr)
#pylint: disable=unused-argument
"""Automatic quantization toolkit."""
from __future__ import absolute_import
-import numpy as np
-
from . import _quantize
+from ._calibrate import calibrate
from .. import expr as _expr
-from .. import module as _module
-from .. import analysis as _analysis
from .. import transform as _transform
-from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
"dtype_input": "int8",
"dtype_weight": "int8",
"dtype_activation": "int32",
+ "calibrate_mode": "global_scale",
"global_scale": 8.0,
+ "weight_scale": "power2",
"skip_conv_layers": [0],
"do_simulation": False,
"round_for_shift": True,
nbit_dict: dict of QAnnotateKind -> int
Number of bit for every kind of annotate field.
+ calibrate_mode: str
+ The calibration mode. 'global_scale' or 'kl_divergence'.
+ global_scale: use global scale
+ kl_divergence: find scales by kl divergence on the dataset.
+
global_scale: float
The global scale for calibration.
+ weight_scale: str
+ The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
+ power2: Find the maximum of the absolute value of the tensor, and then round up to power
+ of two.
+ max: Find the maximum of the absolute value of the tensor
+
skip_conv_layers: list
Specifying which layers to be skipped. Provide a list of indices
that indicate which conv2d layers to leave untouched. Start from 0.
return _quantize.QuantizeAnnotate()
-def collect_stats(graph):
- """Given an annotated graph, create a profile graph to collect profile data from the
- calibration dataset. This pass collects simulated_quantize op input into a tuple.
- Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
- graph.
-
- Parameters
- ----------
- graph: Function
- The simulation graph after annotation.
-
- Returns
- -------
- ret: Function
- The profile graph which outputs a tuple of profile data.
- """
- return _quantize.CollectStats(graph)
-
-
-def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
- """The calibrate procedure will try to calculate the content of
- dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
- operator.
-
- Parameters
- ---------
- graph: Function
- The simulation graph after annotation.
-
- mod: tvm.relay.Module
- The module where calibration happens on.
-
- ctx: tvm.relay.PassContext
- The pass context used for calibration.
-
- weight_scales: 'power2' or 'max'.
- The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
- power2: Find the maximum of the absolute value of the tensor, and then round up to power
- of two.
- max: Find the maximum of the absolute value of the tensor.
-
- scales: List[float]
- Pre-calculated scales for input and activations. Length and the order of elements of the
- scales list should match the output tuple of the profile graph created by collect_stats.
-
- Returns
- -------
- ret: Function
- The graph after calibration
- """
- def power2_scale(arr):
- """calculate weight scale with nearest mode-2 scale"""
- val = np.amax(np.abs(arr.asnumpy()))
- return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
-
- def max_scale(arr):
- """calculate weight scale with maximum absolute value"""
- val = np.amax(np.abs(arr.asnumpy()))
- return val
-
- scale_idx = 0
-
- cfg = current_qconfig()
- const_params = {}
- quantize_op = _op.get("relay.op.annotation.simulated_quantize")
-
- def visit_func(expr):
- """Internal visit function"""
- nonlocal scale_idx
- if isinstance(expr, _expr.Call) and expr.op == quantize_op:
- _, ndom_scale, nclip_min, nclip_max = expr.args
- attrs = expr.attrs
- kind = attrs.kind
- nbit = cfg.get_nbit_by_kind(kind)
-
- valid_bit = nbit - attrs.sign
- if kind in [QAnnotateKind.WEIGHT]:
- if all([isinstance(arg, _expr.Constant)
- for arg in [ndom_scale, nclip_min, nclip_max]]):
- return
- var = expr.args[0]
- assert isinstance(var, _expr.Constant)
- if weight_scales == 'max':
- scale = max_scale(var.data)
- elif weight_scales == 'power2':
- scale = power2_scale(var.data)
- else:
- raise ValueError('{} not supported'.format(weight_scales))
- elif scales is not None:
- scale = scales[scale_idx]
- scale_idx += 1
- else:
- scale = cfg.global_scale
-
- def _make_const(val):
- return _expr.const(val, 'float32')
-
- valid_range = 2**valid_bit
- const_params[ndom_scale] = _make_const(scale / valid_range)
- const_params[nclip_min] = _make_const(- (valid_range - 1))
- const_params[nclip_max] = _make_const((valid_range - 1))
-
- _analysis.post_order_visit(graph, visit_func)
- ret = _expr.bind(graph, const_params)
- return ret
-
-
def realize():
"""The realize pass will transform the simulated quantized graph, which
actually computes with float32, to a real low-bit integer graph. It will
return _expr.bind(func, bind_dict)
-def prerequisite_optimize(graph, params=None):
+def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
_transform.FoldConstant()])
if params:
- graph = _bind_params(graph, params)
+ mod['main'] = _bind_params(mod['main'], params)
- mod = _module.Module.from_expr(graph)
- with _transform.PassContext(opt_level=3):
- mod = optimize(mod)
- return mod["main"]
+ mod = optimize(mod)
+ return mod
-def quantize(graph, params=None, dataset=None):
+def quantize(mod, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
, we need to do "SimplifyInference", "FoldScaleAxis", "FoldConstant"
Parameters
---------
- graph: Function
- The original graph.
+ mod: Module
+ The original module.
params : dict of str to NDArray
Input parameters to the graph that do not change
ret: Function
The graph after quantization
"""
- graph = prerequisite_optimize(graph, params)
+ mod = prerequisite_optimize(mod, params)
- mod = _module.Module.from_expr(graph)
- calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
- name="QuantizeCalibrate")
+ calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
+ name="QuantizeCalibrate")
quant_passes = [partition(),
annotate(),
calibrate_pass]
with quantize_context():
mod = quantize_seq(mod)
- return mod["main"]
+ return mod
* \param expr The simulation graph after annotation.
* \return The profile graph.
*/
-Expr CollectStats(const Expr& expr) {
+Expr CreateStatsCollector(const Expr& expr) {
return StatsCollector().Collect(expr);
}
-TVM_REGISTER_API("relay._quantize.CollectStats")
-.set_body_typed(CollectStats);
+TVM_REGISTER_API("relay._quantize.CreateStatsCollector")
+.set_body_typed(CreateStatsCollector);
} // namespace quantize
} // namespace relay
p->stream << "nbit_input=" << op->nbit_input << ", ";
p->stream << "nbit_weight=" << op->nbit_weight << ", ";
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
+ p->stream << "calibrate_mode=" << op->calibrate_mode << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
+ p->stream << "weight_scale=" << op->weight_scale << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "do_simulation==" << op->do_simulation << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
DataType dtype_input = Int(8);
DataType dtype_weight = Int(8);
DataType dtype_activation = Int(32);
+ std::string calibrate_mode = "global_scale";
double global_scale = 8.0;
+ std::string weight_scale = "power2";
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
bool do_simulation = false;
bool round_for_shift = true;
v->Visit("dtype_input", &dtype_input);
v->Visit("dtype_weight", &dtype_weight);
v->Visit("dtype_activation", &dtype_activation);
+ v->Visit("calibrate_mode", &calibrate_mode);
v->Visit("global_scale", &global_scale);
+ v->Visit("weight_scale", &weight_scale);
v->Visit("skip_conv_layers", &skip_conv_layers);
v->Visit("do_simulation", &do_simulation);
v->Visit("round_for_shift", &round_for_shift);