[AutoTVM] Add batch_matmul to tunable operations (#4242)
authorJosh Fromm <jwfromm@uw.edu>
Thu, 7 Nov 2019 00:07:09 +0000 (16:07 -0800)
committerHaichen Shen <shenhaichen@gmail.com>
Thu, 7 Nov 2019 00:07:09 +0000 (16:07 -0800)
* Batch matmul tuning running but with errors.

* Default x86 schedule as good as before.

* Code Cleanup

* Remove unused argument.

* improved template documentation.

* Silly lint fix

* Removed leftover comment.

* Moved cfg declaration to schedule for batch_matmul

* Moved x86 dense cfg declaration to schedule.

* lint fix

* Removed duplicate cfg declaration in dense.

* Reverted changes to dense.

python/tvm/autotvm/task/relay_integration.py
python/tvm/autotvm/task/topi_integration.py
topi/python/topi/x86/batch_matmul.py

index 6ee8bc0..345da66 100644 (file)
@@ -117,6 +117,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
                                  topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
         tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
         tvm.relay.op.nn.dense: [topi.nn.dense],
+        tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
         tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
     }
 
index ac4683d..7bfc313 100644 (file)
@@ -87,6 +87,7 @@ class TaskExtractEnv:
             topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
             topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
             topi.nn.dense: "topi_nn_dense",
+            topi.nn.batch_matmul: "topi_nn_batch_matmul",
             topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
             topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
             topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
@@ -103,6 +104,7 @@ class TaskExtractEnv:
             topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
             topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8],
             topi.nn.dense: [topi.generic.schedule_dense],
+            topi.nn.batch_matmul: [topi.generic.schedule_batch_matmul],
             topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
             topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
             topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
@@ -118,6 +120,7 @@ class TaskExtractEnv:
             topi.nn.group_conv2d_nchw:      lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
             topi.nn.conv2d_transpose_nchw:  lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
             topi.nn.dense:                  lambda x: setattr(topi.nn, 'dense', x),
+            topi.nn.batch_matmul:           lambda x: setattr(topi.nn, 'batch_matmul', x),
             topi.nn.bitserial_conv2d_nchw:  lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x),
             topi.nn.bitserial_conv2d_nhwc:  lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
             topi.nn.bitserial_dense:        lambda x: setattr(topi.nn, 'bitserial_dense', x),
@@ -226,6 +229,15 @@ class TaskExtractEnv:
                 return s, [data, weight, bias, C]
             return s, [data, weight, C]
 
+        @register("topi_nn_batch_matmul")
+        def _topi_nn_batch_matmul(*args, **kwargs):
+            assert not kwargs, "Do not support kwargs in template function call"
+            args = deserialize_args(args)
+            A, B = args
+            C = topi.nn.batch_matmul(A, B)
+            s = topi.generic.schedule_batch_matmul([C])
+            return s, [A, B, C]
+
         @register("topi_nn_bitserial_conv2d_nhwc")
         def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
             args = deserialize_args(args)
index 047e97f..b505cbf 100644 (file)
 """x86 batch_matmul operators"""
 from __future__ import absolute_import as _abs
 import tvm
+from tvm import autotvm
+from tvm.autotvm.task.space import SplitEntity
 from tvm.contrib import cblas
-from topi.nn import batch_matmul, batch_matmul_default
-from .. import generic
+from .. import generic, nn
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
-@batch_matmul.register(["cpu"])
-def batch_matmul_x86(x, y):
+
+@autotvm.register_topi_compute(nn.batch_matmul, "cpu", "direct")
+def _declaration_batch_matmul_nopack(cfg, x, y):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch.
 
     Parameters
     ----------
+    cfg : ConfigSpace
+        Autotvm tuning space config file
     x : tvm.Tensor
         3-D with shape [batch, M, K]
-
     y : tvm.Tensor
         3-D with shape [batch, N, K]
-
     Returns
     -------
     output : tvm.Tensor
@@ -44,17 +46,37 @@ def batch_matmul_x86(x, y):
     target = tvm.target.current_target()
     if "cblas" in target.libs:
         return cblas.batch_matmul(x, y, False, True)
-    return batch_matmul_default(x, y)
 
-@generic.schedule_batch_matmul.register(["cpu"])
-def schedule_batch_matmul(outs):
+    assert len(x.shape) == 3 and len(
+        y.shape) == 3, "only support 3-dim batch_matmul"
+    XB, M, XK = get_const_tuple(x.shape)
+    YB, N, YK = get_const_tuple(y.shape)
+    assert XB == YB, "batch dimension doesn't match"
+    assert XK == YK, "shapes of x and y is inconsistant"
+    B = XB
+    K = XK
+    if cfg.is_fallback:
+        _default_batch_matmul_nopack_config(cfg, M, N, K)
+
+    k = tvm.reduce_axis((0, K), name='k')
+    C = tvm.compute(
+        (B, M, N),
+        lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
+        tag='batch_matmul')
+    return C
+
+
+@autotvm.register_topi_schedule(generic.schedule_batch_matmul, "cpu", "direct")
+def schedule_batch_matmul(cfg, outs):
     """Schedule for batch_matmul
 
     Parameters
     ----------
-    outs: Array of Tensor
-          The computation graph description of batch_matmul
-          in the format of an array of tensors.
+    cfg : ConfigSpace
+        AutoTVM tuning space config file.
+    outs : Array of Tensor
+        The computation graph description of batch_matmul
+        in the format of an array of tensors.
 
     Returns
     -------
@@ -71,16 +93,22 @@ def schedule_batch_matmul(outs):
         if "batch_matmul" in op.tag:
             C = op.output(0)
             A, B = s[C].op.input_tensors
-            _, M, N = get_const_tuple(C.shape)
+            _, M, K = get_const_tuple(A.shape)
+            _, _, N = get_const_tuple(C.shape)
+
+            # create tuning space
+            cfg.define_split("tile_y", M, num_outputs=2)
+            cfg.define_split("tile_x", N, num_outputs=2)
+            cfg.define_split("tile_k", K, num_outputs=2)
+
             k, = s[C].op.reduce_axis
-            ko, ki = s[C].split(k, 16)
+
+            ko, ki = cfg["tile_k"].apply(s, C, k)
             CC = s.rfactor(C, ki)
 
             b, y, x = s[C].op.axis
-            y_bn = get_max_power2_factor(M, 8)
-            x_bn = get_max_power2_factor(N, 8)
-            yo, yi = s[C].split(y, y_bn)
-            xo, xi = s[C].split(x, x_bn)
+            yo, yi = cfg["tile_y"].apply(s, C, y)
+            xo, xi = cfg["tile_x"].apply(s, C, x)
             s[C].reorder(b, yo, xo, yi, xi)
             bxyo = s[C].fuse(b, yo, xo)
             s[C].parallel(bxyo)
@@ -94,3 +122,11 @@ def schedule_batch_matmul(outs):
 
     traverse_inline(s, outs[0].op, _callback)
     return s
+
+
+def _default_batch_matmul_nopack_config(cfg, M, N, K):
+    cfg["tile_k"] = SplitEntity([K // 16, 16])
+    x_bn = get_max_power2_factor(N, 8)
+    cfg["tile_x"] = SplitEntity([N // x_bn, x_bn])
+    y_bn = get_max_power2_factor(M, 8)
+    cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])