elif x is None:
workload = 0
else:
- raise RuntimeError('Do not support type "%s" in argument. Consider to use'
+ raise RuntimeError('Do not support type "%s" in argument. Consider to use '
'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
- assert layout == 'NCHW', "only support NCHW currently"
+ assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
C = topi.nn.conv2d(*args, **kwargs)
- s = topi.generic.schedule_conv2d_nchw([C])
+ if layout == 'NCHW':
+ s = topi.generic.schedule_conv2d_nchw([C])
+ else:
+ s = topi.generic.schedule_conv2d_hwcn([C])
return s, [A, W, C]
@register("topi_nn_depthwise_conv2d_nchw")
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
- assert layout in ["NCHW", "NHWC", "NCHW4c"]
+ assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
def _get_out_depth():
weight_shape = get_const_tuple(inputs[1].shape)
- if kernel_layout == "HWOI":
+ if kernel_layout.startswith("HW"):
return weight_shape[2] * weight_shape[3]
return weight_shape[0] * weight_shape[1]
with target:
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
- if groups == 1 and layout == "NCHW4c":
+ elif groups == 1 and layout == "NCHW4c":
return topi.generic.schedule_conv2d_nchw(outs)
- if groups == 1 and layout == "NHWC":
+ elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
- if groups != 1:
+ elif groups == 1 and layout == "HWCN":
+ return topi.generic.schedule_conv2d_hwcn(outs)
+ elif groups != 1:
# collect in_channels to distinguish depthwise and group conv2d
op = _find_conv2d_op(outs[0].op)
assert op is not None
CHECK(!op->extent.type().is_vector());
Expr extent = Mutate(op->extent);
if (extent.type().is_vector()) {
- LOG(WARNING) << "Detect vectorized extent type, scalarizing...";
return Scalarize(s);
}
Stmt body = Mutate(op->body);
CHECK(!op->condition.type().is_vector());
Expr condition = this->Mutate(op->condition);
if (condition.type().is_vector()) {
- LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing...";
return Scalarize(s);
}
Stmt then_case = this->Mutate(op->then_case);
# pylint: disable=invalid-name, too-many-locals, too-many-statements
"""Schedule for conv2d_hwcn with auto fusion"""
import tvm
-from .. import tag
+from tvm import autotvm
+from tvm.autotvm.task.space import SplitEntity
-def schedule_conv2d_hwcn(outs):
+from .. import generic, tag
+
+
+@autotvm.register_topi_schedule(generic.schedule_conv2d_hwcn, ["cuda", "gpu"], ["direct"])
+def schedule_conv2d_hwcn(cfg, outs):
"""Schedule for conv2d_hwcn and any element-wise operations.
Parameters
sch[B].set_scope("local")
BL = B
- tile = 8
- num_thread = 8
- block_factor = tile * num_thread
- step = 8
- vthread = 2
+ hi, wi, fi, ni = sch[Out].op.axis
- block_x = tvm.thread_axis("blockIdx.x")
- block_y = tvm.thread_axis("blockIdx.y")
- block_z = tvm.thread_axis("blockIdx.z")
- thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
- thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
- thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
- thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
+ # Create tuning space
+ n_thread_cand = [1, 2, 4, 8, 16, 32]
+ vthread_cand = [1, 2, 4, 8]
+
+ cfg.define_split(
+ 'tile_fi',
+ fi,
+ num_outputs=4,
+ filter=lambda x:
+ (x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
+ cfg.define_split(
+ 'tile_ni',
+ ni,
+ num_outputs=4,
+ filter=lambda x:
+ (x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
+
+ if cfg.is_fallback:
+ cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4])
+ cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4])
+
+ # Scheduling
+ step = 8
- hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(hi, wi)
- by, fi = sch[Out].split(fi, factor=block_factor)
- bx, ni = sch[Out].split(ni, factor=block_factor)
- tyz, fi = sch[Out].split(fi, nparts=vthread)
- txz, ni = sch[Out].split(ni, nparts=vthread)
- ty, fi = sch[Out].split(fi, nparts=num_thread)
- tx, ni = sch[Out].split(ni, nparts=num_thread)
+ by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi)
+ bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni)
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
- sch[Out].bind(bz, block_z)
- sch[Out].bind(by, block_y)
- sch[Out].bind(bx, block_x)
- sch[Out].bind(tyz, thread_yz)
- sch[Out].bind(txz, thread_xz)
- sch[Out].bind(ty, thread_y)
- sch[Out].bind(tx, thread_x)
+
+ sch[Out].bind(bz, tvm.thread_axis('blockIdx.z'))
+ sch[Out].bind(by, tvm.thread_axis('blockIdx.y'))
+ sch[Out].bind(bx, tvm.thread_axis('blockIdx.x'))
+ sch[Out].bind(tyz, tvm.thread_axis('vthread'))
+ sch[Out].bind(txz, tvm.thread_axis('vthread'))
+ sch[Out].bind(ty, tvm.thread_axis('threadIdx.y'))
+ sch[Out].bind(tx, tvm.thread_axis('threadIdx.x'))
# Schedule BL local write
sch[BL].compute_at(sch[Out], tx)
sch[WL].compute_at(sch[BL], rci)
# Schedule for A's shared memory load
yi, xi, ci, ni = sch[AA].op.axis
- ty, ci = sch[AA].split(ci, nparts=num_thread)
- tx, ni = sch[AA].split(ni, nparts=num_thread)
+ ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2])
+ tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2])
_, ni = sch[AA].split(ni, factor=4)
sch[AA].reorder(ty, tx, yi, xi, ci, ni)
- sch[AA].bind(ty, thread_y)
- sch[AA].bind(tx, thread_x)
+ sch[AA].bind(ty, tvm.thread_axis('threadIdx.y'))
+ sch[AA].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[AA].vectorize(ni)
# Schedule for W's shared memory load
yi, xi, ci, fi = sch[WW].op.axis
- ty, ci = sch[WW].split(ci, nparts=num_thread)
- tx, fi = sch[WW].split(fi, nparts=num_thread)
+ ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2])
+ tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2])
_, fi = sch[WW].split(fi, factor=4)
sch[WW].reorder(ty, tx, yi, xi, ci, fi)
- sch[WW].bind(ty, thread_y)
- sch[WW].bind(tx, thread_x)
+ sch[WW].bind(ty, tvm.thread_axis('threadIdx.y'))
+ sch[WW].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[WW].vectorize(fi)
scheduled_ops = []
@tvm.target.generic_func
+def schedule_conv2d_hwcn(outs):
+ """Schedule for conv2d_hwcn
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of conv2d_hwcn
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ sch: Schedule
+ The computation schedule for the op.
+ """
+ return _default_schedule(outs, False)
+
+
+@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw
# default declaration
if layout == 'NCHW':
return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
- if layout == 'HWCN':
+ elif layout == 'HWCN':
return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
- if layout == 'NHWC':
+ elif layout == 'NHWC':
return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout))
A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
- B = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation)
- C = topi.nn.relu(B)
- s1 = topi.cuda.schedule_conv2d_hwcn([B])
- s2 = topi.cuda.schedule_conv2d_hwcn([C])
+ B = tvm.placeholder((1, num_filter, 1), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
+ b_shape = get_const_tuple(B.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_hwcn.verify_hwcn")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
+ b_np = np.random.uniform(size=b_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
- b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
- c_np = np.maximum(b_np, 0)
- return a_np, w_np, b_np, c_np
- a_np, w_np, b_np, c_np = get_ref_data()
+ c1_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
+ c2_np = c1_np + b_np
+ c3_np = np.maximum(c2_np, 0)
+ return a_np, w_np, b_np, c1_np, c2_np, c3_np
+
+ a_np, w_np, b_np, c1_np, c2_np, c3_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN')
+ t_bias = topi.add(t_conv, B)
+ t_relu = topi.nn.relu(t_bias)
+ s1 = topi.generic.schedule_conv2d_hwcn([t_conv])
+ s2 = topi.generic.schedule_conv2d_hwcn([t_bias])
+ s3 = topi.generic.schedule_conv2d_hwcn([t_relu])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
- b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
- c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
- func1 = tvm.build(s1, [A, W, B], device)
- func2 = tvm.build(s2, [A, W, C], device)
- func1(a, w, b)
- func2(a, w, c)
- tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
- tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+ b = tvm.nd.array(b_np, ctx)
+
+ conv_out = tvm.nd.array(
+ np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), ctx)
+ bias_out = tvm.nd.array(
+ np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), ctx)
+ relu_out = tvm.nd.array(
+ np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), ctx)
+ func1 = tvm.build(s1, [A, W, t_conv], device)
+ func2 = tvm.build(s2, [A, W, B, t_bias], device)
+ func3 = tvm.build(s3, [A, W, B, t_relu], device)
+ func1(a, w, conv_out)
+ func2(a, w, b, bias_out)
+ func3(a, w, b, relu_out)
+ tvm.testing.assert_allclose(conv_out.asnumpy(), c1_np, rtol=1e-5)
+ tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5)
+ tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
if add_bias:
- b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)