From 4ab73634c195fb7006e6279d9dbd71ecba33997b Mon Sep 17 00:00:00 2001 From: Cody Hao Yu Date: Thu, 24 Oct 2019 12:03:15 -0700 Subject: [PATCH] [TOPI] Tunable Template for Conv2D HWCN on CUDA (#4168) * support conv2d HWCN in AutoTVM and Relay * fix lint * fix comments and unit tests --- python/tvm/autotvm/task/task.py | 2 +- python/tvm/autotvm/task/topi_integration.py | 7 ++- python/tvm/relay/op/nn/_nn.py | 12 ++-- src/pass/vectorize_loop.cc | 2 - topi/python/topi/cuda/conv2d_hwcn.py | 85 +++++++++++++++++------------ topi/python/topi/generic/nn.py | 18 ++++++ topi/python/topi/nn/conv2d.py | 4 +- topi/tests/python/test_topi_conv2d_hwcn.py | 49 +++++++++++------ topi/tests/python/test_topi_conv2d_nchw.py | 1 - 9 files changed, 115 insertions(+), 65 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index e0db275..4f3cc90 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -226,7 +226,7 @@ def args_to_workload(x, topi_compute_func=None): elif x is None: workload = 0 else: - raise RuntimeError('Do not support type "%s" in argument. Consider to use' + raise RuntimeError('Do not support type "%s" in argument. Consider to use ' 'primitive types only' % type(x)) return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 09f08ad..ac4683d 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -176,9 +176,12 @@ class TaskExtractEnv: args = deserialize_args(args) A, W = args[:2] layout = args[-2] - assert layout == 'NCHW', "only support NCHW currently" + assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently" C = topi.nn.conv2d(*args, **kwargs) - s = topi.generic.schedule_conv2d_nchw([C]) + if layout == 'NCHW': + s = topi.generic.schedule_conv2d_nchw([C]) + else: + s = topi.generic.schedule_conv2d_hwcn([C]) return s, [A, W, C] @register("topi_nn_depthwise_conv2d_nchw") diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index b857234..0043ffa 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -153,14 +153,14 @@ def compute_conv2d(attrs, inputs, out_type, target): out_dtype = (inputs[0].dtype if out_dtype in ("same", "") else out_dtype) - assert layout in ["NCHW", "NHWC", "NCHW4c"] + assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"] (dilation_h, dilation_w) = dilation if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") def _get_out_depth(): weight_shape = get_const_tuple(inputs[1].shape) - if kernel_layout == "HWOI": + if kernel_layout.startswith("HW"): return weight_shape[2] * weight_shape[3] return weight_shape[0] * weight_shape[1] @@ -192,11 +192,13 @@ def schedule_conv2d(attrs, outs, target): with target: if groups == 1 and layout == "NCHW": return topi.generic.schedule_conv2d_nchw(outs) - if groups == 1 and layout == "NCHW4c": + elif groups == 1 and layout == "NCHW4c": return topi.generic.schedule_conv2d_nchw(outs) - if groups == 1 and layout == "NHWC": + elif groups == 1 and layout == "NHWC": return topi.generic.schedule_conv2d_nhwc(outs) - if groups != 1: + elif groups == 1 and layout == "HWCN": + return topi.generic.schedule_conv2d_hwcn(outs) + elif groups != 1: # collect in_channels to distinguish depthwise and group conv2d op = _find_conv2d_op(outs[0].op) assert op is not None diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 10db37d..1870330 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -368,7 +368,6 @@ class Vectorizer : public IRMutator { CHECK(!op->extent.type().is_vector()); Expr extent = Mutate(op->extent); if (extent.type().is_vector()) { - LOG(WARNING) << "Detect vectorized extent type, scalarizing..."; return Scalarize(s); } Stmt body = Mutate(op->body); @@ -386,7 +385,6 @@ class Vectorizer : public IRMutator { CHECK(!op->condition.type().is_vector()); Expr condition = this->Mutate(op->condition); if (condition.type().is_vector()) { - LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing..."; return Scalarize(s); } Stmt then_case = this->Mutate(op->then_case); diff --git a/topi/python/topi/cuda/conv2d_hwcn.py b/topi/python/topi/cuda/conv2d_hwcn.py index 5d101b9..18a624a 100644 --- a/topi/python/topi/cuda/conv2d_hwcn.py +++ b/topi/python/topi/cuda/conv2d_hwcn.py @@ -17,9 +17,14 @@ # pylint: disable=invalid-name, too-many-locals, too-many-statements """Schedule for conv2d_hwcn with auto fusion""" import tvm -from .. import tag +from tvm import autotvm +from tvm.autotvm.task.space import SplitEntity -def schedule_conv2d_hwcn(outs): +from .. import generic, tag + + +@autotvm.register_topi_schedule(generic.schedule_conv2d_hwcn, ["cuda", "gpu"], ["direct"]) +def schedule_conv2d_hwcn(cfg, outs): """Schedule for conv2d_hwcn and any element-wise operations. Parameters @@ -51,36 +56,44 @@ def schedule_conv2d_hwcn(outs): sch[B].set_scope("local") BL = B - tile = 8 - num_thread = 8 - block_factor = tile * num_thread - step = 8 - vthread = 2 + hi, wi, fi, ni = sch[Out].op.axis - block_x = tvm.thread_axis("blockIdx.x") - block_y = tvm.thread_axis("blockIdx.y") - block_z = tvm.thread_axis("blockIdx.z") - thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") - thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") - thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") - thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") + # Create tuning space + n_thread_cand = [1, 2, 4, 8, 16, 32] + vthread_cand = [1, 2, 4, 8] + + cfg.define_split( + 'tile_fi', + fi, + num_outputs=4, + filter=lambda x: + (x.size[1] in vthread_cand and x.size[2] in n_thread_cand)) + cfg.define_split( + 'tile_ni', + ni, + num_outputs=4, + filter=lambda x: + (x.size[1] in vthread_cand and x.size[2] in n_thread_cand)) + + if cfg.is_fallback: + cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4]) + cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4]) + + # Scheduling + step = 8 - hi, wi, fi, ni = sch[Out].op.axis bz = sch[Out].fuse(hi, wi) - by, fi = sch[Out].split(fi, factor=block_factor) - bx, ni = sch[Out].split(ni, factor=block_factor) - tyz, fi = sch[Out].split(fi, nparts=vthread) - txz, ni = sch[Out].split(ni, nparts=vthread) - ty, fi = sch[Out].split(fi, nparts=num_thread) - tx, ni = sch[Out].split(ni, nparts=num_thread) + by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi) + bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni) sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni) - sch[Out].bind(bz, block_z) - sch[Out].bind(by, block_y) - sch[Out].bind(bx, block_x) - sch[Out].bind(tyz, thread_yz) - sch[Out].bind(txz, thread_xz) - sch[Out].bind(ty, thread_y) - sch[Out].bind(tx, thread_x) + + sch[Out].bind(bz, tvm.thread_axis('blockIdx.z')) + sch[Out].bind(by, tvm.thread_axis('blockIdx.y')) + sch[Out].bind(bx, tvm.thread_axis('blockIdx.x')) + sch[Out].bind(tyz, tvm.thread_axis('vthread')) + sch[Out].bind(txz, tvm.thread_axis('vthread')) + sch[Out].bind(ty, tvm.thread_axis('threadIdx.y')) + sch[Out].bind(tx, tvm.thread_axis('threadIdx.x')) # Schedule BL local write sch[BL].compute_at(sch[Out], tx) @@ -98,21 +111,21 @@ def schedule_conv2d_hwcn(outs): sch[WL].compute_at(sch[BL], rci) # Schedule for A's shared memory load yi, xi, ci, ni = sch[AA].op.axis - ty, ci = sch[AA].split(ci, nparts=num_thread) - tx, ni = sch[AA].split(ni, nparts=num_thread) + ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2]) + tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2]) _, ni = sch[AA].split(ni, factor=4) sch[AA].reorder(ty, tx, yi, xi, ci, ni) - sch[AA].bind(ty, thread_y) - sch[AA].bind(tx, thread_x) + sch[AA].bind(ty, tvm.thread_axis('threadIdx.y')) + sch[AA].bind(tx, tvm.thread_axis('threadIdx.x')) sch[AA].vectorize(ni) # Schedule for W's shared memory load yi, xi, ci, fi = sch[WW].op.axis - ty, ci = sch[WW].split(ci, nparts=num_thread) - tx, fi = sch[WW].split(fi, nparts=num_thread) + ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2]) + tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2]) _, fi = sch[WW].split(fi, factor=4) sch[WW].reorder(ty, tx, yi, xi, ci, fi) - sch[WW].bind(ty, thread_y) - sch[WW].bind(tx, thread_x) + sch[WW].bind(ty, tvm.thread_axis('threadIdx.y')) + sch[WW].bind(tx, tvm.thread_axis('threadIdx.x')) sch[WW].vectorize(fi) scheduled_ops = [] diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index c2cb2b2..4043cb7 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -35,6 +35,24 @@ def _default_schedule(outs, auto_inline): @tvm.target.generic_func +def schedule_conv2d_hwcn(outs): + """Schedule for conv2d_hwcn + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv2d_hwcn + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +@tvm.target.generic_func def schedule_conv2d_nchw(outs): """Schedule for conv2d_nchw diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index ffae4b2..130632f 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -64,9 +64,9 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N # default declaration if layout == 'NCHW': return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype) - if layout == 'HWCN': + elif layout == 'HWCN': return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype) - if layout == 'NHWC': + elif layout == 'NHWC': return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype) raise ValueError("not support this layout {} yet".format(layout)) diff --git a/topi/tests/python/test_topi_conv2d_hwcn.py b/topi/tests/python/test_topi_conv2d_hwcn.py index 297df82..35423a6 100644 --- a/topi/tests/python/test_topi_conv2d_hwcn.py +++ b/topi/tests/python/test_topi_conv2d_hwcn.py @@ -29,24 +29,25 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') - B = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation) - C = topi.nn.relu(B) - s1 = topi.cuda.schedule_conv2d_hwcn([B]) - s2 = topi.cuda.schedule_conv2d_hwcn([C]) + B = tvm.placeholder((1, num_filter, 1), name='bias') a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) + b_shape = get_const_tuple(B.shape) dtype = A.dtype @memoize("topi.tests.test_topi_conv2d_hwcn.verify_hwcn") def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=b_shape).astype(dtype) dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) - b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - a_np, w_np, b_np, c_np = get_ref_data() + c1_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding) + c2_np = c1_np + b_np + c3_np = np.maximum(c2_np, 0) + return a_np, w_np, b_np, c1_np, c2_np, c3_np + + a_np, w_np, b_np, c1_np, c2_np, c3_np = get_ref_data() def check_device(device): ctx = tvm.context(device, 0) @@ -54,16 +55,32 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) + with tvm.target.create(device): + t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN') + t_bias = topi.add(t_conv, B) + t_relu = topi.nn.relu(t_bias) + s1 = topi.generic.schedule_conv2d_hwcn([t_conv]) + s2 = topi.generic.schedule_conv2d_hwcn([t_bias]) + s3 = topi.generic.schedule_conv2d_hwcn([t_relu]) a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - func1 = tvm.build(s1, [A, W, B], device) - func2 = tvm.build(s2, [A, W, C], device) - func1(a, w, b) - func2(a, w, c) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + b = tvm.nd.array(b_np, ctx) + + conv_out = tvm.nd.array( + np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), ctx) + bias_out = tvm.nd.array( + np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), ctx) + relu_out = tvm.nd.array( + np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), ctx) + func1 = tvm.build(s1, [A, W, t_conv], device) + func2 = tvm.build(s2, [A, W, B, t_bias], device) + func3 = tvm.build(s3, [A, W, B, t_relu], device) + func1(a, w, conv_out) + func2(a, w, b, bias_out) + func3(a, w, b, relu_out) + tvm.testing.assert_allclose(conv_out.asnumpy(), c1_np, rtol=1e-5) + tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5) + tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5) for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: check_device(device) diff --git a/topi/tests/python/test_topi_conv2d_nchw.py b/topi/tests/python/test_topi_conv2d_nchw.py index ca1cef2..d7c39a9 100644 --- a/topi/tests/python/test_topi_conv2d_nchw.py +++ b/topi/tests/python/test_topi_conv2d_nchw.py @@ -48,7 +48,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) if add_bias: - b_np = np.random.uniform(size=bias_shape).astype(dtype) c_np += b_np if add_relu: c_np = np.maximum(c_np, 0) -- 2.7.4