[QUANTIZE] Refactor quantization codebase and fix model accuracy (#3543)
authorziheng <ziheng@apache.org>
Thu, 15 Aug 2019 09:31:30 +0000 (02:31 -0700)
committerGitHub <noreply@github.com>
Thu, 15 Aug 2019 09:31:30 +0000 (02:31 -0700)
* Refactor.

* update

* update

* update

* update

* update

* update

18 files changed:
include/tvm/relay/analysis.h
include/tvm/relay/attrs/annotation.h
python/tvm/relay/analysis.py
python/tvm/relay/quantize/__init__.py
python/tvm/relay/quantize/_annotate.py
python/tvm/relay/quantize/_partition.py [new file with mode: 0644]
python/tvm/relay/quantize/quantize.py
src/relay/op/annotation/annotation.cc
src/relay/pass/fold_constant.cc
src/relay/pass/pattern_util.h
src/relay/pass/quantize/annotate.cc [new file with mode: 0644]
src/relay/pass/quantize/partition.cc [new file with mode: 0644]
src/relay/pass/quantize/quantize.cc
src/relay/pass/quantize/quantize.h
src/relay/pass/quantize/realize.cc [new file with mode: 0644]
tests/python/nightly/quantization/test_quantization_accuracy.py [new file with mode: 0644]
tests/python/relay/test_pass_quantize.py [deleted file]
tests/scripts/task_python_nightly.sh [new file with mode: 0755]

index 3672a22..8c14f02 100644 (file)
@@ -52,6 +52,18 @@ namespace relay {
 TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
 
 /*!
+ * \brief Check whether an expression is constant.
+ *
+ * If the inputs of an expression are all constant, it means the expression
+ * itself is constant also.
+ *
+ * \param e the expression.
+ *
+ * \return whether the expression is constant.
+ */
+TVM_DLL bool ConstantCheck(const Expr& e);
+
+/*!
  * \brief Compare two expressions for structural equivalence.
  *
  * This comparison operator respects scoping and compares
index 29750c5..fd21db5 100644 (file)
@@ -44,6 +44,19 @@ struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
   }
 };
 
+/*!
+ * \brief Annotate an expression to be cast into specific data type.
+ */
+struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
+  DataType dtype;
+
+  TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") {
+    TVM_ATTR_FIELD(dtype)
+      .describe(
+         "The data type denoted to be cast.");
+  }
+};
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_ANNOTATION_H_
index 91b53bb..7372fcd 100644 (file)
@@ -91,6 +91,22 @@ def check_kind(t, mod=None):
         return _analysis.check_kind(t)
 
 
+def check_constant(expr):
+    """Check whether an expression is constant
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr
+        The input expression
+
+    Returns
+    -------
+    result : bool
+        Whether the expression is constant.
+    """
+    return _analysis.check_constant(expr)
+
+
 def free_vars(expr):
     """Get free Vars from expression expr in Post DFS order.
 
index a9e7b40..29b6895 100644 (file)
@@ -19,5 +19,6 @@
 from __future__ import absolute_import as _abs
 
 from .quantize import *
+from ._partition import register_partition_function
 from ._annotate import register_annotate_function
 from .kl_divergence import kl_divergence_scale
index e03eaab..55f3597 100644 (file)
@@ -20,14 +20,15 @@ from __future__ import absolute_import
 import warnings
 
 import topi
-from . import _quantize
-from .quantize import QAnnotateKind, current_qconfig
-from .quantize import annotate_context
+from ..._ffi.function import register_func
 from .. import expr as _expr
+from .. import analysis as _analysis
 from .. import op as _op
 from ..op import op as _reg
 from ..base import register_relay_node
-from ..._ffi.function import register_func
+from . import _quantize
+from .quantize import QAnnotateKind, current_qconfig, quantize_context
+from .quantize import _forward_op
 
 
 @_reg.register_compute("relay.op.annotation.simulated_quantize")
@@ -75,12 +76,6 @@ class QAnnotateExpr(_expr.TempExpr):
             _quantize.make_annotate_expr, expr, kind)
 
 
-def _forward_op(ref_call, args):
-    """forward the operator of ref_call with provided arguments"""
-    return _expr.Call(
-        ref_call.op, args, ref_call.attrs, ref_call.type_args)
-
-
 def _get_expr_kind(anno):
     """Get the expression and QAnnotateKind from QAnnotateExpr or Expr"""
     if isinstance(anno, QAnnotateExpr):
@@ -113,7 +108,7 @@ def register_annotate_function(op_name, frewrite=None, level=10):
             if not current_qconfig().guard(ref_call):
                 return default_rewrite(ref_call, new_args, ctx)
             return func(ref_call, new_args, ctx)
-        _op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
+        _reg._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
         return frewrite_with_guard
 
     return _register(frewrite) if frewrite is not None else _register
@@ -135,17 +130,17 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
         if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
             return data
 
-    actx = annotate_context()
+    qctx = quantize_context()
     key = tuple([data, kind, sign, rounding])
-    if key in actx.qnode_map:
-        return actx.qnode_map[key]
+    if key in qctx.qnode_map:
+        return qctx.qnode_map[key]
 
     dom_scale = _expr.var("dom_scale")
     clip_min = _expr.var("clip_min")
     clip_max = _expr.var("clip_max")
     qnode = _quantize.simulated_quantize(
         data, dom_scale, clip_min, clip_max, kind, sign, rounding)
-    actx.qnode_map[key] = qnode
+    qctx.qnode_map[key] = qnode
     return qnode
 
 register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
@@ -163,13 +158,8 @@ def conv2d_rewrite(ref_call, new_args, ctx):
     """Rewrite function for conv2d. Lhs of conv will be quantized to
     input field, and rhs of conv will be quantized to weight field.
     Output would be in activation field"""
-    actx = annotate_context()
-    if current_qconfig().skip_conv_layers is not None:
-        skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if actx.conv2d_counter() in skipped_indices:
-            actx.count_conv2d()
-            return None
-    actx.count_conv2d()
+    if quantize_context().check_to_skip(ref_call):
+        return None
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
     rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@@ -185,21 +175,12 @@ def conv2d_rewrite(ref_call, new_args, ctx):
     return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
 
 
-def check_to_skip():
-    """Check the index of conv2d layer to decide whether to skip the current operator."""
-    if current_qconfig().skip_conv_layers is not None:
-        skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if annotate_context().conv2d_counter() - 1 in skipped_indices:
-            return True
-    return False
-
-
 # TODO(tmoreau89,ziheng) need to include an option to turn off dense quant
 # @register_annotate_function("nn.dense")
 def dense_rewrite(ref_call, new_args, ctx):
     """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
     dense will be quantized to weight field. Output would be in activation field."""
-    if check_to_skip():
+    if quantize_context().check_to_skip(ref_call):
         return None
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
@@ -219,7 +200,7 @@ def dense_rewrite(ref_call, new_args, ctx):
 @register_annotate_function("multiply")
 def multiply_rewrite(ref_call, new_args, ctx):
     """Rewrite function for multiply."""
-    if check_to_skip():
+    if quantize_context().check_to_skip(ref_call):
         return None
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
@@ -243,13 +224,14 @@ def multiply_rewrite(ref_call, new_args, ctx):
 @register_annotate_function("add")
 def add_rewrite(ref_call, new_args, ctx):
     """Rewrite function for add."""
-    if check_to_skip():
+    if quantize_context().check_to_skip(ref_call):
         return None
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
     rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
 
     if lhs_kind is None and rhs_kind is None:
+        # trivial case
         return None
 
     if lhs_kind is None and rhs_kind is not None:
@@ -260,11 +242,10 @@ def add_rewrite(ref_call, new_args, ctx):
         return QAnnotateExpr(expr, QAnnotateKind.INPUT)
 
     if lhs_kind is not None and rhs_kind is None:
-        if isinstance(rhs_expr, _expr.Constant):
-            # quantize rhs to WEIGHT field if it is Constant
+        if _analysis.check_constant(rhs_expr):
+            # - introduced by batch_norm: add(out, const)
             rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
         else:
-            # quantize rhs to INPUT field if it is not Constant
             rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
         expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
         return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
@@ -274,7 +255,6 @@ def add_rewrite(ref_call, new_args, ctx):
             expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
             return QAnnotateExpr(expr, QAnnotateKind.INPUT)
         if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
-            # quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
             rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
             expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
             return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
@@ -285,24 +265,9 @@ def add_rewrite(ref_call, new_args, ctx):
     raise ValueError()
 
 
-@register_annotate_function("stop_fusion")
-def stop_fusion_rewrite(ref_call, new_args, ctx):
-    """Rewrite function for add."""
-    if check_to_skip():
-        return None
-
-    x_expr, x_kind = _get_expr_kind(new_args[0])
-    if x_kind is None:
-        return None
-
-    ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT)
-    ret_expr = _forward_op(ref_call, [ret_expr])
-    return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT)
-
-
 def identity_rewrite(ref_call, new_args, ctx):
     """Simply forward the original operation"""
-    if check_to_skip():
+    if quantize_context().check_to_skip(ref_call):
         return None
 
     x_expr, x_kind = _get_expr_kind(new_args[0])
@@ -322,7 +287,7 @@ register_annotate_function("annotation.stop_fusion", identity_rewrite)
 
 def pool2d_rewrite(ref_call, new_args, ctx):
     """Rewrite function for max pool2d"""
-    if check_to_skip():
+    if quantize_context().check_to_skip(ref_call):
         return None
 
     expr, x_kind = _get_expr_kind(new_args[0])
@@ -339,14 +304,14 @@ def pool2d_rewrite(ref_call, new_args, ctx):
 register_annotate_function("nn.max_pool2d", pool2d_rewrite)
 
 
-@register_annotate_function("annotation.force_cast")
-def force_cast_rewrite(ref_call, new_args, ctx):
+@register_annotate_function("annotation.cast_hint")
+def cast_hint_rewrite(ref_call, new_args, ctx):
     """Rewrite function to force cast"""
-    if check_to_skip():
-        return None
-
     expr, x_kind = _get_expr_kind(new_args[0])
 
+    if quantize_context().check_to_skip(ref_call):
+        return expr
+
     if x_kind is None:
         return new_args[0]
     if x_kind == QAnnotateKind.ACTIVATION:
@@ -359,7 +324,7 @@ def force_cast_rewrite(ref_call, new_args, ctx):
 @register_annotate_function("concatenate")
 def concatenate_rewrite(ref_call, new_args, ctx):
     """Rewrite function for concatenate"""
-    if check_to_skip():
+    if quantize_context().check_to_skip(ref_call):
         return None
 
     input_tuple = new_args[0]
@@ -377,69 +342,18 @@ def concatenate_rewrite(ref_call, new_args, ctx):
     return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
 
 
-# Graph rewrite function registration for VTA target
-def register_vta_rewrite(op_name, frewrite=None, level=10):
-    def _register(func):
-        return _op.op._Register(op_name, "FQVTARewrite", func, level)
-    return _register(frewrite) if frewrite is not None else _register
+@register_annotate_function("nn.global_avg_pool2d")
+def global_avg_pool2d_rewrite(ref_call, new_args, ctx):
+    """Rewrite function for global_avg_pool2d for stopping quantize"""
+    if quantize_context().check_to_skip(ref_call):
+        return None
 
+    expr, x_kind = _get_expr_kind(new_args[0])
 
-@register_relay_node
-class QVTAExpr(_expr.TempExpr):
-    def __init__(self, expr):
-        self.__init_handle_by_constructor__(
-            _quantize.make_vta_expr, expr)
-
-    def realize(self):
-        return _quantize.temp_expr_realize(self)
-
-
-def vta_expr_check(expr):
-    if isinstance(expr, QVTAExpr):
-        return True, expr.expr
-    return False, expr
-
-
-@register_vta_rewrite("nn.conv2d")
-def conv2d_vta_rewrite(ref_call, new_args, ctx):
-    """Rewrite function for conv2d for VTA target"""
-    actx = annotate_context()
-    if current_qconfig().skip_conv_layers is not None:
-        skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
-        if actx.conv2d_counter() in skipped_indices:
-            actx.count_conv2d()
-            return None
-    actx.count_conv2d()
-
-    data_cond, data = vta_expr_check(new_args[0])
-    kernel_cond, kernel = vta_expr_check(new_args[1])
-
-    assert not kernel_cond
-    if data_cond:
-        data = new_args[0].realize()
-    ret = _forward_op(ref_call, [data, kernel])
-    return QVTAExpr(ret)
-
-
-def identity_vta_rewrite(ref_call, new_args, ctx):
-    cond, expr = vta_expr_check(new_args[0])
-    if cond:
-        return QVTAExpr(_forward_op(ref_call, [expr]))
-    return None
-
-register_vta_rewrite("nn.relu", identity_vta_rewrite)
-register_vta_rewrite("nn.max_pool2d", identity_vta_rewrite)
-
-
-@register_vta_rewrite("add")
-def add_vta_rewrite(ref_call, new_args, ctx):
-    """Rewrite function for ewise add for VTA target"""
-    lhs_cond, lhs = vta_expr_check(new_args[0])
-    rhs_cond, rhs = vta_expr_check(new_args[1])
-    if lhs_cond and rhs_cond:
-        lhs = new_args[0].realize()
-        rhs = new_args[1].realize()
-        return _forward_op(ref_call, [lhs, rhs])
-    elif lhs_cond and not rhs_cond:
-        return QVTAExpr(_forward_op(ref_call, [lhs, rhs]))
-    return None
+    if x_kind is None:
+        return None
+    expr = _forward_op(ref_call, [new_args[0].realize()])
+
+    # stop quantize after global_avg_pool2d
+    quantize_context().stop_quantize()
+    return expr
diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py
new file mode 100644 (file)
index 0000000..597c55c
--- /dev/null
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=unused-argument,inconsistent-return-statements
+"""Internal module for registering attribute for annotation."""
+from __future__ import absolute_import
+
+from ... import target as _target
+from .. import expr as _expr
+from .. import analysis as _analysis
+from ..base import register_relay_node
+from ..op import op as _reg
+from . import _quantize
+from .quantize import _forward_op
+
+def register_partition_function(op_name, frewrite=None, level=10):
+    def _register(func):
+        return _reg._Register(op_name, "FQPartitionRewrite", func, level)
+    return _register(frewrite) if frewrite is not None else _register
+
+
+@register_relay_node
+class QPartitionExpr(_expr.TempExpr):
+    def __init__(self, expr):
+        self.__init_handle_by_constructor__(
+            _quantize.make_partition_expr, expr)
+
+
+def partition_expr_check(expr):
+    if isinstance(expr, QPartitionExpr):
+        return True, expr.expr
+    return False, expr
+
+
+@register_partition_function("nn.conv2d")
+def conv2d_partition_function(ref_call, new_args, ctx):
+    """Rewrite function for conv2d for partition"""
+    data_cond, data = partition_expr_check(new_args[0])
+    kernel_cond, kernel = partition_expr_check(new_args[1])
+
+    assert not kernel_cond
+    if data_cond:
+        data = new_args[0].realize()
+    ret = _forward_op(ref_call, [data, kernel])
+    return QPartitionExpr(ret)
+
+
+def identity_partition_function(ref_call, new_args, ctx):
+    cond, expr = partition_expr_check(new_args[0])
+    if cond:
+        return QPartitionExpr(_forward_op(ref_call, [expr]))
+    return None
+
+register_partition_function("clip", identity_partition_function)
+register_partition_function("nn.relu", identity_partition_function)
+register_partition_function("nn.max_pool2d", identity_partition_function)
+
+
+def add_partition_generic(ref_call, new_args, ctx):
+    """Rewrite function for ewise add for partition for generic devices"""
+    lhs_cond, lhs = partition_expr_check(new_args[0])
+    rhs_cond, rhs = partition_expr_check(new_args[1])
+    if lhs_cond and rhs_cond:
+        # - introduced by ResNet, when for the first residual connection
+        #     ...
+        #     %0 = nn.conv2d(%data, %meta[relay.Constant])
+        #     %1 = add(%0, %meta[relay.Constant])
+        #     %2 = nn.relu(%1)
+        #     %3 = nn.max_pool2d(%2)
+        #     ...
+        #     %9 = nn.conv2d(%8, %meta[relay.Constant])
+        #     %10 = add(%9, %meta[relay.Constant])
+        #     %11 = add(%3, %10)  <- need to insert annotations for %3, %10
+        #     ...
+        lhs = new_args[0].realize()
+        rhs = new_args[1].realize()
+        return _forward_op(ref_call, [lhs, rhs])
+    elif not lhs_cond and rhs_cond:
+        # - introduced by residual connection in ResNet
+        #     ...
+        #     %13 = nn.conv2d(%12, %meta[relay.Constant])
+        #     %14 = add(%13, %meta[relay.Constant])
+        #     %15 = annotation.cast_hint(%15, 'int8')
+        #     %16 = annotation.stop_fusion(%16)
+        #     %17 = add(%5, %16)
+        #     %18 = nn.relu(%17)
+        #     ...
+        #     %24 = nn.conv2d(%23, %meta[relay.Constant])
+        #     %25 = add(%24, %meta[relay.Constant])
+        #     %26 = add(%18, %25)  <- need to insert annotations for %25
+        #     ...
+        rhs = new_args[1].realize()
+        return _forward_op(ref_call, [lhs, rhs])
+    elif lhs_cond and not rhs_cond:
+        if _analysis.check_constant(rhs):
+            # - introduced by batch_norm: add(out, bias)
+            return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
+        # - introduced by residual connection in MobileNetV2
+        #     ...
+        #     %81 = add(%80, meta[relay.Constant])
+        #     %82 = annotation.cast_hint(%81, 'int8')
+        #     %83 = annotation.stop_fusion(%82)
+        #     %84 = add(%79, %83)
+        #     ...
+        #     %96 = nn.conv2d(%94, %meta[relay.Constant])
+        #     %96 = add(%95, %meta[relay.Constant])
+        #     %97 = add(%96, %84)  <- need to insert annotations for %96
+        #     ...
+        lhs = new_args[0].realize()
+        return _forward_op(ref_call, [lhs, rhs])
+    elif not lhs_cond and not rhs_cond:
+        # trivial case
+        return None
+    else:
+        raise ValueError
+
+
+# TODO(ziheng) enhance `register_partition_function` to dispatch
+# for target automatically
+@register_partition_function("add")
+def add_partition_function(ref_call, new_args, ctx):
+    """Rewrite function for ewise add for partition"""
+    if 'cuda' in _target.current_target().keys:
+        #TODO(wuwei/ziheng) cuda specific rules
+        return add_partition_generic(ref_call, new_args, ctx)
+    return add_partition_generic(ref_call, new_args, ctx)
+
+
+@register_partition_function("multiply")
+def multiply_partition_function(ref_call, new_args, ctx):
+    """Rewrite function for ewise add for partition"""
+    lhs_cond, lhs = partition_expr_check(new_args[0])
+    rhs_cond, rhs = partition_expr_check(new_args[1])
+    if lhs_cond:
+        # introduced by bn: multiply(out, scale)
+        return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
+    assert (not lhs_cond) and (not rhs_cond)
+    return None
index 07d4d9d..adde205 100644 (file)
@@ -50,6 +50,12 @@ def kind2str(kind):
     return str_map[kind]
 
 
+def _forward_op(ref_call, args):
+    """forward the operator of ref_call with provided arguments"""
+    return _expr.Call(
+        ref_call.op, args, ref_call.attrs, ref_call.type_args)
+
+
 @register_relay_node("relay.quantize.QConfig")
 class QConfig(NodeBase):
     """Configure the quantization behavior by setting config variables.
@@ -74,8 +80,8 @@ class QConfig(NodeBase):
         "dtype_activation": "int32",
         "global_scale": 8.0,
         "skip_conv_layers": [0],
+        "do_simulation": False,
         "round_for_shift": True,
-        "store_lowbit_output": True,
         "debug_enabled_ops": None,
     }
 
@@ -92,6 +98,7 @@ class QConfig(NodeBase):
         self.handle = handle
 
     def guard(self, ref_call):
+        """Return true if op is enabled, otherwise return false"""
         op_name = ref_call.op.name
         if self.debug_enabled_ops is not None:
             name_list = [x.value for x in self.debug_enabled_ops]
@@ -126,9 +133,7 @@ def current_qconfig():
     """Get the current quantization configuration."""
     return _quantize._GetCurrentQConfig()
 
-# TODO(tmoreau89, ZihengJiang) the skip parameters are
-# hacky - we should explore a more future-proof way to
-# skip operators based on pattern matching
+
 def qconfig(**kwargs):
     """Configure the quantization behavior by setting config variables.
 
@@ -142,15 +147,14 @@ def qconfig(**kwargs):
 
     skip_conv_layers: list
         Specifying which layers to be skipped. Provide a list of indices
-        that indicate which conv2d layers to leave untouched.
+        that indicate which conv2d layers to leave untouched. Start from 0.
+
+    do_simulation: boolean
+        Whether to do simulation with float operation only.
 
     round_for_shift: boolean
         Whether to add bias for rounding during shift.
 
-    store_lowbit_output: boolean
-        Whether to store low-bit integer back as output before dequantizing.
-        Some accelerators need this, e.g. VTA.
-
     debug_enabled_ops: None or list of str
         Partially quantize specified operators for debugging. The default value
         is None, which means will try to call all operartors' annotate rewrite
@@ -166,35 +170,79 @@ def qconfig(**kwargs):
     return _make.node("relay.quantize.QConfig", **node_args)
 
 
-class AnnotateContext(object):
-    """A global singleton annotate scope"""
+class QuantizeContext(object):
+    """An internal used global context object for annotation,
+    for putting some state variables like `conv2d_counter`."""
     Current = None
 
     def __init__(self):
         self.qnode_map = dict()
         self._conv2d_counter = 0
+        self._stop_quantize = False
+
+    def check_to_skip(self, ref_call):
+        """Check the index of conv2d layer to decide whether to
+        skip the current operator."""
+        if self._stop_quantize:
+            return True
+
+        if current_qconfig().skip_conv_layers is not None:
+            # check skip conv layers
+            skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+            if self._conv2d_counter in skipped_indices:
+                if ref_call.op.name == 'nn.conv2d':
+                    self._conv2d_counter += 1
+                return True
+            if ref_call.op.name == 'nn.conv2d':
+                self._conv2d_counter += 1
+
+        return False
+
+    def stop_quantize(self):
+        self._stop_quantize = True
+
+    def reset(self):
+        self._conv2d_counter = 0
+        self._stop_quantize = False
 
     def __enter__(self):
-        self._conv2d_counter = 0
+        self.reset()
         return self
 
-    def conv2d_counter(self):
-        """Get the counter for conv2d."""
-        return self._conv2d_counter
-
-    def count_conv2d(self):
-        """Increase the value of the conv2d counter by one."""
-        self._conv2d_counter += 1
-
     def __exit__(self, ptype, value, traceback):
         pass
 
 
-def annotate_context():
+def quantize_context():
     """Get the global singleton scope"""
-    if AnnotateContext.Current is None:
-        AnnotateContext.Current = AnnotateContext()
-    return AnnotateContext.Current
+    if QuantizeContext.Current is None:
+        QuantizeContext.Current = QuantizeContext()
+    return QuantizeContext.Current
+
+
+def partition():
+    """Partition graph into small low-precision sections by `cast_hint` and
+    `stop_fusion`.
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass for VTA rewrite.
+    """
+    return _quantize.QuantizePartition()
+
+
+def annotate():
+    """Given a float32 graph, this pass will rewrite the graph and return
+    a graph which simulates the error brought by the current quantization
+    scheme.
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass for quantization annotation.
+    """
+    return _quantize.QuantizeAnnotate()
 
 
 def collect_stats(graph):
@@ -300,20 +348,8 @@ def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
             const_params[nclip_max] = _make_const((valid_range - 1))
 
     _analysis.post_order_visit(graph, visit_func)
-    return _expr.bind(graph, const_params)
-
-
-def annotate():
-    """Given a float32 graph, this pass will rewrite the graph and return
-    a graph which simulates the error brought by the current quantization
-    scheme.
-
-    Returns
-    -------
-    ret: tvm.relay.Pass
-        The registered pass for quantization annotation.
-    """
-    return _quantize.QuantizeAnnotate()
+    ret = _expr.bind(graph, const_params)
+    return ret
 
 
 def realize():
@@ -330,17 +366,6 @@ def realize():
     return _quantize.QuantizeRealize()
 
 
-def rewrite_for_vta():
-    """Performs rewriting for VTA target.
-
-    Returns
-    -------
-    ret: tvm.relay.Pass
-        The registered pass for VTA rewrite.
-    """
-    return _quantize.QuantizeRewriteForVTA()
-
-
 def _bind_params(func, params):
     """Bind the params to the expression.
     """
@@ -362,6 +387,25 @@ def _bind_params(func, params):
     return _expr.bind(func, bind_dict)
 
 
+def prerequisite_optimize(graph, params=None):
+    """ Prerequisite optimization passes for quantization. Perform
+    "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
+    "CanonicalizeOps" optimization before quantization. """
+    optimize = _transform.Sequential([_transform.SimplifyInference(),
+                                      _transform.FoldConstant(),
+                                      _transform.FoldScaleAxis(),
+                                      _transform.CanonicalizeOps(),
+                                      _transform.FoldConstant()])
+
+    if params:
+        graph = _bind_params(graph, params)
+
+    mod = _module.Module.from_expr(graph)
+    with _transform.PassContext(opt_level=3):
+        mod = optimize(mod)
+    return mod["main"]
+
+
 def quantize(graph, params=None, dataset=None):
     """ The quantization procedure. Before running the three main
     procedure of quantization, "annotate", "calibrate" and "realize"
@@ -385,33 +429,23 @@ def quantize(graph, params=None, dataset=None):
     ret: Function
         The graph after quantization
     """
-    if params:
-        graph = _bind_params(graph, params)
+    graph = prerequisite_optimize(graph, params)
 
     mod = _module.Module.from_expr(graph)
-    # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
-    # "CanonicalizeOps" optimization before quantization.
-    optimize = _transform.Sequential([_transform.SimplifyInference(),
-                                      _transform.FoldConstant(),
-                                      _transform.FoldScaleAxis(),
-                                      _transform.CanonicalizeOps(),
-                                      _transform.FoldConstant()])
-
     calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
                                               name="QuantizeCalibrate")
-    # Quantize pass list
-    quant_passes = [annotate(),
-                    calibrate_pass,
-                    realize(),
-                    _transform.FoldConstant()]
-    if current_qconfig().store_lowbit_output:
-        quant_passes = [rewrite_for_vta()] + quant_passes
+    quant_passes = [partition(),
+                    annotate(),
+                    calibrate_pass]
+    if not current_qconfig().do_simulation:
+        quant_passes.append(realize())
+    quant_passes.append(_transform.FoldConstant())
     quantize_seq = _transform.Sequential(quant_passes)
     with _transform.PassContext(opt_level=3,
                                 required_pass=["QuantizeAnnotate",
                                                "QuantizeCalibrate",
                                                "QuantizeRealize"]):
-        mod = optimize(mod)
-        mod = quantize_seq(mod)
+        with quantize_context():
+            mod = quantize_seq(mod)
 
     return mod["main"]
index a5ade5b..eeacc6c 100644 (file)
@@ -83,13 +83,18 @@ TVM_ADD_FILELINE)
                          return {topi::identity(inputs[0])};
                        });
 
-Expr ForceCast(Expr data) {
-  static const Op& op = Op::Get("annotation.force_cast");
-  return CallNode::make(op, {data}, Attrs{}, {});
+// relay.annotation.cast_hint
+TVM_REGISTER_NODE_TYPE(CastHintAttrs);
+
+Expr CastHint(Expr data, DataType dtype) {
+  auto attrs = make_node<CastHintAttrs>();
+  attrs->dtype = dtype;
+  static const Op& op = Op::Get("annotation.cast_hint");
+  return CallNode::make(op, {data}, Attrs{attrs}, {});
 }
 
-RELAY_REGISTER_OP("annotation.force_cast")
-.describe(R"code(Annotate an expression to force a cast.)code"
+RELAY_REGISTER_OP("annotation.cast_hint")
+.describe(R"code(Annotate an expression to be cast into specific data type.)code"
 TVM_ADD_FILELINE)
 .set_num_inputs(1)
 .add_argument("data", "Tensor", "The input data.")
index 7b896a8..eba77c7 100644 (file)
@@ -66,6 +66,13 @@ class ConstantChecker : private ExprVisitor {
   }
 };
 
+bool ConstantCheck(const Expr& e) {
+  return ConstantChecker().Check(e);
+}
+
+TVM_REGISTER_API("relay._analysis.check_constant")
+.set_body_typed(ConstantCheck);
+
 
 // TODO(tvm-team) consider combine dead-code with constant folder.
 // or make a more powerful partial evaluator.
index 3ccfff0..18e5df3 100644 (file)
@@ -31,6 +31,7 @@
 #include <tvm/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/expr.h>
+#include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/attrs/reduce.h>
@@ -420,7 +421,7 @@ Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array
 
 Expr StopFusion(Expr data);
 
-Expr ForceCast(Expr data);
+Expr CastHint(Expr data, DataType dtype);
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/pass/quantize/annotate.cc b/src/relay/pass/quantize/annotate.cc
new file mode 100644 (file)
index 0000000..d8a7a0f
--- /dev/null
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ *
+ * \file annotate.cc
+ *
+ * \brief Annotating the graph with simulated quantize operators.
+ */
+
+#include <tvm/relay/transform.h>
+#include <tvm/relay/analysis.h>
+#include "./quantize.h"
+
+namespace tvm {
+namespace relay {
+namespace quantize {
+
+using namespace relay::transform;
+
+class QAnnotateExpr;
+class QAnnotateExprNode : public TempExprNode {
+ public:
+  Expr expr;
+  QAnnotateKind kind;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("expr", &expr);
+    v->Visit("kind", &kind);
+  }
+
+  TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
+
+  Expr Realize() const final;
+
+  static constexpr const char* _type_key = "relay.QAnnotateExpr";
+  TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
+
+
+Expr QAnnotateExprNode::Realize() const {
+  return expr;
+}
+
+QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
+  auto rnode = make_node<QAnnotateExprNode>();
+  rnode->expr = expr;
+  rnode->kind = kind;
+  return QAnnotateExpr(rnode);
+}
+
+TVM_REGISTER_API("relay._quantize.make_annotate_expr")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    *ret = QAnnotateExprNode::make(args[0],
+      static_cast<QAnnotateKind>(args[1].operator int()));
+  });
+
+
+Pass QuantizeAnnotate() {
+  // TODO(tvm-teams): since partition has added cast_hint in different
+  // branches, try to remove this in the future.
+  std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
+    if (e->derived_from<TempExprNode>()) {
+      const auto* n = e.as<QAnnotateExprNode>();
+      CHECK(n);
+      const PackedFunc* f =
+          runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
+      Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
+      return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
+    }
+    return e;
+  };
+
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
+      auto new_params = func->params;
+      for (const auto& x : FreeVars(func)) {
+        new_params.push_back(x);
+      }
+      return FunctionNode::make(new_params,
+                                func->body,
+                                func->ret_type,
+                                func->type_params,
+                                func->attrs);
+  };
+  return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
+.set_body_typed(QuantizeAnnotate);
+
+}  // namespace quantize
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/pass/quantize/partition.cc b/src/relay/pass/quantize/partition.cc
new file mode 100644 (file)
index 0000000..3f46cf2
--- /dev/null
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ *
+ * \file partition.cc
+ *
+ * \brief Partition a graph into sections for quantization.
+ */
+
+#include <tvm/relay/transform.h>
+#include "../pattern_util.h"
+#include "./quantize.h"
+
+namespace tvm {
+namespace relay {
+namespace quantize {
+
+using namespace relay::transform;
+
+class QPartitionExpr;
+class QPartitionExprNode : public TempExprNode {
+ public:
+  /*! \brief The original expression */
+  Expr expr;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("expr", &expr);
+  }
+
+  TVM_DLL static QPartitionExpr make(Expr expr);
+
+  Expr Realize() const final;
+
+  static constexpr const char* _type_key = "relay.QPartitionExpr";
+  TVM_DECLARE_NODE_TYPE_INFO(QPartitionExprNode, TempExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(QPartitionExpr, QPartitionExprNode, TempExpr);
+
+
+Expr QPartitionExprNode::Realize() const {
+  // insert cast hint and stop fusion
+  const QConfig& cfg = QConfig::Current();
+  Expr ret = CastHint(this->expr, cfg->dtype_input);
+  return StopFusion(ret);
+}
+
+QPartitionExpr QPartitionExprNode::make(Expr expr) {
+  auto rnode = make_node<QPartitionExprNode>();
+  rnode->expr = expr;
+  return QPartitionExpr(rnode);
+}
+
+TVM_REGISTER_API("relay._quantize.make_partition_expr")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    *ret = QPartitionExprNode::make(args[0]);
+  });
+
+Pass QuantizePartition() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      auto ret = Downcast<Function>(
+          ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
+      return ret;
+  };
+  return CreateFunctionPass(pass_func, 1, "QuantizePartition", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizePartition")
+.set_body_typed(QuantizePartition);
+
+}  // namespace quantize
+}  // namespace relay
+}  // namespace tvm
index 6cffc20..c6d71ba 100644 (file)
  *   for compression and acceleration.
  */
 #include <dmlc/thread_local.h>
-#include <tvm/base.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/transform.h>
-#include <cmath>
-#include <string>
-#include <vector>
 #include <stack>
-#include <utility>
-#include "../pattern_util.h"
 #include "./quantize.h"
 
 
@@ -44,8 +36,6 @@ namespace tvm {
 namespace relay {
 namespace quantize {
 
-using namespace relay::transform;
-
 TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
 
 bool SimulatedQuantizeRel(const Array<Type>& types,
@@ -91,490 +81,6 @@ TVM_REGISTER_API("relay._quantize.simulated_quantize")
   });
 
 
-// =============
-// annotate pass
-
-Expr QAnnotateExprNode::Realize() const {
-  const auto& cfg = QConfig::Current();
-  if (cfg->store_lowbit_output) {
-    // store low bit output back for VTA
-    const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
-    return (*f)(this->expr, static_cast<int>(kQInput));
-  } else {
-    return expr;
-  }
-}
-
-QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
-  auto rnode = make_node<QAnnotateExprNode>();
-  rnode->expr = expr;
-  rnode->kind = kind;
-  return QAnnotateExpr(rnode);
-}
-
-TVM_REGISTER_API("relay._quantize.make_annotate_expr")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = QAnnotateExprNode::make(args[0],
-      static_cast<QAnnotateKind>(args[1].operator int()));
-  });
-
-
-// =============
-// realize pass
-
-Expr QRealizeIntExprNode::Realize() const {
-  const auto& cfg = QConfig::Current();
-  Expr data = this->data;
-  if (cfg->store_lowbit_output) {
-    data = Cast(data, cfg->dtype_input);
-  }
-  // dequantize
-  data = Cast(data, Float(32));
-  data = Multiply(data, this->dom_scale);
-  return data;
-}
-
-QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
-  NodePtr<QRealizeIntExprNode> n = make_node<QRealizeIntExprNode>();
-  n->data = std::move(data);
-  n->dom_scale = std::move(dom_scale);
-  n->dtype = std::move(dtype);
-  return QRealizeIntExpr(n);
-}
-
-
-inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
-  return CallNode::make(ref_call->op,
-    args, ref_call->attrs, ref_call->type_args);
-}
-
-
-/* calculate `data * s1 / s2`, use shift if possible */
-inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
-  // here we assume the dtype of data is dtype activation
-  if (s1 == s2) return data;
-
-  float factor = s1 / s2;
-  float shift_factor = std::log2(factor);
-  CHECK_GT(shift_factor, 0);
-  if (static_cast<int>(shift_factor) == shift_factor) {
-    return LeftShift(data, MakeConstantScalar(dtype,
-                                              static_cast<int>(shift_factor)));
-  } else if (static_cast<int>(factor) == factor) {
-    return Multiply(data, MakeConstantScalar(dtype, factor));
-  } else {
-    data = Cast(data, Float(32));
-    data = Multiply(data, MakeConstantScalar(Float(32), factor));
-    return Cast(Round(data), dtype);
-  }
-}
-
-Expr QuantizeRealize(const Call& ref_call,
-                     const Array<Expr>& new_args,
-                     const NodeRef& ctx) {
-  const QConfig& cfg = QConfig::Current();
-  // do not handle data type cast
-  const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
-  CHECK_EQ(param->rounding, "round");
-
-  Expr dom_scale = new_args[1];
-  Expr clip_min = new_args[2];
-  Expr clip_max = new_args[3];
-
-  float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
-  float clip_min_imm = GetScalarFromConstant<float>(clip_min);
-  float clip_max_imm = GetScalarFromConstant<float>(clip_max);
-
-  // x * idom_scale = y * odom_scale
-  // => y = x * idom_scale / odom_scale
-  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
-    // int32->int8
-    Expr data = n->data;
-    float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
-    float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
-    if (idom_scale_imm == odom_scale_imm) {
-      // same domain scale, only clip
-      data = Clip(data, clip_min_imm, clip_max_imm);
-      return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
-    }
-
-    float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
-    CHECK_NE(shift_nbit, 0);
-    if (static_cast<int>(shift_nbit) == shift_nbit) {
-      if (shift_nbit > 0) {
-        // use right shift
-        if (cfg->round_for_shift) {
-          float round_bias = std::pow(2.0, shift_nbit - 1);
-          data = Add(data, MakeConstantScalar(cfg->dtype_activation,
-                                              static_cast<int>(round_bias)));
-        }
-        data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
-                                                   static_cast<int>(shift_nbit)));
-      } else {
-        data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
-                                                  static_cast<int>(shift_nbit)));
-      }
-      data = Clip(data, clip_min_imm, clip_max_imm);
-      return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
-    } else {
-      // float computation
-      data = Cast(data, Float(32));
-      Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
-      Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
-      return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
-    }
-  }
-
-  // quantize from real
-  CHECK(!new_args[0]->derived_from<TempExprNode>());
-  Expr data = new_args[0];
-  Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm));
-  Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
-  return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
-}
-
-Expr FoldConstantOpt(const Expr& expr) {
-  auto mod = ModuleNode::FromExpr(expr);
-  mod = transform::FoldConstant()(mod);
-  auto entry_func = mod->Lookup("main");
-  return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
-}
-
-RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
-
-
-Expr Conv2dRealize(const Call& ref_call,
-                   const Array<Expr>& new_args,
-                   const NodeRef& ctx) {
-  const QConfig& cfg = QConfig::Current();
-  CHECK_EQ(new_args.size(), 2);
-  if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) {
-    return Expr(nullptr);
-  }
-  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
-  CHECK(lhs);
-  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
-  CHECK(rhs);
-
-  Expr ldata = lhs->data;
-  if (lhs->dtype != cfg->dtype_input) {
-    ldata = Cast(ldata, cfg->dtype_input);
-  }
-  Expr rdata = Cast(rhs->data, cfg->dtype_weight);
-
-  const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
-  auto attrs = make_node<Conv2DAttrs>();
-  *attrs = *ref_attrs;
-  DataType out_dtype = cfg->dtype_activation;
-  attrs->out_dtype = out_dtype;
-
-  Expr ret = CallNode::make(ref_call->op,
-    {ldata, rdata}, Attrs(attrs), ref_call->type_args);
-  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
-  Expr dom_scale = FoldConstantOpt(mul);
-  return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
-}
-
-RELAY_REGISTER_OP("nn.conv2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
-
-
-Expr DenseRealize(const Call& ref_call,
-                  const Array<Expr>& new_args,
-                  const NodeRef& ctx) {
-  const QConfig& cfg = QConfig::Current();
-  CHECK_EQ(new_args.size(), 2);
-  if (!new_args[0]->derived_from<TempExprNode>() || !new_args[1]->derived_from<TempExprNode>()) {
-    return Expr(nullptr);
-  }
-  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
-  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
-
-  Expr ldata = lhs->data;
-  if (lhs->dtype != cfg->dtype_input) {
-    ldata = Cast(ldata, cfg->dtype_input);
-  }
-  Expr rdata = Cast(rhs->data, cfg->dtype_weight);
-
-  const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
-  auto attrs = make_node<DenseAttrs>();
-  *attrs = *ref_attrs;
-  DataType out_dtype = cfg->dtype_activation;
-  attrs->out_dtype = out_dtype;
-
-  Expr ret = CallNode::make(ref_call->op,
-          {ldata, rdata}, Attrs(attrs), ref_call->type_args);
-  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
-  Expr dom_scale = FoldConstantOpt(mul);
-  return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
-}
-
-RELAY_REGISTER_OP("nn.dense")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
-
-
-Expr MulRealize(const Call& ref_call,
-                const Array<Expr>& new_args,
-                const NodeRef& ctx) {
-  const QConfig& cfg = QConfig::Current();
-  CHECK_EQ(new_args.size(), 2);
-  if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
-    // execute the operation with activation data type.
-    const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
-    const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
-    Expr ldata = lhs->data;
-    Expr rdata = rhs->data;
-
-    DataType dtype = cfg->dtype_activation;
-    if (lhs->dtype != dtype) {
-      ldata = Cast(ldata, dtype);
-    }
-    if (rhs->dtype != dtype) {
-      rdata = Cast(rdata, dtype);
-    }
-
-    Expr ret = ForwardOp(ref_call, {ldata, rdata});
-    Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
-    Expr dom_scale = FoldConstantOpt(mul);
-    return QRealizeIntExprNode::make(ret, dom_scale, dtype);
-  }
-  CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
-  return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("multiply")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
-
-
-float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
-  if (nptrs.size() == 2) {
-    // x = a * s1, y = b * s2
-    // x + y = (a * s1 / s2 + b) * s2, if s1 > s2
-    //       = (a + b * s2 / s1) * s1, if s2 > s1
-    float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
-    float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
-    return s1 > s2 ? s2 : s1;
-  } else {
-    const QConfig& cfg = QConfig::Current();
-    float scale = cfg->global_scale;
-    return scale / std::pow(2.0, cfg->nbit_activation - 1);
-  }
-}
-
-
-/* \brief Unify the dom scale of arguments */
-Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
-                            DataType* dtype_ptr, Expr* scale_ptr) {
-  static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
-  const QConfig& cfg = QConfig::Current();
-
-  std::vector<const QRealizeIntExprNode*> nptrs;
-  Array<Expr> ret;
-  for (auto arg : args) {
-    const auto* nptr = arg.as<QRealizeIntExprNode>();
-    CHECK(nptr);
-    nptrs.push_back(nptr);
-    ret.push_back(nptr->data);
-  }
-
-  // unify the data type
-  CHECK_EQ(ref_args.size(), args.size());
-  DataType dtype;
-  if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
-    dtype = cfg->dtype_input;
-  } else {
-    dtype = cfg->dtype_activation;
-  }
-  for (size_t i = 0; i < ret.size(); ++i) {
-    auto ref_arg = ref_args[i].as<CallNode>();
-    if (nptrs[i]->dtype != dtype) {
-      ret.Set(i, Cast(ret[i], dtype));
-    } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
-               ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
-      auto new_arg = Cast(ret[i], cfg->dtype_input);
-      new_arg = StopFusion(new_arg);
-      ret.Set(i, Cast(new_arg, dtype));
-    }
-  }
-
-  // unify the dom_scale
-  float s = ChooseDomScale(nptrs);
-  Expr dom_scale = MakeConstantScalar(Float(32), s);
-  for (size_t i = 0; i < ret.size(); ++i) {
-    float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
-    ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
-  }
-
-  *dtype_ptr = dtype;
-  *scale_ptr = dom_scale;
-  return ret;
-}
-
-Expr AddRealize(const Call& ref_call,
-                const Array<Expr>& new_args,
-                const NodeRef& ctx) {
-  CHECK_EQ(new_args.size(), 2);
-  if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
-    DataType dtype;
-    Expr dom_scale;
-    Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
-    Expr ret = ForwardOp(ref_call, ret_args);
-    return QRealizeIntExprNode::make(ret, dom_scale, dtype);
-  }
-
-  CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
-  return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("add")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
-
-Expr ClipRealize(const Call& ref_call,
-                 const Array<Expr>& new_args,
-                 const NodeRef& ctx) {
-  CHECK_EQ(new_args.size(), 1);
-  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
-    const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
-    auto attrs = make_node<ClipAttrs>();
-    double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
-    attrs->a_min = ref_attrs->a_min / dom_scale;
-    attrs->a_max = ref_attrs->a_max / dom_scale;
-
-    Expr ret = CallNode::make(ref_call->op,
-      {n->data}, Attrs(attrs), ref_call->type_args);
-    return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
-  }
-  CHECK(!new_args[0]->derived_from<TempExprNode>());
-  return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("clip")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
-
-
-Expr ConcatenateRealize(const Call& ref_call,
-                        const Array<Expr>& new_args,
-                        const NodeRef& ctx) {
-  CHECK_EQ(new_args.size(), 1);
-  CHECK_EQ(ref_call->args.size(), 1);
-
-  const auto* tuple = new_args[0].as<TupleNode>();
-  const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
-  CHECK(tuple);
-  CHECK(ref_tuple);
-  const Array<Expr>& arr = tuple->fields;
-  const Array<Expr>& ref_arr = ref_tuple->fields;
-
-  if (arr[0].as<QRealizeIntExprNode>()) {
-    DataType dtype;
-    Expr dom_scale;
-    Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
-    Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
-    return QRealizeIntExprNode::make(ret, dom_scale, dtype);
-  } else {
-    for (auto arg : new_args) {
-      CHECK(!arg->derived_from<TempExprNode>());
-    }
-    return Expr(nullptr);
-  }
-}
-
-RELAY_REGISTER_OP("concatenate")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
-
-
-/* \brief forward the original operator */
-Expr IdentityRealize(const Call& ref_call,
-                     const Array<Expr>& new_args,
-                     const NodeRef& ctx) {
-  CHECK_EQ(new_args.size(), 1);
-  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
-    Expr ret = ForwardOp(ref_call, {n->data});
-    return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
-  }
-  CHECK(!new_args[0]->derived_from<TempExprNode>());
-  return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("nn.relu")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
-
-RELAY_REGISTER_OP("strided_slice")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
-
-RELAY_REGISTER_OP("annotation.stop_fusion")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
-
-/* \brief for unary operators which requantize its input to dtype_nbit */
-Expr CastDtypeInputRealize(const Call& ref_call,
-                           const Array<Expr>& new_args,
-                           const NodeRef& ctx) {
-  const QConfig& cfg = QConfig::Current();
-  CHECK_EQ(new_args.size(), 1);
-  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
-    Expr data = Cast(n->data, cfg->dtype_input);
-    Expr ret = ForwardOp(ref_call, {data});
-    return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
-  }
-  CHECK(!new_args[0]->derived_from<TempExprNode>());
-  return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("nn.max_pool2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
-
-
-Expr AvgPoolRealize(const Call& ref_call,
-                    const Array<Expr>& new_args,
-                    const NodeRef& ctx) {
-  const QConfig& cfg = QConfig::Current();
-  CHECK_EQ(new_args.size(), 1);
-  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
-    Expr data = n->data;
-    if (n->dtype != cfg->dtype_activation) {
-      data = Cast(n->data, cfg->dtype_activation);
-    }
-    Expr ret = ForwardOp(ref_call, {data});
-    return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
-  }
-  CHECK(!new_args[0]->derived_from<TempExprNode>());
-  return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("nn.avg_pool2d")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
-
-Expr ForceCastRealize(const Call& ref_call,
-                      const Array<Expr>& new_args,
-                      const NodeRef& ctx) {
-  const QConfig& cfg = QConfig::Current();
-  CHECK_EQ(new_args.size(), 1);
-  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
-    Expr ret = Cast(n->data, cfg->dtype_input);
-    return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
-  }
-  CHECK(!new_args[0]->derived_from<TempExprNode>());
-  return Expr(nullptr);
-}
-
-RELAY_REGISTER_OP("annotation.force_cast")
-.set_attr<FForwardRewrite>("FQRealizeRewrite", ForceCastRealize);
-
-TVM_REGISTER_API("relay._quantize.realize")
-.set_body_typed<Expr(Expr)>([](const Expr& e) {
-  Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
-  return ret;
-});
-
-
-// =============
-// qconfig
-
-QConfig qconfig() {
-  return QConfig(make_node<QConfigNode>());
-}
-
 /*! \brief Entry to hold the BuildConfig context stack. */
 struct TVMQConfigThreadLocalEntry {
   /*! \brief The default build config if the stack is empty */
@@ -584,7 +90,7 @@ struct TVMQConfigThreadLocalEntry {
   std::stack<QConfig> context_stack;
 
   TVMQConfigThreadLocalEntry() :
-    default_config(qconfig()) {
+    default_config(make_node<QConfigNode>()) {
   }
 };
 
@@ -620,8 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   p->stream << "nbit_activation=" << op->nbit_activation << ", ";
   p->stream << "global_scale=" << op->global_scale << ", ";
   p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
+  p->stream << "do_simulation==" << op->do_simulation << ", ";
   p->stream << "round_for_shift==" << op->round_for_shift << ", ";
-  p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
   p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
   p->stream << ")";
 });
@@ -635,95 +141,6 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
 TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
 .set_body_typed(QConfig::ExitQConfigScope);
 
-Pass QuantizeAnnotate() {
-  std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
-    if (e->derived_from<TempExprNode>()) {
-      const auto* n = e.as<QAnnotateExprNode>();
-      CHECK(n);
-      const PackedFunc* f =
-          runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
-      Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
-      return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
-    }
-    return e;
-  };
-
-  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
-    [=](Function f, Module m, PassContext pc) {
-      auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
-      auto new_params = func->params;
-      for (const auto& x : FreeVars(func)) {
-        new_params.push_back(x);
-      }
-      return FunctionNode::make(new_params,
-                                func->body,
-                                func->ret_type,
-                                func->type_params,
-                                func->attrs);
-  };
-  return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
-}
-
-TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
-.set_body_typed(QuantizeAnnotate);
-
-Pass QuantizeRealizePass() {
-  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
-    [=](Function f, Module m, PassContext pc) {
-      return Downcast<Function>(
-          ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
-  };
-  return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
-}
-
-TVM_REGISTER_API("relay._quantize.QuantizeRealize")
-.set_body_typed(QuantizeRealizePass);
-
-Pass QuantizeRewriteForVTAPass() {
-  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
-    [=](Function f, Module m, PassContext pc) {
-      return Downcast<Function>(
-          ForwardRewrite(f, "FQVTARewrite", nullptr, nullptr));
-  };
-  return CreateFunctionPass(pass_func, 1, "QuantizeRewriteForVTA", {});
-}
-
-TVM_REGISTER_API("relay._quantize.QuantizeRewriteForVTA")
-.set_body_typed(QuantizeRewriteForVTAPass);
-
-// =============
-// Insert stop_fusion for vta.
-
-
-Expr QVTAExprNode::Realize() const {
-  Expr ret = ForceCast(this->expr);
-  return StopFusion(ret);
-}
-
-QVTAExpr QVTAExprNode::make(Expr expr) {
-  auto rnode = make_node<QVTAExprNode>();
-  rnode->expr = expr;
-  return QVTAExpr(rnode);
-}
-
-TVM_REGISTER_API("relay._quantize.make_vta_expr")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = QVTAExprNode::make(args[0]);
-  });
-
-TVM_REGISTER_API("relay._quantize.make_stop_fusion")
-.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
-  return StopFusion(expr);
-});
-
-TVM_REGISTER_API("relay._quantize.temp_expr_realize")
-.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
-  const QVTAExprNode* n = expr.as<QVTAExprNode>();
-  CHECK(n);
-  return n->Realize();
-});
-
-
 }  // namespace quantize
 }  // namespace relay
 }  // namespace tvm
index 4965a70..4c153d5 100644 (file)
@@ -59,104 +59,8 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
   }
 };
 
-/*!
- * \brief TempExpr used during annotate forward rewrite.
- */
-class QAnnotateExpr;
-/*!
- * \brief TempExprNode used during annotate forward rewrite.
- */
-class QAnnotateExprNode : public TempExprNode {
- public:
-  /*! \brief The original expression */
-  Expr expr;
-  /*! \brief The kind of annotate field */
-  QAnnotateKind kind;
-
-  void VisitAttrs(tvm::AttrVisitor* v) final {
-    v->Visit("expr", &expr);
-    v->Visit("kind", &kind);
-  }
-
-  TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
-
-  Expr Realize() const final;
-
-  static constexpr const char* _type_key = "relay.QAnnotateExpr";
-  TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
-
-
-/*!
- * \brief TempExpr used to insert `force_cast` for VTA.
- */
-class QVTAExpr;
-/*!
- * \brief TempExprNode used to insert `force_cast` for VTA.
- */
-class QVTAExprNode : public TempExprNode {
- public:
-  /*! \brief The original expression */
-  Expr expr;
-
-  void VisitAttrs(tvm::AttrVisitor* v) final {
-    v->Visit("expr", &expr);
-  }
-
-  TVM_DLL static QVTAExpr make(Expr expr);
-
-  Expr Realize() const final;
-
-  static constexpr const char* _type_key = "relay.QVTAExpr";
-  TVM_DECLARE_NODE_TYPE_INFO(QVTAExprNode, TempExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(QVTAExpr, QVTAExprNode, TempExpr);
-
-
-/*! \brief TempExpr used during realize forward rewrite. */
-class QRealizeExpr;
-/*! \brief TempExpr representing integer. */
-class QRealizeIntExpr;
-
-class QRealizeExprNode : public TempExprNode {
- public:
-  /*! \brief The original expression */
-  Expr data;
-  static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
-  TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr);
-
-
-class QRealizeIntExprNode : public QRealizeExprNode {
- public:
-  Expr dom_scale;
-  /*! \brief current data type */
-  DataType dtype;
-
-  void VisitAttrs(tvm::AttrVisitor* v) final {
-    v->Visit("data", &data);
-    v->Visit("dom_scale", &dom_scale);
-    v->Visit("dtype", &dtype);
-  }
-
-  Expr Realize() const final;
-
-  TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
-
-  static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
-  TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr);
-
 
 class QConfig;
-
 /*!
 * \brief Container for build configuration options
 */
@@ -170,8 +74,8 @@ class QConfigNode : public Node {
   DataType dtype_activation = Int(32);
   double global_scale = 8.0;
   Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
+  bool do_simulation = false;
   bool round_for_shift = true;
-  bool store_lowbit_output = true;
   Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
 
   void VisitAttrs(AttrVisitor* v) final {
@@ -183,8 +87,8 @@ class QConfigNode : public Node {
     v->Visit("dtype_activation", &dtype_activation);
     v->Visit("global_scale", &global_scale);
     v->Visit("skip_conv_layers", &skip_conv_layers);
+    v->Visit("do_simulation", &do_simulation);
     v->Visit("round_for_shift", &round_for_shift);
-    v->Visit("store_lowbit_output", &store_lowbit_output);
     v->Visit("debug_enabled_ops", &debug_enabled_ops);
   }
 
@@ -250,12 +154,6 @@ struct QConfigContext {
   }
 };
 
-/*!
-* \brief Construct a BuildConfig containing a new BuildConfigNode
-* \return The new BuildConfig
-*/
-TVM_DLL QConfig qconfig();
-
 }  // namespace quantize
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc
new file mode 100644 (file)
index 0000000..e4bc63a
--- /dev/null
@@ -0,0 +1,525 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ *
+ * \file realize.cc
+ *
+ * \brief Realizing the simulated graph into real low-precision
+ *   graph.
+ */
+
+#include <tvm/relay/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include "./quantize.h"
+#include "../pattern_util.h"
+
+namespace tvm {
+namespace relay {
+namespace quantize {
+
+using namespace relay::transform;
+
+class QRealizeExpr;
+class QRealizeIntExpr;
+
+class QRealizeExprNode : public TempExprNode {
+ public:
+  Expr data;
+  static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
+  TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr);
+
+
+class QRealizeIntExprNode : public QRealizeExprNode {
+ public:
+  Expr dom_scale;
+  DataType dtype;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("data", &data);
+    v->Visit("dom_scale", &dom_scale);
+    v->Visit("dtype", &dtype);
+  }
+
+  Expr Realize() const final;
+
+  TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
+
+  static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
+  TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr);
+
+
+Expr QRealizeIntExprNode::Realize() const {
+  Expr data = this->data;
+  // dequantize
+  data = Cast(data, Float(32));
+  data = Multiply(data, this->dom_scale);
+  return data;
+}
+
+QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
+  NodePtr<QRealizeIntExprNode> n = make_node<QRealizeIntExprNode>();
+  n->data = std::move(data);
+  n->dom_scale = std::move(dom_scale);
+  n->dtype = std::move(dtype);
+  return QRealizeIntExpr(n);
+}
+
+
+inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
+  return CallNode::make(ref_call->op,
+    args, ref_call->attrs, ref_call->type_args);
+}
+
+
+/* calculate `data * s1 / s2`, use shift if possible */
+inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
+  // here we assume the dtype of data is dtype activation
+  if (s1 == s2) return data;
+
+  float factor = s1 / s2;
+  float shift_factor = std::log2(factor);
+  CHECK_GT(shift_factor, 0);
+  if (static_cast<int>(shift_factor) == shift_factor) {
+    return LeftShift(data, MakeConstantScalar(dtype,
+                                              static_cast<int>(shift_factor)));
+  } else if (static_cast<int>(factor) == factor) {
+    return Multiply(data, MakeConstantScalar(dtype, factor));
+  } else {
+    LOG(FATAL) << "fall back to float computation";
+    data = Cast(data, Float(32));
+    data = Multiply(data, MakeConstantScalar(Float(32), factor));
+    return Cast(Round(data), dtype);
+  }
+}
+
+Expr QuantizeRealize(const Call& ref_call,
+                     const Array<Expr>& new_args,
+                     const NodeRef& ctx) {
+  const QConfig& cfg = QConfig::Current();
+  // do not handle data type cast
+  const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
+  CHECK_EQ(param->rounding, "round");
+
+  Expr dom_scale = new_args[1];
+  Expr clip_min = new_args[2];
+  Expr clip_max = new_args[3];
+
+  float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
+  float clip_min_imm = GetScalarFromConstant<float>(clip_min);
+  float clip_max_imm = GetScalarFromConstant<float>(clip_max);
+
+  // x * idom_scale = y * odom_scale
+  // => y = x * idom_scale / odom_scale
+  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+    // int32->int8
+    Expr data = n->data;
+    float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
+    float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
+    if (idom_scale_imm == odom_scale_imm) {
+      // same domain scale, only clip
+      data = Clip(data, clip_min_imm, clip_max_imm);
+      return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
+    }
+
+    float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
+    CHECK_GT(shift_nbit, 0);
+    if (static_cast<int>(shift_nbit) == shift_nbit) {
+      // use right shift
+      if (cfg->round_for_shift) {
+        float round_bias = std::pow(2.0, shift_nbit - 1);
+        data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
+      }
+      data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
+                                                 static_cast<int>(shift_nbit)));
+      data = Clip(data, clip_min_imm, clip_max_imm);
+      return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
+    } else {
+      // float computation
+      data = Cast(data, Float(32));
+      Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
+      Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
+      return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
+    }
+  }
+
+  // quantize from real
+  CHECK(!new_args[0]->derived_from<TempExprNode>());
+  Expr data = new_args[0];
+  Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm));
+  Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
+  return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
+}
+
+Expr FoldConstantOpt(const Expr& expr) {
+  auto mod = ModuleNode::FromExpr(expr);
+  mod = transform::FoldConstant()(mod);
+  auto entry_func = mod->Lookup("main");
+  return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
+}
+
+RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
+
+
+Expr Conv2dRealize(const Call& ref_call,
+                   const Array<Expr>& new_args,
+                   const NodeRef& ctx) {
+  const QConfig& cfg = QConfig::Current();
+  CHECK_EQ(new_args.size(), 2);
+  if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) {
+    return Expr(nullptr);
+  }
+  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
+  CHECK(lhs);
+  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
+  CHECK(rhs);
+
+  Expr ldata = lhs->data;
+  if (lhs->dtype != cfg->dtype_input) {
+    ldata = Cast(ldata, cfg->dtype_input);
+  }
+  Expr rdata = Cast(rhs->data, cfg->dtype_weight);
+
+  const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
+  auto attrs = make_node<Conv2DAttrs>();
+  *attrs = *ref_attrs;
+  DataType out_dtype = cfg->dtype_activation;
+  attrs->out_dtype = out_dtype;
+
+  Expr ret = CallNode::make(ref_call->op,
+    {ldata, rdata}, Attrs(attrs), ref_call->type_args);
+  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
+  Expr dom_scale = FoldConstantOpt(mul);
+  return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
+}
+
+RELAY_REGISTER_OP("nn.conv2d")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
+
+
+Expr DenseRealize(const Call& ref_call,
+                  const Array<Expr>& new_args,
+                  const NodeRef& ctx) {
+  const QConfig& cfg = QConfig::Current();
+  CHECK_EQ(new_args.size(), 2);
+  if (!new_args[0]->derived_from<TempExprNode>() || !new_args[1]->derived_from<TempExprNode>()) {
+    return Expr(nullptr);
+  }
+  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
+  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
+
+  Expr ldata = lhs->data;
+  if (lhs->dtype != cfg->dtype_input) {
+    ldata = Cast(ldata, cfg->dtype_input);
+  }
+  Expr rdata = Cast(rhs->data, cfg->dtype_weight);
+
+  const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
+  auto attrs = make_node<DenseAttrs>();
+  *attrs = *ref_attrs;
+  DataType out_dtype = cfg->dtype_activation;
+  attrs->out_dtype = out_dtype;
+
+  Expr ret = CallNode::make(ref_call->op,
+          {ldata, rdata}, Attrs(attrs), ref_call->type_args);
+  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
+  Expr dom_scale = FoldConstantOpt(mul);
+  return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
+}
+
+RELAY_REGISTER_OP("nn.dense")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
+
+
+Expr MulRealize(const Call& ref_call,
+                const Array<Expr>& new_args,
+                const NodeRef& ctx) {
+  const QConfig& cfg = QConfig::Current();
+  CHECK_EQ(new_args.size(), 2);
+  if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
+    // execute the operation with activation data type.
+    const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
+    const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
+    Expr ldata = lhs->data;
+    Expr rdata = rhs->data;
+
+    DataType dtype = cfg->dtype_activation;
+    if (lhs->dtype != dtype) {
+      ldata = Cast(ldata, dtype);
+    } else {
+      CHECK_EQ(lhs->dtype, dtype);
+    }
+    if (rhs->dtype != dtype) {
+      rdata = Cast(rdata, dtype);
+    } else {
+      CHECK_EQ(rhs->dtype, dtype);
+    }
+
+    Expr ret = ForwardOp(ref_call, {ldata, rdata});
+    Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
+    Expr dom_scale = FoldConstantOpt(mul);
+    return QRealizeIntExprNode::make(ret, dom_scale, dtype);
+  }
+  CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
+  return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("multiply")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
+
+
+float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
+  if (nptrs.size() == 2) {
+    // x = a * s1, y = b * s2
+    // x + y = (a * s1 / s2 + b) * s2, if s1 > s2
+    //       = (a + b * s2 / s1) * s1, if s2 > s1
+    float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
+    float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
+    return s1 > s2 ? s2 : s1;
+  } else {
+    const QConfig& cfg = QConfig::Current();
+    float scale = cfg->global_scale;
+    return scale / std::pow(2.0, cfg->nbit_activation - 1);
+  }
+}
+
+
+/* \brief Unify the dom scale of arguments */
+Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
+                            DataType* dtype_ptr, Expr* scale_ptr) {
+  static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
+  const QConfig& cfg = QConfig::Current();
+
+  std::vector<const QRealizeIntExprNode*> nptrs;
+  Array<Expr> ret;
+  for (auto arg : args) {
+    const auto* nptr = arg.as<QRealizeIntExprNode>();
+    CHECK(nptr);
+    nptrs.push_back(nptr);
+    ret.push_back(nptr->data);
+  }
+
+  // unify the data type
+  CHECK_EQ(ref_args.size(), args.size());
+  DataType dtype;
+
+  if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
+    dtype = cfg->dtype_input;
+  } else {
+    dtype = cfg->dtype_activation;
+  }
+  for (size_t i = 0; i < ret.size(); ++i) {
+    auto ref_arg = ref_args[i].as<CallNode>();
+    if (nptrs[i]->dtype != dtype) {
+      ret.Set(i, Cast(ret[i], dtype));
+    } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
+               ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
+      auto new_arg = Cast(ret[i], cfg->dtype_input);
+      new_arg = StopFusion(new_arg);
+      ret.Set(i, Cast(new_arg, dtype));
+    }
+  }
+
+  // unify the dom_scale
+  float s = ChooseDomScale(nptrs);
+  Expr dom_scale = MakeConstantScalar(Float(32), s);
+  for (size_t i = 0; i < ret.size(); ++i) {
+    float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
+    ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
+  }
+
+  *dtype_ptr = dtype;
+  *scale_ptr = dom_scale;
+  return ret;
+}
+
+Expr AddRealize(const Call& ref_call,
+                const Array<Expr>& new_args,
+                const NodeRef& ctx) {
+  CHECK_EQ(new_args.size(), 2);
+  if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
+    DataType dtype;
+    Expr dom_scale;
+    Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
+    Expr ret = ForwardOp(ref_call, ret_args);
+    return QRealizeIntExprNode::make(ret, dom_scale, dtype);
+  }
+
+  CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
+  return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("add")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
+
+Expr ClipRealize(const Call& ref_call,
+                 const Array<Expr>& new_args,
+                 const NodeRef& ctx) {
+  CHECK_EQ(new_args.size(), 1);
+  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+    const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
+    auto attrs = make_node<ClipAttrs>();
+    double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
+    attrs->a_min = ref_attrs->a_min / dom_scale;
+    attrs->a_max = ref_attrs->a_max / dom_scale;
+
+    Expr ret = CallNode::make(ref_call->op,
+      {n->data}, Attrs(attrs), ref_call->type_args);
+    return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
+  }
+  CHECK(!new_args[0]->derived_from<TempExprNode>());
+  return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("clip")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
+
+
+Expr ConcatenateRealize(const Call& ref_call,
+                        const Array<Expr>& new_args,
+                        const NodeRef& ctx) {
+  CHECK_EQ(new_args.size(), 1);
+  CHECK_EQ(ref_call->args.size(), 1);
+
+  const auto* tuple = new_args[0].as<TupleNode>();
+  const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
+  CHECK(tuple);
+  CHECK(ref_tuple);
+  const Array<Expr>& arr = tuple->fields;
+  const Array<Expr>& ref_arr = ref_tuple->fields;
+
+  if (arr[0].as<QRealizeIntExprNode>()) {
+    DataType dtype;
+    Expr dom_scale;
+    Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
+    Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
+    return QRealizeIntExprNode::make(ret, dom_scale, dtype);
+  } else {
+    for (auto arg : new_args) {
+      CHECK(!arg->derived_from<TempExprNode>());
+    }
+    return Expr(nullptr);
+  }
+}
+
+RELAY_REGISTER_OP("concatenate")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
+
+
+/* \brief forward the original operator */
+Expr IdentityRealize(const Call& ref_call,
+                     const Array<Expr>& new_args,
+                     const NodeRef& ctx) {
+  CHECK_EQ(new_args.size(), 1);
+  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+    Expr ret = ForwardOp(ref_call, {n->data});
+    return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
+  }
+  CHECK(!new_args[0]->derived_from<TempExprNode>());
+  return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("nn.relu")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
+RELAY_REGISTER_OP("strided_slice")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
+RELAY_REGISTER_OP("annotation.stop_fusion")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
+/* \brief for unary operators which requantize its input to dtype_nbit */
+Expr CastDtypeInputRealize(const Call& ref_call,
+                           const Array<Expr>& new_args,
+                           const NodeRef& ctx) {
+  const QConfig& cfg = QConfig::Current();
+  CHECK_EQ(new_args.size(), 1);
+  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+    Expr data = Cast(n->data, cfg->dtype_input);
+    Expr ret = ForwardOp(ref_call, {data});
+    return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
+  }
+  CHECK(!new_args[0]->derived_from<TempExprNode>());
+  return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("nn.max_pool2d")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
+
+
+Expr AvgPoolRealize(const Call& ref_call,
+                    const Array<Expr>& new_args,
+                    const NodeRef& ctx) {
+  const QConfig& cfg = QConfig::Current();
+  CHECK_EQ(new_args.size(), 1);
+  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+    Expr data = n->data;
+    if (n->dtype != cfg->dtype_activation) {
+      data = Cast(n->data, cfg->dtype_activation);
+    }
+    Expr ret = ForwardOp(ref_call, {data});
+    return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
+  }
+  CHECK(!new_args[0]->derived_from<TempExprNode>());
+  return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("nn.avg_pool2d")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
+
+Expr CastHintRealize(const Call& ref_call,
+                     const Array<Expr>& new_args,
+                     const NodeRef& ctx) {
+  const auto param = ref_call->attrs.as<CastHintAttrs>();
+  CHECK_EQ(new_args.size(), 1);
+  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
+    Expr ret = Cast(n->data, param->dtype);
+    return QRealizeIntExprNode::make(ret, n->dom_scale, param->dtype);
+  }
+  CHECK(!new_args[0]->derived_from<TempExprNode>());
+  return Expr(nullptr);
+}
+
+RELAY_REGISTER_OP("annotation.cast_hint")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);
+
+Pass QuantizeRealizePass() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      return Downcast<Function>(
+          ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
+  };
+  return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
+}
+
+TVM_REGISTER_API("relay._quantize.QuantizeRealize")
+.set_body_typed(QuantizeRealizePass);
+
+}  // namespace quantize
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/nightly/quantization/test_quantization_accuracy.py b/tests/python/nightly/quantization/test_quantization_accuracy.py
new file mode 100644 (file)
index 0000000..f047952
--- /dev/null
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from collections import namedtuple
+import tvm
+from tvm import relay
+from tvm.relay import quantize as qtz
+import mxnet as mx
+from mxnet import gluon
+import logging
+import os
+
+logging.basicConfig(level=logging.INFO)
+
+Config = namedtuple('Config', ['model', 'nbit_input',  'dtype_input', 'nbit_output', 'dtype_output', 'global_scale', 'expected_acc'])
+
+
+def get_val_data(model_name,
+                 rec_val,
+                 batch_size,
+                 num_workers=4):
+    rec_val = os.path.expanduser(rec_val)
+    mean_rgb = [123.68, 116.779, 103.939]
+    std_rgb = [58.393, 57.12, 57.375]
+    def batch_fn(batch, ctx):
+        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
+        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
+        return data, label
+
+    img_size = 299 if model_name == 'inceptionv3' else 224
+    val_data = mx.io.ImageRecordIter(
+        path_imgrec         = rec_val,
+        preprocess_threads  = num_workers,
+        shuffle             = False,
+        batch_size          = batch_size,
+        resize              = 256,
+        data_shape          = (3, img_size, img_size),
+        mean_r              = mean_rgb[0],
+        mean_g              = mean_rgb[1],
+        mean_b              = mean_rgb[2],
+        std_r               = std_rgb[0],
+        std_g               = std_rgb[1],
+        std_b               = std_rgb[2],
+    )
+    return val_data, batch_fn
+
+
+def get_model(model_name, batch_size, qconfig, target=None, original=False, simulated=False):
+    gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
+    img_size = 299 if model_name == 'inceptionv3' else 224
+    data_shape = (batch_size, 3, img_size, img_size)
+    mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
+    net = mod['main']
+
+    with relay.build_config(opt_level=3):
+        qfunc = relay.quantize.prerequisite_optimize(net, params=params)
+    logging.debug('original')
+    logging.debug(qfunc.astext(show_meta_data=False))
+    if original:
+        return qfunc
+
+    with qconfig:
+        logging.debug('current quantize config')
+        logging.debug(qtz.current_qconfig())
+        qfunc = qtz.quantize(qfunc)
+        logging.debug('after quantize')
+        logging.debug(qfunc.astext(show_meta_data=False))
+    return qfunc
+
+
+def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), log_interval=100):
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(model, target)
+    # create runtime module
+    m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+    m.set_input(**params)
+
+    # setup evaluaiton metric
+    dataset.reset()
+    batch_size = dataset.batch_size
+    acc_top1 = mx.metric.Accuracy()
+    acc_top5 = mx.metric.TopKAccuracy(5)
+    acc_top1.reset()
+    acc_top5.reset()
+    # Execute
+    for i, batch in enumerate(dataset):
+        data, label = batch_fn(batch, [mx.cpu(0)])
+        m.run(data=data[0].asnumpy())
+        out_arr = m.get_output(0)
+        acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())])
+        acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())])
+
+        if not (i + 1) % log_interval:
+            _, top1 = acc_top1.get()
+            _, top5 = acc_top5.get()
+            nsamples = (i + 1) * batch_size
+            logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5)
+    logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5)
+    return top1
+
+def test_quantize_acc(cfg, rec_val):
+    qconfig = qtz.qconfig(skip_conv_layers=[0],
+                          nbit_input=cfg.nbit_input,
+                          nbit_weight=cfg.nbit_input,
+                          global_scale=cfg.global_scale,
+                          dtype_input=cfg.dtype_input,
+                          dtype_weight=cfg.dtype_input,
+                          dtype_activation=cfg.dtype_output,
+                          debug_enabled_ops=None)
+
+    model = get_model(cfg.model, 32, qconfig, tvm.target.cuda())
+    val_data, batch_fn = get_val_data(cfg.model, rec_val=rec_val, batch_size=32)
+
+    acc = eval_acc(model, val_data, batch_fn)
+    assert acc > cfg.expected_acc
+    return acc
+
+
+if __name__ == "__main__":
+    #TODO(for user): replace the line with the path to imagenet validation dataset
+    rec_val = "/scratch/tqchen/imagenet/val.rec"
+
+    results = []
+    configs = [
+        Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.666),
+
+        Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.692),
+        Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.692),
+        Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.733),
+        Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.747),
+        Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.756),
+        # TODO: need to fix accuracy
+        # Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0),
+    ]
+
+    for config in configs:
+        acc = test_quantize_acc(config, rec_val)
+        results.append((config, acc))
+    for res in results:
+        print(res)
diff --git a/tests/python/relay/test_pass_quantize.py b/tests/python/relay/test_pass_quantize.py
deleted file mode 100644 (file)
index f6f67d6..0000000
+++ /dev/null
@@ -1,109 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import math
-import numpy as np
-import tvm
-from tvm import relay
-from tvm.relay import quantize as qtz
-from tvm.relay import transform
-
-
-def run_infer_type(expr):
-    mod = relay.Module.from_expr(expr)
-    mod = transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
-
-
-def make_dataset(graph, size=100):
-    args = run_infer_type(graph).params
-    def create_arr(var):
-        ttype = var.type_annotation
-        np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype)
-        return tvm.ndarray.array(np_arr)
-
-    params = {}
-    for arg in args:
-        if arg.name_hint == 'data':
-            dataset = [{'data': create_arr(arg)} for _ in range(size)]
-        else:
-            params[arg.name_hint] = create_arr(arg)
-    return dataset, params
-
-
-def test_simulated_quantize():
-    data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32"))
-    out = qtz._annotate.attach_simulated_quantize(data, 1)
-    out = run_infer_type(out)
-    assert out.checked_type == out.args[0].checked_type
-    assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32")
-    assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "float32")
-    assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
-
-
-def test_quantize_pass():
-    def quantize_weight(arr):
-        maximum = np.amax(np.abs(arr.asnumpy()))
-        scale = 2**math.ceil(math.log(maximum, 2))
-        out = np.around(arr.asnumpy() / scale * 128).astype('int8')
-        out = np.clip(out, -127, 127)
-        return relay.const(out, 'int8')
-
-    n, c, h, w = 1, 3, 224, 224
-    def make_graph(data):
-        weight = relay.var("conv_weight")
-        out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
-        out = relay.Function(relay.analysis.free_vars(out), out)
-        return out
-
-    def make_qgraph(data, weight):
-        out = data * relay.const(32.0)
-        out = relay.round(out)
-        out = relay.clip(out, a_min=-127, a_max=127)
-        out = out.astype('int8')
-
-        out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
-                              padding=(1, 1), channels=c, out_dtype='int32')
-        out = out.astype('float32')
-        out = relay.multiply(out, relay.const(0.00024414062))
-        out = relay.Function(relay.analysis.free_vars(out), out)
-        return out
-
-    np.random.seed(42)
-
-    data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
-    graph = make_graph(data)
-    dataset, params = make_dataset(graph, 10)
-
-    with qtz.qconfig(skip_conv_layers=None, global_scale=4.0,
-                     round_for_shift=False, store_lowbit_output=False):
-        qgraph0 = qtz.quantize(graph, params)
-        qgraph0 = run_infer_type(qgraph0)
-
-    conv_weight = quantize_weight(params['conv_weight'])
-    qgraph1 = make_qgraph(data, conv_weight)
-    qgraph1 = run_infer_type(qgraph1)
-
-    graph = relay.create_executor('graph')
-    res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
-    res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
-    tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
-
-
-if __name__ == "__main__":
-    test_simulated_quantize()
-    test_quantize_pass()
diff --git a/tests/scripts/task_python_nightly.sh b/tests/scripts/task_python_nightly.sh
new file mode 100755 (executable)
index 0000000..09f7e8a
--- /dev/null
@@ -0,0 +1,30 @@
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+set -u
+
+export PYTHONPATH=python:topi/python
+
+# Rebuild cython
+make cython3
+
+rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
+rm -rf topi/python/topi/*.pyc topi/python/topi/*/*.pyc topi/python/topi/*/*/*.pyc topi/python/topi/*/*/*/*.pyc 
+
+python3 -m nose -v topi/tests/python/nightly