From c870261facb8bcc452975c4138bee22dd749e1c5 Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Tue, 20 Aug 2019 22:24:10 -0700 Subject: [PATCH] [TOPI] Use cblas for dense and batch_matmul when "cblas" is in the target libraries (#3787) * Support cblas library in dense * start to add support for generic batch_matmul compute * Add x86 override for batch_matmul * Fix linting * reset file * Fix typos * dummy change to re-trigger CI --- python/tvm/relay/op/nn/_nn.py | 3 +- topi/python/topi/nn/batch_matmul.py | 25 +++++++- topi/python/topi/x86/batch_matmul.py | 29 +++++++++- topi/python/topi/x86/dense.py | 107 ++++++++++++++++++++--------------- 4 files changed, 114 insertions(+), 50 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c896a00..6dcff6b 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -73,7 +73,8 @@ reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) @reg.register_compute("nn.batch_matmul") def compute_batch_matmul(attrs, inputs, out_type, target): """Compute definition of batch_matmul""" - return [topi.nn.batch_matmul(inputs[0], inputs[1])] + with target: + return [topi.nn.batch_matmul(inputs[0], inputs[1])] @reg.register_schedule("nn.batch_matmul") diff --git a/topi/python/topi/nn/batch_matmul.py b/topi/python/topi/nn/batch_matmul.py index c8ca3e3..7b872ce 100644 --- a/topi/python/topi/nn/batch_matmul.py +++ b/topi/python/topi/nn/batch_matmul.py @@ -20,8 +20,7 @@ from __future__ import absolute_import as _abs import tvm from ..util import get_const_tuple - -def batch_matmul(x, y): +def batch_matmul_default(x, y): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -30,7 +29,7 @@ def batch_matmul(x, y): x : tvm.Tensor 3-D with shape [batch, M, K] - y : tvm.TEnsor + y : tvm.Tensor 3-D with shape [batch, N, K] Returns @@ -49,3 +48,23 @@ def batch_matmul(x, y): return tvm.compute((batch, M, N), lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k), tag='batch_matmul') + +@tvm.target.generic_func +def batch_matmul(x, y): + """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are + data in batch. + + Parameters + ---------- + x : tvm.Tensor + 3-D with shape [batch, M, K] + + y : tvm.Tensor + 3-D with shape [batch, N, K] + + Returns + ------- + output : tvm.Tensor + 3-D with shape [batch, M, N] + """ + return batch_matmul_default(x, y) diff --git a/topi/python/topi/x86/batch_matmul.py b/topi/python/topi/x86/batch_matmul.py index 7368760..047e97f 100644 --- a/topi/python/topi/x86/batch_matmul.py +++ b/topi/python/topi/x86/batch_matmul.py @@ -18,10 +18,33 @@ """x86 batch_matmul operators""" from __future__ import absolute_import as _abs import tvm - +from tvm.contrib import cblas +from topi.nn import batch_matmul, batch_matmul_default from .. import generic from ..util import traverse_inline, get_const_tuple, get_max_power2_factor +@batch_matmul.register(["cpu"]) +def batch_matmul_x86(x, y): + """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are + data in batch. + + Parameters + ---------- + x : tvm.Tensor + 3-D with shape [batch, M, K] + + y : tvm.Tensor + 3-D with shape [batch, N, K] + + Returns + ------- + output : tvm.Tensor + 3-D with shape [batch, M, N] + """ + target = tvm.target.current_target() + if "cblas" in target.libs: + return cblas.batch_matmul(x, y, False, True) + return batch_matmul_default(x, y) @generic.schedule_batch_matmul.register(["cpu"]) def schedule_batch_matmul(outs): @@ -38,6 +61,10 @@ def schedule_batch_matmul(outs): sch: Schedule The computation schedule for the op. """ + target = tvm.target.current_target() + if "cblas" in target.libs: + return generic.schedule_extern(outs) + s = tvm.create_schedule([x.op for x in outs]) def _callback(op): diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py index 2525ba0..e22ad44 100644 --- a/topi/python/topi/x86/dense.py +++ b/topi/python/topi/x86/dense.py @@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs import tvm from tvm import autotvm from tvm.autotvm.task.space import SplitEntity +from tvm.contrib import cblas from .util import get_fp32_len from .. import generic, tag, nn @@ -40,29 +41,33 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): # Declare dense compute with packing weight into cache-friendly layout @autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack") def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): - if out_dtype is None: - out_dtype = data.dtype - batch, in_dim = get_const_tuple(data.shape) - out_dim, _ = get_const_tuple(weight.shape) - # create tuning space - cfg.define_split("tile_y", batch, num_outputs=3) - cfg.define_split("tile_x", out_dim, num_outputs=3) - cfg.define_split("tile_k", in_dim, num_outputs=2) - if cfg.is_fallback: - _default_dense_pack_config(cfg, batch, out_dim, in_dim) - - packw_bn = cfg["tile_x"].size[-1] - packw_shape = (out_dim // packw_bn, in_dim, packw_bn) - packw = tvm.compute(packw_shape, - lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight") - - k = tvm.reduce_axis((0, in_dim), name="k") - C = tvm.compute((batch, out_dim), - lambda y, x: tvm.sum( - data[y, k].astype(out_dtype) * - packw[x // packw_bn, k, x % packw_bn].astype(out_dtype), - axis=k), - tag="dense_pack") + target = tvm.target.current_target() + if "cblas" in target.libs: + C = cblas.matmul(data, weight, False, True) + else: + if out_dtype is None: + out_dtype = data.dtype + batch, in_dim = get_const_tuple(data.shape) + out_dim, _ = get_const_tuple(weight.shape) + # create tuning space + cfg.define_split("tile_y", batch, num_outputs=3) + cfg.define_split("tile_x", out_dim, num_outputs=3) + cfg.define_split("tile_k", in_dim, num_outputs=2) + if cfg.is_fallback: + _default_dense_pack_config(cfg, batch, out_dim, in_dim) + + packw_bn = cfg["tile_x"].size[-1] + packw_shape = (out_dim // packw_bn, in_dim, packw_bn) + packw = tvm.compute(packw_shape, + lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight") + + k = tvm.reduce_axis((0, in_dim), name="k") + C = tvm.compute((batch, out_dim), + lambda y, x: tvm.sum( + data[y, k].astype(out_dtype) * + packw[x // packw_bn, k, x % packw_bn].astype(out_dtype), + axis=k), + tag="dense_pack") if bias is not None: C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) @@ -72,28 +77,32 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): # Declare dense compute without packing weight @autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack") def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None): - if out_dtype is None: - out_dtype = data.dtype - batch, in_dim = get_const_tuple(data.shape) - out_dim, _ = get_const_tuple(weight.shape) - # create tuning space - cfg.define_split("tile_x", out_dim, num_outputs=2) - cfg.define_split("tile_y", batch, num_outputs=2) - cfg.define_split("tile_k", in_dim, num_outputs=2) - if cfg.is_fallback: - _default_dense_nopack_config(cfg, batch, out_dim, in_dim) - - vec = cfg["tile_k"].size[-1] - k = tvm.reduce_axis((0, in_dim // vec), "k") - CC = tvm.compute((batch, out_dim, vec), - lambda z, y, x: tvm.sum( - data[z, k * vec + x].astype(out_dtype) * - weight[y, k * vec + x].astype(out_dtype), axis=k)) - - kk = tvm.reduce_axis((0, vec), "kk") - C = tvm.compute((batch, out_dim), - lambda y, x: tvm.sum(CC[y, x, kk], axis=kk), - tag="dense_nopack") + target = tvm.target.current_target() + if "cblas" in target.libs: + C = cblas.matmul(data, weight, False, True) + else: + if out_dtype is None: + out_dtype = data.dtype + batch, in_dim = get_const_tuple(data.shape) + out_dim, _ = get_const_tuple(weight.shape) + # create tuning space + cfg.define_split("tile_x", out_dim, num_outputs=2) + cfg.define_split("tile_y", batch, num_outputs=2) + cfg.define_split("tile_k", in_dim, num_outputs=2) + if cfg.is_fallback: + _default_dense_nopack_config(cfg, batch, out_dim, in_dim) + + vec = cfg["tile_k"].size[-1] + k = tvm.reduce_axis((0, in_dim // vec), "k") + CC = tvm.compute((batch, out_dim, vec), + lambda z, y, x: tvm.sum( + data[z, k * vec + x].astype(out_dtype) * + weight[y, k * vec + x].astype(out_dtype), axis=k)) + + kk = tvm.reduce_axis((0, vec), "kk") + C = tvm.compute((batch, out_dim), + lambda y, x: tvm.sum(CC[y, x, kk], axis=kk), + tag="dense_nopack") if bias is not None: C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) @@ -116,6 +125,10 @@ def _schedule_dense(cfg, outs): @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack") def _schedule_dense_pack(cfg, outs): + target = tvm.target.current_target() + if "cblas" in target.libs: + return generic.schedule_extern(outs) + s = tvm.create_schedule([x.op for x in outs]) def _callback(op): @@ -127,6 +140,10 @@ def _schedule_dense_pack(cfg, outs): @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack") def _schedule_dense_nopack(cfg, outs): + target = tvm.target.current_target() + if "cblas" in target.libs: + return generic.schedule_extern(outs) + s = tvm.create_schedule([x.op for x in outs]) def _callback(op): -- 2.7.4