[ blas/neon ] Add NEON fp16 function for snrm2
[platform/core/ml/nntrainer.git] / nntrainer / tensor / blas_interface.cpp
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
4  *
5  * @file   blas_interface.cpp
6  * @date   28 Aug 2020
7  * @see    https://github.com/nnstreamer/nntrainer
8  * @author Jijoong Moon <jijoong.moon@samsung.com>
9  * @bug    No known bugs except for NYI items
10  * @brief  This is dummy header for blas support
11  *
12  */
13
14 #include <blas_interface.h>
15 #include <nntrainer_error.h>
16
17 #ifdef USE__FP16
18 #include <blas_neon.h>
19 #endif
20
21 #include <cmath>
22
23 #define sgemv_loop(ci, cj, cM, cN)           \
24   do {                                       \
25     float y0;                                \
26     unsigned int i, j;                       \
27     for (ci = 0; ci != cM; ci++) {           \
28       y0 = Y[ci * incy] * beta;              \
29       for (cj = 0; cj != cN; cj++)           \
30         y0 += A[i + j * lda] * X[cj * incx]; \
31       Y[ci * incy] = y0;                     \
32     }                                        \
33   } while (0);
34
35 #define sgemv_loop_fp16(ci, cj, cM, cN)             \
36   do {                                              \
37     _FP16 y0;                                       \
38     unsigned int i, j;                              \
39     for (ci = 0; ci != cM; ci++) {                  \
40       y0 = Y[ci * incy] * static_cast<_FP16>(beta); \
41       for (cj = 0; cj != cN; cj++)                  \
42         y0 += A[i + j * lda] * X[cj * incx];        \
43       Y[ci * incy] = y0;                            \
44     }                                               \
45   } while (0);
46
47 #define saxpy_loop_fp16()                                                  \
48   do {                                                                     \
49     unsigned int i;                                                        \
50     for (i = 0; i < N; ++i)                                                \
51       Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX]; \
52   } while (0);
53
54 namespace nntrainer {
55
56 #ifdef ENABLE_FP16
57 static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X,
58                        const int incX, _FP16 *Y, const int incY) {
59   if (incX < 0 or incY < 0)
60     throw std::invalid_argument(
61       "Error: negative inc not supported without cblas");
62
63 #ifdef USE__FP16
64   // USE__FP16 is defined when platform is android
65   if (incX == 1 && incY == 1) {
66     nntrainer::neon::saxpy_neon_fp16(N, alpha, X, Y);
67   } else {
68     saxpy_loop_fp16();
69   }
70 #else
71   saxpy_loop_fp16();
72 #endif
73 }
74
75 static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
76                        const unsigned int M, const unsigned int N,
77                        const float alpha, const _FP16 *A,
78                        const unsigned int lda, const _FP16 *X, const int incX,
79                        const float beta, _FP16 *Y, const int incY) {
80
81   unsigned int incy = abs(incY);
82   unsigned int incx = abs(incX);
83
84   if (TransA == CblasTrans) {
85 #ifdef USE__FP16
86     if (incX == 1 && incY == 1 && (N % 16 == 0 || N % 8 == 0)) {
87       nntrainer::neon::sgemv_transpose_neon_fp16(A, X, Y, M, N, alpha, beta);
88     } else {
89       sgemv_loop_fp16(i, j, N, M);
90     }
91 #else
92     sgemv_loop_fp16(i, j, N, M);
93 #endif
94   } else {
95 #ifdef USE__FP16
96     if (incX == 1 && incY == 1 && (N % 16 == 0 || N % 8 == 0)) {
97       nntrainer::neon::sgemv_neon_fp16(A, X, Y, M, N, alpha, beta);
98     } else {
99       sgemv_loop_fp16(j, i, M, N);
100     }
101 #else
102     sgemv_loop_fp16(j, i, M, N);
103 #endif
104   }
105 }
106
107 static _FP16 sdot_FP16(const unsigned int N, const _FP16 *X,
108                        const unsigned int incX, const _FP16 *Y,
109                        const unsigned int incY) {
110
111   if (incX < 0 or incY < 0)
112     throw std::invalid_argument("Error: negative inc not supported");
113
114   _FP16 ret = 0;
115
116 #ifdef USE__FP16
117   if (incX == 1 && incY == 1) {
118     ret = nntrainer::neon::sdot_neon_fp16(N, X, Y);
119   } else {
120     for (unsigned int i = 0; i < N; ++i) {
121       ret += X[i * incX] * Y[i * incY];
122     }
123   }
124 #else
125   for (unsigned int i = 0; i < N; ++i) {
126     ret += X[i * incX] * Y[i * incY];
127   }
128 #endif
129   return ret;
130 }
131
132 static void scopy_FP16(const unsigned int N, const _FP16 *X, const int incX,
133                        _FP16 *Y, const int incY) {
134   unsigned int incy = abs(incY);
135   unsigned int incx = abs(incX);
136
137   for (unsigned int i = 0; i < N; ++i)
138     Y[i * incy] = X[i * incx];
139 }
140
141 void sscal(const unsigned int N, const float alpha, _FP16 *X, const int incX) {
142   unsigned int incx = abs(incX);
143
144   for (unsigned int i = 0; i < N; ++i)
145     X[i * incx] = static_cast<_FP16>(alpha) * X[i * incx];
146 }
147
148 static _FP16 snrm2_FP16(const unsigned int N, const _FP16 *X, const int incX) {
149   unsigned int incx = abs(incX);
150   _FP16 sum = 0;
151   _FP16 tmp;
152 #ifdef USE__FP16
153   if (incX == 1) {
154     sum = nntrainer::neon::snrm2_neon_fp16(N, X);
155   } else {
156     for (unsigned int i = 0; i < N; i++) {
157       tmp = X[i * incx];
158       sum += tmp * tmp;
159     }
160   }
161 #else
162   for (unsigned int i = 0; i < N; i++) {
163     tmp = X[i * incx];
164     sum += tmp * tmp;
165   }
166 #endif
167   return static_cast<_FP16>(sqrt(sum));
168 }
169
170 static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
171                        CBLAS_TRANSPOSE TransB, const unsigned int M,
172                        const unsigned int N, const unsigned int K,
173                        const float alpha, const _FP16 *A,
174                        const unsigned int lda, const _FP16 *B,
175                        const unsigned int ldb, const float beta, _FP16 *C,
176                        const unsigned int ldc) {
177
178   for (unsigned int m = 0; m < M; ++m) {
179     for (unsigned int n = 0; n < N; ++n) {
180       _FP16 c = 0;
181       _FP16 c_old = C[m * ldc + n];
182       for (unsigned int k = 0; k < K; ++k) {
183         _FP16 a, b;
184         a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
185         b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
186         c += a * b;
187       }
188       C[m * ldc + n] = static_cast<_FP16>(alpha) * c;
189       if (beta != 0.0)
190         C[m * ldc + n] += static_cast<_FP16>(beta) * c_old;
191     }
192   }
193 }
194
195 static unsigned int isamax_FP16(const unsigned int N, const _FP16 *X,
196                                 const int incX) {
197
198   unsigned int max_idx = 0;
199   _FP16 max_val = X[0];
200   for (unsigned int n = 1; n < N; n += incX) {
201     _FP16 cur_val = (X[n] >= 0) ? X[n] : -1 * X[n];
202     if (cur_val > max_val) {
203       max_val = cur_val;
204       max_idx = n;
205     }
206   }
207
208   return max_idx;
209 }
210
211 void saxpy(const unsigned int N, const float alpha, const _FP16 *X,
212            const int incX, _FP16 *Y, const int incY) {
213   saxpy_FP16(N, alpha, X, incX, Y, incY);
214 }
215
216 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
217            const unsigned int M, const unsigned int N, const unsigned int K,
218            const float alpha, const _FP16 *A, const unsigned int lda,
219            const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C,
220            const unsigned int ldc) {
221   sgemm_FP16(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
222              ldc);
223 }
224
225 void scopy(const unsigned int N, const _FP16 *X, const int incX, _FP16 *Y,
226            const int incY) {
227   scopy_FP16(N, X, incX, Y, incY);
228
229 } // namespace nntrainer
230
231 _FP16 snrm2(const int N, const _FP16 *X, const int incX) {
232   return snrm2_FP16(N, X, incX);
233 }
234
235 _FP16 sdot(const unsigned int N, const _FP16 *X, const unsigned int incX,
236            const _FP16 *Y, const unsigned int incY) {
237   return sdot_FP16(N, X, incX, Y, incY);
238 }
239
240 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
241            const unsigned int N, const float alpha, const _FP16 *A,
242            const unsigned int lda, const _FP16 *X, const int incX,
243            const float beta, _FP16 *Y, const int incY) {
244   sgemv_FP16(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
245 }
246
247 unsigned int isamax(const unsigned int N, const _FP16 *X, const int incX) {
248   /// @todo isamax_FP16 for BLAS_NUM_THREADS
249   return isamax_FP16(N, X, incX);
250 }
251
252 #endif
253
254 #ifndef USE_BLAS
255 static void saxpy_raw(const unsigned int N, const float alpha, const float *X,
256                       const int incX, float *Y, const int incY) {
257   if (incX < 0 or incY < 0)
258     throw std::invalid_argument(
259       "Error: negative inc not supported without cblas");
260   for (unsigned int i = 0; i < N; ++i)
261     Y[i * incY] = Y[i * incY] + X[i * incX] * alpha;
262 }
263
264 static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
265                       const unsigned int M, const unsigned int N,
266                       const float alpha, const float *A, const unsigned int lda,
267                       const float *X, const int incX, const float beta,
268                       float *Y, const int incY) {
269
270   unsigned int incy = abs(incY);
271   unsigned int incx = abs(incX);
272
273   if (TransA == CblasTrans) {
274     sgemv_loop(i, j, N, M);
275   } else {
276     sgemv_loop(j, i, M, N);
277   }
278 }
279
280 static float sdot_raw(const unsigned int N, const float *X,
281                       const unsigned int incX, const float *Y,
282                       const unsigned int incY) {
283   float ret = 0;
284   for (unsigned int i = 0; i < N; ++i) {
285     ret += X[i * incX] * Y[i * incY];
286   }
287   return ret;
288 }
289
290 static void scopy_raw(const unsigned int N, const float *X, const int incX,
291                       float *Y, const int incY) {
292   unsigned int incy = abs(incY);
293   unsigned int incx = abs(incX);
294
295   for (unsigned int i = 0; i < N; ++i)
296     Y[i * incy] = X[i * incx];
297 }
298
299 static void sscal_raw(const unsigned int N, const float alpha, float *X,
300                       const int incX) {
301   unsigned int incx = abs(incX);
302
303   for (unsigned int i = 0; i < N; ++i)
304     X[i * incx] = alpha * X[i * incx];
305 }
306
307 static float snrm2_raw(const unsigned int N, const float *X, const int incX) {
308   unsigned int incx = abs(incX);
309   float sum = 0.0f;
310   float tmp;
311
312   for (unsigned int i = 0; i < N; i++) {
313     tmp = X[i * incx];
314     sum += tmp * tmp;
315   }
316   return sqrt(sum);
317 }
318
319 static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
320                       CBLAS_TRANSPOSE TransB, const unsigned int M,
321                       const unsigned int N, const unsigned int K,
322                       const float alpha, const float *A, const unsigned int lda,
323                       const float *B, const unsigned int ldb, const float beta,
324                       float *C, const unsigned int ldc) {
325
326   for (unsigned int m = 0; m < M; ++m) {
327     for (unsigned int n = 0; n < N; ++n) {
328       double c = 0.0;
329       float c_old = C[m * ldc + n];
330       for (unsigned int k = 0; k < K; ++k) {
331         float a, b;
332         a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
333         b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
334         c += a * b;
335       }
336       C[m * ldc + n] = alpha * c;
337       if (beta != 0.0)
338         C[m * ldc + n] += beta * c_old;
339     }
340   }
341 }
342
343 static unsigned int isamax_raw(const unsigned int N, const float *X,
344                                const int incX) {
345
346   unsigned int max_idx = 0;
347   float max_val = X[0];
348   for (unsigned int n = 1; n < N; n += incX) {
349     float cur_val = abs(X[n]);
350     if (cur_val > max_val) {
351       max_val = cur_val;
352       max_idx = n;
353     }
354   }
355
356   return max_idx;
357 }
358
359 #endif
360
361 void sscal(const unsigned int N, const float alpha, void *X, const int incX,
362            ml::train::TensorDim::DataType d_type) {
363
364   if (d_type == ml::train::TensorDim::DataType::FP32) {
365
366 #ifdef USE_BLAS
367 #ifdef BLAS_NUM_THREADS
368     openblas_set_num_threads(BLAS_NUM_THREADS);
369 #endif // BLAS_NUM_THREADS
370     cblas_sscal(N, alpha, (float *)X, incX);
371 #else  // USE_BLAS else
372     sscal_raw(N, alpha, (float *)X, incX);
373 #endif //  USE_BLAS
374   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
375 #ifdef ENABLE_FP16
376     sscal(N, alpha, (_FP16 *)X, incX);
377 #else
378     throw std::invalid_argument("Error: enable-fp16 is not enabled");
379 #endif
380   }
381 }
382
383 void sscal(const unsigned int N, const float alpha, float *X, const int incX) {
384 #ifdef USE_BLAS
385 #ifdef BLAS_NUM_THREADS
386   openblas_set_num_threads(BLAS_NUM_THREADS);
387 #endif
388   cblas_sscal(N, alpha, X, incX);
389 #else
390   sscal_raw(N, alpha, X, incX);
391 #endif
392 }
393
394 void saxpy(const unsigned int N, const float alpha, const void *X,
395            const int incX, void *Y, const int incY,
396            ml::train::TensorDim::DataType d_type) {
397   if (d_type == ml::train::TensorDim::DataType::FP32) {
398 #ifdef USE_BLAS
399 #ifdef BLAS_NUM_THREADS
400     openblas_set_num_threads(BLAS_NUM_THREADS);
401 #endif
402     cblas_saxpy(N, alpha, static_cast<const float *>(X), incX,
403                 static_cast<float *>(Y), incY);
404 #else
405     saxpy_raw(N, alpha, static_cast<const float *>(X), incX,
406               static_cast<float *>(Y), incY);
407 #endif
408   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
409 #ifdef ENABLE_FP16
410     saxpy_FP16(N, alpha, static_cast<const _FP16 *>(X), incX,
411                static_cast<_FP16 *>(Y), incY);
412 #else
413     throw std::invalid_argument("Error: enable-fp16 is not enabled");
414 #endif
415   }
416 }
417
418 void saxpy(const unsigned int N, const float alpha, const float *X,
419            const int incX, float *Y, const int incY) {
420 #ifdef USE_BLAS
421 #ifdef BLAS_NUM_THREADS
422   openblas_set_num_threads(BLAS_NUM_THREADS);
423 #endif
424   cblas_saxpy(N, alpha, X, incX, Y, incY);
425 #else
426   saxpy_raw(N, alpha, X, incX, Y, incY);
427 #endif
428 }
429
430 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
431            const unsigned int M, const unsigned int N, const unsigned int K,
432            const float alpha, const void *A, const unsigned int lda,
433            const void *B, const unsigned int ldb, const float beta, void *C,
434            const unsigned int ldc, ml::train::TensorDim::DataType d_type) {
435
436   if (d_type == ml::train::TensorDim::DataType::FP32) {
437 #ifdef USE_CUBLAS
438     int devID = 0;
439     cudaDeviceProp deviceProp;
440     cudaGetDeviceProperties(&deviceProp, devID);
441     float *d_A, *d_B, *d_C;
442
443     unsigned int size_A = M * K * sizeof(float);
444     unsigned int size_B = K * N * sizeof(float);
445     unsigned int size_C = M * N * sizeof(float);
446
447     cudaMalloc((void **)&d_A, size_A);
448     cudaMalloc((void **)&d_B, size_B);
449     cudaMemcpy(d_A, A, size_A, cudaMemcpyHostToDevice);
450     cudaMemcpy(d_B, B, size_B, cudaMemcpyHostToDevice);
451     cudaMalloc((void **)&d_C, size_C);
452
453     cublasHandle_t handle;
454     cublasCreate(&handle);
455
456     cublasOperation_t transA =
457       (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
458     cublasOperation_t transB =
459       (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
460     cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
461                 d_C, N);
462
463     cudaMemcpy(C, d_C, size_C, cudaMemcpyDeviceToHost);
464     cublasDestroy(handle);
465
466 #elif defined USE_BLAS
467
468 #ifdef BLAS_NUM_THREADS
469     openblas_set_num_threads(BLAS_NUM_THREADS);
470 #endif
471
472     cblas_sgemm(
473       order, TransA, TransB, M, N, K, alpha, static_cast<const float *>(A), lda,
474       static_cast<const float *>(B), ldb, beta, static_cast<float *>(C), ldc);
475 #else
476     sgemm_raw(order, TransA, TransB, M, N, K, alpha,
477               static_cast<const float *>(A), lda, static_cast<const float *>(B),
478               ldb, beta, static_cast<float *>(C), ldc);
479 #endif
480
481   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
482 #ifdef ENABLE_FP16
483     sgemm_FP16(
484       order, TransA, TransB, M, N, K, alpha, static_cast<const _FP16 *>(A), lda,
485       static_cast<const _FP16 *>(B), ldb, beta, static_cast<_FP16 *>(C), ldc);
486 #else
487     throw std::invalid_argument("Error: enable-fp16 is not enabled");
488 #endif
489   }
490 } // namespace nntrainer
491
492 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
493            const unsigned int M, const unsigned int N, const unsigned int K,
494            const float alpha, const float *A, const unsigned int lda,
495            const float *B, const unsigned int ldb, const float beta, float *C,
496            const unsigned int ldc) {
497
498 #ifdef USE_CUBLAS
499   int devID = 0;
500   cudaDeviceProp deviceProp;
501   cudaGetDeviceProperties(&deviceProp, devID);
502   float *d_A, *d_B, *d_C;
503
504   unsigned int size_A = M * K * sizeof(float);
505   unsigned int size_B = K * N * sizeof(float);
506   unsigned int size_C = M * N * sizeof(float);
507
508   cudaMalloc((void **)&d_A, size_A);
509   cudaMalloc((void **)&d_B, size_B);
510   cudaMemcpy(d_A, A, size_A, cudaMemcpyHostToDevice);
511   cudaMemcpy(d_B, B, size_B, cudaMemcpyHostToDevice);
512   cudaMalloc((void **)&d_C, size_C);
513
514   cublasHandle_t handle;
515   cublasCreate(&handle);
516
517   cublasOperation_t transA = (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
518   cublasOperation_t transB = (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
519   cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
520               d_C, N);
521
522   cudaMemcpy(C, d_C, size_C, cudaMemcpyDeviceToHost);
523   cublasDestroy(handle);
524 #elif defined USE_BLAS
525 #ifdef BLAS_NUM_THREADS
526   openblas_set_num_threads(BLAS_NUM_THREADS);
527 #endif
528   cblas_sgemm(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
529               ldc);
530 #else
531   sgemm_raw(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
532             ldc);
533 #endif
534 }
535
536 void scopy(const unsigned int N, const void *X, const int incX, void *Y,
537            const int incY, ml::train::TensorDim::DataType d_type) {
538
539   if (d_type == ml::train::TensorDim::DataType::FP32) {
540
541 #ifdef USE_BLAS
542 #ifdef BLAS_NUM_THREADS
543     openblas_set_num_threads(BLAS_NUM_THREADS);
544 #endif
545     cblas_scopy(N, (float *)X, incX, (float *)Y, incY);
546 #else
547     scopy_raw(N, (float *)X, incX, (float *)Y, incY);
548 #endif
549
550   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
551 #ifdef ENABLE_FP16
552     scopy_FP16(N, (_FP16 *)X, incX, (_FP16 *)Y, incY);
553 #else
554     throw std::invalid_argument("Error: enable-fp16 is not enabled");
555 #endif
556   }
557
558 } // namespace nntrainer
559
560 void scopy(const unsigned int N, const float *X, const int incX, float *Y,
561            const int incY) {
562 #ifdef USE_BLAS
563 #ifdef BLAS_NUM_THREADS
564   openblas_set_num_threads(BLAS_NUM_THREADS);
565 #endif
566   cblas_scopy(N, X, incX, Y, incY);
567 #else
568   scopy_raw(N, X, incX, Y, incY);
569 #endif
570 } // namespace nntrainer
571
572 float snrm2(const int N, const float *X, const int incX) {
573 #ifdef USE_BLAS
574 #ifdef BLAS_NUM_THREADS
575   openblas_set_num_threads(BLAS_NUM_THREADS);
576 #endif
577   return cblas_snrm2(N, X, incX);
578 #else
579   return snrm2_raw(N, X, incX);
580 #endif
581 }
582
583 float sdot(const unsigned int N, const float *X, const unsigned int incX,
584            const float *Y, const unsigned int incY) {
585 #ifdef USE_BLAS
586 #ifdef BLAS_NUM_THREADS
587   openblas_set_num_threads(BLAS_NUM_THREADS);
588 #endif
589   return cblas_sdot(N, X, incX, Y, incY);
590 #else
591   return sdot_raw(N, X, incX, Y, incY);
592 #endif
593 }
594
595 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
596            const unsigned int N, const float alpha, const void *A,
597            const unsigned int lda, const void *X, const int incX,
598            const float beta, void *Y, const int incY,
599            ml::train::TensorDim::DataType d_type) {
600   if (d_type == ml::train::TensorDim::DataType::FP32) {
601 #ifdef USE_BLAS
602 #ifdef BLAS_NUM_THREADS
603     openblas_set_num_threads(BLAS_NUM_THREADS);
604 #endif
605     return cblas_sgemv(
606       order, TransA, M, N, alpha, static_cast<const float *>(A), lda,
607       static_cast<const float *>(X), incX, beta, static_cast<float *>(Y), incY);
608 #else
609
610     return sgemv_raw(order, TransA, M, N, alpha, static_cast<const float *>(A),
611                      lda, static_cast<const float *>(X), incX, beta,
612                      static_cast<float *>(Y), incY);
613 #endif
614   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
615 #ifdef ENABLE_FP16
616     return sgemv_FP16(order, TransA, M, N, alpha, static_cast<const _FP16 *>(A),
617                       lda, static_cast<const _FP16 *>(X), incX, beta,
618                       static_cast<_FP16 *>(Y), incY);
619 #else
620     throw std::invalid_argument("Error: enable-fp16 is not enabled");
621 #endif
622   }
623 }
624
625 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
626            const unsigned int N, const float alpha, const float *A,
627            const unsigned int lda, const float *X, const int incX,
628            const float beta, float *Y, const int incY) {
629 #ifdef USE_BLAS
630 #ifdef BLAS_NUM_THREADS
631   openblas_set_num_threads(BLAS_NUM_THREADS);
632 #endif
633   return cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y,
634                      incY);
635 #else
636   return sgemv_raw(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
637 #endif
638 }
639
640 unsigned int isamax(const unsigned int N, const float *X, const int incX) {
641 #ifdef USE_BLAS
642 #ifdef BLAS_NUM_THREADS
643   openblas_set_num_threads(BLAS_NUM_THREADS);
644 #endif
645   return cblas_isamax(N, X, incX);
646 #else
647   return isamax_raw(N, X, incX);
648 #endif
649 }
650
651 } // namespace nntrainer