[QUANTIZE] Add config switch for nn.dense layer type. (#5801)
authorBalint Cristian <cristian.balint@gmail.com>
Sun, 14 Jun 2020 17:37:32 +0000 (20:37 +0300)
committerGitHub <noreply@github.com>
Sun, 14 Jun 2020 17:37:32 +0000 (10:37 -0700)
python/tvm/relay/quantize/_annotate.py
python/tvm/relay/quantize/quantize.py
src/relay/quantize/quantize.h

index 5954e07..952a864 100644 (file)
@@ -173,11 +173,14 @@ def conv2d_rewrite(ref_call, new_args, ctx):
     return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
 
 
-# TODO(tmoreau89,ziheng) need to include an option to turn off dense quant
-# @register_annotate_function("nn.dense")
+@register_annotate_function("nn.dense")
 def dense_rewrite(ref_call, new_args, ctx):
     """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
     dense will be quantized to weight field. Output would be in activation field."""
+
+    if current_qconfig().skip_dense_layer:
+        return None
+
     if quantize_context().check_to_skip(ref_call):
         return None
 
index e5d5409..28ebf7f 100644 (file)
@@ -78,6 +78,7 @@ class QConfig(Object):
         "calibrate_mode": "global_scale",
         "global_scale": 8.0,
         "weight_scale": "power2",
+        "skip_dense_layer": True,
         "skip_conv_layers": [0],
         "do_simulation": False,
         "round_for_shift": True,
@@ -157,6 +158,9 @@ def qconfig(**kwargs):
         of two.
         max: Find the maximum of the absolute value of the tensor
 
+    skip_dense_layer: boolean
+        Whether to skip all nn.dense layer type. By default are skipped.
+
     skip_conv_layers: list
         Specifying which layers to be skipped. Provide a list of indices
         that indicate which conv2d layers to leave untouched. Start from 0.
index a883cb1..86f8926 100644 (file)
@@ -67,6 +67,7 @@ class QConfigNode : public Object {
   std::string calibrate_mode = "global_scale";
   double global_scale = 8.0;
   std::string weight_scale = "power2";
+  bool skip_dense_layer = true;
   Array<Expr> skip_conv_layers = Array<Expr>(ObjectPtr<Object>(nullptr));
   bool do_simulation = false;
   bool round_for_shift = true;
@@ -84,6 +85,7 @@ class QConfigNode : public Object {
     v->Visit("calibrate_mode", &calibrate_mode);
     v->Visit("global_scale", &global_scale);
     v->Visit("weight_scale", &weight_scale);
+    v->Visit("skip_dense_layer", &skip_dense_layer);
     v->Visit("skip_conv_layers", &skip_conv_layers);
     v->Visit("do_simulation", &do_simulation);
     v->Visit("round_for_shift", &round_for_shift);