[TOPI][CUDA] Support cuBLAS BatchMatMul (#3936)
authorJon Soifer <soiferj@gmail.com>
Thu, 12 Sep 2019 20:04:45 +0000 (13:04 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Thu, 12 Sep 2019 20:04:44 +0000 (13:04 -0700)
* Support cuBLAS BatchMatMul

* Add test and check target name

python/tvm/contrib/cublas.py
src/contrib/cublas/cublas.cc
tests/python/contrib/test_cublas.py
topi/include/topi/contrib/cublas.h
topi/python/topi/cuda/batch_matmul.py

index d7351a3..6edac1b 100644 (file)
@@ -46,3 +46,31 @@ def matmul(lhs, rhs, transa=False, transb=False):
         lambda ins, outs: _intrin.call_packed(
             "tvm.contrib.cublas.matmul",
             ins[0], ins[1], outs[0], transa, transb), name="C")
+
+def batch_matmul(lhs, rhs, transa=False, transb=False):
+    """Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
+
+    Parameters
+    ----------
+    lhs : Tensor
+        The left matrix operand
+    rhs : Tensor
+        The right matrix operand
+    transa : bool
+        Whether transpose lhs
+    transb : bool
+        Whether transpose rhs
+
+    Returns
+    -------
+    C : Tensor
+        The result tensor.
+    """
+    b = lhs.shape[0]
+    n = lhs.shape[2] if transa else lhs.shape[1]
+    m = rhs.shape[1] if transb else rhs.shape[2]
+    return _api.extern(
+        (b, n, m), [lhs, rhs],
+        lambda ins, outs: _intrin.call_packed(
+            "tvm.contrib.cublas.batch_matmul",
+            ins[0], ins[1], outs[0], transa, transb), name="C")
index 694d19b..5cee5be 100644 (file)
@@ -81,6 +81,50 @@ struct CublasDgemmOp {
   }
 };
 
+struct CublasSgemmBatchOp {
+  typedef float TDatatype;
+  cublasHandle_t handle;
+  explicit CublasSgemmBatchOp(cublasHandle_t hdl)
+    : handle(hdl)
+    {}
+  void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
+                  int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
+                  int c_stride, int ldc) {
+    CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle,
+                                                 BooleanToTranspose(ta),
+                                                 BooleanToTranspose(tb),
+                                                 M, N, K,
+                                                 &alpha,
+                                                 A, lda, a_stride,
+                                                 B, ldb, b_stride,
+                                                 &beta,
+                                                 C, ldc, c_stride,
+                                                 batch_size));
+  }
+};
+
+struct CublasDgemmBatchOp {
+  typedef double TDatatype;
+  cublasHandle_t handle;
+  explicit CublasDgemmBatchOp(cublasHandle_t hdl)
+    : handle(hdl)
+    {}
+  void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
+                  int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
+                  int c_stride, int ldc) {
+    CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle,
+                                                 BooleanToTranspose(ta),
+                                                 BooleanToTranspose(tb),
+                                                 M, N, K,
+                                                 &alpha,
+                                                 A, lda, a_stride,
+                                                 B, ldb, b_stride,
+                                                 &beta,
+                                                 C, ldc, c_stride,
+                                                 batch_size));
+  }
+};
+
 // matrix multiplication for row major
 TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
@@ -96,5 +140,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
     else
       CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
 });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    DLTensor* A = args[0];
+
+    CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
+          TypeMatch(A->dtype, kDLFloat, 64));
+
+    CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
+
+    if (TypeMatch(A->dtype, kDLFloat, 32))
+      CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle));
+    else
+      CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle));
+});
+
 }  // namespace contrib
 }  // namespace tvm
index de19b60..f0519eb 100644 (file)
@@ -44,6 +44,34 @@ def test_matmul_add():
             c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
     verify()
 
+def test_batch_matmul():
+    j = 16
+    n = 1024
+    l = 128
+    m = 235
+    A = tvm.placeholder((j, n, l), name='A')
+    B = tvm.placeholder((j, l, m), name='B')
+    C = cublas.batch_matmul(A, B)
+    s = tvm.create_schedule(C.op)
+
+    def verify(target="cuda"):
+        if not tvm.module.enabled(target):
+            print("skip because %s is not enabled..." % target)
+            return
+        if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
+            print("skip because extern function is not available")
+            return
+        ctx = tvm.gpu(0)
+        f = tvm.build(s, [A, B, C], target)
+        a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), ctx)
+        c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx)
+        f(a, b, c)
+        tvm.testing.assert_allclose(
+            c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5)
+    verify()
+
 
 if __name__ == "__main__":
     test_matmul_add()
+    test_batch_matmul()
index 3b47a93..4dce9a0 100644 (file)
@@ -61,6 +61,38 @@ inline Tensor cublas_matmul(const Tensor& lhs,
     }, "C", "", {})[0];
 }
 
+/*!
+* \brief Create an op that multiplies batch matrices 
+*        lhs and rhs with cuBLAS
+*
+* \param lhs The left matrix operand
+* \param rhs The right matrix operand
+* \param transa Whether to transpose lhs
+* \param transb Whether to transpose rhs
+*
+* \return The output tensor
+*/
+inline Tensor cublas_batch_matmul(const Tensor& lhs,
+                                  const Tensor& rhs,
+                                  bool transa,
+                                  bool transb) {
+  auto b = lhs->shape[0];
+  auto n = transa ? lhs->shape[2] : lhs->shape[1];
+  auto m = transb ? rhs->shape[1] : rhs->shape[2];
+
+  return make_extern(
+    { { b, n, m } }, { lhs->dtype }, { lhs, rhs },
+    [&](Array<Buffer> ins, Array<Buffer> outs) {
+      return call_packed({
+        Expr("tvm.contrib.cublas.batch_matmul"),
+        pack_buffer(ins[0]),
+        pack_buffer(ins[1]),
+        pack_buffer(outs[0]),
+        transa,
+        transb });
+    }, "C", "", {})[0];
+}
+
 }  // namespace contrib
 }  // namespace topi
 
index b5dd802..2d1b93e 100644 (file)
 """cuda batch_matmul operators"""
 from __future__ import absolute_import as _abs
 import tvm
-
+from tvm.contrib import cublas
+from topi.nn import batch_matmul, batch_matmul_default
 from .. import generic
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
+@batch_matmul.register(["cuda", "gpu"])
+def batch_matmul_cuda(x, y):
+    """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+    data in batch.
+
+    Parameters
+    ----------
+    x : tvm.Tensor
+        3-D with shape [batch, M, K]
+
+    y : tvm.Tensor
+        3-D with shape [batch, N, K]
+
+    Returns
+    -------
+    output : tvm.Tensor
+        3-D with shape [batch, M, N]
+    """
+    target = tvm.target.current_target()
+    if target.target_name == "cuda" and "cublas" in target.libs:
+        return cublas.batch_matmul(x, y, False, True)
+    return batch_matmul_default(x, y)
 
 @generic.schedule_batch_matmul.register(["cuda", "gpu"])
 def schedule_batch_matmul(outs):
@@ -38,6 +61,10 @@ def schedule_batch_matmul(outs):
     s: Schedule
         The computation schedule for the op.
     """
+    target = tvm.target.current_target()
+    if target.target_name == "cuda" and "cublas" in target.libs:
+        return generic.schedule_extern(outs)
+
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])