Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda4dnn / csl / cublas.hpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP
6 #define OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP
7
8 #include "error.hpp"
9 #include "stream.hpp"
10 #include "pointer.hpp"
11 #include "fp16.hpp"
12
13 #include <opencv2/core.hpp>
14
15 #include <cublas_v2.h>
16
17 #include <cstddef>
18 #include <memory>
19 #include <utility>
20
21 #define CUDA4DNN_CHECK_CUBLAS(call) \
22     ::cv::dnn::cuda4dnn::csl::cublas::detail::check((call), CV_Func, __FILE__, __LINE__)
23
24 namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cublas {
25
26     /** @brief exception class for errors thrown by the cuBLAS API */
27     class cuBLASException : public CUDAException {
28     public:
29         using CUDAException::CUDAException;
30     };
31
32     namespace detail {
33         static void check(cublasStatus_t status, const char* func, const char* file, int line) {
34             auto cublasGetErrorString = [](cublasStatus_t err) {
35                 switch (err) {
36                 case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
37                 case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
38                 case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
39                 case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
40                 case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
41                 case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
42                 case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
43                 case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
44                 case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
45                 case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
46                 }
47                 return "UNKNOWN_CUBLAS_ERROR";
48             };
49
50             if (status != CUBLAS_STATUS_SUCCESS)
51                 throw cuBLASException(Error::GpuApiCallError, cublasGetErrorString(status), func, file, line);
52         }
53     }
54
55     /** noncopyable cuBLAS smart handle
56      *
57      * UniqueHandle is a smart non-sharable wrapper for cuBLAS handle which ensures that the handle
58      * is destroyed after use. The handle can be associated with a CUDA stream by specifying the
59      * stream during construction. By default, the handle is associated with the default stream.
60      */
61     class UniqueHandle {
62     public:
63         UniqueHandle() { CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle)); }
64         UniqueHandle(UniqueHandle&) = delete;
65         UniqueHandle(UniqueHandle&& other) noexcept
66             : stream(std::move(other.stream)), handle{ other.handle } {
67             other.handle = nullptr;
68         }
69
70         UniqueHandle(Stream strm) : stream(std::move(strm)) {
71             CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle));
72             try {
73                 CUDA4DNN_CHECK_CUBLAS(cublasSetStream(handle, stream.get()));
74             } catch (...) {
75                 /* cublasDestroy won't throw if a valid handle is passed */
76                 CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle));
77                 throw;
78             }
79         }
80
81         ~UniqueHandle() noexcept {
82             if (handle != nullptr) {
83                 /* cublasDestroy won't throw if a valid handle is passed */
84                 CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle));
85             }
86         }
87
88         UniqueHandle& operator=(const UniqueHandle&) = delete;
89         UniqueHandle& operator=(UniqueHandle&& other) noexcept {
90             stream = std::move(other.stream);
91             handle = other.handle;
92             other.handle = nullptr;
93             return *this;
94         }
95
96         /** @brief returns the raw cuBLAS handle */
97         cublasHandle_t get() const noexcept { return handle; }
98
99     private:
100         Stream stream;
101         cublasHandle_t handle;
102     };
103
104     /** @brief sharable cuBLAS smart handle
105      *
106      * Handle is a smart sharable wrapper for cuBLAS handle which ensures that the handle
107      * is destroyed after all references to the handle are destroyed. The handle can be
108      * associated with a CUDA stream by specifying the stream during construction. By default,
109      * the handle is associated with the default stream.
110      *
111      * @note Moving a Handle object to another invalidates the former
112      */
113     class Handle {
114     public:
115         Handle() : handle(std::make_shared<UniqueHandle>()) { }
116         Handle(const Handle&) = default;
117         Handle(Handle&&) = default;
118         Handle(Stream strm) : handle(std::make_shared<UniqueHandle>(std::move(strm))) { }
119
120         Handle& operator=(const Handle&) = default;
121         Handle& operator=(Handle&&) = default;
122
123         /** returns true if the handle is valid */
124         explicit operator bool() const noexcept { return static_cast<bool>(handle); }
125
126         cublasHandle_t get() const noexcept {
127             CV_Assert(handle);
128             return handle->get();
129         }
130
131     private:
132         std::shared_ptr<UniqueHandle> handle;
133     };
134
135     /** @brief GEMM for colummn-major matrices
136      *
137      * \f$ C = \alpha AB + \beta C \f$
138      *
139      * @tparam          T           matrix element type (must be `half` or `float`)
140      *
141      * @param           handle      valid cuBLAS Handle
142      * @param           transa      use transposed matrix of A for computation
143      * @param           transb      use transposed matrix of B for computation
144      * @param           rows_c      number of rows in C
145      * @param           cols_c      number of columns in C
146      * @param           common_dim  common dimension of A (or trans A) and B (or trans B)
147      * @param           alpha       scale factor for AB
148      * @param[in]       A           pointer to column-major matrix A in device memory
149      * @param           lda         leading dimension of matrix A
150      * @param[in]       B           pointer to column-major matrix B in device memory
151      * @param           ldb         leading dimension of matrix B
152      * @param           beta        scale factor for C
153      * @param[in,out]   C           pointer to column-major matrix C in device memory
154      * @param           ldc         leading dimension of matrix C
155      *
156      * Exception Guarantee: Basic
157      */
158     template <class T>
159     void gemm(const Handle& handle,
160         bool transa, bool transb,
161         std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
162         T alpha, const DevicePtr<const T> A, std::size_t lda,
163         const DevicePtr<const T> B, std::size_t ldb,
164         T beta, const DevicePtr<T> C, std::size_t ldc);
165
166     template <> inline
167     void gemm<half>(const Handle& handle,
168         bool transa, bool transb,
169         std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
170         half alpha, const DevicePtr<const half> A, std::size_t lda,
171         const DevicePtr<const half> B, std::size_t ldb,
172         half beta, const DevicePtr<half> C, std::size_t ldc)
173     {
174         CV_Assert(handle);
175
176         auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
177             opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
178         int irows_c = static_cast<int>(rows_c),
179             icols_c = static_cast<int>(cols_c),
180             icommon_dim = static_cast<int>(common_dim),
181             ilda = static_cast<int>(lda),
182             ildb = static_cast<int>(ldb),
183             ildc = static_cast<int>(ldc);
184
185         CUDA4DNN_CHECK_CUBLAS(
186             cublasHgemm(
187                 handle.get(),
188                 opa, opb,
189                 irows_c, icols_c, icommon_dim,
190                 &alpha, A.get(), ilda,
191                 B.get(), ildb,
192                 &beta, C.get(), ildc
193             )
194         );
195     }
196
197     template <> inline
198     void gemm<float>(const Handle& handle,
199         bool transa, bool transb,
200         std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
201         float alpha, const DevicePtr<const float> A, std::size_t lda,
202         const DevicePtr<const float> B, std::size_t ldb,
203         float beta, const DevicePtr<float> C, std::size_t ldc)
204     {
205         CV_Assert(handle);
206
207         auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
208             opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
209         int irows_c = static_cast<int>(rows_c),
210             icols_c = static_cast<int>(cols_c),
211             icommon_dim = static_cast<int>(common_dim),
212             ilda = static_cast<int>(lda),
213             ildb = static_cast<int>(ldb),
214             ildc = static_cast<int>(ldc);
215
216         CUDA4DNN_CHECK_CUBLAS(
217             cublasSgemm(
218                 handle.get(),
219                 opa, opb,
220                 irows_c, icols_c, icommon_dim,
221                 &alpha, A.get(), ilda,
222                 B.get(), ildb,
223                 &beta, C.get(), ildc
224             )
225         );
226     }
227
228 }}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */
229
230 #endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */