From bfb4884e47f7d2c759b2ee707aa18acf35116380 Mon Sep 17 00:00:00 2001 From: ziheng Date: Sat, 22 Jun 2019 14:59:41 -0700 Subject: [PATCH] [QUANTIZE] Memorizing the quantize node mapping (#3233) * [QUANTIZE] Support for clip operator * [QUANTIZE] Memorizing the quantize node mapping. * [QUANTIZE] Remove use_stop_fusion and skip_k_conv in qconfig * update * update * update * update --- python/tvm/relay/backend/_backend.py | 5 +- python/tvm/relay/quantize/_annotate.py | 96 ++++++++++++++++---------------- python/tvm/relay/quantize/quantize.py | 63 ++++++++++++--------- src/relay/pass/quantize.cc | 32 +++++++++-- src/relay/pass/quantize.h | 8 +-- tests/python/relay/test_pass_quantize.py | 2 +- 6 files changed, 116 insertions(+), 90 deletions(-) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 50e9694..860788a 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -17,7 +17,6 @@ """The interface of expr function exposed from C++.""" from __future__ import absolute_import -import logging from ... import build_module as _build from ... import container as _container from ..._ffi.function import _init_api, register_func @@ -50,8 +49,8 @@ def lower(sch, inputs, func_name, source_func): # pylint: disable=broad-except try: f = _build.lower(sch, inputs, name=func_name) - logging.debug("lower function %s", func_name) - logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) + # logging.debug("lower function %s", func_name) + # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) except Exception: msg = traceback.format_exc() msg += "Error during compile function\n" diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 9bf546f..61e895a 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -22,7 +22,7 @@ import warnings import topi from . import _quantize from .quantize import QAnnotateKind, current_qconfig -from .quantize import _conv_counter, _set_conv_counter +from .quantize import annotate_context from .. import expr as _expr from .. import op as _op from ..op import op as _reg @@ -116,7 +116,6 @@ def register_annotate_function(op_name, frewrite=None, level=10): return _register(frewrite) if frewrite is not None else _register -@register_func("relay.quantize.attach_simulated_quantize") def attach_simulated_quantize(data, kind, sign=True, rounding="round"): """Attach a simulated quantize operation after input data expr. @@ -133,11 +132,20 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding: return data + actx = annotate_context() + key = tuple([data, kind, sign, rounding]) + if key in actx.qnode_map: + return actx.qnode_map[key] + dom_scale = _expr.var("dom_scale") clip_min = _expr.var("clip_min") clip_max = _expr.var("clip_max") - return _quantize.simulated_quantize( + qnode = _quantize.simulated_quantize( data, dom_scale, clip_min, clip_max, kind, sign, rounding) + actx.qnode_map[key] = qnode + return qnode + +register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) @register_annotate_function("nn.contrib_conv2d_NCHWc") @@ -152,18 +160,13 @@ def conv2d_rewrite(ref_call, new_args, ctx): """Rewrite function for conv2d. Lhs of conv will be quantized to input field, and rhs of conv will be quantized to weight field. Output would be in activation field""" - cnt = _conv_counter() - if cnt < current_qconfig().skip_k_conv: - _set_conv_counter(cnt + 1) - return None - + actx = annotate_context() if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices: - _set_conv_counter(cnt + 1) + skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if actx.conv2d_counter() in skipped_indices: + actx.count_conv2d() return None - - _set_conv_counter(cnt + 1) + actx.count_conv2d() lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -179,17 +182,21 @@ def conv2d_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) +def check_to_skip(): + """Check the index of conv2d layer to decide whether to skip the current operator.""" + if current_qconfig().skip_conv_layers is not None: + skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if annotate_context().conv2d_counter() - 1 in skipped_indices: + return True + return False + + @register_annotate_function("nn.dense") def dense_rewrite(ref_call, new_args, ctx): """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of dense will be quantized to weight field. Output would be in activation field.""" - cnt = _conv_counter() - if cnt < current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -207,13 +214,8 @@ def dense_rewrite(ref_call, new_args, ctx): @register_annotate_function("multiply") def multiply_rewrite(ref_call, new_args, ctx): """Rewrite function for multiply.""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -234,13 +236,8 @@ def multiply_rewrite(ref_call, new_args, ctx): @register_annotate_function("add") def add_rewrite(ref_call, new_args, ctx): """Rewrite function for add.""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -265,15 +262,25 @@ def add_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) +@register_annotate_function("stop_fusion") +def stop_fusion_rewrite(ref_call, new_args, ctx): + """Rewrite function for add.""" + if check_to_skip(): + return None + + x_expr, x_kind = _get_expr_kind(new_args[0]) + if x_kind is None: + return None + + ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT) + ret_expr = _forward_op(ref_call, [ret_expr]) + return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT) + + def identity_rewrite(ref_call, new_args, ctx): """Simply forward the original operation""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None x_expr, x_kind = _get_expr_kind(new_args[0]) if x_kind is None: @@ -283,6 +290,7 @@ def identity_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(ret_expr, x_kind) +register_annotate_function("clip", identity_rewrite) register_annotate_function("nn.relu", identity_rewrite) register_annotate_function("strided_slice", identity_rewrite) register_annotate_function("nn.avg_pool2d", identity_rewrite) @@ -290,13 +298,8 @@ register_annotate_function("nn.avg_pool2d", identity_rewrite) def pool2d_rewrite(ref_call, new_args, ctx): """Rewrite function for max pool2d""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None expr, x_kind = _get_expr_kind(new_args[0]) @@ -314,13 +317,8 @@ register_annotate_function("nn.max_pool2d", pool2d_rewrite) @register_annotate_function("concatenate") def concatenate_rewrite(ref_call, new_args, ctx): """Rewrite function for concatenate""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None input_tuple = new_args[0] expr_list = [_get_expr_kind(x)[0] for x in input_tuple] diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 992217c..a7749d4 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -71,12 +71,10 @@ class QConfig(NodeBase): "dtype_weight": "int8", "dtype_activation": "int32", "global_scale": 8.0, - "skip_k_conv": 1, - "skip_conv_layers": None, + "skip_conv_layers": [0], "round_for_shift": True, "store_lowbit_output": True, "debug_enabled_ops": None, - "use_stop_fusion": True } # pylint: disable=no-member @@ -138,11 +136,8 @@ def qconfig(**kwargs): global_scale: float The global scale for calibration. - skip_k_conv: int - The number of skipped conv2d. - skip_conv_layers: list - Different way of specifying which layers to avoid. Provide a list of indices + Specifying which layers to be skipped. Provide a list of indices that indicate which conv2d layers to leave untouched. round_for_shift: boolean @@ -152,9 +147,10 @@ def qconfig(**kwargs): Whether to store low-bit integer back as output before dequantizing. Some accelerators need this, e.g. VTA. - use_stop_fusion: boolean - Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit - results to be stored in memory. + debug_enabled_ops: None or list of str + Partially quantize specified operators for debugging. The default value + is None, which means will try to call all operartors' annotate rewrite + function. Returns ------- @@ -166,18 +162,35 @@ def qconfig(**kwargs): return _make.node("relay.quantize.QConfig", **node_args) -CONV_COUNTER = 0 +class AnnotateContext(object): + """A global singleton annotate scope""" + Current = None + + def __init__(self): + self.qnode_map = dict() + self._conv2d_counter = 0 + + def __enter__(self): + self._conv2d_counter = 0 + return self + + def conv2d_counter(self): + """Get the counter for conv2d.""" + return self._conv2d_counter + def count_conv2d(self): + """Increase the value of the conv2d counter by one.""" + self._conv2d_counter += 1 -def _conv_counter(): - """Get the global counter for conv2d.""" - return CONV_COUNTER + def __exit__(self, ptype, value, traceback): + pass -def _set_conv_counter(n): - """Set the value of the global conv2d counter.""" - global CONV_COUNTER - CONV_COUNTER = n +def annotate_context(): + """Get the global singleton scope""" + if AnnotateContext.Current is None: + AnnotateContext.Current = AnnotateContext() + return AnnotateContext.Current def calibrate(graph, mod=None, ctx=None): @@ -324,15 +337,15 @@ def quantize(graph, params=None, dataset=None): calibrate_pass = _transform.function_pass(calibrate, opt_level=1, name="QuantizeCalibrate") - _set_conv_counter(0) # reset counter quantize_seq = _transform.Sequential([annotate(), calibrate_pass, realize(), _transform.FoldConstant()]) - with _transform.PassContext(opt_level=3, - required_pass=["QuantizeAnnotate", - "QuantizeCalibrate", - "QuantizeRealize"]): - mod = optimize(mod) - mod = quantize_seq(mod) + with annotate_context(): + with _transform.PassContext(opt_level=3, + required_pass=["QuantizeAnnotate", + "QuantizeCalibrate", + "QuantizeRealize"]): + mod = optimize(mod) + mod = quantize_seq(mod) return mod[mod.entry_func.name_hint] diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 7b6c1ff..07233a8 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -393,7 +393,7 @@ Array UnifyDTypeScale(const Array& ref_args, } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && ref_arg->attrs.as()->kind == kQInput) { auto new_arg = Cast(ret[i], cfg->dtype_input); - if (cfg->use_stop_fusion) { + if (cfg->store_lowbit_output) { new_arg = StopFusion(new_arg); } ret.Set(i, Cast(new_arg, dtype)); @@ -431,6 +431,28 @@ Expr AddRealize(const Call& ref_call, RELAY_REGISTER_OP("add") .set_attr("FQRealizeRewrite", AddRealize); +Expr ClipRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + const auto ref_attrs = ref_call->attrs.as(); + auto attrs = make_node(); + double dom_scale = GetScalarFromConstant(n->dom_scale); + attrs->a_min = ref_attrs->a_min / dom_scale; + attrs->a_max = ref_attrs->a_max / dom_scale; + + Expr ret = CallNode::make(ref_call->op, + {n->data}, Attrs(attrs), ref_call->type_args); + return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("clip") +.set_attr("FQRealizeRewrite", ClipRealize); + Expr ConcatenateRealize(const Call& ref_call, const Array& new_args, @@ -572,12 +594,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "global_scale=" << op->global_scale << ", "; - p->stream << "skip_k_conv==" << op->skip_k_conv << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; - p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; - p->stream << "use_stop_fusion==" << op->use_stop_fusion; + p->stream << "debug_enabled_ops==" << op->debug_enabled_ops; p->stream << ")"; }); diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 2c70da1..da95a6c 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -125,12 +125,10 @@ class QConfigNode : public Node { DataType dtype_weight = Int(8); DataType dtype_activation = Int(32); double global_scale = 8.0; - int skip_k_conv = 1; Array skip_conv_layers = Array(NodePtr(nullptr)); bool round_for_shift = true; bool store_lowbit_output = true; Array debug_enabled_ops = Array(NodePtr(nullptr)); - bool use_stop_fusion = true; void VisitAttrs(AttrVisitor* v) final { v->Visit("nbit_input", &nbit_input); @@ -140,12 +138,10 @@ class QConfigNode : public Node { v->Visit("dtype_weight", &dtype_weight); v->Visit("dtype_activation", &dtype_activation); v->Visit("global_scale", &global_scale); - v->Visit("skip_k_conv", &skip_k_conv); v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("round_for_shift", &round_for_shift); v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("debug_enabled_ops", &debug_enabled_ops); - v->Visit("use_stop_fusion", &use_stop_fusion); } static constexpr const char* _type_key = "relay.quantize.QConfig"; diff --git a/tests/python/relay/test_pass_quantize.py b/tests/python/relay/test_pass_quantize.py index e02601e..fe62c3b 100644 --- a/tests/python/relay/test_pass_quantize.py +++ b/tests/python/relay/test_pass_quantize.py @@ -81,7 +81,7 @@ def test_quantize_pass(): graph = make_graph(data) dataset, params = make_dataset(graph, 10) - with qtz.qconfig(skip_k_conv=0, global_scale=4.0, + with qtz.qconfig(skip_conv_layers=None, global_scale=4.0, round_for_shift=False, store_lowbit_output=False): qgraph0 = qtz.quantize(graph, params) qgraph0 = relay.ir_pass.infer_type(qgraph0) -- 2.7.4