[Strategy] Support for Int8 schedules - CUDA/x86 (#5031)
authorAnimesh Jain <anijain@umich.edu>
Thu, 12 Mar 2020 16:51:09 +0000 (09:51 -0700)
committerGitHub <noreply@github.com>
Thu, 12 Mar 2020 16:51:09 +0000 (09:51 -0700)
* [CUDA] Op strategy changes for Int8 schedules.

* Applying Haichen's suggestions.

* Make 4D output work for task extraction.

* Make x86 work.

* Fix lint.

* Lint fixes.

* Tests, comments, out channel a multiple of 4.

* Topi test.

Co-authored-by: Ubuntu <ubuntu@ip-172-31-38-96.us-west-2.compute.internal>
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/op/strategy/cuda.py
python/tvm/relay/qnn/op/legalizations.py
tests/python/relay/test_pass_qnn_legalize.py
topi/python/topi/cuda/conv2d_alter_op.py
topi/python/topi/cuda/conv2d_int8.py
topi/python/topi/generic/conv2d.py
topi/tests/python/test_topi_conv2d_int8.py

index c2bfd75..ba93bb2 100644 (file)
@@ -1373,8 +1373,8 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
 
         # 3) Clip/cast to change the out dtype.
         _res = relay.clip(_res,
-                          a_min=float(tvm.api.min_value(out_dtype).value),
-                          a_max=float(tvm.api.max_value(out_dtype).value))
+                          a_min=float(tvm.tir.op.min_value(out_dtype).value),
+                          a_max=float(tvm.tir.op.max_value(out_dtype).value))
         _res = relay.cast(_res, out_dtype)
         return _res
 
@@ -1647,8 +1647,8 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
                 _op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
             rounded_bias = _op.round(multiplied_bias)
             clipped_bias = _op.clip(rounded_bias,
-                                    a_min=tvm.api.min_value('int32').value,
-                                    a_max=tvm.api.max_value('int32').value)
+                                    a_min=tvm.tir.op.min_value('int32').value,
+                                    a_max=tvm.tir.op.max_value('int32').value)
             requantized_bias = _op.cast(clipped_bias, 'int32')
             res = _op.nn.bias_add(res, requantized_bias, axis=-1)
         enable_float_output = attrs.get_bool('enable_float_output', False)
index b1e77bd..e5eff1c 100644 (file)
@@ -85,12 +85,18 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
 
     if groups == 1:
         if layout == "NCHW":
-            # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
             assert kernel_layout == "OIHW"
-            strategy.add_implementation(
-                wrap_compute_conv2d(topi.cuda.conv2d_nchw),
-                wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
-                name="conv2d_nchw.cuda")
+            if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'):
+                assert data.dtype == kernel.dtype
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8),
+                    wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_int8),
+                    name="conv2d_nchw_int8.cuda")
+            else:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.cuda.conv2d_nchw),
+                    wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
+                    name="conv2d_nchw.cuda")
             _, _, kh, kw = get_const_tuple(kernel.shape)
             if 2 < kh < 8 and 2 < kw < 8 and kh == kw and stride_h == 1 and stride_w == 1 and \
                 dilation_h == 1 and dilation_w == 1:
index ad71313..f9874b7 100644 (file)
@@ -264,3 +264,17 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
     if is_fast_int8_on_intel():
         return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
     return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
+
+#####################
+# CUDA legalizations.
+#####################
+
+@qnn_conv2d_legalize.register('cuda')
+def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
+    # CUDA prefers the dtypes to be same.
+    return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
+
+@qnn_dense_legalize.register('cuda')
+def _qnn_dense_legalize_cuda(attrs, inputs, types):
+    # CUDA prefers the dtypes to be same.
+    return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
index 7d3d9cc..ed05096 100644 (file)
@@ -177,6 +177,13 @@ def test_qnn_legalize_qnn_conv2d():
         legalized_mod = relay.qnn.transform.Legalize()(mod)
         assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
+    ###########################################
+    # Check transformations for CUDA platforms.
+    ###########################################
+    with tvm.target.create('cuda'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()
+
 
 def test_qnn_legalize_qnn_dense():
     def _get_mod(data_dtype, kernel_dtype):
@@ -257,6 +264,13 @@ def test_qnn_legalize_qnn_dense():
         legalized_mod = relay.qnn.transform.Legalize()(mod)
         assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
+    ###########################################
+    # Check transformations for CUDA platforms.
+    ###########################################
+    with tvm.target.create('cuda'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()
+
 
 if __name__ == "__main__":
     test_qnn_legalize()
index b598271..8d9e86c 100644 (file)
@@ -26,6 +26,7 @@ from tvm import autotvm
 from .. import nn
 from ..util import get_const_tuple
 from .conv2d_winograd import _infer_tile_size
+from ..nn import conv2d_legalize
 
 logger = logging.getLogger('topi')
 
@@ -135,3 +136,82 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         return relay.nn.conv2d(*inputs, **new_attrs)
 
     return None
+
+@conv2d_legalize.register("cuda")
+def _conv2d_legalize(attrs, inputs, arg_types):
+    """Legalizes Conv2D op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+
+    # Dilation not supported yet. Return None if dilation is not (1, 1)
+    dilation = attrs.get_int_tuple("dilation")
+    if not (dilation[0] == 1 and dilation[1] == 1):
+        return None
+
+    # No legalization for depthwise convolutions yet.
+    groups = attrs.get_int("groups")
+    if groups != 1:
+        return None
+
+    # Collect the input tensors.
+    data_tensor, kernel_tensor = arg_types[0], arg_types[1]
+    data_dtype = data_tensor.dtype
+
+    # Collect the output tensor.
+    output_tensor = arg_types[2]
+
+    # Collect the input exprs.
+    data, kernel = inputs
+
+    # Get the conv attrs
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+    # Get data layout. Return None if not NCHW
+    data_layout = attrs['data_layout']
+    kernel_layout = attrs['kernel_layout']
+
+    # Pad input and output channels to use int8 schedule.
+    if data_dtype in ['int8', 'uint8']:
+        if data_layout == 'NCHW' and kernel_layout == "OIHW":
+            oc_modified = False
+            in_channel = data_tensor.shape[1].value
+            out_channel = kernel_tensor.shape[0].value
+
+            # Pad input channel
+            if in_channel % 4 != 0:
+                new_in_channel = ((in_channel + 4) // 4) * 4
+                diff = new_in_channel - in_channel
+                pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
+                data = relay.nn.pad(data, pad_width=pad_width)
+                kernel = relay.nn.pad(kernel, pad_width=pad_width)
+
+            # Pad output channel
+            new_out_channel = out_channel
+            if out_channel % 4 != 0:
+                new_out_channel = ((out_channel + 4) // 4) * 4
+                diff = new_out_channel - out_channel
+                kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0)))
+                oc_modified = True
+
+            if oc_modified:
+                new_attrs['channels'] = new_out_channel
+                out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
+                original_out_shape = [x.value for x in output_tensor.shape]
+                out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
+            else:
+                out = relay.nn.conv2d(data, kernel, **new_attrs)
+            return out
+    return None
index ad97fa6..bc8aa35 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=invalid-name
+# pylint: disable=no-value-for-parameter
 """Int8 conv2d in NCHWc layout"""
 import tvm
 from tvm import te
@@ -23,10 +24,23 @@ from tvm import autotvm
 from .injective import schedule_injective_from_existing
 from .tensor_intrin import dp4a
 from ..nn.pad import pad
+from ..nn.conv2d import unpack_NCHWc_to_nchw
 from ..nn.util import get_pad_tuple
 from ..util import get_const_tuple, traverse_inline
 
 
+def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype='int32'):
+    """Compute conv2d internally using conv2d_nchwc layout for int8 dtype"""
+    assert data.dtype in ('int8', 'uint8')
+    assert kernel.dtype in ('int8', 'uint8')
+    assert data.dtype == kernel.dtype
+    packed_out = conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, "NCHW", out_dtype)
+    return unpack_NCHWc_to_nchw(packed_out, out_dtype)
+
+def schedule_conv2d_nchw_int8(outs):
+    """Create schedule for tensors"""
+    return schedule_conv2d_NCHWc_int8(outs)
+
 @autotvm.register_topi_compute("conv2d_NCHWc_int8.cuda")
 def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype):
     """Convolution operator in NCHW[x]c layout for int8.
@@ -205,7 +219,13 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
         output = s.outputs[0].output(0)
 
     # tile and bind spatial axes
-    n, f, y, x, c = s[output].op.axis
+    if len(s[output].op.axis) == 5:
+        n, f, y, x, c = s[output].op.axis
+    else:
+        # For task extraction of auto-tuning, the expected output is 4D.  Since auto-tuning tasks
+        # are created from scratch, therefore the real auto-tuning will still happen on 5D output.
+        n, f, y, x = s[output].op.axis
+
     cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
     cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
     cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
index 69984a1..2d9f78b 100644 (file)
@@ -144,7 +144,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
         parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
         s[data_vec].parallel(parallel_axis)
 
-        oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
+        # conv2d_nchwc_int8 has 7D kernel
+        oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
         s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
         oc_bn = cfg["tile_oc"].size[-1]
         if oc_bn > 1:
@@ -189,13 +190,26 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
     s[CC].unroll(oc_f_inner)
 
     if C != O:
-        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
-        ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
-        s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-        parallel_axis = s[O].fuse(batch, oc_chunk, oh)
-        s[C].compute_at(s[O], parallel_axis)
-        s[O].vectorize(oc_block)
-        s[O].parallel(parallel_axis)
+        out_ndim = len(s[O].op.axis)
+        if out_ndim == 5:
+            batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+            ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
+            s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+            parallel_axis = s[O].fuse(batch, oc_chunk, oh)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        elif out_ndim == 4:
+            batch, oc, oh, ow = s[O].op.axis
+            ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
+            oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
+            s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+            parallel_axis = s[O].fuse(batch, oc_chunk, oh)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        else:
+            raise ValueError("Unsupported output ndim: %s" % out_ndim)
 
     return s
 
@@ -234,7 +248,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
         parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
         s[data_vec].parallel(parallel_axis)
 
-        oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
+        # Conv2d int8 schedule has 7D kernel
+        oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
         s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
         oc_bn = cfg["tile_oc"].size[-1]
         if oc_bn > 1:
@@ -277,14 +292,29 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
     s[CC].unroll(oh_inner)
 
     if C != O:
-        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
-        oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
-        ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
-        s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
-
-        parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
-        s[C].compute_at(s[O], parallel_axis)
-        s[O].vectorize(oc_block)
-        s[O].parallel(parallel_axis)
+        out_ndim = len(s[O].op.axis)
+        if out_ndim == 5:
+            batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+            oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
+            ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
+            s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+
+            parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        elif out_ndim == 4:
+            batch, oc, oh, ow = s[O].op.axis
+            oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
+            oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
+            ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
+            s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+
+            parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        else:
+            raise ValueError("Unsupported output ndim: %s" % out_ndim)
 
     return s
index d784e5c..06f930c 100644 (file)
@@ -108,6 +108,76 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
         check_device(device)
 
 
+def verify_conv2d_nchw_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
+    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+    padding_sum = pad_top + pad_left + pad_bottom + pad_right
+    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+
+    in_height = in_width = in_size
+
+    A = te.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
+    W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W', dtype='int8')
+    bias = te.placeholder((num_filter, 1, 1), name='bias', dtype='int8')
+
+    a_shape = get_const_tuple(A.shape)
+    w_shape = get_const_tuple(W.shape)
+    bias_shape = get_const_tuple(bias.shape)
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
+    def get_ref_data():
+        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+        b_np = np.random.uniform(size=bias_shape).astype(dtype)
+        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+        c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
+
+        if add_bias:
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            c_np += b_np
+        if add_relu:
+            c_np = np.maximum(c_np, 0)
+
+        return a_np, w_np, b_np, c_np
+
+    a_np, w_np, b_np, c_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
+            print("Skip because int8 intrinsics are not available")
+            return
+
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            C = topi.cuda.conv2d_nchw_int8(A, W, (stride, stride), padding, (dilation, dilation),
+                                           dtype)
+            if add_bias:
+                C = topi.add(C, bias)
+            if add_relu:
+                C = topi.nn.relu(C)
+            s = topi.cuda.schedule_conv2d_nchw_int8([C])
+
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
+        if add_bias:
+            tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func(a, w, b, c)
+        else:
+            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func(a, w, c)
+        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+    for device in ["cuda"]:
+        check_device(device)
+
+
 def test_conv2d_nchw():
     with Int8Fallback():
         # ResNet18 workloads where channels in / out are multiple of oc_block_factor
@@ -204,6 +274,17 @@ def test_conv2d_nchw():
         verify_conv2d_NCHWc_int8(1,  64,   56,  64,  3, 1, "VALID", add_bias=True, add_relu=True)
         verify_conv2d_NCHWc_int8(1,  64,   56,  64, 24, 1, "SAME", add_bias=True, add_relu=True)
 
+        # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
+        # performing basic testing - one test for all different scenarios - batch, dilation etc..
+        verify_conv2d_nchw_int8(1,  64,  56,  64, 3, 1, 1)
+        verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True)
+        verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, dilation=2)
+        verify_conv2d_nchw_int8(9, 64, 56, 64, 3, 1, 1)
+        verify_conv2d_nchw_int8(4, 4, 4, 4, 4, 4, 4)
+        verify_conv2d_nchw_int8(1,   32, 149,  32, 3, 1, 0)
+        verify_conv2d_nchw_int8(7,   32, 149,  32, 3, 1, 0)
+        verify_conv2d_nchw_int8(1,  32,   35,  64,  7, 2, (0, 0, 1, 1))
+
 
 if __name__ == "__main__":
     test_conv2d_nchw()