return s
+def schedule_group_conv2d_nchw_direct(cfg, s, conv):
+ """Schedule group conv2d NCHW direct template"""
+ workload = conv.op.attrs["workload"]
+ groups = get_const_int(workload[6])
+ num_filters = get_const_int(conv.shape[1])
+
+ ##### space definition begin #####
+ n, f, y, x = s[conv].op.axis
+ rc, ry, rx = s[conv].op.reduce_axis
+ cfg.define_split("tile_n", n, num_outputs=4)
+ cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
+ cfg.define_split("tile_f", cfg.axis(num_filters // groups), num_outputs=4)
+ cfg.define_split("tile_y", y, num_outputs=4)
+ cfg.define_split("tile_x", x, num_outputs=4)
+ cfg.define_split("tile_rc", rc, num_outputs=2)
+ cfg.define_split("tile_ry", ry, num_outputs=2)
+ cfg.define_split("tile_rx", rx, num_outputs=2)
+ cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+
+ target = tvm.target.current_target()
+ if target.target_name in ['nvptx', 'rocm']:
+ cfg.define_knob("unroll_explicit", [1])
+ else:
+ cfg.define_knob("unroll_explicit", [0, 1])
+
+ pad_data, kernel = s[conv].op.input_tensors
+
+ s[pad_data].compute_inline()
+
+ if conv.op in s.outputs:
+ output = conv
+ OL = s.cache_write(conv, 'local')
+ else:
+ output = s.outputs[0].output(0)
+ s[conv].set_scope('local')
+ OL = conv
+
+ # create cache stage
+ AA = s.cache_read(pad_data, 'shared', [OL])
+ WW = s.cache_read(kernel, 'shared', [OL])
+
+ # tile and bind spatial axes
+ n, f, y, x = s[output].op.axis
+ kernel_scope, n = s[output].split(n, nparts=1)
+
+ g, f = s[output].split(f, nparts=groups)
+ bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
+ bg, vg = cfg["tile_g"].apply(s, output, g)
+ bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+ by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+ bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+
+ s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
+ s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
+ s[output].bind(s[output].fuse(bg, bf), tvm.thread_axis("blockIdx.y"))
+ s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
+ s[output].bind(vn, tvm.thread_axis("vthread"))
+ s[output].bind(vg, tvm.thread_axis("vthread"))
+ s[output].bind(vf, tvm.thread_axis("vthread"))
+ s[output].bind(vy, tvm.thread_axis("vthread"))
+ s[output].bind(vx, tvm.thread_axis("vthread"))
+
+ cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
+ if cfg["fuse_yx"].val:
+ s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
+ s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
+ tyx = s[output].fuse(ty, tx)
+ s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
+ s[OL].compute_at(s[output], tyx)
+
+ # number of threads
+ n_tz = cfg["tile_n"].size[2]
+ n_ty = cfg["tile_f"].size[2]
+ n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
+ else:
+ s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
+ s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
+ s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
+ s[OL].compute_at(s[output], tx)
+
+ # number of threads
+ n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
+ n_ty = cfg["tile_y"].size[2]
+ n_tx = cfg["tile_x"].size[2]
+
+ # tile reduction axes
+ n, f, y, x = s[OL].op.axis
+ rc, ry, rx = s[OL].op.reduce_axis
+ rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+ ryo, ryi = cfg['tile_rx'].apply(s, OL, ry)
+ rxo, rxi = cfg['tile_ry'].apply(s, OL, rx)
+ s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
+
+ s[AA].compute_at(s[OL], rxo)
+ s[WW].compute_at(s[OL], rxo)
+
+ # cooperative fetching
+ for load in [AA, WW]:
+ n, f, y, x = s[load].op.axis
+ fused = s[load].fuse(n, f, y, x)
+ fused, tx = s[load].split(fused, factor=n_tx)
+ fused, ty = s[load].split(fused, factor=n_ty)
+ fused, tz = s[load].split(fused, factor=n_tz)
+ s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
+ s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
+ s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
+
+ # unroll
+ s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
+ s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+
+ N, CO, OH, OW = get_const_tuple(output.shape)
+ _, CI_div_groups, KH, KW = get_const_tuple(kernel.shape)
+ cfg.add_flop(2 * N * OH * OW * CO * CI_div_groups * KH * KW)
+
+
@autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw,
- ["cuda", "gpu"], ["int8"])
+ ["cuda", "gpu"], ["int8", "direct"])
def schedule_conv2d_nchw_cuda(cfg, outs):
"""TOPI schedule callback of group conv2d for cuda gpu
if op.tag == "group_conv2d_NCHWc_int8":
schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0))
if op.tag == "group_conv2d_nchw":
- raise tvm.error.OpNotImplemented("group_conv2d_nchw not supported")
+ schedule_group_conv2d_nchw_direct(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s