Add packing for int8 1x1 convolution and support the int8 group convolution on X86...
authorllyfacebook <34827865+llyfacebook@users.noreply.github.com>
Wed, 22 May 2019 16:11:17 +0000 (09:11 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Wed, 22 May 2019 16:11:17 +0000 (09:11 -0700)
* Support the 1x1 int8 conv with NHWC layout and weight packing

fix linter

* fix the memoize issue

* fix the failed nhwc test

* add the schedule for pack to unbreak other tests

* skip avx512 compile

* Support the 1x1 int8 conv with NHWC layout and weight packing

fix linter

* fix the memoize issue

* fix the failed nhwc test

* add the schedule for pack to unbreak other tests

* skip avx512 compile

* Unify the data_layout and kernel_layout relation

* add asf header

* fix the comment

* retrigger the build/test

topi/python/topi/generic/nn.py
topi/python/topi/nn/conv2d.py
topi/python/topi/x86/conv2d.py
topi/python/topi/x86/conv2d_avx_1x1.py
topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py [new file with mode: 0644]
topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py [new file with mode: 0644]

index db1c772..60a2d55 100644 (file)
@@ -53,6 +53,24 @@ def schedule_conv2d_nchw(outs):
 
 
 @tvm.target.generic_func
+def schedule_conv2d_nhwc_pack(outs):
+    """Schedule for conv2d_nhwc_pack
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of conv2d_nhwc_pack
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
+@tvm.target.generic_func
 def schedule_conv2d_nhwc(outs):
     """Schedule for conv2d_nhwc
 
index 06d4074..83e0274 100644 (file)
@@ -28,8 +28,8 @@ from ..util import simplify, const_matrix, get_const_tuple
 
 # workload description of conv2d
 Workload = namedtuple('Workload',
-                      ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+                      ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups',
+                       'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
 
 @tvm.target.generic_func
 def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None):
@@ -95,11 +95,24 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F):
     return None
 
 
-def _get_workload(data, kernel, stride, padding, out_dtype):
+def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
     """ Get the workload structure. """
-    _, CI, IH, IW = [x.value for x in data.shape]
-    CO, _, KH, KW = [x.value for x in kernel.shape]
+    if data_layout == 'NCHW':
+        _, CI, IH, IW = [x.value for x in data.shape]
+    elif data_layout == 'NHWC':
+        _, IH, IW, CI = [x.value for x in data.shape]
+    elif data_layout == 'HWCN':
+        IH, IW, CI, _ = [x.value for x in data.shape]
+    else:
+        raise ValueError("not support this layout {} yet".format(data_layout))
+
+    if data_layout == 'NCHW':
+        CO, CIG, KH, KW = [x.value for x in kernel.shape]
+    else:
+        KH, KW, CO, CIG = [x.value for x in kernel.shape]
+
     HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
+    GRPS = CI // CIG
     if isinstance(stride, (tuple, list)):
         HSTR, WSTR = stride
     else:
@@ -107,7 +120,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
     assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
         "Do not support inputs with different data types now. ' \
         '{} vs. {}".format(data.dtype, kernel.dtype)
-    return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
+    return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
 
 
 def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
index 02f78f8..de18abd 100644 (file)
@@ -37,7 +37,8 @@ from . import conv2d_avx_1x1, conv2d_avx_common
 
 logger = logging.getLogger('topi')
 
-def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False):
+def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
+                        layout='NCHW'):
     """
     Get default schedule config for the workload
     """
@@ -46,7 +47,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
         from .depthwise_conv2d import _fallback_schedule
         _fallback_schedule(cfg, wkl)
     else:
-        wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
+        wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
         is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
         if is_kernel_1x1:
             conv2d_avx_1x1._fallback_schedule(cfg, wkl)
@@ -62,6 +63,9 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
     if layout == 'NCHW':
         n, ic, h, w = dshape
         oc, _, kh, kw = kshape
+    elif layout == 'NHWC':
+        n, h, w, ic = dshape
+        kh, kw, oc, _ = kshape
     elif pat.match(layout) is not None:
         n, ic_chunk, h, w, ic_bn = dshape
         if data.dtype == 'uint8':
@@ -93,21 +97,31 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
         cfg.define_knob("unroll_kw", [True, False])
 
 
-@autotvm.register_topi_compute(conv2d, 'cpu', 'direct')
+@autotvm.register_topi_compute(conv2d, 'cpu', ['direct'])
 def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
     out_dtype = data.dtype if out_dtype is None else out_dtype
     padding = padding if isinstance(padding, (tuple, list)) else (padding, padding)
     strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
     dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
+
     if layout == 'NCHW':
         _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
         if cfg.is_fallback:
             _get_default_config(cfg, data, kernel, strides, padding, out_dtype)
         return _declaration_conv_impl(cfg, data, kernel, strides,
                                       padding, dilation, layout, out_dtype)
+
+    # HWOI kernel layout is for NHWC and HWCN
+    kh, kw, _, _ = get_const_tuple(kernel.shape)
     if layout == 'HWCN':
         return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
-    if layout == 'NHWC':
+    elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8":
+        if cfg.is_fallback:
+            _get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout)
+        # specialize for INT8 1X1 conv on X86
+        return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides,
+                                                          padding, dilation, out_dtype)
+    elif layout == 'NHWC':
         return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
     raise ValueError("not support this layout {} yet".format(layout))
 
@@ -226,6 +240,58 @@ def schedule_conv2d(cfg, outs):
     return s
 
 
+@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc_pack, 'cpu', ['direct'])
+def schedule_conv2d_nhwc_pack(cfg, outs):
+    """Create schedule for tensors"""
+    s = tvm.create_schedule([x.op for x in outs])
+    output_op = outs[0].op
+    scheduled_ops = []
+
+    def traverse(op):
+        """Traverse operators from computation graph"""
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_broadcast(op.tag):
+            if op not in s.outputs:
+                s[op].compute_inline()
+            else: # inject custom schedule
+                if len(op.axis) == 4: # schedule bias + bn + relu
+                    n, h, w, c = op.axis
+                    fused = s[op].fuse(n, h, w)
+                    s[op].parallel(fused)
+                    s[op].vectorize(c)
+            for tensor in op.input_tensors:
+                if tensor.op.input_tensors and tensor.op not in scheduled_ops:
+                    traverse(tensor.op)
+
+        if 'conv2d_nhwc_pack_int8' in op.tag:
+            conv_out = op.output(0)
+            kernel = conv_out.op.input_tensors[1]
+            data_vec = conv_out.op.input_tensors[0]
+            data = data_vec.op.input_tensors[0] \
+                if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
+                else data_vec
+            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
+                data_pad = data
+                data = data_pad.op.input_tensors[0]
+
+            args = [s, cfg, data_vec, conv_out, outs[0]]
+            if data.dtype == 'uint8':
+                # int8 conv kernel is 7-dim
+                kh, kw, _, _, _ = get_const_tuple(kernel.shape)
+                if kh == 1 and kw == 1:
+                    conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args)
+                else:
+                    raise ValueError("Only support 1x1 kernel with "
+                                     "schedule_conv2d_nhwc_pack.")
+            else:
+                raise ValueError("Not support this data type {} with "
+                                 "schedule_conv2d_nhwc_pack. Only support int8".format(data.dtype))
+
+        scheduled_ops.append(op)
+    traverse(output_op)
+    return s
+
+
 @generic.schedule_conv2d_nhwc.register("cpu")
 def schedule_conv2d_nhwc(outs):
     """Create schedule for tensors"""
@@ -422,10 +488,13 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
     n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
     in_channel = ic_chunk * ic_bn
     if data.dtype == 'uint8':
-        oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape)
+        oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
+            get_const_tuple(kernel.shape)
     else:
-        oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
+        oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
+            get_const_tuple(kernel.shape)
     num_filter = oc_chunk * oc_bn
+    groups = ic_chunk // ic_chunk_group
 
     if cfg.is_fallback:
         _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
@@ -449,7 +518,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
     kh = tvm.reduce_axis((0, kernel_height), name='kh')
     kw = tvm.reduce_axis((0, kernel_width), name='kw')
 
-    if data.dtype == 'uint8':
+    if data.dtype == 'uint8' and groups == 1:
         assert out_dtype == "int32", \
             "INT8 convolution requires input dtype = uint8 and output dtype=int32"
         # Intel performs dot product of 2 "4" Int8 values
@@ -468,6 +537,24 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
                                           oc_block, ic_s_inner].astype(out_dtype),
                                    axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
                            name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
+    if data.dtype == 'uint8':
+       # for int8 group conv support
+        n_elems = 4
+        ic_chunk = in_channel//ic_bn
+        ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer')
+        ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
+        ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
+        oshape = (n, oc_chunk, out_height, out_width, oc_bn)
+        return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block:
+                           tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*\
+                                            (ic_chunk//groups)+ic_outer,
+                                            oh*HSTR+kh, ow*WSTR+kw,
+                                            ic_f_inner * n_elems +  ic_s_inner].astype(out_dtype) *
+                                   kernel[occ, ic_outer, kh, kw, ic_f_inner,
+                                          oc_block, ic_s_inner].astype(out_dtype),
+                                   axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
+                           name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
+
     # else: fp implementation
     return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
                        tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw,
index bcd2cef..4994d45 100644 (file)
@@ -20,8 +20,9 @@ from __future__ import absolute_import as _abs
 import tvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 
-from ..nn.util import infer_pad
-from ..util import get_const_tuple
+from ..nn.pad import pad
+from ..nn.util import infer_pad, get_pad_tuple
+from ..util import get_const_tuple, simplify
 from .tensor_intrin import dot_16x1x16_int8_int8_int32
 from .check_targets import check_skylake
 from .util import get_fp32_len
@@ -251,3 +252,104 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
         s[O].parallel(parallel_axis)
 
     return s
+
+
+def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype):
+    # more assertion for the shapes
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilation, int) or len(dilation) == 2
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    batch, in_height, in_width, in_channel = Input.shape
+    kernel_h, kernel_w, num_filter, channel = Filter.shape
+
+    # compute the output shape
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    out_channel = num_filter
+    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
+    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
+    pad_before = [0, pad_top, pad_left, 0]
+    pad_after = [0, pad_down, pad_right, 0]
+    PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
+    # todo: padding filter to accomodate the intrinsic
+
+    # packing the Filter to let memory access be consecutive for AVX512 intrinsic
+    # Done in pre-compute stage
+    packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4)
+    PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e],
+                        name="packed_filter")
+
+    rc = tvm.reduce_axis((0, in_channel), name='rc')
+    ry = tvm.reduce_axis((0, kernel_h), name='ry')
+    rx = tvm.reduce_axis((0, kernel_w), name='rx')
+    Output = tvm.compute(
+        (batch, out_height, out_width, out_channel),
+        lambda nn, yy, xx, ff: tvm.sum(
+            PaddedInput[nn, yy * stride_h + ry * dilation_h,
+                        xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
+            PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]),
+        name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8")
+    return Output
+
+
+def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last):
+    """
+    Defines the schedule for the int8 nhwc layout. For 1x1 conv, it
+    is a matrix-multiply operation by using nhwc layout. We will do
+    packing of weight to make the address access be friendly to int8
+    intrinsic
+    """
+    target = tvm.target.current_target(allow_none=False)
+    int32_lanes = -1
+    if check_skylake(target):
+        int32_lanes = 16
+    else:
+        return s
+    assert int32_lanes != -1
+
+    # assertion to fail the unhandled case
+    _, _, _, ic_num = get_const_tuple(data.shape)
+    _, _, _, oc_num = get_const_tuple(conv_out.shape)
+    assert ic_num % 4 == 0
+    assert oc_num % 16 == 0
+
+    ic_factor, oc_factor = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+    # schedule data
+    A = data
+    if isinstance(s[A].op, tvm.tensor.ComputeOp):
+        batch, ih, iw, ic = s[A].op.axis
+        d_ic_chunk, d_ic_block = s[A].split(ic, factor=4)
+        s[A].vectorize(d_ic_block)
+
+    C, O = conv_out, last
+
+    batch, oh, ow, oc = s[C].op.axis
+    kh, kw, ic = s[C].op.reduce_axis
+    # match the x86 intrinsic
+    ic_outer, ic_inner = s[C].split(ic, factor=4)
+    oc_outer, oc_inner = s[C].split(oc, factor=int32_lanes)
+
+    ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor)
+    s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner)
+
+    pc = dot_16x1x16_int8_int8_int32()
+    s[C].tensorize(oc_inner, pc)
+
+    if C != O:
+        batch, last_oh, last_ow, last_oc = s[O].op.axis
+        oc_chunk, oc_block = s[O].split(ochannel, 16)
+        # not saw perf improvement to split oh/ow here
+        s[O].vectorize(oc_block)
+
+    return s
diff --git a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py
new file mode 100644 (file)
index 0000000..763150a
--- /dev/null
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example code to do convolution."""
+import os
+import numpy as np
+import tvm
+from tvm import autotvm
+from tvm.autotvm.task.space import FallbackConfigEntity
+import topi
+import topi.testing
+from tvm.contrib.pickle_memoize import memoize
+from topi.util import get_const_tuple
+
+
+def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
+    in_height = in_width = in_size
+
+    A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8')
+    W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
+
+    a_shape = get_const_tuple(A.shape)
+    w_shape = get_const_tuple(W.shape)
+    adtype = A.dtype
+    wdtype = W.dtype
+
+    @memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2")
+    def get_ref_data():
+        a_np = np.random.uniform(size=a_shape).astype(adtype)
+        w_np = np.random.uniform(size=w_shape).astype(wdtype)
+        dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
+        b_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
+        return a_np, w_np, b_np
+
+    a_np, w_np, b_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+
+        with tvm.target.create(device):
+            B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32")
+            s = topi.generic.schedule_conv2d_nhwc_pack([B])
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
+        func = tvm.build(s, [A, W, B], device)
+        func(a, w, b)
+        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+
+    # for device in ['llvm -mcpu=skylake-avx512']:
+    for device in ['llvm']:
+        check_device(device)
+
+
+class DefaultFallback(autotvm.FallbackContext):
+    def _query_inside(self, target, workload):
+        key = (target, workload)
+        if key in self.memory:
+            return self.memory[key]
+        cfg = FallbackConfigEntity()
+        cfg.template_key = 'direct'
+        self.memory[key] = cfg
+        return cfg
+
+
+def test_conv2d_nhwc():
+    autotvm.DispatchContext.current.silent = True
+    with DefaultFallback():
+        verify_conv2d_1x1_nhwc_pack_int8(1, 256, 32, 256, 1, 1, 0)
+
+
+if __name__ == "__main__":
+    test_conv2d_nhwc()
diff --git a/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py
new file mode 100644 (file)
index 0000000..6ed1b4a
--- /dev/null
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test for NCHW[x]c convolution"""
+
+import numpy as np
+import tvm
+from tvm import autotvm
+import topi
+import topi.testing
+from tvm.contrib.pickle_memoize import memoize
+from topi.util import get_const_tuple
+
+from common import get_all_backend
+
+def _transform_data(data, bn):
+    # NCHW -> NCHW[x]c
+    batch_size, channel, height, width = data.shape
+    data = np.reshape(data, (batch_size, channel//bn, bn, height, width))
+    data = np.transpose(data, (0, 1, 3, 4, 2))
+    return data
+
+def _transform_kernel(kernel, ic_bn, oc_bn):
+    # OIHW -> OIHW[x]i[x]o
+    out_channel, in_channel, kh, kw = kernel.shape
+    kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn//4, kh, kw, 4))
+    kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1, 6))
+    return kernel
+
+def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filter, kernel, stride,
+                        padding, dilation=1, add_bias=False, add_relu=False, dtype="int32"):
+    assert dilation == 1, "conv2d_NCHWc does not support dilation for now."
+    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" %
+          (batch, in_channel, groups, in_size, num_filter, kernel, stride, padding))
+
+    in_height = in_width = in_size
+
+    # for testing functionality,
+    # we choose arbitrary block size that can divide the channel,
+    # regardless of the performance.
+    oc_block = 1
+    for bn in range(16, 0, -1):
+        if num_filter % bn == 0:
+            oc_block = bn
+            break
+
+    ic_block = 8
+    autotvm.DispatchContext.current.silent = True
+    A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A', dtype='uint8')
+    W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block//groups, kernel, kernel, ic_block//4, oc_block, 4), name='W', dtype='int8')
+
+    @memoize("topi.tests.test_topi_conv2d_NCHWc_int8.verify_conv2d_NCHWc_int8")
+    def get_ref_data():
+        a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype("uint8")
+        w_np = np.random.uniform(size=(num_filter, in_channel//groups, kernel, kernel)).astype("int8")
+        c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding, groups)
+        return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \
+               _transform_data(c_np, oc_block)
+
+    a_np, w_np, c_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding),
+                                     (dilation, dilation),
+                                     layout='NCHW%dc'%ic_block,
+                                     out_layout="NCHW%dc"%oc_block,
+                                     out_dtype=dtype)
+            s = topi.generic.schedule_conv2d_NCHWc([C])
+
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
+        func = tvm.build(s, [A, W, C], device,
+                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
+                              (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+        # print(tvm.lower(s, [A, W, C], simple_mode=True))
+        func(a, w, c)
+        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)
+
+    # for device in ["llvm -mcpu=skylake-avx512"]:
+    for device in ["llvm"]:
+        with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
+            check_device(device)
+
+
+def test_conv2d_NCHWc():
+    # ResNet50 workloads
+    verify_group_conv2d_NCHWc_int8(1, 256, 32, 224, 64, 7, 2, 3)
+
+if __name__ == "__main__":
+    test_conv2d_NCHWc()