# pylint: disable=invalid-name, unused-variable
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
+import logging
import tvm
import tvm.autotvm as autotvm
+from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cublas
from .tensor_intrin import dp4a
from ..nn.dense import dense, dense_default
from .. import generic
from ..util import traverse_inline, get_const_tuple
+logger = logging.getLogger('topi')
+
@autotvm.register_topi_compute(dense, ["cuda", "gpu"], "direct")
def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
"""
# pylint: disable=unused-argument
target = tvm.target.current_target()
+
+ outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if target.target_name == "cuda" and "cublas" in target.libs:
+ A, B = outs[0].op.input_tensors
+ b, i = get_const_tuple(A.shape)
+ o, _ = get_const_tuple(B.shape)
+ cfg.add_flop(2 * i * b * o)
return generic.schedule_extern(outs)
- outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
- def _schedule(Dense):
- num_thread = 64
- k = Dense.op.reduce_axis[0]
- ko, kf = s[Dense].split(k, factor=num_thread)
- DenseF = s.rfactor(Dense, kf)
-
- if Dense.op in s.outputs:
- Out = Dense
- else:
- Out = outs[0].op.output(0)
- s[Dense].compute_at(s[Out], s[Out].op.axis[1])
- s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
- s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
-
- tx = s[Dense].op.reduce_axis[0]
- thread_x = tvm.thread_axis("threadIdx.x")
- s[Dense].bind(tx, thread_x)
- s[DenseF].compute_at(s[Dense], tx)
- s[Dense].set_store_predicate(thread_x.var.equal(0))
- s[Out].set_store_predicate(thread_x.var.equal(0))
+
+ def _schedule(C):
+ A, _ = C.op.input_tensors
+ batch, _ = get_const_tuple(A.shape)
+ if batch < 32:
+ return schedule_dense_small_batch(cfg, s, C)
+ return schedule_dense_large_batch(cfg, s, C)
scheduled_ops = []
return s
+def schedule_dense_small_batch(cfg, s, C):
+ """Schedule float32/64 dense with small batch size"""
+ A, _ = C.op.input_tensors
+ _, in_dim = get_const_tuple(A.shape)
+ cfg.define_split('tile_k', in_dim, num_outputs=2)
+ if cfg.is_fallback:
+ cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
+
+ _, kf = cfg['tile_k'].apply(s, C, C.op.reduce_axis[0])
+ CF = s.rfactor(C, kf)
+
+ if C.op in s.outputs:
+ Out = C
+ else:
+ Out = s.outputs[0].output(0)
+ s[C].compute_at(s[Out], s[Out].op.axis[1])
+ s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
+ s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
+
+ tx = s[C].op.reduce_axis[0]
+ thread_x = tvm.thread_axis("threadIdx.x")
+ s[C].bind(tx, thread_x)
+ s[CF].compute_at(s[C], tx)
+ s[C].set_store_predicate(thread_x.var.equal(0))
+ s[Out].set_store_predicate(thread_x.var.equal(0))
+
+def schedule_dense_large_batch(cfg, s, C):
+ """Schedule float32/64 dense with large batch size"""
+ A, B = C.op.input_tensors
+ batch, in_dim = get_const_tuple(A.shape)
+ out_dim, _ = get_const_tuple(B.shape)
+ k = C.op.reduce_axis[0]
+
+ # create tuning space
+ try:
+ block_cand = [64, 128]
+ vthread_cand = [2**x for x in range(1, 7)]
+ n_thread_cand = [2**x for x in range(3, 7)]
+ cfg.define_split('tile_x', batch, num_outputs=4,
+ filter=lambda x: (x.size[1] in vthread_cand and
+ x.size[2] in n_thread_cand and
+ (x.size[1] * x.size[2] * x.size[3]) in block_cand))
+ cfg.define_split('tile_y', out_dim, num_outputs=4,
+ filter=lambda x: (x.size[1] in vthread_cand and
+ x.size[2] in n_thread_cand and
+ (x.size[1] * x.size[2] * x.size[3]) in block_cand))
+ cfg.define_split('tile_k', in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
+ except IndexError:
+ # Index error happens when no entities left after filtering, which was designed
+ # to prune tuning space for better search efficiency.
+ logger.debug(
+ 'Tuning space was created without pruning due to unfit shapes')
+ cfg.define_split('tile_x', batch, num_outputs=4)
+ cfg.define_split('tile_y', out_dim, num_outputs=4)
+ cfg.define_split('tile_k', in_dim, num_outputs=3)
+
+ if cfg.is_fallback:
+ if batch > 1:
+ cfg['tile_x'] = SplitEntity([-1, 2, 16, 2])
+ else:
+ cfg['tile_x'] = SplitEntity([1, 1, 1, 1])
+ if out_dim > 1:
+ cfg['tile_y'] = SplitEntity([-1, 2, 16, 2])
+ else:
+ cfg['tile_y'] = SplitEntity([1, 1, 1, 1])
+ if in_dim > 8:
+ cfg['tile_k'] = SplitEntity([-1, 8, 1])
+ else:
+ cfg['tile_k'] = SplitEntity([-1, 1, 1])
+
+ # Explicit memory access
+ AA = s.cache_read(A, "shared", [C])
+ BB = s.cache_read(B, "shared", [C])
+ AL = s.cache_read(AA, "local", [C])
+ BL = s.cache_read(BB, "local", [C])
+ CC = s.cache_write(C, "local")
+
+ # Deal with op fusion
+ if C.op not in s.outputs:
+ s[C].compute_inline()
+ C = s.outputs[0].output(0)
+
+ # Split and reorder computation
+ bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0])
+ by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1])
+ s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
+ s[CC].compute_at(s[C], tx)
+
+ # Binding
+ s[C].bind(by, tvm.thread_axis("blockIdx.y"))
+ s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+ s[C].bind(tyz, tvm.thread_axis("vthread"))
+ s[C].bind(txz, tvm.thread_axis("vthread"))
+ s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
+ s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+
+ # Split reduction
+ yo, xo = CC.op.axis
+ ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
+ s[CC].reorder(ko, kt, ki, yo, xo)
+ s[AA].compute_at(s[CC], ko)
+ s[BB].compute_at(s[CC], ko)
+ s[CC].unroll(kt)
+ s[AL].compute_at(s[CC], kt)
+ s[BL].compute_at(s[CC], kt)
+
+ # Schedule for A's shared memory load
+ num_thread_x = cfg['tile_x'].size[2]
+ ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
+ _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
+ tx, xi = s[AA].split(xi, nparts=num_thread_x)
+ s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
+ s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
+ s[AA].double_buffer()
+
+ # Schedule for B' shared memory load
+ num_thread_y = cfg['tile_y'].size[2]
+ ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
+ _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
+ tx, xi = s[BB].split(xi, nparts=num_thread_y)
+ s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
+ s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
+ s[BB].double_buffer()
+
@autotvm.register_topi_compute(dense, ['cuda'], ['int8'])
def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for int8 on CUDA"""