[TOPI][CUDA] schedule for group_conv2d (#3663)
authorWuwei Lin <vincentl13x@gmail.com>
Wed, 31 Jul 2019 08:26:05 +0000 (16:26 +0800)
committermasahi <masahi129@gmail.com>
Wed, 31 Jul 2019 08:26:05 +0000 (17:26 +0900)
* [TOPI][CUDA] schedule for group_conv2d

* Fix #flops

topi/python/topi/cuda/group_conv2d_nchw.py
topi/tests/python/test_topi_group_conv2d.py

index cbdb3cb..cd8c823 100644 (file)
@@ -321,8 +321,124 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
     return s
 
 
+def schedule_group_conv2d_nchw_direct(cfg, s, conv):
+    """Schedule group conv2d NCHW direct template"""
+    workload = conv.op.attrs["workload"]
+    groups = get_const_int(workload[6])
+    num_filters = get_const_int(conv.shape[1])
+
+    ##### space definition begin #####
+    n, f, y, x = s[conv].op.axis
+    rc, ry, rx = s[conv].op.reduce_axis
+    cfg.define_split("tile_n", n, num_outputs=4)
+    cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
+    cfg.define_split("tile_f", cfg.axis(num_filters // groups), num_outputs=4)
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_split("tile_ry", ry, num_outputs=2)
+    cfg.define_split("tile_rx", rx, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+
+    target = tvm.target.current_target()
+    if target.target_name in ['nvptx', 'rocm']:
+        cfg.define_knob("unroll_explicit", [1])
+    else:
+        cfg.define_knob("unroll_explicit", [0, 1])
+
+    pad_data, kernel = s[conv].op.input_tensors
+
+    s[pad_data].compute_inline()
+
+    if conv.op in s.outputs:
+        output = conv
+        OL = s.cache_write(conv, 'local')
+    else:
+        output = s.outputs[0].output(0)
+        s[conv].set_scope('local')
+        OL = conv
+
+    # create cache stage
+    AA = s.cache_read(pad_data, 'shared', [OL])
+    WW = s.cache_read(kernel, 'shared', [OL])
+
+    # tile and bind spatial axes
+    n, f, y, x = s[output].op.axis
+    kernel_scope, n = s[output].split(n, nparts=1)
+
+    g, f = s[output].split(f, nparts=groups)
+    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
+    bg, vg = cfg["tile_g"].apply(s, output, g)
+    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+
+    s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
+    s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
+    s[output].bind(s[output].fuse(bg, bf), tvm.thread_axis("blockIdx.y"))
+    s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
+    s[output].bind(vn, tvm.thread_axis("vthread"))
+    s[output].bind(vg, tvm.thread_axis("vthread"))
+    s[output].bind(vf, tvm.thread_axis("vthread"))
+    s[output].bind(vy, tvm.thread_axis("vthread"))
+    s[output].bind(vx, tvm.thread_axis("vthread"))
+
+    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
+    if cfg["fuse_yx"].val:
+        s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
+        s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
+        tyx = s[output].fuse(ty, tx)
+        s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
+        s[OL].compute_at(s[output], tyx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2]
+        n_ty = cfg["tile_f"].size[2]
+        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
+    else:
+        s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
+        s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
+        s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
+        s[OL].compute_at(s[output], tx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
+        n_ty = cfg["tile_y"].size[2]
+        n_tx = cfg["tile_x"].size[2]
+
+    # tile reduction axes
+    n, f, y, x = s[OL].op.axis
+    rc, ry, rx = s[OL].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    ryo, ryi = cfg['tile_rx'].apply(s, OL, ry)
+    rxo, rxi = cfg['tile_ry'].apply(s, OL, rx)
+    s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
+
+    s[AA].compute_at(s[OL], rxo)
+    s[WW].compute_at(s[OL], rxo)
+
+    # cooperative fetching
+    for load in [AA, WW]:
+        n, f, y, x = s[load].op.axis
+        fused = s[load].fuse(n, f, y, x)
+        fused, tx = s[load].split(fused, factor=n_tx)
+        fused, ty = s[load].split(fused, factor=n_ty)
+        fused, tz = s[load].split(fused, factor=n_tz)
+        s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
+        s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
+        s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
+
+    # unroll
+    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
+    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+
+    N, CO, OH, OW = get_const_tuple(output.shape)
+    _, CI_div_groups, KH, KW = get_const_tuple(kernel.shape)
+    cfg.add_flop(2 * N * OH * OW * CO * CI_div_groups * KH * KW)
+
+
 @autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw,
-                                ["cuda", "gpu"], ["int8"])
+                                ["cuda", "gpu"], ["int8", "direct"])
 def schedule_conv2d_nchw_cuda(cfg, outs):
     """TOPI schedule callback of group conv2d for cuda gpu
 
@@ -347,7 +463,7 @@ def schedule_conv2d_nchw_cuda(cfg, outs):
         if op.tag == "group_conv2d_NCHWc_int8":
             schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0))
         if op.tag == "group_conv2d_nchw":
-            raise tvm.error.OpNotImplemented("group_conv2d_nchw not supported")
+            schedule_group_conv2d_nchw_direct(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
     return s
index e809999..0e17678 100644 (file)
@@ -67,9 +67,6 @@ def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, str
         if not ctx.exist:
             print("Skip because %s is not enabled" % device)
             return
-        if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
 
         print("Running on target: %s" % device)
         with tvm.target.create(device):
@@ -94,7 +91,7 @@ def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, str
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
-    for device in ["llvm"]:
+    for device in ["llvm", "cuda"]:
         check_device(device)