1 // SPDX-License-Identifier: Apache-2.0
3 * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
5 * @file blas_interface.cpp
7 * @see https://github.com/nnstreamer/nntrainer
8 * @author Jijoong Moon <jijoong.moon@samsung.com>
9 * @bug No known bugs except for NYI items
10 * @brief This is dummy header for blas support
14 #include <blas_interface.h>
15 #include <nntrainer_error.h>
18 #include <blas_neon.h>
23 #define sgemv_loop(ci, cj, cM, cN) \
27 for (ci = 0; ci != cM; ci++) { \
28 y0 = Y[ci * incy] * beta; \
29 for (cj = 0; cj != cN; cj++) \
30 y0 += A[i + j * lda] * X[cj * incx]; \
35 #define sgemv_loop_fp16(ci, cj, cM, cN) \
39 for (ci = 0; ci != cM; ci++) { \
40 y0 = Y[ci * incy] * static_cast<_FP16>(beta); \
41 for (cj = 0; cj != cN; cj++) \
42 y0 += A[i + j * lda] * X[cj * incx]; \
47 #define saxpy_loop_fp16() \
50 for (i = 0; i < N; ++i) \
51 Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX]; \
57 static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X,
58 const int incX, _FP16 *Y, const int incY) {
59 if (incX < 0 or incY < 0)
60 throw std::invalid_argument(
61 "Error: negative inc not supported without cblas");
64 // USE__FP16 is defined when platform is android
65 if (incX == 1 && incY == 1) {
66 nntrainer::neon::saxpy_neon_fp16(N, alpha, X, Y);
75 static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
76 const unsigned int M, const unsigned int N,
77 const float alpha, const _FP16 *A,
78 const unsigned int lda, const _FP16 *X, const int incX,
79 const float beta, _FP16 *Y, const int incY) {
81 unsigned int incy = abs(incY);
82 unsigned int incx = abs(incX);
84 if (TransA == CblasTrans) {
86 if (incX == 1 && incY == 1 && (N % 16 == 0 || N % 8 == 0)) {
87 nntrainer::neon::sgemv_transpose_neon_fp16(A, X, Y, M, N, alpha, beta);
89 sgemv_loop_fp16(i, j, N, M);
92 sgemv_loop_fp16(i, j, N, M);
96 if (incX == 1 && incY == 1 && (N % 16 == 0 || N % 8 == 0)) {
97 nntrainer::neon::sgemv_neon_fp16(A, X, Y, M, N, alpha, beta);
99 sgemv_loop_fp16(j, i, M, N);
102 sgemv_loop_fp16(j, i, M, N);
107 static _FP16 sdot_FP16(const unsigned int N, const _FP16 *X,
108 const unsigned int incX, const _FP16 *Y,
109 const unsigned int incY) {
111 if (incX < 0 or incY < 0)
112 throw std::invalid_argument("Error: negative inc not supported");
117 if (incX == 1 && incY == 1) {
118 ret = nntrainer::neon::sdot_neon_fp16(N, X, Y);
120 for (unsigned int i = 0; i < N; ++i) {
121 ret += X[i * incX] * Y[i * incY];
125 for (unsigned int i = 0; i < N; ++i) {
126 ret += X[i * incX] * Y[i * incY];
132 static void scopy_FP16(const unsigned int N, const _FP16 *X, const int incX,
133 _FP16 *Y, const int incY) {
134 unsigned int incy = abs(incY);
135 unsigned int incx = abs(incX);
137 for (unsigned int i = 0; i < N; ++i)
138 Y[i * incy] = X[i * incx];
141 void sscal(const unsigned int N, const float alpha, _FP16 *X, const int incX) {
142 unsigned int incx = abs(incX);
144 for (unsigned int i = 0; i < N; ++i)
145 X[i * incx] = static_cast<_FP16>(alpha) * X[i * incx];
148 static _FP16 snrm2_FP16(const unsigned int N, const _FP16 *X, const int incX) {
149 unsigned int incx = abs(incX);
154 sum = nntrainer::neon::snrm2_neon_fp16(N, X);
156 for (unsigned int i = 0; i < N; i++) {
162 for (unsigned int i = 0; i < N; i++) {
167 return static_cast<_FP16>(sqrt(sum));
170 static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
171 CBLAS_TRANSPOSE TransB, const unsigned int M,
172 const unsigned int N, const unsigned int K,
173 const float alpha, const _FP16 *A,
174 const unsigned int lda, const _FP16 *B,
175 const unsigned int ldb, const float beta, _FP16 *C,
176 const unsigned int ldc) {
178 for (unsigned int m = 0; m < M; ++m) {
179 for (unsigned int n = 0; n < N; ++n) {
181 _FP16 c_old = C[m * ldc + n];
182 for (unsigned int k = 0; k < K; ++k) {
184 a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
185 b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
188 C[m * ldc + n] = static_cast<_FP16>(alpha) * c;
190 C[m * ldc + n] += static_cast<_FP16>(beta) * c_old;
195 static unsigned int isamax_FP16(const unsigned int N, const _FP16 *X,
198 unsigned int max_idx = 0;
199 _FP16 max_val = X[0];
200 for (unsigned int n = 1; n < N; n += incX) {
201 _FP16 cur_val = (X[n] >= 0) ? X[n] : -1 * X[n];
202 if (cur_val > max_val) {
211 void saxpy(const unsigned int N, const float alpha, const _FP16 *X,
212 const int incX, _FP16 *Y, const int incY) {
213 saxpy_FP16(N, alpha, X, incX, Y, incY);
216 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
217 const unsigned int M, const unsigned int N, const unsigned int K,
218 const float alpha, const _FP16 *A, const unsigned int lda,
219 const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C,
220 const unsigned int ldc) {
221 sgemm_FP16(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
225 void scopy(const unsigned int N, const _FP16 *X, const int incX, _FP16 *Y,
227 scopy_FP16(N, X, incX, Y, incY);
229 } // namespace nntrainer
231 _FP16 snrm2(const int N, const _FP16 *X, const int incX) {
232 return snrm2_FP16(N, X, incX);
235 _FP16 sdot(const unsigned int N, const _FP16 *X, const unsigned int incX,
236 const _FP16 *Y, const unsigned int incY) {
237 return sdot_FP16(N, X, incX, Y, incY);
240 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
241 const unsigned int N, const float alpha, const _FP16 *A,
242 const unsigned int lda, const _FP16 *X, const int incX,
243 const float beta, _FP16 *Y, const int incY) {
244 sgemv_FP16(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
247 unsigned int isamax(const unsigned int N, const _FP16 *X, const int incX) {
248 /// @todo isamax_FP16 for BLAS_NUM_THREADS
249 return isamax_FP16(N, X, incX);
255 static void saxpy_raw(const unsigned int N, const float alpha, const float *X,
256 const int incX, float *Y, const int incY) {
257 if (incX < 0 or incY < 0)
258 throw std::invalid_argument(
259 "Error: negative inc not supported without cblas");
260 for (unsigned int i = 0; i < N; ++i)
261 Y[i * incY] = Y[i * incY] + X[i * incX] * alpha;
264 static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
265 const unsigned int M, const unsigned int N,
266 const float alpha, const float *A, const unsigned int lda,
267 const float *X, const int incX, const float beta,
268 float *Y, const int incY) {
270 unsigned int incy = abs(incY);
271 unsigned int incx = abs(incX);
273 if (TransA == CblasTrans) {
274 sgemv_loop(i, j, N, M);
276 sgemv_loop(j, i, M, N);
280 static float sdot_raw(const unsigned int N, const float *X,
281 const unsigned int incX, const float *Y,
282 const unsigned int incY) {
284 for (unsigned int i = 0; i < N; ++i) {
285 ret += X[i * incX] * Y[i * incY];
290 static void scopy_raw(const unsigned int N, const float *X, const int incX,
291 float *Y, const int incY) {
292 unsigned int incy = abs(incY);
293 unsigned int incx = abs(incX);
295 for (unsigned int i = 0; i < N; ++i)
296 Y[i * incy] = X[i * incx];
299 static void sscal_raw(const unsigned int N, const float alpha, float *X,
301 unsigned int incx = abs(incX);
303 for (unsigned int i = 0; i < N; ++i)
304 X[i * incx] = alpha * X[i * incx];
307 static float snrm2_raw(const unsigned int N, const float *X, const int incX) {
308 unsigned int incx = abs(incX);
312 for (unsigned int i = 0; i < N; i++) {
319 static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
320 CBLAS_TRANSPOSE TransB, const unsigned int M,
321 const unsigned int N, const unsigned int K,
322 const float alpha, const float *A, const unsigned int lda,
323 const float *B, const unsigned int ldb, const float beta,
324 float *C, const unsigned int ldc) {
326 for (unsigned int m = 0; m < M; ++m) {
327 for (unsigned int n = 0; n < N; ++n) {
329 float c_old = C[m * ldc + n];
330 for (unsigned int k = 0; k < K; ++k) {
332 a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
333 b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
336 C[m * ldc + n] = alpha * c;
338 C[m * ldc + n] += beta * c_old;
343 static unsigned int isamax_raw(const unsigned int N, const float *X,
346 unsigned int max_idx = 0;
347 float max_val = X[0];
348 for (unsigned int n = 1; n < N; n += incX) {
349 float cur_val = abs(X[n]);
350 if (cur_val > max_val) {
361 void sscal(const unsigned int N, const float alpha, void *X, const int incX,
362 ml::train::TensorDim::DataType d_type) {
364 if (d_type == ml::train::TensorDim::DataType::FP32) {
367 #ifdef BLAS_NUM_THREADS
368 openblas_set_num_threads(BLAS_NUM_THREADS);
369 #endif // BLAS_NUM_THREADS
370 cblas_sscal(N, alpha, (float *)X, incX);
371 #else // USE_BLAS else
372 sscal_raw(N, alpha, (float *)X, incX);
374 } else if (d_type == ml::train::TensorDim::DataType::FP16) {
376 sscal(N, alpha, (_FP16 *)X, incX);
378 throw std::invalid_argument("Error: enable-fp16 is not enabled");
383 void sscal(const unsigned int N, const float alpha, float *X, const int incX) {
385 #ifdef BLAS_NUM_THREADS
386 openblas_set_num_threads(BLAS_NUM_THREADS);
388 cblas_sscal(N, alpha, X, incX);
390 sscal_raw(N, alpha, X, incX);
394 void saxpy(const unsigned int N, const float alpha, const void *X,
395 const int incX, void *Y, const int incY,
396 ml::train::TensorDim::DataType d_type) {
397 if (d_type == ml::train::TensorDim::DataType::FP32) {
399 #ifdef BLAS_NUM_THREADS
400 openblas_set_num_threads(BLAS_NUM_THREADS);
402 cblas_saxpy(N, alpha, static_cast<const float *>(X), incX,
403 static_cast<float *>(Y), incY);
405 saxpy_raw(N, alpha, static_cast<const float *>(X), incX,
406 static_cast<float *>(Y), incY);
408 } else if (d_type == ml::train::TensorDim::DataType::FP16) {
410 saxpy_FP16(N, alpha, static_cast<const _FP16 *>(X), incX,
411 static_cast<_FP16 *>(Y), incY);
413 throw std::invalid_argument("Error: enable-fp16 is not enabled");
418 void saxpy(const unsigned int N, const float alpha, const float *X,
419 const int incX, float *Y, const int incY) {
421 #ifdef BLAS_NUM_THREADS
422 openblas_set_num_threads(BLAS_NUM_THREADS);
424 cblas_saxpy(N, alpha, X, incX, Y, incY);
426 saxpy_raw(N, alpha, X, incX, Y, incY);
430 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
431 const unsigned int M, const unsigned int N, const unsigned int K,
432 const float alpha, const void *A, const unsigned int lda,
433 const void *B, const unsigned int ldb, const float beta, void *C,
434 const unsigned int ldc, ml::train::TensorDim::DataType d_type) {
436 if (d_type == ml::train::TensorDim::DataType::FP32) {
439 cudaDeviceProp deviceProp;
440 cudaGetDeviceProperties(&deviceProp, devID);
441 float *d_A, *d_B, *d_C;
443 unsigned int size_A = M * K * sizeof(float);
444 unsigned int size_B = K * N * sizeof(float);
445 unsigned int size_C = M * N * sizeof(float);
447 cudaMalloc((void **)&d_A, size_A);
448 cudaMalloc((void **)&d_B, size_B);
449 cudaMemcpy(d_A, A, size_A, cudaMemcpyHostToDevice);
450 cudaMemcpy(d_B, B, size_B, cudaMemcpyHostToDevice);
451 cudaMalloc((void **)&d_C, size_C);
453 cublasHandle_t handle;
454 cublasCreate(&handle);
456 cublasOperation_t transA =
457 (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
458 cublasOperation_t transB =
459 (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
460 cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
463 cudaMemcpy(C, d_C, size_C, cudaMemcpyDeviceToHost);
464 cublasDestroy(handle);
466 #elif defined USE_BLAS
468 #ifdef BLAS_NUM_THREADS
469 openblas_set_num_threads(BLAS_NUM_THREADS);
473 order, TransA, TransB, M, N, K, alpha, static_cast<const float *>(A), lda,
474 static_cast<const float *>(B), ldb, beta, static_cast<float *>(C), ldc);
476 sgemm_raw(order, TransA, TransB, M, N, K, alpha,
477 static_cast<const float *>(A), lda, static_cast<const float *>(B),
478 ldb, beta, static_cast<float *>(C), ldc);
481 } else if (d_type == ml::train::TensorDim::DataType::FP16) {
484 order, TransA, TransB, M, N, K, alpha, static_cast<const _FP16 *>(A), lda,
485 static_cast<const _FP16 *>(B), ldb, beta, static_cast<_FP16 *>(C), ldc);
487 throw std::invalid_argument("Error: enable-fp16 is not enabled");
490 } // namespace nntrainer
492 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
493 const unsigned int M, const unsigned int N, const unsigned int K,
494 const float alpha, const float *A, const unsigned int lda,
495 const float *B, const unsigned int ldb, const float beta, float *C,
496 const unsigned int ldc) {
500 cudaDeviceProp deviceProp;
501 cudaGetDeviceProperties(&deviceProp, devID);
502 float *d_A, *d_B, *d_C;
504 unsigned int size_A = M * K * sizeof(float);
505 unsigned int size_B = K * N * sizeof(float);
506 unsigned int size_C = M * N * sizeof(float);
508 cudaMalloc((void **)&d_A, size_A);
509 cudaMalloc((void **)&d_B, size_B);
510 cudaMemcpy(d_A, A, size_A, cudaMemcpyHostToDevice);
511 cudaMemcpy(d_B, B, size_B, cudaMemcpyHostToDevice);
512 cudaMalloc((void **)&d_C, size_C);
514 cublasHandle_t handle;
515 cublasCreate(&handle);
517 cublasOperation_t transA = (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
518 cublasOperation_t transB = (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
519 cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
522 cudaMemcpy(C, d_C, size_C, cudaMemcpyDeviceToHost);
523 cublasDestroy(handle);
524 #elif defined USE_BLAS
525 #ifdef BLAS_NUM_THREADS
526 openblas_set_num_threads(BLAS_NUM_THREADS);
528 cblas_sgemm(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
531 sgemm_raw(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
536 void scopy(const unsigned int N, const void *X, const int incX, void *Y,
537 const int incY, ml::train::TensorDim::DataType d_type) {
539 if (d_type == ml::train::TensorDim::DataType::FP32) {
542 #ifdef BLAS_NUM_THREADS
543 openblas_set_num_threads(BLAS_NUM_THREADS);
545 cblas_scopy(N, (float *)X, incX, (float *)Y, incY);
547 scopy_raw(N, (float *)X, incX, (float *)Y, incY);
550 } else if (d_type == ml::train::TensorDim::DataType::FP16) {
552 scopy_FP16(N, (_FP16 *)X, incX, (_FP16 *)Y, incY);
554 throw std::invalid_argument("Error: enable-fp16 is not enabled");
558 } // namespace nntrainer
560 void scopy(const unsigned int N, const float *X, const int incX, float *Y,
563 #ifdef BLAS_NUM_THREADS
564 openblas_set_num_threads(BLAS_NUM_THREADS);
566 cblas_scopy(N, X, incX, Y, incY);
568 scopy_raw(N, X, incX, Y, incY);
570 } // namespace nntrainer
572 float snrm2(const int N, const float *X, const int incX) {
574 #ifdef BLAS_NUM_THREADS
575 openblas_set_num_threads(BLAS_NUM_THREADS);
577 return cblas_snrm2(N, X, incX);
579 return snrm2_raw(N, X, incX);
583 float sdot(const unsigned int N, const float *X, const unsigned int incX,
584 const float *Y, const unsigned int incY) {
586 #ifdef BLAS_NUM_THREADS
587 openblas_set_num_threads(BLAS_NUM_THREADS);
589 return cblas_sdot(N, X, incX, Y, incY);
591 return sdot_raw(N, X, incX, Y, incY);
595 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
596 const unsigned int N, const float alpha, const void *A,
597 const unsigned int lda, const void *X, const int incX,
598 const float beta, void *Y, const int incY,
599 ml::train::TensorDim::DataType d_type) {
600 if (d_type == ml::train::TensorDim::DataType::FP32) {
602 #ifdef BLAS_NUM_THREADS
603 openblas_set_num_threads(BLAS_NUM_THREADS);
606 order, TransA, M, N, alpha, static_cast<const float *>(A), lda,
607 static_cast<const float *>(X), incX, beta, static_cast<float *>(Y), incY);
610 return sgemv_raw(order, TransA, M, N, alpha, static_cast<const float *>(A),
611 lda, static_cast<const float *>(X), incX, beta,
612 static_cast<float *>(Y), incY);
614 } else if (d_type == ml::train::TensorDim::DataType::FP16) {
616 return sgemv_FP16(order, TransA, M, N, alpha, static_cast<const _FP16 *>(A),
617 lda, static_cast<const _FP16 *>(X), incX, beta,
618 static_cast<_FP16 *>(Y), incY);
620 throw std::invalid_argument("Error: enable-fp16 is not enabled");
625 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
626 const unsigned int N, const float alpha, const float *A,
627 const unsigned int lda, const float *X, const int incX,
628 const float beta, float *Y, const int incY) {
630 #ifdef BLAS_NUM_THREADS
631 openblas_set_num_threads(BLAS_NUM_THREADS);
633 return cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y,
636 return sgemv_raw(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
640 unsigned int isamax(const unsigned int N, const float *X, const int incX) {
642 #ifdef BLAS_NUM_THREADS
643 openblas_set_num_threads(BLAS_NUM_THREADS);
645 return cblas_isamax(N, X, incX);
647 return isamax_raw(N, X, incX);
651 } // namespace nntrainer