[QUANTIZE] Memorizing the quantize node mapping (#3233)
authorziheng <ziheng@apache.org>
Sat, 22 Jun 2019 21:59:41 +0000 (14:59 -0700)
committerGitHub <noreply@github.com>
Sat, 22 Jun 2019 21:59:41 +0000 (14:59 -0700)
* [QUANTIZE] Support for clip operator

* [QUANTIZE] Memorizing the quantize node mapping.

* [QUANTIZE] Remove use_stop_fusion and skip_k_conv in qconfig

* update

* update

* update

* update

python/tvm/relay/backend/_backend.py
python/tvm/relay/quantize/_annotate.py
python/tvm/relay/quantize/quantize.py
src/relay/pass/quantize.cc
src/relay/pass/quantize.h
tests/python/relay/test_pass_quantize.py

index 50e9694..860788a 100644 (file)
@@ -17,7 +17,6 @@
 """The interface of expr function exposed from C++."""
 from __future__ import absolute_import
 
-import logging
 from ... import build_module as _build
 from ... import container as _container
 from ..._ffi.function import _init_api, register_func
@@ -50,8 +49,8 @@ def lower(sch, inputs, func_name, source_func):
     # pylint: disable=broad-except
     try:
         f = _build.lower(sch, inputs, name=func_name)
-        logging.debug("lower function %s", func_name)
-        logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
+        logging.debug("lower function %s", func_name)
+        logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
     except Exception:
         msg = traceback.format_exc()
         msg += "Error during compile function\n"
index 9bf546f..61e895a 100644 (file)
@@ -22,7 +22,7 @@ import warnings
 import topi
 from . import _quantize
 from .quantize import QAnnotateKind, current_qconfig
-from .quantize import _conv_counter, _set_conv_counter
+from .quantize import annotate_context
 from .. import expr as _expr
 from .. import op as _op
 from ..op import op as _reg
@@ -116,7 +116,6 @@ def register_annotate_function(op_name, frewrite=None, level=10):
     return _register(frewrite) if frewrite is not None else _register
 
 
-@register_func("relay.quantize.attach_simulated_quantize")
 def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
     """Attach a simulated quantize operation after input data expr.
 
@@ -133,11 +132,20 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
         if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
             return data
 
+    actx = annotate_context()
+    key = tuple([data, kind, sign, rounding])
+    if key in actx.qnode_map:
+        return actx.qnode_map[key]
+
     dom_scale = _expr.var("dom_scale")
     clip_min = _expr.var("clip_min")
     clip_max = _expr.var("clip_max")
-    return _quantize.simulated_quantize(
+    qnode = _quantize.simulated_quantize(
         data, dom_scale, clip_min, clip_max, kind, sign, rounding)
+    actx.qnode_map[key] = qnode
+    return qnode
+
+register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
 
 
 @register_annotate_function("nn.contrib_conv2d_NCHWc")
@@ -152,18 +160,13 @@ def conv2d_rewrite(ref_call, new_args, ctx):
     """Rewrite function for conv2d. Lhs of conv will be quantized to
     input field, and rhs of conv will be quantized to weight field.
     Output would be in activation field"""
-    cnt = _conv_counter()
-    if cnt < current_qconfig().skip_k_conv:
-        _set_conv_counter(cnt + 1)
-        return None
-
+    actx = annotate_context()
     if current_qconfig().skip_conv_layers is not None:
-        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if cnt in leave_alone_indices:
-            _set_conv_counter(cnt + 1)
+        skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+        if actx.conv2d_counter() in skipped_indices:
+            actx.count_conv2d()
             return None
-
-    _set_conv_counter(cnt + 1)
+    actx.count_conv2d()
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
     rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@@ -179,17 +182,21 @@ def conv2d_rewrite(ref_call, new_args, ctx):
     return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
 
 
+def check_to_skip():
+    """Check the index of conv2d layer to decide whether to skip the current operator."""
+    if current_qconfig().skip_conv_layers is not None:
+        skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+        if annotate_context().conv2d_counter() - 1 in skipped_indices:
+            return True
+    return False
+
+
 @register_annotate_function("nn.dense")
 def dense_rewrite(ref_call, new_args, ctx):
     """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
     dense will be quantized to weight field. Output would be in activation field."""
-    cnt = _conv_counter()
-    if cnt < current_qconfig().skip_k_conv:
+    if check_to_skip():
         return None
-    if current_qconfig().skip_conv_layers is not None:
-        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if cnt - 1 in leave_alone_indices:
-            return None
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
     rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@@ -207,13 +214,8 @@ def dense_rewrite(ref_call, new_args, ctx):
 @register_annotate_function("multiply")
 def multiply_rewrite(ref_call, new_args, ctx):
     """Rewrite function for multiply."""
-    cnt = _conv_counter()
-    if cnt <= current_qconfig().skip_k_conv:
+    if check_to_skip():
         return None
-    if current_qconfig().skip_conv_layers is not None:
-        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if cnt - 1 in leave_alone_indices:
-            return None
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
     rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@@ -234,13 +236,8 @@ def multiply_rewrite(ref_call, new_args, ctx):
 @register_annotate_function("add")
 def add_rewrite(ref_call, new_args, ctx):
     """Rewrite function for add."""
-    cnt = _conv_counter()
-    if cnt <= current_qconfig().skip_k_conv:
+    if check_to_skip():
         return None
-    if current_qconfig().skip_conv_layers is not None:
-        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if cnt - 1 in leave_alone_indices:
-            return None
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
     rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@@ -265,15 +262,25 @@ def add_rewrite(ref_call, new_args, ctx):
     return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
 
 
+@register_annotate_function("stop_fusion")
+def stop_fusion_rewrite(ref_call, new_args, ctx):
+    """Rewrite function for add."""
+    if check_to_skip():
+        return None
+
+    x_expr, x_kind = _get_expr_kind(new_args[0])
+    if x_kind is None:
+        return None
+
+    ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT)
+    ret_expr = _forward_op(ref_call, [ret_expr])
+    return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT)
+
+
 def identity_rewrite(ref_call, new_args, ctx):
     """Simply forward the original operation"""
-    cnt = _conv_counter()
-    if cnt <= current_qconfig().skip_k_conv:
+    if check_to_skip():
         return None
-    if current_qconfig().skip_conv_layers is not None:
-        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if cnt - 1 in leave_alone_indices:
-            return None
 
     x_expr, x_kind = _get_expr_kind(new_args[0])
     if x_kind is None:
@@ -283,6 +290,7 @@ def identity_rewrite(ref_call, new_args, ctx):
     return QAnnotateExpr(ret_expr, x_kind)
 
 
+register_annotate_function("clip", identity_rewrite)
 register_annotate_function("nn.relu", identity_rewrite)
 register_annotate_function("strided_slice", identity_rewrite)
 register_annotate_function("nn.avg_pool2d", identity_rewrite)
@@ -290,13 +298,8 @@ register_annotate_function("nn.avg_pool2d", identity_rewrite)
 
 def pool2d_rewrite(ref_call, new_args, ctx):
     """Rewrite function for max pool2d"""
-    cnt = _conv_counter()
-    if cnt <= current_qconfig().skip_k_conv:
+    if check_to_skip():
         return None
-    if current_qconfig().skip_conv_layers is not None:
-        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if cnt - 1 in leave_alone_indices:
-            return None
 
     expr, x_kind = _get_expr_kind(new_args[0])
 
@@ -314,13 +317,8 @@ register_annotate_function("nn.max_pool2d", pool2d_rewrite)
 @register_annotate_function("concatenate")
 def concatenate_rewrite(ref_call, new_args, ctx):
     """Rewrite function for concatenate"""
-    cnt = _conv_counter()
-    if cnt <= current_qconfig().skip_k_conv:
+    if check_to_skip():
         return None
-    if current_qconfig().skip_conv_layers is not None:
-        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if cnt - 1 in leave_alone_indices:
-            return None
 
     input_tuple = new_args[0]
     expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
index 992217c..a7749d4 100644 (file)
@@ -71,12 +71,10 @@ class QConfig(NodeBase):
         "dtype_weight": "int8",
         "dtype_activation": "int32",
         "global_scale": 8.0,
-        "skip_k_conv": 1,
-        "skip_conv_layers": None,
+        "skip_conv_layers": [0],
         "round_for_shift": True,
         "store_lowbit_output": True,
         "debug_enabled_ops": None,
-        "use_stop_fusion": True
     }
 
     # pylint: disable=no-member
@@ -138,11 +136,8 @@ def qconfig(**kwargs):
     global_scale: float
         The global scale for calibration.
 
-    skip_k_conv: int
-        The number of skipped conv2d.
-
     skip_conv_layers: list
-        Different way of specifying which layers to avoid. Provide a list of indices
+        Specifying which layers to be skipped. Provide a list of indices
         that indicate which conv2d layers to leave untouched.
 
     round_for_shift: boolean
@@ -152,9 +147,10 @@ def qconfig(**kwargs):
         Whether to store low-bit integer back as output before dequantizing.
         Some accelerators need this, e.g. VTA.
 
-    use_stop_fusion: boolean
-        Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit
-        results to be stored in memory.
+    debug_enabled_ops: None or list of str
+        Partially quantize specified operators for debugging. The default value
+        is None, which means will try to call all operartors' annotate rewrite
+        function.
 
     Returns
     -------
@@ -166,18 +162,35 @@ def qconfig(**kwargs):
     return _make.node("relay.quantize.QConfig", **node_args)
 
 
-CONV_COUNTER = 0
+class AnnotateContext(object):
+    """A global singleton annotate scope"""
+    Current = None
+
+    def __init__(self):
+        self.qnode_map = dict()
+        self._conv2d_counter = 0
+
+    def __enter__(self):
+        self._conv2d_counter = 0
+        return self
+
+    def conv2d_counter(self):
+        """Get the counter for conv2d."""
+        return self._conv2d_counter
 
+    def count_conv2d(self):
+        """Increase the value of the conv2d counter by one."""
+        self._conv2d_counter += 1
 
-def _conv_counter():
-    """Get the global counter for conv2d."""
-    return CONV_COUNTER
+    def __exit__(self, ptype, value, traceback):
+        pass
 
 
-def _set_conv_counter(n):
-    """Set the value of the global conv2d counter."""
-    global CONV_COUNTER
-    CONV_COUNTER = n
+def annotate_context():
+    """Get the global singleton scope"""
+    if AnnotateContext.Current is None:
+        AnnotateContext.Current = AnnotateContext()
+    return AnnotateContext.Current
 
 
 def calibrate(graph, mod=None, ctx=None):
@@ -324,15 +337,15 @@ def quantize(graph, params=None, dataset=None):
 
     calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
                                               name="QuantizeCalibrate")
-    _set_conv_counter(0)  # reset counter
     quantize_seq = _transform.Sequential([annotate(),
                                           calibrate_pass,
                                           realize(),
                                           _transform.FoldConstant()])
-    with _transform.PassContext(opt_level=3,
-                                required_pass=["QuantizeAnnotate",
-                                               "QuantizeCalibrate",
-                                               "QuantizeRealize"]):
-        mod = optimize(mod)
-        mod = quantize_seq(mod)
+    with annotate_context():
+        with _transform.PassContext(opt_level=3,
+                                    required_pass=["QuantizeAnnotate",
+                                                   "QuantizeCalibrate",
+                                                   "QuantizeRealize"]):
+            mod = optimize(mod)
+            mod = quantize_seq(mod)
     return mod[mod.entry_func.name_hint]
index 7b6c1ff..07233a8 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -393,7 +393,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
     } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
                ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
       auto new_arg = Cast(ret[i], cfg->dtype_input);
-      if (cfg->use_stop_fusion) {
+      if (cfg->store_lowbit_output) {
         new_arg = StopFusion(new_arg);
       }
       ret.Set(i, Cast(new_arg, dtype));
@@ -431,6 +431,28 @@ Expr AddRealize(const Call& ref_call,
 RELAY_REGISTER_OP("add")
 .set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
 
+Expr ClipRealize(const Call& ref_call,
+                 const Array<Expr>& new_args,
+                 const NodeRef& ctx) {
+  CHECK_EQ(new_args.size(), 1);
+  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+    const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
+    auto attrs = make_node<ClipAttrs>();
+    double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
+    attrs->a_min = ref_attrs->a_min / dom_scale;
+    attrs->a_max = ref_attrs->a_max / dom_scale;
+
+    Expr ret = CallNode::make(ref_call->op,
+      {n->data}, Attrs(attrs), ref_call->type_args);
+    return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
+  }
+  CHECK(!new_args[0]->derived_from<TempExprNode>());
+  return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("clip")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
+
 
 Expr ConcatenateRealize(const Call& ref_call,
                         const Array<Expr>& new_args,
@@ -572,12 +594,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   p->stream << "nbit_weight=" << op->nbit_weight << ", ";
   p->stream << "nbit_activation=" << op->nbit_activation << ", ";
   p->stream << "global_scale=" << op->global_scale << ", ";
-  p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
   p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
   p->stream << "round_for_shift==" << op->round_for_shift << ", ";
   p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
-  p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
-  p->stream << "use_stop_fusion==" << op->use_stop_fusion;
+  p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
   p->stream << ")";
 });
 
index 2c70da1..da95a6c 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -125,12 +125,10 @@ class QConfigNode : public Node {
   DataType dtype_weight = Int(8);
   DataType dtype_activation = Int(32);
   double global_scale = 8.0;
-  int skip_k_conv = 1;
   Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
   bool round_for_shift = true;
   bool store_lowbit_output = true;
   Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
-  bool use_stop_fusion = true;
 
   void VisitAttrs(AttrVisitor* v) final {
     v->Visit("nbit_input", &nbit_input);
@@ -140,12 +138,10 @@ class QConfigNode : public Node {
     v->Visit("dtype_weight", &dtype_weight);
     v->Visit("dtype_activation", &dtype_activation);
     v->Visit("global_scale", &global_scale);
-    v->Visit("skip_k_conv", &skip_k_conv);
     v->Visit("skip_conv_layers", &skip_conv_layers);
     v->Visit("round_for_shift", &round_for_shift);
     v->Visit("store_lowbit_output", &store_lowbit_output);
     v->Visit("debug_enabled_ops", &debug_enabled_ops);
-    v->Visit("use_stop_fusion", &use_stop_fusion);
   }
 
   static constexpr const char* _type_key = "relay.quantize.QConfig";
index e02601e..fe62c3b 100644 (file)
@@ -81,7 +81,7 @@ def test_quantize_pass():
     graph = make_graph(data)
     dataset, params = make_dataset(graph, 10)
 
-    with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
+    with qtz.qconfig(skip_conv_layers=None, global_scale=4.0,
                      round_for_shift=False, store_lowbit_output=False):
         qgraph0 = qtz.quantize(graph, params)
         qgraph0 = relay.ir_pass.infer_type(qgraph0)