From 61f007507ebee82b73d343fb300e27cd6048503a Mon Sep 17 00:00:00 2001 From: skykongkong8 Date: Wed, 19 Jul 2023 14:18:42 +0900 Subject: [PATCH] [WIP] [Tensor] Add __fp16 supporting functions in blas_interface * Add __fp16 support with #ifdef, and parameter overloading * (trivial) fix typo * TODO: replace with valid __fp16 supporting functions Signed-off-by: skykongkong8 --- nntrainer/tensor/blas_interface.cpp | 212 +++++++++++++++++++++++++++++++++++- nntrainer/tensor/blas_interface.h | 37 +++++++ nntrainer/tensor/tensor.cpp | 6 +- 3 files changed, 251 insertions(+), 4 deletions(-) diff --git a/nntrainer/tensor/blas_interface.cpp b/nntrainer/tensor/blas_interface.cpp index 5335db3..cd81d90 100644 --- a/nntrainer/tensor/blas_interface.cpp +++ b/nntrainer/tensor/blas_interface.cpp @@ -12,8 +12,8 @@ */ #include -#include #include +#include #include @@ -42,6 +42,15 @@ static void saxpy_raw(const unsigned int N, const float alpha, const float *X, Y[i * incY] = Y[i * incY] + X[i * incX] * alpha; } +static void saxpy_FP16(const unsigned int N, const float alpha, const __fp16 *X, + const int incX, __fp16 *Y, const int incY) { + if (incX < 0 or incY < 0) + throw std::invalid_argument( + "Error: negative inc not supported without cblas"); + for (unsigned int i = 0; i < N; ++i) + Y[i * incY] = Y[i * incY] + X[i * incX] * alpha; +} + static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, const unsigned int N, const float alpha, const float *A, const unsigned int lda, @@ -58,6 +67,22 @@ static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, } } +static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, + const unsigned int M, const unsigned int N, + const float alpha, const __fp16 *A, + const unsigned int lda, const __fp16 *X, const int incX, + const float beta, __fp16 *Y, const int incY) { + + unsigned int incy = abs(incY); + unsigned int incx = abs(incX); + + if (TransA == CblasTrans) { + sgemv_loop(i, j, N, M); + } else { + sgemv_loop(j, i, M, N); + } +} + static float sdot_raw(const unsigned int N, const float *X, const unsigned int incX, const float *Y, const unsigned int incY) { @@ -68,6 +93,16 @@ static float sdot_raw(const unsigned int N, const float *X, return ret; } +static __fp16 sdot_FP16(const unsigned int N, const __fp16 *X, + const unsigned int incX, const __fp16 *Y, + const unsigned int incY) { + __fp16 ret = 0; + for (unsigned int i = 0; i < N; ++i) { + ret += X[i * incX] * Y[i * incY]; + } + return ret; +} + static void scopy_raw(const unsigned int N, const float *X, const int incX, float *Y, const int incY) { unsigned int incy = abs(incY); @@ -141,6 +176,18 @@ static float snrm2_raw(const unsigned int N, const float *X, const int incX) { return sqrt(sum); } +static float snrm2_FP16(const unsigned int N, const __fp16 *X, const int incX) { + unsigned int incx = abs(incX); + __fp16 sum = 0.0f; + __fp16 tmp; +#pragma omp parallel for private(tmp) reduction(+ : sum) + for (unsigned int i = 0; i < N; i++) { + tmp = X[i * incx]; + sum += tmp * tmp; + } + return sqrt(sum); +} + static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const unsigned int M, const unsigned int N, const unsigned int K, @@ -165,6 +212,31 @@ static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, } } +static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, const unsigned int M, + const unsigned int N, const unsigned int K, + const float alpha, const __fp16 *A, + const unsigned int lda, const __fp16 *B, + const unsigned int ldb, const float beta, __fp16 *C, + const unsigned int ldc) { + + for (unsigned int m = 0; m < M; ++m) { + for (unsigned int n = 0; n < N; ++n) { + double c = 0.0; + __fp16 c_old = C[m * ldc + n]; + for (unsigned int k = 0; k < K; ++k) { + __fp16 a, b; + a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]); + b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]); + c += a * b; + } + C[m * ldc + n] = alpha * c; + if (beta != 0.0) + C[m * ldc + n] += beta * c_old; + } + } +} + static unsigned int isamax_raw(const unsigned int N, const float *X, const int incX) { @@ -181,7 +253,40 @@ static unsigned int isamax_raw(const unsigned int N, const float *X, return max_idx; } +static unsigned int isamax_FP16(const unsigned int N, const __fp16 *X, + const int incX) { + + unsigned int max_idx = 0; + __fp16 max_val = X[0]; + for (unsigned int n = 1; n < N; n += incX) { + __fp16 cur_val = abs(X[n]); + if (cur_val > max_val) { + max_val = cur_val; + max_idx = n; + } + } + + return max_idx; +} + +#endif + +void saxpy(const unsigned int N, const float alpha, const void *X, + const int incX, void *Y, const int incY, + ml::train::TensorDim::DataType d_type) { +#ifdef USE_BLAS +#ifdef BLAS_NUM_THREADS + openblas_set_num_threads(BLAS_NUM_THREADS); +#endif + cblas_saxpy(N, alpha, X, incX, Y, incY); +#else + if (d_type == ml::train::TensorDim::DataType::FP32) { + saxpy_raw(N, alpha, (float *)X, incX, (float *)Y, incY); + } else if (d_type == ml::train::TensorDim::DataType::FP16) { + saxpy_FP16(N, alpha, (__fp16 *)X, incX, (__fp16 *)Y, incY); + } #endif +} void saxpy(const unsigned int N, const float alpha, const float *X, const int incX, float *Y, const int incY) { @@ -195,6 +300,59 @@ void saxpy(const unsigned int N, const float alpha, const float *X, #endif } +void saxpy(const unsigned int N, const float alpha, const __fp16 *X, + const int incX, __fp16 *Y, const int incY) { + saxpy_FP16(N, alpha, X, incX, Y, incY); +} + +void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, + const unsigned int M, const unsigned int N, const unsigned int K, + const float alpha, const void *A, const unsigned int lda, + const void *B, const unsigned int ldb, const float beta, void *C, + const unsigned int ldc, ml::train::TensorDim::DataType d_type) { +#ifdef USE_CUBLAS + int devID = 0; + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, devID); + float *d_A, *d_B, *d_C; + + unsigned int size_A = M * K * sizeof(float); + unsigned int size_B = K * N * sizeof(float); + unsigned int size_C = M * N * sizeof(float); + + cudaMalloc((void **)&d_A, size_A); + cudaMalloc((void **)&d_B, size_B); + cudaMemcpy(d_A, A, size_A, cudaMemcpyHostToDevice); + cudaMemcpy(d_B, B, size_B, cudaMemcpyHostToDevice); + cudaMalloc((void **)&d_C, size_C); + + cublasHandle_t handle; + cublasCreate(&handle); + + cublasOperation_t transA = (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t transB = (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta, + d_C, N); + + cudaMemcpy(C, d_C, size_C, cudaMemcpyDeviceToHost); + cublasDestroy(handle); +#elif defined USE_BLAS +#ifdef BLAS_NUM_THREADS + openblas_set_num_threads(BLAS_NUM_THREADS); +#endif + cblas_sgemm(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, + ldc); +#else + if (d_type == ml::train::TensorDim::DataType::FP32) { + sgemm_raw(order, TransA, TransB, M, N, K, alpha, (float *)A, lda, + (float *)B, ldb, beta, (float *)C, ldc); + } else if (d_type == ml::train::TensorDim::DataType::FP16) { + sgemm_FP16(order, TransA, TransB, M, N, K, alpha, (__fp16 *)A, lda, + (__fp16 *)B, ldb, beta, (__fp16 *)C, ldc); + } +#endif +} + void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const unsigned int M, const unsigned int N, const unsigned int K, const float alpha, const float *A, const unsigned int lda, @@ -239,6 +397,15 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, #endif } +void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, + const unsigned int M, const unsigned int N, const unsigned int K, + const float alpha, const __fp16 *A, const unsigned int lda, + const __fp16 *B, const unsigned int ldb, const float beta, __fp16 *C, + const unsigned int ldc) { + sgemm_FP16(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, + ldc); +} + void scopy(const unsigned int N, const void *X, const int incX, void *Y, const int incY, ml::train::TensorDim::DataType d_type) { #ifdef USE_BLAS @@ -286,6 +453,10 @@ float snrm2(const int N, const float *X, const int incX) { #endif } +__fp16 snrm2(const int N, const __fp16 *X, const int incX) { + return snrm2_FP16(N, X, incX); +} + float sdot(const unsigned int N, const float *X, const unsigned int incX, const float *Y, const unsigned int incY) { #ifdef USE_BLAS @@ -298,6 +469,33 @@ float sdot(const unsigned int N, const float *X, const unsigned int incX, #endif } +__fp16 sdot(const unsigned int N, const __fp16 *X, const unsigned int incX, + const __fp16 *Y, const unsigned int incY) { + return sdot_FP16(N, X, incX, Y, incY); +} + +void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, + const unsigned int N, const float alpha, const void *A, + const unsigned int lda, const void *X, const int incX, + const float beta, void *Y, const int incY, + ml::train::TensorDim::DataType d_type) { +#ifdef USE_BLAS +#ifdef BLAS_NUM_THREADS + openblas_set_num_threads(BLAS_NUM_THREADS); +#endif + return cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, + incY); +#else + if (d_type == ml::train::TensorDim::DataType::FP32) { + return sgemv_raw(order, TransA, M, N, alpha, (float *)A, lda, (float *)X, + incX, beta, (float *)Y, incY); + } else if (d_type == ml::train::TensorDim::DataType::FP16) { + return sgemv_FP16(order, TransA, M, N, alpha, (__fp16 *)A, lda, (__fp16 *)X, + incX, beta, (__fp16 *)Y, incY); + } +#endif +} + void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, const unsigned int N, const float alpha, const float *A, const unsigned int lda, const float *X, const int incX, @@ -313,6 +511,13 @@ void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, #endif } +void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, + const unsigned int N, const float alpha, const __fp16 *A, + const unsigned int lda, const __fp16 *X, const int incX, + const float beta, __fp16 *Y, const int incY) { + sgemv_FP16(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); +} + unsigned int isamax(const unsigned int N, const float *X, const int incX) { #ifdef USE_BLAS #ifdef BLAS_NUM_THREADS @@ -324,4 +529,9 @@ unsigned int isamax(const unsigned int N, const float *X, const int incX) { #endif } +unsigned int isamax(const unsigned int N, const __fp16 *X, const int incX) { + /// @todo isamax_FP16 for BLAS_NUM_THREADS + return isamax_FP16(N, X, incX); +} + } // namespace nntrainer diff --git a/nntrainer/tensor/blas_interface.h b/nntrainer/tensor/blas_interface.h index b560c28..63274a9 100644 --- a/nntrainer/tensor/blas_interface.h +++ b/nntrainer/tensor/blas_interface.h @@ -48,6 +48,8 @@ void sscal(const unsigned int N, const float alpha, __fp16 *X, const int incX); float snrm2(const int N, const float *X, const int incX); +__fp16 snrm2(const int N, const __fp16 *X, const int incX); + void scopy(const unsigned int N, const void *X, const int incX, void *Y, const int incY, ml::train::TensorDim::DataType d_type); @@ -60,22 +62,57 @@ void scopy(const unsigned int N, const __fp16 *X, const int incX, __fp16 *Y, float sdot(const unsigned int N, const float *X, const unsigned int incX, const float *Y, const unsigned int incY); +__fp16 sdot(const unsigned int N, const __fp16 *X, const unsigned int incX, + const __fp16 *Y, const unsigned int incY); + +void saxpy(const unsigned int N, const float alpha, const void *X, + const int incX, void *Y, const int incY, + ml::train::TensorDim::DataType d_type); + void saxpy(const unsigned int N, const float alpha, const float *X, const int incX, float *Y, const int incY); +void saxpy(const unsigned int N, const float alpha, const __fp16 *X, + const int incX, __fp16 *Y, const int incY); + +void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, + const unsigned int M, const unsigned int N, const unsigned int K, + const float alpha, const void *A, const unsigned int lda, + const void *B, const unsigned int ldb, const float beta, void *C, + const unsigned int ldc, ml::train::TensorDim::DataType d_type); + void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const unsigned int M, const unsigned int N, const unsigned int K, const float alpha, const float *A, const unsigned int lda, const float *B, const unsigned int ldb, const float beta, float *C, const unsigned int ldc); +void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, + const unsigned int M, const unsigned int N, const unsigned int K, + const float alpha, const __fp16 *A, const unsigned int lda, + const __fp16 *B, const unsigned int ldb, const float beta, __fp16 *C, + const unsigned int ldc); + +void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, + const unsigned int N, const float alpha, const void *A, + const unsigned int lda, const void *X, const int incX, + const float beta, void *Y, const int incY, + ml::train::TensorDim::DataType d_type); + void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, const unsigned int N, const float alpha, const float *A, const unsigned int lda, const float *X, const int incX, const float beta, float *Y, const int incY); +void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, + const unsigned int N, const float alpha, const __fp16 *A, + const unsigned int lda, const __fp16 *X, const int incX, + const float beta, __fp16 *Y, const int incY); + unsigned int isamax(const unsigned int N, const float *X, const int incX); +unsigned int isamax(const unsigned int N, const __fp16 *X, const int incX); + } /* namespace nntrainer */ #endif /* __cplusplus */ #endif /* __BLAS_INTERFACE_H__ */ diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 6e45c4d..a2681a9 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -1355,7 +1355,7 @@ Tensor Tensor::cat(const std::vector &tensors, int axis) { auto iter_value = [is_format_nchw](std::array &loc, const std::array &start_loc, Tensor &t, - const std::array &ref_dim_arr) -> float & { + const std::array &ref_dim_arr) -> __fp16 & { auto &value = is_format_nchw ? t.getValue<__fp16>(loc[0], loc[1], loc[2], loc[3]) : t.getValue<__fp16>(loc[0], loc[3], loc[1], loc[2]); @@ -2090,7 +2090,7 @@ Tensor &Tensor::dot(Tensor const &m, Tensor &result, bool trans, bool trans_m, const __fp16 *data = getData<__fp16>(); const __fp16 *mdata = m.getData<__fp16>(); __fp16 *rdata = result.getData<__fp16>(); - const __fp16 alpha = 1.0f; + const float alpha = 1.0f; enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans; enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans; @@ -2117,7 +2117,7 @@ Tensor &Tensor::dot(Tensor const &m, Tensor &result, bool trans, bool trans_m, sgemv(CblasRowMajor, transB, mdim1, mdim2, alpha, mdata, ldb, data, 1, beta, rdata, 1); } - /// case others: use gemm + /// case others: use sgemm else { sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, data, lda, mdata, ldb, beta, rdata, ldc); -- 2.7.4