From 5cc17649f491299ddf15a8eb144fbb6732382c9a Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 1 Oct 2019 23:40:16 +0800 Subject: [PATCH] [topi] add ARM v8.2 udot (uint8) support (#3978) * [topi] add ARM v8.2 udot (uint8) support * fix test case * fix common conv2d schedule * add back fp32_time in test * fix lint * fix doc, add support for int32_lanes=4, signed int * fix lint * add ic_bn % 4 checker in schedule --- topi/python/topi/arm_cpu/__init__.py | 1 + topi/python/topi/arm_cpu/conv2d_int8.py | 112 ++++++++++++++ topi/python/topi/arm_cpu/tensor_intrin.py | 110 ++++++++++++++ topi/python/topi/generic/conv2d.py | 239 ++++++++++++++++++++++++++++++ topi/python/topi/nn/conv2d.py | 8 - topi/python/topi/x86/conv2d_avx_1x1.py | 99 +------------ topi/python/topi/x86/conv2d_avx_common.py | 70 +-------- topi/python/topi/x86/conv2d_int8.py | 7 +- topi/recipe/conv/test_conv_int8_arm.py | 158 ++++++++++++++++++++ 9 files changed, 633 insertions(+), 171 deletions(-) create mode 100644 topi/python/topi/arm_cpu/conv2d_int8.py create mode 100644 topi/python/topi/arm_cpu/tensor_intrin.py create mode 100644 topi/python/topi/generic/conv2d.py create mode 100644 topi/recipe/conv/test_conv_int8_arm.py diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index 6cf4d91..32751bf 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -3,6 +3,7 @@ from . import conv2d from . import depthwise_conv2d from . import conv2d_transpose +from . import conv2d_int8 from . import bitserial_conv2d from . import bitserial_dense from . import injective diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py new file mode 100644 index 0000000..8f43f5c --- /dev/null +++ b/topi/python/topi/arm_cpu/conv2d_int8.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D int8 schedule on ARM""" + +import tvm +from tvm import autotvm +from .. import generic, tag +from ..util import get_const_tuple +from ..nn.conv2d import conv2d_NCHWc_int8 +from ..generic import conv2d as conv2d_generic +from .. import nn +from ..nn.conv2d import _get_workload as _get_conv2d_workload +from .tensor_intrin import dot_int8_int8_int32 + + +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype): + """ + Get default int8 schedule config for the workload + """ + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) + is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 + if is_kernel_1x1: + conv2d_generic.fallback_schedule_cpu_1x1_int8( + cfg, wkl, int32_lanes=2, num_int8_elements=4) + else: + conv2d_generic.fallback_schedule_cpu_common_int8( + cfg, wkl, int32_lanes=2, num_int8_elements=4) + + +@autotvm.register_topi_compute(conv2d_NCHWc_int8, ['arm_cpu'], 'direct') +def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides, + padding, dilation, layout, out_layout, out_dtype): + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + + oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape) + num_filter = oc_chunk * oc_bn + + # If no config was set, we can fallback to NCHW config. + if cfg.is_fallback: + _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), + tvm.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype), + strides, padding, out_dtype) + return nn.conv2d_NCHWc_int8_compute(data, + kernel, + strides, + padding, + dilation, + layout, + out_layout, + out_dtype) + + +@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc_int8, ['arm_cpu'], ['direct']) +def _schedule_conv2d_NCHWc_int8(cfg, outs): + """Create schedule for tensors""" + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + traverse(tensor.op) + + if 'conv2d_NCHWc_int8' in op.tag: + conv_out = op.output(0) + kernel = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] \ + if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ + else data_vec + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + args = [s, cfg, data_vec, conv_out, outs[0]] + # int8 conv kernel is 7-dim + _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape) + dtype = "uint" if data.dtype == "uint8" else "int" + if kh == 1 and kw == 1: + conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8( + *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype)) + else: + conv2d_generic.schedule_conv_NCHWc_cpu_common_int8( + *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype)) + + scheduled_ops.append(op) + + traverse(outs[0].op) + return s diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py new file mode 100644 index 0000000..2f300a1 --- /dev/null +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D int8 schedule on ARM""" + +import tvm + +def dot_int8_int8_int32(int32_lanes, dtype='uint'): + """ + Int8 dot product by every 4 elements using ARM v8.2 udot. + This function takes two arrays of int8 datatype -- data[4] and + kernel[int32_lanes][4] -- and computes a dot product of data[4] with every + 4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype. + The pseudo code is as follows. + + .. code-block:: c + + void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){ + for (int i = 0; i < int32_lanes; i++){ + out[i] = 0; + for (int k = 0; k < 4; k++){ + out[i] += data[k] * kernel[i][k] + } + } + } + + Physically, the kernel array sits in a vector register and + the data[4] is broadcasted to another vector register. This + function returns a TensorIntrin that can be used to tensorize + a schedule. + + Parameters + ---------- + int32_lanes: int + How many int32/uint32 to produce + dtype: str, optional, {"uint", "int"} + Whether it works on unsigned int or signed int + + Returns + ------- + intrin : TensorIntrin + The ARM uint8 TensorIntrin that can be used in tensorizing schedule + """ + num_int8_elements = 4 # 4 int8 elements in int32 + + data = tvm.placeholder((num_int8_elements,), dtype='%s8' % dtype, name='data') + kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='%s8' % dtype, name='kernel') + + k = tvm.reduce_axis((0, num_int8_elements), name='k') + C = tvm.compute((int32_lanes,), + lambda i: tvm.sum(data[k].astype('%s32' % dtype) * + kernel[i, k].astype('%s32' % dtype), + axis=k), name="C") + + a_buffer = tvm.decl_buffer(data.shape, dtype='%s8' % dtype, name="a_buffer", + offset_factor=1, + strides=[1]) + b_buffer = tvm.decl_buffer(kernel.shape, dtype='%s8' % dtype, name="b_buffer", + offset_factor=1, + strides=[tvm.var('s'), 1]) + + def _intrin_func(ins, outs): + def _instr(index): + ib = tvm.ir_builder.create() + if index == 1: + ib.emit(outs[0].vstore(0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes)))) + return ib.get() + + dtype_a = '%s8x%d' % (dtype, num_int8_elements) + dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements) + dtype_c = '%s32x%d' % (dtype, int32_lanes) + + a_int8 = ins[0].vload([0], dtype_a) + re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8) + # broadcast a + vec_ai32 = re_int32.astype(dtype_c) + + vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32) + vec_b = ins[1].vload([0, 0], dtype_b) + vec_c = outs[0].vload([0], dtype_c) + + inst = 'udot' if dtype == 'uint' else 'sdot' + inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % ( + inst, int32_lanes, int32_lanes * num_int8_elements) + vdot = tvm.call_llvm_intrin(dtype_c, + inst, + tvm.const(2, 'uint32'), + vec_c, vec_a, vec_b) + ib.emit(outs[0].vstore(0, vdot)) + return ib.get() + + # body, reset, update + return _instr(0), _instr(1), _instr(2) + + with tvm.build_config(offset_factor=1, partition_const_loop=True): + return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) diff --git a/topi/python/topi/generic/conv2d.py b/topi/python/topi/generic/conv2d.py new file mode 100644 index 0000000..332c2fd --- /dev/null +++ b/topi/python/topi/generic/conv2d.py @@ -0,0 +1,239 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, too-many-locals +# pylint: disable=unused-argument, redefined-builtin +"""Generic convolution schedules""" +from __future__ import absolute_import as _abs +import tvm +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity +from ..util import get_const_tuple + +def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements): + """Fallback schedule for conv2d int8 on cpu. + Normally the inner most pattern takes two int8/uint8 tensors + data[num_int8_elements] and kernel[int32_lanes, num_int8_elements], + produces a dot product int32/uint32 output[int32_lanes]. + + Parameters + ---------- + int32_lanes : int + How many numbers of int32/uint32 will be produced using intrinsic. + This is related to output channel. + num_int8_elements : int + How many numbers of input int32/uint32 will be multiplied and reduced. + This is related to input channel. + """ + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + assert wkl.out_filter % int32_lanes == 0, \ + "wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes) + assert wkl.in_filter % num_int8_elements == 0, \ + "wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements) + + oc_bn = int32_lanes + ic_bn = 1 + for bn in range(oc_bn, 0, -4): + if wkl.in_filter % bn == 0: + ic_bn = bn + break + + reg_n = 1 + for n in range(31, 0, -1): + if out_width % n == 0: + reg_n = n + break + + cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) + cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) + cfg["unroll_kw"] = OtherOptionEntity(False) + + +def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements): + """Fallback schedule for 1x1 conv2d int8 on cpu. + Normally the inner most pattern takes two int8/uint8 tensors + data[num_int8_elements] and kernel[int32_lanes, num_int8_elements], + produces a dot product int32/uint32 output[int32_lanes]. + + Parameters + ---------- + int32_lanes : int + How many numbers of int32/uint32 will be produced using intrinsic. + This is related to output channel. + num_int8_elements : int + How many numbers of input int32/uint32 will be multiplied and reduced. + This is related to input channel. + """ + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + assert wkl.out_filter % int32_lanes == 0, \ + "wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes) + assert wkl.in_filter % num_int8_elements == 0, \ + "wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements) + + oc_bn = int32_lanes + ic_bn = 1 + for bn in range(oc_bn, 0, -4): + if wkl.in_filter % bn == 0: + ic_bn = bn + break + + for ow_factor in range(out_width, 0, -1): + if out_width % ow_factor == 0: + for oh_factor in range(out_height, 0, -1): + if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: + cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) + cfg["tile_oh"] = OtherOptionEntity(oh_factor) + cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor]) + return + raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) + + +def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None): + """ + Defines the schedule for INT8 for Intel and ARM machines + Uses the Intel/ARM intrinsics to use INT8 operations + More details - https://software.intel.com/en-us/articles/ + lower-numerical-precision-deep-learning-inference-and-training + """ + reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val + _, _, _, _, ic_bn = get_const_tuple(data.shape) + _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) + + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ic_chunk, ih, iw, _ = s[A].op.axis + parallel_axis = s[A].fuse(batch, ic_chunk, ih) + s[A].parallel(parallel_axis) + + # schedule 5-D NCHW[x]c conv + C, O = conv_out, last + CC = s.cache_write(C, 'global') + + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis + ow_chunk, ow_block = s[C].split(ow, factor=reg_n) + s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[C].fuse(batch, oc_chunk, oh) + s[C].vectorize(oc_block) + if C == O: + s[C].parallel(parallel_axis) + + s[CC].compute_at(s[C], ow_chunk) + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) + + assert oc_bn % int32_lanes == 0 + assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32 + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) + + if unroll_kw: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw, + ow_block, oc_f_inner, oc_s_inner, ic_s_inner) + s[CC].unroll(kw) + else: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner, + ow_block, oc_f_inner, oc_s_inner, ic_s_inner) + + if intrin is not None: + s[CC].tensorize(oc_s_inner, intrin) + s[CC].unroll(ow_block) + s[CC].unroll(oc_f_inner) + + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s + +def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None): + """ + Defines the 1x1 conv schedule for INT8 for Intel and ARM machines + Uses the Intel/ARM intrinsics to use INT8 operations + More details - https://software.intel.com/en-us/articles/ + lower-numerical-precision-deep-learning-inference-and-training + """ + oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] + _, _, _, _, ic_bn = get_const_tuple(data.shape) + _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) + + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ic_chunk, ih, iw, ic_block = s[A].op.axis + parallel_axis = s[A].fuse(batch, ic_chunk, ih) + s[A].parallel(parallel_axis) + + C, O = conv_out, last + CC = s.cache_write(C, 'global') + + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis + oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[C].split(ow, factor=ow_factor) + s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].vectorize(oc_block) + + parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) + s[CC].compute_at(s[C], parallel_axis) + if C == O: + s[C].parallel(parallel_axis) + + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + assert oc_bn % int32_lanes == 0 + assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32 + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) + + oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) + + s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner, + ow_inner, oc_f_inner, oc_s_inner, ic_s_inner) + s[CC].fuse(oc_chunk, oh_outer) + + if intrin is not None: + s[CC].tensorize(oc_s_inner, intrin) + s[CC].unroll(ow_inner) + s[CC].unroll(oh_inner) + + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 904dd54..ffae4b2 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -595,19 +595,11 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - target = tvm.target.current_target(allow_none=False) oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group - # Since the weight is 7-D and the last element size is 4, we have to - # check ic_bn should be a multiple of 4. - # Similary, oc_bn has to be a multiple of 4. - - assert ic_bn % 4 == 0 - assert oc_bn % 16 == 0 - dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 6e36e93..96b6e47 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -22,6 +22,7 @@ from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from ..nn.pad import pad from ..nn.util import infer_pad, get_pad_tuple +from ..generic import conv2d as conv2d_generic from ..util import get_const_tuple, simplify from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .util import get_fp32_len @@ -57,36 +58,6 @@ def _fallback_schedule(cfg, wkl): raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) -def _fallback_schedule_int8(cfg, wkl): - simd_width = get_fp32_len() - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride - out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 - out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 - - oc_bn = 16 - assert wkl.out_filter % oc_bn == 0 - - ic_bn = 1 - for bn in range(oc_bn, 0, -4): - if wkl.in_filter % bn == 0: - ic_bn = bn - break - assert wkl.in_filter % 4 == 0 - - for ow_factor in range(out_width, 0, -1): - if out_width % ow_factor == 0: - for oh_factor in range(out_height, 0, -1): - if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: - cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) - cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) - cfg["tile_oh"] = OtherOptionEntity(oh_factor) - cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor]) - return - raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) - - - def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): # fetch schedule ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], @@ -210,71 +181,9 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): - """ - Defines the schedule for INT8 for intel machines - Uses the Intel intrinsics to use INT8 operations - More details - https://software.intel.com/en-us/articles/ - lower-numerical-precision-deep-learning-inference-and-training - """ - int32_lanes = 16 - - oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] - _, _, _, _, ic_bn = get_const_tuple(data.shape) - _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) - - # schedule data - A = data - if isinstance(s[A].op, tvm.tensor.ComputeOp): - batch, ic_chunk, ih, iw, ic_block = s[A].op.axis - parallel_axis = s[A].fuse(batch, ic_chunk, ih) - s[A].parallel(parallel_axis) - - C, O = conv_out, last - CC = s.cache_write(C, 'global') - - batch, oc_chunk, oh, ow, oc_block = s[C].op.axis - oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) - ow_outer, ow_inner = s[C].split(ow, factor=ow_factor) - s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) - s[C].vectorize(oc_block) - - parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) - s[CC].compute_at(s[C], parallel_axis) - if C == O: - s[C].parallel(parallel_axis) - - _, oc_chunk, oh, ow, oc_block = s[CC].op.axis - kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis - - # Skylake and future processors have 16 vector lanes - assert oc_bn % int32_lanes == 0 - - oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) - - oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) - ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) - - s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner, - ow_inner, oc_f_inner, oc_s_inner, ic_s_inner) - s[CC].fuse(oc_chunk, oh_outer) - - pc = dot_16x1x16_int8_int8_int32() - s[CC].tensorize(oc_s_inner, pc) - s[CC].unroll(ow_inner) - s[CC].unroll(oh_inner) - - if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) - ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) - s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) - - parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) - - return s + return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, + int32_lanes=16, + intrin=dot_16x1x16_int8_int8_int32()) def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype): diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index a7f38ac..53b79bd 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -21,6 +21,7 @@ import tvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from ..nn.util import infer_pad +from ..generic import conv2d as conv2d_generic from ..util import get_const_tuple from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .util import get_fp32_len @@ -56,7 +57,6 @@ def _fallback_schedule(cfg, wkl): def _fallback_schedule_int8(cfg, wkl): - simd_width = get_fp32_len() HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 @@ -207,68 +207,6 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): - """ - Defines the schedule for INT8 for intel machines - Uses the Intel intrinsics to use INT8 operations - More details - https://software.intel.com/en-us/articles/ - lower-numerical-precision-deep-learning-inference-and-training - """ - int32_lanes = 16 - - reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val - _, _, _, _, ic_bn = get_const_tuple(data.shape) - _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) - - A = data - if isinstance(s[A].op, tvm.tensor.ComputeOp): - batch, ic_chunk, ih, iw, _ = s[A].op.axis - parallel_axis = s[A].fuse(batch, ic_chunk, ih) - s[A].parallel(parallel_axis) - - # schedule 5-D NCHW[x]c conv - C, O = conv_out, last - CC = s.cache_write(C, 'global') - - batch, oc_chunk, oh, ow, oc_block = s[C].op.axis - ow_chunk, ow_block = s[C].split(ow, factor=reg_n) - s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[C].fuse(batch, oc_chunk, oh) - s[C].vectorize(oc_block) - if C == O: - s[C].parallel(parallel_axis) - - s[CC].compute_at(s[C], ow_chunk) - _, oc_chunk, oh, ow, oc_block = s[CC].op.axis - kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis - - ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) - - # Skylake and future processors have 16 vector lanes - assert oc_bn % int32_lanes == 0 - - oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) - - if unroll_kw: - s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw, - ow_block, oc_f_inner, oc_s_inner, ic_s_inner) - s[CC].unroll(kw) - else: - s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner, - ow_block, oc_f_inner, oc_s_inner, ic_s_inner) - - - pc = dot_16x1x16_int8_int8_int32() - s[CC].tensorize(oc_s_inner, pc) - s[CC].unroll(ow_block) - s[CC].unroll(oc_f_inner) - - if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - ow_chunk, ow_block = s[O].split(ow, factor=reg_n) - s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) - - return s + return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, + int32_lanes=16, + intrin=dot_16x1x16_int8_int8_int32()) diff --git a/topi/python/topi/x86/conv2d_int8.py b/topi/python/topi/x86/conv2d_int8.py index 3f65db4..f701108 100644 --- a/topi/python/topi/x86/conv2d_int8.py +++ b/topi/python/topi/x86/conv2d_int8.py @@ -24,6 +24,7 @@ from tvm.autotvm.task import get_config from tvm.autotvm.task.topi_integration import deserialize_args from ..nn.conv2d import _get_workload as _get_conv2d_workload from .. import generic, tag +from ..generic import conv2d as conv2d_generic from ..util import get_const_tuple from ..nn.conv2d import conv2d_NCHWc_int8 from .. import nn @@ -38,9 +39,11 @@ def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_ wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 if is_kernel_1x1: - conv2d_avx_1x1._fallback_schedule_int8(cfg, wkl) + conv2d_generic.fallback_schedule_cpu_1x1_int8( + cfg, wkl, int32_lanes=16, num_int8_elements=4) else: - conv2d_avx_common._fallback_schedule_int8(cfg, wkl) + conv2d_generic.fallback_schedule_cpu_common_int8( + cfg, wkl, int32_lanes=16, num_int8_elements=4) def _is_int8_hw_support(data_dtype, kernel_dtype): diff --git a/topi/recipe/conv/test_conv_int8_arm.py b/topi/recipe/conv/test_conv_int8_arm.py new file mode 100644 index 0000000..ff0d37d --- /dev/null +++ b/topi/recipe/conv/test_conv_int8_arm.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return +""" Conv Int8 functional and performance testing""" +import sys +import logging +import numpy as np +import tvm +import topi + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +LOGGER = logging.getLogger('test_conv_int8_intel') +LOGGER.disabled = False + +# All the WORKLOADS from Resnet except first layer +# Workload is ['height', 'width', 'in_filter', 'out_filter', +# 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) +WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + (56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + (56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + (56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + (28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + (28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + (28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + (14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + (14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + (14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + (7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + (56, 56, 64, 256, 1, 1, 0, 0, 1, 1), + (56, 56, 256, 64, 1, 1, 0, 0, 1, 1), + (56, 56, 256, 128, 1, 1, 0, 0, 2, 2), + (28, 28, 128, 512, 1, 1, 0, 0, 1, 1), + (56, 56, 256, 512, 1, 1, 0, 0, 2, 2), + (28, 28, 512, 128, 1, 1, 0, 0, 1, 1), + (28, 28, 512, 256, 1, 1, 0, 0, 2, 2), + (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1), + (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2), + (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1), + (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2), + (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1), + (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2), + (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1) + ] + + +TARGET_NAME = 'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod' +NUM_VEC_LANES = 16 +CTX = tvm.context(TARGET_NAME, 0) + +def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad, + hstride, wstride, out_dtype): + """ + Finds out the shape of all data structures + """ + data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES) + + if out_dtype == 'int32' or out_dtype == 'uint32': + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, + NUM_VEC_LANES//4, NUM_VEC_LANES, 4) + elif out_dtype == 'float32': + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, + NUM_VEC_LANES, NUM_VEC_LANES) + out_height = (im_height + 2 * hpad - k_h) // hstride + 1 + out_width = (im_width + 2 * wpad - k_w) // wstride + 1 + o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES) + return (data_shape, kernel_shape, o_shape) + + + +def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter, + out_filter, k_h, k_w, hpad, wpad, hstride, wstride): + """ + Runs the inference and checks the functional correctness between + compute and schedule outputs + """ + (data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter, + out_filter, k_h, k_w, hpad, wpad, + hstride, wstride, out_dtype) + + # Create TVM placeholders + data = tvm.placeholder(data_shape, name='data', dtype=data_dtype) + kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype) + + # Create the numpy arrays to be used for executing conv models + if data_dtype == 'float32': + data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX) + kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX) + else: + data_array = tvm.nd.array(np.random.randint(100, size=data_shape).astype(data_dtype)) + kernel_array = tvm.nd.array(np.random.randint(100, size=kernel_shape).astype(kernel_dtype)) + + # c_orig will be used for declaration ouptut + # c_sch will be used for scheduled computation output + c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX) + c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX) + + + with tvm.target.create(TARGET_NAME): + if out_dtype == "float32": + conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride, + padding=hpad, dilation=(1, 1), + layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype) + else: + conv = topi.nn.conv2d_NCHWc_int8(data, kernel, strides=hstride, + padding=hpad, dilation=(1, 1), + layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype) + out = topi.nn.relu(conv) + sch = tvm.create_schedule(out.op) + func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out') + func(data_array, kernel_array, c_orig) + LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True)) + + # Generate and run the optimized schedule + if out_dtype == "float32": + sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out]) + else: + sconv = topi.generic.nn.schedule_conv2d_NCHWc_int8(outs=[out]) + func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv') + func(data_array, kernel_array, c_sch) + + # Functional check + if data_dtype == 'uint8': + np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy()) + else: + assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy()) + + evaluator = func.time_evaluator(func.entry_name, CTX, number=1000) + LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True)) + return evaluator(data_array, kernel_array, c_sch).mean + +if __name__ == "__main__": + LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup") + SPEEDUP_ARRAY = [] + for i, wkl in enumerate(WORKLOADS): + for dtype in ["uint", "int"]: + fp32_time = run_inference('float32', 'float32', 'float32', *wkl) + int8_time = run_inference('%s8' % dtype, '%s8' % dtype, '%s32' % dtype, *wkl) + kernel_h = wkl[4] + kernel_w = wkl[5] + LOGGER.info("[%s] Workload#" % dtype + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", " + + str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time)) + + SPEEDUP_ARRAY.append(fp32_time/int8_time) + LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY)))) -- 2.7.4