* [CUDA] Op strategy changes for Int8 schedules.
* Applying Haichen's suggestions.
* Make 4D output work for task extraction.
* Make x86 work.
* Fix lint.
* Lint fixes.
* Tests, comments, out channel a multiple of 4.
* Topi test.
Co-authored-by: Ubuntu <ubuntu@ip-172-31-38-96.us-west-2.compute.internal>
# 3) Clip/cast to change the out dtype.
_res = relay.clip(_res,
- a_min=float(tvm.api.min_value(out_dtype).value),
- a_max=float(tvm.api.max_value(out_dtype).value))
+ a_min=float(tvm.tir.op.min_value(out_dtype).value),
+ a_max=float(tvm.tir.op.max_value(out_dtype).value))
_res = relay.cast(_res, out_dtype)
return _res
_op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
rounded_bias = _op.round(multiplied_bias)
clipped_bias = _op.clip(rounded_bias,
- a_min=tvm.api.min_value('int32').value,
- a_max=tvm.api.max_value('int32').value)
+ a_min=tvm.tir.op.min_value('int32').value,
+ a_max=tvm.tir.op.max_value('int32').value)
requantized_bias = _op.cast(clipped_bias, 'int32')
res = _op.nn.bias_add(res, requantized_bias, axis=-1)
enable_float_output = attrs.get_bool('enable_float_output', False)
if groups == 1:
if layout == "NCHW":
- # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
- strategy.add_implementation(
- wrap_compute_conv2d(topi.cuda.conv2d_nchw),
- wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
- name="conv2d_nchw.cuda")
+ if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'):
+ assert data.dtype == kernel.dtype
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8),
+ wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_int8),
+ name="conv2d_nchw_int8.cuda")
+ else:
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.cuda.conv2d_nchw),
+ wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
+ name="conv2d_nchw.cuda")
_, _, kh, kw = get_const_tuple(kernel.shape)
if 2 < kh < 8 and 2 < kw < 8 and kh == kw and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
+
+#####################
+# CUDA legalizations.
+#####################
+
+@qnn_conv2d_legalize.register('cuda')
+def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
+ # CUDA prefers the dtypes to be same.
+ return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
+
+@qnn_dense_legalize.register('cuda')
+def _qnn_dense_legalize_cuda(attrs, inputs, types):
+ # CUDA prefers the dtypes to be same.
+ return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+ ###########################################
+ # Check transformations for CUDA platforms.
+ ###########################################
+ with tvm.target.create('cuda'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()
+
def test_qnn_legalize_qnn_dense():
def _get_mod(data_dtype, kernel_dtype):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+ ###########################################
+ # Check transformations for CUDA platforms.
+ ###########################################
+ with tvm.target.create('cuda'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()
+
if __name__ == "__main__":
test_qnn_legalize()
from .. import nn
from ..util import get_const_tuple
from .conv2d_winograd import _infer_tile_size
+from ..nn import conv2d_legalize
logger = logging.getLogger('topi')
return relay.nn.conv2d(*inputs, **new_attrs)
return None
+
+@conv2d_legalize.register("cuda")
+def _conv2d_legalize(attrs, inputs, arg_types):
+ """Legalizes Conv2D op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+
+ # Dilation not supported yet. Return None if dilation is not (1, 1)
+ dilation = attrs.get_int_tuple("dilation")
+ if not (dilation[0] == 1 and dilation[1] == 1):
+ return None
+
+ # No legalization for depthwise convolutions yet.
+ groups = attrs.get_int("groups")
+ if groups != 1:
+ return None
+
+ # Collect the input tensors.
+ data_tensor, kernel_tensor = arg_types[0], arg_types[1]
+ data_dtype = data_tensor.dtype
+
+ # Collect the output tensor.
+ output_tensor = arg_types[2]
+
+ # Collect the input exprs.
+ data, kernel = inputs
+
+ # Get the conv attrs
+ new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+ # Get data layout. Return None if not NCHW
+ data_layout = attrs['data_layout']
+ kernel_layout = attrs['kernel_layout']
+
+ # Pad input and output channels to use int8 schedule.
+ if data_dtype in ['int8', 'uint8']:
+ if data_layout == 'NCHW' and kernel_layout == "OIHW":
+ oc_modified = False
+ in_channel = data_tensor.shape[1].value
+ out_channel = kernel_tensor.shape[0].value
+
+ # Pad input channel
+ if in_channel % 4 != 0:
+ new_in_channel = ((in_channel + 4) // 4) * 4
+ diff = new_in_channel - in_channel
+ pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
+ data = relay.nn.pad(data, pad_width=pad_width)
+ kernel = relay.nn.pad(kernel, pad_width=pad_width)
+
+ # Pad output channel
+ new_out_channel = out_channel
+ if out_channel % 4 != 0:
+ new_out_channel = ((out_channel + 4) // 4) * 4
+ diff = new_out_channel - out_channel
+ kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0)))
+ oc_modified = True
+
+ if oc_modified:
+ new_attrs['channels'] = new_out_channel
+ out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
+ original_out_shape = [x.value for x in output_tensor.shape]
+ out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
+ else:
+ out = relay.nn.conv2d(data, kernel, **new_attrs)
+ return out
+ return None
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
+# pylint: disable=no-value-for-parameter
"""Int8 conv2d in NCHWc layout"""
import tvm
from tvm import te
from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
+from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, traverse_inline
+def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype='int32'):
+ """Compute conv2d internally using conv2d_nchwc layout for int8 dtype"""
+ assert data.dtype in ('int8', 'uint8')
+ assert kernel.dtype in ('int8', 'uint8')
+ assert data.dtype == kernel.dtype
+ packed_out = conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, "NCHW", out_dtype)
+ return unpack_NCHWc_to_nchw(packed_out, out_dtype)
+
+def schedule_conv2d_nchw_int8(outs):
+ """Create schedule for tensors"""
+ return schedule_conv2d_NCHWc_int8(outs)
+
@autotvm.register_topi_compute("conv2d_NCHWc_int8.cuda")
def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype):
"""Convolution operator in NCHW[x]c layout for int8.
output = s.outputs[0].output(0)
# tile and bind spatial axes
- n, f, y, x, c = s[output].op.axis
+ if len(s[output].op.axis) == 5:
+ n, f, y, x, c = s[output].op.axis
+ else:
+ # For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks
+ # are created from scratch, therefore the real auto-tuning will still happen on 5D output.
+ n, f, y, x = s[output].op.axis
+
cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)
- oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
+ # conv2d_nchwc_int8 has 7D kernel
+ oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
s[CC].unroll(oc_f_inner)
if C != O:
- batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
- ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
- s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
- parallel_axis = s[O].fuse(batch, oc_chunk, oh)
- s[C].compute_at(s[O], parallel_axis)
- s[O].vectorize(oc_block)
- s[O].parallel(parallel_axis)
+ out_ndim = len(s[O].op.axis)
+ if out_ndim == 5:
+ batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+ ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
+ s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+ parallel_axis = s[O].fuse(batch, oc_chunk, oh)
+ s[C].compute_at(s[O], parallel_axis)
+ s[O].vectorize(oc_block)
+ s[O].parallel(parallel_axis)
+ elif out_ndim == 4:
+ batch, oc, oh, ow = s[O].op.axis
+ ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
+ oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
+ s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+ parallel_axis = s[O].fuse(batch, oc_chunk, oh)
+ s[C].compute_at(s[O], parallel_axis)
+ s[O].vectorize(oc_block)
+ s[O].parallel(parallel_axis)
+ else:
+ raise ValueError("Unsupported output ndim: %s" % out_ndim)
return s
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)
- oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
+ # Conv2d int8 schedule has 7D kernel
+ oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
s[CC].unroll(oh_inner)
if C != O:
- batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
- oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
- ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
- s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
-
- parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
- s[C].compute_at(s[O], parallel_axis)
- s[O].vectorize(oc_block)
- s[O].parallel(parallel_axis)
+ out_ndim = len(s[O].op.axis)
+ if out_ndim == 5:
+ batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+ oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
+ ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
+ s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+
+ parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
+ s[C].compute_at(s[O], parallel_axis)
+ s[O].vectorize(oc_block)
+ s[O].parallel(parallel_axis)
+ elif out_ndim == 4:
+ batch, oc, oh, ow = s[O].op.axis
+ oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
+ oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
+ ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
+ s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+
+ parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
+ s[C].compute_at(s[O], parallel_axis)
+ s[O].vectorize(oc_block)
+ s[O].parallel(parallel_axis)
+ else:
+ raise ValueError("Unsupported output ndim: %s" % out_ndim)
return s
check_device(device)
+def verify_conv2d_nchw_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
+ pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+ padding_sum = pad_top + pad_left + pad_bottom + pad_right
+ print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+
+ in_height = in_width = in_size
+
+ A = te.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
+ W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W', dtype='int8')
+ bias = te.placeholder((num_filter, 1, 1), name='bias', dtype='int8')
+
+ a_shape = get_const_tuple(A.shape)
+ w_shape = get_const_tuple(W.shape)
+ bias_shape = get_const_tuple(bias.shape)
+ dtype = A.dtype
+
+ @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+ def get_ref_data():
+ a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+ w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+ b_np = np.random.uniform(size=bias_shape).astype(dtype)
+ dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+ c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
+
+ if add_bias:
+ b_np = np.random.uniform(size=bias_shape).astype(dtype)
+ c_np += b_np
+ if add_relu:
+ c_np = np.maximum(c_np, 0)
+
+ return a_np, w_np, b_np, c_np
+
+ a_np, w_np, b_np, c_np = get_ref_data()
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
+ print("Skip because int8 intrinsics are not available")
+ return
+
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ C = topi.cuda.conv2d_nchw_int8(A, W, (stride, stride), padding, (dilation, dilation),
+ dtype)
+ if add_bias:
+ C = topi.add(C, bias)
+ if add_relu:
+ C = topi.nn.relu(C)
+ s = topi.cuda.schedule_conv2d_nchw_int8([C])
+
+ a = tvm.nd.array(a_np, ctx)
+ w = tvm.nd.array(w_np, ctx)
+ b = tvm.nd.array(b_np, ctx)
+ c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
+ if add_bias:
+ tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+ func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+ func(a, w, b, c)
+ else:
+ func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+ func(a, w, c)
+ tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+ for device in ["cuda"]:
+ check_device(device)
+
+
def test_conv2d_nchw():
with Int8Fallback():
# ResNet18 workloads where channels in / out are multiple of oc_block_factor
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True)
+ # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+ # performing basic testing - one test for all different scenarios - batch, dilation etc..
+ verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1)
+ verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True)
+ verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, dilation=2)
+ verify_conv2d_nchw_int8(9, 64, 56, 64, 3, 1, 1)
+ verify_conv2d_nchw_int8(4, 4, 4, 4, 4, 4, 4)
+ verify_conv2d_nchw_int8(1, 32, 149, 32, 3, 1, 0)
+ verify_conv2d_nchw_int8(7, 32, 149, 32, 3, 1, 0)
+ verify_conv2d_nchw_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
+
if __name__ == "__main__":
test_conv2d_nchw()