[WIP] [Tensor] Add __fp16 supporting functions in blas_interface
authorskykongkong8 <kssjustin98@gmail.com>
Wed, 19 Jul 2023 05:18:42 +0000 (14:18 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 21 Aug 2023 06:29:23 +0000 (15:29 +0900)
* Add __fp16 support with #ifdef, and parameter overloading
* (trivial) fix typo
* TODO: replace with valid __fp16 supporting functions

Signed-off-by: skykongkong8 <kssjustin98@gmail.com>
nntrainer/tensor/blas_interface.cpp
nntrainer/tensor/blas_interface.h
nntrainer/tensor/tensor.cpp

index 5335db3..cd81d90 100644 (file)
@@ -12,8 +12,8 @@
  */
 
 #include <blas_interface.h>
-#include <nntrainer_error.h>
 #include <iostream>
+#include <nntrainer_error.h>
 
 #include <cmath>
 
@@ -42,6 +42,15 @@ static void saxpy_raw(const unsigned int N, const float alpha, const float *X,
     Y[i * incY] = Y[i * incY] + X[i * incX] * alpha;
 }
 
+static void saxpy_FP16(const unsigned int N, const float alpha, const __fp16 *X,
+                       const int incX, __fp16 *Y, const int incY) {
+  if (incX < 0 or incY < 0)
+    throw std::invalid_argument(
+      "Error: negative inc not supported without cblas");
+  for (unsigned int i = 0; i < N; ++i)
+    Y[i * incY] = Y[i * incY] + X[i * incX] * alpha;
+}
+
 static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
                       const unsigned int M, const unsigned int N,
                       const float alpha, const float *A, const unsigned int lda,
@@ -58,6 +67,22 @@ static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
   }
 }
 
+static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
+                       const unsigned int M, const unsigned int N,
+                       const float alpha, const __fp16 *A,
+                       const unsigned int lda, const __fp16 *X, const int incX,
+                       const float beta, __fp16 *Y, const int incY) {
+
+  unsigned int incy = abs(incY);
+  unsigned int incx = abs(incX);
+
+  if (TransA == CblasTrans) {
+    sgemv_loop(i, j, N, M);
+  } else {
+    sgemv_loop(j, i, M, N);
+  }
+}
+
 static float sdot_raw(const unsigned int N, const float *X,
                       const unsigned int incX, const float *Y,
                       const unsigned int incY) {
@@ -68,6 +93,16 @@ static float sdot_raw(const unsigned int N, const float *X,
   return ret;
 }
 
+static __fp16 sdot_FP16(const unsigned int N, const __fp16 *X,
+                        const unsigned int incX, const __fp16 *Y,
+                        const unsigned int incY) {
+  __fp16 ret = 0;
+  for (unsigned int i = 0; i < N; ++i) {
+    ret += X[i * incX] * Y[i * incY];
+  }
+  return ret;
+}
+
 static void scopy_raw(const unsigned int N, const float *X, const int incX,
                       float *Y, const int incY) {
   unsigned int incy = abs(incY);
@@ -141,6 +176,18 @@ static float snrm2_raw(const unsigned int N, const float *X, const int incX) {
   return sqrt(sum);
 }
 
+static float snrm2_FP16(const unsigned int N, const __fp16 *X, const int incX) {
+  unsigned int incx = abs(incX);
+  __fp16 sum = 0.0f;
+  __fp16 tmp;
+#pragma omp parallel for private(tmp) reduction(+ : sum)
+  for (unsigned int i = 0; i < N; i++) {
+    tmp = X[i * incx];
+    sum += tmp * tmp;
+  }
+  return sqrt(sum);
+}
+
 static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
                       CBLAS_TRANSPOSE TransB, const unsigned int M,
                       const unsigned int N, const unsigned int K,
@@ -165,6 +212,31 @@ static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
   }
 }
 
+static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
+                       CBLAS_TRANSPOSE TransB, const unsigned int M,
+                       const unsigned int N, const unsigned int K,
+                       const float alpha, const __fp16 *A,
+                       const unsigned int lda, const __fp16 *B,
+                       const unsigned int ldb, const float beta, __fp16 *C,
+                       const unsigned int ldc) {
+
+  for (unsigned int m = 0; m < M; ++m) {
+    for (unsigned int n = 0; n < N; ++n) {
+      double c = 0.0;
+      __fp16 c_old = C[m * ldc + n];
+      for (unsigned int k = 0; k < K; ++k) {
+        __fp16 a, b;
+        a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
+        b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
+        c += a * b;
+      }
+      C[m * ldc + n] = alpha * c;
+      if (beta != 0.0)
+        C[m * ldc + n] += beta * c_old;
+    }
+  }
+}
+
 static unsigned int isamax_raw(const unsigned int N, const float *X,
                                const int incX) {
 
@@ -181,7 +253,40 @@ static unsigned int isamax_raw(const unsigned int N, const float *X,
   return max_idx;
 }
 
+static unsigned int isamax_FP16(const unsigned int N, const __fp16 *X,
+                                const int incX) {
+
+  unsigned int max_idx = 0;
+  __fp16 max_val = X[0];
+  for (unsigned int n = 1; n < N; n += incX) {
+    __fp16 cur_val = abs(X[n]);
+    if (cur_val > max_val) {
+      max_val = cur_val;
+      max_idx = n;
+    }
+  }
+
+  return max_idx;
+}
+
+#endif
+
+void saxpy(const unsigned int N, const float alpha, const void *X,
+           const int incX, void *Y, const int incY,
+           ml::train::TensorDim::DataType d_type) {
+#ifdef USE_BLAS
+#ifdef BLAS_NUM_THREADS
+  openblas_set_num_threads(BLAS_NUM_THREADS);
+#endif
+  cblas_saxpy(N, alpha, X, incX, Y, incY);
+#else
+  if (d_type == ml::train::TensorDim::DataType::FP32) {
+    saxpy_raw(N, alpha, (float *)X, incX, (float *)Y, incY);
+  } else if (d_type == ml::train::TensorDim::DataType::FP16) {
+    saxpy_FP16(N, alpha, (__fp16 *)X, incX, (__fp16 *)Y, incY);
+  }
 #endif
+}
 
 void saxpy(const unsigned int N, const float alpha, const float *X,
            const int incX, float *Y, const int incY) {
@@ -195,6 +300,59 @@ void saxpy(const unsigned int N, const float alpha, const float *X,
 #endif
 }
 
+void saxpy(const unsigned int N, const float alpha, const __fp16 *X,
+           const int incX, __fp16 *Y, const int incY) {
+  saxpy_FP16(N, alpha, X, incX, Y, incY);
+}
+
+void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+           const unsigned int M, const unsigned int N, const unsigned int K,
+           const float alpha, const void *A, const unsigned int lda,
+           const void *B, const unsigned int ldb, const float beta, void *C,
+           const unsigned int ldc, ml::train::TensorDim::DataType d_type) {
+#ifdef USE_CUBLAS
+  int devID = 0;
+  cudaDeviceProp deviceProp;
+  cudaGetDeviceProperties(&deviceProp, devID);
+  float *d_A, *d_B, *d_C;
+
+  unsigned int size_A = M * K * sizeof(float);
+  unsigned int size_B = K * N * sizeof(float);
+  unsigned int size_C = M * N * sizeof(float);
+
+  cudaMalloc((void **)&d_A, size_A);
+  cudaMalloc((void **)&d_B, size_B);
+  cudaMemcpy(d_A, A, size_A, cudaMemcpyHostToDevice);
+  cudaMemcpy(d_B, B, size_B, cudaMemcpyHostToDevice);
+  cudaMalloc((void **)&d_C, size_C);
+
+  cublasHandle_t handle;
+  cublasCreate(&handle);
+
+  cublasOperation_t transA = (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+  cublasOperation_t transB = (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+  cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
+              d_C, N);
+
+  cudaMemcpy(C, d_C, size_C, cudaMemcpyDeviceToHost);
+  cublasDestroy(handle);
+#elif defined USE_BLAS
+#ifdef BLAS_NUM_THREADS
+  openblas_set_num_threads(BLAS_NUM_THREADS);
+#endif
+  cblas_sgemm(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
+              ldc);
+#else
+  if (d_type == ml::train::TensorDim::DataType::FP32) {
+    sgemm_raw(order, TransA, TransB, M, N, K, alpha, (float *)A, lda,
+              (float *)B, ldb, beta, (float *)C, ldc);
+  } else if (d_type == ml::train::TensorDim::DataType::FP16) {
+    sgemm_FP16(order, TransA, TransB, M, N, K, alpha, (__fp16 *)A, lda,
+               (__fp16 *)B, ldb, beta, (__fp16 *)C, ldc);
+  }
+#endif
+}
+
 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
            const float alpha, const float *A, const unsigned int lda,
@@ -239,6 +397,15 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
 #endif
 }
 
+void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+           const unsigned int M, const unsigned int N, const unsigned int K,
+           const float alpha, const __fp16 *A, const unsigned int lda,
+           const __fp16 *B, const unsigned int ldb, const float beta, __fp16 *C,
+           const unsigned int ldc) {
+  sgemm_FP16(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
+             ldc);
+}
+
 void scopy(const unsigned int N, const void *X, const int incX, void *Y,
            const int incY, ml::train::TensorDim::DataType d_type) {
 #ifdef USE_BLAS
@@ -286,6 +453,10 @@ float snrm2(const int N, const float *X, const int incX) {
 #endif
 }
 
+__fp16 snrm2(const int N, const __fp16 *X, const int incX) {
+  return snrm2_FP16(N, X, incX);
+}
+
 float sdot(const unsigned int N, const float *X, const unsigned int incX,
            const float *Y, const unsigned int incY) {
 #ifdef USE_BLAS
@@ -298,6 +469,33 @@ float sdot(const unsigned int N, const float *X, const unsigned int incX,
 #endif
 }
 
+__fp16 sdot(const unsigned int N, const __fp16 *X, const unsigned int incX,
+            const __fp16 *Y, const unsigned int incY) {
+  return sdot_FP16(N, X, incX, Y, incY);
+}
+
+void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+           const unsigned int N, const float alpha, const void *A,
+           const unsigned int lda, const void *X, const int incX,
+           const float beta, void *Y, const int incY,
+           ml::train::TensorDim::DataType d_type) {
+#ifdef USE_BLAS
+#ifdef BLAS_NUM_THREADS
+  openblas_set_num_threads(BLAS_NUM_THREADS);
+#endif
+  return cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y,
+                     incY);
+#else
+  if (d_type == ml::train::TensorDim::DataType::FP32) {
+    return sgemv_raw(order, TransA, M, N, alpha, (float *)A, lda, (float *)X,
+                     incX, beta, (float *)Y, incY);
+  } else if (d_type == ml::train::TensorDim::DataType::FP16) {
+    return sgemv_FP16(order, TransA, M, N, alpha, (__fp16 *)A, lda, (__fp16 *)X,
+                      incX, beta, (__fp16 *)Y, incY);
+  }
+#endif
+}
+
 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
            const unsigned int N, const float alpha, const float *A,
            const unsigned int lda, const float *X, const int incX,
@@ -313,6 +511,13 @@ void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
 #endif
 }
 
+void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+           const unsigned int N, const float alpha, const __fp16 *A,
+           const unsigned int lda, const __fp16 *X, const int incX,
+           const float beta, __fp16 *Y, const int incY) {
+  sgemv_FP16(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+}
+
 unsigned int isamax(const unsigned int N, const float *X, const int incX) {
 #ifdef USE_BLAS
 #ifdef BLAS_NUM_THREADS
@@ -324,4 +529,9 @@ unsigned int isamax(const unsigned int N, const float *X, const int incX) {
 #endif
 }
 
+unsigned int isamax(const unsigned int N, const __fp16 *X, const int incX) {
+  /// @todo isamax_FP16 for BLAS_NUM_THREADS
+  return isamax_FP16(N, X, incX);
+}
+
 } // namespace nntrainer
index b560c28..63274a9 100644 (file)
@@ -48,6 +48,8 @@ void sscal(const unsigned int N, const float alpha, __fp16 *X, const int incX);
 
 float snrm2(const int N, const float *X, const int incX);
 
+__fp16 snrm2(const int N, const __fp16 *X, const int incX);
+
 void scopy(const unsigned int N, const void *X, const int incX, void *Y,
            const int incY, ml::train::TensorDim::DataType d_type);
 
@@ -60,22 +62,57 @@ void scopy(const unsigned int N, const __fp16 *X, const int incX, __fp16 *Y,
 float sdot(const unsigned int N, const float *X, const unsigned int incX,
            const float *Y, const unsigned int incY);
 
+__fp16 sdot(const unsigned int N, const __fp16 *X, const unsigned int incX,
+            const __fp16 *Y, const unsigned int incY);
+
+void saxpy(const unsigned int N, const float alpha, const void *X,
+           const int incX, void *Y, const int incY,
+           ml::train::TensorDim::DataType d_type);
+
 void saxpy(const unsigned int N, const float alpha, const float *X,
            const int incX, float *Y, const int incY);
 
+void saxpy(const unsigned int N, const float alpha, const __fp16 *X,
+           const int incX, __fp16 *Y, const int incY);
+
+void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+           const unsigned int M, const unsigned int N, const unsigned int K,
+           const float alpha, const void *A, const unsigned int lda,
+           const void *B, const unsigned int ldb, const float beta, void *C,
+           const unsigned int ldc, ml::train::TensorDim::DataType d_type);
+
 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
            const float alpha, const float *A, const unsigned int lda,
            const float *B, const unsigned int ldb, const float beta, float *C,
            const unsigned int ldc);
 
+void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
+           const unsigned int M, const unsigned int N, const unsigned int K,
+           const float alpha, const __fp16 *A, const unsigned int lda,
+           const __fp16 *B, const unsigned int ldb, const float beta, __fp16 *C,
+           const unsigned int ldc);
+
+void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+           const unsigned int N, const float alpha, const void *A,
+           const unsigned int lda, const void *X, const int incX,
+           const float beta, void *Y, const int incY,
+           ml::train::TensorDim::DataType d_type);
+
 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
            const unsigned int N, const float alpha, const float *A,
            const unsigned int lda, const float *X, const int incX,
            const float beta, float *Y, const int incY);
 
+void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
+           const unsigned int N, const float alpha, const __fp16 *A,
+           const unsigned int lda, const __fp16 *X, const int incX,
+           const float beta, __fp16 *Y, const int incY);
+
 unsigned int isamax(const unsigned int N, const float *X, const int incX);
 
+unsigned int isamax(const unsigned int N, const __fp16 *X, const int incX);
+
 } /* namespace nntrainer */
 #endif /* __cplusplus */
 #endif /* __BLAS_INTERFACE_H__ */
index 6e45c4d..a2681a9 100644 (file)
@@ -1355,7 +1355,7 @@ Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
     auto iter_value =
       [is_format_nchw](std::array<unsigned, 4> &loc,
                        const std::array<unsigned, 4> &start_loc, Tensor &t,
-                       const std::array<unsigned, 4> &ref_dim_arr) -> float & {
+                       const std::array<unsigned, 4> &ref_dim_arr) -> __fp16 & {
       auto &value = is_format_nchw
                       ? t.getValue<__fp16>(loc[0], loc[1], loc[2], loc[3])
                       : t.getValue<__fp16>(loc[0], loc[3], loc[1], loc[2]);
@@ -2090,7 +2090,7 @@ Tensor &Tensor::dot(Tensor const &m, Tensor &result, bool trans, bool trans_m,
     const __fp16 *data = getData<__fp16>();
     const __fp16 *mdata = m.getData<__fp16>();
     __fp16 *rdata = result.getData<__fp16>();
-    const __fp16 alpha = 1.0f;
+    const float alpha = 1.0f;
     enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans;
     enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans;
 
@@ -2117,7 +2117,7 @@ Tensor &Tensor::dot(Tensor const &m, Tensor &result, bool trans, bool trans_m,
       sgemv(CblasRowMajor, transB, mdim1, mdim2, alpha, mdata, ldb, data, 1,
             beta, rdata, 1);
     }
-    /// case others: use gemm
+    /// case others: use sgemm
     else {
       sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, data, lda, mdata,
             ldb, beta, rdata, ldc);