From f9a4cd48336dac92e1f0ee357c0a6b4719b21c42 Mon Sep 17 00:00:00 2001 From: Debadri Samaddar Date: Tue, 23 Apr 2024 12:00:16 +0530 Subject: [PATCH] [hgemm] hgemm noTrans with 1x4 kernel Added hgemm_kernel_1x4 Added hgemm_noTrans_1x4 calls Added unittest dot_gemm_50_768_516 Signed-off-by: Debadri Samaddar --- nntrainer/tensor/hgemm/hgemm.cpp | 138 +++++++++++++++++ nntrainer/tensor/hgemm/hgemm.h | 40 +++++ nntrainer/tensor/hgemm/hgemm_kernel_1x4.h | 144 ++++++++++++++++++ .../unittest_nntrainer_tensor_neon_fp16.cpp | 61 ++++++++ 4 files changed, 383 insertions(+) create mode 100644 nntrainer/tensor/hgemm/hgemm_kernel_1x4.h diff --git a/nntrainer/tensor/hgemm/hgemm.cpp b/nntrainer/tensor/hgemm/hgemm.cpp index d57f9a8d..a41a5ba6 100644 --- a/nntrainer/tensor/hgemm/hgemm.cpp +++ b/nntrainer/tensor/hgemm/hgemm.cpp @@ -13,6 +13,7 @@ */ #include +#include #include #include #include @@ -21,6 +22,7 @@ #include #include +#define HGEMM_KERNEL_1x4 hgemm_kernel_1x4 #define HGEMM_KERNEL_4x4 hgemm_kernel_4x4 #define HGEMM_KERNEL_1x8 hgemm_kernel_1x8 #define HGEMM_KERNEL_4x8 hgemm_kernel_4x8 @@ -38,6 +40,8 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M, hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta); } else if (N % 8 == 0) { hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta); + } else if (N % 4 == 0) { + hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta); } else { hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta); } @@ -58,10 +62,144 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta); } else if (M % 4 == 0 && N % 4 == 0 && K % 4 == 0) { hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta); + } else if (N % 4 == 0) { + hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta); } } } +void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha, float beta) { + __fp16 *sa = alignedMalloc(M * K); + __fp16 *sb = alignedMalloc(K * N); + + unsigned int ms, mms, ns, ks; + unsigned int m_min, m2_min, n_min, k_min; + for (ms = 0; ms < M; ms += M_BLOCKING) { + m_min = M - ms; + if (m_min > M_BLOCKING) { + m_min = M_BLOCKING; + } + + for (ks = 0; ks < K; ks += k_min) { + k_min = K - ks; + if (k_min >= (K_BLOCKING << 1)) { + k_min = K_BLOCKING; + } else if (k_min > K_BLOCKING) { + k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1); + } + + n_min = N; + if (N >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (N > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1); + } + packing_B4(k_min, n_min, B + ks * ldb, ldb, sb); + + for (mms = ms; mms < ms + m_min; mms += m2_min) { + m2_min = (ms + m_min) - mms; + if (m2_min >= 3 * GEMM_UNROLLING_1) { + m2_min = 3 * GEMM_UNROLLING_1; + } else if (m2_min >= 2 * GEMM_UNROLLING_1) { + m2_min = 2 * GEMM_UNROLLING_1; + } else if (m2_min > GEMM_UNROLLING_1) { + m2_min = GEMM_UNROLLING_1; + } + + packing_A1(m2_min, k_min, A + mms * lda + ks, lda, + sa + k_min * (mms - ms)); + + HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb, + C + mms * ldc, ldc); + } + + for (ns = n_min; ns < N; ns += n_min) { + n_min = N - ns; + if (n_min >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (n_min > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1); + } + + packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb); + HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc); + } + } + } + + free(sa); + free(sb); +} + +void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, float *C, unsigned int ldc, + float alpha, float beta) { + __fp16 *sa = alignedMalloc(M * K); + __fp16 *sb = alignedMalloc(K * N); + + unsigned int ms, mms, ns, ks; + unsigned int m_min, m2_min, n_min, k_min; + for (ms = 0; ms < M; ms += M_BLOCKING) { + m_min = M - ms; + if (m_min > M_BLOCKING) { + m_min = M_BLOCKING; + } + + for (ks = 0; ks < K; ks += k_min) { + k_min = K - ks; + if (k_min >= (K_BLOCKING << 1)) { + k_min = K_BLOCKING; + } else if (k_min > K_BLOCKING) { + k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1); + } + + n_min = N; + if (N >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (N > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1); + } + packing_B4(k_min, n_min, B + ks * ldb, ldb, sb); + + for (mms = ms; mms < ms + m_min; mms += m2_min) { + m2_min = (ms + m_min) - mms; + if (m2_min >= 3 * GEMM_UNROLLING_1) { + m2_min = 3 * GEMM_UNROLLING_1; + } else if (m2_min >= 2 * GEMM_UNROLLING_1) { + m2_min = 2 * GEMM_UNROLLING_1; + } else if (m2_min > GEMM_UNROLLING_1) { + m2_min = GEMM_UNROLLING_1; + } + + packing_A1(m2_min, k_min, A + mms * lda + ks, lda, + sa + k_min * (mms - ms)); + + HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb, + C + mms * ldc, ldc); + } + + for (ns = n_min; ns < N; ns += n_min) { + n_min = N - ns; + if (n_min >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (n_min > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1); + } + + packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb); + HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc); + } + } + } + + free(sa); + free(sb); +} + void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K, const __fp16 *A, unsigned int lda, const __fp16 *B, unsigned int ldb, __fp16 *C, unsigned int ldc, diff --git a/nntrainer/tensor/hgemm/hgemm.h b/nntrainer/tensor/hgemm/hgemm.h index 2d9cc2e3..b05d89cb 100644 --- a/nntrainer/tensor/hgemm/hgemm.h +++ b/nntrainer/tensor/hgemm/hgemm.h @@ -61,6 +61,46 @@ void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K, unsigned int ldb, float *C, unsigned int ldc, float alpha = 1.F, float beta = 0.F); +/** + * @brief hgemm noTrans computation with 1x4 kernel : C = A*B, + * + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param A input matrix A + * @param lda length of the col of matrix C + * @param B input matrix B + * @param ldb length of the col of matrix C + * @param C output matrix C + * @param ldc length of the col of matrix C + * @param[in] alpha float number + * @param[in] beta float number + */ +void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha = 1.F, float beta = 0.F); + +/** + * @brief hgemm noTrans computation with 1x4 kernel : C = A*B, + * + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param A input matrix A + * @param lda length of the col of matrix C + * @param B input matrix B + * @param ldb length of the col of matrix C + * @param C output matrix C + * @param ldc length of the col of matrix C + * @param[in] alpha float number + * @param[in] beta float number + */ +void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, float *C, unsigned int ldc, + float alpha = 1.F, float beta = 0.F); + /** * @brief hgemm noTrans computation with 4x4 kernel : C = A*B, * diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h b/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h new file mode 100644 index 00000000..c189f636 --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file hgemm_kernel_1x4.h + * @date 23 April 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Debadri Samaddar + * @bug No known bugs except for NYI items + * @brief This is half-precision GEMM 1x4 kernel + * + */ + +#include +#include + +/** + * @brief hgemm 1x4 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading dimension of matrix C + */ +void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(N % 4 == 0); + + __fp16 *a = sa, *b = sb, *c = sc; + unsigned int i, j, l; + for (i = 0; i < M; i++) { + for (j = 0; j < N; j += VL_FP16_HALF) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + + for (l = 0; l < K; l += VL_FP16_HALF) { + float16x4_t v24 = {0.F}; + float16x4_t v0 = vld1_f16(b); + float16_t v16 = *a; + + v24 = vfma_n_f16(v24, v0, v16); + + float16x4_t v1 = vld1_f16(b + 4); + float16_t v17 = *(a + 1); + + v24 = vfma_n_f16(v24, v1, v17); + + float16x4_t v2 = vld1_f16(b + 8); + float16_t v18 = *(a + 2); + + v24 = vfma_n_f16(v24, v2, v18); + + float16x4_t v3 = vld1_f16(b + 12); + float16_t v19 = *(a + 3); + + v24 = vfma_n_f16(v24, v3, v19); + + __builtin_prefetch(b + 16, 0, 3); + __builtin_prefetch(a + 4, 0, 3); + + b += 16; + a += 4; + + v24 = vadd_f16(vld1_f16(c), v24); + + vst1_f16(c, v24); + } + c += 4; + a -= K; + } + sc += ldc; + c = sc; + a += K; + b = sb; + } +} + +/** + * @brief hgemm 1x4 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading dimension of matrix C + */ +void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(N % 4 == 0); + + __fp16 *a = sa, *b = sb; + float *c = sc; + unsigned int i, j, l; + for (i = 0; i < M; i++) { + for (j = 0; j < N; j += VL_FP16_HALF) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + + for (l = 0; l < K; l += VL_FP16_HALF) { + float16x4_t v24 = {0.F}; + float16x4_t v0 = vld1_f16(b); + float16_t v16 = *a; + + v24 = vfma_n_f16(v24, v0, v16); + + float16x4_t v1 = vld1_f16(b + 4); + float16_t v17 = *(a + 1); + + v24 = vfma_n_f16(v24, v1, v17); + + float16x4_t v2 = vld1_f16(b + 8); + float16_t v18 = *(a + 2); + + v24 = vfma_n_f16(v24, v2, v18); + + float16x4_t v3 = vld1_f16(b + 12); + float16_t v19 = *(a + 3); + + v24 = vfma_n_f16(v24, v3, v19); + + __builtin_prefetch(b + 16, 0, 3); + __builtin_prefetch(a + 4, 0, 3); + + b += 16; + a += 4; + + vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24))); + } + c += 4; + a -= K; + } + sc += ldc; + c = sc; + a += K; + b = sb; + } +} diff --git a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp index 6454feea..53d01858 100644 --- a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp +++ b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp @@ -658,6 +658,67 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_20000) { EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); } +TEST(nntrainer_Tensor, dot_gemm_50_768_516) { + /// @note GEMM : A X B = C + int batch = 1; + int channel = 1; + int height = 50; + int width = 768; + + int height_b = 768; + int width_b = 516; + + bool transA = false; + bool transB = false; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16); + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + const float alpha = 1e-1; + const int MOD = 10; + + GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) + + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + nntrainer::Tensor C = A.dot(B, transA, transB); + + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + + float mseErrorNeon = + mse<__fp16>(C.getData<__fp16>(), C_fp32.getData(), C.size()); + + double cosSimNeon = cosine_similarity<__fp16>( + C.getData<__fp16>(), C_fp32.getData(), C.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); +} + TEST(nntrainer_Tensor, dot_gemv_768_96000) { /// @note GEMV : A X B = C int batch = 1; -- 2.34.1