from . import _quantize
from .. import expr as _expr
+from .. import module as _module
from .. import ir_pass as _ir_pass
+from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
CONV_COUNTER = n
-def annotate(graph):
- """Given a float32 graph, annotate will rewrite the graph
- and return back a graph which simulates the error brought by
- current quantization scheme.
-
- Parameters
- ---------
- graph: Function
- The original graph
-
- Returns
- -------
- ret: Function
- The graph after annotation
- """
- _set_conv_counter(0) # reset counter
- return _quantize.annotate(graph)
-
-
-def calibrate(graph, dataset=None):
+def calibrate(graph, mod=None, ctx=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
graph: Function
The simulation graph after annotation.
- dataset: list of dict of Var -> NDArray
- The calibration dataset.
+ mod: tvm.relay.Module
+ The module where calibration happens on.
+
+ ctx: tvm.relay.PassContext
+ The pass context used for calibration.
Returns
-------
return _expr.bind(graph, const_params)
-def realize(graph):
- """The realize pass will transform the simulated quantized
- graph, which computes with float32 actually, to a real low-bit
- integer graph. It will replace the simulated_quantize with
- several fine-grained operators like add, multiply, and shift
- as more as possible for performance (fusion, etc.)
-
- Parameters
- ---------
- graph: Function
- The simulated graph after calibrating.
+def annotate():
+ """Given a float32 graph, this pass will rewrite the graph and return
+ a graph which simulates the error brought by the current quantization
+ scheme.
Returns
-------
- ret: Function
- The graph after realization
+ ret: tvm.relay.Pass
+ The registered pass for quantization annotation.
"""
- return _quantize.realize(graph)
+ return _quantize.QuantizeAnnotate()
-def optimize(func, params=None):
- """ Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
- "CanonicalizeOps" optimization before quantization.
-
- # TODO(zhiics) These passes are executed one by one so far. We need to
- # move them to the pass manager.
-
- Parameters
- ---------
- func: tvm.relay.Function
- The original Relay function to be optimized.
-
- params : dict of str to tvm.NDArray
- Input parameters to the graph that do not change
- during inference time. Used for constant folding.
+def realize():
+ """The realize pass will transform the simulated quantized graph, which
+ actually computes with float32, to a real low-bit integer graph. It will
+ replace the `simulated_quantize` with several fine-grained operators like
+ add, multiply, and shift as much as possible for better performance.
Returns
-------
- ret: tvm.relay.Function
- The graph after quantization
+ ret: tvm.relay.Pass
+ The registered pass for quantization realization.
"""
+ return _quantize.QuantizeRealize()
- opt_passes = ["SimplifyInference",
- "FoldScaleAxis",
- "FoldConstant",
- "CanonicalizeOps"]
- if params:
- name_dict = {}
- for arg in func.params:
- name = arg.name_hint
- if name in name_dict:
- name_dict[name] = None
- else:
- name_dict[name] = arg
- bind_dict = {}
- for k, v in params.items():
- if k not in name_dict:
- continue
- arg = name_dict[k]
- if arg is None:
- raise ValueError("Multiple args in the function have name %s" % k)
- bind_dict[arg] = _expr.const(v)
- func = _expr.bind(func, bind_dict)
-
- if "SimplifyInference" in opt_passes:
- func = _ir_pass.infer_type(func)
- func = _ir_pass.simplify_inference(func)
-
- if "FoldConstant" in opt_passes:
- func = _ir_pass.fold_constant(func)
-
- if "FoldScaleAxis" in opt_passes:
- func = _ir_pass.infer_type(func)
- func = _ir_pass.backward_fold_scale_axis(func)
- func = _ir_pass.infer_type(func)
- func = _ir_pass.forward_fold_scale_axis(func)
- func = _ir_pass.fold_constant(func)
-
- if "CanonicalizeOps" in opt_passes:
- func = _ir_pass.infer_type(func)
- func = _ir_pass.canonicalize_ops(func)
-
- if "FoldConstant" in opt_passes:
- func = _ir_pass.fold_constant(func)
-
- return func
+def _bind_params(func, params):
+ """Bind the params to the expression.
+ """
+ name_dict = {}
+ for arg in func.params:
+ name = arg.name_hint
+ if name in name_dict:
+ name_dict[name] = None
+ else:
+ name_dict[name] = arg
+ bind_dict = {}
+ for k, v in params.items():
+ if k not in name_dict:
+ continue
+ arg = name_dict[k]
+ if arg is None:
+ raise ValueError("Multiple args in the function have name %s" % k)
+ bind_dict[arg] = _expr.const(v)
+ return _expr.bind(func, bind_dict)
def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
- # TODO(zhiics) Move this to the pass manager.
- graph = optimize(graph, params)
-
- graph = annotate(graph)
- graph = calibrate(graph, dataset)
- graph = realize(graph)
- graph = _ir_pass.fold_constant(graph)
- return graph
+ if params:
+ graph = _bind_params(graph, params)
+
+ mod = _module.Module.from_expr(graph)
+ # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
+ # "CanonicalizeOps" optimization before quantization.
+ optimize = _transform.Sequential([_transform.SimplifyInference(),
+ _transform.FoldConstant(),
+ _transform.FoldScaleAxis(),
+ _transform.CanonicalizeOps(),
+ _transform.FoldConstant()])
+
+ calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
+ name="QuantizeCalibrate")
+ _set_conv_counter(0) # reset counter
+ quantize_seq = _transform.Sequential([annotate(),
+ calibrate_pass,
+ realize(),
+ _transform.FoldConstant()])
+ with _transform.PassContext(opt_level=3,
+ required_pass=["QuantizeAnnotate",
+ "QuantizeCalibrate",
+ "QuantizeRealize"]):
+ mod = optimize(mod)
+ mod = quantize_seq(mod)
+ return mod[mod.entry_func.name_hint]
namespace relay {
namespace quantize {
+using namespace relay::transform;
+
/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
static_cast<QAnnotateKind>(args[1].operator int()));
});
-
-TVM_REGISTER_API("relay._quantize.annotate")
-.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
- std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
- if (e->derived_from<TempExprNode>()) {
- const auto* n = e.as<QAnnotateExprNode>();
- CHECK(n);
- const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
- Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
- return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
- }
- return e;
- };
- return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
-});
-
-
// =============
// realize pass
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
-
-TVM_REGISTER_API("relay._quantize.realize")
-.set_body_typed<Expr(Expr)>([](const Expr& e) {
- Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
- return ret;
-});
-
-
// =============
// qconfig
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body_typed(QConfig::ExitQConfigScope);
+Pass QuantizeAnnotate() {
+ std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
+ if (e->derived_from<TempExprNode>()) {
+ const auto* n = e.as<QAnnotateExprNode>();
+ CHECK(n);
+ const PackedFunc* f =
+ runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
+ Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
+ return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
+ }
+ return e;
+ };
+
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(
+ ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
+ };
+ return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
+.set_body_typed(QuantizeAnnotate);
+
+Pass QuantizeRealizePass() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(
+ ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
+ };
+ return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizeRealize")
+.set_body_typed(QuantizeRealizePass);
+
} // namespace quantize
} // namespace relay
} // namespace tvm