1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP
6 #define OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP
10 #include "pointer.hpp"
13 #include <opencv2/core.hpp>
15 #include <cublas_v2.h>
21 #define CUDA4DNN_CHECK_CUBLAS(call) \
22 ::cv::dnn::cuda4dnn::csl::cublas::detail::check((call), CV_Func, __FILE__, __LINE__)
24 namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cublas {
26 /** @brief exception class for errors thrown by the cuBLAS API */
27 class cuBLASException : public CUDAException {
29 using CUDAException::CUDAException;
33 static void check(cublasStatus_t status, const char* func, const char* file, int line) {
34 auto cublasGetErrorString = [](cublasStatus_t err) {
36 case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
37 case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
38 case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
39 case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
40 case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
41 case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
42 case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
43 case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
44 case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
45 case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
47 return "UNKNOWN_CUBLAS_ERROR";
50 if (status != CUBLAS_STATUS_SUCCESS)
51 throw cuBLASException(Error::GpuApiCallError, cublasGetErrorString(status), func, file, line);
55 /** noncopyable cuBLAS smart handle
57 * UniqueHandle is a smart non-sharable wrapper for cuBLAS handle which ensures that the handle
58 * is destroyed after use. The handle can be associated with a CUDA stream by specifying the
59 * stream during construction. By default, the handle is associated with the default stream.
63 UniqueHandle() { CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle)); }
64 UniqueHandle(UniqueHandle&) = delete;
65 UniqueHandle(UniqueHandle&& other) noexcept
66 : stream(std::move(other.stream)), handle{ other.handle } {
67 other.handle = nullptr;
70 UniqueHandle(Stream strm) : stream(std::move(strm)) {
71 CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle));
73 CUDA4DNN_CHECK_CUBLAS(cublasSetStream(handle, stream.get()));
75 /* cublasDestroy won't throw if a valid handle is passed */
76 CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle));
81 ~UniqueHandle() noexcept {
82 if (handle != nullptr) {
83 /* cublasDestroy won't throw if a valid handle is passed */
84 CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle));
88 UniqueHandle& operator=(const UniqueHandle&) = delete;
89 UniqueHandle& operator=(UniqueHandle&& other) noexcept {
90 stream = std::move(other.stream);
91 handle = other.handle;
92 other.handle = nullptr;
96 /** @brief returns the raw cuBLAS handle */
97 cublasHandle_t get() const noexcept { return handle; }
101 cublasHandle_t handle;
104 /** @brief sharable cuBLAS smart handle
106 * Handle is a smart sharable wrapper for cuBLAS handle which ensures that the handle
107 * is destroyed after all references to the handle are destroyed. The handle can be
108 * associated with a CUDA stream by specifying the stream during construction. By default,
109 * the handle is associated with the default stream.
111 * @note Moving a Handle object to another invalidates the former
115 Handle() : handle(std::make_shared<UniqueHandle>()) { }
116 Handle(const Handle&) = default;
117 Handle(Handle&&) = default;
118 Handle(Stream strm) : handle(std::make_shared<UniqueHandle>(std::move(strm))) { }
120 Handle& operator=(const Handle&) = default;
121 Handle& operator=(Handle&&) = default;
123 /** returns true if the handle is valid */
124 explicit operator bool() const noexcept { return static_cast<bool>(handle); }
126 cublasHandle_t get() const noexcept {
128 return handle->get();
132 std::shared_ptr<UniqueHandle> handle;
135 /** @brief GEMM for colummn-major matrices
137 * \f$ C = \alpha AB + \beta C \f$
139 * @tparam T matrix element type (must be `half` or `float`)
141 * @param handle valid cuBLAS Handle
142 * @param transa use transposed matrix of A for computation
143 * @param transb use transposed matrix of B for computation
144 * @param rows_c number of rows in C
145 * @param cols_c number of columns in C
146 * @param common_dim common dimension of A (or trans A) and B (or trans B)
147 * @param alpha scale factor for AB
148 * @param[in] A pointer to column-major matrix A in device memory
149 * @param lda leading dimension of matrix A
150 * @param[in] B pointer to column-major matrix B in device memory
151 * @param ldb leading dimension of matrix B
152 * @param beta scale factor for C
153 * @param[in,out] C pointer to column-major matrix C in device memory
154 * @param ldc leading dimension of matrix C
156 * Exception Guarantee: Basic
159 void gemm(const Handle& handle,
160 bool transa, bool transb,
161 std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
162 T alpha, const DevicePtr<const T> A, std::size_t lda,
163 const DevicePtr<const T> B, std::size_t ldb,
164 T beta, const DevicePtr<T> C, std::size_t ldc);
167 void gemm<half>(const Handle& handle,
168 bool transa, bool transb,
169 std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
170 half alpha, const DevicePtr<const half> A, std::size_t lda,
171 const DevicePtr<const half> B, std::size_t ldb,
172 half beta, const DevicePtr<half> C, std::size_t ldc)
176 auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
177 opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
178 int irows_c = static_cast<int>(rows_c),
179 icols_c = static_cast<int>(cols_c),
180 icommon_dim = static_cast<int>(common_dim),
181 ilda = static_cast<int>(lda),
182 ildb = static_cast<int>(ldb),
183 ildc = static_cast<int>(ldc);
185 CUDA4DNN_CHECK_CUBLAS(
189 irows_c, icols_c, icommon_dim,
190 &alpha, A.get(), ilda,
198 void gemm<float>(const Handle& handle,
199 bool transa, bool transb,
200 std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
201 float alpha, const DevicePtr<const float> A, std::size_t lda,
202 const DevicePtr<const float> B, std::size_t ldb,
203 float beta, const DevicePtr<float> C, std::size_t ldc)
207 auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
208 opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
209 int irows_c = static_cast<int>(rows_c),
210 icols_c = static_cast<int>(cols_c),
211 icommon_dim = static_cast<int>(common_dim),
212 ilda = static_cast<int>(lda),
213 ildb = static_cast<int>(ldb),
214 ildc = static_cast<int>(ldc);
216 CUDA4DNN_CHECK_CUBLAS(
220 irows_c, icols_c, icommon_dim,
221 &alpha, A.get(), ilda,
228 }}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */
230 #endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */