[topi] add ARM v8.2 udot (uint8) support (#3978)
authorYizhi Liu <liuyizhi@apache.org>
Tue, 1 Oct 2019 15:40:16 +0000 (23:40 +0800)
committerGitHub <noreply@github.com>
Tue, 1 Oct 2019 15:40:16 +0000 (23:40 +0800)
* [topi] add ARM v8.2 udot (uint8) support

* fix test case

* fix common conv2d schedule

* add back fp32_time in test

* fix lint

* fix doc, add support for int32_lanes=4, signed int

* fix lint

* add ic_bn % 4 checker in schedule

topi/python/topi/arm_cpu/__init__.py
topi/python/topi/arm_cpu/conv2d_int8.py [new file with mode: 0644]
topi/python/topi/arm_cpu/tensor_intrin.py [new file with mode: 0644]
topi/python/topi/generic/conv2d.py [new file with mode: 0644]
topi/python/topi/nn/conv2d.py
topi/python/topi/x86/conv2d_avx_1x1.py
topi/python/topi/x86/conv2d_avx_common.py
topi/python/topi/x86/conv2d_int8.py
topi/recipe/conv/test_conv_int8_arm.py [new file with mode: 0644]

index 6cf4d91..32751bf 100644 (file)
@@ -3,6 +3,7 @@
 from . import conv2d
 from . import depthwise_conv2d
 from . import conv2d_transpose
+from . import conv2d_int8
 from . import bitserial_conv2d
 from . import bitserial_dense
 from . import injective
diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py
new file mode 100644 (file)
index 0000000..8f43f5c
--- /dev/null
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Conv2D int8 schedule on ARM"""
+
+import tvm
+from tvm import autotvm
+from .. import generic, tag
+from ..util import get_const_tuple
+from ..nn.conv2d import conv2d_NCHWc_int8
+from ..generic import conv2d as conv2d_generic
+from .. import nn
+from ..nn.conv2d import _get_workload as _get_conv2d_workload
+from .tensor_intrin import dot_int8_int8_int32
+
+
+def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
+    """
+    Get default int8 schedule config for the workload
+    """
+    wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
+    is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
+    if is_kernel_1x1:
+        conv2d_generic.fallback_schedule_cpu_1x1_int8(
+            cfg, wkl, int32_lanes=2, num_int8_elements=4)
+    else:
+        conv2d_generic.fallback_schedule_cpu_common_int8(
+            cfg, wkl, int32_lanes=2, num_int8_elements=4)
+
+
+@autotvm.register_topi_compute(conv2d_NCHWc_int8, ['arm_cpu'], 'direct')
+def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
+                                 padding, dilation, layout, out_layout, out_dtype):
+    # layout and out_layout are not used here,
+    # we keep them for debug convenience when dumping autotvm workload
+    n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
+    in_channel = ic_chunk * ic_bn
+
+    oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape)
+    num_filter = oc_chunk * oc_bn
+
+    # If no config was set, we can fallback to NCHW config.
+    if cfg.is_fallback:
+        _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
+                            tvm.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype),
+                            strides, padding, out_dtype)
+    return nn.conv2d_NCHWc_int8_compute(data,
+                                        kernel,
+                                        strides,
+                                        padding,
+                                        dilation,
+                                        layout,
+                                        out_layout,
+                                        out_dtype)
+
+
+@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc_int8, ['arm_cpu'], ['direct'])
+def _schedule_conv2d_NCHWc_int8(cfg, outs):
+    """Create schedule for tensors"""
+    s = tvm.create_schedule([x.op for x in outs])
+    scheduled_ops = []
+
+    def traverse(op):
+        """Traverse operators from computation graph"""
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_broadcast(op.tag):
+            if op not in s.outputs:
+                s[op].compute_inline()
+            for tensor in op.input_tensors:
+                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
+                    traverse(tensor.op)
+
+        if 'conv2d_NCHWc_int8' in op.tag:
+            conv_out = op.output(0)
+            kernel = conv_out.op.input_tensors[1]
+            data_vec = conv_out.op.input_tensors[0]
+            data = data_vec.op.input_tensors[0] \
+                if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
+                else data_vec
+            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
+                data_pad = data
+                data = data_pad.op.input_tensors[0]
+
+            args = [s, cfg, data_vec, conv_out, outs[0]]
+            # int8 conv kernel is 7-dim
+            _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
+            dtype = "uint" if data.dtype == "uint8" else "int"
+            if kh == 1 and kw == 1:
+                conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(
+                    *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype))
+            else:
+                conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(
+                    *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype))
+
+        scheduled_ops.append(op)
+
+    traverse(outs[0].op)
+    return s
diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py
new file mode 100644 (file)
index 0000000..2f300a1
--- /dev/null
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Conv2D int8 schedule on ARM"""
+
+import tvm
+
+def dot_int8_int8_int32(int32_lanes, dtype='uint'):
+    """
+    Int8 dot product by every 4 elements using ARM v8.2 udot.
+    This function takes two arrays of int8 datatype -- data[4] and
+    kernel[int32_lanes][4] -- and computes a dot product of data[4] with every
+    4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype.
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){
+            for (int i = 0; i < int32_lanes; i++){
+                out[i] = 0;
+                for (int k = 0; k < 4; k++){
+                    out[i] += data[k] * kernel[i][k]
+                }
+            }
+        }
+
+    Physically, the kernel array sits in a vector register and
+    the data[4] is broadcasted to another vector register. This
+    function returns a TensorIntrin that can be used to tensorize
+    a schedule.
+
+    Parameters
+    ----------
+    int32_lanes: int
+        How many int32/uint32 to produce
+    dtype: str, optional, {"uint", "int"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The ARM uint8 TensorIntrin that can be used in tensorizing schedule
+    """
+    num_int8_elements = 4  # 4 int8 elements in int32
+
+    data = tvm.placeholder((num_int8_elements,), dtype='%s8' % dtype, name='data')
+    kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='%s8' % dtype, name='kernel')
+
+    k = tvm.reduce_axis((0, num_int8_elements), name='k')
+    C = tvm.compute((int32_lanes,),
+                    lambda i: tvm.sum(data[k].astype('%s32' % dtype) *
+                                      kernel[i, k].astype('%s32' % dtype),
+                                      axis=k), name="C")
+
+    a_buffer = tvm.decl_buffer(data.shape, dtype='%s8' % dtype, name="a_buffer",
+                               offset_factor=1,
+                               strides=[1])
+    b_buffer = tvm.decl_buffer(kernel.shape, dtype='%s8' % dtype, name="b_buffer",
+                               offset_factor=1,
+                               strides=[tvm.var('s'), 1])
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.ir_builder.create()
+            if index == 1:
+                ib.emit(outs[0].vstore(0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes))))
+                return ib.get()
+
+            dtype_a = '%s8x%d' % (dtype, num_int8_elements)
+            dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements)
+            dtype_c = '%s32x%d' % (dtype, int32_lanes)
+
+            a_int8 = ins[0].vload([0], dtype_a)
+            re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8)
+            # broadcast a
+            vec_ai32 = re_int32.astype(dtype_c)
+
+            vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32)
+            vec_b = ins[1].vload([0, 0], dtype_b)
+            vec_c = outs[0].vload([0], dtype_c)
+
+            inst = 'udot' if dtype == 'uint' else 'sdot'
+            inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % (
+                inst, int32_lanes, int32_lanes * num_int8_elements)
+            vdot = tvm.call_llvm_intrin(dtype_c,
+                                        inst,
+                                        tvm.const(2, 'uint32'),
+                                        vec_c, vec_a, vec_b)
+            ib.emit(outs[0].vstore(0, vdot))
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    with tvm.build_config(offset_factor=1, partition_const_loop=True):
+        return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
diff --git a/topi/python/topi/generic/conv2d.py b/topi/python/topi/generic/conv2d.py
new file mode 100644 (file)
index 0000000..332c2fd
--- /dev/null
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+# pylint: disable=unused-argument, redefined-builtin
+"""Generic convolution schedules"""
+from __future__ import absolute_import as _abs
+import tvm
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
+from ..util import get_const_tuple
+
+def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements):
+    """Fallback schedule for conv2d int8 on cpu.
+    Normally the inner most pattern takes two int8/uint8 tensors
+    data[num_int8_elements] and kernel[int32_lanes, num_int8_elements],
+    produces a dot product int32/uint32 output[int32_lanes].
+
+    Parameters
+    ----------
+    int32_lanes : int
+        How many numbers of int32/uint32 will be produced using intrinsic.
+        This is related to output channel.
+    num_int8_elements : int
+        How many numbers of input int32/uint32 will be multiplied and reduced.
+        This is related to input channel.
+    """
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+
+    assert wkl.out_filter % int32_lanes == 0, \
+        "wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes)
+    assert wkl.in_filter % num_int8_elements == 0, \
+        "wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements)
+
+    oc_bn = int32_lanes
+    ic_bn = 1
+    for bn in range(oc_bn, 0, -4):
+        if wkl.in_filter % bn == 0:
+            ic_bn = bn
+            break
+
+    reg_n = 1
+    for n in range(31, 0, -1):
+        if out_width % n == 0:
+            reg_n = n
+            break
+
+    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
+    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
+    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
+    cfg["unroll_kw"] = OtherOptionEntity(False)
+
+
+def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
+    """Fallback schedule for 1x1 conv2d int8 on cpu.
+    Normally the inner most pattern takes two int8/uint8 tensors
+    data[num_int8_elements] and kernel[int32_lanes, num_int8_elements],
+    produces a dot product int32/uint32 output[int32_lanes].
+
+    Parameters
+    ----------
+    int32_lanes : int
+        How many numbers of int32/uint32 will be produced using intrinsic.
+        This is related to output channel.
+    num_int8_elements : int
+        How many numbers of input int32/uint32 will be multiplied and reduced.
+        This is related to input channel.
+    """
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+    out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
+    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+
+    assert wkl.out_filter % int32_lanes == 0, \
+        "wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes)
+    assert wkl.in_filter % num_int8_elements == 0, \
+        "wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements)
+
+    oc_bn = int32_lanes
+    ic_bn = 1
+    for bn in range(oc_bn, 0, -4):
+        if wkl.in_filter % bn == 0:
+            ic_bn = bn
+            break
+
+    for ow_factor in range(out_width, 0, -1):
+        if out_width % ow_factor == 0:
+            for oh_factor in range(out_height, 0, -1):
+                if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
+                    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
+                    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
+                    cfg["tile_oh"] = OtherOptionEntity(oh_factor)
+                    cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor])
+                    return
+    raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
+
+
+def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None):
+    """
+    Defines the schedule for INT8 for Intel and ARM machines
+    Uses the Intel/ARM intrinsics to use INT8 operations
+    More details - https://software.intel.com/en-us/articles/
+    lower-numerical-precision-deep-learning-inference-and-training
+    """
+    reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
+    _, _, _, _, ic_bn = get_const_tuple(data.shape)
+    _, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
+
+    A = data
+    if isinstance(s[A].op, tvm.tensor.ComputeOp):
+        batch, ic_chunk, ih, iw, _ = s[A].op.axis
+        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
+        s[A].parallel(parallel_axis)
+
+    # schedule 5-D NCHW[x]c conv
+    C, O = conv_out, last
+    CC = s.cache_write(C, 'global')
+
+    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
+    ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
+    s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+    parallel_axis = s[C].fuse(batch, oc_chunk, oh)
+    s[C].vectorize(oc_block)
+    if C == O:
+        s[C].parallel(parallel_axis)
+
+    s[CC].compute_at(s[C], ow_chunk)
+    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
+    kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
+
+    ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
+
+    assert oc_bn % int32_lanes == 0
+    assert ic_bn % 4 == 0  # 4 (u)int8 elements in (u)int32
+
+    oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
+
+    if unroll_kw:
+        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
+                      ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
+        s[CC].unroll(kw)
+    else:
+        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner,
+                      ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
+
+    if intrin is not None:
+        s[CC].tensorize(oc_s_inner, intrin)
+    s[CC].unroll(ow_block)
+    s[CC].unroll(oc_f_inner)
+
+    if C != O:
+        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+        ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
+        s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+        parallel_axis = s[O].fuse(batch, oc_chunk, oh)
+        s[C].compute_at(s[O], parallel_axis)
+        s[O].vectorize(oc_block)
+        s[O].parallel(parallel_axis)
+
+    return s
+
+def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None):
+    """
+    Defines the 1x1 conv schedule for INT8 for Intel and ARM machines
+    Uses the Intel/ARM intrinsics to use INT8 operations
+    More details - https://software.intel.com/en-us/articles/
+    lower-numerical-precision-deep-learning-inference-and-training
+    """
+    oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
+    _, _, _, _, ic_bn = get_const_tuple(data.shape)
+    _, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
+
+    # schedule data
+    A = data
+    if isinstance(s[A].op, tvm.tensor.ComputeOp):
+        batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
+        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
+        s[A].parallel(parallel_axis)
+
+    C, O = conv_out, last
+    CC = s.cache_write(C, 'global')
+
+    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
+    oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
+    ow_outer, ow_inner = s[C].split(ow, factor=ow_factor)
+    s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+    s[C].vectorize(oc_block)
+
+    parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer)
+    s[CC].compute_at(s[C], parallel_axis)
+    if C == O:
+        s[C].parallel(parallel_axis)
+
+    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
+    kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
+
+    assert oc_bn % int32_lanes == 0
+    assert ic_bn % 4 == 0  # 4 (u)int8 elements in (u)int32
+
+    oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
+
+    oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
+    ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
+
+    s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner,
+                  ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
+    s[CC].fuse(oc_chunk, oh_outer)
+
+    if intrin is not None:
+        s[CC].tensorize(oc_s_inner, intrin)
+    s[CC].unroll(ow_inner)
+    s[CC].unroll(oh_inner)
+
+    if C != O:
+        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+        oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
+        ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
+        s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+
+        parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
+        s[C].compute_at(s[O], parallel_axis)
+        s[O].vectorize(oc_block)
+        s[O].parallel(parallel_axis)
+
+    return s
index 904dd54..ffae4b2 100644 (file)
@@ -595,19 +595,11 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout,
 
     n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
     in_channel = ic_chunk * ic_bn
-    target = tvm.target.current_target(allow_none=False)
     oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
         get_const_tuple(kernel.shape)
     num_filter = oc_chunk * oc_bn
     groups = ic_chunk // ic_chunk_group
 
-    # Since the weight is 7-D and the last element size is 4, we have to
-    # check ic_bn should be a multiple of 4.
-    # Similary, oc_bn has to be a multiple of 4.
-
-    assert ic_bn % 4 == 0
-    assert oc_bn % 16 == 0
-
     dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
 
index 6e36e93..96b6e47 100644 (file)
@@ -22,6 +22,7 @@ from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 
 from ..nn.pad import pad
 from ..nn.util import infer_pad, get_pad_tuple
+from ..generic import conv2d as conv2d_generic
 from ..util import get_const_tuple, simplify
 from .tensor_intrin import dot_16x1x16_int8_int8_int32
 from .util import get_fp32_len
@@ -57,36 +58,6 @@ def _fallback_schedule(cfg, wkl):
     raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
 
 
-def _fallback_schedule_int8(cfg, wkl):
-    simd_width = get_fp32_len()
-    HPAD, WPAD = wkl.hpad, wkl.wpad
-    HSTR, WSTR = wkl.hstride, wkl.wstride
-    out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
-    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
-
-    oc_bn = 16
-    assert wkl.out_filter % oc_bn == 0
-
-    ic_bn = 1
-    for bn in range(oc_bn, 0, -4):
-        if wkl.in_filter % bn == 0:
-            ic_bn = bn
-            break
-    assert wkl.in_filter % 4 == 0
-
-    for ow_factor in range(out_width, 0, -1):
-        if out_width % ow_factor == 0:
-            for oh_factor in range(out_height, 0, -1):
-                if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
-                    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
-                    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
-                    cfg["tile_oh"] = OtherOptionEntity(oh_factor)
-                    cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor])
-                    return
-    raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
-
-
-
 def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
     # fetch schedule
     ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
@@ -210,71 +181,9 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
 
 
 def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
-    """
-    Defines the schedule for INT8 for intel machines
-    Uses the Intel intrinsics to use INT8 operations
-    More details - https://software.intel.com/en-us/articles/
-    lower-numerical-precision-deep-learning-inference-and-training
-    """
-    int32_lanes = 16
-
-    oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
-    _, _, _, _, ic_bn = get_const_tuple(data.shape)
-    _, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
-
-    # schedule data
-    A = data
-    if isinstance(s[A].op, tvm.tensor.ComputeOp):
-        batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
-        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
-        s[A].parallel(parallel_axis)
-
-    C, O = conv_out, last
-    CC = s.cache_write(C, 'global')
-
-    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
-    oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
-    ow_outer, ow_inner = s[C].split(ow, factor=ow_factor)
-    s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
-    s[C].vectorize(oc_block)
-
-    parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer)
-    s[CC].compute_at(s[C], parallel_axis)
-    if C == O:
-        s[C].parallel(parallel_axis)
-
-    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
-    kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
-
-    # Skylake and future processors have 16 vector lanes
-    assert oc_bn % int32_lanes == 0
-
-    oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
-
-    oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
-    ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
-
-    s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner,
-                  ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
-    s[CC].fuse(oc_chunk, oh_outer)
-
-    pc = dot_16x1x16_int8_int8_int32()
-    s[CC].tensorize(oc_s_inner, pc)
-    s[CC].unroll(ow_inner)
-    s[CC].unroll(oh_inner)
-
-    if C != O:
-        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
-        oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
-        ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
-        s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
-
-        parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
-        s[C].compute_at(s[O], parallel_axis)
-        s[O].vectorize(oc_block)
-        s[O].parallel(parallel_axis)
-
-    return s
+    return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last,
+                                                           int32_lanes=16,
+                                                           intrin=dot_16x1x16_int8_int8_int32())
 
 
 def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype):
index a7f38ac..53b79bd 100644 (file)
@@ -21,6 +21,7 @@ import tvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 
 from ..nn.util import infer_pad
+from ..generic import conv2d as conv2d_generic
 from ..util import get_const_tuple
 from .tensor_intrin import dot_16x1x16_int8_int8_int32
 from .util import get_fp32_len
@@ -56,7 +57,6 @@ def _fallback_schedule(cfg, wkl):
 
 
 def _fallback_schedule_int8(cfg, wkl):
-    simd_width = get_fp32_len()
     HPAD, WPAD = wkl.hpad, wkl.wpad
     HSTR, WSTR = wkl.hstride, wkl.wstride
     out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
@@ -207,68 +207,6 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
 
 
 def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
-    """
-    Defines the schedule for INT8 for intel machines
-    Uses the Intel intrinsics to use INT8 operations
-    More details - https://software.intel.com/en-us/articles/
-    lower-numerical-precision-deep-learning-inference-and-training
-    """
-    int32_lanes = 16
-
-    reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
-    _, _, _, _, ic_bn = get_const_tuple(data.shape)
-    _, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
-
-    A = data
-    if isinstance(s[A].op, tvm.tensor.ComputeOp):
-        batch, ic_chunk, ih, iw, _ = s[A].op.axis
-        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
-        s[A].parallel(parallel_axis)
-
-    # schedule 5-D NCHW[x]c conv
-    C, O = conv_out, last
-    CC = s.cache_write(C, 'global')
-
-    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
-    ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
-    s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-    parallel_axis = s[C].fuse(batch, oc_chunk, oh)
-    s[C].vectorize(oc_block)
-    if C == O:
-        s[C].parallel(parallel_axis)
-
-    s[CC].compute_at(s[C], ow_chunk)
-    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
-    kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
-
-    ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
-
-    # Skylake and future processors have 16 vector lanes
-    assert oc_bn % int32_lanes == 0
-
-    oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
-
-    if unroll_kw:
-        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
-                      ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
-        s[CC].unroll(kw)
-    else:
-        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner,
-                      ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
-
-
-    pc = dot_16x1x16_int8_int8_int32()
-    s[CC].tensorize(oc_s_inner, pc)
-    s[CC].unroll(ow_block)
-    s[CC].unroll(oc_f_inner)
-
-    if C != O:
-        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
-        ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
-        s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-        parallel_axis = s[O].fuse(batch, oc_chunk, oh)
-        s[C].compute_at(s[O], parallel_axis)
-        s[O].vectorize(oc_block)
-        s[O].parallel(parallel_axis)
-
-    return s
+    return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last,
+                                                              int32_lanes=16,
+                                                              intrin=dot_16x1x16_int8_int8_int32())
index 3f65db4..f701108 100644 (file)
@@ -24,6 +24,7 @@ from tvm.autotvm.task import get_config
 from tvm.autotvm.task.topi_integration import deserialize_args
 from ..nn.conv2d import _get_workload as _get_conv2d_workload
 from .. import generic, tag
+from ..generic import conv2d as conv2d_generic
 from ..util import get_const_tuple
 from ..nn.conv2d import conv2d_NCHWc_int8
 from .. import nn
@@ -38,9 +39,11 @@ def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_
     wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
     is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
     if is_kernel_1x1:
-        conv2d_avx_1x1._fallback_schedule_int8(cfg, wkl)
+        conv2d_generic.fallback_schedule_cpu_1x1_int8(
+            cfg, wkl, int32_lanes=16, num_int8_elements=4)
     else:
-        conv2d_avx_common._fallback_schedule_int8(cfg, wkl)
+        conv2d_generic.fallback_schedule_cpu_common_int8(
+            cfg, wkl, int32_lanes=16, num_int8_elements=4)
 
 
 def _is_int8_hw_support(data_dtype, kernel_dtype):
diff --git a/topi/recipe/conv/test_conv_int8_arm.py b/topi/recipe/conv/test_conv_int8_arm.py
new file mode 100644 (file)
index 0000000..ff0d37d
--- /dev/null
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
+""" Conv Int8 functional and performance testing"""
+import sys
+import logging
+import numpy as np
+import tvm
+import topi
+
+logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+LOGGER = logging.getLogger('test_conv_int8_intel')
+LOGGER.disabled = False
+
+# All the WORKLOADS from Resnet except first layer
+# Workload is ['height', 'width', 'in_filter', 'out_filter',
+#              'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
+             (56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
+             (56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
+             (56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
+             (28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
+             (28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
+             (28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
+             (14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
+             (14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
+             (14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
+             (7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
+             (56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
+             (56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
+             (56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
+             (28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
+             (56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
+             (28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
+             (28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
+             (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
+             (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
+             (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
+             (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
+             (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
+             (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
+             (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1)
+            ]
+
+
+TARGET_NAME = 'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'
+NUM_VEC_LANES = 16
+CTX = tvm.context(TARGET_NAME, 0)
+
+def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
+              hstride, wstride, out_dtype):
+    """
+    Finds out the shape of all data structures
+    """
+    data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
+
+    if out_dtype == 'int32' or out_dtype == 'uint32':
+        kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
+                        NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
+    elif out_dtype == 'float32':
+        kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
+                        NUM_VEC_LANES, NUM_VEC_LANES)
+    out_height = (im_height + 2 * hpad - k_h) // hstride + 1
+    out_width = (im_width + 2 * wpad - k_w) // wstride + 1
+    o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
+    return (data_shape, kernel_shape, o_shape)
+
+
+
+def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter,
+                  out_filter, k_h, k_w, hpad, wpad, hstride, wstride):
+    """
+    Runs the inference and checks the functional correctness between
+    compute and schedule outputs
+    """
+    (data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter,
+                                                    out_filter, k_h, k_w, hpad, wpad,
+                                                    hstride, wstride, out_dtype)
+
+    # Create TVM placeholders
+    data = tvm.placeholder(data_shape, name='data', dtype=data_dtype)
+    kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype)
+
+    # Create the numpy arrays to be used for executing conv models
+    if data_dtype == 'float32':
+        data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX)
+        kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX)
+    else:
+        data_array = tvm.nd.array(np.random.randint(100, size=data_shape).astype(data_dtype))
+        kernel_array = tvm.nd.array(np.random.randint(100, size=kernel_shape).astype(kernel_dtype))
+
+    # c_orig will be used for declaration ouptut
+    # c_sch will be used for scheduled computation output
+    c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
+    c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
+
+
+    with tvm.target.create(TARGET_NAME):
+        if out_dtype == "float32":
+            conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride,
+                                        padding=hpad, dilation=(1, 1),
+                                        layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype)
+        else:
+            conv = topi.nn.conv2d_NCHWc_int8(data, kernel, strides=hstride,
+                                             padding=hpad, dilation=(1, 1),
+                                             layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype)
+        out = topi.nn.relu(conv)
+        sch = tvm.create_schedule(out.op)
+        func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out')
+        func(data_array, kernel_array, c_orig)
+        LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
+
+        # Generate and run the optimized schedule
+        if out_dtype == "float32":
+            sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out])
+        else:
+            sconv = topi.generic.nn.schedule_conv2d_NCHWc_int8(outs=[out])
+        func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
+        func(data_array, kernel_array, c_sch)
+
+        # Functional check
+        if data_dtype == 'uint8':
+            np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy())
+        else:
+            assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy())
+
+        evaluator = func.time_evaluator(func.entry_name, CTX, number=1000)
+        LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True))
+        return evaluator(data_array, kernel_array, c_sch).mean
+
+if __name__ == "__main__":
+    LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup")
+    SPEEDUP_ARRAY = []
+    for i, wkl in enumerate(WORKLOADS):
+        for dtype in ["uint", "int"]:
+            fp32_time = run_inference('float32', 'float32', 'float32', *wkl)
+            int8_time = run_inference('%s8' % dtype, '%s8' % dtype, '%s32' % dtype, *wkl)
+            kernel_h = wkl[4]
+            kernel_w = wkl[5]
+            LOGGER.info("[%s] Workload#" % dtype + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", "
+                        + str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time))
+
+            SPEEDUP_ARRAY.append(fp32_time/int8_time)
+    LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY))))