[ Tensor ] Remove CBLAS params from Tensor related files.
authorskykongkong8 <ss.kong@samsung.com>
Mon, 12 Aug 2024 04:15:53 +0000 (13:15 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 22 Aug 2024 22:46:53 +0000 (07:46 +0900)
- Remove cblas params from tensor related files since nntrainer is not fully-dependent on cblas anymore.
- Letting tensors to be aware of Cblas related parameters is a nonsense at the first place.
- CBLAS params will be declared only when functions from cblas is called.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/blas_interface.cpp
nntrainer/tensor/blas_interface.h
nntrainer/tensor/cl_operations/blas_kernel_interface.cpp
nntrainer/tensor/cl_operations/blas_kernels.cpp
nntrainer/tensor/cl_operations/blas_kernels.h
nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp
nntrainer/tensor/float_tensor.cpp
nntrainer/tensor/half_tensor.cpp

index d4d21d4fc5bb02854dc05431543d3982d49ee59a..4832ebd8796bd63fdd6169be2c2a5917902888b9 100644 (file)
 #include <blas_avx.h>
 #endif
 
+#ifdef USE_BLAS
+extern "C" {
+#include <cblas.h>
+}
+#endif
+
 #include <cmath>
 
 #define sgemv_loop(ci, cj, cM, cN)           \
       Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX]; \
   } while (0);
 
-#define hgemm_loop()                                                      \
-  do {                                                                    \
-    for (unsigned int m = 0; m < M; ++m) {                                \
-      for (unsigned int n = 0; n < N; ++n) {                              \
-        float c = 0;                                                      \
-        _FP16 c_old = C[m * ldc + n];                                     \
-        for (unsigned int k = 0; k < K; ++k) {                            \
-          _FP16 a, b;                                                     \
-          a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]); \
-          b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]); \
-          c += static_cast<float>(a * b);                                 \
-        }                                                                 \
-        C[m * ldc + n] = static_cast<_FP16>(alpha * c);                   \
-        if (beta != 0.0)                                                  \
-          C[m * ldc + n] += static_cast<_FP16>(beta) * c_old;             \
-      }                                                                   \
-    }                                                                     \
+#define hgemm_loop()                                          \
+  do {                                                        \
+    for (unsigned int m = 0; m < M; ++m) {                    \
+      for (unsigned int n = 0; n < N; ++n) {                  \
+        float c = 0;                                          \
+        _FP16 c_old = C[m * ldc + n];                         \
+        for (unsigned int k = 0; k < K; ++k) {                \
+          _FP16 a, b;                                         \
+          a = ((TransA) ? A[k * lda + m] : A[m * lda + k]);   \
+          b = ((TransB) ? B[n * ldb + k] : B[k * ldb + n]);   \
+          c += static_cast<float>(a * b);                     \
+        }                                                     \
+        C[m * ldc + n] = static_cast<_FP16>(alpha * c);       \
+        if (beta != 0.0)                                      \
+          C[m * ldc + n] += static_cast<_FP16>(beta) * c_old; \
+      }                                                       \
+    }                                                         \
   } while (0);
 
 namespace nntrainer {
@@ -93,8 +99,7 @@ static inline void transpose_fallback(unsigned int M, unsigned int N,
 static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X,
                        const int incX, _FP16 *Y, const int incY) {
   if (incX < 0 or incY < 0)
-    throw std::invalid_argument(
-      "Error: negative inc not supported without cblas");
+    throw std::invalid_argument("Error: negative inc not supported");
 
 #if (defined USE__FP16 && USE_NEON)
   // USE__FP16 is defined when platform is android
@@ -108,22 +113,22 @@ static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X,
 #endif
 }
 
-static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
+static void sgemv_FP16(const unsigned int TStorageOrder, bool TransA,
                        const unsigned int M, const unsigned int N,
                        const float alpha, const _FP16 *A,
                        const unsigned int lda, const _FP16 *X, const int incX,
                        const float beta, _FP16 *Y, const int incY) {
 #if (defined USE__FP16 && USE_NEON)
-  if (TransA == CblasTrans) {
+  if (TransA) {
     nntrainer::neon::hgemv_transpose(A, X, Y, M, N, alpha, beta);
   } else {
     nntrainer::neon::hgemv(A, X, Y, M, N, alpha, beta);
   }
 #else
   unsigned int lenX =
-    (TransA == CblasTrans) ? 1 + (M - 1) * abs(incX) : 1 + (N - 1) * abs(incX);
+    (TransA) ? 1 + (M - 1) * abs(incX) : 1 + (N - 1) * abs(incX);
   unsigned int lenY =
-    (TransA == CblasTrans) ? 1 + (N - 1) * abs(incY) : 1 + (M - 1) * abs(incY);
+    (TransA) ? 1 + (N - 1) * abs(incY) : 1 + (M - 1) * abs(incY);
 
   float *A_ = new float[M * N];
   float *X_ = new float[lenX];
@@ -317,18 +322,20 @@ static _FP16 snrm2_FP16(const unsigned int N, const _FP16 *X, const int incX) {
   return sum;
 }
 
-static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
-                       CBLAS_TRANSPOSE TransB, const unsigned int M,
-                       const unsigned int N, const unsigned int K,
-                       const float alpha, const _FP16 *A,
+static void sgemm_FP16(const unsigned int TStorageOrder, bool TransA,
+                       bool TransB, const unsigned int M, const unsigned int N,
+                       const unsigned int K, const float alpha, const _FP16 *A,
                        const unsigned int lda, const _FP16 *B,
                        const unsigned int ldb, const float beta, _FP16 *C,
                        const unsigned int ldc) {
 
 #if (defined USE__FP16 && USE_NEON)
-  nntrainer::neon::custom_hgemm(A, B, C, M, N, K, alpha, beta,
-                                TransA == CblasTrans, TransB == CblasTrans);
+  nntrainer::neon::custom_hgemm(A, B, C, M, N, K, alpha, beta, TransA, TransB);
 #else
+  CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans;
+  CBLAS_TRANSPOSE transB = TransB ? CblasTrans : CblasNoTrans;
+  CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor;
+
   float *A_ = new float[M * K];
   float *B_ = new float[N * K];
   float *C_ = new float[M * N];
@@ -336,7 +343,7 @@ static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
   scopy(M * K, A, 1, A_, 1);
   scopy(N * K, B, 1, B_, 1);
   scopy(M * N, C, 1, C_, 1);
-  sgemm(order, TransA, TransB, M, N, K, alpha, A_, lda, B_, ldb, beta, C_, ldc);
+  sgemm(order, transA, transB, M, N, K, alpha, A_, lda, B_, ldb, beta, C_, ldc);
   scopy(M * N, C_, 1, C, 1);
 
   delete[] A_;
@@ -381,13 +388,13 @@ void saxpy(const unsigned int N, const float alpha, const _FP16 *X,
   saxpy_FP16(N, alpha, X, incX, Y, incY);
 }
 
-void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
            const float alpha, const _FP16 *A, const unsigned int lda,
            const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C,
            const unsigned int ldc) {
-  sgemm_FP16(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
-             ldc);
+  sgemm_FP16(TStorageOrder, TransA, TransB, M, N, K, alpha, A, lda, B, ldb,
+             beta, C, ldc);
 }
 
 void scopy(const unsigned int N, const _FP16 *X, const int incX, _FP16 *Y,
@@ -520,11 +527,12 @@ _FP16 sdot(const unsigned int N, const _FP16 *X, const unsigned int incX,
   return sdot_FP16(N, X, incX, Y, incY);
 }
 
-void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M,
            const unsigned int N, const float alpha, const _FP16 *A,
            const unsigned int lda, const _FP16 *X, const int incX,
            const float beta, _FP16 *Y, const int incY) {
-  sgemv_FP16(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+  sgemv_FP16(TStorageOrder, TransA, M, N, alpha, A, lda, X, incX, beta, Y,
+             incY);
 }
 
 unsigned int isamax(const unsigned int N, const _FP16 *X, const int incX) {
@@ -557,13 +565,12 @@ void transpose_matrix(const unsigned int M, const unsigned int N,
 static void saxpy_raw(const unsigned int N, const float alpha, const float *X,
                       const int incX, float *Y, const int incY) {
   if (incX < 0 or incY < 0)
-    throw std::invalid_argument(
-      "Error: negative inc not supported without cblas");
+    throw std::invalid_argument("Error: negative inc not supported");
   for (unsigned int i = 0; i < N; ++i)
     Y[i * incY] = Y[i * incY] + X[i * incX] * alpha;
 }
 
-static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
+static void sgemv_raw(const unsigned int TStorageOrder, bool TransA,
                       const unsigned int M, const unsigned int N,
                       const float alpha, const float *A, const unsigned int lda,
                       const float *X, const int incX, const float beta,
@@ -572,7 +579,7 @@ static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
   unsigned int incy = abs(incY);
   unsigned int incx = abs(incX);
 
-  if (TransA == CblasTrans) {
+  if (TransA) {
     sgemv_loop(i, j, N, M);
   } else {
     sgemv_loop(j, i, M, N);
@@ -618,12 +625,12 @@ static float snrm2_raw(const unsigned int N, const float *X, const int incX) {
   return sqrt(sum);
 }
 
-static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
-                      CBLAS_TRANSPOSE TransB, const unsigned int M,
-                      const unsigned int N, const unsigned int K,
-                      const float alpha, const float *A, const unsigned int lda,
-                      const float *B, const unsigned int ldb, const float beta,
-                      float *C, const unsigned int ldc) {
+static void sgemm_raw(const unsigned int TStorageOrder, bool TransA,
+                      bool TransB, const unsigned int M, const unsigned int N,
+                      const unsigned int K, const float alpha, const float *A,
+                      const unsigned int lda, const float *B,
+                      const unsigned int ldb, const float beta, float *C,
+                      const unsigned int ldc) {
 
   for (unsigned int m = 0; m < M; ++m) {
     for (unsigned int n = 0; n < N; ++n) {
@@ -631,8 +638,8 @@ static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
       float c_old = C[m * ldc + n];
       for (unsigned int k = 0; k < K; ++k) {
         float a, b;
-        a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
-        b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
+        a = ((TransA) ? A[k * lda + m] : A[m * lda + k]);
+        b = ((TransB) ? B[n * ldb + k] : B[k * ldb + n]);
         c += a * b;
       }
       C[m * ldc + n] = alpha * c;
@@ -729,12 +736,11 @@ void saxpy(const unsigned int N, const float alpha, const float *X,
 #endif
 }
 
-void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
            const float alpha, const void *A, const unsigned int lda,
            const void *B, const unsigned int ldb, const float beta, void *C,
            const unsigned int ldc, ml::train::TensorDim::DataType d_type) {
-
   if (d_type == ml::train::TensorDim::DataType::FP32) {
 #ifdef USE_CUBLAS
     int devID = 0;
@@ -755,10 +761,8 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
     cublasHandle_t handle;
     cublasCreate(&handle);
 
-    cublasOperation_t transA =
-      (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
-    cublasOperation_t transB =
-      (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+    cublasOperation_t transA = (TransA) ? CUBLAS_OP_T : CUBLAS_OP_N;
+    cublasOperation_t transB = (TransB) ? CUBLAS_OP_T : CUBLAS_OP_N;
     cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
                 d_C, N);
 
@@ -770,33 +774,35 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
 #ifdef BLAS_NUM_THREADS
     openblas_set_num_threads(BLAS_NUM_THREADS);
 #endif
-
+    CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans;
+    CBLAS_TRANSPOSE transB = TransB ? CblasTrans : CblasNoTrans;
+    CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor;
     cblas_sgemm(
-      order, TransA, TransB, M, N, K, alpha, static_cast<const float *>(A), lda,
+      order, transA, transB, M, N, K, alpha, static_cast<const float *>(A), lda,
       static_cast<const float *>(B), ldb, beta, static_cast<float *>(C), ldc);
 #else
-    sgemm_raw(order, TransA, TransB, M, N, K, alpha,
+    sgemm_raw(TStorageOrder, TransA, TransB, M, N, K, alpha,
               static_cast<const float *>(A), lda, static_cast<const float *>(B),
               ldb, beta, static_cast<float *>(C), ldc);
 #endif
 
   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    sgemm_FP16(
-      order, TransA, TransB, M, N, K, alpha, static_cast<const _FP16 *>(A), lda,
-      static_cast<const _FP16 *>(B), ldb, beta, static_cast<_FP16 *>(C), ldc);
+    sgemm_FP16(TStorageOrder, TransA, TransB, M, N, K, alpha,
+               static_cast<const _FP16 *>(A), lda,
+               static_cast<const _FP16 *>(B), ldb, beta,
+               static_cast<_FP16 *>(C), ldc);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
   }
 } // namespace nntrainer
 
-void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
            const float alpha, const float *A, const unsigned int lda,
            const float *B, const unsigned int ldb, const float beta, float *C,
            const unsigned int ldc) {
-
 #ifdef USE_CUBLAS
   int devID = 0;
   cudaDeviceProp deviceProp;
@@ -816,8 +822,8 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
   cublasHandle_t handle;
   cublasCreate(&handle);
 
-  cublasOperation_t transA = (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
-  cublasOperation_t transB = (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+  cublasOperation_t transA = (TransA) ? CUBLAS_OP_T : CUBLAS_OP_N;
+  cublasOperation_t transB = (TransB) ? CUBLAS_OP_T : CUBLAS_OP_N;
   cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
               d_C, N);
 
@@ -827,11 +833,14 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
 #ifdef BLAS_NUM_THREADS
   openblas_set_num_threads(BLAS_NUM_THREADS);
 #endif
-  cblas_sgemm(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
+  CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans;
+  CBLAS_TRANSPOSE transB = TransB ? CblasTrans : CblasNoTrans;
+  CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor;
+  cblas_sgemm(order, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C,
               ldc);
 #else
-  sgemm_raw(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
-            ldc);
+  sgemm_raw(TStorageOrder, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta,
+            C, ldc);
 #endif
 }
 
@@ -927,37 +936,39 @@ float sdot(const unsigned int N, const float *X, const unsigned int incX,
 #endif
 }
 
-void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M,
            const unsigned int N, const float alpha, const void *A,
            const unsigned int lda, const void *X, const int incX,
            const float beta, void *Y, const int incY,
            ml::train::TensorDim::DataType d_type) {
+
   if (d_type == ml::train::TensorDim::DataType::FP32) {
 #ifdef USE_BLAS
 #ifdef BLAS_NUM_THREADS
     openblas_set_num_threads(BLAS_NUM_THREADS);
 #endif
+    CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans;
+    CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor;
     return cblas_sgemv(
-      order, TransA, M, N, alpha, static_cast<const float *>(A), lda,
+      order, transA, M, N, alpha, static_cast<const float *>(A), lda,
       static_cast<const float *>(X), incX, beta, static_cast<float *>(Y), incY);
 #else
-
-    return sgemv_raw(order, TransA, M, N, alpha, static_cast<const float *>(A),
-                     lda, static_cast<const float *>(X), incX, beta,
-                     static_cast<float *>(Y), incY);
+    return sgemv_raw(
+      TStorageOrder, TransA, M, N, alpha, static_cast<const float *>(A), lda,
+      static_cast<const float *>(X), incX, beta, static_cast<float *>(Y), incY);
 #endif
   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    return sgemv_FP16(order, TransA, M, N, alpha, static_cast<const _FP16 *>(A),
-                      lda, static_cast<const _FP16 *>(X), incX, beta,
-                      static_cast<_FP16 *>(Y), incY);
+    return sgemv_FP16(
+      TStorageOrder, TransA, M, N, alpha, static_cast<const _FP16 *>(A), lda,
+      static_cast<const _FP16 *>(X), incX, beta, static_cast<_FP16 *>(Y), incY);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
   }
 }
 
-void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M,
            const unsigned int N, const float alpha, const float *A,
            const unsigned int lda, const float *X, const int incX,
            const float beta, float *Y, const int incY) {
@@ -965,10 +976,13 @@ void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
 #ifdef BLAS_NUM_THREADS
   openblas_set_num_threads(BLAS_NUM_THREADS);
 #endif
-  return cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y,
+  CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans;
+  CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor;
+  return cblas_sgemv(order, transA, M, N, alpha, A, lda, X, incX, beta, Y,
                      incY);
 #else
-  return sgemv_raw(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+  return sgemv_raw(TStorageOrder, TransA, M, N, alpha, A, lda, X, incX, beta, Y,
+                   incY);
 #endif
 }
 
index 69cdda01f9e072ffb2bc7bf7dfdba0dc26479ea9..b57ea3e057ef9c2c7cf5c1e600245bfd873acdf3 100644 (file)
 #define __BLAS_INTERFACE_H_
 #ifdef __cplusplus
 
-#ifdef USE_BLAS
-extern "C" {
-#include <cblas.h>
-}
-#else
-enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 };
-
-enum CBLAS_TRANSPOSE {
-  CblasNoTrans = 111,
-  CblasTrans = 112,
-  CblasConjTrans = 113
-};
-
-#endif
-
 #ifdef USE_CUBLAS
 #include <helper_cuda.h>
 #include <helper_functions.h>
@@ -132,7 +117,7 @@ void saxpy(const unsigned int N, const float alpha, const _FP16 *X,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
            const float alpha, const _FP16 *A, const unsigned int lda,
            const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C,
@@ -147,7 +132,7 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M,
            const unsigned int N, const float alpha, const _FP16 *A,
            const unsigned int lda, const _FP16 *X, const int incX,
            const float beta, _FP16 *Y, const int incY);
@@ -346,7 +331,7 @@ void saxpy(const unsigned int N, const float alpha, const float *X,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
            const float alpha, const void *A, const unsigned int lda,
            const void *B, const unsigned int ldb, const float beta, void *C,
@@ -363,7 +348,7 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
            const float alpha, const float *A, const unsigned int lda,
            const float *B, const unsigned int ldb, const float beta, float *C,
@@ -378,7 +363,7 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M,
            const unsigned int N, const float alpha, const void *A,
            const unsigned int lda, const void *X, const int incX,
            const float beta, void *Y, const int incY,
@@ -393,7 +378,7 @@ void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M,
            const unsigned int N, const float alpha, const float *A,
            const unsigned int lda, const float *X, const int incX,
            const float beta, float *Y, const int incY);
index 9e8422d404f5b67d9799dfc694052cd4847deaed..c1ecf2ddc1562b94fe3ff4bc2886048cb4481c91 100644 (file)
@@ -119,8 +119,6 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result,
     const float *data = input.getData();
     const float *mdata = m.getData();
     float *rdata = result.getData();
-    enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans;
-    enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans;
 
     /// shortcut handling in case of vector
     /// for vector, (1 * K) == (K * 1) in current memory layout...
@@ -134,20 +132,19 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result,
     }
     /// case2: (M * K) X (K * 1)
     else if (N == 1) {
-      transA ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context)
-             : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context);
+      trans ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context)
+            : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context);
     }
     /// case3: (1 * K) X (K * N) = 1 * N = R
     /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
     /// Effectively a translation of sgemv
     else if (M == 1) {
-      transB = transB == CblasTrans ? CblasNoTrans : CblasTrans;
-      transB ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context)
-             : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context);
+      trans_m ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context)
+              : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context);
     }
     /// case others: use gemm
     else {
-      sgemm_cl(transA, transB, data, mdata, rdata, M, N, K, lda, ldb, ldc,
+      sgemm_cl(trans, trans_m, data, mdata, rdata, M, N, K, lda, ldb, ldc,
                context);
     }
   } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) {
@@ -155,8 +152,6 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result,
     const _FP16 *data = input.getData<_FP16>();
     const _FP16 *mdata = m.getData<_FP16>();
     _FP16 *rdata = result.getData<_FP16>();
-    enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans;
-    enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans;
 
     /// shortcut handling in case of vector
     /// for vector, (1 * K) == (K * 1) in current memory layout...
@@ -170,20 +165,19 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result,
     }
     /// case2: (M * K) X (K * 1)
     else if (N == 1) {
-      transA ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context)
-             : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context);
+      trans ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context)
+            : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context);
     }
     /// case3: (1 * K) X (K * N) = 1 * N = R
     /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
     /// Effectively a translation of sgemv
     else if (M == 1) {
-      transB = transB == CblasTrans ? CblasNoTrans : CblasTrans;
-      transB ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context)
-             : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context);
+      trans_m ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context)
+              : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context);
     }
     /// case others: use sgemm
     else {
-      sgemm_cl(transA, transB, data, mdata, rdata, M, N, K, lda, ldb, ldc,
+      sgemm_cl(trans, trans_m, data, mdata, rdata, M, N, K, lda, ldb, ldc,
                context);
     }
 #else
index 791cdc5e6b8c6e450314081b145e934d73ceebd1..5c0d1dfa72abcb2a99fe751da03fa7272d54d167 100644 (file)
@@ -282,24 +282,24 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1,
   return cl_ret;
 }
 
-void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const float *A,
-              const float *B, float *C, unsigned int M, unsigned int N,
-              unsigned int K, unsigned int lda, unsigned int ldb,
-              unsigned int ldc, RunLayerContext &context) {
+void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B,
+              float *C, unsigned int M, unsigned int N, unsigned int K,
+              unsigned int lda, unsigned int ldb, unsigned int ldc,
+              RunLayerContext &context) {
 
   opencl::Kernel *kernel_sgemm = nullptr;
   RunLayerContext::LayerKernel layerKernel;
   std::string sgemm_cl_kernel_;
 
-  if (TransA != CblasTrans && TransB != CblasTrans) {
+  if (!TransA && !TransB) {
     kernel_sgemm = &kernel_sgemm_noTrans;
     layerKernel = context.LayerKernel::SGEMM_NOTRANS;
     sgemm_cl_kernel_ = sgemm_cl_noTrans_kernel_;
-  } else if (TransA == CblasTrans && TransB != CblasTrans) {
+  } else if (TransA && !TransB) {
     kernel_sgemm = &kernel_sgemm_transA;
     layerKernel = context.LayerKernel::SGEMM_TRANSA;
     sgemm_cl_kernel_ = sgemm_cl_transA_kernel_;
-  } else if (TransA != CblasTrans && TransB == CblasTrans) {
+  } else if (!TransA && TransB) {
     kernel_sgemm = &kernel_sgemm_transB;
     layerKernel = context.LayerKernel::SGEMM_TRANSB;
     sgemm_cl_kernel_ = sgemm_cl_transB_kernel_;
index 6b118c68dd0b92715a11ed4c46adfaa2a92c3f3d..008345eef27d5c6add6f7498b01a2bcb492d2d25 100644 (file)
@@ -61,8 +61,8 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1,
 /**
  * @brief     sgemm computation : Y = op(A)*op(B) + C,
  * where op(X) is one of X or X**T
- * @param[in] transA CBLAS_TRANSPOSE
- * @param[in] transB CBLAS_TRANSPOSE
+ * @param[in] transA bool transpose
+ * @param[in] transB bool transpose
  * @param[in] A float * for Matrix A
  * @param[in] B float * for Matrix B
  * @param[in] C float * for Matrix C
@@ -74,10 +74,10 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1,
  * @param[in] ldc number of C's columns
  * @param[in] context RunLayerContext reference
  */
-void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const float *A,
-              const float *B, float *C, unsigned int M, unsigned int N,
-              unsigned int K, unsigned int lda, unsigned int ldb,
-              unsigned int ldc, RunLayerContext &context);
+void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B,
+              float *C, unsigned int M, unsigned int N, unsigned int K,
+              unsigned int lda, unsigned int ldb, unsigned int ldc,
+              RunLayerContext &context);
 
 /**
  * @brief     addition : sum of all input vectors
@@ -140,8 +140,8 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1,
 /**
  * @brief     fp16 sgemm computation : Y = op(A)*op(B) + C,
  * where op(X) is one of X or X**T
- * @param[in] transA CBLAS_TRANSPOSE
- * @param[in] transB CBLAS_TRANSPOSE
+ * @param[in] transA bool transpose
+ * @param[in] transB bool transpose
  * @param[in] A fp16 * for Matrix A
  * @param[in] B fp16 * for Matrix B
  * @param[in] C fp16 * for Matrix C
@@ -153,10 +153,10 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1,
  * @param[in] ldc number of C's columns
  * @param[in] context RunLayerContext reference
  */
-void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const __fp16 *A,
-              const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
-              unsigned int K, unsigned int lda, unsigned int ldb,
-              unsigned int ldc, RunLayerContext &context);
+void sgemm_cl(bool TransA, bool TransB, const __fp16 *A, const __fp16 *B,
+              __fp16 *C, unsigned int M, unsigned int N, unsigned int K,
+              unsigned int lda, unsigned int ldb, unsigned int ldc,
+              RunLayerContext &context);
 
 /**
  * @brief     fp16 addition : sum of all input vectors
index 96c7ce9c90ac9a8bf6c6f72366f70984c59ca9ed..e7f2f8b2f9d9f60dde91575d89a9873e7785d256 100644 (file)
@@ -302,24 +302,24 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1,
   return cl_ret;
 }
 
-void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const __fp16 *A,
-              const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
-              unsigned int K, unsigned int lda, unsigned int ldb,
-              unsigned int ldc, RunLayerContext &context) {
+void sgemm_cl(bool TransA, bool TransB, const __fp16 *A, const __fp16 *B,
+              __fp16 *C, unsigned int M, unsigned int N, unsigned int K,
+              unsigned int lda, unsigned int ldb, unsigned int ldc,
+              RunLayerContext &context) {
 
   opencl::Kernel *kernel_sgemm_fp16 = nullptr;
   RunLayerContext::LayerKernel layerKernel;
   std::string sgemm_cl_kernel_fp16_;
 
-  if (TransA != CblasTrans && TransB != CblasTrans) {
+  if (!TransA && !TransB) {
     kernel_sgemm_fp16 = &kernel_sgemm_noTrans_fp16;
     layerKernel = context.LayerKernel::SGEMM_NOTRANS_FP16;
     sgemm_cl_kernel_fp16_ = sgemm_cl_noTrans_kernel_fp16_;
-  } else if (TransA == CblasTrans && TransB != CblasTrans) {
+  } else if (TransA && !TransB) {
     kernel_sgemm_fp16 = &kernel_sgemm_transA_fp16;
     layerKernel = context.LayerKernel::SGEMM_TRANSA_FP16;
     sgemm_cl_kernel_fp16_ = sgemm_cl_transA_kernel_fp16_;
-  } else if (TransA != CblasTrans && TransB == CblasTrans) {
+  } else if (!TransA && TransB) {
     kernel_sgemm_fp16 = &kernel_sgemm_transB_fp16;
     layerKernel = context.LayerKernel::SGEMM_TRANSB_FP16;
     sgemm_cl_kernel_fp16_ = sgemm_cl_transB_kernel_fp16_;
index b35894f9ec6c960502cc3ad3bc875f5c19eea09e..7ca18a7b401763b740943bd8426f0579da74fc90 100644 (file)
@@ -493,8 +493,8 @@ void FloatTensor::sum_by_batch(Tensor &output) const {
 
   Tensor ones(1, 1, 1, feat_len, this->getFormat());
   ones.setValue(1.0);
-  sgemv(CblasRowMajor, CblasNoTrans, batch, feat_len, 1, data, feat_len,
-        ones.getData<float>(), 1, 0.0, out_data, 1);
+  sgemv((unsigned int)dim.getStorageOrder(), false, batch, feat_len, 1, data,
+        feat_len, ones.getData<float>(), 1, 0.0, out_data, 1);
 }
 
 Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha,
@@ -521,8 +521,8 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha,
     size_t batch = dim.batch();
     Tensor ones(1, 1, 1, batch, getTensorType());
     ones.setValue(alpha);
-    sgemv(CblasRowMajor, CblasTrans, batch, feat_len, 1, data, feat_len,
-          ones.getData<float>(), 1, beta, output.getData<float>(), 1);
+    sgemv((unsigned int)dim.getStorageOrder(), true, batch, feat_len, 1, data,
+          feat_len, ones.getData<float>(), 1, beta, output.getData<float>(), 1);
   } break;
   case 1: {
     CREATE_IF_EMPTY_DIMS(output, dim[0], 1, dim[2], dim[3], getTensorType());
@@ -531,8 +531,9 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha,
       unsigned int t_axis = dim[1];
       Tensor ones(1, 1, 1, t_axis, getTensorType());
       ones.setValue(alpha);
-      sgemv(CblasRowMajor, CblasNoTrans, feat_len, t_axis, 1, data, t_axis,
-            ones.getData<float>(), 1, beta, output.getData<float>(), 1);
+      sgemv((unsigned int)dim.getStorageOrder(), false, feat_len, t_axis, 1,
+            data, t_axis, ones.getData<float>(), 1, beta,
+            output.getData<float>(), 1);
     } else {
       unsigned int feat_len = dim[2] * dim[3];
       unsigned int t_axis = dim[1];
@@ -540,7 +541,7 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha,
       ones.setValue(alpha);
       float *rdata = output.getData<float>();
       for (unsigned int k = 0; k < dim[0]; ++k) {
-        sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1,
+        sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, feat_len, 1,
               &data[k * dim.getFeatureLen()], feat_len, ones.getData<float>(),
               1, beta, &rdata[k * feat_len], 1);
       }
@@ -555,7 +556,7 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha,
       ones.setValue(alpha);
       float *rdata = output.getData<float>();
       for (unsigned int k = 0; k < dim[0]; ++k) {
-        sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1,
+        sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, feat_len, 1,
               &data[k * dim.getFeatureLen()], feat_len, ones.getData<float>(),
               1, beta, &rdata[k * feat_len], 1);
       }
@@ -573,14 +574,15 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha,
             unsigned int ridx =
               k * output.getDim().getFeatureLen() + c * dim[3];
 
-            sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3,
-                  ones.getData<float>(), 1, beta, &rdata[ridx], 1);
+            sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, t_3, 1,
+                  &data[idx], t_3, ones.getData<float>(), 1, beta, &rdata[ridx],
+                  1);
           }
         }
       } else {
-        sgemv(CblasColMajor, CblasTrans, t_axis, output.getDim().getDataLen(),
-              1, data, t_axis, ones.getData<float>(), 1, beta,
-              output.getData<float>(), 1);
+        sgemv((unsigned int)dim.getStorageOrder(), true, t_axis,
+              output.getDim().getDataLen(), 1, data, t_axis,
+              ones.getData<float>(), 1, beta, output.getData<float>(), 1);
       }
     }
   } break;
@@ -597,8 +599,9 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha,
         for (unsigned int c = 0; c < dim[2]; ++c) {
           unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[1];
           unsigned int ridx = k * output.getDim().getFeatureLen() + c * dim[1];
-          sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3,
-                ones.getData<float>(), 1, beta, &rdata[ridx], 1);
+          sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, t_3, 1,
+                &data[idx], t_3, ones.getData<float>(), 1, beta, &rdata[ridx],
+                1);
         }
       }
     } else {
@@ -608,7 +611,7 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha,
       ones.setValue(alpha);
 
       if (dim.getStorageOrder() == TStorageOrder::ROW_MAJOR) {
-        sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n,
+        sgemv((unsigned int)dim.getStorageOrder(), false, m, n, 1, data, n,
               ones.getData<float>(), 1, beta, output.getData<float>(), 1);
       } else {
         float *rdata = output.getData<float>();
@@ -618,8 +621,9 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha,
             unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[2];
             unsigned int ridx = k * dim[1] * dim[2] + c * dim[2];
 
-            sgemv(CblasColMajor, CblasNoTrans, dim[2], n, 1, &data[idx], dim[2],
-                  ones.getData<float>(), 1, beta, &rdata[ridx], 1);
+            sgemv((unsigned int)dim.getStorageOrder(), false, dim[2], n, 1,
+                  &data[idx], dim[2], ones.getData<float>(), 1, beta,
+                  &rdata[ridx], 1);
           }
         }
       }
@@ -699,8 +703,6 @@ Tensor &FloatTensor::dot(Tensor const &input, Tensor &output, bool trans,
   const float *mdata = input.getData<float>();
   float *rdata = output.getData<float>();
   const float alpha = 1.0f;
-  enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans;
-  enum CBLAS_TRANSPOSE transB = trans_in ? CblasTrans : CblasNoTrans;
 
   /// shortcut handling in case of vector
   /// for vector, (1 * K) == (K * 1) in current memory layout...
@@ -714,21 +716,21 @@ Tensor &FloatTensor::dot(Tensor const &input, Tensor &output, bool trans,
   }
   /// case2: (M * K) X (K * 1)
   else if (N == 1) {
-    sgemv(CblasRowMajor, transA, first_three_flat, last_axis, alpha, data, lda,
-          mdata, 1, beta, rdata, 1);
+    sgemv((unsigned int)dim.getStorageOrder(), trans, first_three_flat,
+          last_axis, alpha, data, lda, mdata, 1, beta, rdata, 1);
   }
   /// case3: (1 * K) X (K * N) = 1 * N = R
   /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
   /// Effectively a translation of sgemv
   else if (M == 1) {
-    transB = transB == CblasTrans ? CblasNoTrans : CblasTrans;
-    sgemv(CblasRowMajor, transB, input_first_three_flat, input_last_axis, alpha,
-          mdata, ldb, data, 1, beta, rdata, 1);
+    sgemv((unsigned int)dim.getStorageOrder(), !trans_in,
+          input_first_three_flat, input_last_axis, alpha, mdata, ldb, data, 1,
+          beta, rdata, 1);
   }
   /// case others: use gemm
   else {
-    sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, data, lda, mdata, ldb,
-          beta, rdata, ldc);
+    sgemm((unsigned int)dim.getStorageOrder(), trans, trans_in, M, N, K, alpha,
+          data, lda, mdata, ldb, beta, rdata, ldc);
   }
 
   return output;
index bea483df54f7c1147036249ad2419d865c1cdf08..6753d51d341ee94dd4a6be557d0d141cd429fcdb 100644 (file)
@@ -478,8 +478,8 @@ void HalfTensor::sum_by_batch(Tensor &output) const {
 
   Tensor ones(1, 1, 1, feat_len, this->getTensorType());
   ones.setValue((_FP16)1.0);
-  sgemv(CblasRowMajor, CblasNoTrans, batch, feat_len, 1, data, feat_len,
-        ones.getData<_FP16>(), 1, 0.0, out_data, 1);
+  sgemv((unsigned int)dim.getStorageOrder(), false, batch, feat_len, 1, data,
+        feat_len, ones.getData<_FP16>(), 1, 0.0, out_data, 1);
 }
 
 Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha,
@@ -507,8 +507,8 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha,
     size_t batch = dim.batch();
     Tensor ones(1, 1, 1, batch, this->getTensorType());
     ones.setValue(alpha);
-    sgemv(CblasRowMajor, CblasTrans, batch, feat_len, 1, data, feat_len,
-          ones.getData<_FP16>(), 1, beta, output.getData<_FP16>(), 1);
+    sgemv((unsigned int)dim.getStorageOrder(), true, batch, feat_len, 1, data,
+          feat_len, ones.getData<_FP16>(), 1, beta, output.getData<_FP16>(), 1);
   } break;
   case 1: {
     CREATE_IF_EMPTY_DIMS(output, dim[0], 1, dim[2], dim[3], getTensorType());
@@ -517,8 +517,9 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha,
       unsigned int t_axis = dim[1];
       Tensor ones(1, 1, 1, t_axis, this->getTensorType());
       ones.setValue(alpha);
-      sgemv(CblasRowMajor, CblasNoTrans, feat_len, t_axis, 1, data, t_axis,
-            ones.getData<_FP16>(), 1, beta, output.getData<_FP16>(), 1);
+      sgemv((unsigned int)dim.getStorageOrder(), false, feat_len, t_axis, 1,
+            data, t_axis, ones.getData<_FP16>(), 1, beta,
+            output.getData<_FP16>(), 1);
     } else {
       unsigned int feat_len = dim[2] * dim[3];
       unsigned int t_axis = dim[1];
@@ -526,7 +527,7 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha,
       ones.setValue(alpha);
       _FP16 *rdata = output.getData<_FP16>();
       for (unsigned int k = 0; k < dim[0]; ++k) {
-        sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1,
+        sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, feat_len, 1,
               &data[k * dim.getFeatureLen()], feat_len, ones.getData<_FP16>(),
               1, beta, &rdata[k * feat_len], 1);
       }
@@ -542,7 +543,7 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha,
       ones.setValue(alpha);
       _FP16 *rdata = output.getData<_FP16>();
       for (unsigned int k = 0; k < dim[0]; ++k) {
-        sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1,
+        sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, feat_len, 1,
               &data[k * dim.getFeatureLen()], feat_len, ones.getData<_FP16>(),
               1, beta, &rdata[k * feat_len], 1);
       }
@@ -556,8 +557,9 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha,
         for (unsigned int c = 0; c < dim[1]; ++c) {
           unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[2];
           unsigned int ridx = k * output.getDim().getFeatureLen() + c * dim[3];
-          sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3,
-                ones.getData<_FP16>(), 1, beta, &rdata[ridx], 1);
+          sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, t_3, 1,
+                &data[idx], t_3, ones.getData<_FP16>(), 1, beta, &rdata[ridx],
+                1);
         }
       }
     }
@@ -574,8 +576,9 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha,
         for (unsigned int c = 0; c < dim[2]; ++c) {
           unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[1];
           unsigned int ridx = k * output.getDim().getFeatureLen() + c * dim[1];
-          sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3,
-                ones.getData<_FP16>(), 1, beta, &rdata[ridx], 1);
+          sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, t_3, 1,
+                &data[idx], t_3, ones.getData<_FP16>(), 1, beta, &rdata[ridx],
+                1);
         }
       }
     } else {
@@ -583,7 +586,7 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha,
       unsigned int n = dim[3];
       Tensor ones(1, 1, 1, n, getTensorType());
       ones.setValue(alpha);
-      sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n,
+      sgemv((unsigned int)dim.getStorageOrder(), false, m, n, 1, data, n,
             ones.getData<_FP16>(), 1, beta, output.getData<_FP16>(), 1);
     }
   } break;
@@ -651,8 +654,6 @@ Tensor &HalfTensor::dot(Tensor const &input, Tensor &output, bool trans,
   const _FP16 *mdata = input.getData<_FP16>();
   _FP16 *rdata = output.getData<_FP16>();
   const float alpha = 1.0f;
-  enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans;
-  enum CBLAS_TRANSPOSE transB = trans_in ? CblasTrans : CblasNoTrans;
 
   /// shortcut handling in case of vector
   /// for vector, (1 * K) == (K * 1) in current memory layout...
@@ -666,21 +667,21 @@ Tensor &HalfTensor::dot(Tensor const &input, Tensor &output, bool trans,
   }
   /// case2: (M * K) X (K * 1)
   else if (N == 1) {
-    sgemv(CblasRowMajor, transA, first_three_flat, last_axis, alpha, data, lda,
-          mdata, 1, beta, rdata, 1);
+    sgemv((unsigned int)dim.getStorageOrder(), trans, first_three_flat,
+          last_axis, alpha, data, lda, mdata, 1, beta, rdata, 1);
   }
   /// case3: (1 * K) X (K * N) = 1 * N = R
   /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
   /// Effectively a translation of sgemv
   else if (M == 1) {
-    transB = transB == CblasTrans ? CblasNoTrans : CblasTrans;
-    sgemv(CblasRowMajor, transB, input_first_three_flat, input_last_axis, alpha,
-          mdata, ldb, data, 1, beta, rdata, 1);
+    sgemv((unsigned int)dim.getStorageOrder(), !trans_in,
+          input_first_three_flat, input_last_axis, alpha, mdata, ldb, data, 1,
+          beta, rdata, 1);
   }
   /// case others: use sgemm
   else {
-    sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, data, lda, mdata, ldb,
-          beta, rdata, ldc);
+    sgemm((unsigned int)dim.getStorageOrder(), trans, trans_in, M, N, K, alpha,
+          data, lda, mdata, ldb, beta, rdata, ldc);
   }
 
   return output;