topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
topi.nn.dense: "topi_nn_dense",
+ topi.nn.batch_matmul: "topi_nn_batch_matmul",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8],
topi.nn.dense: [topi.generic.schedule_dense],
+ topi.nn.batch_matmul: [topi.generic.schedule_batch_matmul],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x),
+ topi.nn.batch_matmul: lambda x: setattr(topi.nn, 'batch_matmul', x),
topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x),
topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
return s, [data, weight, bias, C]
return s, [data, weight, C]
+ @register("topi_nn_batch_matmul")
+ def _topi_nn_batch_matmul(*args, **kwargs):
+ assert not kwargs, "Do not support kwargs in template function call"
+ args = deserialize_args(args)
+ A, B = args
+ C = topi.nn.batch_matmul(A, B)
+ s = topi.generic.schedule_batch_matmul([C])
+ return s, [A, B, C]
+
@register("topi_nn_bitserial_conv2d_nhwc")
def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
args = deserialize_args(args)
"""x86 batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm
+from tvm import autotvm
+from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas
-from topi.nn import batch_matmul, batch_matmul_default
-from .. import generic
+from .. import generic, nn
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
-@batch_matmul.register(["cpu"])
-def batch_matmul_x86(x, y):
+
+@autotvm.register_topi_compute(nn.batch_matmul, "cpu", "direct")
+def _declaration_batch_matmul_nopack(cfg, x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
+ cfg : ConfigSpace
+ Autotvm tuning space config file
x : tvm.Tensor
3-D with shape [batch, M, K]
-
y : tvm.Tensor
3-D with shape [batch, N, K]
-
Returns
-------
output : tvm.Tensor
target = tvm.target.current_target()
if "cblas" in target.libs:
return cblas.batch_matmul(x, y, False, True)
- return batch_matmul_default(x, y)
-@generic.schedule_batch_matmul.register(["cpu"])
-def schedule_batch_matmul(outs):
+ assert len(x.shape) == 3 and len(
+ y.shape) == 3, "only support 3-dim batch_matmul"
+ XB, M, XK = get_const_tuple(x.shape)
+ YB, N, YK = get_const_tuple(y.shape)
+ assert XB == YB, "batch dimension doesn't match"
+ assert XK == YK, "shapes of x and y is inconsistant"
+ B = XB
+ K = XK
+ if cfg.is_fallback:
+ _default_batch_matmul_nopack_config(cfg, M, N, K)
+
+ k = tvm.reduce_axis((0, K), name='k')
+ C = tvm.compute(
+ (B, M, N),
+ lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
+ tag='batch_matmul')
+ return C
+
+
+@autotvm.register_topi_schedule(generic.schedule_batch_matmul, "cpu", "direct")
+def schedule_batch_matmul(cfg, outs):
"""Schedule for batch_matmul
Parameters
----------
- outs: Array of Tensor
- The computation graph description of batch_matmul
- in the format of an array of tensors.
+ cfg : ConfigSpace
+ AutoTVM tuning space config file.
+ outs : Array of Tensor
+ The computation graph description of batch_matmul
+ in the format of an array of tensors.
Returns
-------
if "batch_matmul" in op.tag:
C = op.output(0)
A, B = s[C].op.input_tensors
- _, M, N = get_const_tuple(C.shape)
+ _, M, K = get_const_tuple(A.shape)
+ _, _, N = get_const_tuple(C.shape)
+
+ # create tuning space
+ cfg.define_split("tile_y", M, num_outputs=2)
+ cfg.define_split("tile_x", N, num_outputs=2)
+ cfg.define_split("tile_k", K, num_outputs=2)
+
k, = s[C].op.reduce_axis
- ko, ki = s[C].split(k, 16)
+
+ ko, ki = cfg["tile_k"].apply(s, C, k)
CC = s.rfactor(C, ki)
b, y, x = s[C].op.axis
- y_bn = get_max_power2_factor(M, 8)
- x_bn = get_max_power2_factor(N, 8)
- yo, yi = s[C].split(y, y_bn)
- xo, xi = s[C].split(x, x_bn)
+ yo, yi = cfg["tile_y"].apply(s, C, y)
+ xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(b, yo, xo, yi, xi)
bxyo = s[C].fuse(b, yo, xo)
s[C].parallel(bxyo)
traverse_inline(s, outs[0].op, _callback)
return s
+
+
+def _default_batch_matmul_nopack_config(cfg, M, N, K):
+ cfg["tile_k"] = SplitEntity([K // 16, 16])
+ x_bn = get_max_power2_factor(N, 8)
+ cfg["tile_x"] = SplitEntity([N // x_bn, x_bn])
+ y_bn = get_max_power2_factor(M, 8)
+ cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])