[TOPI] Setting up AutoTVM template for Intel Int8 conv2D (#3955)
authorAnimesh Jain <anijain@umich.edu>
Mon, 16 Sep 2019 21:52:28 +0000 (14:52 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Mon, 16 Sep 2019 21:52:28 +0000 (05:52 +0800)
python/tvm/autotvm/task/topi_integration.py
topi/python/topi/nn/conv2d.py
topi/python/topi/x86/conv2d.py
topi/python/topi/x86/conv2d_int8.py

index bc43471..09f08ad 100644 (file)
@@ -85,6 +85,7 @@ class TaskExtractEnv:
             topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
             topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
             topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
+            topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
             topi.nn.dense: "topi_nn_dense",
             topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
             topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
@@ -100,6 +101,7 @@ class TaskExtractEnv:
             topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
             topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
             topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
+            topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8],
             topi.nn.dense: [topi.generic.schedule_dense],
             topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
             topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
@@ -111,6 +113,7 @@ class TaskExtractEnv:
         self.func_to_reflection = {
             topi.nn.conv2d:                 lambda x: setattr(topi.nn, 'conv2d', x),
             topi.nn.conv2d_NCHWc:           lambda x: setattr(topi.nn, 'conv2d_NCHWc', x),
+            topi.nn.conv2d_NCHWc_int8:      lambda x: setattr(topi.nn, 'conv2d_NCHWc_int8', x),
             topi.nn.depthwise_conv2d_nchw:  lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x),
             topi.nn.group_conv2d_nchw:      lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
             topi.nn.conv2d_transpose_nchw:  lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
index e2a4720..e52d0a6 100644 (file)
@@ -669,7 +669,6 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout,
                        name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
 
 
-
 def conv2d_winograd_weight_transform(kernel, tile_size):
     """Weight transformation for winograd
 
index 6f134ea..09e7a88 100644 (file)
@@ -27,7 +27,7 @@ from tvm.autotvm.task import get_config
 from .. import generic, tag
 from .. import nn
 from ..util import get_const_tuple, get_shape
-from ..nn.conv2d import conv2d, conv2d_NCHWc, \
+from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, \
     conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
 from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
 from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
@@ -77,7 +77,6 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
         else:
             conv2d_avx_common._fallback_schedule(cfg, wkl)
 
-
 def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
     """Create schedule configuration from input arguments"""
     dshape = get_const_tuple(data.shape)
@@ -92,19 +91,15 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
     elif pat.match(layout) is not None:
         n, ic_chunk, h, w, ic_bn = dshape
         target = tvm.target.current_target(allow_none=False)
-        if _is_int8_hw_support(data.dtype, kernel.dtype, target):
-            oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
-            ic = ic_chunk*ic_bn
-            assert ic == k_ic*k_ic_f*k_ic_s
-        else:
-            oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
-            assert ic_chunk == k_ic_chunk
-            assert ic_bn == k_ic_bn
-            ic = ic_chunk*ic_bn
+        oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
+        assert ic_chunk == k_ic_chunk
+        assert ic_bn == k_ic_bn
+        ic = ic_chunk*ic_bn
         oc = oc_chunk*oc_bn
     else:
         raise ValueError("Not support this layout {} with "
                          "schedule template.".format(layout))
+
     is_kernel_1x1 = kh == 1 and kw == 1
     ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
     sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
@@ -444,14 +439,25 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
                                                    in_channel//ic_bn, ic_bn//n_elems, n_elems))
         kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
         copy_inputs = [data_expr, kernel_OIHWioe]
-        # Store altered operator's config
-        new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn,
-                                      in_channel//ic_bn, ic_bn//n_elems,
-                                      n_elems))
-        new_workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation,
-             new_attrs[layout_name], new_attrs['out_layout'], out_dtype],
-            conv2d_NCHWc)
+
+        # Store altered operator's config. New kernel layout OIHWio4
+        new_kernel = tvm.placeholder((out_channel // oc_bn,
+                                      in_channel // ic_bn,
+                                      kh,
+                                      kw,
+                                      ic_bn // n_elems,
+                                      oc_bn,
+                                      n_elems), dtype=kernel.dtype)
+
+        new_workload = autotvm.task.args_to_workload([new_data,
+                                                      new_kernel,
+                                                      strides,
+                                                      padding,
+                                                      dilation,
+                                                      new_attrs[layout_name],
+                                                      new_attrs['out_layout'],
+                                                      out_dtype],
+                                                     conv2d_NCHWc_int8)
         dispatch_ctx.update(target, new_workload, cfg)
         if F.__name__ == 'nnvm.symbol':
             logging.warning("Use native layout for int8 convolution on NNVM.")
index d02ddae..c5ef585 100644 (file)
 # pylint: disable=invalid-name,unused-variable,unused-argument,no-member
 """Conv2D int8 schedule on x86"""
 
+import re
 import tvm
 from tvm import autotvm
+from tvm.autotvm.task import get_config
+from tvm.autotvm.task.topi_integration import deserialize_args
 from .. import generic, tag
 from ..util import get_const_tuple
 from ..nn.conv2d import conv2d_NCHWc_int8
 from .. import nn
-from .conv2d import _get_default_config
 from . import conv2d_avx_1x1, conv2d_avx_common
 
+def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, layout):
+    """Create schedule configuration from input arguments"""
+    dshape = get_const_tuple(data.shape)
+    kshape = get_const_tuple(kernel.shape)
+    pat = re.compile(r'NCHW.+(\d+)c')
+    if layout == 'NCHW':
+        n, ic, h, w = dshape
+        oc, _, kh, kw = kshape
+    elif layout == 'NHWC':
+        n, h, w, ic = dshape
+        kh, kw, oc, _ = kshape
+    elif pat.match(layout) is not None:
+        n, ic_chunk, h, w, ic_bn = dshape
+        target = tvm.target.current_target(allow_none=False)
+        oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
+        ic = ic_chunk * ic_bn
+        assert ic == k_ic * k_ic_f * k_ic_s
+        oc = oc_chunk*oc_bn
+    else:
+        raise ValueError("Not support this layout {} with "
+                         "schedule template.".format(layout))
+
+    is_kernel_1x1 = kh == 1 and kw == 1
+    ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
+    sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+    oh = (h - kh + 2 * ph) // sh + 1
+    ow = (w - kw + 2 * pw) // sw + 1
+
+    # Create schedule config
+    cfg.define_split('tile_ic', ic, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0)
+    cfg.define_split('tile_oc', oc, num_outputs=2, filter=lambda y: y.size[-1] % 16 == 0)
+    cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
+    if is_kernel_1x1:
+        cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
+    else:
+        cfg.define_knob("unroll_kw", [True, False])
+
+
+# Define template function for autotvm task
+# We define schedule template in this function instead of
+# declaration function since actual input arguments need
+# to be altered by the schedule selected.
+@autotvm.task.register("topi_x86_conv2d_NCHWc_int8")
+def _topi_nn_conv2d_NCHWc_int8(*args, **kwargs):
+    assert not kwargs, "Do not support kwargs in template function call"
+    args = deserialize_args(args)
+
+    if len(args) == 7:
+        data, kernel, strides, padding, dilation, origin_layout, dtype = args
+    else:
+        assert len(args) == 8
+        data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args
+
+    raw_data_shape = get_const_tuple(data.shape)
+    raw_kernel_shape = get_const_tuple(kernel.shape)
+
+    # get config here
+    cfg = get_config()
+    _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, origin_layout)
+
+    # change shape with the value in config
+    ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
+                           cfg["tile_ow"].size[-1])
+
+    data_layout = "NCHW%dc" % ic_bn
+    out_layout = "NCHW%dc" % oc_bn
+
+    # Set up the new shape for data and kernel
+    new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
+                      raw_data_shape[2], raw_data_shape[3], ic_bn)
+    n_elems = 4
+    new_kernel_shape = (raw_kernel_shape[0] // oc_bn,
+                        raw_kernel_shape[1] // ic_bn,
+                        raw_kernel_shape[2],
+                        raw_kernel_shape[3],
+                        ic_bn // n_elems,
+                        oc_bn,
+                        n_elems)
+
+    new_data = tvm.placeholder(new_data_shape, data.dtype)
+    new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
+
+    C = _declaration_conv_NCHWc_int8(cfg, new_data, new_kernel, strides, padding, dilation,
+                                     data_layout, out_layout, dtype)
+    s = _schedule_conv2d_NCHWc_int8(cfg, [C])
+    return s, [new_data, new_kernel, C]
+
+
 @autotvm.register_topi_compute(conv2d_NCHWc_int8, 'cpu', 'direct')
 def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
                                  padding, dilation, layout, out_layout, out_dtype):
-    n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
-    in_channel = ic_chunk * ic_bn
-    oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = \
-            get_const_tuple(kernel.shape)
-    num_filter = oc_chunk * oc_bn
-
-    # If config is not set, we can reuse the default config for NCHW.
-    if cfg.is_fallback:
-        _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
-                            tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
-                                            dtype=kernel.dtype),
-                            strides, padding, out_dtype)
     return nn.conv2d_NCHWc_int8_compute(data,
                                         kernel,
                                         strides,