From 23c22812b8e2a46133cfbc89e4910e24dc0b976a Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Thu, 5 Sep 2019 11:29:42 -0700 Subject: [PATCH] [VTA][TOPI] Conv2d transpose (deconvolution) operator support (#3777) * initial conv2d_transpose * correct select operator * cleanup * fix * fix correcness check * conv2d transpose declaration fix * autotvm conv2d_transpose tuning script * ir pass fix * fix tuning script * deriving params from env, adding bias * removing bias comp from deconvolution * lint * fix * lint * lint * turning off cpu * lint, ops * lint * import fix * removing hard coded values * lint --- vta/python/vta/build_module.py | 3 +- vta/python/vta/ir_pass.py | 139 +++++++++++++ vta/python/vta/top/__init__.py | 2 + vta/python/vta/top/op.py | 46 +++- vta/python/vta/top/util.py | 25 +++ vta/python/vta/top/vta_conv2d.py | 10 +- vta/python/vta/top/vta_conv2d_transpose.py | 195 +++++++++++++++++ vta/scripts/tune_conv2d.py | 34 +-- vta/scripts/tune_conv2d_transpose.py | 142 +++++++++++++ .../integration/test_benchmark_topi_conv2d.py | 9 +- .../test_benchmark_topi_conv2d_transpose.py | 231 +++++++++++++++++++++ 11 files changed, 809 insertions(+), 27 deletions(-) create mode 100644 vta/python/vta/top/util.py create mode 100644 vta/python/vta/top/vta_conv2d_transpose.py create mode 100644 vta/scripts/tune_conv2d_transpose.py create mode 100644 vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index dbd2e4b..5c24375 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -69,7 +69,8 @@ def build_config(debug_flag=0, **kwargs): debug_flag) return tvm.make.stmt_seq(debug, stmt) - pass_list = [(1, ir_pass.inject_dma_intrin), + pass_list = [(0, ir_pass.inject_conv2d_transpose_skip), + (1, ir_pass.inject_dma_intrin), (1, ir_pass.inject_skip_copy), (1, ir_pass.annotate_alu_coproc_scope), (1, lambda x: tvm.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)), diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index e809dd4..06a1975 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -579,6 +579,145 @@ def inject_dma_intrin(stmt_in): return tvm.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy) +def _get_gemm_intrin_buffer(): + env = get_env() + wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH + assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN + wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN) + assert wgt_shape[0] * wgt_shape[1] == wgt_lanes + inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH + assert inp_lanes == env.BATCH * env.BLOCK_IN + inp_shape = (env.BATCH, env.BLOCK_IN) + assert inp_shape[0] * inp_shape[1] == inp_lanes + out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH + assert out_lanes == env.BATCH * env.BLOCK_OUT + out_shape = (env.BATCH, env.BLOCK_OUT) + assert out_shape[0] * out_shape[1] == out_lanes + wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), + dtype="int%d" % env.WGT_WIDTH, + name=env.wgt_scope) + inp = tvm.placeholder((inp_shape[0], inp_shape[1]), + dtype="int%d" % env.INP_WIDTH, + name=env.inp_scope) + k = tvm.reduce_axis((0, wgt_shape[1]), name="k") + out_dtype = "int%d" % env.ACC_WIDTH + out = tvm.compute((out_shape[0], out_shape[1]), + lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) * + wgt[j, k].astype(out_dtype), + axis=[k]), + name="out") + wgt_layout = tvm.decl_buffer( + wgt.shape, wgt.dtype, env.wgt_scope, + scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes) + inp_layout = tvm.decl_buffer( + inp.shape, inp.dtype, env.inp_scope, + scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes) + out_layout = tvm.decl_buffer( + out.shape, out.dtype, env.acc_scope, + scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes) + + return wgt_layout, inp_layout, out_layout + + +def inject_conv2d_transpose_skip(stmt_in): + """Pass to skip 0-weights in conv2d transpose with stride > 1. + + Parameters + ---------- + stmt_in : Stmt + Input statement + + Returns + ------- + stmt_out : Stmt + Transformed statement + """ + env = get_env() + dwgt, dinp, dout = _get_gemm_intrin_buffer() + + calls = [] + selects = [] + + def _find_basics(op): + if isinstance(op, tvm.expr.Call): + calls.append(op) + elif isinstance(op, tvm.expr.Select): + selects.append(op) + + def _do_fold(op): + if _match_pragma(op, "conv2d_transpose_gemm"): + is_init = ".init" in str(op) + tvm.ir_pass.PostOrderVisit(op, _find_basics) + + if is_init: + # create inner most block + irb = tvm.ir_builder.create() + dev = env.dev + irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) + irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) + irb.emit(tvm.call_extern("int32", "VTAUopPush", + 0, 1, + dout.access_ptr("rw", "int32"), + 0, 0, + 0, 0, 0)) + inner = irb.get() + args = op.body.body.args + res_tensor = op.body.body.func.output(0) + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) + inner = tvm.make.AttrStmt( + [dout, res_tensor], 'buffer_bind_scope', + tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) + return inner + else: + conv_call, data_call, kernel_call = calls[-3:] + pad_data_tensor = data_call.func.output(0) + kernel_tensor = kernel_call.func.output(0) + res_tensor = conv_call.func.output(0) + + if selects: + condition = selects[0].condition + else: + condition = tvm.const(1, 'int') + + # create inner most block + irb = tvm.ir_builder.create() + with irb.if_scope(condition): + dev = env.dev + irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) + irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) + irb.emit(tvm.call_extern("int32", "VTAUopPush", + 0, 0, + dout.access_ptr("rw", "int32"), + dinp.access_ptr("r", "int32"), + dwgt.access_ptr("r", "int32"), + 0, 0, 0)) + inner = irb.get() + + args = conv_call.args + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], + 1, 0, 1, 0, env.BLOCK_OUT) + inner = tvm.make.AttrStmt( + [dout, res_tensor], 'buffer_bind_scope', + tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) + args = kernel_call.args + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], + 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) + inner = tvm.make.AttrStmt( + [dwgt, kernel_tensor], 'buffer_bind_scope', + tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) + args = data_call.args + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], + 1, 0, 1, 0, env.BLOCK_IN) + inner = tvm.make.AttrStmt( + [dinp, pad_data_tensor], 'buffer_bind_scope', + tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) + return inner + return None + ret = tvm.ir_pass.IRTransform( + stmt_in, _do_fold, None, ["AttrStmt"]) + return ret + + def annotate_alu_coproc_scope(stmt_in): """Pass to insert ALU instruction. diff --git a/vta/python/vta/top/__init__.py b/vta/python/vta/top/__init__.py index 3b5132e..01a101a 100644 --- a/vta/python/vta/top/__init__.py +++ b/vta/python/vta/top/__init__.py @@ -4,7 +4,9 @@ from . import bitpack from .graphpack import graph_pack from . import op from . import vta_conv2d +from . import vta_conv2d_transpose from . import vta_dense +from . import util # NNVM is deprecated for VTA # from . import nnvm_bitpack diff --git a/vta/python/vta/top/op.py b/vta/python/vta/top/op.py index 96eaa8f..a02f62b 100644 --- a/vta/python/vta/top/op.py +++ b/vta/python/vta/top/op.py @@ -25,12 +25,14 @@ from tvm.relay.op import op as reg from tvm.relay.op.op import OpPattern from tvm.relay.op.nn import _nn -from .vta_conv2d import is_packed_layout +from .util import is_packed_layout from ..environment import get_env + # override to force partition at copy reg.register_pattern("copy", OpPattern.INJECTIVE, level=15) + @reg.register_compute("clip", level=15) def compute_clip(attrs, inputs, output_type, target): """ Clip operator. """ @@ -110,6 +112,48 @@ def schedule_conv2d(attrs, outs, target): return _nn.schedule_conv2d(attrs, outs, target) +@reg.register_compute("nn.conv2d_transpose", level=15) +def compute_conv2d_transpose(attrs, inputs, output_type, target): + """ 2D convolution algorithm. + """ + padding = topi.util.get_const_tuple(attrs.padding) + strides = topi.util.get_const_tuple(attrs.strides) + dilation = tuple([int(d) for d in attrs.dilation]) + layout = attrs.data_layout + out_dtype = attrs.out_dtype + + if target.device_name == "vta": + assert dilation == (1, 1), "support for dilation limited to (1, 1)" + if is_packed_layout(layout): + return [topi.nn.conv2d_transpose_nchw( + inputs[0], inputs[1], strides, padding, out_dtype)] + else: + # If it's not packed, run on ARM CPU + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target) + + # If VTA is not the target, default to _nn def + return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target) + + +@reg.register_schedule("nn.conv2d_transpose", level=15) +def schedule_conv2d_transpose(attrs, outputs, target): + """ 2D convolution schedule. + """ + layout = attrs.data_layout + + if target.device_name == "vta": + if is_packed_layout(layout): + return topi.nn.schedule_conv2d_transpose_nchw(outputs) + else: + # If it's not packed, run on ARM CPU + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target()) + + # If VTA is not the target, default to _nn def + return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target()) + + @reg.register_compute("nn.dense", level=15) def compute_dense(attrs, inputs, out_type, target): """Compute definition of dense""" diff --git a/vta/python/vta/top/util.py b/vta/python/vta/top/util.py new file mode 100644 index 0000000..0fbdb2f --- /dev/null +++ b/vta/python/vta/top/util.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""VTA TOPI Utils.""" + +def is_packed_layout(layout): + """Check if layout is packed layout""" + if layout == "NCHW": + return False + if "n" in layout and "c" in layout: + return True + return False diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index c455f53..e15f6c1 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -17,20 +17,14 @@ """Conv2D operator declaration and schedule registration for VTA.""" import numpy as np + import tvm from tvm import autotvm import topi +from .util import is_packed_layout from ..environment import get_env -def is_packed_layout(layout): - """Check if layout is packed layout""" - if layout == "NCHW": - return False - if "n" in layout and "c" in layout: - return True - return False - @autotvm.register_topi_compute(topi.nn.conv2d, 'vta', 'direct') def _declaration_conv2d(cfg, data, diff --git a/vta/python/vta/top/vta_conv2d_transpose.py b/vta/python/vta/top/vta_conv2d_transpose.py new file mode 100644 index 0000000..a2750dc --- /dev/null +++ b/vta/python/vta/top/vta_conv2d_transpose.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Conv2D_transpose operator declaration and schedule registration for VTA.""" + +import numpy as np + +import tvm +from tvm import autotvm +import topi +from topi.util import get_const_tuple +from topi.nn.util import get_pad_tuple + +from ..environment import get_env + +@autotvm.register_topi_compute(topi.nn.conv2d_transpose_nchw, 'vta', 'direct') +def _declatation_conv2d_transpose(cfg, + data, + kernel, + strides, + padding, + out_dtype): + ishape = get_const_tuple(data.shape) + kshape = get_const_tuple(kernel.shape) + b, c_i, i_h, i_w, t_b, t_ci = ishape + c_o, _, k_h, k_w, t_co, t_ci = kshape + stride_h, stride_w = strides + + # derive padding parameters + fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (k_h, k_w)) + bpad_top = k_h - 1 - fpad_top + bpad_bottom = k_h - 1 - fpad_bottom + bpad_left = k_w - 1 - fpad_left + bpad_right = k_w - 1 - fpad_right + + # padding stage + dilated_input = topi.nn.dilate(data, [1, 1, stride_h, stride_w, 1, 1]) + data_pad = topi.nn.pad(dilated_input, + [0, 0, bpad_top, bpad_left, 0, 0], + [0, 0, bpad_bottom, bpad_right, 0, 0]) + + # convolution transpose stage + out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h + out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w + oshape = (b, c_o, out_h, out_w, t_b, t_co) + d_c = tvm.reduce_axis((0, c_i), name='d_c') + d_h = tvm.reduce_axis((0, k_h), name='d_h') + d_w = tvm.reduce_axis((0, k_w), name='d_w') + d_ci = tvm.reduce_axis((0, t_ci), name='d_ci') + + out = tvm.compute( + oshape, + lambda i_n, i_c, i_h, i_w, j_n, j_c: tvm.sum( + data_pad(i_n, d_c, i_h + d_h, i_w + d_w, j_n, d_ci).astype(out_dtype) * + kernel[i_c, d_c, d_h, d_w, j_c, d_ci].astype(out_dtype), + axis=[d_c, d_h, d_w, d_ci]), + tag="packed_conv2d_transpose", + name='res') + + cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) * + kshape[2] * kshape[3] * ishape[1] * ishape[-1]) + + return out + +@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw, 'vta', 'direct') +def _schedule_conv2d_transpose(cfg, outs): + assert len(outs) == 1 + output = outs[0] + ewise_inputs = [] + ewise_ops = [] + conv2d_res = [] + assert output.dtype == "int8" + assert output.op.input_tensors[0].dtype == "int32" + + def _traverse(op): + if topi.tag.is_broadcast(op.tag): + if not op.same_as(output.op): + ewise_ops.append(op) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.PlaceholderOp): + ewise_inputs.append((op, tensor)) + else: + _traverse(tensor.op) + else: + assert op.tag == "packed_conv2d_transpose" + conv2d_res.append(op) + + _traverse(output.op) + assert len(conv2d_res) == 1 + conv2d_stage = conv2d_res[0].output(0) + s = tvm.create_schedule(output.op) + + ##### space definition begin ##### + b, c_o, x_i, x_j, _, c_i = s[conv2d_stage].op.axis + c_i, _, _, _ = s[conv2d_stage].op.reduce_axis + cfg.define_split('tile_b', b, num_outputs=2) + cfg.define_split('tile_h', x_i, num_outputs=2) + cfg.define_split('tile_w', x_j, num_outputs=2) + cfg.define_split('tile_ci', c_i, num_outputs=2) + cfg.define_split('tile_co', c_o, num_outputs=2) + cfg.define_knob('oc_nthread', [1, 2]) + cfg.define_knob('h_nthread', [1, 2]) + ###### space definition end ###### + + data, kernel = conv2d_stage.op.input_tensors + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + temp = data.op.input_tensors[0] + pad_data = data + data = temp + else: + pad_data = None + + env = get_env() + + # setup pad + if pad_data is not None: + cdata = pad_data + s[pad_data].set_scope(env.inp_scope) + else: + cdata = s.cache_read(data, env.inp_scope, [conv2d_stage]) + ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage]) + s[conv2d_stage].set_scope(env.acc_scope) + + # cache read input + cache_read_ewise = [] + for consumer, tensor in ewise_inputs: + cache_read_ewise.append( + s.cache_read(tensor, env.acc_scope, [consumer])) + # set ewise scope + for op in ewise_ops: + s[op].set_scope(env.acc_scope) + s[op].pragma(s[op].op.axis[0], env.alu) + + # tile + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis + x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co) + x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i) + x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j) + s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) + store_pt = x_j0 + + # set all compute scopes + s[conv2d_stage].compute_at(s[output], store_pt) + for op in ewise_ops: + s[op].compute_at(s[output], store_pt) + + for tensor in cache_read_ewise: + s[tensor].compute_at(s[output], store_pt) + s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy) + + # virtual threading along output channel axes + if cfg['oc_nthread'].val > 1: + _, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + # virtual threading along spatial rows + if cfg['h_nthread'].val > 1: + _, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis + k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis + x_i, x_ii = s[conv2d_stage].split(x_i, 4) + x_j, x_jj = s[conv2d_stage].split(x_j, 2) + s[conv2d_stage].reorder(x_bo, k_o, x_j, x_co, x_i, x_jj, d_j, d_i, x_ii, x_bi, x_ci, k_i) + + for axis in [d_j, d_i, x_ii, x_jj]: + s[conv2d_stage].unroll(axis) + + k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o) + s[cdata].compute_at(s[conv2d_stage], k_o) + s[ckernel].compute_at(s[conv2d_stage], k_o) + + # Use VTA instructions + s[cdata].pragma(s[cdata].op.axis[0], env.dma_copy) + s[ckernel].pragma(s[ckernel].op.axis[0], env.dma_copy) + s[conv2d_stage].pragma(x_bi, "conv2d_transpose_gemm") + s[output].pragma(x_co1, env.dma_copy) + + return s diff --git a/vta/scripts/tune_conv2d.py b/vta/scripts/tune_conv2d.py index 296ce99..87f7909 100644 --- a/vta/scripts/tune_conv2d.py +++ b/vta/scripts/tune_conv2d.py @@ -58,22 +58,28 @@ def my_clip(x, a_min, a_max): x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x -def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype): +def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation): data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN) kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN) bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) - bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) with tvm.target.vta(): - res = topi.nn.conv2d(data, kernel, padding=padding, strides=strides, dilation=dilation, - layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN), out_dtype='int32') + res = topi.nn.conv2d( + input=data, + filter=kernel, + padding=padding, + strides=strides, + dilation=dilation, + layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN), + out_dtype=env.acc_dtype) + res = topi.right_shift(res, env.WGT_WIDTH) res = topi.add(res, bias) - res = topi.right_shift(res, 8) - res = my_clip(res, 0, 127) - res = topi.cast(res, "int8") + res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) + res = topi.cast(res, env.out_dtype) if tvm.target.current_target().device_name == 'vta': s = topi.generic.schedule_conv2d_nchw([res]) @@ -103,10 +109,9 @@ if __name__ == '__main__': exit() for idx, (wl_name, wl) in enumerate(resnet_wkls): - prefix = "[Task %2d/%2d] " % (idx, len(resnet_wkls)) - # Workload parameters + # Read in workload parameters N = wl.batch CI = wl.in_filter H = wl.height @@ -117,11 +122,14 @@ if __name__ == '__main__': strides = (wl.hstride, wl.wstride) padding = (wl.hpad, wl.wpad) dilation = (1, 1) - in_dtype = 'int8' - out_dtype = 'int32' - task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype), - target=tvm.target.vta(), target_host=env.target_host, template_key='direct') + # Create task + task = autotvm.task.create( + conv2d, + args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation), + target=tvm.target.vta(), + target_host=env.target_host, + template_key='direct') print(task.config_space) # Tune diff --git a/vta/scripts/tune_conv2d_transpose.py b/vta/scripts/tune_conv2d_transpose.py new file mode 100644 index 0000000..3e51d41 --- /dev/null +++ b/vta/scripts/tune_conv2d_transpose.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tuning a single conv2d transpose operator""" + +from collections import namedtuple +import logging +import os + +import tvm +from tvm import autotvm +from tvm.contrib.util import get_lower_ir +import topi +import vta +import vta.testing + +# Get batch info from env +env = vta.get_env() + +Workload = namedtuple("Conv2DTransposeWorkload", + ['batch', 'height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +dcgan_wkls = [ + # dcgan + ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2)), +] + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding): + data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN) + kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN) + + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + + with tvm.target.vta(): + res = topi.nn.conv2d_transpose_nchw( + Input=data, + Filter=kernel, + strides=strides, + padding=padding, + out_dtype=env.acc_dtype) + res = topi.right_shift(res, env.WGT_WIDTH) + res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) + res = topi.cast(res, env.out_dtype) + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_conv2d_transpose_nchw([res]) + else: + s = tvm.create_schedule([res.op]) + + return s, [data, kernel, res] + +if __name__ == '__main__': + + # Logging config (for printing tuning log to the screen) + logging.basicConfig() + # logging.getLogger('autotvm').setLevel(logging.DEBUG) + + # Tuning log files + log_file = "%s.conv2d_transpose.log" % (env.TARGET) + # create tmp log file + tmp_log_file = log_file + ".tmp" + if os.path.exists(log_file): + os.remove(log_file) + + # Get tracker info from env + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = os.environ.get("TVM_TRACKER_PORT", None) + if not tracker_host or not tracker_port: + print("Set your AutoTVM tracker node host and port variables to run the autotuner") + exit() + + for idx, (wl_name, wl) in enumerate(dcgan_wkls): + prefix = "[Task %2d/%2d] " % (idx, len(dcgan_wkls)) + + # Read in workload parameters + N = wl.batch + H = wl.height + W = wl.width + CI = wl.in_filter + CO = wl.out_filter + KH = wl.hkernel + KW = wl.wkernel + strides = (wl.hstride, wl.wstride) + padding = (wl.hpad, wl.wpad) + + # Create task + task = autotvm.task.create( + conv2d_transpose, + args=(N, CI, H, W, CO, KH, KW, strides, padding), + target=tvm.target.vta(), + target_host=env.target_host, + template_key='direct') + print(task.config_space) + + # Tune + measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder(), + runner=autotvm.RPCRunner( + env.TARGET, host=tracker_host, port=int(tracker_port), + number=5, timeout=60, + check_correctness=True)) + + # Run Tuner + tuner = autotvm.tuner.RandomTuner(task) + tuner.tune( + n_trial=len(task.config_space), + early_stopping=None, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(len(task.config_space), prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file)]) + + # Pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_file) + os.remove(tmp_log_file) diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py index 96922c3..942776f 100644 --- a/vta/tests/python/integration/test_benchmark_topi_conv2d.py +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py @@ -17,11 +17,11 @@ """Testing topi conv2d operator for VTA""" -import os import json -from collections import namedtuple +import os import numpy as np +from collections import namedtuple import tvm from tvm import autotvm @@ -34,6 +34,7 @@ from vta import program_fpga, reconfig_runtime import vta.testing from vta.testing import simulator + Workload = namedtuple("Conv2DWorkload", ['batch', 'height', 'width', 'in_filter', 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) @@ -88,7 +89,7 @@ def run_conv2d(env, remote, wl, target, b_shape = (wl.batch, wl.out_filter, 1, 1) if data_pack: data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN, - wl.height, wl.width, env.BATCH, env.BLOCK_IN) + wl.height, wl.width, env.BATCH, env.BLOCK_IN) kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN, wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) bias_shape = (wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT, @@ -205,7 +206,7 @@ def run_conv2d(env, remote, wl, target, (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width) bias_np = bias_np.transpose( (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1) - res_ref = res_ref >> 8 + res_ref = res_ref >> env.WGT_WIDTH res_ref += bias_np res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1) res_ref = res_ref.astype(env.out_dtype) diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py new file mode 100644 index 0000000..e2601d1 --- /dev/null +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Testing topi conv2d_transpose operator for VTA""" + +import json +import os + +import numpy as np +from collections import namedtuple + +import tvm +from tvm import autotvm +from tvm.contrib import util +from tvm.contrib.pickle_memoize import memoize +import topi +import topi.testing +import vta +from vta import program_fpga, reconfig_runtime +import vta.testing +from vta.testing import simulator + + +Workload = namedtuple("Conv2DTransposeWorkload", + ['batch', 'height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +# Get batch info from env +env = vta.get_env() + +# DCGAN workloads +dcgan_wklds = [ + # dcgan + ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2)), +] + +# FIXME: we need a custom clip operator to circumvent a pattern detection limitation +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +# Helper function to get factors +def _find_factors(n): + factors = [] + for f in range(1, n + 1): + if n % f == 0: + factors.append(f) + return factors + + +def run_conv2d_transpose(env, remote, wl, target, + check_correctness=True, print_ir=False, + samples=4): + + # Workload assertions + assert wl.hpad == wl.wpad + + # Perform packing only if we are targeting the accelerator + if "arm_cpu" in target.keys: + data_pack = False + layout = "NCHW" + elif "vta" in target.keys: + data_pack = True + layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN) + + # Derive shapes depending upon packing + + a_shape = (wl.batch, wl.in_filter, wl.height, wl.width) + w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) + if data_pack: + data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN, + wl.height, wl.width, env.BATCH, env.BLOCK_IN) + kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN, + wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) + else: + data_shape = a_shape + kernel_shape = w_shape + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + + # Define base computation schedule + with target: + res = topi.nn.conv2d_transpose_nchw( + data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), env.acc_dtype) + res = topi.right_shift(res, env.WGT_WIDTH) + res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) + res = topi.cast(res, env.out_dtype) + # Derive base schedule + s = topi.generic.schedule_conv2d_transpose_nchw([res]) + if print_ir: + print(vta.lower(s, [data, kernel, res], simple_mode=True)) + + # Derive number of ops + fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel + fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel + num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter + + # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc") + def get_ref_data(): + # derive min max for act and wgt types (max non inclusive) + a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1)) + w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1)) + a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype) + w_np = np.random.randint(w_min, w_max, size=(wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel)).astype(kernel.dtype) + r_np = topi.testing.conv2d_transpose_nchw_python( + a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype), (wl.hstride, wl.wstride), wl.hpad).astype(env.acc_dtype) + return a_np, w_np, r_np + + # Data in original format + data_np, kernel_np, res_ref = get_ref_data() + if data_pack: + data_np = data_np.reshape( + wl.batch//env.BATCH, env.BATCH, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) + kernel_np = kernel_np.reshape( + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT, + wl.hkernel, wl.wkernel).transpose((2, 0, 4, 5, 3, 1)) + kernel_np = np.flip(kernel_np, 2) + kernel_np = np.flip(kernel_np, 3) + + # Build + if "vta" in target.keys: + mod = vta.build(s, [data, kernel, res], + target=target, + target_host=env.target_host, + name="conv2d_transpose") + else: + mod = tvm.build(s, [data, kernel, res], + target=target, + target_host=env.target_host, + name="conv2d_transpose") + temp = util.tempdir() + mod.save(temp.relpath("conv2d_transpose.o")) + remote.upload(temp.relpath("conv2d_transpose.o")) + f = remote.load_module("conv2d_transpose.o") + ctx = remote.context(str(target)) + + res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype) + data_arr = tvm.nd.array(data_np, ctx) + kernel_arr = tvm.nd.array(kernel_np, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d_transpose", ctx, number=samples) + + # In vta sim mode, collect simulator runtime statistics + stats = {} + cost = None + if env.TARGET in ["sim", "tsim"]: + # Check if we're in local RPC mode (allows us to rebuild the + # runtime on the fly when varying the VTA designs) + local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) + if local_rpc: + if env.TARGET == "sim": + remote.get_function("vta.simulator.profiler_clear")() + else: + remote.get_function("vta.tsim.profiler_clear")() + cost = time_f(data_arr, kernel_arr, res_arr) + if env.TARGET == "sim": + stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) + else: + stats = json.loads(remote.get_function("vta.tsim.profiler_status")()) + else: + simulator.clear_stats() + cost = time_f(data_arr, kernel_arr, res_arr) + stats = simulator.stats() + else: + cost = time_f(data_arr, kernel_arr, res_arr) + + # Check correctness + correct = False + if check_correctness: + res_orig = res_arr.asnumpy() + if data_pack: + res_orig = res_orig.transpose( + (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width) + res_ref = res_ref >> env.WGT_WIDTH + res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1) + res_ref = res_ref.astype(env.out_dtype) + correct = np.allclose(res_orig, res_ref) + + gops = (num_ops / cost.mean) / float(10 ** 9) + status = "PASSED" if correct else "FAILED" + if "arm_cpu" in target.keys: + device = "CPU" + elif "vta" in target.keys: + device = "VTA" + print("%s CONV2D TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops)) + + return correct, cost, stats + +def test_conv2d_transpose(device="vta"): + def _run(env, remote): + if device == "vta": + target = env.target + if env.TARGET not in ["sim", "tsim"]: + assert tvm.module.enabled("rpc") + program_fpga(remote, bitstream=None) + reconfig_runtime(remote) + elif device == "arm_cpu": + target = env.target_vta_cpu + with autotvm.tophub.context(target): # load pre-tuned schedule parameters + for _, wl in dcgan_wklds: + print(wl) + run_conv2d_transpose(env, remote, wl, target) + vta.testing.run(_run) + +if __name__ == "__main__": + # test_conv2d_transpose(device="arm_cpu") + test_conv2d_transpose(device="vta") -- 2.7.4