lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
+
+def batch_matmul(lhs, rhs, transa=False, transb=False):
+ """Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
+
+ Parameters
+ ----------
+ lhs : Tensor
+ The left matrix operand
+ rhs : Tensor
+ The right matrix operand
+ transa : bool
+ Whether transpose lhs
+ transb : bool
+ Whether transpose rhs
+
+ Returns
+ -------
+ C : Tensor
+ The result tensor.
+ """
+ b = lhs.shape[0]
+ n = lhs.shape[2] if transa else lhs.shape[1]
+ m = rhs.shape[1] if transb else rhs.shape[2]
+ return _api.extern(
+ (b, n, m), [lhs, rhs],
+ lambda ins, outs: _intrin.call_packed(
+ "tvm.contrib.cublas.batch_matmul",
+ ins[0], ins[1], outs[0], transa, transb), name="C")
}
};
+struct CublasSgemmBatchOp {
+ typedef float TDatatype;
+ cublasHandle_t handle;
+ explicit CublasSgemmBatchOp(cublasHandle_t hdl)
+ : handle(hdl)
+ {}
+ void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
+ int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
+ int c_stride, int ldc) {
+ CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle,
+ BooleanToTranspose(ta),
+ BooleanToTranspose(tb),
+ M, N, K,
+ &alpha,
+ A, lda, a_stride,
+ B, ldb, b_stride,
+ &beta,
+ C, ldc, c_stride,
+ batch_size));
+ }
+};
+
+struct CublasDgemmBatchOp {
+ typedef double TDatatype;
+ cublasHandle_t handle;
+ explicit CublasDgemmBatchOp(cublasHandle_t hdl)
+ : handle(hdl)
+ {}
+ void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
+ int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
+ int c_stride, int ldc) {
+ CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle,
+ BooleanToTranspose(ta),
+ BooleanToTranspose(tb),
+ M, N, K,
+ &alpha,
+ A, lda, a_stride,
+ B, ldb, b_stride,
+ &beta,
+ C, ldc, c_stride,
+ batch_size));
+ }
+};
+
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
else
CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
});
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+
+ CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
+ TypeMatch(A->dtype, kDLFloat, 64));
+
+ CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
+
+ if (TypeMatch(A->dtype, kDLFloat, 32))
+ CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle));
+ else
+ CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle));
+});
+
} // namespace contrib
} // namespace tvm
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
verify()
+def test_batch_matmul():
+ j = 16
+ n = 1024
+ l = 128
+ m = 235
+ A = tvm.placeholder((j, n, l), name='A')
+ B = tvm.placeholder((j, l, m), name='B')
+ C = cublas.batch_matmul(A, B)
+ s = tvm.create_schedule(C.op)
+
+ def verify(target="cuda"):
+ if not tvm.module.enabled(target):
+ print("skip because %s is not enabled..." % target)
+ return
+ if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
+ print("skip because extern function is not available")
+ return
+ ctx = tvm.gpu(0)
+ f = tvm.build(s, [A, B, C], target)
+ a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), ctx)
+ b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), ctx)
+ c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx)
+ f(a, b, c)
+ tvm.testing.assert_allclose(
+ c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5)
+ verify()
+
if __name__ == "__main__":
test_matmul_add()
+ test_batch_matmul()
}, "C", "", {})[0];
}
+/*!
+* \brief Create an op that multiplies batch matrices
+* lhs and rhs with cuBLAS
+*
+* \param lhs The left matrix operand
+* \param rhs The right matrix operand
+* \param transa Whether to transpose lhs
+* \param transb Whether to transpose rhs
+*
+* \return The output tensor
+*/
+inline Tensor cublas_batch_matmul(const Tensor& lhs,
+ const Tensor& rhs,
+ bool transa,
+ bool transb) {
+ auto b = lhs->shape[0];
+ auto n = transa ? lhs->shape[2] : lhs->shape[1];
+ auto m = transb ? rhs->shape[1] : rhs->shape[2];
+
+ return make_extern(
+ { { b, n, m } }, { lhs->dtype }, { lhs, rhs },
+ [&](Array<Buffer> ins, Array<Buffer> outs) {
+ return call_packed({
+ Expr("tvm.contrib.cublas.batch_matmul"),
+ pack_buffer(ins[0]),
+ pack_buffer(ins[1]),
+ pack_buffer(outs[0]),
+ transa,
+ transb });
+ }, "C", "", {})[0];
+}
+
} // namespace contrib
} // namespace topi
"""cuda batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm
-
+from tvm.contrib import cublas
+from topi.nn import batch_matmul, batch_matmul_default
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
+@batch_matmul.register(["cuda", "gpu"])
+def batch_matmul_cuda(x, y):
+ """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+ data in batch.
+
+ Parameters
+ ----------
+ x : tvm.Tensor
+ 3-D with shape [batch, M, K]
+
+ y : tvm.Tensor
+ 3-D with shape [batch, N, K]
+
+ Returns
+ -------
+ output : tvm.Tensor
+ 3-D with shape [batch, M, N]
+ """
+ target = tvm.target.current_target()
+ if target.target_name == "cuda" and "cublas" in target.libs:
+ return cublas.batch_matmul(x, y, False, True)
+ return batch_matmul_default(x, y)
@generic.schedule_batch_matmul.register(["cuda", "gpu"])
def schedule_batch_matmul(outs):
s: Schedule
The computation schedule for the op.
"""
+ target = tvm.target.current_target()
+ if target.target_name == "cuda" and "cublas" in target.libs:
+ return generic.schedule_extern(outs)
+
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])