[Coverity] Fix coverity issues
[platform/core/ml/nntrainer.git] / nntrainer / tensor / blas_interface.cpp
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
4  *
5  * @file   blas_interface.cpp
6  * @date   28 Aug 2020
7  * @see    https://github.com/nnstreamer/nntrainer
8  * @author Jijoong Moon <jijoong.moon@samsung.com>
9  * @bug    No known bugs except for NYI items
10  * @brief  This is dummy header for blas support
11  *
12  */
13
14 #include <blas_interface.h>
15 #include <nntrainer_error.h>
16
17 #ifdef USE__FP16
18 #include <blas_neon.h>
19 #endif
20
21 #include <cmath>
22
23 #define sgemv_loop(ci, cj, cM, cN)           \
24   do {                                       \
25     float y0;                                \
26     unsigned int i, j;                       \
27     for (ci = 0; ci != cM; ci++) {           \
28       y0 = Y[ci * incy] * beta;              \
29       for (cj = 0; cj != cN; cj++)           \
30         y0 += A[i + j * lda] * X[cj * incx]; \
31       Y[ci * incy] = y0;                     \
32     }                                        \
33   } while (0);
34
35 #define sgemv_loop_fp16(ci, cj, cM, cN)             \
36   do {                                              \
37     _FP16 y0;                                       \
38     unsigned int i, j;                              \
39     for (ci = 0; ci != cM; ci++) {                  \
40       y0 = Y[ci * incy] * static_cast<_FP16>(beta); \
41       for (cj = 0; cj != cN; cj++)                  \
42         y0 += A[i + j * lda] * X[cj * incx];        \
43       Y[ci * incy] = y0;                            \
44     }                                               \
45   } while (0);
46
47 #define saxpy_loop_fp16()                                                  \
48   do {                                                                     \
49     unsigned int i;                                                        \
50     for (i = 0; i < N; ++i)                                                \
51       Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX]; \
52   } while (0);
53
54 #define sgemm_loop_fp16()                                                 \
55   do {                                                                    \
56     for (unsigned int m = 0; m < M; ++m) {                                \
57       for (unsigned int n = 0; n < N; ++n) {                              \
58         _FP16 c = 0;                                                      \
59         _FP16 c_old = C[m * ldc + n];                                     \
60         for (unsigned int k = 0; k < K; ++k) {                            \
61           _FP16 a, b;                                                     \
62           a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]); \
63           b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]); \
64           c += a * b;                                                     \
65         }                                                                 \
66         C[m * ldc + n] = static_cast<_FP16>(alpha) * c;                   \
67         if (beta != 0.0)                                                  \
68           C[m * ldc + n] += static_cast<_FP16>(beta) * c_old;             \
69       }                                                                   \
70     }                                                                     \
71   } while (0);
72
73 namespace nntrainer {
74
75 #ifdef ENABLE_FP16
76 static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X,
77                        const int incX, _FP16 *Y, const int incY) {
78   if (incX < 0 or incY < 0)
79     throw std::invalid_argument(
80       "Error: negative inc not supported without cblas");
81
82 #ifdef USE__FP16
83   // USE__FP16 is defined when platform is android
84   if (incX == 1 && incY == 1) {
85     nntrainer::neon::saxpy_neon_fp16(N, alpha, X, Y);
86   } else {
87     saxpy_loop_fp16();
88   }
89 #else
90   saxpy_loop_fp16();
91 #endif
92 }
93
94 static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
95                        const unsigned int M, const unsigned int N,
96                        const float alpha, const _FP16 *A,
97                        const unsigned int lda, const _FP16 *X, const int incX,
98                        const float beta, _FP16 *Y, const int incY) {
99
100   unsigned int incy = abs(incY);
101   unsigned int incx = abs(incX);
102
103   if (TransA == CblasTrans) {
104 #ifdef USE__FP16
105     if (incX == 1 && incY == 1 && (N % 16 == 0 || N % 8 == 0)) {
106       nntrainer::neon::sgemv_transpose_neon_fp16(A, X, Y, M, N, alpha, beta);
107     } else {
108       sgemv_loop_fp16(i, j, N, M);
109     }
110 #else
111     sgemv_loop_fp16(i, j, N, M);
112 #endif
113   } else {
114 #ifdef USE__FP16
115     if (incX == 1 && incY == 1 && (N % 16 == 0 || N % 8 == 0)) {
116       nntrainer::neon::sgemv_neon_fp16(A, X, Y, M, N, alpha, beta);
117     } else {
118       sgemv_loop_fp16(j, i, M, N);
119     }
120 #else
121     sgemv_loop_fp16(j, i, M, N);
122 #endif
123   }
124 }
125
126 static _FP16 sdot_FP16(const unsigned int N, const _FP16 *X,
127                        const unsigned int incX, const _FP16 *Y,
128                        const unsigned int incY) {
129
130   if (incX < 0 or incY < 0)
131     throw std::invalid_argument("Error: negative inc not supported");
132
133   _FP16 ret = 0;
134
135 #ifdef USE__FP16
136   if (incX == 1 && incY == 1) {
137     ret = nntrainer::neon::sdot_neon_fp16(N, X, Y);
138   } else {
139     for (unsigned int i = 0; i < N; ++i) {
140       ret += X[i * incX] * Y[i * incY];
141     }
142   }
143 #else
144   for (unsigned int i = 0; i < N; ++i) {
145     ret += X[i * incX] * Y[i * incY];
146   }
147 #endif
148   return ret;
149 }
150
151 static void scopy_FP16(const unsigned int N, const _FP16 *X, const int incX,
152                        _FP16 *Y, const int incY) {
153   unsigned int incy = abs(incY);
154   unsigned int incx = abs(incX);
155
156 #ifdef USE__FP16
157   if (incX == 1 && incY == 1) {
158     nntrainer::neon::scopy_neon_fp16(N, X, Y);
159   } else {
160     for (unsigned int i = 0; i < N; ++i)
161       Y[i * incy] = X[i * incx];
162   }
163 #else
164   for (unsigned int i = 0; i < N; ++i)
165     Y[i * incy] = X[i * incx];
166 #endif
167 }
168
169 void sscal(const unsigned int N, const float alpha, _FP16 *X, const int incX) {
170   unsigned int incx = abs(incX);
171
172 #ifdef USE__FP16
173   if (incX == 1) {
174     nntrainer::neon::sscal_neon_fp16(N, X, alpha);
175   } else {
176     for (unsigned int i = 0; i < N; ++i)
177       X[i * incx] = static_cast<_FP16>(alpha) * X[i * incx];
178   }
179 #else
180   for (unsigned int i = 0; i < N; ++i)
181     X[i * incx] = static_cast<_FP16>(alpha) * X[i * incx];
182 #endif
183 }
184
185 static _FP16 snrm2_FP16(const unsigned int N, const _FP16 *X, const int incX) {
186   unsigned int incx = abs(incX);
187   _FP16 sum = 0;
188   _FP16 tmp;
189 #ifdef USE__FP16
190   if (incX == 1) {
191     sum = nntrainer::neon::snrm2_neon_fp16(N, X);
192   } else {
193     for (unsigned int i = 0; i < N; i++) {
194       tmp = X[i * incx];
195       sum += tmp * tmp;
196     }
197   }
198 #else
199   for (unsigned int i = 0; i < N; i++) {
200     tmp = X[i * incx];
201     sum += tmp * tmp;
202   }
203 #endif
204   return static_cast<_FP16>(sqrt(sum));
205 }
206
207 static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
208                        CBLAS_TRANSPOSE TransB, const unsigned int M,
209                        const unsigned int N, const unsigned int K,
210                        const float alpha, const _FP16 *A,
211                        const unsigned int lda, const _FP16 *B,
212                        const unsigned int ldb, const float beta, _FP16 *C,
213                        const unsigned int ldc) {
214
215 #ifdef USE__FP16
216   if ((N % 8 == 0) && (K % 8 == 0)) {
217     nntrainer::neon::sgemm_neon_fp16(A, B, C, M, N, K, alpha, beta,
218                                      TransA == CblasTrans,
219                                      TransB == CblasTrans);
220   } else {
221     sgemm_loop_fp16();
222   }
223 #else
224   sgemm_loop_fp16();
225 #endif
226 }
227
228 static unsigned int isamax_FP16(const unsigned int N, const _FP16 *X,
229                                 const int incX) {
230   unsigned int max_idx = 0;
231
232 #ifdef USE__FP16
233   if (incX == 1 && N >= 8) {
234     max_idx = nntrainer::neon::isamax_neon_fp16(N, X);
235   } else {
236     _FP16 max_val = X[0];
237     for (unsigned int n = 1; n < N; n += incX) {
238       _FP16 cur_val = (X[n] >= 0) ? X[n] : -1 * X[n];
239       if (cur_val > max_val) {
240         max_val = cur_val;
241         max_idx = n;
242       }
243     }
244   }
245 #else
246   _FP16 max_val = X[0];
247   for (unsigned int n = 1; n < N; n += incX) {
248     _FP16 cur_val = (X[n] >= 0) ? X[n] : -1 * X[n];
249     if (cur_val > max_val) {
250       max_val = cur_val;
251       max_idx = n;
252     }
253   }
254 #endif
255
256   return max_idx;
257 }
258
259 void saxpy(const unsigned int N, const float alpha, const _FP16 *X,
260            const int incX, _FP16 *Y, const int incY) {
261   saxpy_FP16(N, alpha, X, incX, Y, incY);
262 }
263
264 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
265            const unsigned int M, const unsigned int N, const unsigned int K,
266            const float alpha, const _FP16 *A, const unsigned int lda,
267            const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C,
268            const unsigned int ldc) {
269   sgemm_FP16(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
270              ldc);
271 }
272
273 void scopy(const unsigned int N, const _FP16 *X, const int incX, _FP16 *Y,
274            const int incY) {
275   scopy_FP16(N, X, incX, Y, incY);
276
277 } // namespace nntrainer
278
279 _FP16 snrm2(const int N, const _FP16 *X, const int incX) {
280   return snrm2_FP16(N, X, incX);
281 }
282
283 _FP16 sdot(const unsigned int N, const _FP16 *X, const unsigned int incX,
284            const _FP16 *Y, const unsigned int incY) {
285   return sdot_FP16(N, X, incX, Y, incY);
286 }
287
288 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
289            const unsigned int N, const float alpha, const _FP16 *A,
290            const unsigned int lda, const _FP16 *X, const int incX,
291            const float beta, _FP16 *Y, const int incY) {
292   sgemv_FP16(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
293 }
294
295 unsigned int isamax(const unsigned int N, const _FP16 *X, const int incX) {
296   /// @todo isamax_FP16 for BLAS_NUM_THREADS
297   return isamax_FP16(N, X, incX);
298 }
299
300 #endif
301
302 #ifndef USE_BLAS
303 static void saxpy_raw(const unsigned int N, const float alpha, const float *X,
304                       const int incX, float *Y, const int incY) {
305   if (incX < 0 or incY < 0)
306     throw std::invalid_argument(
307       "Error: negative inc not supported without cblas");
308   for (unsigned int i = 0; i < N; ++i)
309     Y[i * incY] = Y[i * incY] + X[i * incX] * alpha;
310 }
311
312 static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
313                       const unsigned int M, const unsigned int N,
314                       const float alpha, const float *A, const unsigned int lda,
315                       const float *X, const int incX, const float beta,
316                       float *Y, const int incY) {
317
318   unsigned int incy = abs(incY);
319   unsigned int incx = abs(incX);
320
321   if (TransA == CblasTrans) {
322     sgemv_loop(i, j, N, M);
323   } else {
324     sgemv_loop(j, i, M, N);
325   }
326 }
327
328 static float sdot_raw(const unsigned int N, const float *X,
329                       const unsigned int incX, const float *Y,
330                       const unsigned int incY) {
331   float ret = 0;
332   for (unsigned int i = 0; i < N; ++i) {
333     ret += X[i * incX] * Y[i * incY];
334   }
335   return ret;
336 }
337
338 static void scopy_raw(const unsigned int N, const float *X, const int incX,
339                       float *Y, const int incY) {
340   unsigned int incy = abs(incY);
341   unsigned int incx = abs(incX);
342
343   for (unsigned int i = 0; i < N; ++i)
344     Y[i * incy] = X[i * incx];
345 }
346
347 static void sscal_raw(const unsigned int N, const float alpha, float *X,
348                       const int incX) {
349   unsigned int incx = abs(incX);
350
351   for (unsigned int i = 0; i < N; ++i)
352     X[i * incx] = alpha * X[i * incx];
353 }
354
355 static float snrm2_raw(const unsigned int N, const float *X, const int incX) {
356   unsigned int incx = abs(incX);
357   float sum = 0.0f;
358   float tmp;
359
360   for (unsigned int i = 0; i < N; i++) {
361     tmp = X[i * incx];
362     sum += tmp * tmp;
363   }
364   return sqrt(sum);
365 }
366
367 static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
368                       CBLAS_TRANSPOSE TransB, const unsigned int M,
369                       const unsigned int N, const unsigned int K,
370                       const float alpha, const float *A, const unsigned int lda,
371                       const float *B, const unsigned int ldb, const float beta,
372                       float *C, const unsigned int ldc) {
373
374   for (unsigned int m = 0; m < M; ++m) {
375     for (unsigned int n = 0; n < N; ++n) {
376       double c = 0.0;
377       float c_old = C[m * ldc + n];
378       for (unsigned int k = 0; k < K; ++k) {
379         float a, b;
380         a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
381         b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
382         c += a * b;
383       }
384       C[m * ldc + n] = alpha * c;
385       if (beta != 0.0)
386         C[m * ldc + n] += beta * c_old;
387     }
388   }
389 }
390
391 static unsigned int isamax_raw(const unsigned int N, const float *X,
392                                const int incX) {
393
394   unsigned int max_idx = 0;
395   float max_val = X[0];
396   for (unsigned int n = 1; n < N; n += incX) {
397     float cur_val = abs(X[n]);
398     if (cur_val > max_val) {
399       max_val = cur_val;
400       max_idx = n;
401     }
402   }
403
404   return max_idx;
405 }
406
407 #endif
408
409 void sscal(const unsigned int N, const float alpha, void *X, const int incX,
410            ml::train::TensorDim::DataType d_type) {
411
412   if (d_type == ml::train::TensorDim::DataType::FP32) {
413
414 #ifdef USE_BLAS
415 #ifdef BLAS_NUM_THREADS
416     openblas_set_num_threads(BLAS_NUM_THREADS);
417 #endif // BLAS_NUM_THREADS
418     cblas_sscal(N, alpha, (float *)X, incX);
419 #else  // USE_BLAS else
420     sscal_raw(N, alpha, (float *)X, incX);
421 #endif //  USE_BLAS
422   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
423 #ifdef ENABLE_FP16
424     sscal(N, alpha, (_FP16 *)X, incX);
425 #else
426     throw std::invalid_argument("Error: enable-fp16 is not enabled");
427 #endif
428   }
429 }
430
431 void sscal(const unsigned int N, const float alpha, float *X, const int incX) {
432 #ifdef USE_BLAS
433 #ifdef BLAS_NUM_THREADS
434   openblas_set_num_threads(BLAS_NUM_THREADS);
435 #endif
436   cblas_sscal(N, alpha, X, incX);
437 #else
438   sscal_raw(N, alpha, X, incX);
439 #endif
440 }
441
442 void saxpy(const unsigned int N, const float alpha, const void *X,
443            const int incX, void *Y, const int incY,
444            ml::train::TensorDim::DataType d_type) {
445   if (d_type == ml::train::TensorDim::DataType::FP32) {
446 #ifdef USE_BLAS
447 #ifdef BLAS_NUM_THREADS
448     openblas_set_num_threads(BLAS_NUM_THREADS);
449 #endif
450     cblas_saxpy(N, alpha, static_cast<const float *>(X), incX,
451                 static_cast<float *>(Y), incY);
452 #else
453     saxpy_raw(N, alpha, static_cast<const float *>(X), incX,
454               static_cast<float *>(Y), incY);
455 #endif
456   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
457 #ifdef ENABLE_FP16
458     saxpy_FP16(N, alpha, static_cast<const _FP16 *>(X), incX,
459                static_cast<_FP16 *>(Y), incY);
460 #else
461     throw std::invalid_argument("Error: enable-fp16 is not enabled");
462 #endif
463   }
464 }
465
466 void saxpy(const unsigned int N, const float alpha, const float *X,
467            const int incX, float *Y, const int incY) {
468 #ifdef USE_BLAS
469 #ifdef BLAS_NUM_THREADS
470   openblas_set_num_threads(BLAS_NUM_THREADS);
471 #endif
472   cblas_saxpy(N, alpha, X, incX, Y, incY);
473 #else
474   saxpy_raw(N, alpha, X, incX, Y, incY);
475 #endif
476 }
477
478 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
479            const unsigned int M, const unsigned int N, const unsigned int K,
480            const float alpha, const void *A, const unsigned int lda,
481            const void *B, const unsigned int ldb, const float beta, void *C,
482            const unsigned int ldc, ml::train::TensorDim::DataType d_type) {
483
484   if (d_type == ml::train::TensorDim::DataType::FP32) {
485 #ifdef USE_CUBLAS
486     int devID = 0;
487     cudaDeviceProp deviceProp;
488     cudaGetDeviceProperties(&deviceProp, devID);
489     float *d_A, *d_B, *d_C;
490
491     unsigned int size_A = M * K * sizeof(float);
492     unsigned int size_B = K * N * sizeof(float);
493     unsigned int size_C = M * N * sizeof(float);
494
495     cudaMalloc((void **)&d_A, size_A);
496     cudaMalloc((void **)&d_B, size_B);
497     cudaMemcpy(d_A, A, size_A, cudaMemcpyHostToDevice);
498     cudaMemcpy(d_B, B, size_B, cudaMemcpyHostToDevice);
499     cudaMalloc((void **)&d_C, size_C);
500
501     cublasHandle_t handle;
502     cublasCreate(&handle);
503
504     cublasOperation_t transA =
505       (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
506     cublasOperation_t transB =
507       (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
508     cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
509                 d_C, N);
510
511     cudaMemcpy(C, d_C, size_C, cudaMemcpyDeviceToHost);
512     cublasDestroy(handle);
513
514 #elif defined USE_BLAS
515
516 #ifdef BLAS_NUM_THREADS
517     openblas_set_num_threads(BLAS_NUM_THREADS);
518 #endif
519
520     cblas_sgemm(
521       order, TransA, TransB, M, N, K, alpha, static_cast<const float *>(A), lda,
522       static_cast<const float *>(B), ldb, beta, static_cast<float *>(C), ldc);
523 #else
524     sgemm_raw(order, TransA, TransB, M, N, K, alpha,
525               static_cast<const float *>(A), lda, static_cast<const float *>(B),
526               ldb, beta, static_cast<float *>(C), ldc);
527 #endif
528
529   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
530 #ifdef ENABLE_FP16
531     sgemm_FP16(
532       order, TransA, TransB, M, N, K, alpha, static_cast<const _FP16 *>(A), lda,
533       static_cast<const _FP16 *>(B), ldb, beta, static_cast<_FP16 *>(C), ldc);
534 #else
535     throw std::invalid_argument("Error: enable-fp16 is not enabled");
536 #endif
537   }
538 } // namespace nntrainer
539
540 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
541            const unsigned int M, const unsigned int N, const unsigned int K,
542            const float alpha, const float *A, const unsigned int lda,
543            const float *B, const unsigned int ldb, const float beta, float *C,
544            const unsigned int ldc) {
545
546 #ifdef USE_CUBLAS
547   int devID = 0;
548   cudaDeviceProp deviceProp;
549   cudaGetDeviceProperties(&deviceProp, devID);
550   float *d_A, *d_B, *d_C;
551
552   unsigned int size_A = M * K * sizeof(float);
553   unsigned int size_B = K * N * sizeof(float);
554   unsigned int size_C = M * N * sizeof(float);
555
556   cudaMalloc((void **)&d_A, size_A);
557   cudaMalloc((void **)&d_B, size_B);
558   cudaMemcpy(d_A, A, size_A, cudaMemcpyHostToDevice);
559   cudaMemcpy(d_B, B, size_B, cudaMemcpyHostToDevice);
560   cudaMalloc((void **)&d_C, size_C);
561
562   cublasHandle_t handle;
563   cublasCreate(&handle);
564
565   cublasOperation_t transA = (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
566   cublasOperation_t transB = (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
567   cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
568               d_C, N);
569
570   cudaMemcpy(C, d_C, size_C, cudaMemcpyDeviceToHost);
571   cublasDestroy(handle);
572 #elif defined USE_BLAS
573 #ifdef BLAS_NUM_THREADS
574   openblas_set_num_threads(BLAS_NUM_THREADS);
575 #endif
576   cblas_sgemm(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
577               ldc);
578 #else
579   sgemm_raw(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
580             ldc);
581 #endif
582 }
583
584 void scopy(const unsigned int N, const void *X, const int incX, void *Y,
585            const int incY, ml::train::TensorDim::DataType d_type) {
586
587   if (d_type == ml::train::TensorDim::DataType::FP32) {
588
589 #ifdef USE_BLAS
590 #ifdef BLAS_NUM_THREADS
591     openblas_set_num_threads(BLAS_NUM_THREADS);
592 #endif
593     cblas_scopy(N, (float *)X, incX, (float *)Y, incY);
594 #else
595     scopy_raw(N, (float *)X, incX, (float *)Y, incY);
596 #endif
597
598   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
599 #ifdef ENABLE_FP16
600     scopy_FP16(N, (_FP16 *)X, incX, (_FP16 *)Y, incY);
601 #else
602     throw std::invalid_argument("Error: enable-fp16 is not enabled");
603 #endif
604   }
605
606 } // namespace nntrainer
607
608 void scopy(const unsigned int N, const float *X, const int incX, float *Y,
609            const int incY) {
610 #ifdef USE_BLAS
611 #ifdef BLAS_NUM_THREADS
612   openblas_set_num_threads(BLAS_NUM_THREADS);
613 #endif
614   cblas_scopy(N, X, incX, Y, incY);
615 #else
616   scopy_raw(N, X, incX, Y, incY);
617 #endif
618 } // namespace nntrainer
619
620 float snrm2(const int N, const float *X, const int incX) {
621 #ifdef USE_BLAS
622 #ifdef BLAS_NUM_THREADS
623   openblas_set_num_threads(BLAS_NUM_THREADS);
624 #endif
625   return cblas_snrm2(N, X, incX);
626 #else
627   return snrm2_raw(N, X, incX);
628 #endif
629 }
630
631 float sdot(const unsigned int N, const float *X, const unsigned int incX,
632            const float *Y, const unsigned int incY) {
633 #ifdef USE_BLAS
634 #ifdef BLAS_NUM_THREADS
635   openblas_set_num_threads(BLAS_NUM_THREADS);
636 #endif
637   return cblas_sdot(N, X, incX, Y, incY);
638 #else
639   return sdot_raw(N, X, incX, Y, incY);
640 #endif
641 }
642
643 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
644            const unsigned int N, const float alpha, const void *A,
645            const unsigned int lda, const void *X, const int incX,
646            const float beta, void *Y, const int incY,
647            ml::train::TensorDim::DataType d_type) {
648   if (d_type == ml::train::TensorDim::DataType::FP32) {
649 #ifdef USE_BLAS
650 #ifdef BLAS_NUM_THREADS
651     openblas_set_num_threads(BLAS_NUM_THREADS);
652 #endif
653     return cblas_sgemv(
654       order, TransA, M, N, alpha, static_cast<const float *>(A), lda,
655       static_cast<const float *>(X), incX, beta, static_cast<float *>(Y), incY);
656 #else
657
658     return sgemv_raw(order, TransA, M, N, alpha, static_cast<const float *>(A),
659                      lda, static_cast<const float *>(X), incX, beta,
660                      static_cast<float *>(Y), incY);
661 #endif
662   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
663 #ifdef ENABLE_FP16
664     return sgemv_FP16(order, TransA, M, N, alpha, static_cast<const _FP16 *>(A),
665                       lda, static_cast<const _FP16 *>(X), incX, beta,
666                       static_cast<_FP16 *>(Y), incY);
667 #else
668     throw std::invalid_argument("Error: enable-fp16 is not enabled");
669 #endif
670   }
671 }
672
673 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
674            const unsigned int N, const float alpha, const float *A,
675            const unsigned int lda, const float *X, const int incX,
676            const float beta, float *Y, const int incY) {
677 #ifdef USE_BLAS
678 #ifdef BLAS_NUM_THREADS
679   openblas_set_num_threads(BLAS_NUM_THREADS);
680 #endif
681   return cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y,
682                      incY);
683 #else
684   return sgemv_raw(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
685 #endif
686 }
687
688 unsigned int isamax(const unsigned int N, const float *X, const int incX) {
689 #ifdef USE_BLAS
690 #ifdef BLAS_NUM_THREADS
691   openblas_set_num_threads(BLAS_NUM_THREADS);
692 #endif
693   return cblas_isamax(N, X, incX);
694 #else
695   return isamax_raw(N, X, incX);
696 #endif
697 }
698
699 } // namespace nntrainer