[TOPI] Add proper scheduling for dense on CUDA (#3923)
authorCody Hao Yu <comaniac0422@gmail.com>
Thu, 19 Sep 2019 21:23:20 +0000 (14:23 -0700)
committerYuwei Hu <huyuwei1995@gmail.com>
Thu, 19 Sep 2019 21:23:20 +0000 (14:23 -0700)
* add proper scheduling for dense on CUDA

* add fallback config and fix unit test

* fix corner cases

* refactoring

* fix bias and add testcase

* let fusion happen

topi/python/topi/cuda/dense.py
topi/tests/python/test_topi_dense.py

index 0cf33ec..1db56d1 100644 (file)
 # pylint: disable=invalid-name, unused-variable
 """Schedule for dense operator"""
 from __future__ import absolute_import as _abs
+import logging
 import tvm
 import tvm.autotvm as autotvm
+from tvm.autotvm.task.space import SplitEntity
 from tvm.contrib import cublas
 from .tensor_intrin import dp4a
 from ..nn.dense import dense, dense_default
@@ -26,6 +28,8 @@ from .. import tag
 from .. import generic
 from ..util import traverse_inline, get_const_tuple
 
+logger = logging.getLogger('topi')
+
 
 @autotvm.register_topi_compute(dense, ["cuda", "gpu"], "direct")
 def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
@@ -85,31 +89,23 @@ def schedule_dense(cfg, outs):
     """
     # pylint: disable=unused-argument
     target = tvm.target.current_target()
+
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     if target.target_name == "cuda" and "cublas" in target.libs:
+        A, B = outs[0].op.input_tensors
+        b, i = get_const_tuple(A.shape)
+        o, _ = get_const_tuple(B.shape)
+        cfg.add_flop(2 * i * b * o)
         return generic.schedule_extern(outs)
 
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
-    def _schedule(Dense):
-        num_thread = 64
-        k = Dense.op.reduce_axis[0]
-        ko, kf = s[Dense].split(k, factor=num_thread)
-        DenseF = s.rfactor(Dense, kf)
-
-        if Dense.op in s.outputs:
-            Out = Dense
-        else:
-            Out = outs[0].op.output(0)
-            s[Dense].compute_at(s[Out], s[Out].op.axis[1])
-        s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
-        s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
-
-        tx = s[Dense].op.reduce_axis[0]
-        thread_x = tvm.thread_axis("threadIdx.x")
-        s[Dense].bind(tx, thread_x)
-        s[DenseF].compute_at(s[Dense], tx)
-        s[Dense].set_store_predicate(thread_x.var.equal(0))
-        s[Out].set_store_predicate(thread_x.var.equal(0))
+
+    def _schedule(C):
+        A, _ = C.op.input_tensors
+        batch, _ = get_const_tuple(A.shape)
+        if batch < 32:
+            return schedule_dense_small_batch(cfg, s, C)
+        return schedule_dense_large_batch(cfg, s, C)
 
     scheduled_ops = []
 
@@ -135,6 +131,130 @@ def schedule_dense(cfg, outs):
     return s
 
 
+def schedule_dense_small_batch(cfg, s, C):
+    """Schedule float32/64 dense with small batch size"""
+    A, _ = C.op.input_tensors
+    _, in_dim = get_const_tuple(A.shape)
+    cfg.define_split('tile_k', in_dim, num_outputs=2)
+    if cfg.is_fallback:
+        cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
+
+    _, kf = cfg['tile_k'].apply(s, C, C.op.reduce_axis[0])
+    CF = s.rfactor(C, kf)
+
+    if C.op in s.outputs:
+        Out = C
+    else:
+        Out = s.outputs[0].output(0)
+        s[C].compute_at(s[Out], s[Out].op.axis[1])
+    s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
+    s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
+
+    tx = s[C].op.reduce_axis[0]
+    thread_x = tvm.thread_axis("threadIdx.x")
+    s[C].bind(tx, thread_x)
+    s[CF].compute_at(s[C], tx)
+    s[C].set_store_predicate(thread_x.var.equal(0))
+    s[Out].set_store_predicate(thread_x.var.equal(0))
+
+def schedule_dense_large_batch(cfg, s, C):
+    """Schedule float32/64 dense with large batch size"""
+    A, B = C.op.input_tensors
+    batch, in_dim = get_const_tuple(A.shape)
+    out_dim, _ = get_const_tuple(B.shape)
+    k = C.op.reduce_axis[0]
+
+    # create tuning space
+    try:
+        block_cand = [64, 128]
+        vthread_cand = [2**x for x in range(1, 7)]
+        n_thread_cand = [2**x for x in range(3, 7)]
+        cfg.define_split('tile_x', batch, num_outputs=4,
+                         filter=lambda x: (x.size[1] in vthread_cand and
+                                           x.size[2] in n_thread_cand and
+                                           (x.size[1] * x.size[2] * x.size[3]) in block_cand))
+        cfg.define_split('tile_y', out_dim, num_outputs=4,
+                         filter=lambda x: (x.size[1] in vthread_cand and
+                                           x.size[2] in n_thread_cand and
+                                           (x.size[1] * x.size[2] * x.size[3]) in block_cand))
+        cfg.define_split('tile_k', in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
+    except IndexError:
+        # Index error happens when no entities left after filtering, which was designed
+        # to prune tuning space for better search efficiency.
+        logger.debug(
+            'Tuning space was created without pruning due to unfit shapes')
+        cfg.define_split('tile_x', batch, num_outputs=4)
+        cfg.define_split('tile_y', out_dim, num_outputs=4)
+        cfg.define_split('tile_k', in_dim, num_outputs=3)
+
+    if cfg.is_fallback:
+        if batch > 1:
+            cfg['tile_x'] = SplitEntity([-1, 2, 16, 2])
+        else:
+            cfg['tile_x'] = SplitEntity([1, 1, 1, 1])
+        if out_dim > 1:
+            cfg['tile_y'] = SplitEntity([-1, 2, 16, 2])
+        else:
+            cfg['tile_y'] = SplitEntity([1, 1, 1, 1])
+        if in_dim > 8:
+            cfg['tile_k'] = SplitEntity([-1, 8, 1])
+        else:
+            cfg['tile_k'] = SplitEntity([-1, 1, 1])
+
+    # Explicit memory access
+    AA = s.cache_read(A, "shared", [C])
+    BB = s.cache_read(B, "shared", [C])
+    AL = s.cache_read(AA, "local", [C])
+    BL = s.cache_read(BB, "local", [C])
+    CC = s.cache_write(C, "local")
+
+    # Deal with op fusion
+    if C.op not in s.outputs:
+        s[C].compute_inline()
+        C = s.outputs[0].output(0)
+
+    # Split and reorder computation
+    bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0])
+    by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1])
+    s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
+    s[CC].compute_at(s[C], tx)
+
+    # Binding
+    s[C].bind(by, tvm.thread_axis("blockIdx.y"))
+    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+    s[C].bind(tyz, tvm.thread_axis("vthread"))
+    s[C].bind(txz, tvm.thread_axis("vthread"))
+    s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+
+    # Split reduction
+    yo, xo = CC.op.axis
+    ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
+    s[CC].reorder(ko, kt, ki, yo, xo)
+    s[AA].compute_at(s[CC], ko)
+    s[BB].compute_at(s[CC], ko)
+    s[CC].unroll(kt)
+    s[AL].compute_at(s[CC], kt)
+    s[BL].compute_at(s[CC], kt)
+
+    # Schedule for A's shared memory load
+    num_thread_x = cfg['tile_x'].size[2]
+    ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
+    _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
+    tx, xi = s[AA].split(xi, nparts=num_thread_x)
+    s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[AA].double_buffer()
+
+    # Schedule for B' shared memory load
+    num_thread_y = cfg['tile_y'].size[2]
+    ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
+    _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
+    tx, xi = s[BB].split(xi, nparts=num_thread_y)
+    s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[BB].double_buffer()
+
 @autotvm.register_topi_compute(dense, ['cuda'], ['int8'])
 def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
     """Dense operator for int8 on CUDA"""
index 412eb30..3b74771 100644 (file)
@@ -117,8 +117,9 @@ def verify_dense_int8(batch, in_dim, out_dim, use_bias=True):
 def test_dense():
     verify_dense(1, 1024, 1000, use_bias=True)
     verify_dense(1, 1024, 1000, use_bias=False)
-
     verify_dense(2, 1024, 1000, use_bias=True)
+    verify_dense(128, 1024, 1000, use_bias=False)
+    verify_dense(128, 1024, 1000, use_bias=True)
 
 
 def test_dense_int8():