* Refactor.
* update
* update
* update
* update
* update
* update
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
/*!
+ * \brief Check whether an expression is constant.
+ *
+ * If the inputs of an expression are all constant, it means the expression
+ * itself is constant also.
+ *
+ * \param e the expression.
+ *
+ * \return whether the expression is constant.
+ */
+TVM_DLL bool ConstantCheck(const Expr& e);
+
+/*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
}
};
+/*!
+ * \brief Annotate an expression to be cast into specific data type.
+ */
+struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
+ DataType dtype;
+
+ TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") {
+ TVM_ATTR_FIELD(dtype)
+ .describe(
+ "The data type denoted to be cast.");
+ }
+};
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
return _analysis.check_kind(t)
+def check_constant(expr):
+ """Check whether an expression is constant
+
+ Parameters
+ ----------
+ expr : tvm.relay.Expr
+ The input expression
+
+ Returns
+ -------
+ result : bool
+ Whether the expression is constant.
+ """
+ return _analysis.check_constant(expr)
+
+
def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.
from __future__ import absolute_import as _abs
from .quantize import *
+from ._partition import register_partition_function
from ._annotate import register_annotate_function
from .kl_divergence import kl_divergence_scale
import warnings
import topi
-from . import _quantize
-from .quantize import QAnnotateKind, current_qconfig
-from .quantize import annotate_context
+from ..._ffi.function import register_func
from .. import expr as _expr
+from .. import analysis as _analysis
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
-from ..._ffi.function import register_func
+from . import _quantize
+from .quantize import QAnnotateKind, current_qconfig, quantize_context
+from .quantize import _forward_op
@_reg.register_compute("relay.op.annotation.simulated_quantize")
_quantize.make_annotate_expr, expr, kind)
-def _forward_op(ref_call, args):
- """forward the operator of ref_call with provided arguments"""
- return _expr.Call(
- ref_call.op, args, ref_call.attrs, ref_call.type_args)
-
-
def _get_expr_kind(anno):
"""Get the expression and QAnnotateKind from QAnnotateExpr or Expr"""
if isinstance(anno, QAnnotateExpr):
if not current_qconfig().guard(ref_call):
return default_rewrite(ref_call, new_args, ctx)
return func(ref_call, new_args, ctx)
- _op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
+ _reg._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
return frewrite_with_guard
return _register(frewrite) if frewrite is not None else _register
if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
return data
- actx = annotate_context()
+ qctx = quantize_context()
key = tuple([data, kind, sign, rounding])
- if key in actx.qnode_map:
- return actx.qnode_map[key]
+ if key in qctx.qnode_map:
+ return qctx.qnode_map[key]
dom_scale = _expr.var("dom_scale")
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
qnode = _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding)
- actx.qnode_map[key] = qnode
+ qctx.qnode_map[key] = qnode
return qnode
register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
"""Rewrite function for conv2d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field.
Output would be in activation field"""
- actx = annotate_context()
- if current_qconfig().skip_conv_layers is not None:
- skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
- if actx.conv2d_counter() in skipped_indices:
- actx.count_conv2d()
- return None
- actx.count_conv2d()
+ if quantize_context().check_to_skip(ref_call):
+ return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
-def check_to_skip():
- """Check the index of conv2d layer to decide whether to skip the current operator."""
- if current_qconfig().skip_conv_layers is not None:
- skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
- if annotate_context().conv2d_counter() - 1 in skipped_indices:
- return True
- return False
-
-
# TODO(tmoreau89,ziheng) need to include an option to turn off dense quant
# @register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
dense will be quantized to weight field. Output would be in activation field."""
- if check_to_skip():
+ if quantize_context().check_to_skip(ref_call):
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
- if check_to_skip():
+ if quantize_context().check_to_skip(ref_call):
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
- if check_to_skip():
+ if quantize_context().check_to_skip(ref_call):
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None and rhs_kind is None:
+ # trivial case
return None
if lhs_kind is None and rhs_kind is not None:
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
if lhs_kind is not None and rhs_kind is None:
- if isinstance(rhs_expr, _expr.Constant):
- # quantize rhs to WEIGHT field if it is Constant
+ if _analysis.check_constant(rhs_expr):
+ # - introduced by batch_norm: add(out, const)
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
- # quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
- # quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError()
-@register_annotate_function("stop_fusion")
-def stop_fusion_rewrite(ref_call, new_args, ctx):
- """Rewrite function for add."""
- if check_to_skip():
- return None
-
- x_expr, x_kind = _get_expr_kind(new_args[0])
- if x_kind is None:
- return None
-
- ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT)
- ret_expr = _forward_op(ref_call, [ret_expr])
- return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT)
-
-
def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation"""
- if check_to_skip():
+ if quantize_context().check_to_skip(ref_call):
return None
x_expr, x_kind = _get_expr_kind(new_args[0])
def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
- if check_to_skip():
+ if quantize_context().check_to_skip(ref_call):
return None
expr, x_kind = _get_expr_kind(new_args[0])
register_annotate_function("nn.max_pool2d", pool2d_rewrite)
-@register_annotate_function("annotation.force_cast")
-def force_cast_rewrite(ref_call, new_args, ctx):
+@register_annotate_function("annotation.cast_hint")
+def cast_hint_rewrite(ref_call, new_args, ctx):
"""Rewrite function to force cast"""
- if check_to_skip():
- return None
-
expr, x_kind = _get_expr_kind(new_args[0])
+ if quantize_context().check_to_skip(ref_call):
+ return expr
+
if x_kind is None:
return new_args[0]
if x_kind == QAnnotateKind.ACTIVATION:
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
- if check_to_skip():
+ if quantize_context().check_to_skip(ref_call):
return None
input_tuple = new_args[0]
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
-# Graph rewrite function registration for VTA target
-def register_vta_rewrite(op_name, frewrite=None, level=10):
- def _register(func):
- return _op.op._Register(op_name, "FQVTARewrite", func, level)
- return _register(frewrite) if frewrite is not None else _register
+@register_annotate_function("nn.global_avg_pool2d")
+def global_avg_pool2d_rewrite(ref_call, new_args, ctx):
+ """Rewrite function for global_avg_pool2d for stopping quantize"""
+ if quantize_context().check_to_skip(ref_call):
+ return None
+ expr, x_kind = _get_expr_kind(new_args[0])
-@register_relay_node
-class QVTAExpr(_expr.TempExpr):
- def __init__(self, expr):
- self.__init_handle_by_constructor__(
- _quantize.make_vta_expr, expr)
-
- def realize(self):
- return _quantize.temp_expr_realize(self)
-
-
-def vta_expr_check(expr):
- if isinstance(expr, QVTAExpr):
- return True, expr.expr
- return False, expr
-
-
-@register_vta_rewrite("nn.conv2d")
-def conv2d_vta_rewrite(ref_call, new_args, ctx):
- """Rewrite function for conv2d for VTA target"""
- actx = annotate_context()
- if current_qconfig().skip_conv_layers is not None:
- skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
- if actx.conv2d_counter() in skipped_indices:
- actx.count_conv2d()
- return None
- actx.count_conv2d()
-
- data_cond, data = vta_expr_check(new_args[0])
- kernel_cond, kernel = vta_expr_check(new_args[1])
-
- assert not kernel_cond
- if data_cond:
- data = new_args[0].realize()
- ret = _forward_op(ref_call, [data, kernel])
- return QVTAExpr(ret)
-
-
-def identity_vta_rewrite(ref_call, new_args, ctx):
- cond, expr = vta_expr_check(new_args[0])
- if cond:
- return QVTAExpr(_forward_op(ref_call, [expr]))
- return None
-
-register_vta_rewrite("nn.relu", identity_vta_rewrite)
-register_vta_rewrite("nn.max_pool2d", identity_vta_rewrite)
-
-
-@register_vta_rewrite("add")
-def add_vta_rewrite(ref_call, new_args, ctx):
- """Rewrite function for ewise add for VTA target"""
- lhs_cond, lhs = vta_expr_check(new_args[0])
- rhs_cond, rhs = vta_expr_check(new_args[1])
- if lhs_cond and rhs_cond:
- lhs = new_args[0].realize()
- rhs = new_args[1].realize()
- return _forward_op(ref_call, [lhs, rhs])
- elif lhs_cond and not rhs_cond:
- return QVTAExpr(_forward_op(ref_call, [lhs, rhs]))
- return None
+ if x_kind is None:
+ return None
+ expr = _forward_op(ref_call, [new_args[0].realize()])
+
+ # stop quantize after global_avg_pool2d
+ quantize_context().stop_quantize()
+ return expr
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=unused-argument,inconsistent-return-statements
+"""Internal module for registering attribute for annotation."""
+from __future__ import absolute_import
+
+from ... import target as _target
+from .. import expr as _expr
+from .. import analysis as _analysis
+from ..base import register_relay_node
+from ..op import op as _reg
+from . import _quantize
+from .quantize import _forward_op
+
+def register_partition_function(op_name, frewrite=None, level=10):
+ def _register(func):
+ return _reg._Register(op_name, "FQPartitionRewrite", func, level)
+ return _register(frewrite) if frewrite is not None else _register
+
+
+@register_relay_node
+class QPartitionExpr(_expr.TempExpr):
+ def __init__(self, expr):
+ self.__init_handle_by_constructor__(
+ _quantize.make_partition_expr, expr)
+
+
+def partition_expr_check(expr):
+ if isinstance(expr, QPartitionExpr):
+ return True, expr.expr
+ return False, expr
+
+
+@register_partition_function("nn.conv2d")
+def conv2d_partition_function(ref_call, new_args, ctx):
+ """Rewrite function for conv2d for partition"""
+ data_cond, data = partition_expr_check(new_args[0])
+ kernel_cond, kernel = partition_expr_check(new_args[1])
+
+ assert not kernel_cond
+ if data_cond:
+ data = new_args[0].realize()
+ ret = _forward_op(ref_call, [data, kernel])
+ return QPartitionExpr(ret)
+
+
+def identity_partition_function(ref_call, new_args, ctx):
+ cond, expr = partition_expr_check(new_args[0])
+ if cond:
+ return QPartitionExpr(_forward_op(ref_call, [expr]))
+ return None
+
+register_partition_function("clip", identity_partition_function)
+register_partition_function("nn.relu", identity_partition_function)
+register_partition_function("nn.max_pool2d", identity_partition_function)
+
+
+def add_partition_generic(ref_call, new_args, ctx):
+ """Rewrite function for ewise add for partition for generic devices"""
+ lhs_cond, lhs = partition_expr_check(new_args[0])
+ rhs_cond, rhs = partition_expr_check(new_args[1])
+ if lhs_cond and rhs_cond:
+ # - introduced by ResNet, when for the first residual connection
+ # ...
+ # %0 = nn.conv2d(%data, %meta[relay.Constant])
+ # %1 = add(%0, %meta[relay.Constant])
+ # %2 = nn.relu(%1)
+ # %3 = nn.max_pool2d(%2)
+ # ...
+ # %9 = nn.conv2d(%8, %meta[relay.Constant])
+ # %10 = add(%9, %meta[relay.Constant])
+ # %11 = add(%3, %10) <- need to insert annotations for %3, %10
+ # ...
+ lhs = new_args[0].realize()
+ rhs = new_args[1].realize()
+ return _forward_op(ref_call, [lhs, rhs])
+ elif not lhs_cond and rhs_cond:
+ # - introduced by residual connection in ResNet
+ # ...
+ # %13 = nn.conv2d(%12, %meta[relay.Constant])
+ # %14 = add(%13, %meta[relay.Constant])
+ # %15 = annotation.cast_hint(%15, 'int8')
+ # %16 = annotation.stop_fusion(%16)
+ # %17 = add(%5, %16)
+ # %18 = nn.relu(%17)
+ # ...
+ # %24 = nn.conv2d(%23, %meta[relay.Constant])
+ # %25 = add(%24, %meta[relay.Constant])
+ # %26 = add(%18, %25) <- need to insert annotations for %25
+ # ...
+ rhs = new_args[1].realize()
+ return _forward_op(ref_call, [lhs, rhs])
+ elif lhs_cond and not rhs_cond:
+ if _analysis.check_constant(rhs):
+ # - introduced by batch_norm: add(out, bias)
+ return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
+ # - introduced by residual connection in MobileNetV2
+ # ...
+ # %81 = add(%80, meta[relay.Constant])
+ # %82 = annotation.cast_hint(%81, 'int8')
+ # %83 = annotation.stop_fusion(%82)
+ # %84 = add(%79, %83)
+ # ...
+ # %96 = nn.conv2d(%94, %meta[relay.Constant])
+ # %96 = add(%95, %meta[relay.Constant])
+ # %97 = add(%96, %84) <- need to insert annotations for %96
+ # ...
+ lhs = new_args[0].realize()
+ return _forward_op(ref_call, [lhs, rhs])
+ elif not lhs_cond and not rhs_cond:
+ # trivial case
+ return None
+ else:
+ raise ValueError
+
+
+# TODO(ziheng) enhance `register_partition_function` to dispatch
+# for target automatically
+@register_partition_function("add")
+def add_partition_function(ref_call, new_args, ctx):
+ """Rewrite function for ewise add for partition"""
+ if 'cuda' in _target.current_target().keys:
+ #TODO(wuwei/ziheng) cuda specific rules
+ return add_partition_generic(ref_call, new_args, ctx)
+ return add_partition_generic(ref_call, new_args, ctx)
+
+
+@register_partition_function("multiply")
+def multiply_partition_function(ref_call, new_args, ctx):
+ """Rewrite function for ewise add for partition"""
+ lhs_cond, lhs = partition_expr_check(new_args[0])
+ rhs_cond, rhs = partition_expr_check(new_args[1])
+ if lhs_cond:
+ # introduced by bn: multiply(out, scale)
+ return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
+ assert (not lhs_cond) and (not rhs_cond)
+ return None
return str_map[kind]
+def _forward_op(ref_call, args):
+ """forward the operator of ref_call with provided arguments"""
+ return _expr.Call(
+ ref_call.op, args, ref_call.attrs, ref_call.type_args)
+
+
@register_relay_node("relay.quantize.QConfig")
class QConfig(NodeBase):
"""Configure the quantization behavior by setting config variables.
"dtype_activation": "int32",
"global_scale": 8.0,
"skip_conv_layers": [0],
+ "do_simulation": False,
"round_for_shift": True,
- "store_lowbit_output": True,
"debug_enabled_ops": None,
}
self.handle = handle
def guard(self, ref_call):
+ """Return true if op is enabled, otherwise return false"""
op_name = ref_call.op.name
if self.debug_enabled_ops is not None:
name_list = [x.value for x in self.debug_enabled_ops]
"""Get the current quantization configuration."""
return _quantize._GetCurrentQConfig()
-# TODO(tmoreau89, ZihengJiang) the skip parameters are
-# hacky - we should explore a more future-proof way to
-# skip operators based on pattern matching
+
def qconfig(**kwargs):
"""Configure the quantization behavior by setting config variables.
skip_conv_layers: list
Specifying which layers to be skipped. Provide a list of indices
- that indicate which conv2d layers to leave untouched.
+ that indicate which conv2d layers to leave untouched. Start from 0.
+
+ do_simulation: boolean
+ Whether to do simulation with float operation only.
round_for_shift: boolean
Whether to add bias for rounding during shift.
- store_lowbit_output: boolean
- Whether to store low-bit integer back as output before dequantizing.
- Some accelerators need this, e.g. VTA.
-
debug_enabled_ops: None or list of str
Partially quantize specified operators for debugging. The default value
is None, which means will try to call all operartors' annotate rewrite
return _make.node("relay.quantize.QConfig", **node_args)
-class AnnotateContext(object):
- """A global singleton annotate scope"""
+class QuantizeContext(object):
+ """An internal used global context object for annotation,
+ for putting some state variables like `conv2d_counter`."""
Current = None
def __init__(self):
self.qnode_map = dict()
self._conv2d_counter = 0
+ self._stop_quantize = False
+
+ def check_to_skip(self, ref_call):
+ """Check the index of conv2d layer to decide whether to
+ skip the current operator."""
+ if self._stop_quantize:
+ return True
+
+ if current_qconfig().skip_conv_layers is not None:
+ # check skip conv layers
+ skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+ if self._conv2d_counter in skipped_indices:
+ if ref_call.op.name == 'nn.conv2d':
+ self._conv2d_counter += 1
+ return True
+ if ref_call.op.name == 'nn.conv2d':
+ self._conv2d_counter += 1
+
+ return False
+
+ def stop_quantize(self):
+ self._stop_quantize = True
+
+ def reset(self):
+ self._conv2d_counter = 0
+ self._stop_quantize = False
def __enter__(self):
- self._conv2d_counter = 0
+ self.reset()
return self
- def conv2d_counter(self):
- """Get the counter for conv2d."""
- return self._conv2d_counter
-
- def count_conv2d(self):
- """Increase the value of the conv2d counter by one."""
- self._conv2d_counter += 1
-
def __exit__(self, ptype, value, traceback):
pass
-def annotate_context():
+def quantize_context():
"""Get the global singleton scope"""
- if AnnotateContext.Current is None:
- AnnotateContext.Current = AnnotateContext()
- return AnnotateContext.Current
+ if QuantizeContext.Current is None:
+ QuantizeContext.Current = QuantizeContext()
+ return QuantizeContext.Current
+
+
+def partition():
+ """Partition graph into small low-precision sections by `cast_hint` and
+ `stop_fusion`.
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass for VTA rewrite.
+ """
+ return _quantize.QuantizePartition()
+
+
+def annotate():
+ """Given a float32 graph, this pass will rewrite the graph and return
+ a graph which simulates the error brought by the current quantization
+ scheme.
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass for quantization annotation.
+ """
+ return _quantize.QuantizeAnnotate()
def collect_stats(graph):
const_params[nclip_max] = _make_const((valid_range - 1))
_analysis.post_order_visit(graph, visit_func)
- return _expr.bind(graph, const_params)
-
-
-def annotate():
- """Given a float32 graph, this pass will rewrite the graph and return
- a graph which simulates the error brought by the current quantization
- scheme.
-
- Returns
- -------
- ret: tvm.relay.Pass
- The registered pass for quantization annotation.
- """
- return _quantize.QuantizeAnnotate()
+ ret = _expr.bind(graph, const_params)
+ return ret
def realize():
return _quantize.QuantizeRealize()
-def rewrite_for_vta():
- """Performs rewriting for VTA target.
-
- Returns
- -------
- ret: tvm.relay.Pass
- The registered pass for VTA rewrite.
- """
- return _quantize.QuantizeRewriteForVTA()
-
-
def _bind_params(func, params):
"""Bind the params to the expression.
"""
return _expr.bind(func, bind_dict)
+def prerequisite_optimize(graph, params=None):
+ """ Prerequisite optimization passes for quantization. Perform
+ "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
+ "CanonicalizeOps" optimization before quantization. """
+ optimize = _transform.Sequential([_transform.SimplifyInference(),
+ _transform.FoldConstant(),
+ _transform.FoldScaleAxis(),
+ _transform.CanonicalizeOps(),
+ _transform.FoldConstant()])
+
+ if params:
+ graph = _bind_params(graph, params)
+
+ mod = _module.Module.from_expr(graph)
+ with _transform.PassContext(opt_level=3):
+ mod = optimize(mod)
+ return mod["main"]
+
+
def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
ret: Function
The graph after quantization
"""
- if params:
- graph = _bind_params(graph, params)
+ graph = prerequisite_optimize(graph, params)
mod = _module.Module.from_expr(graph)
- # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
- # "CanonicalizeOps" optimization before quantization.
- optimize = _transform.Sequential([_transform.SimplifyInference(),
- _transform.FoldConstant(),
- _transform.FoldScaleAxis(),
- _transform.CanonicalizeOps(),
- _transform.FoldConstant()])
-
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
- # Quantize pass list
- quant_passes = [annotate(),
- calibrate_pass,
- realize(),
- _transform.FoldConstant()]
- if current_qconfig().store_lowbit_output:
- quant_passes = [rewrite_for_vta()] + quant_passes
+ quant_passes = [partition(),
+ annotate(),
+ calibrate_pass]
+ if not current_qconfig().do_simulation:
+ quant_passes.append(realize())
+ quant_passes.append(_transform.FoldConstant())
quantize_seq = _transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
- mod = optimize(mod)
- mod = quantize_seq(mod)
+ with quantize_context():
+ mod = quantize_seq(mod)
return mod["main"]
return {topi::identity(inputs[0])};
});
-Expr ForceCast(Expr data) {
- static const Op& op = Op::Get("annotation.force_cast");
- return CallNode::make(op, {data}, Attrs{}, {});
+// relay.annotation.cast_hint
+TVM_REGISTER_NODE_TYPE(CastHintAttrs);
+
+Expr CastHint(Expr data, DataType dtype) {
+ auto attrs = make_node<CastHintAttrs>();
+ attrs->dtype = dtype;
+ static const Op& op = Op::Get("annotation.cast_hint");
+ return CallNode::make(op, {data}, Attrs{attrs}, {});
}
-RELAY_REGISTER_OP("annotation.force_cast")
-.describe(R"code(Annotate an expression to force a cast.)code"
+RELAY_REGISTER_OP("annotation.cast_hint")
+.describe(R"code(Annotate an expression to be cast into specific data type.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input data.")
}
};
+bool ConstantCheck(const Expr& e) {
+ return ConstantChecker().Check(e);
+}
+
+TVM_REGISTER_API("relay._analysis.check_constant")
+.set_body_typed(ConstantCheck);
+
// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
+#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/reduce.h>
Expr StopFusion(Expr data);
-Expr ForceCast(Expr data);
+Expr CastHint(Expr data, DataType dtype);
} // namespace relay
} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ *
+ * \file annotate.cc
+ *
+ * \brief Annotating the graph with simulated quantize operators.
+ */
+
+#include <tvm/relay/transform.h>
+#include <tvm/relay/analysis.h>
+#include "./quantize.h"
+
+namespace tvm {
+namespace relay {
+namespace quantize {
+
+using namespace relay::transform;
+
+class QAnnotateExpr;
+class QAnnotateExprNode : public TempExprNode {
+ public:
+ Expr expr;
+ QAnnotateKind kind;
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("expr", &expr);
+ v->Visit("kind", &kind);
+ }
+
+ TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
+
+ Expr Realize() const final;
+
+ static constexpr const char* _type_key = "relay.QAnnotateExpr";
+ TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
+
+
+Expr QAnnotateExprNode::Realize() const {
+ return expr;
+}
+
+QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
+ auto rnode = make_node<QAnnotateExprNode>();
+ rnode->expr = expr;
+ rnode->kind = kind;
+ return QAnnotateExpr(rnode);
+}
+
+TVM_REGISTER_API("relay._quantize.make_annotate_expr")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+ *ret = QAnnotateExprNode::make(args[0],
+ static_cast<QAnnotateKind>(args[1].operator int()));
+ });
+
+
+Pass QuantizeAnnotate() {
+ // TODO(tvm-teams): since partition has added cast_hint in different
+ // branches, try to remove this in the future.
+ std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
+ if (e->derived_from<TempExprNode>()) {
+ const auto* n = e.as<QAnnotateExprNode>();
+ CHECK(n);
+ const PackedFunc* f =
+ runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
+ Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
+ return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
+ }
+ return e;
+ };
+
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
+ auto new_params = func->params;
+ for (const auto& x : FreeVars(func)) {
+ new_params.push_back(x);
+ }
+ return FunctionNode::make(new_params,
+ func->body,
+ func->ret_type,
+ func->type_params,
+ func->attrs);
+ };
+ return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
+.set_body_typed(QuantizeAnnotate);
+
+} // namespace quantize
+} // namespace relay
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ *
+ * \file partition.cc
+ *
+ * \brief Partition a graph into sections for quantization.
+ */
+
+#include <tvm/relay/transform.h>
+#include "../pattern_util.h"
+#include "./quantize.h"
+
+namespace tvm {
+namespace relay {
+namespace quantize {
+
+using namespace relay::transform;
+
+class QPartitionExpr;
+class QPartitionExprNode : public TempExprNode {
+ public:
+ /*! \brief The original expression */
+ Expr expr;
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("expr", &expr);
+ }
+
+ TVM_DLL static QPartitionExpr make(Expr expr);
+
+ Expr Realize() const final;
+
+ static constexpr const char* _type_key = "relay.QPartitionExpr";
+ TVM_DECLARE_NODE_TYPE_INFO(QPartitionExprNode, TempExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(QPartitionExpr, QPartitionExprNode, TempExpr);
+
+
+Expr QPartitionExprNode::Realize() const {
+ // insert cast hint and stop fusion
+ const QConfig& cfg = QConfig::Current();
+ Expr ret = CastHint(this->expr, cfg->dtype_input);
+ return StopFusion(ret);
+}
+
+QPartitionExpr QPartitionExprNode::make(Expr expr) {
+ auto rnode = make_node<QPartitionExprNode>();
+ rnode->expr = expr;
+ return QPartitionExpr(rnode);
+}
+
+TVM_REGISTER_API("relay._quantize.make_partition_expr")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+ *ret = QPartitionExprNode::make(args[0]);
+ });
+
+Pass QuantizePartition() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ auto ret = Downcast<Function>(
+ ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
+ return ret;
+ };
+ return CreateFunctionPass(pass_func, 1, "QuantizePartition", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizePartition")
+.set_body_typed(QuantizePartition);
+
+} // namespace quantize
+} // namespace relay
+} // namespace tvm
* for compression and acceleration.
*/
#include <dmlc/thread_local.h>
-#include <tvm/base.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
-#include <cmath>
-#include <string>
-#include <vector>
#include <stack>
-#include <utility>
-#include "../pattern_util.h"
#include "./quantize.h"
namespace relay {
namespace quantize {
-using namespace relay::transform;
-
TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
bool SimulatedQuantizeRel(const Array<Type>& types,
});
-// =============
-// annotate pass
-
-Expr QAnnotateExprNode::Realize() const {
- const auto& cfg = QConfig::Current();
- if (cfg->store_lowbit_output) {
- // store low bit output back for VTA
- const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
- return (*f)(this->expr, static_cast<int>(kQInput));
- } else {
- return expr;
- }
-}
-
-QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
- auto rnode = make_node<QAnnotateExprNode>();
- rnode->expr = expr;
- rnode->kind = kind;
- return QAnnotateExpr(rnode);
-}
-
-TVM_REGISTER_API("relay._quantize.make_annotate_expr")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = QAnnotateExprNode::make(args[0],
- static_cast<QAnnotateKind>(args[1].operator int()));
- });
-
-
-// =============
-// realize pass
-
-Expr QRealizeIntExprNode::Realize() const {
- const auto& cfg = QConfig::Current();
- Expr data = this->data;
- if (cfg->store_lowbit_output) {
- data = Cast(data, cfg->dtype_input);
- }
- // dequantize
- data = Cast(data, Float(32));
- data = Multiply(data, this->dom_scale);
- return data;
-}
-
-QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
- NodePtr<QRealizeIntExprNode> n = make_node<QRealizeIntExprNode>();
- n->data = std::move(data);
- n->dom_scale = std::move(dom_scale);
- n->dtype = std::move(dtype);
- return QRealizeIntExpr(n);
-}
-
-
-inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
- return CallNode::make(ref_call->op,
- args, ref_call->attrs, ref_call->type_args);
-}
-
-
-/* calculate `data * s1 / s2`, use shift if possible */
-inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
- // here we assume the dtype of data is dtype activation
- if (s1 == s2) return data;
-
- float factor = s1 / s2;
- float shift_factor = std::log2(factor);
- CHECK_GT(shift_factor, 0);
- if (static_cast<int>(shift_factor) == shift_factor) {
- return LeftShift(data, MakeConstantScalar(dtype,
- static_cast<int>(shift_factor)));
- } else if (static_cast<int>(factor) == factor) {
- return Multiply(data, MakeConstantScalar(dtype, factor));
- } else {
- data = Cast(data, Float(32));
- data = Multiply(data, MakeConstantScalar(Float(32), factor));
- return Cast(Round(data), dtype);
- }
-}
-
-Expr QuantizeRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- const QConfig& cfg = QConfig::Current();
- // do not handle data type cast
- const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
- CHECK_EQ(param->rounding, "round");
-
- Expr dom_scale = new_args[1];
- Expr clip_min = new_args[2];
- Expr clip_max = new_args[3];
-
- float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
- float clip_min_imm = GetScalarFromConstant<float>(clip_min);
- float clip_max_imm = GetScalarFromConstant<float>(clip_max);
-
- // x * idom_scale = y * odom_scale
- // => y = x * idom_scale / odom_scale
- if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
- // int32->int8
- Expr data = n->data;
- float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
- float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
- if (idom_scale_imm == odom_scale_imm) {
- // same domain scale, only clip
- data = Clip(data, clip_min_imm, clip_max_imm);
- return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
- }
-
- float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
- CHECK_NE(shift_nbit, 0);
- if (static_cast<int>(shift_nbit) == shift_nbit) {
- if (shift_nbit > 0) {
- // use right shift
- if (cfg->round_for_shift) {
- float round_bias = std::pow(2.0, shift_nbit - 1);
- data = Add(data, MakeConstantScalar(cfg->dtype_activation,
- static_cast<int>(round_bias)));
- }
- data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
- static_cast<int>(shift_nbit)));
- } else {
- data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
- static_cast<int>(shift_nbit)));
- }
- data = Clip(data, clip_min_imm, clip_max_imm);
- return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
- } else {
- // float computation
- data = Cast(data, Float(32));
- Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
- Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
- return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
- }
- }
-
- // quantize from real
- CHECK(!new_args[0]->derived_from<TempExprNode>());
- Expr data = new_args[0];
- Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm));
- Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
- return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
-}
-
-Expr FoldConstantOpt(const Expr& expr) {
- auto mod = ModuleNode::FromExpr(expr);
- mod = transform::FoldConstant()(mod);
- auto entry_func = mod->Lookup("main");
- return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
-}
-
-RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
-
-
-Expr Conv2dRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- const QConfig& cfg = QConfig::Current();
- CHECK_EQ(new_args.size(), 2);
- if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) {
- return Expr(nullptr);
- }
- const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
- CHECK(lhs);
- const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
- CHECK(rhs);
-
- Expr ldata = lhs->data;
- if (lhs->dtype != cfg->dtype_input) {
- ldata = Cast(ldata, cfg->dtype_input);
- }
- Expr rdata = Cast(rhs->data, cfg->dtype_weight);
-
- const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
- auto attrs = make_node<Conv2DAttrs>();
- *attrs = *ref_attrs;
- DataType out_dtype = cfg->dtype_activation;
- attrs->out_dtype = out_dtype;
-
- Expr ret = CallNode::make(ref_call->op,
- {ldata, rdata}, Attrs(attrs), ref_call->type_args);
- Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
- Expr dom_scale = FoldConstantOpt(mul);
- return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
-}
-
-RELAY_REGISTER_OP("nn.conv2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
-
-
-Expr DenseRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- const QConfig& cfg = QConfig::Current();
- CHECK_EQ(new_args.size(), 2);
- if (!new_args[0]->derived_from<TempExprNode>() || !new_args[1]->derived_from<TempExprNode>()) {
- return Expr(nullptr);
- }
- const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
- const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
-
- Expr ldata = lhs->data;
- if (lhs->dtype != cfg->dtype_input) {
- ldata = Cast(ldata, cfg->dtype_input);
- }
- Expr rdata = Cast(rhs->data, cfg->dtype_weight);
-
- const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
- auto attrs = make_node<DenseAttrs>();
- *attrs = *ref_attrs;
- DataType out_dtype = cfg->dtype_activation;
- attrs->out_dtype = out_dtype;
-
- Expr ret = CallNode::make(ref_call->op,
- {ldata, rdata}, Attrs(attrs), ref_call->type_args);
- Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
- Expr dom_scale = FoldConstantOpt(mul);
- return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
-}
-
-RELAY_REGISTER_OP("nn.dense")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
-
-
-Expr MulRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- const QConfig& cfg = QConfig::Current();
- CHECK_EQ(new_args.size(), 2);
- if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
- // execute the operation with activation data type.
- const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
- const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
- Expr ldata = lhs->data;
- Expr rdata = rhs->data;
-
- DataType dtype = cfg->dtype_activation;
- if (lhs->dtype != dtype) {
- ldata = Cast(ldata, dtype);
- }
- if (rhs->dtype != dtype) {
- rdata = Cast(rdata, dtype);
- }
-
- Expr ret = ForwardOp(ref_call, {ldata, rdata});
- Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
- Expr dom_scale = FoldConstantOpt(mul);
- return QRealizeIntExprNode::make(ret, dom_scale, dtype);
- }
- CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
- return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("multiply")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
-
-
-float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
- if (nptrs.size() == 2) {
- // x = a * s1, y = b * s2
- // x + y = (a * s1 / s2 + b) * s2, if s1 > s2
- // = (a + b * s2 / s1) * s1, if s2 > s1
- float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
- float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
- return s1 > s2 ? s2 : s1;
- } else {
- const QConfig& cfg = QConfig::Current();
- float scale = cfg->global_scale;
- return scale / std::pow(2.0, cfg->nbit_activation - 1);
- }
-}
-
-
-/* \brief Unify the dom scale of arguments */
-Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
- DataType* dtype_ptr, Expr* scale_ptr) {
- static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
- const QConfig& cfg = QConfig::Current();
-
- std::vector<const QRealizeIntExprNode*> nptrs;
- Array<Expr> ret;
- for (auto arg : args) {
- const auto* nptr = arg.as<QRealizeIntExprNode>();
- CHECK(nptr);
- nptrs.push_back(nptr);
- ret.push_back(nptr->data);
- }
-
- // unify the data type
- CHECK_EQ(ref_args.size(), args.size());
- DataType dtype;
- if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
- dtype = cfg->dtype_input;
- } else {
- dtype = cfg->dtype_activation;
- }
- for (size_t i = 0; i < ret.size(); ++i) {
- auto ref_arg = ref_args[i].as<CallNode>();
- if (nptrs[i]->dtype != dtype) {
- ret.Set(i, Cast(ret[i], dtype));
- } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
- ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
- auto new_arg = Cast(ret[i], cfg->dtype_input);
- new_arg = StopFusion(new_arg);
- ret.Set(i, Cast(new_arg, dtype));
- }
- }
-
- // unify the dom_scale
- float s = ChooseDomScale(nptrs);
- Expr dom_scale = MakeConstantScalar(Float(32), s);
- for (size_t i = 0; i < ret.size(); ++i) {
- float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
- ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
- }
-
- *dtype_ptr = dtype;
- *scale_ptr = dom_scale;
- return ret;
-}
-
-Expr AddRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- CHECK_EQ(new_args.size(), 2);
- if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
- DataType dtype;
- Expr dom_scale;
- Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
- Expr ret = ForwardOp(ref_call, ret_args);
- return QRealizeIntExprNode::make(ret, dom_scale, dtype);
- }
-
- CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
- return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("add")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
-
-Expr ClipRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- CHECK_EQ(new_args.size(), 1);
- if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
- const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
- auto attrs = make_node<ClipAttrs>();
- double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
- attrs->a_min = ref_attrs->a_min / dom_scale;
- attrs->a_max = ref_attrs->a_max / dom_scale;
-
- Expr ret = CallNode::make(ref_call->op,
- {n->data}, Attrs(attrs), ref_call->type_args);
- return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
- }
- CHECK(!new_args[0]->derived_from<TempExprNode>());
- return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("clip")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
-
-
-Expr ConcatenateRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- CHECK_EQ(new_args.size(), 1);
- CHECK_EQ(ref_call->args.size(), 1);
-
- const auto* tuple = new_args[0].as<TupleNode>();
- const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
- CHECK(tuple);
- CHECK(ref_tuple);
- const Array<Expr>& arr = tuple->fields;
- const Array<Expr>& ref_arr = ref_tuple->fields;
-
- if (arr[0].as<QRealizeIntExprNode>()) {
- DataType dtype;
- Expr dom_scale;
- Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
- Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
- return QRealizeIntExprNode::make(ret, dom_scale, dtype);
- } else {
- for (auto arg : new_args) {
- CHECK(!arg->derived_from<TempExprNode>());
- }
- return Expr(nullptr);
- }
-}
-
-RELAY_REGISTER_OP("concatenate")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
-
-
-/* \brief forward the original operator */
-Expr IdentityRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- CHECK_EQ(new_args.size(), 1);
- if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
- Expr ret = ForwardOp(ref_call, {n->data});
- return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
- }
- CHECK(!new_args[0]->derived_from<TempExprNode>());
- return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("nn.relu")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
-
-RELAY_REGISTER_OP("strided_slice")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
-
-RELAY_REGISTER_OP("annotation.stop_fusion")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
-
-/* \brief for unary operators which requantize its input to dtype_nbit */
-Expr CastDtypeInputRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- const QConfig& cfg = QConfig::Current();
- CHECK_EQ(new_args.size(), 1);
- if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
- Expr data = Cast(n->data, cfg->dtype_input);
- Expr ret = ForwardOp(ref_call, {data});
- return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
- }
- CHECK(!new_args[0]->derived_from<TempExprNode>());
- return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("nn.max_pool2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
-
-
-Expr AvgPoolRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- const QConfig& cfg = QConfig::Current();
- CHECK_EQ(new_args.size(), 1);
- if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
- Expr data = n->data;
- if (n->dtype != cfg->dtype_activation) {
- data = Cast(n->data, cfg->dtype_activation);
- }
- Expr ret = ForwardOp(ref_call, {data});
- return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
- }
- CHECK(!new_args[0]->derived_from<TempExprNode>());
- return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("nn.avg_pool2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
-
-Expr ForceCastRealize(const Call& ref_call,
- const Array<Expr>& new_args,
- const NodeRef& ctx) {
- const QConfig& cfg = QConfig::Current();
- CHECK_EQ(new_args.size(), 1);
- if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
- Expr ret = Cast(n->data, cfg->dtype_input);
- return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
- }
- CHECK(!new_args[0]->derived_from<TempExprNode>());
- return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("annotation.force_cast")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", ForceCastRealize);
-
-TVM_REGISTER_API("relay._quantize.realize")
-.set_body_typed<Expr(Expr)>([](const Expr& e) {
- Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
- return ret;
-});
-
-
-// =============
-// qconfig
-
-QConfig qconfig() {
- return QConfig(make_node<QConfigNode>());
-}
-
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMQConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */
std::stack<QConfig> context_stack;
TVMQConfigThreadLocalEntry() :
- default_config(qconfig()) {
+ default_config(make_node<QConfigNode>()) {
}
};
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
+ p->stream << "do_simulation==" << op->do_simulation << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
- p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << ")";
});
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body_typed(QConfig::ExitQConfigScope);
-Pass QuantizeAnnotate() {
- std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
- if (e->derived_from<TempExprNode>()) {
- const auto* n = e.as<QAnnotateExprNode>();
- CHECK(n);
- const PackedFunc* f =
- runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
- Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
- return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
- }
- return e;
- };
-
- runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
- [=](Function f, Module m, PassContext pc) {
- auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
- auto new_params = func->params;
- for (const auto& x : FreeVars(func)) {
- new_params.push_back(x);
- }
- return FunctionNode::make(new_params,
- func->body,
- func->ret_type,
- func->type_params,
- func->attrs);
- };
- return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
-}
-
-TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
-.set_body_typed(QuantizeAnnotate);
-
-Pass QuantizeRealizePass() {
- runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
- [=](Function f, Module m, PassContext pc) {
- return Downcast<Function>(
- ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
- };
- return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
-}
-
-TVM_REGISTER_API("relay._quantize.QuantizeRealize")
-.set_body_typed(QuantizeRealizePass);
-
-Pass QuantizeRewriteForVTAPass() {
- runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
- [=](Function f, Module m, PassContext pc) {
- return Downcast<Function>(
- ForwardRewrite(f, "FQVTARewrite", nullptr, nullptr));
- };
- return CreateFunctionPass(pass_func, 1, "QuantizeRewriteForVTA", {});
-}
-
-TVM_REGISTER_API("relay._quantize.QuantizeRewriteForVTA")
-.set_body_typed(QuantizeRewriteForVTAPass);
-
-// =============
-// Insert stop_fusion for vta.
-
-
-Expr QVTAExprNode::Realize() const {
- Expr ret = ForceCast(this->expr);
- return StopFusion(ret);
-}
-
-QVTAExpr QVTAExprNode::make(Expr expr) {
- auto rnode = make_node<QVTAExprNode>();
- rnode->expr = expr;
- return QVTAExpr(rnode);
-}
-
-TVM_REGISTER_API("relay._quantize.make_vta_expr")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = QVTAExprNode::make(args[0]);
- });
-
-TVM_REGISTER_API("relay._quantize.make_stop_fusion")
-.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
- return StopFusion(expr);
-});
-
-TVM_REGISTER_API("relay._quantize.temp_expr_realize")
-.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
- const QVTAExprNode* n = expr.as<QVTAExprNode>();
- CHECK(n);
- return n->Realize();
-});
-
-
} // namespace quantize
} // namespace relay
} // namespace tvm
}
};
-/*!
- * \brief TempExpr used during annotate forward rewrite.
- */
-class QAnnotateExpr;
-/*!
- * \brief TempExprNode used during annotate forward rewrite.
- */
-class QAnnotateExprNode : public TempExprNode {
- public:
- /*! \brief The original expression */
- Expr expr;
- /*! \brief The kind of annotate field */
- QAnnotateKind kind;
-
- void VisitAttrs(tvm::AttrVisitor* v) final {
- v->Visit("expr", &expr);
- v->Visit("kind", &kind);
- }
-
- TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
-
- Expr Realize() const final;
-
- static constexpr const char* _type_key = "relay.QAnnotateExpr";
- TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
-
-
-/*!
- * \brief TempExpr used to insert `force_cast` for VTA.
- */
-class QVTAExpr;
-/*!
- * \brief TempExprNode used to insert `force_cast` for VTA.
- */
-class QVTAExprNode : public TempExprNode {
- public:
- /*! \brief The original expression */
- Expr expr;
-
- void VisitAttrs(tvm::AttrVisitor* v) final {
- v->Visit("expr", &expr);
- }
-
- TVM_DLL static QVTAExpr make(Expr expr);
-
- Expr Realize() const final;
-
- static constexpr const char* _type_key = "relay.QVTAExpr";
- TVM_DECLARE_NODE_TYPE_INFO(QVTAExprNode, TempExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(QVTAExpr, QVTAExprNode, TempExpr);
-
-
-/*! \brief TempExpr used during realize forward rewrite. */
-class QRealizeExpr;
-/*! \brief TempExpr representing integer. */
-class QRealizeIntExpr;
-
-class QRealizeExprNode : public TempExprNode {
- public:
- /*! \brief The original expression */
- Expr data;
- static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
- TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr);
-
-
-class QRealizeIntExprNode : public QRealizeExprNode {
- public:
- Expr dom_scale;
- /*! \brief current data type */
- DataType dtype;
-
- void VisitAttrs(tvm::AttrVisitor* v) final {
- v->Visit("data", &data);
- v->Visit("dom_scale", &dom_scale);
- v->Visit("dtype", &dtype);
- }
-
- Expr Realize() const final;
-
- TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
-
- static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
- TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr);
-
class QConfig;
-
/*!
* \brief Container for build configuration options
*/
DataType dtype_activation = Int(32);
double global_scale = 8.0;
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
+ bool do_simulation = false;
bool round_for_shift = true;
- bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale);
v->Visit("skip_conv_layers", &skip_conv_layers);
+ v->Visit("do_simulation", &do_simulation);
v->Visit("round_for_shift", &round_for_shift);
- v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
}
}
};
-/*!
-* \brief Construct a BuildConfig containing a new BuildConfigNode
-* \return The new BuildConfig
-*/
-TVM_DLL QConfig qconfig();
-
} // namespace quantize
} // namespace relay
} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ *
+ * \file realize.cc
+ *
+ * \brief Realizing the simulated graph into real low-precision
+ * graph.
+ */
+
+#include <tvm/relay/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include "./quantize.h"
+#include "../pattern_util.h"
+
+namespace tvm {
+namespace relay {
+namespace quantize {
+
+using namespace relay::transform;
+
+class QRealizeExpr;
+class QRealizeIntExpr;
+
+class QRealizeExprNode : public TempExprNode {
+ public:
+ Expr data;
+ static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
+ TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr);
+
+
+class QRealizeIntExprNode : public QRealizeExprNode {
+ public:
+ Expr dom_scale;
+ DataType dtype;
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("data", &data);
+ v->Visit("dom_scale", &dom_scale);
+ v->Visit("dtype", &dtype);
+ }
+
+ Expr Realize() const final;
+
+ TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
+
+ static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
+ TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr);
+
+
+Expr QRealizeIntExprNode::Realize() const {
+ Expr data = this->data;
+ // dequantize
+ data = Cast(data, Float(32));
+ data = Multiply(data, this->dom_scale);
+ return data;
+}
+
+QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
+ NodePtr<QRealizeIntExprNode> n = make_node<QRealizeIntExprNode>();
+ n->data = std::move(data);
+ n->dom_scale = std::move(dom_scale);
+ n->dtype = std::move(dtype);
+ return QRealizeIntExpr(n);
+}
+
+
+inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
+ return CallNode::make(ref_call->op,
+ args, ref_call->attrs, ref_call->type_args);
+}
+
+
+/* calculate `data * s1 / s2`, use shift if possible */
+inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
+ // here we assume the dtype of data is dtype activation
+ if (s1 == s2) return data;
+
+ float factor = s1 / s2;
+ float shift_factor = std::log2(factor);
+ CHECK_GT(shift_factor, 0);
+ if (static_cast<int>(shift_factor) == shift_factor) {
+ return LeftShift(data, MakeConstantScalar(dtype,
+ static_cast<int>(shift_factor)));
+ } else if (static_cast<int>(factor) == factor) {
+ return Multiply(data, MakeConstantScalar(dtype, factor));
+ } else {
+ LOG(FATAL) << "fall back to float computation";
+ data = Cast(data, Float(32));
+ data = Multiply(data, MakeConstantScalar(Float(32), factor));
+ return Cast(Round(data), dtype);
+ }
+}
+
+Expr QuantizeRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ const QConfig& cfg = QConfig::Current();
+ // do not handle data type cast
+ const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
+ CHECK_EQ(param->rounding, "round");
+
+ Expr dom_scale = new_args[1];
+ Expr clip_min = new_args[2];
+ Expr clip_max = new_args[3];
+
+ float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
+ float clip_min_imm = GetScalarFromConstant<float>(clip_min);
+ float clip_max_imm = GetScalarFromConstant<float>(clip_max);
+
+ // x * idom_scale = y * odom_scale
+ // => y = x * idom_scale / odom_scale
+ if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+ // int32->int8
+ Expr data = n->data;
+ float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
+ float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
+ if (idom_scale_imm == odom_scale_imm) {
+ // same domain scale, only clip
+ data = Clip(data, clip_min_imm, clip_max_imm);
+ return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
+ }
+
+ float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
+ CHECK_GT(shift_nbit, 0);
+ if (static_cast<int>(shift_nbit) == shift_nbit) {
+ // use right shift
+ if (cfg->round_for_shift) {
+ float round_bias = std::pow(2.0, shift_nbit - 1);
+ data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
+ }
+ data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
+ static_cast<int>(shift_nbit)));
+ data = Clip(data, clip_min_imm, clip_max_imm);
+ return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
+ } else {
+ // float computation
+ data = Cast(data, Float(32));
+ Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
+ Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
+ return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
+ }
+ }
+
+ // quantize from real
+ CHECK(!new_args[0]->derived_from<TempExprNode>());
+ Expr data = new_args[0];
+ Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm));
+ Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
+ return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
+}
+
+Expr FoldConstantOpt(const Expr& expr) {
+ auto mod = ModuleNode::FromExpr(expr);
+ mod = transform::FoldConstant()(mod);
+ auto entry_func = mod->Lookup("main");
+ return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
+}
+
+RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
+
+
+Expr Conv2dRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ const QConfig& cfg = QConfig::Current();
+ CHECK_EQ(new_args.size(), 2);
+ if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) {
+ return Expr(nullptr);
+ }
+ const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
+ CHECK(lhs);
+ const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
+ CHECK(rhs);
+
+ Expr ldata = lhs->data;
+ if (lhs->dtype != cfg->dtype_input) {
+ ldata = Cast(ldata, cfg->dtype_input);
+ }
+ Expr rdata = Cast(rhs->data, cfg->dtype_weight);
+
+ const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
+ auto attrs = make_node<Conv2DAttrs>();
+ *attrs = *ref_attrs;
+ DataType out_dtype = cfg->dtype_activation;
+ attrs->out_dtype = out_dtype;
+
+ Expr ret = CallNode::make(ref_call->op,
+ {ldata, rdata}, Attrs(attrs), ref_call->type_args);
+ Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
+ Expr dom_scale = FoldConstantOpt(mul);
+ return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
+}
+
+RELAY_REGISTER_OP("nn.conv2d")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
+
+
+Expr DenseRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ const QConfig& cfg = QConfig::Current();
+ CHECK_EQ(new_args.size(), 2);
+ if (!new_args[0]->derived_from<TempExprNode>() || !new_args[1]->derived_from<TempExprNode>()) {
+ return Expr(nullptr);
+ }
+ const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
+ const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
+
+ Expr ldata = lhs->data;
+ if (lhs->dtype != cfg->dtype_input) {
+ ldata = Cast(ldata, cfg->dtype_input);
+ }
+ Expr rdata = Cast(rhs->data, cfg->dtype_weight);
+
+ const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
+ auto attrs = make_node<DenseAttrs>();
+ *attrs = *ref_attrs;
+ DataType out_dtype = cfg->dtype_activation;
+ attrs->out_dtype = out_dtype;
+
+ Expr ret = CallNode::make(ref_call->op,
+ {ldata, rdata}, Attrs(attrs), ref_call->type_args);
+ Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
+ Expr dom_scale = FoldConstantOpt(mul);
+ return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
+}
+
+RELAY_REGISTER_OP("nn.dense")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
+
+
+Expr MulRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ const QConfig& cfg = QConfig::Current();
+ CHECK_EQ(new_args.size(), 2);
+ if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
+ // execute the operation with activation data type.
+ const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
+ const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
+ Expr ldata = lhs->data;
+ Expr rdata = rhs->data;
+
+ DataType dtype = cfg->dtype_activation;
+ if (lhs->dtype != dtype) {
+ ldata = Cast(ldata, dtype);
+ } else {
+ CHECK_EQ(lhs->dtype, dtype);
+ }
+ if (rhs->dtype != dtype) {
+ rdata = Cast(rdata, dtype);
+ } else {
+ CHECK_EQ(rhs->dtype, dtype);
+ }
+
+ Expr ret = ForwardOp(ref_call, {ldata, rdata});
+ Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
+ Expr dom_scale = FoldConstantOpt(mul);
+ return QRealizeIntExprNode::make(ret, dom_scale, dtype);
+ }
+ CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
+ return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("multiply")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
+
+
+float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
+ if (nptrs.size() == 2) {
+ // x = a * s1, y = b * s2
+ // x + y = (a * s1 / s2 + b) * s2, if s1 > s2
+ // = (a + b * s2 / s1) * s1, if s2 > s1
+ float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
+ float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
+ return s1 > s2 ? s2 : s1;
+ } else {
+ const QConfig& cfg = QConfig::Current();
+ float scale = cfg->global_scale;
+ return scale / std::pow(2.0, cfg->nbit_activation - 1);
+ }
+}
+
+
+/* \brief Unify the dom scale of arguments */
+Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
+ DataType* dtype_ptr, Expr* scale_ptr) {
+ static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
+ const QConfig& cfg = QConfig::Current();
+
+ std::vector<const QRealizeIntExprNode*> nptrs;
+ Array<Expr> ret;
+ for (auto arg : args) {
+ const auto* nptr = arg.as<QRealizeIntExprNode>();
+ CHECK(nptr);
+ nptrs.push_back(nptr);
+ ret.push_back(nptr->data);
+ }
+
+ // unify the data type
+ CHECK_EQ(ref_args.size(), args.size());
+ DataType dtype;
+
+ if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
+ dtype = cfg->dtype_input;
+ } else {
+ dtype = cfg->dtype_activation;
+ }
+ for (size_t i = 0; i < ret.size(); ++i) {
+ auto ref_arg = ref_args[i].as<CallNode>();
+ if (nptrs[i]->dtype != dtype) {
+ ret.Set(i, Cast(ret[i], dtype));
+ } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
+ ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
+ auto new_arg = Cast(ret[i], cfg->dtype_input);
+ new_arg = StopFusion(new_arg);
+ ret.Set(i, Cast(new_arg, dtype));
+ }
+ }
+
+ // unify the dom_scale
+ float s = ChooseDomScale(nptrs);
+ Expr dom_scale = MakeConstantScalar(Float(32), s);
+ for (size_t i = 0; i < ret.size(); ++i) {
+ float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
+ ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
+ }
+
+ *dtype_ptr = dtype;
+ *scale_ptr = dom_scale;
+ return ret;
+}
+
+Expr AddRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ CHECK_EQ(new_args.size(), 2);
+ if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
+ DataType dtype;
+ Expr dom_scale;
+ Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
+ Expr ret = ForwardOp(ref_call, ret_args);
+ return QRealizeIntExprNode::make(ret, dom_scale, dtype);
+ }
+
+ CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
+ return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("add")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
+
+Expr ClipRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ CHECK_EQ(new_args.size(), 1);
+ if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+ const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
+ auto attrs = make_node<ClipAttrs>();
+ double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
+ attrs->a_min = ref_attrs->a_min / dom_scale;
+ attrs->a_max = ref_attrs->a_max / dom_scale;
+
+ Expr ret = CallNode::make(ref_call->op,
+ {n->data}, Attrs(attrs), ref_call->type_args);
+ return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
+ }
+ CHECK(!new_args[0]->derived_from<TempExprNode>());
+ return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("clip")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
+
+
+Expr ConcatenateRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ CHECK_EQ(new_args.size(), 1);
+ CHECK_EQ(ref_call->args.size(), 1);
+
+ const auto* tuple = new_args[0].as<TupleNode>();
+ const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
+ CHECK(tuple);
+ CHECK(ref_tuple);
+ const Array<Expr>& arr = tuple->fields;
+ const Array<Expr>& ref_arr = ref_tuple->fields;
+
+ if (arr[0].as<QRealizeIntExprNode>()) {
+ DataType dtype;
+ Expr dom_scale;
+ Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
+ Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
+ return QRealizeIntExprNode::make(ret, dom_scale, dtype);
+ } else {
+ for (auto arg : new_args) {
+ CHECK(!arg->derived_from<TempExprNode>());
+ }
+ return Expr(nullptr);
+ }
+}
+
+RELAY_REGISTER_OP("concatenate")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
+
+
+/* \brief forward the original operator */
+Expr IdentityRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ CHECK_EQ(new_args.size(), 1);
+ if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+ Expr ret = ForwardOp(ref_call, {n->data});
+ return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
+ }
+ CHECK(!new_args[0]->derived_from<TempExprNode>());
+ return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("nn.relu")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
+RELAY_REGISTER_OP("strided_slice")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
+RELAY_REGISTER_OP("annotation.stop_fusion")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
+/* \brief for unary operators which requantize its input to dtype_nbit */
+Expr CastDtypeInputRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ const QConfig& cfg = QConfig::Current();
+ CHECK_EQ(new_args.size(), 1);
+ if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+ Expr data = Cast(n->data, cfg->dtype_input);
+ Expr ret = ForwardOp(ref_call, {data});
+ return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
+ }
+ CHECK(!new_args[0]->derived_from<TempExprNode>());
+ return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("nn.max_pool2d")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
+
+
+Expr AvgPoolRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ const QConfig& cfg = QConfig::Current();
+ CHECK_EQ(new_args.size(), 1);
+ if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+ Expr data = n->data;
+ if (n->dtype != cfg->dtype_activation) {
+ data = Cast(n->data, cfg->dtype_activation);
+ }
+ Expr ret = ForwardOp(ref_call, {data});
+ return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
+ }
+ CHECK(!new_args[0]->derived_from<TempExprNode>());
+ return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("nn.avg_pool2d")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
+
+Expr CastHintRealize(const Call& ref_call,
+ const Array<Expr>& new_args,
+ const NodeRef& ctx) {
+ const auto param = ref_call->attrs.as<CastHintAttrs>();
+ CHECK_EQ(new_args.size(), 1);
+ if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+ Expr ret = Cast(n->data, param->dtype);
+ return QRealizeIntExprNode::make(ret, n->dom_scale, param->dtype);
+ }
+ CHECK(!new_args[0]->derived_from<TempExprNode>());
+ return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("annotation.cast_hint")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);
+
+Pass QuantizeRealizePass() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(
+ ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
+ };
+ return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizeRealize")
+.set_body_typed(QuantizeRealizePass);
+
+} // namespace quantize
+} // namespace relay
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from collections import namedtuple
+import tvm
+from tvm import relay
+from tvm.relay import quantize as qtz
+import mxnet as mx
+from mxnet import gluon
+import logging
+import os
+
+logging.basicConfig(level=logging.INFO)
+
+Config = namedtuple('Config', ['model', 'nbit_input', 'dtype_input', 'nbit_output', 'dtype_output', 'global_scale', 'expected_acc'])
+
+
+def get_val_data(model_name,
+ rec_val,
+ batch_size,
+ num_workers=4):
+ rec_val = os.path.expanduser(rec_val)
+ mean_rgb = [123.68, 116.779, 103.939]
+ std_rgb = [58.393, 57.12, 57.375]
+ def batch_fn(batch, ctx):
+ data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
+ label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
+ return data, label
+
+ img_size = 299 if model_name == 'inceptionv3' else 224
+ val_data = mx.io.ImageRecordIter(
+ path_imgrec = rec_val,
+ preprocess_threads = num_workers,
+ shuffle = False,
+ batch_size = batch_size,
+ resize = 256,
+ data_shape = (3, img_size, img_size),
+ mean_r = mean_rgb[0],
+ mean_g = mean_rgb[1],
+ mean_b = mean_rgb[2],
+ std_r = std_rgb[0],
+ std_g = std_rgb[1],
+ std_b = std_rgb[2],
+ )
+ return val_data, batch_fn
+
+
+def get_model(model_name, batch_size, qconfig, target=None, original=False, simulated=False):
+ gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
+ img_size = 299 if model_name == 'inceptionv3' else 224
+ data_shape = (batch_size, 3, img_size, img_size)
+ mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
+ net = mod['main']
+
+ with relay.build_config(opt_level=3):
+ qfunc = relay.quantize.prerequisite_optimize(net, params=params)
+ logging.debug('original')
+ logging.debug(qfunc.astext(show_meta_data=False))
+ if original:
+ return qfunc
+
+ with qconfig:
+ logging.debug('current quantize config')
+ logging.debug(qtz.current_qconfig())
+ qfunc = qtz.quantize(qfunc)
+ logging.debug('after quantize')
+ logging.debug(qfunc.astext(show_meta_data=False))
+ return qfunc
+
+
+def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), log_interval=100):
+ with relay.build_config(opt_level=3):
+ graph, lib, params = relay.build(model, target)
+ # create runtime module
+ m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+ m.set_input(**params)
+
+ # setup evaluaiton metric
+ dataset.reset()
+ batch_size = dataset.batch_size
+ acc_top1 = mx.metric.Accuracy()
+ acc_top5 = mx.metric.TopKAccuracy(5)
+ acc_top1.reset()
+ acc_top5.reset()
+ # Execute
+ for i, batch in enumerate(dataset):
+ data, label = batch_fn(batch, [mx.cpu(0)])
+ m.run(data=data[0].asnumpy())
+ out_arr = m.get_output(0)
+ acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())])
+ acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())])
+
+ if not (i + 1) % log_interval:
+ _, top1 = acc_top1.get()
+ _, top5 = acc_top5.get()
+ nsamples = (i + 1) * batch_size
+ logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5)
+ logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5)
+ return top1
+
+def test_quantize_acc(cfg, rec_val):
+ qconfig = qtz.qconfig(skip_conv_layers=[0],
+ nbit_input=cfg.nbit_input,
+ nbit_weight=cfg.nbit_input,
+ global_scale=cfg.global_scale,
+ dtype_input=cfg.dtype_input,
+ dtype_weight=cfg.dtype_input,
+ dtype_activation=cfg.dtype_output,
+ debug_enabled_ops=None)
+
+ model = get_model(cfg.model, 32, qconfig, tvm.target.cuda())
+ val_data, batch_fn = get_val_data(cfg.model, rec_val=rec_val, batch_size=32)
+
+ acc = eval_acc(model, val_data, batch_fn)
+ assert acc > cfg.expected_acc
+ return acc
+
+
+if __name__ == "__main__":
+ #TODO(for user): replace the line with the path to imagenet validation dataset
+ rec_val = "/scratch/tqchen/imagenet/val.rec"
+
+ results = []
+ configs = [
+ Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.666),
+
+ Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.692),
+ Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.692),
+ Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.733),
+ Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.747),
+ Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.756),
+ # TODO: need to fix accuracy
+ # Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0),
+ ]
+
+ for config in configs:
+ acc = test_quantize_acc(config, rec_val)
+ results.append((config, acc))
+ for res in results:
+ print(res)
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import math
-import numpy as np
-import tvm
-from tvm import relay
-from tvm.relay import quantize as qtz
-from tvm.relay import transform
-
-
-def run_infer_type(expr):
- mod = relay.Module.from_expr(expr)
- mod = transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
-
-
-def make_dataset(graph, size=100):
- args = run_infer_type(graph).params
- def create_arr(var):
- ttype = var.type_annotation
- np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype)
- return tvm.ndarray.array(np_arr)
-
- params = {}
- for arg in args:
- if arg.name_hint == 'data':
- dataset = [{'data': create_arr(arg)} for _ in range(size)]
- else:
- params[arg.name_hint] = create_arr(arg)
- return dataset, params
-
-
-def test_simulated_quantize():
- data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32"))
- out = qtz._annotate.attach_simulated_quantize(data, 1)
- out = run_infer_type(out)
- assert out.checked_type == out.args[0].checked_type
- assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32")
- assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "float32")
- assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
-
-
-def test_quantize_pass():
- def quantize_weight(arr):
- maximum = np.amax(np.abs(arr.asnumpy()))
- scale = 2**math.ceil(math.log(maximum, 2))
- out = np.around(arr.asnumpy() / scale * 128).astype('int8')
- out = np.clip(out, -127, 127)
- return relay.const(out, 'int8')
-
- n, c, h, w = 1, 3, 224, 224
- def make_graph(data):
- weight = relay.var("conv_weight")
- out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
- out = relay.Function(relay.analysis.free_vars(out), out)
- return out
-
- def make_qgraph(data, weight):
- out = data * relay.const(32.0)
- out = relay.round(out)
- out = relay.clip(out, a_min=-127, a_max=127)
- out = out.astype('int8')
-
- out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
- padding=(1, 1), channels=c, out_dtype='int32')
- out = out.astype('float32')
- out = relay.multiply(out, relay.const(0.00024414062))
- out = relay.Function(relay.analysis.free_vars(out), out)
- return out
-
- np.random.seed(42)
-
- data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
- graph = make_graph(data)
- dataset, params = make_dataset(graph, 10)
-
- with qtz.qconfig(skip_conv_layers=None, global_scale=4.0,
- round_for_shift=False, store_lowbit_output=False):
- qgraph0 = qtz.quantize(graph, params)
- qgraph0 = run_infer_type(qgraph0)
-
- conv_weight = quantize_weight(params['conv_weight'])
- qgraph1 = make_qgraph(data, conv_weight)
- qgraph1 = run_infer_type(qgraph1)
-
- graph = relay.create_executor('graph')
- res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
- res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
- tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
-
-
-if __name__ == "__main__":
- test_simulated_quantize()
- test_quantize_pass()
--- /dev/null
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+set -u
+
+export PYTHONPATH=python:topi/python
+
+# Rebuild cython
+make cython3
+
+rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
+rm -rf topi/python/topi/*.pyc topi/python/topi/*/*.pyc topi/python/topi/*/*/*.pyc topi/python/topi/*/*/*/*.pyc
+
+python3 -m nose -v topi/tests/python/nightly