From b0ddcff6748bddc69881a2ff4216a830407421c9 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 18 Sep 2019 21:54:01 -0700 Subject: [PATCH] [Relay] Legalize and AlterOpLayout for Int8 Intel. (#3961) --- tests/python/relay/test_op_level2.py | 85 +++++++--- topi/python/topi/nn/conv2d.py | 2 +- topi/python/topi/x86/__init__.py | 1 + topi/python/topi/x86/conv2d.py | 157 +----------------- topi/python/topi/x86/conv2d_alter_op.py | 267 ++++++++++++++++++++++++++++++ topi/python/topi/x86/conv2d_avx_1x1.py | 30 ++++ topi/python/topi/x86/conv2d_avx_common.py | 28 ++++ topi/python/topi/x86/conv2d_int8.py | 40 +++++ 8 files changed, 432 insertions(+), 178 deletions(-) create mode 100644 topi/python/topi/x86/conv2d_alter_op.py diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index a94a203..0155824 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -541,18 +541,35 @@ def test_upsampling(): def test_conv2d_int8_intrinsics(): - def _compile(input_dtype, weight_dtype, output_dtype, target): - n, ic, h, w, oc, ch, cw = 1, 16, 224, 224, 32, 3, 3 - x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype)) - w = relay.var("w", relay.TensorType((oc, ic, ch, cw), weight_dtype)) + def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): + input_dtype, weight_dtype, output_dtype = dtypes + + n, h, w, ch, cw = 1, 64, 64, 3, 3 + if data_layout == 'NCHW': + x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype)) + elif data_layout == 'NHWC': + x = relay.var("x", relay.TensorType((n, h, w, ic), input_dtype)) + else: + raise ValueError('Not supported') + + if kernel_layout == 'OIHW': + kernel_shape = (oc, ic, ch, cw) + elif kernel_layout == 'HWIO': + kernel_shape = (ch, cw, ic, oc) + else: + raise ValueError('Not supported') + + w = relay.var("w", relay.TensorType(kernel_shape, weight_dtype)) y = relay.nn.conv2d(x, w, kernel_size=(ch, cw), channels=oc, padding=(1, 1), dilation=(1, 1), + data_layout=data_layout, + kernel_layout=kernel_layout, out_dtype=output_dtype) func = relay.Function([x, w], y) - wdata = np.random.rand(oc, ic, ch, cw) * 10 + wdata = np.random.rand(*kernel_shape) * 10 parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))} with relay.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=parameters) @@ -564,37 +581,59 @@ def test_conv2d_int8_intrinsics(): name = "llvm.x86.avx512.pmaddubs.w.512" llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) if llvm_id != 0: - # Intel Int8 instruction need uint8 data and int8 kernel - asm = _compile(input_dtype="uint8", - weight_dtype="int8", - output_dtype="int32", - target=target) - # Check that intrinisic is present in the assembly. + fast_int8_dtypes = ('uint8', 'int8', 'int32') + # Sweep the input channels to check int8 robustness + for ic in range(1, 24): + asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW', + dtypes=fast_int8_dtypes) + assert "pmaddubs" in asm + + for ic in range(1, 24): + asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', + dtypes=fast_int8_dtypes) + assert "pmaddubs" in asm + + + # Sweep the output channels to check int8 robustness + for oc in range(2, 24): + asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW", kernel_layout='OIHW', + dtypes=fast_int8_dtypes) + assert "pmaddubs" in asm + + for oc in range(2, 24): + asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC", kernel_layout='HWIO', + dtypes=fast_int8_dtypes) + assert "pmaddubs" in asm + + # Check that both non-divisible oc and ic work + asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW', + dtypes=fast_int8_dtypes) + assert "pmaddubs" in asm + + asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO', + dtypes=fast_int8_dtypes) assert "pmaddubs" in asm # Ensure that code is generated when datatypes are not HW supported. - asm = _compile(input_dtype="int8", - weight_dtype="int8", - output_dtype="int32", - target=target) + dtypes = ('int8', 'int8', 'int32') + asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', + dtypes=dtypes) # Check that intrinisic is not present in the assembly. assert "pmaddubs" not in asm # Ensure that code is generated when datatypes are not HW supported. - asm = _compile(input_dtype="uint8", - weight_dtype="uint8", - output_dtype="int32", - target=target) + dtypes = ('uint8', 'uint8', 'int32') + asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', + dtypes=dtypes) # Check that intrinisic is not present in the assembly. assert "pmaddubs" not in asm # Check that a vectorized instruction is generated for older Intel # generations, because we default to NCHWc layout. target = "llvm -mcpu=core-avx2" - asm = _compile(input_dtype="int8", - weight_dtype="int8", - output_dtype="int32", - target=target) + fast_int8_dtypes = ('uint8', 'int8', 'int32') + asm = _compile(ic=16, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW', + dtypes=fast_int8_dtypes) # Check that vector int mult and add instructions are generated. assert "vpmulld" in asm and "vpadd" in asm diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index e52d0a6..590600a 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -151,7 +151,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): if data_layout == 'NCHW': CO, CIG, KH, KW = [x.value for x in kernel.shape] else: - KH, KW, CO, CIG = [x.value for x in kernel.shape] + KH, KW, CIG, CO = [x.value for x in kernel.shape] HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) GRPS = CI // CIG diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index fe4060b..2f38c19 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -17,3 +17,4 @@ from .batch_matmul import schedule_batch_matmul from .roi_align import roi_align_nchw from .conv2d_transpose import _schedule_conv2d_transpose_nchw from .sparse import * +from .conv2d_alter_op import * diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 09e7a88..6565de2 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -26,40 +26,16 @@ from tvm.autotvm.task.topi_integration import deserialize_args from tvm.autotvm.task import get_config from .. import generic, tag from .. import nn -from ..util import get_const_tuple, get_shape -from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, \ - conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload +from ..nn.conv2d import conv2d, conv2d_NCHWc, \ + conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload -from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw from ..nn.pad import pad +from ..util import get_const_tuple from . import conv2d_avx_1x1, conv2d_avx_common logger = logging.getLogger('topi') -def _is_int8_hw_support(data_dtype, kernel_dtype, target): - """ - Checks to ensure that we can use Intel DLBoost instructions - 1) The datatypes are correct. - 2) LLVM version has support for the instructions. - 3) Target is skylake and above. - """ - # 1) Check datatypes - is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8' - - # 2) Check LLVM support - llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512" - llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8) - is_llvm_support = llvm_id != 0 - - # 3) Check target - is_target_support = False - for opt in target.options: - if opt == '-mcpu=skylake-avx512': - is_target_support = True - - return is_dtype_support and is_llvm_support and is_target_support - def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout='NCHW'): """ @@ -353,133 +329,6 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): return s, [new_data, new_kernel, C] -@conv2d_alter_layout.register("cpu") -def _alter_conv2d_layout(attrs, inputs, tinfo, F): - - copy_inputs = [s for s in inputs] - new_attrs = {k : attrs[k] for k in attrs.keys()} - - if F.__name__ == 'tvm.relay.op': - # Derive channels for frontends (e.g ONNX) that miss "channel" field. - new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')] - - data, kernel = tinfo[0], tinfo[1] - batch_size, in_channel, height, width = get_const_tuple(data.shape) - - groups = attrs.get_int("groups") - out_channel = attrs.get_int("channels") \ - if F.__name__ == 'nnvm.symbol' else new_attrs["channels"] - padding = attrs.get_int_tuple("padding") - strides = attrs.get_int_tuple("strides") - dilation = attrs.get_int_tuple("dilation") - out_dtype = attrs["out_dtype"] - - layout_name = 'layout' if F.__name__ == 'nnvm.symbol' else 'data_layout' - - layout = attrs[layout_name] - kh, kw = attrs.get_int_tuple("kernel_size") - - dtype = data.dtype - out_dtype = dtype if out_dtype in ("same", "") else out_dtype - - kshape = get_shape(kernel.shape, attrs["kernel_layout"], "OIHW") - is_depthwise = groups == kshape[0] and kshape[1] == 1 - - # only optimize for NCHW - if layout != 'NCHW' or attrs["kernel_layout"] != "OIHW": - return None - - if groups != 1 and not is_depthwise: - return None - - dispatch_ctx = autotvm.task.DispatchContext.current - target = tvm.target.current_target() - # query schedule and fallback if necessary - workload = autotvm.task.args_to_workload( - [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \ - if is_depthwise else \ - autotvm.task.args_to_workload( - [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) - cfg = dispatch_ctx.query(target, workload) - if cfg.is_fallback: - _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise) - - ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - - new_attrs[layout_name] = 'NCHW%dc' % ic_bn - new_attrs['out_layout'] = 'NCHW%dc' % oc_bn - - new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), - dtype=data.dtype) - - if is_depthwise: - new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn - # Store altered operator's config - new_kernel = tvm.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], - new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc) - dispatch_ctx.update(target, new_workload, cfg) - if F.__name__ == 'nnvm.symbol': - logging.warning("Use native layout for depthwise convolution on NNVM.") - return None - return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs) - - if _is_int8_hw_support(data.dtype, kernel.dtype, target): - # Convert kernel data layout from 4D to 7D - n_elems = 4 - out_channel, _, kh, kw = get_const_tuple(kernel.shape) - data_expr, kernel_expr = inputs - kernel_IHWO = F.transpose(kernel_expr, axes=(1, 2, 3, 0)) - kernel_IHWOo = F.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn)) - kernel_OHWoI = F.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0)) - kernel_OHWoIi = F.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn, - in_channel//ic_bn, ic_bn)) - kernel_OHWoIie = F.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn, - in_channel//ic_bn, ic_bn//n_elems, n_elems)) - kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6)) - copy_inputs = [data_expr, kernel_OIHWioe] - - # Store altered operator's config. New kernel layout OIHWio4 - new_kernel = tvm.placeholder((out_channel // oc_bn, - in_channel // ic_bn, - kh, - kw, - ic_bn // n_elems, - oc_bn, - n_elems), dtype=kernel.dtype) - - new_workload = autotvm.task.args_to_workload([new_data, - new_kernel, - strides, - padding, - dilation, - new_attrs[layout_name], - new_attrs['out_layout'], - out_dtype], - conv2d_NCHWc_int8) - dispatch_ctx.update(target, new_workload, cfg) - if F.__name__ == 'nnvm.symbol': - logging.warning("Use native layout for int8 convolution on NNVM.") - return None - return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs) - - out_channel, _, kh, kw = get_const_tuple(kernel.shape) - # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) - new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) - # Store altered operator's config - new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, - kh, kw, ic_bn, oc_bn), dtype=kernel.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], - new_attrs['out_layout'], out_dtype], conv2d_NCHWc) - dispatch_ctx.update(target, new_workload, cfg) - - if F.__name__ == 'nnvm.symbol': - return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) - return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) - - @conv2d_infer_layout.register("cpu") def _conv2d_infer_layout(workload, cfg): _, data, kernel, strides, padding, dilation, layout, dtype = workload diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py new file mode 100644 index 0000000..a6333d8 --- /dev/null +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -0,0 +1,267 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D alter op and legalize functions for x86""" + +import logging + +import tvm +from tvm import relay +from tvm import autotvm +from .conv2d import _get_default_config +from .conv2d_int8 import _is_int8_hw_support, _get_default_config_int8 +from ..util import get_const_tuple, get_shape +from ..nn import conv2d_legalize +from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, conv2d_alter_layout +from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw + +logger = logging.getLogger('topi') + +@conv2d_alter_layout.register("cpu") +def _alter_conv2d_layout(attrs, inputs, tinfo, F): + # Parse the attributes. + groups = attrs.get_int("groups") + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + out_dtype = attrs["out_dtype"] + layout_name = 'layout' if F.__name__ == 'nnvm.symbol' else 'data_layout' + data_layout = attrs[layout_name] + kh, kw = attrs.get_int_tuple("kernel_size") + + data_tensor, kernel_tensor = tinfo[0], tinfo[1] + if attrs[layout_name] == 'NHWC' and attrs['kernel_layout'] == 'HWIO': + batch_size, height, width, in_channel = get_const_tuple(data_tensor.shape) + kh, kw, _, out_channel = get_const_tuple(kernel_tensor.shape) + elif attrs[layout_name] == 'NCHW' and attrs['kernel_layout'] == 'OIHW': + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) + else: + return None + + data_dtype = data_tensor.dtype + kernel_dtype = kernel_tensor.dtype + out_dtype = data_dtype if out_dtype in ("same", "") else out_dtype + + # Check if depthwise. + kshape = get_shape(kernel_tensor.shape, attrs["kernel_layout"], "OIHW") + is_depthwise = groups == kshape[0] and kshape[1] == 1 + + # Save the input exprs. + copy_inputs = [s for s in inputs] + + # Set the new attrs + new_attrs = {k : attrs[k] for k in attrs.keys()} + new_attrs['channels'] = out_channel + + # Return if the groups is not 1 and depthwise. + if groups != 1 and not is_depthwise: + return None + + # Set workload. Config update. + dispatch_ctx = autotvm.task.DispatchContext.current + target = tvm.target.current_target() + + if is_depthwise: + workload = autotvm.task.args_to_workload( + [data_tensor, kernel_tensor, strides, padding, dilation, out_dtype], + depthwise_conv2d_nchw) + else: + workload = autotvm.task.args_to_workload( + [data_tensor, kernel_tensor, strides, padding, dilation, data_layout, out_dtype], + conv2d) + + cfg = dispatch_ctx.query(target, workload) + if cfg.is_fallback: + if _is_int8_hw_support(data_dtype, kernel_dtype): + _get_default_config_int8(cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, + is_depthwise, data_layout) + else: + _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, + is_depthwise, data_layout) + + # Get the tiling parameters to set the layout names. + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + new_attrs[layout_name] = 'NCHW%dc' % ic_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data_dtype) + + if is_depthwise and data_layout == 'NCHW' and attrs['kernel_layout'] == 'OIHW': + new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], + new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) + if F.__name__ == 'nnvm.symbol': + logging.warning("Use native layout for depthwise convolution on NNVM.") + return None + return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs) + + if _is_int8_hw_support(data_dtype, kernel_dtype): + # Convert kernel data layout from 4D to 7D + n_elems = 4 + data_expr, kernel_expr = inputs + if attrs['kernel_layout'] == 'HWIO': + kernel_IHWO = F.transpose(kernel_expr, axes=(2, 0, 1, 3)) + elif attrs['kernel_layout'] == 'OIHW': + kernel_IHWO = F.transpose(kernel_expr, axes=(1, 2, 3, 0)) + else: + return None + + kernel_IHWOo = F.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn)) + kernel_OHWoI = F.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0)) + kernel_OHWoIi = F.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn)) + kernel_OHWoIie = F.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn//n_elems, n_elems)) + kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6)) + copy_inputs = [data_expr, kernel_OIHWioe] + + # Store altered operator's config. New kernel layout OIHWio4 + new_kernel = tvm.placeholder((out_channel // oc_bn, + in_channel // ic_bn, + kh, + kw, + ic_bn // n_elems, + oc_bn, + n_elems), dtype=kernel_dtype) + + new_workload = autotvm.task.args_to_workload([new_data, + new_kernel, + strides, + padding, + dilation, + new_attrs[layout_name], + new_attrs['out_layout'], + out_dtype], + conv2d_NCHWc_int8) + dispatch_ctx.update(target, new_workload, cfg) + if F.__name__ == 'nnvm.symbol': + logging.warning("Use native layout for int8 convolution on NNVM.") + return None + return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs) + + if data_layout == 'NCHW' and attrs['kernel_layout'] == 'OIHW': + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, + kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], + new_attrs['out_layout'], out_dtype], conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) + + if F.__name__ == 'nnvm.symbol': + return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) + return None + + +@conv2d_legalize.register("cpu") +def _conv2d_legalize(attrs, inputs, arg_types): + """Legalizes Conv2D op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + + # Collect the input tensors. + data_tensor, kernel_tensor = arg_types[0], arg_types[1] + + # Collect the output tensor. + output_tensor = arg_types[2] + + # Legalize if the datatypes are suitable for fast Int8 instructions. Int8 instructions require + # input channel to be a multiple of 4 and output channels to be a multiple of 16. For input + # channels, we pad both the inputs and weights input channels. For output channels, we pad the + # weight and stride_slice the output. + if _is_int8_hw_support(data_tensor.dtype, kernel_tensor.dtype): + # Flags to remember if the expr is modified + ic_modified = False + oc_modified = False + + # Collect the input exprs. + data, kernel = inputs + + # Find the value of input and output channel. + in_channel = -1 + out_channel = -1 + if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO': + in_channel = data_tensor.shape[3].value + out_channel = kernel_tensor.shape[3].value + elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW': + in_channel = data_tensor.shape[1].value + out_channel = kernel_tensor.shape[0].value + else: + return None + + if in_channel % 4 != 0: + new_in_channel = ((in_channel + 4) // 4) * 4 + diff = new_in_channel - in_channel + if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO': + data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, diff))) + kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, diff), (0, 0))) + ic_modified = True + elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW': + pad_width = ((0, 0), (0, diff), (0, 0), (0, 0)) + data = relay.nn.pad(data, pad_width=pad_width) + kernel = relay.nn.pad(kernel, pad_width=pad_width) + ic_modified = True + else: + return None + + new_out_channel = out_channel + if out_channel % 16 != 0: + new_out_channel = ((out_channel + 16) // 16) * 16 + diff = new_out_channel - out_channel + if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO': + kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, diff))) + oc_modified = True + elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW': + kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0))) + oc_modified = True + else: + return None + + if not (ic_modified or oc_modified): + return None + + if ic_modified and not oc_modified: + return relay.nn.conv2d(data, kernel, **attrs) + + if oc_modified: + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs['channels'] = new_out_channel + out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) + original_out_shape = [x.value for x in output_tensor.shape] + return relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape) + return None diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 88112b6..3d0978c 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -57,6 +57,36 @@ def _fallback_schedule(cfg, wkl): raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) +def _fallback_schedule_int8(cfg, wkl): + simd_width = get_fp32_len() + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + oc_bn = 16 + assert wkl.out_filter % oc_bn == 0 + + ic_bn = 1 + for bn in range(oc_bn, 0, -4): + if wkl.in_filter % bn == 0: + ic_bn = bn + break + assert wkl.in_filter % 4 == 0 + + for ow_factor in range(out_width, 0, -1): + if out_width % ow_factor == 0: + for oh_factor in range(out_height, 0, -1): + if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: + cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) + cfg["tile_oh"] = OtherOptionEntity(oh_factor) + cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor]) + return + raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) + + + def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): # fetch schedule ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index 5b17212..a7f38ac 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -55,6 +55,34 @@ def _fallback_schedule(cfg, wkl): cfg["unroll_kw"] = OtherOptionEntity(False) +def _fallback_schedule_int8(cfg, wkl): + simd_width = get_fp32_len() + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + oc_bn = 16 + assert wkl.out_filter % oc_bn == 0 + + ic_bn = 1 + for bn in range(oc_bn, 0, -4): + if wkl.in_filter % bn == 0: + ic_bn = bn + break + assert wkl.in_filter % 4 == 0 + + reg_n = 1 + for n in range(31, 0, -1): + if out_width % n == 0: + reg_n = n + break + + cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) + cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) + cfg["unroll_kw"] = OtherOptionEntity(False) + + def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): # fetch schedule ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], diff --git a/topi/python/topi/x86/conv2d_int8.py b/topi/python/topi/x86/conv2d_int8.py index c5ef585..3f65db4 100644 --- a/topi/python/topi/x86/conv2d_int8.py +++ b/topi/python/topi/x86/conv2d_int8.py @@ -22,12 +22,52 @@ import tvm from tvm import autotvm from tvm.autotvm.task import get_config from tvm.autotvm.task.topi_integration import deserialize_args +from ..nn.conv2d import _get_workload as _get_conv2d_workload from .. import generic, tag from ..util import get_const_tuple from ..nn.conv2d import conv2d_NCHWc_int8 from .. import nn from . import conv2d_avx_1x1, conv2d_avx_common +def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, + layout='NCHW'): + """ + Get default schedule config for the workload + """ + assert not is_depthwise, "Depthwise Int8 not supported" + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) + is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 + if is_kernel_1x1: + conv2d_avx_1x1._fallback_schedule_int8(cfg, wkl) + else: + conv2d_avx_common._fallback_schedule_int8(cfg, wkl) + + +def _is_int8_hw_support(data_dtype, kernel_dtype): + """ + Checks to ensure that we can use Intel DLBoost instructions + 1) The datatypes are correct. + 2) LLVM version has support for the instructions. + 3) Target is skylake and above. + """ + # 1) Check datatypes + is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8' + + # 2) Check LLVM support + llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512" + llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8) + is_llvm_support = llvm_id != 0 + + # 3) Check target + target = tvm.target.current_target() + is_target_support = False + for opt in target.options: + if opt == '-mcpu=skylake-avx512': + is_target_support = True + + return is_dtype_support and is_llvm_support and is_target_support + + def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, layout): """Create schedule configuration from input arguments""" dshape = get_const_tuple(data.shape) -- 2.7.4