[Relay][Transform] quantize opt passes to pass manager (#3289)
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 13 Jun 2019 15:57:43 +0000 (08:57 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 13 Jun 2019 15:57:43 +0000 (08:57 -0700)
python/tvm/relay/quantize/quantize.py
src/relay/pass/pass_manager.cc
src/relay/pass/quantize.cc

index 66c35b6..992217c 100644 (file)
@@ -21,7 +21,9 @@ import numpy as np
 
 from . import _quantize
 from .. import expr as _expr
+from .. import module as _module
 from .. import ir_pass as _ir_pass
+from .. import transform as _transform
 from .. import op as _op
 from ... import make as _make
 from ..base import NodeBase, register_relay_node
@@ -178,26 +180,7 @@ def _set_conv_counter(n):
     CONV_COUNTER = n
 
 
-def annotate(graph):
-    """Given a float32 graph, annotate will rewrite the graph
-    and return back a graph which simulates the error brought by
-    current quantization scheme.
-
-    Parameters
-    ---------
-    graph: Function
-        The original graph
-
-    Returns
-    -------
-    ret: Function
-        The graph after annotation
-    """
-    _set_conv_counter(0)  # reset counter
-    return _quantize.annotate(graph)
-
-
-def calibrate(graph, dataset=None):
+def calibrate(graph, mod=None, ctx=None):
     """The calibrate procedure will try to calculate the content of
     dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
     operator.
@@ -207,8 +190,11 @@ def calibrate(graph, dataset=None):
     graph: Function
         The simulation graph after annotation.
 
-    dataset: list of dict of Var -> NDArray
-        The calibration dataset.
+    mod: tvm.relay.Module
+        The module where calibration happens on.
+
+    ctx: tvm.relay.PassContext
+        The pass context used for calibration.
 
     Returns
     -------
@@ -253,93 +239,52 @@ def calibrate(graph, dataset=None):
     return _expr.bind(graph, const_params)
 
 
-def realize(graph):
-    """The realize pass will transform the simulated quantized
-    graph, which computes with float32 actually, to a real low-bit
-    integer graph. It will replace the simulated_quantize with
-    several fine-grained operators like add, multiply, and shift
-    as more as possible for performance (fusion, etc.)
-
-    Parameters
-    ---------
-    graph: Function
-        The simulated graph after calibrating.
+def annotate():
+    """Given a float32 graph, this pass will rewrite the graph and return
+    a graph which simulates the error brought by the current quantization
+    scheme.
 
     Returns
     -------
-    ret: Function
-        The graph after realization
+    ret: tvm.relay.Pass
+        The registered pass for quantization annotation.
     """
-    return _quantize.realize(graph)
+    return _quantize.QuantizeAnnotate()
 
 
-def optimize(func, params=None):
-    """ Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
-    "CanonicalizeOps" optimization before quantization.
-
-    # TODO(zhiics) These passes are executed one by one so far. We need to
-    # move them to the pass manager.
-
-    Parameters
-    ---------
-    func: tvm.relay.Function
-        The original Relay function to be optimized.
-
-    params : dict of str to tvm.NDArray
-        Input parameters to the graph that do not change
-        during inference time. Used for constant folding.
+def realize():
+    """The realize pass will transform the simulated quantized graph, which
+    actually computes with float32, to a real low-bit integer graph. It will
+    replace the `simulated_quantize` with several fine-grained operators like
+    add, multiply, and shift as much as possible for better performance.
 
     Returns
     -------
-    ret: tvm.relay.Function
-        The graph after quantization
+    ret: tvm.relay.Pass
+        The registered pass for quantization realization.
     """
+    return _quantize.QuantizeRealize()
 
-    opt_passes = ["SimplifyInference",
-                  "FoldScaleAxis",
-                  "FoldConstant",
-                  "CanonicalizeOps"]
 
-    if params:
-        name_dict = {}
-        for arg in func.params:
-            name = arg.name_hint
-            if name in name_dict:
-                name_dict[name] = None
-            else:
-                name_dict[name] = arg
-        bind_dict = {}
-        for k, v in params.items():
-            if k not in name_dict:
-                continue
-            arg = name_dict[k]
-            if arg is None:
-                raise ValueError("Multiple args in the function have name %s" % k)
-            bind_dict[arg] = _expr.const(v)
-        func = _expr.bind(func, bind_dict)
-
-    if "SimplifyInference" in opt_passes:
-        func = _ir_pass.infer_type(func)
-        func = _ir_pass.simplify_inference(func)
-
-    if "FoldConstant" in opt_passes:
-        func = _ir_pass.fold_constant(func)
-
-    if "FoldScaleAxis" in opt_passes:
-        func = _ir_pass.infer_type(func)
-        func = _ir_pass.backward_fold_scale_axis(func)
-        func = _ir_pass.infer_type(func)
-        func = _ir_pass.forward_fold_scale_axis(func)
-        func = _ir_pass.fold_constant(func)
-
-    if "CanonicalizeOps" in opt_passes:
-        func = _ir_pass.infer_type(func)
-        func = _ir_pass.canonicalize_ops(func)
-
-    if "FoldConstant" in opt_passes:
-        func = _ir_pass.fold_constant(func)
-
-    return func
+def _bind_params(func, params):
+    """Bind the params to the expression.
+    """
+    name_dict = {}
+    for arg in func.params:
+        name = arg.name_hint
+        if name in name_dict:
+            name_dict[name] = None
+        else:
+            name_dict[name] = arg
+    bind_dict = {}
+    for k, v in params.items():
+        if k not in name_dict:
+            continue
+        arg = name_dict[k]
+        if arg is None:
+            raise ValueError("Multiple args in the function have name %s" % k)
+        bind_dict[arg] = _expr.const(v)
+    return _expr.bind(func, bind_dict)
 
 
 def quantize(graph, params=None, dataset=None):
@@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None):
     ret: Function
         The graph after quantization
     """
-    # TODO(zhiics) Move this to the pass manager.
-    graph = optimize(graph, params)
-
-    graph = annotate(graph)
-    graph = calibrate(graph, dataset)
-    graph = realize(graph)
-    graph = _ir_pass.fold_constant(graph)
-    return graph
+    if params:
+        graph = _bind_params(graph, params)
+
+    mod = _module.Module.from_expr(graph)
+    # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
+    # "CanonicalizeOps" optimization before quantization.
+    optimize = _transform.Sequential([_transform.SimplifyInference(),
+                                      _transform.FoldConstant(),
+                                      _transform.FoldScaleAxis(),
+                                      _transform.CanonicalizeOps(),
+                                      _transform.FoldConstant()])
+
+    calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
+                                              name="QuantizeCalibrate")
+    _set_conv_counter(0)  # reset counter
+    quantize_seq = _transform.Sequential([annotate(),
+                                          calibrate_pass,
+                                          realize(),
+                                          _transform.FoldConstant()])
+    with _transform.PassContext(opt_level=3,
+                                required_pass=["QuantizeAnnotate",
+                                               "QuantizeCalibrate",
+                                               "QuantizeRealize"]):
+        mod = optimize(mod)
+        mod = quantize_seq(mod)
+    return mod[mod.entry_func.name_hint]
index fa79a5e..d63d912 100644 (file)
@@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod,
              << pass_info->name
              << " with opt level: "
              << pass_info->opt_level;
+
   Module updated_mod = mod;
   // Execute the pass function and return a new module.
   std::vector<std::pair<GlobalVar, Function> > updates;
index 3a2e54c..7b6c1ff 100644 (file)
@@ -43,6 +43,8 @@ namespace tvm {
 namespace relay {
 namespace quantize {
 
+using namespace relay::transform;
+
 /*! \brief Attribute for simulated quantize operator */
 struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
   int kind;
@@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
       static_cast<QAnnotateKind>(args[1].operator int()));
   });
 
-
-TVM_REGISTER_API("relay._quantize.annotate")
-.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
-  std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
-      if (e->derived_from<TempExprNode>()) {
-        const auto* n = e.as<QAnnotateExprNode>();
-        CHECK(n);
-        const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
-        Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
-        return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
-      }
-      return e;
-    };
-  return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
-});
-
-
 // =============
 // realize pass
 
@@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call,
 RELAY_REGISTER_OP("nn.avg_pool2d")
 .set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
 
-
-TVM_REGISTER_API("relay._quantize.realize")
-.set_body_typed<Expr(Expr)>([](const Expr& e) {
-  Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
-  return ret;
-});
-
-
 // =============
 // qconfig
 
@@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
 TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
 .set_body_typed(QConfig::ExitQConfigScope);
 
+Pass QuantizeAnnotate() {
+  std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
+    if (e->derived_from<TempExprNode>()) {
+      const auto* n = e.as<QAnnotateExprNode>();
+      CHECK(n);
+      const PackedFunc* f =
+          runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
+      Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
+      return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
+    }
+    return e;
+  };
+
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      return Downcast<Function>(
+          ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
+  };
+  return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
+.set_body_typed(QuantizeAnnotate);
+
+Pass QuantizeRealizePass() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      return Downcast<Function>(
+          ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
+  };
+  return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizeRealize")
+.set_body_typed(QuantizeRealizePass);
+
 }  // namespace quantize
 }  // namespace relay
 }  // namespace tvm