[Relay] Option to select which convolution layers are quantized. (#3173)
authorJosh Fromm <jwfromm@uw.edu>
Thu, 16 May 2019 02:03:40 +0000 (19:03 -0700)
committerWuwei Lin <vincentl13x@gmail.com>
Thu, 16 May 2019 02:03:40 +0000 (10:03 +0800)
* Stashing for later maybe.

* Added new option to leave specific layers unquantized.

* Better error checking.

* remove unneeded import

* tab to spaces

* pylint fixes

* more pylint fixes

python/tvm/relay/quantize/_annotate.py
python/tvm/relay/quantize/quantize.py
src/relay/pass/quantize.cc
src/relay/pass/quantize.h
topi/python/topi/cuda/conv2d.py

index e52ce14..9bf546f 100644 (file)
@@ -156,6 +156,13 @@ def conv2d_rewrite(ref_call, new_args, ctx):
     if cnt < current_qconfig().skip_k_conv:
         _set_conv_counter(cnt + 1)
         return None
+
+    if current_qconfig().skip_conv_layers is not None:
+        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+        if cnt in leave_alone_indices:
+            _set_conv_counter(cnt + 1)
+            return None
+
     _set_conv_counter(cnt + 1)
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
@@ -168,6 +175,7 @@ def conv2d_rewrite(ref_call, new_args, ctx):
     rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
 
     expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
+
     return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
 
 
@@ -178,6 +186,11 @@ def dense_rewrite(ref_call, new_args, ctx):
     cnt = _conv_counter()
     if cnt < current_qconfig().skip_k_conv:
         return None
+    if current_qconfig().skip_conv_layers is not None:
+        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+        if cnt - 1 in leave_alone_indices:
+            return None
+
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
     rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
 
@@ -194,8 +207,13 @@ def dense_rewrite(ref_call, new_args, ctx):
 @register_annotate_function("multiply")
 def multiply_rewrite(ref_call, new_args, ctx):
     """Rewrite function for multiply."""
-    if _conv_counter() <= current_qconfig().skip_k_conv:
+    cnt = _conv_counter()
+    if cnt <= current_qconfig().skip_k_conv:
         return None
+    if current_qconfig().skip_conv_layers is not None:
+        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+        if cnt - 1 in leave_alone_indices:
+            return None
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
     rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@@ -216,8 +234,13 @@ def multiply_rewrite(ref_call, new_args, ctx):
 @register_annotate_function("add")
 def add_rewrite(ref_call, new_args, ctx):
     """Rewrite function for add."""
-    if _conv_counter() <= current_qconfig().skip_k_conv:
+    cnt = _conv_counter()
+    if cnt <= current_qconfig().skip_k_conv:
         return None
+    if current_qconfig().skip_conv_layers is not None:
+        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+        if cnt - 1 in leave_alone_indices:
+            return None
 
     lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
     rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@@ -244,8 +267,13 @@ def add_rewrite(ref_call, new_args, ctx):
 
 def identity_rewrite(ref_call, new_args, ctx):
     """Simply forward the original operation"""
-    if _conv_counter() <= current_qconfig().skip_k_conv:
+    cnt = _conv_counter()
+    if cnt <= current_qconfig().skip_k_conv:
         return None
+    if current_qconfig().skip_conv_layers is not None:
+        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+        if cnt - 1 in leave_alone_indices:
+            return None
 
     x_expr, x_kind = _get_expr_kind(new_args[0])
     if x_kind is None:
@@ -262,8 +290,14 @@ register_annotate_function("nn.avg_pool2d", identity_rewrite)
 
 def pool2d_rewrite(ref_call, new_args, ctx):
     """Rewrite function for max pool2d"""
-    if _conv_counter() <= current_qconfig().skip_k_conv:
+    cnt = _conv_counter()
+    if cnt <= current_qconfig().skip_k_conv:
         return None
+    if current_qconfig().skip_conv_layers is not None:
+        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+        if cnt - 1 in leave_alone_indices:
+            return None
+
     expr, x_kind = _get_expr_kind(new_args[0])
 
     if x_kind is None:
@@ -280,8 +314,13 @@ register_annotate_function("nn.max_pool2d", pool2d_rewrite)
 @register_annotate_function("concatenate")
 def concatenate_rewrite(ref_call, new_args, ctx):
     """Rewrite function for concatenate"""
-    if _conv_counter() <= current_qconfig().skip_k_conv:
+    cnt = _conv_counter()
+    if cnt <= current_qconfig().skip_k_conv:
         return None
+    if current_qconfig().skip_conv_layers is not None:
+        leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+        if cnt - 1 in leave_alone_indices:
+            return None
 
     input_tuple = new_args[0]
     expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
index b84d3eb..7fd0099 100644 (file)
@@ -71,6 +71,7 @@ class QConfig(NodeBase):
         "dtype_activation": "int32",
         "global_scale": 8.0,
         "skip_k_conv": 1,
+        "skip_conv_layers": None,
         "round_for_shift": True,
         "store_lowbit_output": True,
         "debug_enabled_ops": None,
@@ -139,6 +140,10 @@ def qconfig(**kwargs):
     skip_k_conv: int
         The number of skipped conv2d.
 
+    skip_conv_layers: list
+        Different way of specifying which layers to avoid. Provide a list of indices
+        that indicate which conv2d layers to leave untouched.
+
     round_for_shift: boolean
         Whether to add bias for rounding during shift.
 
index 7fd27b4..3a2e54c 100644 (file)
@@ -596,6 +596,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   p->stream << "nbit_activation=" << op->nbit_activation << ", ";
   p->stream << "global_scale=" << op->global_scale << ", ";
   p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
+  p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
   p->stream << "round_for_shift==" << op->round_for_shift << ", ";
   p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
   p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
index 4d26dd6..2c70da1 100644 (file)
@@ -126,6 +126,7 @@ class QConfigNode : public Node {
   DataType dtype_activation = Int(32);
   double global_scale = 8.0;
   int skip_k_conv = 1;
+  Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
   bool round_for_shift = true;
   bool store_lowbit_output = true;
   Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
@@ -140,6 +141,7 @@ class QConfigNode : public Node {
     v->Visit("dtype_activation", &dtype_activation);
     v->Visit("global_scale", &global_scale);
     v->Visit("skip_k_conv", &skip_k_conv);
+    v->Visit("skip_conv_layers", &skip_conv_layers);
     v->Visit("round_for_shift", &round_for_shift);
     v->Visit("store_lowbit_output", &store_lowbit_output);
     v->Visit("debug_enabled_ops", &debug_enabled_ops);
index 006e2fc..4d764b0 100644 (file)
@@ -105,7 +105,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
         return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
                              pre_computed=False)
     if cfg.template_key == 'int8':
-        return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
+        if (data.dtype == 'int8' or data.dtype == 'uint8'):
+            return conv2d_NCHWc_int8(
+                cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
 
     if layout == 'NCHW':
         return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)