From 88f9bfd46e71216d0c904dffbf9becd250128f20 Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Thu, 12 Sep 2019 13:04:45 -0700 Subject: [PATCH] [TOPI][CUDA] Support cuBLAS BatchMatMul (#3936) * Support cuBLAS BatchMatMul * Add test and check target name --- python/tvm/contrib/cublas.py | 28 +++++++++++++ src/contrib/cublas/cublas.cc | 60 +++++++++++++++++++++++++++ tests/python/contrib/test_cublas.py | 28 +++++++++++++ topi/include/topi/contrib/cublas.h | 32 ++++++++++++++ topi/python/topi/cuda/batch_matmul.py | 29 ++++++++++++- 5 files changed, 176 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/cublas.py b/python/tvm/contrib/cublas.py index d7351a3ad..6edac1ba1 100644 --- a/python/tvm/contrib/cublas.py +++ b/python/tvm/contrib/cublas.py @@ -46,3 +46,31 @@ def matmul(lhs, rhs, transa=False, transb=False): lambda ins, outs: _intrin.call_packed( "tvm.contrib.cublas.matmul", ins[0], ins[1], outs[0], transa, transb), name="C") + +def batch_matmul(lhs, rhs, transa=False, transb=False): + """Create an extern op that compute batch matrix mult of A and rhs with cuBLAS + + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + + Returns + ------- + C : Tensor + The result tensor. + """ + b = lhs.shape[0] + n = lhs.shape[2] if transa else lhs.shape[1] + m = rhs.shape[1] if transb else rhs.shape[2] + return _api.extern( + (b, n, m), [lhs, rhs], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.cublas.batch_matmul", + ins[0], ins[1], outs[0], transa, transb), name="C") diff --git a/src/contrib/cublas/cublas.cc b/src/contrib/cublas/cublas.cc index 694d19b39..5cee5be28 100644 --- a/src/contrib/cublas/cublas.cc +++ b/src/contrib/cublas/cublas.cc @@ -81,6 +81,50 @@ struct CublasDgemmOp { } }; +struct CublasSgemmBatchOp { + typedef float TDatatype; + cublasHandle_t handle; + explicit CublasSgemmBatchOp(cublasHandle_t hdl) + : handle(hdl) + {} + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle, + BooleanToTranspose(ta), + BooleanToTranspose(tb), + M, N, K, + &alpha, + A, lda, a_stride, + B, ldb, b_stride, + &beta, + C, ldc, c_stride, + batch_size)); + } +}; + +struct CublasDgemmBatchOp { + typedef double TDatatype; + cublasHandle_t handle; + explicit CublasDgemmBatchOp(cublasHandle_t hdl) + : handle(hdl) + {} + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle, + BooleanToTranspose(ta), + BooleanToTranspose(tb), + M, N, K, + &alpha, + A, lda, a_stride, + B, ldb, b_stride, + &beta, + C, ldc, c_stride, + batch_size)); + } +}; + // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") .set_body([](TVMArgs args, TVMRetValue *ret) { @@ -96,5 +140,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") else CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); }); + +TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); + + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); + else + CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); +}); + } // namespace contrib } // namespace tvm diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index de19b602f..f0519ebe9 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -44,6 +44,34 @@ def test_matmul_add(): c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5) verify() +def test_batch_matmul(): + j = 16 + n = 1024 + l = 128 + m = 235 + A = tvm.placeholder((j, n, l), name='A') + B = tvm.placeholder((j, l, m), name='B') + C = cublas.batch_matmul(A, B) + s = tvm.create_schedule(C.op) + + def verify(target="cuda"): + if not tvm.module.enabled(target): + print("skip because %s is not enabled..." % target) + return + if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): + print("skip because extern function is not available") + return + ctx = tvm.gpu(0) + f = tvm.build(s, [A, B, C], target) + a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx) + f(a, b, c) + tvm.testing.assert_allclose( + c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5) + verify() + if __name__ == "__main__": test_matmul_add() + test_batch_matmul() diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h index 3b47a93e1..4dce9a0bc 100644 --- a/topi/include/topi/contrib/cublas.h +++ b/topi/include/topi/contrib/cublas.h @@ -61,6 +61,38 @@ inline Tensor cublas_matmul(const Tensor& lhs, }, "C", "", {})[0]; } +/*! +* \brief Create an op that multiplies batch matrices +* lhs and rhs with cuBLAS +* +* \param lhs The left matrix operand +* \param rhs The right matrix operand +* \param transa Whether to transpose lhs +* \param transb Whether to transpose rhs +* +* \return The output tensor +*/ +inline Tensor cublas_batch_matmul(const Tensor& lhs, + const Tensor& rhs, + bool transa, + bool transb) { + auto b = lhs->shape[0]; + auto n = transa ? lhs->shape[2] : lhs->shape[1]; + auto m = transb ? rhs->shape[1] : rhs->shape[2]; + + return make_extern( + { { b, n, m } }, { lhs->dtype }, { lhs, rhs }, + [&](Array ins, Array outs) { + return call_packed({ + Expr("tvm.contrib.cublas.batch_matmul"), + pack_buffer(ins[0]), + pack_buffer(ins[1]), + pack_buffer(outs[0]), + transa, + transb }); + }, "C", "", {})[0]; +} + } // namespace contrib } // namespace topi diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py index b5dd802ad..2d1b93ec0 100644 --- a/topi/python/topi/cuda/batch_matmul.py +++ b/topi/python/topi/cuda/batch_matmul.py @@ -18,10 +18,33 @@ """cuda batch_matmul operators""" from __future__ import absolute_import as _abs import tvm - +from tvm.contrib import cublas +from topi.nn import batch_matmul, batch_matmul_default from .. import generic from ..util import traverse_inline, get_const_tuple, get_max_power2_factor +@batch_matmul.register(["cuda", "gpu"]) +def batch_matmul_cuda(x, y): + """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are + data in batch. + + Parameters + ---------- + x : tvm.Tensor + 3-D with shape [batch, M, K] + + y : tvm.Tensor + 3-D with shape [batch, N, K] + + Returns + ------- + output : tvm.Tensor + 3-D with shape [batch, M, N] + """ + target = tvm.target.current_target() + if target.target_name == "cuda" and "cublas" in target.libs: + return cublas.batch_matmul(x, y, False, True) + return batch_matmul_default(x, y) @generic.schedule_batch_matmul.register(["cuda", "gpu"]) def schedule_batch_matmul(outs): @@ -38,6 +61,10 @@ def schedule_batch_matmul(outs): s: Schedule The computation schedule for the op. """ + target = tvm.target.current_target() + if target.target_name == "cuda" and "cublas" in target.libs: + return generic.schedule_extern(outs) + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) -- 2.34.1