From 3edf52607b1209fb0d10f2dcc53ebe7cd177042d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 16 Sep 2019 14:52:28 -0700 Subject: [PATCH] [TOPI] Setting up AutoTVM template for Intel Int8 conv2D (#3955) --- python/tvm/autotvm/task/topi_integration.py | 3 + topi/python/topi/nn/conv2d.py | 1 - topi/python/topi/x86/conv2d.py | 44 +++++++----- topi/python/topi/x86/conv2d_int8.py | 104 ++++++++++++++++++++++++---- 4 files changed, 119 insertions(+), 33 deletions(-) diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index bc43471..09f08ad 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -85,6 +85,7 @@ class TaskExtractEnv: topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw", topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc", + topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8", topi.nn.dense: "topi_nn_dense", topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw", topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc", @@ -100,6 +101,7 @@ class TaskExtractEnv: topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw], topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc], + topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8], topi.nn.dense: [topi.generic.schedule_dense], topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw], topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc], @@ -111,6 +113,7 @@ class TaskExtractEnv: self.func_to_reflection = { topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x), topi.nn.conv2d_NCHWc: lambda x: setattr(topi.nn, 'conv2d_NCHWc', x), + topi.nn.conv2d_NCHWc_int8: lambda x: setattr(topi.nn, 'conv2d_NCHWc_int8', x), topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x), topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x), topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x), diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index e2a4720..e52d0a6 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -669,7 +669,6 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") - def conv2d_winograd_weight_transform(kernel, tile_size): """Weight transformation for winograd diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 6f134ea..09e7a88 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -27,7 +27,7 @@ from tvm.autotvm.task import get_config from .. import generic, tag from .. import nn from ..util import get_const_tuple, get_shape -from ..nn.conv2d import conv2d, conv2d_NCHWc, \ +from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, \ conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw @@ -77,7 +77,6 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth else: conv2d_avx_common._fallback_schedule(cfg, wkl) - def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): """Create schedule configuration from input arguments""" dshape = get_const_tuple(data.shape) @@ -92,19 +91,15 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): elif pat.match(layout) is not None: n, ic_chunk, h, w, ic_bn = dshape target = tvm.target.current_target(allow_none=False) - if _is_int8_hw_support(data.dtype, kernel.dtype, target): - oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape - ic = ic_chunk*ic_bn - assert ic == k_ic*k_ic_f*k_ic_s - else: - oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape - assert ic_chunk == k_ic_chunk - assert ic_bn == k_ic_bn - ic = ic_chunk*ic_bn + oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape + assert ic_chunk == k_ic_chunk + assert ic_bn == k_ic_bn + ic = ic_chunk*ic_bn oc = oc_chunk*oc_bn else: raise ValueError("Not support this layout {} with " "schedule template.".format(layout)) + is_kernel_1x1 = kh == 1 and kw == 1 ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) @@ -444,14 +439,25 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): in_channel//ic_bn, ic_bn//n_elems, n_elems)) kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6)) copy_inputs = [data_expr, kernel_OIHWioe] - # Store altered operator's config - new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn, - in_channel//ic_bn, ic_bn//n_elems, - n_elems)) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, - new_attrs[layout_name], new_attrs['out_layout'], out_dtype], - conv2d_NCHWc) + + # Store altered operator's config. New kernel layout OIHWio4 + new_kernel = tvm.placeholder((out_channel // oc_bn, + in_channel // ic_bn, + kh, + kw, + ic_bn // n_elems, + oc_bn, + n_elems), dtype=kernel.dtype) + + new_workload = autotvm.task.args_to_workload([new_data, + new_kernel, + strides, + padding, + dilation, + new_attrs[layout_name], + new_attrs['out_layout'], + out_dtype], + conv2d_NCHWc_int8) dispatch_ctx.update(target, new_workload, cfg) if F.__name__ == 'nnvm.symbol': logging.warning("Use native layout for int8 convolution on NNVM.") diff --git a/topi/python/topi/x86/conv2d_int8.py b/topi/python/topi/x86/conv2d_int8.py index d02ddae..c5ef585 100644 --- a/topi/python/topi/x86/conv2d_int8.py +++ b/topi/python/topi/x86/conv2d_int8.py @@ -17,30 +17,108 @@ # pylint: disable=invalid-name,unused-variable,unused-argument,no-member """Conv2D int8 schedule on x86""" +import re import tvm from tvm import autotvm +from tvm.autotvm.task import get_config +from tvm.autotvm.task.topi_integration import deserialize_args from .. import generic, tag from ..util import get_const_tuple from ..nn.conv2d import conv2d_NCHWc_int8 from .. import nn -from .conv2d import _get_default_config from . import conv2d_avx_1x1, conv2d_avx_common +def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, layout): + """Create schedule configuration from input arguments""" + dshape = get_const_tuple(data.shape) + kshape = get_const_tuple(kernel.shape) + pat = re.compile(r'NCHW.+(\d+)c') + if layout == 'NCHW': + n, ic, h, w = dshape + oc, _, kh, kw = kshape + elif layout == 'NHWC': + n, h, w, ic = dshape + kh, kw, oc, _ = kshape + elif pat.match(layout) is not None: + n, ic_chunk, h, w, ic_bn = dshape + target = tvm.target.current_target(allow_none=False) + oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape + ic = ic_chunk * ic_bn + assert ic == k_ic * k_ic_f * k_ic_s + oc = oc_chunk*oc_bn + else: + raise ValueError("Not support this layout {} with " + "schedule template.".format(layout)) + + is_kernel_1x1 = kh == 1 and kw == 1 + ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) + sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) + oh = (h - kh + 2 * ph) // sh + 1 + ow = (w - kw + 2 * pw) // sw + 1 + + # Create schedule config + cfg.define_split('tile_ic', ic, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0) + cfg.define_split('tile_oc', oc, num_outputs=2, filter=lambda y: y.size[-1] % 16 == 0) + cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64) + if is_kernel_1x1: + cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1]) + else: + cfg.define_knob("unroll_kw", [True, False]) + + +# Define template function for autotvm task +# We define schedule template in this function instead of +# declaration function since actual input arguments need +# to be altered by the schedule selected. +@autotvm.task.register("topi_x86_conv2d_NCHWc_int8") +def _topi_nn_conv2d_NCHWc_int8(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + + if len(args) == 7: + data, kernel, strides, padding, dilation, origin_layout, dtype = args + else: + assert len(args) == 8 + data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args + + raw_data_shape = get_const_tuple(data.shape) + raw_kernel_shape = get_const_tuple(kernel.shape) + + # get config here + cfg = get_config() + _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, origin_layout) + + # change shape with the value in config + ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], + cfg["tile_ow"].size[-1]) + + data_layout = "NCHW%dc" % ic_bn + out_layout = "NCHW%dc" % oc_bn + + # Set up the new shape for data and kernel + new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn, + raw_data_shape[2], raw_data_shape[3], ic_bn) + n_elems = 4 + new_kernel_shape = (raw_kernel_shape[0] // oc_bn, + raw_kernel_shape[1] // ic_bn, + raw_kernel_shape[2], + raw_kernel_shape[3], + ic_bn // n_elems, + oc_bn, + n_elems) + + new_data = tvm.placeholder(new_data_shape, data.dtype) + new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) + + C = _declaration_conv_NCHWc_int8(cfg, new_data, new_kernel, strides, padding, dilation, + data_layout, out_layout, dtype) + s = _schedule_conv2d_NCHWc_int8(cfg, [C]) + return s, [new_data, new_kernel, C] + + @autotvm.register_topi_compute(conv2d_NCHWc_int8, 'cpu', 'direct') def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype): - n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) - in_channel = ic_chunk * ic_bn - oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = \ - get_const_tuple(kernel.shape) - num_filter = oc_chunk * oc_bn - - # If config is not set, we can reuse the default config for NCHW. - if cfg.is_fallback: - _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), - tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), - dtype=kernel.dtype), - strides, padding, out_dtype) return nn.conv2d_NCHWc_int8_compute(data, kernel, strides, -- 2.7.4