From: Neo Chien Date: Sun, 31 May 2020 19:32:25 +0000 (+0800) Subject: [AutoTVM][TOPI] Fix bifrost spatial packing conv2d auto tune (#5684) X-Git-Tag: upstream/0.7.0~632 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=12cfe4ae47c44da43542fc889b927d1a2efc4801;p=platform%2Fupstream%2Ftvm.git [AutoTVM][TOPI] Fix bifrost spatial packing conv2d auto tune (#5684) * [AutoTVM][TOPI] Fix bifrost spatial packing conv2d auto tune * [AutoTVM][TOPI] Putting placeholder replacement in compute * Fix winograd kernel replacement * Fix sanity check: Line too long --- diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index df63ae3..4faee42 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -167,15 +167,20 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til idxm(b*VP + bb, nW) * m + nu], name='d') - # transform kernel - if pre_computed: - U = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + VC = cfg['tile_k'].size[-1] + kvshape = (KH + tile_size - 1, KW + tile_size - 1, idxd(CO, VC), CI, VC) + U = tvm.te.placeholder(kvshape, kernel.dtype, name="U") else: - r_kh = te.reduce_axis((0, KH), 'r_kh') - r_kw = te.reduce_axis((0, KW), 'r_kw') - U = te.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk: - te.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) * - G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') + # transform kernel + if pre_computed: + U = kernel + else: + r_kh = te.reduce_axis((0, KH), 'r_kh') + r_kw = te.reduce_axis((0, KW), 'r_kw') + U = te.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk: + te.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) * + G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') # transform image r_eps = te.reduce_axis((0, alpha), 'r_eps') diff --git a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py index a4d7ad8..8cf8401 100644 --- a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py +++ b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py @@ -109,12 +109,15 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation, data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec') - if pre_packed: - kernel_vec = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + kernel_vec = tvm.te.placeholder(kvshape, kernel.dtype, name="kernel") else: - kernel_vec = te.compute(kvshape, lambda co, ci, kh, kw, vc: - kernel[co*VC+vc][ci][kh][kw], - name='kernel_vec') + if pre_packed: + kernel_vec = kernel + else: + kernel_vec = te.compute(kvshape, lambda co, ci, kh, kw, vc: + kernel[co*VC+vc][ci][kh][kw], + name='kernel_vec') ci = te.reduce_axis((0, CI), name='ci') kh = te.reduce_axis((0, KH), name='kh') @@ -187,12 +190,8 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, s[data_vec].parallel(h) if kernel_vec.op.name == 'kernel_vec': - co, _, _, _, _ = s[kernel_vec].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # kernel packing will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[kernel_vec].pragma(co, 'debug_skip_region') - else: + if not autotvm.GLOBAL_SCOPE.in_tuning: + co, _, _, _, _ = s[kernel_vec].op.axis s[kernel_vec].parallel(co) elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose': # for conv2d transpose co, _, _, _, _ = s[kernel_vec].op.axis @@ -267,9 +266,13 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_ data_vec = te.compute(dvshape, lambda n, oho, owo, ohi, owi, ic: data_pad[n][oho*OHI*HSTR+ohi][owo*OWI*WSTR+owi][ic], name='data_vec') - kernel_vec = te.compute(kvshape, lambda oco, kh, kw, ic, oci: \ - kernel[kh][kw][ic][oco*OCI+oci], - name='kernel_vec') + + if autotvm.GLOBAL_SCOPE.in_tuning: + kernel_vec = tvm.te.placeholder(kvshape, kernel.dtype, name="kernel") + else: + kernel_vec = te.compute(kvshape, lambda oco, kh, kw, ic, oci: \ + kernel[kh][kw][ic][oco*OCI+oci], + name='kernel_vec') ic = te.reduce_axis((0, IC), name='ic') kh = te.reduce_axis((0, KH), name='kh') @@ -339,12 +342,13 @@ def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output): s[kernel_vec].compute_at(s[conv], compat_axis) s[data_vec].compute_at(s[conv], compat_axis) - # schedule kernel pack - oco, kh, kw, ic, oci = kernel_vec.op.axis - s[kernel_vec].vectorize(oci) - s[kernel_vec].unroll(ic) - if cfg['compat'].val == 2: - s[kernel_vec].parallel(oco) + if not autotvm.GLOBAL_SCOPE.in_tuning: + # schedule kernel pack + oco, kh, kw, ic, oci = kernel_vec.op.axis + s[kernel_vec].vectorize(oci) + s[kernel_vec].unroll(ic) + if cfg['compat'].val == 2: + s[kernel_vec].parallel(oco) # schedule data pack if data_vec.op.name == 'data_vec_undilated': diff --git a/topi/python/topi/bifrost/conv2d.py b/topi/python/topi/bifrost/conv2d.py index 92e874a..ecc67c7 100644 --- a/topi/python/topi/bifrost/conv2d.py +++ b/topi/python/topi/bifrost/conv2d.py @@ -142,11 +142,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): s[data_vec].unroll(vw) if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == 'kernel_vec': - if autotvm.GLOBAL_SCOPE.in_tuning: - # kernel packing will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region') - else: + if not autotvm.GLOBAL_SCOPE.in_tuning: max_threads = tvm.target.Target.current(allow_none=False).max_num_threads co, ci, kh, kw, vc = s[kernel_vec].op.axis fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) @@ -313,10 +309,15 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til data_pad[n][c][h][w], name='d') - if pre_computed: - U = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + VC = cfg['tile_k'].size[-1] + kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC) + U = tvm.te.placeholder(kvshape, kernel.dtype, name="U") else: - U = _decl_winograd_kernel_transform(kernel, tile_size, G) + if pre_computed: + U = kernel + else: + U = _decl_winograd_kernel_transform(kernel, tile_size, G) # V [alpha * alpha, C, P_round) # Perform the image transform @@ -370,12 +371,7 @@ def _schedule_winograd(cfg, s, op): s[G].compute_inline() eps, _, _, _ = s[U].op.axis y, _, _, _ = s[padded_kernel].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # Kernel transformation will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[U].pragma(eps, 'debug_skip_region') - s[padded_kernel].pragma(y, 'debug_skip_region') - else: + if not autotvm.GLOBAL_SCOPE.in_tuning: # Pad kernel y, x, ky, kx = s[padded_kernel].op.axis s[padded_kernel].unroll(ky) diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index 12eb3d7..ed19326 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -138,14 +138,9 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): s[data_vec].unroll(vw) if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == 'kernel_vec': - co, ci, kh, kw, vc = s[kernel_vec].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # Directly use modified data layout placeholder. - kvshape = (co // vc, ci, kh, kw, vc) - kernel_vec = tvm.te.placeholder(kvshape, kernel_vec.dtype, name="kernel") - s[kernel_vec] = kernel_vec - else: + if not autotvm.GLOBAL_SCOPE.in_tuning: max_threads = tvm.target.Target.current(allow_none=False).max_num_threads + co, ci, kh, kw, vc = s[kernel_vec].op.axis fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) fused, vec = s[kernel_vec].split(fused, VC) bb, tt = s[kernel_vec].split(fused, max_threads) @@ -280,15 +275,21 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til data_pad[(b*bnb+bb) // (nH*nW)][ci][(b*bnb+bb) // nW % nH * m + eps] [(b*bnb+bb) % nW * m + nu], tvm.tir.const(0, data_pad.dtype)), name='d') - # transform kernel - if pre_computed: - U = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + VC = cfg['tile_k'].size[-1] + kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC) + U = tvm.te.placeholder(kvshape, kernel.dtype, name="U") else: - r_kh = te.reduce_axis((0, KH), 'r_kh') - r_kw = te.reduce_axis((0, KW), 'r_kw') - U = te.compute((alpha, alpha, CO // bna, CI, bna), lambda eps, nu, co, ci, vco: - te.sum(kernel[co * bna + vco][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], - axis=[r_kh, r_kw]), name='U') + # transform kernel + if pre_computed: + U = kernel + else: + r_kh = te.reduce_axis((0, KH), 'r_kh') + r_kw = te.reduce_axis((0, KW), 'r_kw') + U = te.compute((alpha, alpha, CO // bna, CI, bna), lambda eps, nu, co, ci, vco: + te.sum(kernel[co * bna + vco][ci][r_kh][r_kw] * + G[eps][r_kh] * G[nu][r_kw], + axis=[r_kh, r_kw]), name='U') # transform image r_a = te.reduce_axis((0, alpha), 'r_a')