Make batch matrix multiplication on GPU tunable (#5752)
authorThomas Viehmann <tv.code@beamnet.de>
Thu, 11 Jun 2020 16:38:46 +0000 (18:38 +0200)
committerGitHub <noreply@github.com>
Thu, 11 Jun 2020 16:38:46 +0000 (09:38 -0700)
This is primarily aimed at the AMD GPU backend and done as part
of a project for AMD, but should work for all users of the GPU
schedule.

python/tvm/relay/op/strategy/cuda.py
topi/python/topi/cuda/batch_matmul.py
topi/tests/python/test_topi_batch_matmul.py

index 5ffb7f1..4b019cf 100644 (file)
@@ -481,7 +481,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
     """batch_matmul cuda strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_batch_matmul(topi.nn.batch_matmul),
+        wrap_compute_batch_matmul(topi.cuda.batch_matmul),
         wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
         name="batch_matmul.cuda",
         plevel=10)
index bf80182..7d92edf 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,too-many-locals,unused-variable
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
 """cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
 from tvm import te
 from tvm.contrib import cublas
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
+from .. import nn
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
-def schedule_batch_matmul(outs):
+@autotvm.register_topi_compute("batch_matmul.cuda")
+def batch_matmul(cfg, x, y):
+    """Compute conv2d with NCHW layout"""
+    return nn.batch_matmul(x, y)
+
+
+@autotvm.register_topi_schedule("batch_matmul.cuda")
+def schedule_batch_matmul(cfg, outs):
     """Schedule for batch_matmul
 
     Parameters
@@ -37,7 +48,7 @@ def schedule_batch_matmul(outs):
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
 
-    def _schedule(op):
+    def _schedule(cfg, op):
         C = op.output(0)
         A, B = s[C].op.input_tensors
         _, M, N = get_const_tuple(C.shape)
@@ -51,16 +62,34 @@ def schedule_batch_matmul(outs):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+        k, = s[CC].op.reduce_axis
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_split("tile_k", k, num_outputs=2)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:
+            y_bn = get_max_power2_factor(M, 64)
+            x_bn = get_max_power2_factor(N, 64)
+            y_nthreads = min(y_bn, 8)
+            x_nthreads = min(x_bn, 8)
+            cfg['tile_x'] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads])
+            cfg['tile_y'] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads])
+            cfg['tile_k'] = SplitEntity([-1, 8])
+            cfg['auto_unroll_max_step'] = OtherOptionEntity(16)
+
+        by, ty, yi = cfg["tile_y"].apply(s, C, y)
+        bx, tx, xi = cfg["tile_x"].apply(s, C, x)
+
+        thread_x = te.thread_axis("threadIdx.x")
+        thread_y = te.thread_axis("threadIdx.y")
 
         s[C].reorder(b, by, bx, ty, tx, yi, xi)
         s[C].bind(b, te.thread_axis("blockIdx.z"))
@@ -68,38 +97,41 @@ def schedule_batch_matmul(outs):
         s[C].bind(bx, te.thread_axis("blockIdx.x"))
         s[C].bind(ty, thread_y)
         s[C].bind(tx, thread_x)
-        s[C].pragma(yi, "auto_unroll_max_step", 16)
+        s[C].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
+        s[C].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val)
 
         s[CC].compute_at(s[C], tx)
         _, yi, xi = s[CC].op.axis
-        k, = s[CC].op.reduce_axis
-        ko, ki = s[CC].split(k, 8)
+        ko, ki = cfg["tile_k"].apply(s, CC, k)
         s[CC].reorder(ko, ki, yi, xi)
-        s[CC].pragma(ki, "auto_unroll_max_step", 16)
+        s[CC].pragma(ki, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
+        s[CC].pragma(ki, 'unroll_explicit', cfg['unroll_explicit'].val)
 
         s[AA].compute_at(s[CC], ko)
         s[AL].compute_at(s[CC], ki)
         s[BB].compute_at(s[CC], ko)
         s[BL].compute_at(s[CC], ki)
         _, y, k = s[AA].op.axis
-        ty, yi = s[AA].split(y, nparts=y_nthreads)
-        tx, ki = s[AA].split(k, nparts=x_nthreads)
+        ty, yi = s[AA].split(y, nparts=cfg["tile_y"].size[1])
+        tx, ki = s[AA].split(k, nparts=cfg["tile_x"].size[1])
         s[AA].reorder(ty, tx, yi, ki)
         s[AA].bind(ty, thread_y)
         s[AA].bind(tx, thread_x)
-        s[AA].pragma(yi, "auto_unroll_max_step", 16)
+        s[AA].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
+        s[AA].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val)
 
         _, x, k = s[BB].op.axis
-        ty, xi = s[BB].split(x, nparts=y_nthreads)
-        tx, ki = s[BB].split(k, nparts=x_nthreads)
+        ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1])
+        tx, ki = s[BB].split(k, nparts=cfg["tile_x"].size[1])
         s[BB].bind(ty, thread_y)
         s[BB].bind(tx, thread_x)
         s[BB].reorder(ty, tx, xi, ki)
-        s[BB].pragma(xi, "auto_unroll_max_step", 16)
+        s[BB].pragma(xi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
+        s[BB].pragma(xi, 'unroll_explicit', cfg['unroll_explicit'].val)
 
     def _callback(op):
         if "batch_matmul" in op.tag:
-            _schedule(op)
+            _schedule(cfg, op)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
index b8c8547..716f407 100644 (file)
@@ -28,7 +28,7 @@ from common import get_all_backend
 _batch_matmul_implement = {
     "generic": (topi.nn.batch_matmul, topi.generic.schedule_batch_matmul),
     "cpu": (topi.x86.batch_matmul, topi.x86.schedule_batch_matmul),
-    "gpu": (topi.nn.batch_matmul, topi.cuda.schedule_batch_matmul),
+    "gpu": (topi.cuda.batch_matmul, topi.cuda.schedule_batch_matmul),
 }
 
 def verify_batch_matmul(batch, M, N, K):