From: skykongkong8 Date: Wed, 10 Jul 2024 08:34:43 +0000 (+0900) Subject: [ hgemm/refactor ] Refactor hgemm file structure X-Git-Tag: accepted/tizen/7.0/unified/20240830.164841~44 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d257930267ed609678aa1dfb64897549d7274e38;p=platform%2Fcore%2Fml%2Fnntrainer.git [ hgemm/refactor ] Refactor hgemm file structure - Kernel functions are used regardless of matrix transpose, does need to be included from separate file. - For further optimal implemenation of matrix A / B / AB transpose blocking-kernel sequences, divide their file for convenience - Function 'hgemm' itself is better to be reside in hgemm directory. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 --- diff --git a/nntrainer/tensor/blas_neon.h b/nntrainer/tensor/blas_neon.h index b7be55d4..81f8c060 100644 --- a/nntrainer/tensor/blas_neon.h +++ b/nntrainer/tensor/blas_neon.h @@ -327,9 +327,9 @@ unsigned int isamax(const unsigned int N, const __fp16 *X); * @param[in] alpha float number * @param[in] beta float number */ -void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N, - uint32_t K, float alpha, float beta, bool TransA, bool TransB); - +void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, + uint32_t N, uint32_t K, float alpha, float beta, bool TransA, + bool TransB); /** * @brief squared root transformation with neon : X = sqrt(X) diff --git a/nntrainer/tensor/hgemm/hgemm.cpp b/nntrainer/tensor/hgemm/hgemm.cpp index 81d22b22..cadb28bb 100644 --- a/nntrainer/tensor/hgemm/hgemm.cpp +++ b/nntrainer/tensor/hgemm/hgemm.cpp @@ -12,19 +12,20 @@ * */ +#include +#include #include +#include #include #include #include #include #include #include -#include -#include - -void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, - unsigned int K, float alpha, float beta, bool TransA, bool TransB) { +void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, + unsigned int N, unsigned int K, float alpha, float beta, bool TransA, + bool TransB) { if (K == 1) { return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB); } @@ -67,9 +68,9 @@ void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned } hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB); - - unsigned int L = M*N; - unsigned int L8 = (L >> 3) <<3; + + unsigned int L = M * N; + unsigned int L8 = (L >> 3) << 3; for (unsigned int idx = 0; idx < L8; idx += 8) { float32x4_t x1 = vld1q_f32(&C32[idx]); @@ -151,11 +152,11 @@ void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, } void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, - unsigned int N, unsigned int K, float alpha, float beta, bool TransA, - bool TransB) { + unsigned int N, unsigned int K, float alpha, float beta, + bool TransA, bool TransB) { unsigned int lda = (TransA) ? M : K; unsigned int ldb = (TransB) ? K : N; - + return hgemm_K1_noTrans(M, N, K, A, lda, B, ldb, C, N, alpha, beta); if (!TransA && TransB) { diff --git a/nntrainer/tensor/hgemm/hgemm.h b/nntrainer/tensor/hgemm/hgemm.h index d2dd2894..a0c7b6f9 100644 --- a/nntrainer/tensor/hgemm/hgemm.h +++ b/nntrainer/tensor/hgemm/hgemm.h @@ -24,8 +24,9 @@ * @param[in] alpha float number * @param[in] beta float number */ -void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, - unsigned int K, float alpha, float beta, bool TransA, bool TransB); +void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, + unsigned int N, unsigned int K, float alpha, float beta, bool TransA, + bool TransB); /** * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C, @@ -54,9 +55,10 @@ void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32, * @param[in] alpha float number * @param[in] beta float number */ -void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M, - unsigned int N, unsigned int K, float alpha = 1.F, float beta = 0.F, - bool TransA = false, bool TransB = false); +void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, + unsigned int M, unsigned int N, unsigned int K, + float alpha = 1.F, float beta = 0.F, bool TransA = false, + bool TransB = false); /** * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C, * where op(X) is one of X or X**T @@ -70,5 +72,5 @@ void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M * @param[in] beta float number */ void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, - unsigned int N, unsigned int K, float alpha, float beta, bool TransA, - bool TransB); + unsigned int N, unsigned int K, float alpha, float beta, + bool TransA, bool TransB); diff --git a/nntrainer/tensor/hgemm/hgemm_common.h b/nntrainer/tensor/hgemm/hgemm_common.h index bdf9bcbc..a041e431 100644 --- a/nntrainer/tensor/hgemm/hgemm_common.h +++ b/nntrainer/tensor/hgemm/hgemm_common.h @@ -10,13 +10,10 @@ * @brief This is common settings for hgemm * */ -#include -#include - -#define A(i, j) a[(i) * lda + (j)] -#define B(i, j) b[(i) * ldb + (j)] -#define C(i, j) c[(i) * ldc + (j)] +#define A(i, j) a[(i)*lda + (j)] +#define B(i, j) b[(i)*ldb + (j)] +#define C(i, j) c[(i)*ldc + (j)] #define N_BLOCKING (768) #define K_BLOCKING (256) @@ -27,9 +24,3 @@ #define GEMM_UNROLLING_1 (1) #define VL_FP16 (8) #define VL_FP16_HALF (4) - - - -/** - * @todo Add macro for instructions in other CPU architectures - */ diff --git a/nntrainer/tensor/hgemm/hgemm_kernel.h b/nntrainer/tensor/hgemm/hgemm_kernel.h deleted file mode 100644 index 4bcea0fe..00000000 --- a/nntrainer/tensor/hgemm/hgemm_kernel.h +++ /dev/null @@ -1,13 +0,0 @@ -// #include -// #include -// #include -// #include -// #include -// #include - -// #define HGEMM_KERNEL_1x4 hgemm_kernel_1x4 -// #define HGEMM_KERNEL_4x4 hgemm_kernel_4x4 -// #define HGEMM_KERNEL_1x8 hgemm_kernel_1x8 -// #define HGEMM_KERNEL_4x8 hgemm_kernel_4x8 -// #define HGEMM_KERNEL_8x8 hgemm_kernel_8x8 -// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16 \ No newline at end of file diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel.h b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel.h new file mode 100644 index 00000000..2ebc8b46 --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel.h @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Sungsik Kong + * + * @file hgemm_kernel.h + * @date 10 July 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Sungsik Kong + * @bug No known bugs except for NYI items + * @brief This is a collection of all the KERNELs function for hgemm + * + */ + +/** + * @brief hgemm_kernel_8x16 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_8x16 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_8x8 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_8x8 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_4x8 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_4x8 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_4x4 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_4x4 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_1x8 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_1x8 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_1x4 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc); +/** + * @brief hgemm_kernel_1x4 KERNEL function + * + * @param M Length of blocked M + * @param N Length of blocked N + * @param K Length of blocked K + * @param sa Starting address of blocked A + * @param sb Starting address of blocked B + * @param sc Starting address of blocked C + * @param ldc Leading dimension of original matrix C + */ +void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc); diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x4.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x4.cpp new file mode 100644 index 00000000..2c301e59 --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x4.cpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file hgemm_kernel_1x4.cpp + * @date 23 April 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Debadri Samaddar + * @bug No known bugs except for NYI items + * @brief This is half-precision GEMM 1x4 kernel + * + */ + +#include +#include +#include +#include + +/** + * @brief hgemm 1x4 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading dimension of matrix C + */ +void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(N % 4 == 0); + + __fp16 *a = sa, *b = sb, *c = sc; + unsigned int i, j, l; + for (i = 0; i < M; i++) { + for (j = 0; j < N; j += 4) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + + for (l = 0; l < K; l += 4) { + float16x4_t v24 = {0.F}; + float16x4_t v0 = vld1_f16(b); + float16_t v16 = *a; + + v24 = vfma_n_f16(v24, v0, v16); + + float16x4_t v1 = vld1_f16(b + 4); + float16_t v17 = *(a + 1); + + v24 = vfma_n_f16(v24, v1, v17); + + float16x4_t v2 = vld1_f16(b + 8); + float16_t v18 = *(a + 2); + + v24 = vfma_n_f16(v24, v2, v18); + + float16x4_t v3 = vld1_f16(b + 12); + float16_t v19 = *(a + 3); + + v24 = vfma_n_f16(v24, v3, v19); + + __builtin_prefetch(b + 16, 0, 3); + __builtin_prefetch(a + 4, 0, 3); + + b += 16; + a += 4; + + v24 = vadd_f16(vld1_f16(c), v24); + + vst1_f16(c, v24); + } + c += 4; + a -= K; + } + sc += ldc; + c = sc; + a += K; + b = sb; + } +} + +/** + * @brief hgemm 1x4 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading dimension of matrix C + */ +void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(N % 4 == 0); + + __fp16 *a = sa, *b = sb; + float *c = sc; + unsigned int i, j, l; + for (i = 0; i < M; i++) { + for (j = 0; j < N; j += 4) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + + for (l = 0; l < K; l += 4) { + float16x4_t v24 = {0.F}; + float16x4_t v0 = vld1_f16(b); + float16_t v16 = *a; + + v24 = vfma_n_f16(v24, v0, v16); + + float16x4_t v1 = vld1_f16(b + 4); + float16_t v17 = *(a + 1); + + v24 = vfma_n_f16(v24, v1, v17); + + float16x4_t v2 = vld1_f16(b + 8); + float16_t v18 = *(a + 2); + + v24 = vfma_n_f16(v24, v2, v18); + + float16x4_t v3 = vld1_f16(b + 12); + float16_t v19 = *(a + 3); + + v24 = vfma_n_f16(v24, v3, v19); + + __builtin_prefetch(b + 16, 0, 3); + __builtin_prefetch(a + 4, 0, 3); + + b += 16; + a += 4; + + vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24))); + } + c += 4; + a -= K; + } + sc += ldc; + c = sc; + a += K; + b = sb; + } +} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x8.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x8.cpp new file mode 100644 index 00000000..35927e55 --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x8.cpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file hgemm_kernel_1x8.cpp + * @date 05 April 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Debadri Samaddar + * @author Sungsik Kong + * @bug No known bugs except for NYI items + * @brief This is half-precision GEMM 1x8 kernel + * + */ + +#include +#include +#include +#include + +// 1. Partial sum 64 digits : worst accuracy, best latency +#define KERNEL_1x8_ACC8() \ + do { \ + v0 = vdupq_n_f16(0.F); \ + dv0 = *a; \ + v24 = vld1q_f16(b); \ + v0 = vfmaq_n_f16(v0, v24, dv0); \ + dv1 = *(a + 1); \ + v25 = vld1q_f16(b + 8); \ + v0 = vfmaq_n_f16(v0, v25, dv1); \ + dv2 = *(a + 2); \ + v26 = vld1q_f16(b + 16); \ + v0 = vfmaq_n_f16(v0, v26, dv2); \ + dv3 = *(a + 3); \ + v27 = vld1q_f16(b + 24); \ + v0 = vfmaq_n_f16(v0, v27, dv3); \ + dv4 = *(a + 4); \ + v28 = vld1q_f16(b + 32); \ + v0 = vfmaq_n_f16(v0, v28, dv4); \ + dv5 = *(a + 5); \ + v29 = vld1q_f16(b + 40); \ + v0 = vfmaq_n_f16(v0, v29, dv5); \ + dv6 = *(a + 6); \ + v30 = vld1q_f16(b + 48); \ + v0 = vfmaq_n_f16(v0, v30, dv6); \ + dv7 = *(a + 7); \ + v31 = vld1q_f16(b + 56); \ + v0 = vfmaq_n_f16(v0, v31, dv7); \ + l += 8; \ + b += 8 * 8; \ + a += 8; \ + } while (0) + +// 2. Partial sum 32 digits : medium accuracy, medium latency +#define KERNEL_1x8_ACC4() \ + do { \ + v0 = vdupq_n_f16(0.F); \ + dv0 = *a; \ + v24 = vld1q_f16(b); \ + v0 = vfmaq_n_f16(v0, v24, dv0); \ + dv1 = *(a + 1); \ + v25 = vld1q_f16(b + 8); \ + v0 = vfmaq_n_f16(v0, v25, dv1); \ + dv2 = *(a + 2); \ + v26 = vld1q_f16(b + 16); \ + v0 = vfmaq_n_f16(v0, v26, dv2); \ + dv3 = *(a + 3); \ + v27 = vld1q_f16(b + 24); \ + v0 = vfmaq_n_f16(v0, v27, dv3); \ + l += 4; \ + b += 8 * 4; \ + a += 4; \ + } while (0) + +// 3. Partial sum 8 digits : Best accuracy, worst latency +#define KERNEL_1x8_ACC1() \ + do { \ + v0 = vdupq_n_f16(0.F); \ + dv0 = *(a); \ + v24 = vld1q_f16(b); \ + v0 = vfmaq_n_f16(v0, v24, dv0); \ + l += 1; \ + b += 8 * 1; \ + a++; \ + } while (0) + +/** + * @brief hgemm 1x8 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading-dimension of matrix C + */ +void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(N % 8 == 0); + + __fp16 *a = sa, *b = sb, *c = sc; + unsigned int k8 = (K >> 3) << 3; + unsigned int i, j, l; + for (i = 0; i < M; i++) { + for (j = 0; j < N; j += 8) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + float16x8_t v0; + float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; + float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; + l = 0; + for (; l < k8;) { + KERNEL_1x8_ACC8(); + + vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0)); + } + for (; l < K;) { + KERNEL_1x8_ACC1(); + + vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0)); + } + c += 8; + a -= K; + } + sc += ldc; + c = sc; + a += K; + b = sb; + } +} + +/** + * @brief hgemm 1x8 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading-dimension of matrix C + */ +void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(N % 8 == 0); + + __fp16 *a = sa, *b = sb; + float *c = sc; + unsigned int k8 = (K >> 3) << 3; + unsigned int i, j, l; + for (i = 0; i < M; i++) { + for (j = 0; j < N; j += 8) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + float16x8_t v0; + float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; + float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; + l = 0; + for (; l < k8;) { + KERNEL_1x8_ACC8(); + + vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0)))); + + vst1q_f32(c + 4, + vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0)))); + } + for (; l < K;) { + KERNEL_1x8_ACC1(); + + vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0)))); + + vst1q_f32(c + 4, + vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0)))); + } + c += 8; + a -= K; + } + sc += ldc; + c = sc; + a += K; + b = sb; + } +} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x4.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x4.cpp new file mode 100644 index 00000000..40ab4eae --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x4.cpp @@ -0,0 +1,360 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Sungsik Kong + * + * @file hgemm_kernel_4x4.cpp + * @date 01 April 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Sungsik Kong + * @bug No known bugs except for NYI items + * @brief This is half-precision GEMM 4x4 kernel + * + */ + +#include +#include +#include +#include + +#define INIT_KERNEL_4x4() \ + do { \ + v24 = vdup_n_f16(0.F); \ + v25 = vdup_n_f16(0.F); \ + v26 = vdup_n_f16(0.F); \ + v27 = vdup_n_f16(0.F); \ + } while (0) + +// 1. Partial sum 256 digits +#define KERNEL_4x4_ACC16() \ + do { \ + dv0 = vld1_f16(a); \ + vb0 = vld1_f16(b); \ + v24 = vfma_lane_f16(v24, vb0, dv0, 0); \ + v25 = vfma_lane_f16(v25, vb0, dv0, 1); \ + v26 = vfma_lane_f16(v26, vb0, dv0, 2); \ + v27 = vfma_lane_f16(v27, vb0, dv0, 3); \ + dv1 = vld1_f16(a + 4); \ + vb1 = vld1_f16(b + 4); \ + v24 = vfma_lane_f16(v24, vb1, dv1, 0); \ + v25 = vfma_lane_f16(v25, vb1, dv1, 1); \ + v26 = vfma_lane_f16(v26, vb1, dv1, 2); \ + v27 = vfma_lane_f16(v27, vb1, dv1, 3); \ + dv2 = vld1_f16(a + 4 * 2); \ + vb2 = vld1_f16(b + 4 * 2); \ + v24 = vfma_lane_f16(v24, vb2, dv2, 0); \ + v25 = vfma_lane_f16(v25, vb2, dv2, 1); \ + v26 = vfma_lane_f16(v26, vb2, dv2, 2); \ + v27 = vfma_lane_f16(v27, vb2, dv2, 3); \ + dv3 = vld1_f16(a + 4 * 3); \ + vb3 = vld1_f16(b + 4 * 3); \ + v24 = vfma_lane_f16(v24, vb3, dv3, 0); \ + v25 = vfma_lane_f16(v25, vb3, dv3, 1); \ + v26 = vfma_lane_f16(v26, vb3, dv3, 2); \ + v27 = vfma_lane_f16(v27, vb3, dv3, 3); \ + dv4 = vld1_f16(a + 4 * 4); \ + vb4 = vld1_f16(b + 4 * 4); \ + v24 = vfma_lane_f16(v24, vb4, dv4, 0); \ + v25 = vfma_lane_f16(v25, vb4, dv4, 1); \ + v26 = vfma_lane_f16(v26, vb4, dv4, 2); \ + v27 = vfma_lane_f16(v27, vb4, dv4, 3); \ + dv5 = vld1_f16(a + 4 * 5); \ + vb5 = vld1_f16(b + 4 * 5); \ + v24 = vfma_lane_f16(v24, vb5, dv5, 0); \ + v25 = vfma_lane_f16(v25, vb5, dv5, 1); \ + v26 = vfma_lane_f16(v26, vb5, dv5, 2); \ + v27 = vfma_lane_f16(v27, vb5, dv5, 3); \ + dv6 = vld1_f16(a + 4 * 6); \ + vb6 = vld1_f16(b + 4 * 6); \ + v24 = vfma_lane_f16(v24, vb6, dv6, 0); \ + v25 = vfma_lane_f16(v25, vb6, dv6, 1); \ + v26 = vfma_lane_f16(v26, vb6, dv6, 2); \ + v27 = vfma_lane_f16(v27, vb6, dv6, 3); \ + dv7 = vld1_f16(a + 4 * 7); \ + vb7 = vld1_f16(b + 4 * 7); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 8); \ + vb7 = vld1_f16(b + 4 * 8); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 9); \ + vb7 = vld1_f16(b + 4 * 9); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 10); \ + vb7 = vld1_f16(b + 4 * 10); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 11); \ + vb7 = vld1_f16(b + 4 * 11); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 12); \ + vb7 = vld1_f16(b + 4 * 12); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 13); \ + vb7 = vld1_f16(b + 4 * 13); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 14); \ + vb7 = vld1_f16(b + 4 * 14); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 15); \ + vb7 = vld1_f16(b + 4 * 15); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + l += 16; \ + __builtin_prefetch(b + 64, 0, 3); \ + __builtin_prefetch(a + 64, 0, 3); \ + b += 4 * 16; \ + a += 4 * 16; \ + } while (0) + +// 2. Partial sum 128 digits +#define KERNEL_4x4_ACC8() \ + do { \ + dv0 = vld1_f16(a); \ + vb0 = vld1_f16(b); \ + v24 = vfma_lane_f16(v24, vb0, dv0, 0); \ + v25 = vfma_lane_f16(v25, vb0, dv0, 1); \ + v26 = vfma_lane_f16(v26, vb0, dv0, 2); \ + v27 = vfma_lane_f16(v27, vb0, dv0, 3); \ + dv1 = vld1_f16(a + 4); \ + vb1 = vld1_f16(b + 4); \ + v24 = vfma_lane_f16(v24, vb1, dv1, 0); \ + v25 = vfma_lane_f16(v25, vb1, dv1, 1); \ + v26 = vfma_lane_f16(v26, vb1, dv1, 2); \ + v27 = vfma_lane_f16(v27, vb1, dv1, 3); \ + dv2 = vld1_f16(a + 8); \ + vb2 = vld1_f16(b + 8); \ + v24 = vfma_lane_f16(v24, vb2, dv2, 0); \ + v25 = vfma_lane_f16(v25, vb2, dv2, 1); \ + v26 = vfma_lane_f16(v26, vb2, dv2, 2); \ + v27 = vfma_lane_f16(v27, vb2, dv2, 3); \ + dv3 = vld1_f16(a + 12); \ + vb3 = vld1_f16(b + 12); \ + v24 = vfma_lane_f16(v24, vb3, dv3, 0); \ + v25 = vfma_lane_f16(v25, vb3, dv3, 1); \ + v26 = vfma_lane_f16(v26, vb3, dv3, 2); \ + v27 = vfma_lane_f16(v27, vb3, dv3, 3); \ + dv4 = vld1_f16(a + 16); \ + vb4 = vld1_f16(b + 16); \ + v24 = vfma_lane_f16(v24, vb4, dv4, 0); \ + v25 = vfma_lane_f16(v25, vb4, dv4, 1); \ + v26 = vfma_lane_f16(v26, vb4, dv4, 2); \ + v27 = vfma_lane_f16(v27, vb4, dv4, 3); \ + dv5 = vld1_f16(a + 20); \ + vb5 = vld1_f16(b + 20); \ + v24 = vfma_lane_f16(v24, vb5, dv5, 0); \ + v25 = vfma_lane_f16(v25, vb5, dv5, 1); \ + v26 = vfma_lane_f16(v26, vb5, dv5, 2); \ + v27 = vfma_lane_f16(v27, vb5, dv5, 3); \ + dv6 = vld1_f16(a + 24); \ + vb6 = vld1_f16(b + 24); \ + v24 = vfma_lane_f16(v24, vb6, dv6, 0); \ + v25 = vfma_lane_f16(v25, vb6, dv6, 1); \ + v26 = vfma_lane_f16(v26, vb6, dv6, 2); \ + v27 = vfma_lane_f16(v27, vb6, dv6, 3); \ + dv7 = vld1_f16(a + 28); \ + vb7 = vld1_f16(b + 28); \ + v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ + v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ + v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ + v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ + l += 8; \ + __builtin_prefetch(b + 32, 0, 3); \ + __builtin_prefetch(a + 32, 0, 3); \ + b += 4 * 8; \ + a += 4 * 8; \ + } while (0) + +// 3. Partial sum 16 digits +#define KERNEL_4x4_ACC1() \ + do { \ + dv0 = vld1_f16(a); \ + vb0 = vld1_f16(b); \ + v24 = vfma_lane_f16(v24, vb0, dv0, 0); \ + v25 = vfma_lane_f16(v25, vb0, dv0, 1); \ + v26 = vfma_lane_f16(v26, vb0, dv0, 2); \ + v27 = vfma_lane_f16(v27, vb0, dv0, 3); \ + l += 1; \ + __builtin_prefetch(b + 4, 0, 3); \ + __builtin_prefetch(a + 4, 0, 3); \ + b += 4 * 1; \ + a += 4 * 1; \ + } while (0) + +#define SAVE_KERNEL_4X4_F16_F32() \ + do { \ + vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24))); \ + vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(v25))); \ + vst1q_f32(c + 2 * ldc, \ + vaddq_f32(vld1q_f32(c + 2 * ldc), vcvt_f32_f16(v26))); \ + vst1q_f32(c + 3 * ldc, \ + vaddq_f32(vld1q_f32(c + 3 * ldc), vcvt_f32_f16(v27))); \ + } while (0) + +/** + * @brief hgemm 4x4 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading dimension of matrix C + */ +void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0); + + __fp16 *a = sa, *b = sb, *c = sc; + unsigned int i, j, l; + for (i = 0; i < M; i += 4) { + for (j = 0; j < N; j += 4) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + + float16x4_t v24; + float16x4_t v25; + float16x4_t v26; + float16x4_t v27; + INIT_KERNEL_4x4(); + + for (l = 0; l < K; l += 4) { + float16x4_t v0 = vld1_f16(b); + float16x4_t v16 = vld1_f16(a); + + v24 = vfma_lane_f16(v24, v0, v16, 0); + v25 = vfma_lane_f16(v25, v0, v16, 1); + v26 = vfma_lane_f16(v26, v0, v16, 2); + v27 = vfma_lane_f16(v27, v0, v16, 3); + + float16x4_t v1 = vld1_f16(b + 4); + float16x4_t v17 = vld1_f16(a + 4); + + v24 = vfma_lane_f16(v24, v1, v17, 0); + v25 = vfma_lane_f16(v25, v1, v17, 1); + v26 = vfma_lane_f16(v26, v1, v17, 2); + v27 = vfma_lane_f16(v27, v1, v17, 3); + + float16x4_t v2 = vld1_f16(b + 8); + float16x4_t v18 = vld1_f16(a + 8); + + v24 = vfma_lane_f16(v24, v2, v18, 0); + v25 = vfma_lane_f16(v25, v2, v18, 1); + v26 = vfma_lane_f16(v26, v2, v18, 2); + v27 = vfma_lane_f16(v27, v2, v18, 3); + + float16x4_t v3 = vld1_f16(b + 12); + float16x4_t v19 = vld1_f16(a + 12); + + v24 = vfma_lane_f16(v24, v3, v19, 0); + v25 = vfma_lane_f16(v25, v3, v19, 1); + v26 = vfma_lane_f16(v26, v3, v19, 2); + v27 = vfma_lane_f16(v27, v3, v19, 3); + + __builtin_prefetch(b + 16, 0, 3); + __builtin_prefetch(a + 16, 0, 3); + + b += 16; + a += 16; + } + + v24 = vadd_f16(vld1_f16(c), v24); + v25 = vadd_f16(vld1_f16(c + ldc), v25); + v26 = vadd_f16(vld1_f16(c + 2 * ldc), v26); + v27 = vadd_f16(vld1_f16(c + 3 * ldc), v27); + + vst1_f16(c, v24); + vst1_f16(c + ldc, v25); + vst1_f16(c + 2 * ldc, v26); + vst1_f16(c + 3 * ldc, v27); + + c += 4; + a -= 4 * K; + } + sc += ldc * 4; + c = sc; + a += 4 * K; + b = sb; + } +} + +/** + * @brief hgemm 4x4 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading dimension of matrix C + */ +void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0); + + __fp16 *a = sa, *b = sb; + float *c = sc; + unsigned int i, j, l; + unsigned int K16 = (K >> 4) << 4; + unsigned int K8 = (K >> 3) << 3; + for (i = 0; i < M; i += 4) { + for (j = 0; j < N; j += 4) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + + float16x4_t v24, v25, v26, v27; + float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; + float16x4_t vb0, vb1, vb2, vb3, vb4, vb5, vb6, vb7; + l = 0; + for (; l < K16;) { + INIT_KERNEL_4x4(); + KERNEL_4x4_ACC16(); + SAVE_KERNEL_4X4_F16_F32(); + } + for (; l < K8;) { + INIT_KERNEL_4x4(); + KERNEL_4x4_ACC8(); + SAVE_KERNEL_4X4_F16_F32(); + } + for (; l < K;) { + INIT_KERNEL_4x4(); + KERNEL_4x4_ACC1(); + SAVE_KERNEL_4X4_F16_F32(); + } + + c += 4; + a -= 4 * K; + } + sc += ldc * 4; + c = sc; + a += 4 * K; + b = sb; + } +} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x8.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x8.cpp new file mode 100644 index 00000000..3cebee45 --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x8.cpp @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Sungsik Kong + * + * @file hgemm_kernel_4x8.cpp + * @date 03 April 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Sungsik Kong + * @bug No known bugs except for NYI items + * @brief This is half-precision GEMM 8x8 kernel + * + */ + +#include +#include +#include +#include + +#define INIT_KERNEL_4X8() \ + do { \ + v0 = vdupq_n_f16(0.F); \ + v3 = vdupq_n_f16(0.F); \ + v6 = vdupq_n_f16(0.F); \ + v9 = vdupq_n_f16(0.F); \ + } while (0) + +// 1. Partial sum 256 digits +#define KERNEL_4x8_ACC16() \ + do { \ + dv0 = vld1_f16(a); \ + v24 = vld1q_f16(b); \ + v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \ + v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \ + v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \ + v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \ + dv1 = vld1_f16(a + 4); \ + v25 = vld1q_f16(b + 8); \ + v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \ + v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \ + v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \ + v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \ + dv2 = vld1_f16(a + 4 * 2); \ + v26 = vld1q_f16(b + 8 * 2); \ + v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \ + v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \ + v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \ + v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \ + dv3 = vld1_f16(a + 4 * 3); \ + v27 = vld1q_f16(b + 8 * 3); \ + v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \ + v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \ + v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \ + v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \ + dv4 = vld1_f16(a + 4 * 4); \ + v28 = vld1q_f16(b + 8 * 4); \ + v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \ + v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \ + v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \ + v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \ + dv5 = vld1_f16(a + 4 * 5); \ + v29 = vld1q_f16(b + 8 * 5); \ + v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \ + v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \ + v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \ + v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \ + dv6 = vld1_f16(a + 4 * 6); \ + v30 = vld1q_f16(b + 8 * 6); \ + v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \ + v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \ + v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \ + v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \ + dv7 = vld1_f16(a + 4 * 7); \ + v31 = vld1q_f16(b + 8 * 7); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 8); \ + v31 = vld1q_f16(b + 8 * 8); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 9); \ + v31 = vld1q_f16(b + 8 * 9); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 10); \ + v31 = vld1q_f16(b + 8 * 10); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 11); \ + v31 = vld1q_f16(b + 8 * 11); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 12); \ + v31 = vld1q_f16(b + 8 * 12); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 13); \ + v31 = vld1q_f16(b + 8 * 13); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 14); \ + v31 = vld1q_f16(b + 8 * 14); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + dv7 = vld1_f16(a + 4 * 15); \ + v31 = vld1q_f16(b + 8 * 15); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + l += 16; \ + __builtin_prefetch(b + 128, 0, 3); \ + __builtin_prefetch(a + 64, 0, 3); \ + b += 8 * 16; \ + a += 4 * 16; \ + } while (0) + +// 1. Partial sum 256 digits +#define KERNEL_4x8_ACC8() \ + do { \ + dv0 = vld1_f16(a); \ + v24 = vld1q_f16(b); \ + v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \ + v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \ + v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \ + v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \ + dv1 = vld1_f16(a + 4); \ + v25 = vld1q_f16(b + 8); \ + v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \ + v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \ + v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \ + v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \ + dv2 = vld1_f16(a + 8); \ + v26 = vld1q_f16(b + 16); \ + v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \ + v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \ + v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \ + v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \ + dv3 = vld1_f16(a + 12); \ + v27 = vld1q_f16(b + 24); \ + v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \ + v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \ + v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \ + v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \ + dv4 = vld1_f16(a + 16); \ + v28 = vld1q_f16(b + 32); \ + v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \ + v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \ + v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \ + v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \ + dv5 = vld1_f16(a + 20); \ + v29 = vld1q_f16(b + 40); \ + v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \ + v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \ + v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \ + v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \ + dv6 = vld1_f16(a + 24); \ + v30 = vld1q_f16(b + 48); \ + v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \ + v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \ + v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \ + v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \ + dv7 = vld1_f16(a + 28); \ + v31 = vld1q_f16(b + 56); \ + v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ + v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ + v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ + v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ + l += 8; \ + __builtin_prefetch(b + 64, 0, 3); \ + __builtin_prefetch(a + 32, 0, 3); \ + b += 8 * 8; \ + a += 4 * 8; \ + } while (0) + +// 2. Partial sum 128 digits +#define KERNEL_4x8_ACC4() \ + do { \ + dv0 = vld1_f16(a); \ + v24 = vld1q_f16(b); \ + v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \ + v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \ + v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \ + v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \ + dv1 = vld1_f16(a + 4); \ + v25 = vld1q_f16(b + 8); \ + v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \ + v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \ + v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \ + v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \ + dv2 = vld1_f16(a + 8); \ + v26 = vld1q_f16(b + 16); \ + v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \ + v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \ + v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \ + v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \ + dv3 = vld1_f16(a + 12); \ + v27 = vld1q_f16(b + 24); \ + v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \ + v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \ + v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \ + v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \ + l += 4; \ + __builtin_prefetch(b + 32, 0, 3); \ + __builtin_prefetch(a + 16, 0, 3); \ + b += 8 * 4; \ + a += 4 * 4; \ + } while (0) + +// 3. Partial sum 32 digits +#define KERNEL_4x8_ACC1() \ + do { \ + dv0 = vld1_f16(a); \ + v24 = vld1q_f16(b); \ + v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \ + v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \ + v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \ + v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \ + l += 1; \ + __builtin_prefetch(b + 8, 0, 3); \ + __builtin_prefetch(a + 4, 0, 3); \ + b += 8 * 1; \ + a += 4 * 1; \ + } while (0) + +#define SAVE_KERNEL_4X8_F16_F32() \ + do { \ + vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0)))); \ + vst1q_f32(c + ldc, \ + vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v3)))); \ + vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \ + vcvt_f32_f16(vget_low_f16(v6)))); \ + vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \ + vcvt_f32_f16(vget_low_f16(v9)))); \ + \ + vst1q_f32(c + 4, \ + vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0)))); \ + vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc), \ + vcvt_f32_f16(vget_high_f16(v3)))); \ + vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc), \ + vcvt_f32_f16(vget_high_f16(v6)))); \ + vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc), \ + vcvt_f32_f16(vget_high_f16(v9)))); \ + } while (0) + +/** + * @brief hgemm 4x8 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading-dimension of matrix C + */ +void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(M % 4 == 0 && N % 8 == 0); + + __fp16 *a = sa, *b = sb, *c = sc; + unsigned int K8 = (K >> 3) << 3; + unsigned int i, j, l; + for (i = 0; i < M; i += 4) { + for (j = 0; j < N; j += 8) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + float16x8_t v0, v3, v6, v9; + float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; + float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; + INIT_KERNEL_4X8(); + l = 0; + for (; l < K8;) { + KERNEL_4x8_ACC8(); + } + for (; l < K;) { + KERNEL_4x8_ACC1(); + } + vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0)); + vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3)); + vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6)); + vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9)); + c += 8; + a -= 4 * K; + } + sc += ldc * 4; + c = sc; + a += 4 * K; + b = sb; + } +} + +/** + * @brief hgemm 4x8 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading-dimension of matrix C + */ +void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(M % 4 == 0 && N % 8 == 0); + + __fp16 *a = sa, *b = sb; + float *c = sc; + unsigned int K16 = (K >> 4) << 4; + unsigned int K8 = (K >> 3) << 3; + unsigned int K4 = (K >> 2) << 2; + unsigned int i, j, l; + for (i = 0; i < M; i += 4) { + for (j = 0; j < N; j += 8) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + float16x8_t v0, v3, v6, v9; + float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; + float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; + l = 0; + for (; l < K16;) { + INIT_KERNEL_4X8(); + KERNEL_4x8_ACC16(); + SAVE_KERNEL_4X8_F16_F32(); + } + for (; l < K8;) { + INIT_KERNEL_4X8(); + KERNEL_4x8_ACC8(); + SAVE_KERNEL_4X8_F16_F32(); + } + for (; l < K4;) { + INIT_KERNEL_4X8(); + KERNEL_4x8_ACC4(); + SAVE_KERNEL_4X8_F16_F32(); + } + for (; l < K;) { + INIT_KERNEL_4X8(); + KERNEL_4x8_ACC1(); + SAVE_KERNEL_4X8_F16_F32(); + } + c += 8; + a -= 4 * K; + } + sc += ldc * 4; + c = sc; + a += 4 * K; + b = sb; + } +} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x16.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x16.cpp new file mode 100644 index 00000000..f8d6b56c --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x16.cpp @@ -0,0 +1,863 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Sungsik Kong + * + * @file hgemm_kernel_8x16.cpp + * @date 04 April 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Sungsik Kong + * @bug No known bugs except for NYI items + * @brief This is half-precision GEMM 8x16 kernel + * + */ + +#include +#include +#include +#include +#include + +#define INIT_KERNEL_8X16() \ + do { \ + v0_7 = vdupq_n_f16(0.F); \ + v8_15 = vdupq_n_f16(0.F); \ + v16_23 = vdupq_n_f16(0.F); \ + v24_31 = vdupq_n_f16(0.F); \ + v32_39 = vdupq_n_f16(0.F); \ + v40_47 = vdupq_n_f16(0.F); \ + v48_55 = vdupq_n_f16(0.F); \ + v56_63 = vdupq_n_f16(0.F); \ + v64_71 = vdupq_n_f16(0.F); \ + v72_79 = vdupq_n_f16(0.F); \ + v80_87 = vdupq_n_f16(0.F); \ + v88_95 = vdupq_n_f16(0.F); \ + v96_103 = vdupq_n_f16(0.F); \ + v104_111 = vdupq_n_f16(0.F); \ + v112_119 = vdupq_n_f16(0.F); \ + v120_127 = vdupq_n_f16(0.F); \ + } while (0) + +// 1. Partial sum 2048 digits +#define KERNEL_8x16_ACC16() \ + do { \ + va0 = vld1q_f16(a + 8 * 0); \ + vb1 = vld1q_f16(b + 8 * 0); \ + vb2 = vld1q_f16(b + 8 * 1); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 1); \ + vb1 = vld1q_f16(b + 8 * 2); \ + vb2 = vld1q_f16(b + 8 * 3); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 2); \ + vb1 = vld1q_f16(b + 8 * 4); \ + vb2 = vld1q_f16(b + 8 * 5); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 3); \ + vb1 = vld1q_f16(b + 8 * 6); \ + vb2 = vld1q_f16(b + 8 * 7); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 4); \ + vb1 = vld1q_f16(b + 8 * 8); \ + vb2 = vld1q_f16(b + 8 * 9); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 5); \ + vb1 = vld1q_f16(b + 8 * 10); \ + vb2 = vld1q_f16(b + 8 * 11); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 6); \ + vb1 = vld1q_f16(b + 8 * 12); \ + vb2 = vld1q_f16(b + 8 * 13); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 7); \ + vb1 = vld1q_f16(b + 8 * 14); \ + vb2 = vld1q_f16(b + 8 * 15); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 8); \ + vb1 = vld1q_f16(b + 8 * 16); \ + vb2 = vld1q_f16(b + 8 * 17); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 9); \ + vb1 = vld1q_f16(b + 8 * 18); \ + vb2 = vld1q_f16(b + 8 * 19); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 10); \ + vb1 = vld1q_f16(b + 8 * 20); \ + vb2 = vld1q_f16(b + 8 * 21); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 11); \ + vb1 = vld1q_f16(b + 8 * 22); \ + vb2 = vld1q_f16(b + 8 * 23); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 12); \ + vb1 = vld1q_f16(b + 8 * 24); \ + vb2 = vld1q_f16(b + 8 * 25); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 13); \ + vb1 = vld1q_f16(b + 8 * 26); \ + vb2 = vld1q_f16(b + 8 * 27); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 14); \ + vb1 = vld1q_f16(b + 8 * 28); \ + vb2 = vld1q_f16(b + 8 * 29); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8 * 15); \ + vb1 = vld1q_f16(b + 8 * 30); \ + vb2 = vld1q_f16(b + 8 * 31); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + __builtin_prefetch(b + 256, 0, 3); \ + __builtin_prefetch(a + 128, 0, 3); \ + l += 16; \ + b += 16 * 16; \ + a += 8 * 16; \ + } while (0) + +// 2. Partial sum 1024 digits +#define KERNEL_8x16_ACC8() \ + do { \ + va0 = vld1q_f16(a); \ + vb1 = vld1q_f16(b); \ + vb2 = vld1q_f16(b + 8); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8); \ + vb1 = vld1q_f16(b + 16); \ + vb2 = vld1q_f16(b + 24); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 16); \ + vb1 = vld1q_f16(b + 32); \ + vb2 = vld1q_f16(b + 40); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 24); \ + vb1 = vld1q_f16(b + 48); \ + vb2 = vld1q_f16(b + 56); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 32); \ + vb1 = vld1q_f16(b + 64); \ + vb2 = vld1q_f16(b + 72); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 40); \ + vb1 = vld1q_f16(b + 80); \ + vb2 = vld1q_f16(b + 88); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 48); \ + vb1 = vld1q_f16(b + 96); \ + vb2 = vld1q_f16(b + 104); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 56); \ + vb1 = vld1q_f16(b + 112); \ + vb2 = vld1q_f16(b + 120); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + l += 8; \ + __builtin_prefetch(b + 128, 0, 3); \ + __builtin_prefetch(a + 64, 0, 3); \ + b += 16 * 8; \ + a += 8 * 8; \ + } while (0) + +// 3. Partial sum 512 digits +#define KERNEL_8x16_ACC4() \ + do { \ + va0 = vld1q_f16(a); \ + vb1 = vld1q_f16(b); \ + vb2 = vld1q_f16(b + 8); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 8); \ + vb1 = vld1q_f16(b + 16); \ + vb2 = vld1q_f16(b + 24); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 16); \ + vb1 = vld1q_f16(b + 32); \ + vb2 = vld1q_f16(b + 40); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + va0 = vld1q_f16(a + 24); \ + vb1 = vld1q_f16(b + 48); \ + vb2 = vld1q_f16(b + 56); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + l += 4; \ + __builtin_prefetch(b + 64, 0, 3); \ + __builtin_prefetch(a + 32, 0, 3); \ + b += 16 * 4; \ + a += 8 * 4; \ + } while (0) + +// 4. Partial sum 128 digits +#define KERNEL_8x16_ACC1() \ + do { \ + va0 = vld1q_f16(a); \ + vb1 = vld1q_f16(b); \ + vb2 = vld1q_f16(b + 8); \ + v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ + v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ + v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ + v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ + v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ + v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ + v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ + v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ + v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ + v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ + v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ + v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ + v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ + v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ + v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ + v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ + l += 1; \ + __builtin_prefetch(b + 16, 0, 3); \ + __builtin_prefetch(a + 8, 0, 3); \ + b += 16 * 1; \ + a += 8 * 1; \ + } while (0) + +#define SAVE_KERNEL_8X16_F16_F32() \ + do { \ + vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0_7)))); \ + vst1q_f32(c + 4, \ + vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0_7)))); \ + \ + vst1q_f32( \ + c + 8, vaddq_f32(vld1q_f32(c + 8), vcvt_f32_f16(vget_low_f16(v64_71)))); \ + vst1q_f32(c + 8 + 4, vaddq_f32(vld1q_f32(c + 8 + 4), \ + vcvt_f32_f16(vget_high_f16(v64_71)))); \ + \ + vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), \ + vcvt_f32_f16(vget_low_f16(v8_15)))); \ + vst1q_f32(c + ldc + 4, vaddq_f32(vld1q_f32(c + ldc + 4), \ + vcvt_f32_f16(vget_high_f16(v8_15)))); \ + \ + vst1q_f32(c + ldc + 8, vaddq_f32(vld1q_f32(c + ldc + 8), \ + vcvt_f32_f16(vget_low_f16(v72_79)))); \ + vst1q_f32(c + ldc + 8 + 4, \ + vaddq_f32(vld1q_f32(c + ldc + 8 + 4), \ + vcvt_f32_f16(vget_high_f16(v72_79)))); \ + \ + vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \ + vcvt_f32_f16(vget_low_f16(v16_23)))); \ + vst1q_f32(c + 2 * ldc + 4, \ + vaddq_f32(vld1q_f32(c + 2 * ldc + 4), \ + vcvt_f32_f16(vget_high_f16(v16_23)))); \ + \ + vst1q_f32(c + 2 * ldc + 8, vaddq_f32(vld1q_f32(c + 2 * ldc + 8), \ + vcvt_f32_f16(vget_low_f16(v80_87)))); \ + vst1q_f32(c + 2 * ldc + 8 + 4, \ + vaddq_f32(vld1q_f32(c + 2 * ldc + 8 + 4), \ + vcvt_f32_f16(vget_high_f16(v80_87)))); \ + \ + vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \ + vcvt_f32_f16(vget_low_f16(v24_31)))); \ + vst1q_f32(c + 3 * ldc + 4, \ + vaddq_f32(vld1q_f32(c + 3 * ldc + 4), \ + vcvt_f32_f16(vget_high_f16(v24_31)))); \ + \ + vst1q_f32(c + 3 * ldc + 8, vaddq_f32(vld1q_f32(c + 3 * ldc + 8), \ + vcvt_f32_f16(vget_low_f16(v88_95)))); \ + vst1q_f32(c + 3 * ldc + 8 + 4, \ + vaddq_f32(vld1q_f32(c + 3 * ldc + 8 + 4), \ + vcvt_f32_f16(vget_high_f16(v88_95)))); \ + \ + vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc), \ + vcvt_f32_f16(vget_low_f16(v32_39)))); \ + vst1q_f32(c + 4 * ldc + 4, \ + vaddq_f32(vld1q_f32(c + 4 * ldc + 4), \ + vcvt_f32_f16(vget_high_f16(v32_39)))); \ + \ + vst1q_f32(c + 4 * ldc + 8, \ + vaddq_f32(vld1q_f32(c + 4 * ldc + 8), \ + vcvt_f32_f16(vget_low_f16(v96_103)))); \ + vst1q_f32(c + 4 * ldc + 8 + 4, \ + vaddq_f32(vld1q_f32(c + 4 * ldc + 8 + 4), \ + vcvt_f32_f16(vget_high_f16(v96_103)))); \ + \ + vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc), \ + vcvt_f32_f16(vget_low_f16(v40_47)))); \ + vst1q_f32(c + 5 * ldc + 4, \ + vaddq_f32(vld1q_f32(c + 5 * ldc + 4), \ + vcvt_f32_f16(vget_high_f16(v40_47)))); \ + vst1q_f32(c + 5 * ldc + 8, \ + vaddq_f32(vld1q_f32(c + 5 * ldc + 8), \ + vcvt_f32_f16(vget_low_f16(v104_111)))); \ + vst1q_f32(c + 5 * ldc + 8 + 4, \ + vaddq_f32(vld1q_f32(c + 5 * ldc + 8 + 4), \ + vcvt_f32_f16(vget_high_f16(v104_111)))); \ + \ + vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc), \ + vcvt_f32_f16(vget_low_f16(v48_55)))); \ + vst1q_f32(c + 6 * ldc + 4, \ + vaddq_f32(vld1q_f32(c + 6 * ldc + 4), \ + vcvt_f32_f16(vget_high_f16(v48_55)))); \ + \ + vst1q_f32(c + 6 * ldc + 8, \ + vaddq_f32(vld1q_f32(c + 6 * ldc + 8), \ + vcvt_f32_f16(vget_low_f16(v112_119)))); \ + vst1q_f32(c + 6 * ldc + 8 + 4, \ + vaddq_f32(vld1q_f32(c + 6 * ldc + 8 + 4), \ + vcvt_f32_f16(vget_high_f16(v112_119)))); \ + \ + vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc), \ + vcvt_f32_f16(vget_low_f16(v56_63)))); \ + vst1q_f32(c + 7 * ldc + 4, \ + vaddq_f32(vld1q_f32(c + 7 * ldc + 4), \ + vcvt_f32_f16(vget_high_f16(v56_63)))); \ + \ + vst1q_f32(c + 7 * ldc + 8, \ + vaddq_f32(vld1q_f32(c + 7 * ldc + 8), \ + vcvt_f32_f16(vget_low_f16(v120_127)))); \ + vst1q_f32(c + 7 * ldc + 8 + 4, \ + vaddq_f32(vld1q_f32(c + 7 * ldc + 8 + 4), \ + vcvt_f32_f16(vget_high_f16(v120_127)))); \ + } while (0) + +/** + * @brief hgemm 8x16 kernel sc = sa * sb + * + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading-dimension of matrix C + */ +void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(M % 8 == 0 && N % 16 == 0 && K % 8 == 0); + + __fp16 *a = sa, *b = sb, *c = sc; + unsigned int i, j, l; + for (i = 0; i < M; i += 8) { + for (j = 0; j < N; j += 16) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + // 8x16 + float16x8_t v0_7, v8_15; + float16x8_t v16_23, v24_31; + float16x8_t v32_39, v40_47; + float16x8_t v48_55, v56_63; + float16x8_t v64_71, v72_79; + float16x8_t v80_87, v88_95; + float16x8_t v96_103, v104_111; + float16x8_t v112_119, v120_127; + float16x8_t vb1, vb2; + float16x8_t va0; + + INIT_KERNEL_8X16(); + l = 0; + for (; l < K;) { + KERNEL_8x16_ACC1(); + } + vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7)); + vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71)); + vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15)); + vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79)); + vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23)); + vst1q_f16(c + 2 * ldc + 8, vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87)); + vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31)); + vst1q_f16(c + 3 * ldc + 8, vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95)); + vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39)); + vst1q_f16(c + 4 * ldc + 8, + vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103)); + vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47)); + vst1q_f16(c + 5 * ldc + 8, + vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111)); + vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55)); + vst1q_f16(c + 6 * ldc + 8, + vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119)); + vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63)); + vst1q_f16(c + 7 * ldc + 8, + vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127)); + c += 16; + a -= 8 * K; + } + sc += ldc * 8; + c = sc; + a += 8 * K; + b = sb; + } +} + +/** + * @brief hgemm 8x16 kernel sc = sa * sb + * + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading-dimension of matrix C + */ +void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(M % 8 == 0 && N % 16 == 0 && K % 4 == 0); + + __fp16 *a = sa, *b = sb; + float *c = sc; + unsigned int i, j, l; + unsigned int K4 = (K >> 2) << 2; + unsigned int K8 = (K >> 3) << 3; + unsigned int K16 = (K >> 4) << 4; + for (i = 0; i < M; i += 8) { + for (j = 0; j < N; j += 16) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + float16x8_t v0_7, v8_15; + float16x8_t v16_23, v24_31; + float16x8_t v32_39, v40_47; + float16x8_t v48_55, v56_63; + float16x8_t v64_71, v72_79; + float16x8_t v80_87, v88_95; + float16x8_t v96_103, v104_111; + float16x8_t v112_119, v120_127; + float16x8_t vb1, vb2; + float16x8_t va0; + l = 0; + for (; l < K16;) { + INIT_KERNEL_8X16(); + KERNEL_8x16_ACC16(); + SAVE_KERNEL_8X16_F16_F32(); + } + for (; l < K8;) { + INIT_KERNEL_8X16(); + KERNEL_8x16_ACC8(); + SAVE_KERNEL_8X16_F16_F32(); + } + for (; l < K4;) { + INIT_KERNEL_8X16(); + KERNEL_8x16_ACC4(); + SAVE_KERNEL_8X16_F16_F32(); + } + for (; l < K;) { + INIT_KERNEL_8X16(); + KERNEL_8x16_ACC1(); + SAVE_KERNEL_8X16_F16_F32(); + } + c += 16; + a -= 8 * K; + } + sc += ldc * 8; + c = sc; + a += 8 * K; + b = sb; + } +} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x8.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x8.cpp new file mode 100644 index 00000000..f799e527 --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x8.cpp @@ -0,0 +1,512 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Sungsik Kong + * + * @file hgemm_kernel_8x8.cpp + * @date 01 April 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Sungsik Kong + * @bug No known bugs except for NYI items + * @brief This is half-precision GEMM 8x8 kernel + * + */ + +#include +#include +#include +#include + +#define INIT_KERNEL_8x8() \ + do { \ + v24 = vdupq_n_f16(0.F); \ + v25 = vdupq_n_f16(0.F); \ + v26 = vdupq_n_f16(0.F); \ + v27 = vdupq_n_f16(0.F); \ + v28 = vdupq_n_f16(0.F); \ + v29 = vdupq_n_f16(0.F); \ + v30 = vdupq_n_f16(0.F); \ + v31 = vdupq_n_f16(0.F); \ + } while (0) + +// 1. Partial sum 1024 digits +#define KERNEL_8x8_ACC16() \ + do { \ + va0 = vld1q_f16(a); \ + v16 = vld1q_f16(b); \ + v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \ + va0 = vld1q_f16(a + 8); \ + v17 = vld1q_f16(b + 8); \ + v24 = vfmaq_laneq_f16(v24, v17, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v17, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v17, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v17, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v17, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v17, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v17, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v17, va0, 7); \ + va0 = vld1q_f16(a + 8 * 2); \ + v18 = vld1q_f16(b + 8 * 2); \ + v24 = vfmaq_laneq_f16(v24, v18, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v18, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v18, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v18, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v18, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v18, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v18, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v18, va0, 7); \ + va0 = vld1q_f16(a + 8 * 3); \ + v19 = vld1q_f16(b + 8 * 3); \ + v24 = vfmaq_laneq_f16(v24, v19, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v19, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v19, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v19, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v19, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v19, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v19, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v19, va0, 7); \ + va0 = vld1q_f16(a + 8 * 4); \ + v20 = vld1q_f16(b + 8 * 4); \ + v24 = vfmaq_laneq_f16(v24, v20, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v20, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v20, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v20, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v20, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v20, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v20, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v20, va0, 7); \ + va0 = vld1q_f16(a + 8 * 5); \ + v21 = vld1q_f16(b + 8 * 5); \ + v24 = vfmaq_laneq_f16(v24, v21, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v21, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v21, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v21, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v21, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v21, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v21, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v21, va0, 7); \ + va0 = vld1q_f16(a + 8 * 6); \ + v22 = vld1q_f16(b + 8 * 6); \ + v24 = vfmaq_laneq_f16(v24, v22, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v22, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v22, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v22, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v22, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v22, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v22, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v22, va0, 7); \ + va0 = vld1q_f16(a + 8 * 7); \ + v23 = vld1q_f16(b + 8 * 7); \ + v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ + va0 = vld1q_f16(a + 8 * 8); \ + v23 = vld1q_f16(b + 8 * 8); \ + v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ + va0 = vld1q_f16(a + 8 * 9); \ + v23 = vld1q_f16(b + 8 * 9); \ + v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ + va0 = vld1q_f16(a + 8 * 10); \ + v23 = vld1q_f16(b + 8 * 10); \ + v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ + va0 = vld1q_f16(a + 8 * 11); \ + v23 = vld1q_f16(b + 8 * 11); \ + v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ + va0 = vld1q_f16(a + 8 * 12); \ + v23 = vld1q_f16(b + 8 * 12); \ + v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ + va0 = vld1q_f16(a + 8 * 13); \ + v23 = vld1q_f16(b + 8 * 13); \ + v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ + va0 = vld1q_f16(a + 8 * 14); \ + v23 = vld1q_f16(b + 8 * 14); \ + v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ + va0 = vld1q_f16(a + 8 * 15); \ + v23 = vld1q_f16(b + 8 * 15); \ + v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ + __builtin_prefetch(b + 128, 0, 3); \ + __builtin_prefetch(a + 128, 0, 3); \ + l += 16; \ + b += 8 * 16; \ + a += 8 * 16; \ + } while (0) + +// 2. Partial sum 512 digits +#define KERNEL_8x8_ACC8() \ + do { \ + va0 = vld1q_f16(a); \ + v16 = vld1q_f16(b); \ + v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \ + va1 = vld1q_f16(a + 8); \ + v17 = vld1q_f16(b + 8); \ + v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \ + v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \ + v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \ + v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \ + v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \ + v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \ + v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \ + v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \ + va2 = vld1q_f16(a + 16); \ + v18 = vld1q_f16(b + 16); \ + v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \ + v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \ + v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \ + v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \ + v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \ + v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \ + v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \ + v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \ + va3 = vld1q_f16(a + 24); \ + v19 = vld1q_f16(b + 24); \ + v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \ + v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \ + v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \ + v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \ + v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \ + v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \ + v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \ + v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \ + va4 = vld1q_f16(a + 32); \ + v20 = vld1q_f16(b + 32); \ + v24 = vfmaq_laneq_f16(v24, v20, va4, 0); \ + v25 = vfmaq_laneq_f16(v25, v20, va4, 1); \ + v26 = vfmaq_laneq_f16(v26, v20, va4, 2); \ + v27 = vfmaq_laneq_f16(v27, v20, va4, 3); \ + v28 = vfmaq_laneq_f16(v28, v20, va4, 4); \ + v29 = vfmaq_laneq_f16(v29, v20, va4, 5); \ + v30 = vfmaq_laneq_f16(v30, v20, va4, 6); \ + v31 = vfmaq_laneq_f16(v31, v20, va4, 7); \ + va5 = vld1q_f16(a + 40); \ + v21 = vld1q_f16(b + 40); \ + v24 = vfmaq_laneq_f16(v24, v21, va5, 0); \ + v25 = vfmaq_laneq_f16(v25, v21, va5, 1); \ + v26 = vfmaq_laneq_f16(v26, v21, va5, 2); \ + v27 = vfmaq_laneq_f16(v27, v21, va5, 3); \ + v28 = vfmaq_laneq_f16(v28, v21, va5, 4); \ + v29 = vfmaq_laneq_f16(v29, v21, va5, 5); \ + v30 = vfmaq_laneq_f16(v30, v21, va5, 6); \ + v31 = vfmaq_laneq_f16(v31, v21, va5, 7); \ + va6 = vld1q_f16(a + 48); \ + v22 = vld1q_f16(b + 48); \ + v24 = vfmaq_laneq_f16(v24, v22, va6, 0); \ + v25 = vfmaq_laneq_f16(v25, v22, va6, 1); \ + v26 = vfmaq_laneq_f16(v26, v22, va6, 2); \ + v27 = vfmaq_laneq_f16(v27, v22, va6, 3); \ + v28 = vfmaq_laneq_f16(v28, v22, va6, 4); \ + v29 = vfmaq_laneq_f16(v29, v22, va6, 5); \ + v30 = vfmaq_laneq_f16(v30, v22, va6, 6); \ + v31 = vfmaq_laneq_f16(v31, v22, va6, 7); \ + va7 = vld1q_f16(a + 56); \ + v23 = vld1q_f16(b + 56); \ + v24 = vfmaq_laneq_f16(v24, v23, va7, 0); \ + v25 = vfmaq_laneq_f16(v25, v23, va7, 1); \ + v26 = vfmaq_laneq_f16(v26, v23, va7, 2); \ + v27 = vfmaq_laneq_f16(v27, v23, va7, 3); \ + v28 = vfmaq_laneq_f16(v28, v23, va7, 4); \ + v29 = vfmaq_laneq_f16(v29, v23, va7, 5); \ + v30 = vfmaq_laneq_f16(v30, v23, va7, 6); \ + v31 = vfmaq_laneq_f16(v31, v23, va7, 7); \ + __builtin_prefetch(b + 64, 0, 3); \ + __builtin_prefetch(a + 64, 0, 3); \ + l += 8; \ + b += 8 * 8; \ + a += 8 * 8; \ + } while (0) + +// 3. Partial sum 256 digits +#define KERNEL_8x8_ACC4() \ + do { \ + va0 = vld1q_f16(a); \ + v16 = vld1q_f16(b); \ + v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \ + va1 = vld1q_f16(a + 8); \ + v17 = vld1q_f16(b + 8); \ + v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \ + v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \ + v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \ + v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \ + v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \ + v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \ + v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \ + v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \ + va2 = vld1q_f16(a + 16); \ + v18 = vld1q_f16(b + 16); \ + v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \ + v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \ + v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \ + v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \ + v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \ + v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \ + v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \ + v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \ + va3 = vld1q_f16(a + 24); \ + v19 = vld1q_f16(b + 24); \ + v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \ + v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \ + v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \ + v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \ + v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \ + v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \ + v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \ + v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \ + __builtin_prefetch(b + 32, 0, 3); \ + __builtin_prefetch(a + 32, 0, 3); \ + l += 4; \ + b += 8 * 4; \ + a += 8 * 4; \ + } while (0) + +// 4. Partial sum 64 digits +#define KERNEL_8x8_ACC1() \ + do { \ + va0 = vld1q_f16(a); \ + v16 = vld1q_f16(b); \ + v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \ + v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \ + v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \ + v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \ + v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \ + v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \ + v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \ + v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \ + __builtin_prefetch(b + 8, 0, 3); \ + __builtin_prefetch(a + 8, 0, 3); \ + l += 1; \ + b += 8 * 1; \ + a += 8 * 1; \ + } while (0) + +#define SAVE_KERNEL_8X8_F16_f32() \ + do { \ + vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v24)))); \ + vst1q_f32(c + 4, \ + vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v24)))); \ + \ + vst1q_f32(c + ldc, \ + vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v25)))); \ + vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc), \ + vcvt_f32_f16(vget_high_f16(v25)))); \ + \ + vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \ + vcvt_f32_f16(vget_low_f16(v26)))); \ + vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc), \ + vcvt_f32_f16(vget_high_f16(v26)))); \ + \ + vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \ + vcvt_f32_f16(vget_low_f16(v27)))); \ + vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc), \ + vcvt_f32_f16(vget_high_f16(v27)))); \ + \ + vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc), \ + vcvt_f32_f16(vget_low_f16(v28)))); \ + vst1q_f32(c + 4 + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 + 4 * ldc), \ + vcvt_f32_f16(vget_high_f16(v28)))); \ + \ + vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc), \ + vcvt_f32_f16(vget_low_f16(v29)))); \ + vst1q_f32(c + 4 + 5 * ldc, vaddq_f32(vld1q_f32(c + 4 + 5 * ldc), \ + vcvt_f32_f16(vget_high_f16(v29)))); \ + \ + vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc), \ + vcvt_f32_f16(vget_low_f16(v30)))); \ + vst1q_f32(c + 4 + 6 * ldc, vaddq_f32(vld1q_f32(c + 4 + 6 * ldc), \ + vcvt_f32_f16(vget_high_f16(v30)))); \ + \ + vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc), \ + vcvt_f32_f16(vget_low_f16(v31)))); \ + vst1q_f32(c + 4 + 7 * ldc, vaddq_f32(vld1q_f32(c + 4 + 7 * ldc), \ + vcvt_f32_f16(vget_high_f16(v31)))); \ + } while (0) + +/** + * @brief hgemm 8x8 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading-dimension of matrix C + */ +void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(M % 8 == 0 && N % 8 == 0 && K % 4 == 0); + + __fp16 *a = sa, *b = sb, *c = sc; + unsigned int i, j, l; + for (i = 0; i < M; i += 8) { + for (j = 0; j < N; j += 8) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + + float16x8_t v16, v17, v18, v19, v20, v21, v22, v23; + float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; + float16x8_t va0, va1, va2, va3, va4, va5, va6, va7; + INIT_KERNEL_8x8(); + l = 0; + for (; l < K;) { + KERNEL_8x8_ACC1(); + } + vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24)); + vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25)); + vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26)); + vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27)); + vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28)); + vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29)); + vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30)); + vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31)); + c += 8; + a -= 8 * K; + } + sc += ldc * 8; + c = sc; + a += 8 * K; + b = sb; + } +} + +/** + * @brief hgemm 8x8 kernel sc = sa * sb + * + * @param m length of the row of matrix A + * @param n length of the col of matrix B + * @param k length of the col of matrix A + * @param sa sub-matrix of input matrix A + * @param sb sub-matrix of input matrix B + * @param sc sub-matrix of output matrix C + * @param ldc leading-dimension of matrix C + */ +void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K, + __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { + assert(M > 0 && N > 0 && K > 0); + assert(M % 8 == 0 && N % 8 == 0 && K % 8 == 0); + + __fp16 *a = sa, *b = sb; + float *c = sc; + unsigned int i, j, l; + unsigned int K4 = (K >> 2) << 2; + unsigned int K8 = (K >> 3) << 3; + unsigned int K16 = (K >> 4) << 4; + for (i = 0; i < M; i += 8) { + for (j = 0; j < N; j += 8) { + __builtin_prefetch(b, 0, 3); + __builtin_prefetch(a, 0, 3); + + float16x8_t v16, v17, v18, v19, v20, v21, v22, v23; + float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; + float16x8_t va0, va1, va2, va3, va4, va5, va6, va7; + l = 0; + for (; l < K16;) { + INIT_KERNEL_8x8(); + KERNEL_8x8_ACC16(); + SAVE_KERNEL_8X8_F16_f32(); + } + for (; l < K8;) { + INIT_KERNEL_8x8(); + KERNEL_8x8_ACC8(); + SAVE_KERNEL_8X8_F16_f32(); + } + for (; l < K4;) { + INIT_KERNEL_8x8(); + KERNEL_8x8_ACC4(); + SAVE_KERNEL_8X8_F16_f32(); + } + for (; l < K;) { + INIT_KERNEL_8x8(); + KERNEL_8x8_ACC1(); + SAVE_KERNEL_8X8_F16_f32(); + } + + c += 8; + a -= 8 * K; + } + sc += ldc * 8; + c = sc; + a += 8 * K; + b = sb; + } +} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/meson.build b/nntrainer/tensor/hgemm/hgemm_kernel/meson.build new file mode 100644 index 00000000..1b6cc50f --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_kernel/meson.build @@ -0,0 +1,22 @@ +hgemm_kernel_headers = [ + 'hgemm_kernel.h', +] + + +hgemm_kernel_sources = [ + 'hgemm_kernel_1x4.cpp', + 'hgemm_kernel_1x8.cpp', + 'hgemm_kernel_4x4.cpp', + 'hgemm_kernel_4x8.cpp', + 'hgemm_kernel_8x8.cpp', + 'hgemm_kernel_8x16.cpp', +] + +foreach s : hgemm_kernel_sources + nntrainer_sources += meson.current_source_dir() / s +endforeach + +foreach h : hgemm_kernel_headers + nntrainer_headers += meson.current_source_dir() / h +endforeach + diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h b/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h deleted file mode 100644 index d0018876..00000000 --- a/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h +++ /dev/null @@ -1,145 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2024 Debadri Samaddar - * - * @file hgemm_kernel_1x4.h - * @date 23 April 2024 - * @see https://github.com/nnstreamer/nntrainer - * @author Debadri Samaddar - * @bug No known bugs except for NYI items - * @brief This is half-precision GEMM 1x4 kernel - * - */ - -#include -#include -#include - -/** - * @brief hgemm 1x4 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading dimension of matrix C - */ -void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(N % 4 == 0); - - __fp16 *a = sa, *b = sb, *c = sc; - unsigned int i, j, l; - for (i = 0; i < M; i++) { - for (j = 0; j < N; j += 4) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - - for (l = 0; l < K; l += 4) { - float16x4_t v24 = {0.F}; - float16x4_t v0 = vld1_f16(b); - float16_t v16 = *a; - - v24 = vfma_n_f16(v24, v0, v16); - - float16x4_t v1 = vld1_f16(b + 4); - float16_t v17 = *(a + 1); - - v24 = vfma_n_f16(v24, v1, v17); - - float16x4_t v2 = vld1_f16(b + 8); - float16_t v18 = *(a + 2); - - v24 = vfma_n_f16(v24, v2, v18); - - float16x4_t v3 = vld1_f16(b + 12); - float16_t v19 = *(a + 3); - - v24 = vfma_n_f16(v24, v3, v19); - - __builtin_prefetch(b + 16, 0, 3); - __builtin_prefetch(a + 4, 0, 3); - - b += 16; - a += 4; - - v24 = vadd_f16(vld1_f16(c), v24); - - vst1_f16(c, v24); - } - c += 4; - a -= K; - } - sc += ldc; - c = sc; - a += K; - b = sb; - } -} - -/** - * @brief hgemm 1x4 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading dimension of matrix C - */ -void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(N % 4 == 0); - - __fp16 *a = sa, *b = sb; - float *c = sc; - unsigned int i, j, l; - for (i = 0; i < M; i++) { - for (j = 0; j < N; j += 4) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - - for (l = 0; l < K; l += 4) { - float16x4_t v24 = {0.F}; - float16x4_t v0 = vld1_f16(b); - float16_t v16 = *a; - - v24 = vfma_n_f16(v24, v0, v16); - - float16x4_t v1 = vld1_f16(b + 4); - float16_t v17 = *(a + 1); - - v24 = vfma_n_f16(v24, v1, v17); - - float16x4_t v2 = vld1_f16(b + 8); - float16_t v18 = *(a + 2); - - v24 = vfma_n_f16(v24, v2, v18); - - float16x4_t v3 = vld1_f16(b + 12); - float16_t v19 = *(a + 3); - - v24 = vfma_n_f16(v24, v3, v19); - - __builtin_prefetch(b + 16, 0, 3); - __builtin_prefetch(a + 4, 0, 3); - - b += 16; - a += 4; - - vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24))); - } - c += 4; - a -= K; - } - sc += ldc; - c = sc; - a += K; - b = sb; - } -} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_1x8.h b/nntrainer/tensor/hgemm/hgemm_kernel_1x8.h deleted file mode 100644 index 3114ca32..00000000 --- a/nntrainer/tensor/hgemm/hgemm_kernel_1x8.h +++ /dev/null @@ -1,184 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2024 Debadri Samaddar - * - * @file hgemm_kernel_1x8.h - * @date 05 April 2024 - * @see https://github.com/nnstreamer/nntrainer - * @author Debadri Samaddar - * @author Sungsik Kong - * @bug No known bugs except for NYI items - * @brief This is half-precision GEMM 1x8 kernel - * - */ - -#include -#include -#include - -// 1. Partial sum 64 digits : worst accuracy, best latency -#define KERNEL_1x8_ACC8() \ - do { \ - v0 = vdupq_n_f16(0.F); \ - dv0 = *a; \ - v24 = vld1q_f16(b); \ - v0 = vfmaq_n_f16(v0, v24, dv0); \ - dv1 = *(a + 1); \ - v25 = vld1q_f16(b + 8); \ - v0 = vfmaq_n_f16(v0, v25, dv1); \ - dv2 = *(a + 2); \ - v26 = vld1q_f16(b + 16); \ - v0 = vfmaq_n_f16(v0, v26, dv2); \ - dv3 = *(a + 3); \ - v27 = vld1q_f16(b + 24); \ - v0 = vfmaq_n_f16(v0, v27, dv3); \ - dv4 = *(a + 4); \ - v28 = vld1q_f16(b + 32); \ - v0 = vfmaq_n_f16(v0, v28, dv4); \ - dv5 = *(a + 5); \ - v29 = vld1q_f16(b + 40); \ - v0 = vfmaq_n_f16(v0, v29, dv5); \ - dv6 = *(a + 6); \ - v30 = vld1q_f16(b + 48); \ - v0 = vfmaq_n_f16(v0, v30, dv6); \ - dv7 = *(a + 7); \ - v31 = vld1q_f16(b + 56); \ - v0 = vfmaq_n_f16(v0, v31, dv7); \ - l += 8; \ - b += 8 * 8; \ - a += 8; \ - } while (0) - -// 2. Partial sum 32 digits : medium accuracy, medium latency -#define KERNEL_1x8_ACC4() \ - do { \ - v0 = vdupq_n_f16(0.F); \ - dv0 = *a; \ - v24 = vld1q_f16(b); \ - v0 = vfmaq_n_f16(v0, v24, dv0); \ - dv1 = *(a + 1); \ - v25 = vld1q_f16(b + 8); \ - v0 = vfmaq_n_f16(v0, v25, dv1); \ - dv2 = *(a + 2); \ - v26 = vld1q_f16(b + 16); \ - v0 = vfmaq_n_f16(v0, v26, dv2); \ - dv3 = *(a + 3); \ - v27 = vld1q_f16(b + 24); \ - v0 = vfmaq_n_f16(v0, v27, dv3); \ - l += 4; \ - b += 8 * 4; \ - a += 4; \ - } while (0) - -// 3. Partial sum 8 digits : Best accuracy, worst latency -#define KERNEL_1x8_ACC1() \ - do { \ - v0 = vdupq_n_f16(0.F); \ - dv0 = *(a); \ - v24 = vld1q_f16(b); \ - v0 = vfmaq_n_f16(v0, v24, dv0); \ - l += 1; \ - b += 8 * 1; \ - a++; \ - } while (0) - -/** - * @brief hgemm 1x8 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading-dimension of matrix C - */ -void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(N % 8 == 0); - - __fp16 *a = sa, *b = sb, *c = sc; - unsigned int k8 = (K >> 3) << 3; - unsigned int i, j, l; - for (i = 0; i < M; i++) { - for (j = 0; j < N; j += 8) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - float16x8_t v0; - float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; - float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; - l = 0; - for (; l < k8;) { - KERNEL_1x8_ACC8(); - - vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0)); - } - for (; l < K;) { - KERNEL_1x8_ACC1(); - - vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0)); - } - c += 8; - a -= K; - } - sc += ldc; - c = sc; - a += K; - b = sb; - } -} - -/** - * @brief hgemm 1x8 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading-dimension of matrix C - */ -void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(N % 8 == 0); - - __fp16 *a = sa, *b = sb; - float *c = sc; - unsigned int k8 = (K >> 3) << 3; - unsigned int i, j, l; - for (i = 0; i < M; i++) { - for (j = 0; j < N; j += 8) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - float16x8_t v0; - float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; - float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; - l = 0; - for (; l < k8;) { - KERNEL_1x8_ACC8(); - - vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0)))); - - vst1q_f32(c + 4, - vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0)))); - } - for (; l < K;) { - KERNEL_1x8_ACC1(); - - vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0)))); - - vst1q_f32(c + 4, - vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0)))); - } - c += 8; - a -= K; - } - sc += ldc; - c = sc; - a += K; - b = sb; - } -} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h b/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h deleted file mode 100644 index 18e86ccc..00000000 --- a/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h +++ /dev/null @@ -1,359 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2024 Sungsik Kong - * - * @file hgemm_kernel_4x4.h - * @date 01 April 2024 - * @see https://github.com/nnstreamer/nntrainer - * @author Sungsik Kong - * @bug No known bugs except for NYI items - * @brief This is half-precision GEMM 4x4 kernel - * - */ - -#include -#include -#include - -#define INIT_KERNEL_4x4() \ - do { \ - v24 = vdup_n_f16(0.F); \ - v25 = vdup_n_f16(0.F); \ - v26 = vdup_n_f16(0.F); \ - v27 = vdup_n_f16(0.F); \ - } while (0) - -// 1. Partial sum 256 digits -#define KERNEL_4x4_ACC16() \ - do { \ - dv0 = vld1_f16(a); \ - vb0 = vld1_f16(b); \ - v24 = vfma_lane_f16(v24, vb0, dv0, 0); \ - v25 = vfma_lane_f16(v25, vb0, dv0, 1); \ - v26 = vfma_lane_f16(v26, vb0, dv0, 2); \ - v27 = vfma_lane_f16(v27, vb0, dv0, 3); \ - dv1 = vld1_f16(a + 4); \ - vb1 = vld1_f16(b + 4); \ - v24 = vfma_lane_f16(v24, vb1, dv1, 0); \ - v25 = vfma_lane_f16(v25, vb1, dv1, 1); \ - v26 = vfma_lane_f16(v26, vb1, dv1, 2); \ - v27 = vfma_lane_f16(v27, vb1, dv1, 3); \ - dv2 = vld1_f16(a + 4 * 2); \ - vb2 = vld1_f16(b + 4 * 2); \ - v24 = vfma_lane_f16(v24, vb2, dv2, 0); \ - v25 = vfma_lane_f16(v25, vb2, dv2, 1); \ - v26 = vfma_lane_f16(v26, vb2, dv2, 2); \ - v27 = vfma_lane_f16(v27, vb2, dv2, 3); \ - dv3 = vld1_f16(a + 4 * 3); \ - vb3 = vld1_f16(b + 4 * 3); \ - v24 = vfma_lane_f16(v24, vb3, dv3, 0); \ - v25 = vfma_lane_f16(v25, vb3, dv3, 1); \ - v26 = vfma_lane_f16(v26, vb3, dv3, 2); \ - v27 = vfma_lane_f16(v27, vb3, dv3, 3); \ - dv4 = vld1_f16(a + 4 * 4); \ - vb4 = vld1_f16(b + 4 * 4); \ - v24 = vfma_lane_f16(v24, vb4, dv4, 0); \ - v25 = vfma_lane_f16(v25, vb4, dv4, 1); \ - v26 = vfma_lane_f16(v26, vb4, dv4, 2); \ - v27 = vfma_lane_f16(v27, vb4, dv4, 3); \ - dv5 = vld1_f16(a + 4 * 5); \ - vb5 = vld1_f16(b + 4 * 5); \ - v24 = vfma_lane_f16(v24, vb5, dv5, 0); \ - v25 = vfma_lane_f16(v25, vb5, dv5, 1); \ - v26 = vfma_lane_f16(v26, vb5, dv5, 2); \ - v27 = vfma_lane_f16(v27, vb5, dv5, 3); \ - dv6 = vld1_f16(a + 4 * 6); \ - vb6 = vld1_f16(b + 4 * 6); \ - v24 = vfma_lane_f16(v24, vb6, dv6, 0); \ - v25 = vfma_lane_f16(v25, vb6, dv6, 1); \ - v26 = vfma_lane_f16(v26, vb6, dv6, 2); \ - v27 = vfma_lane_f16(v27, vb6, dv6, 3); \ - dv7 = vld1_f16(a + 4 * 7); \ - vb7 = vld1_f16(b + 4 * 7); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 8); \ - vb7 = vld1_f16(b + 4 * 8); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 9); \ - vb7 = vld1_f16(b + 4 * 9); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 10); \ - vb7 = vld1_f16(b + 4 * 10); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 11); \ - vb7 = vld1_f16(b + 4 * 11); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 12); \ - vb7 = vld1_f16(b + 4 * 12); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 13); \ - vb7 = vld1_f16(b + 4 * 13); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 14); \ - vb7 = vld1_f16(b + 4 * 14); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 15); \ - vb7 = vld1_f16(b + 4 * 15); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - l += 16; \ - __builtin_prefetch(b + 64, 0, 3); \ - __builtin_prefetch(a + 64, 0, 3); \ - b += 4 * 16; \ - a += 4 * 16; \ - } while (0) - -// 2. Partial sum 128 digits -#define KERNEL_4x4_ACC8() \ - do { \ - dv0 = vld1_f16(a); \ - vb0 = vld1_f16(b); \ - v24 = vfma_lane_f16(v24, vb0, dv0, 0); \ - v25 = vfma_lane_f16(v25, vb0, dv0, 1); \ - v26 = vfma_lane_f16(v26, vb0, dv0, 2); \ - v27 = vfma_lane_f16(v27, vb0, dv0, 3); \ - dv1 = vld1_f16(a + 4); \ - vb1 = vld1_f16(b + 4); \ - v24 = vfma_lane_f16(v24, vb1, dv1, 0); \ - v25 = vfma_lane_f16(v25, vb1, dv1, 1); \ - v26 = vfma_lane_f16(v26, vb1, dv1, 2); \ - v27 = vfma_lane_f16(v27, vb1, dv1, 3); \ - dv2 = vld1_f16(a + 8); \ - vb2 = vld1_f16(b + 8); \ - v24 = vfma_lane_f16(v24, vb2, dv2, 0); \ - v25 = vfma_lane_f16(v25, vb2, dv2, 1); \ - v26 = vfma_lane_f16(v26, vb2, dv2, 2); \ - v27 = vfma_lane_f16(v27, vb2, dv2, 3); \ - dv3 = vld1_f16(a + 12); \ - vb3 = vld1_f16(b + 12); \ - v24 = vfma_lane_f16(v24, vb3, dv3, 0); \ - v25 = vfma_lane_f16(v25, vb3, dv3, 1); \ - v26 = vfma_lane_f16(v26, vb3, dv3, 2); \ - v27 = vfma_lane_f16(v27, vb3, dv3, 3); \ - dv4 = vld1_f16(a + 16); \ - vb4 = vld1_f16(b + 16); \ - v24 = vfma_lane_f16(v24, vb4, dv4, 0); \ - v25 = vfma_lane_f16(v25, vb4, dv4, 1); \ - v26 = vfma_lane_f16(v26, vb4, dv4, 2); \ - v27 = vfma_lane_f16(v27, vb4, dv4, 3); \ - dv5 = vld1_f16(a + 20); \ - vb5 = vld1_f16(b + 20); \ - v24 = vfma_lane_f16(v24, vb5, dv5, 0); \ - v25 = vfma_lane_f16(v25, vb5, dv5, 1); \ - v26 = vfma_lane_f16(v26, vb5, dv5, 2); \ - v27 = vfma_lane_f16(v27, vb5, dv5, 3); \ - dv6 = vld1_f16(a + 24); \ - vb6 = vld1_f16(b + 24); \ - v24 = vfma_lane_f16(v24, vb6, dv6, 0); \ - v25 = vfma_lane_f16(v25, vb6, dv6, 1); \ - v26 = vfma_lane_f16(v26, vb6, dv6, 2); \ - v27 = vfma_lane_f16(v27, vb6, dv6, 3); \ - dv7 = vld1_f16(a + 28); \ - vb7 = vld1_f16(b + 28); \ - v24 = vfma_lane_f16(v24, vb7, dv7, 0); \ - v25 = vfma_lane_f16(v25, vb7, dv7, 1); \ - v26 = vfma_lane_f16(v26, vb7, dv7, 2); \ - v27 = vfma_lane_f16(v27, vb7, dv7, 3); \ - l += 8; \ - __builtin_prefetch(b + 32, 0, 3); \ - __builtin_prefetch(a + 32, 0, 3); \ - b += 4 * 8; \ - a += 4 * 8; \ - } while (0) - -// 3. Partial sum 16 digits -#define KERNEL_4x4_ACC1() \ - do { \ - dv0 = vld1_f16(a); \ - vb0 = vld1_f16(b); \ - v24 = vfma_lane_f16(v24, vb0, dv0, 0); \ - v25 = vfma_lane_f16(v25, vb0, dv0, 1); \ - v26 = vfma_lane_f16(v26, vb0, dv0, 2); \ - v27 = vfma_lane_f16(v27, vb0, dv0, 3); \ - l += 1; \ - __builtin_prefetch(b + 4, 0, 3); \ - __builtin_prefetch(a + 4, 0, 3); \ - b += 4 * 1; \ - a += 4 * 1; \ - } while (0) - -#define SAVE_KERNEL_4X4_F16_F32() \ - do { \ - vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24))); \ - vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(v25))); \ - vst1q_f32(c + 2 * ldc, \ - vaddq_f32(vld1q_f32(c + 2 * ldc), vcvt_f32_f16(v26))); \ - vst1q_f32(c + 3 * ldc, \ - vaddq_f32(vld1q_f32(c + 3 * ldc), vcvt_f32_f16(v27))); \ - } while (0) - -/** - * @brief hgemm 4x4 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading dimension of matrix C - */ -void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0); - - __fp16 *a = sa, *b = sb, *c = sc; - unsigned int i, j, l; - for (i = 0; i < M; i += 4) { - for (j = 0; j < N; j += 4) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - - float16x4_t v24; - float16x4_t v25; - float16x4_t v26; - float16x4_t v27; - INIT_KERNEL_4x4(); - - for (l = 0; l < K; l += 4) { - float16x4_t v0 = vld1_f16(b); - float16x4_t v16 = vld1_f16(a); - - v24 = vfma_lane_f16(v24, v0, v16, 0); - v25 = vfma_lane_f16(v25, v0, v16, 1); - v26 = vfma_lane_f16(v26, v0, v16, 2); - v27 = vfma_lane_f16(v27, v0, v16, 3); - - float16x4_t v1 = vld1_f16(b + 4); - float16x4_t v17 = vld1_f16(a + 4); - - v24 = vfma_lane_f16(v24, v1, v17, 0); - v25 = vfma_lane_f16(v25, v1, v17, 1); - v26 = vfma_lane_f16(v26, v1, v17, 2); - v27 = vfma_lane_f16(v27, v1, v17, 3); - - float16x4_t v2 = vld1_f16(b + 8); - float16x4_t v18 = vld1_f16(a + 8); - - v24 = vfma_lane_f16(v24, v2, v18, 0); - v25 = vfma_lane_f16(v25, v2, v18, 1); - v26 = vfma_lane_f16(v26, v2, v18, 2); - v27 = vfma_lane_f16(v27, v2, v18, 3); - - float16x4_t v3 = vld1_f16(b + 12); - float16x4_t v19 = vld1_f16(a + 12); - - v24 = vfma_lane_f16(v24, v3, v19, 0); - v25 = vfma_lane_f16(v25, v3, v19, 1); - v26 = vfma_lane_f16(v26, v3, v19, 2); - v27 = vfma_lane_f16(v27, v3, v19, 3); - - __builtin_prefetch(b + 16, 0, 3); - __builtin_prefetch(a + 16, 0, 3); - - b += 16; - a += 16; - } - - v24 = vadd_f16(vld1_f16(c), v24); - v25 = vadd_f16(vld1_f16(c + ldc), v25); - v26 = vadd_f16(vld1_f16(c + 2 * ldc), v26); - v27 = vadd_f16(vld1_f16(c + 3 * ldc), v27); - - vst1_f16(c, v24); - vst1_f16(c + ldc, v25); - vst1_f16(c + 2 * ldc, v26); - vst1_f16(c + 3 * ldc, v27); - - c += 4; - a -= 4 * K; - } - sc += ldc * 4; - c = sc; - a += 4 * K; - b = sb; - } -} - -/** - * @brief hgemm 4x4 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading dimension of matrix C - */ -void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0); - - __fp16 *a = sa, *b = sb; - float *c = sc; - unsigned int i, j, l; - unsigned int K16 = (K >> 4) << 4; - unsigned int K8 = (K >> 3) << 3; - for (i = 0; i < M; i += 4) { - for (j = 0; j < N; j += 4) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - - float16x4_t v24, v25, v26, v27; - float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; - float16x4_t vb0, vb1, vb2, vb3, vb4, vb5, vb6, vb7; - l = 0; - for (; l < K16;) { - INIT_KERNEL_4x4(); - KERNEL_4x4_ACC16(); - SAVE_KERNEL_4X4_F16_F32(); - } - for (; l < K8;) { - INIT_KERNEL_4x4(); - KERNEL_4x4_ACC8(); - SAVE_KERNEL_4X4_F16_F32(); - } - for (; l < K;) { - INIT_KERNEL_4x4(); - KERNEL_4x4_ACC1(); - SAVE_KERNEL_4X4_F16_F32(); - } - - c += 4; - a -= 4 * K; - } - sc += ldc * 4; - c = sc; - a += 4 * K; - b = sb; - } -} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_4x8.h b/nntrainer/tensor/hgemm/hgemm_kernel_4x8.h deleted file mode 100644 index b1757bb5..00000000 --- a/nntrainer/tensor/hgemm/hgemm_kernel_4x8.h +++ /dev/null @@ -1,366 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2024 Sungsik Kong - * - * @file hgemm_kernel_4x8.h - * @date 03 April 2024 - * @see https://github.com/nnstreamer/nntrainer - * @author Sungsik Kong - * @bug No known bugs except for NYI items - * @brief This is half-precision GEMM 8x8 kernel - * - */ - -#include -#include -#include - -#define INIT_KERNEL_4X8() \ - do { \ - v0 = vdupq_n_f16(0.F); \ - v3 = vdupq_n_f16(0.F); \ - v6 = vdupq_n_f16(0.F); \ - v9 = vdupq_n_f16(0.F); \ - } while (0) - -// 1. Partial sum 256 digits -#define KERNEL_4x8_ACC16() \ - do { \ - dv0 = vld1_f16(a); \ - v24 = vld1q_f16(b); \ - v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \ - v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \ - v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \ - v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \ - dv1 = vld1_f16(a + 4); \ - v25 = vld1q_f16(b + 8); \ - v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \ - v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \ - v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \ - v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \ - dv2 = vld1_f16(a + 4 * 2); \ - v26 = vld1q_f16(b + 8 * 2); \ - v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \ - v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \ - v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \ - v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \ - dv3 = vld1_f16(a + 4 * 3); \ - v27 = vld1q_f16(b + 8 * 3); \ - v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \ - v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \ - v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \ - v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \ - dv4 = vld1_f16(a + 4 * 4); \ - v28 = vld1q_f16(b + 8 * 4); \ - v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \ - v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \ - v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \ - v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \ - dv5 = vld1_f16(a + 4 * 5); \ - v29 = vld1q_f16(b + 8 * 5); \ - v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \ - v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \ - v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \ - v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \ - dv6 = vld1_f16(a + 4 * 6); \ - v30 = vld1q_f16(b + 8 * 6); \ - v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \ - v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \ - v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \ - v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \ - dv7 = vld1_f16(a + 4 * 7); \ - v31 = vld1q_f16(b + 8 * 7); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 8); \ - v31 = vld1q_f16(b + 8 * 8); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 9); \ - v31 = vld1q_f16(b + 8 * 9); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 10); \ - v31 = vld1q_f16(b + 8 * 10); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 11); \ - v31 = vld1q_f16(b + 8 * 11); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 12); \ - v31 = vld1q_f16(b + 8 * 12); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 13); \ - v31 = vld1q_f16(b + 8 * 13); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 14); \ - v31 = vld1q_f16(b + 8 * 14); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - dv7 = vld1_f16(a + 4 * 15); \ - v31 = vld1q_f16(b + 8 * 15); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - l += 16; \ - __builtin_prefetch(b + 128, 0, 3); \ - __builtin_prefetch(a + 64, 0, 3); \ - b += 8 * 16; \ - a += 4 * 16; \ - } while (0) - -// 1. Partial sum 256 digits -#define KERNEL_4x8_ACC8() \ - do { \ - dv0 = vld1_f16(a); \ - v24 = vld1q_f16(b); \ - v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \ - v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \ - v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \ - v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \ - dv1 = vld1_f16(a + 4); \ - v25 = vld1q_f16(b + 8); \ - v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \ - v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \ - v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \ - v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \ - dv2 = vld1_f16(a + 8); \ - v26 = vld1q_f16(b + 16); \ - v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \ - v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \ - v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \ - v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \ - dv3 = vld1_f16(a + 12); \ - v27 = vld1q_f16(b + 24); \ - v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \ - v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \ - v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \ - v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \ - dv4 = vld1_f16(a + 16); \ - v28 = vld1q_f16(b + 32); \ - v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \ - v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \ - v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \ - v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \ - dv5 = vld1_f16(a + 20); \ - v29 = vld1q_f16(b + 40); \ - v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \ - v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \ - v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \ - v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \ - dv6 = vld1_f16(a + 24); \ - v30 = vld1q_f16(b + 48); \ - v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \ - v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \ - v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \ - v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \ - dv7 = vld1_f16(a + 28); \ - v31 = vld1q_f16(b + 56); \ - v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \ - v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \ - v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \ - v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \ - l += 8; \ - __builtin_prefetch(b + 64, 0, 3); \ - __builtin_prefetch(a + 32, 0, 3); \ - b += 8 * 8; \ - a += 4 * 8; \ - } while (0) - -// 2. Partial sum 128 digits -#define KERNEL_4x8_ACC4() \ - do { \ - dv0 = vld1_f16(a); \ - v24 = vld1q_f16(b); \ - v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \ - v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \ - v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \ - v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \ - dv1 = vld1_f16(a + 4); \ - v25 = vld1q_f16(b + 8); \ - v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \ - v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \ - v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \ - v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \ - dv2 = vld1_f16(a + 8); \ - v26 = vld1q_f16(b + 16); \ - v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \ - v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \ - v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \ - v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \ - dv3 = vld1_f16(a + 12); \ - v27 = vld1q_f16(b + 24); \ - v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \ - v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \ - v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \ - v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \ - l += 4; \ - __builtin_prefetch(b + 32, 0, 3); \ - __builtin_prefetch(a + 16, 0, 3); \ - b += 8 * 4; \ - a += 4 * 4; \ - } while (0) - -// 3. Partial sum 32 digits -#define KERNEL_4x8_ACC1() \ - do { \ - dv0 = vld1_f16(a); \ - v24 = vld1q_f16(b); \ - v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \ - v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \ - v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \ - v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \ - l += 1; \ - __builtin_prefetch(b + 8, 0, 3); \ - __builtin_prefetch(a + 4, 0, 3); \ - b += 8 * 1; \ - a += 4 * 1; \ - } while (0) - -#define SAVE_KERNEL_4X8_F16_F32() \ - do { \ - vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0)))); \ - vst1q_f32(c + ldc, \ - vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v3)))); \ - vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \ - vcvt_f32_f16(vget_low_f16(v6)))); \ - vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \ - vcvt_f32_f16(vget_low_f16(v9)))); \ - \ - vst1q_f32(c + 4, \ - vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0)))); \ - vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc), \ - vcvt_f32_f16(vget_high_f16(v3)))); \ - vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc), \ - vcvt_f32_f16(vget_high_f16(v6)))); \ - vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc), \ - vcvt_f32_f16(vget_high_f16(v9)))); \ - } while (0) - -/** - * @brief hgemm 4x8 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading-dimension of matrix C - */ -void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(M % 4 == 0 && N % 8 == 0); - - __fp16 *a = sa, *b = sb, *c = sc; - unsigned int K8 = (K >> 3) << 3; - unsigned int i, j, l; - for (i = 0; i < M; i += 4) { - for (j = 0; j < N; j += 8) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - float16x8_t v0, v3, v6, v9; - float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; - float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; - INIT_KERNEL_4X8(); - l = 0; - for (; l < K8;) { - KERNEL_4x8_ACC8(); - } - for (; l < K;) { - KERNEL_4x8_ACC1(); - } - vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0)); - vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3)); - vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6)); - vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9)); - c += 8; - a -= 4 * K; - } - sc += ldc * 4; - c = sc; - a += 4 * K; - b = sb; - } -} - -/** - * @brief hgemm 4x8 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading-dimension of matrix C - */ -void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(M % 4 == 0 && N % 8 == 0); - - __fp16 *a = sa, *b = sb; - float *c = sc; - unsigned int K16 = (K >> 4) << 4; - unsigned int K8 = (K >> 3) << 3; - unsigned int K4 = (K >> 2) << 2; - unsigned int i, j, l; - for (i = 0; i < M; i += 4) { - for (j = 0; j < N; j += 8) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - float16x8_t v0, v3, v6, v9; - float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; - float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7; - l = 0; - for (; l < K16;) { - INIT_KERNEL_4X8(); - KERNEL_4x8_ACC16(); - SAVE_KERNEL_4X8_F16_F32(); - } - for (; l < K8;) { - INIT_KERNEL_4X8(); - KERNEL_4x8_ACC8(); - SAVE_KERNEL_4X8_F16_F32(); - } - for (; l < K4;) { - INIT_KERNEL_4X8(); - KERNEL_4x8_ACC4(); - SAVE_KERNEL_4X8_F16_F32(); - } - for (; l < K;) { - INIT_KERNEL_4X8(); - KERNEL_4x8_ACC1(); - SAVE_KERNEL_4X8_F16_F32(); - } - c += 8; - a -= 4 * K; - } - sc += ldc * 4; - c = sc; - a += 4 * K; - b = sb; - } -} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_8x16.h b/nntrainer/tensor/hgemm/hgemm_kernel_8x16.h deleted file mode 100644 index d29cbfc2..00000000 --- a/nntrainer/tensor/hgemm/hgemm_kernel_8x16.h +++ /dev/null @@ -1,862 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2024 Sungsik Kong - * - * @file hgemm_kernel_8x16.h - * @date 04 April 2024 - * @see https://github.com/nnstreamer/nntrainer - * @author Sungsik Kong - * @bug No known bugs except for NYI items - * @brief This is half-precision GEMM 8x16 kernel - * - */ - -#include -#include -#include -#include - -#define INIT_KERNEL_8X16() \ - do { \ - v0_7 = vdupq_n_f16(0.F); \ - v8_15 = vdupq_n_f16(0.F); \ - v16_23 = vdupq_n_f16(0.F); \ - v24_31 = vdupq_n_f16(0.F); \ - v32_39 = vdupq_n_f16(0.F); \ - v40_47 = vdupq_n_f16(0.F); \ - v48_55 = vdupq_n_f16(0.F); \ - v56_63 = vdupq_n_f16(0.F); \ - v64_71 = vdupq_n_f16(0.F); \ - v72_79 = vdupq_n_f16(0.F); \ - v80_87 = vdupq_n_f16(0.F); \ - v88_95 = vdupq_n_f16(0.F); \ - v96_103 = vdupq_n_f16(0.F); \ - v104_111 = vdupq_n_f16(0.F); \ - v112_119 = vdupq_n_f16(0.F); \ - v120_127 = vdupq_n_f16(0.F); \ - } while (0) - -// 1. Partial sum 2048 digits -#define KERNEL_8x16_ACC16() \ - do { \ - va0 = vld1q_f16(a + 8 * 0); \ - vb1 = vld1q_f16(b + 8 * 0); \ - vb2 = vld1q_f16(b + 8 * 1); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 1); \ - vb1 = vld1q_f16(b + 8 * 2); \ - vb2 = vld1q_f16(b + 8 * 3); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 2); \ - vb1 = vld1q_f16(b + 8 * 4); \ - vb2 = vld1q_f16(b + 8 * 5); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 3); \ - vb1 = vld1q_f16(b + 8 * 6); \ - vb2 = vld1q_f16(b + 8 * 7); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 4); \ - vb1 = vld1q_f16(b + 8 * 8); \ - vb2 = vld1q_f16(b + 8 * 9); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 5); \ - vb1 = vld1q_f16(b + 8 * 10); \ - vb2 = vld1q_f16(b + 8 * 11); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 6); \ - vb1 = vld1q_f16(b + 8 * 12); \ - vb2 = vld1q_f16(b + 8 * 13); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 7); \ - vb1 = vld1q_f16(b + 8 * 14); \ - vb2 = vld1q_f16(b + 8 * 15); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 8); \ - vb1 = vld1q_f16(b + 8 * 16); \ - vb2 = vld1q_f16(b + 8 * 17); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 9); \ - vb1 = vld1q_f16(b + 8 * 18); \ - vb2 = vld1q_f16(b + 8 * 19); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 10); \ - vb1 = vld1q_f16(b + 8 * 20); \ - vb2 = vld1q_f16(b + 8 * 21); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 11); \ - vb1 = vld1q_f16(b + 8 * 22); \ - vb2 = vld1q_f16(b + 8 * 23); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 12); \ - vb1 = vld1q_f16(b + 8 * 24); \ - vb2 = vld1q_f16(b + 8 * 25); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 13); \ - vb1 = vld1q_f16(b + 8 * 26); \ - vb2 = vld1q_f16(b + 8 * 27); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 14); \ - vb1 = vld1q_f16(b + 8 * 28); \ - vb2 = vld1q_f16(b + 8 * 29); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8 * 15); \ - vb1 = vld1q_f16(b + 8 * 30); \ - vb2 = vld1q_f16(b + 8 * 31); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - __builtin_prefetch(b + 256, 0, 3); \ - __builtin_prefetch(a + 128, 0, 3); \ - l += 16; \ - b += 16 * 16; \ - a += 8 * 16; \ - } while (0) - -// 2. Partial sum 1024 digits -#define KERNEL_8x16_ACC8() \ - do { \ - va0 = vld1q_f16(a); \ - vb1 = vld1q_f16(b); \ - vb2 = vld1q_f16(b + 8); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8); \ - vb1 = vld1q_f16(b + 16); \ - vb2 = vld1q_f16(b + 24); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 16); \ - vb1 = vld1q_f16(b + 32); \ - vb2 = vld1q_f16(b + 40); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 24); \ - vb1 = vld1q_f16(b + 48); \ - vb2 = vld1q_f16(b + 56); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 32); \ - vb1 = vld1q_f16(b + 64); \ - vb2 = vld1q_f16(b + 72); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 40); \ - vb1 = vld1q_f16(b + 80); \ - vb2 = vld1q_f16(b + 88); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 48); \ - vb1 = vld1q_f16(b + 96); \ - vb2 = vld1q_f16(b + 104); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 56); \ - vb1 = vld1q_f16(b + 112); \ - vb2 = vld1q_f16(b + 120); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - l += 8; \ - __builtin_prefetch(b + 128, 0, 3); \ - __builtin_prefetch(a + 64, 0, 3); \ - b += 16 * 8; \ - a += 8 * 8; \ - } while (0) - -// 3. Partial sum 512 digits -#define KERNEL_8x16_ACC4() \ - do { \ - va0 = vld1q_f16(a); \ - vb1 = vld1q_f16(b); \ - vb2 = vld1q_f16(b + 8); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 8); \ - vb1 = vld1q_f16(b + 16); \ - vb2 = vld1q_f16(b + 24); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 16); \ - vb1 = vld1q_f16(b + 32); \ - vb2 = vld1q_f16(b + 40); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - va0 = vld1q_f16(a + 24); \ - vb1 = vld1q_f16(b + 48); \ - vb2 = vld1q_f16(b + 56); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - l += 4; \ - __builtin_prefetch(b + 64, 0, 3); \ - __builtin_prefetch(a + 32, 0, 3); \ - b += 16 * 4; \ - a += 8 * 4; \ - } while (0) - -// 4. Partial sum 128 digits -#define KERNEL_8x16_ACC1() \ - do { \ - va0 = vld1q_f16(a); \ - vb1 = vld1q_f16(b); \ - vb2 = vld1q_f16(b + 8); \ - v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \ - v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \ - v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \ - v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \ - v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \ - v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \ - v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \ - v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \ - v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \ - v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \ - v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \ - v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \ - v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \ - v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \ - v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \ - v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \ - l += 1; \ - __builtin_prefetch(b + 16, 0, 3); \ - __builtin_prefetch(a + 8, 0, 3); \ - b += 16 * 1; \ - a += 8 * 1; \ - } while (0) - -#define SAVE_KERNEL_8X16_F16_F32() \ - do { \ - vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0_7)))); \ - vst1q_f32(c + 4, \ - vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0_7)))); \ - \ - vst1q_f32( \ - c + 8, vaddq_f32(vld1q_f32(c + 8), vcvt_f32_f16(vget_low_f16(v64_71)))); \ - vst1q_f32(c + 8 + 4, vaddq_f32(vld1q_f32(c + 8 + 4), \ - vcvt_f32_f16(vget_high_f16(v64_71)))); \ - \ - vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), \ - vcvt_f32_f16(vget_low_f16(v8_15)))); \ - vst1q_f32(c + ldc + 4, vaddq_f32(vld1q_f32(c + ldc + 4), \ - vcvt_f32_f16(vget_high_f16(v8_15)))); \ - \ - vst1q_f32(c + ldc + 8, vaddq_f32(vld1q_f32(c + ldc + 8), \ - vcvt_f32_f16(vget_low_f16(v72_79)))); \ - vst1q_f32(c + ldc + 8 + 4, \ - vaddq_f32(vld1q_f32(c + ldc + 8 + 4), \ - vcvt_f32_f16(vget_high_f16(v72_79)))); \ - \ - vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \ - vcvt_f32_f16(vget_low_f16(v16_23)))); \ - vst1q_f32(c + 2 * ldc + 4, \ - vaddq_f32(vld1q_f32(c + 2 * ldc + 4), \ - vcvt_f32_f16(vget_high_f16(v16_23)))); \ - \ - vst1q_f32(c + 2 * ldc + 8, vaddq_f32(vld1q_f32(c + 2 * ldc + 8), \ - vcvt_f32_f16(vget_low_f16(v80_87)))); \ - vst1q_f32(c + 2 * ldc + 8 + 4, \ - vaddq_f32(vld1q_f32(c + 2 * ldc + 8 + 4), \ - vcvt_f32_f16(vget_high_f16(v80_87)))); \ - \ - vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \ - vcvt_f32_f16(vget_low_f16(v24_31)))); \ - vst1q_f32(c + 3 * ldc + 4, \ - vaddq_f32(vld1q_f32(c + 3 * ldc + 4), \ - vcvt_f32_f16(vget_high_f16(v24_31)))); \ - \ - vst1q_f32(c + 3 * ldc + 8, vaddq_f32(vld1q_f32(c + 3 * ldc + 8), \ - vcvt_f32_f16(vget_low_f16(v88_95)))); \ - vst1q_f32(c + 3 * ldc + 8 + 4, \ - vaddq_f32(vld1q_f32(c + 3 * ldc + 8 + 4), \ - vcvt_f32_f16(vget_high_f16(v88_95)))); \ - \ - vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc), \ - vcvt_f32_f16(vget_low_f16(v32_39)))); \ - vst1q_f32(c + 4 * ldc + 4, \ - vaddq_f32(vld1q_f32(c + 4 * ldc + 4), \ - vcvt_f32_f16(vget_high_f16(v32_39)))); \ - \ - vst1q_f32(c + 4 * ldc + 8, \ - vaddq_f32(vld1q_f32(c + 4 * ldc + 8), \ - vcvt_f32_f16(vget_low_f16(v96_103)))); \ - vst1q_f32(c + 4 * ldc + 8 + 4, \ - vaddq_f32(vld1q_f32(c + 4 * ldc + 8 + 4), \ - vcvt_f32_f16(vget_high_f16(v96_103)))); \ - \ - vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc), \ - vcvt_f32_f16(vget_low_f16(v40_47)))); \ - vst1q_f32(c + 5 * ldc + 4, \ - vaddq_f32(vld1q_f32(c + 5 * ldc + 4), \ - vcvt_f32_f16(vget_high_f16(v40_47)))); \ - vst1q_f32(c + 5 * ldc + 8, \ - vaddq_f32(vld1q_f32(c + 5 * ldc + 8), \ - vcvt_f32_f16(vget_low_f16(v104_111)))); \ - vst1q_f32(c + 5 * ldc + 8 + 4, \ - vaddq_f32(vld1q_f32(c + 5 * ldc + 8 + 4), \ - vcvt_f32_f16(vget_high_f16(v104_111)))); \ - \ - vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc), \ - vcvt_f32_f16(vget_low_f16(v48_55)))); \ - vst1q_f32(c + 6 * ldc + 4, \ - vaddq_f32(vld1q_f32(c + 6 * ldc + 4), \ - vcvt_f32_f16(vget_high_f16(v48_55)))); \ - \ - vst1q_f32(c + 6 * ldc + 8, \ - vaddq_f32(vld1q_f32(c + 6 * ldc + 8), \ - vcvt_f32_f16(vget_low_f16(v112_119)))); \ - vst1q_f32(c + 6 * ldc + 8 + 4, \ - vaddq_f32(vld1q_f32(c + 6 * ldc + 8 + 4), \ - vcvt_f32_f16(vget_high_f16(v112_119)))); \ - \ - vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc), \ - vcvt_f32_f16(vget_low_f16(v56_63)))); \ - vst1q_f32(c + 7 * ldc + 4, \ - vaddq_f32(vld1q_f32(c + 7 * ldc + 4), \ - vcvt_f32_f16(vget_high_f16(v56_63)))); \ - \ - vst1q_f32(c + 7 * ldc + 8, \ - vaddq_f32(vld1q_f32(c + 7 * ldc + 8), \ - vcvt_f32_f16(vget_low_f16(v120_127)))); \ - vst1q_f32(c + 7 * ldc + 8 + 4, \ - vaddq_f32(vld1q_f32(c + 7 * ldc + 8 + 4), \ - vcvt_f32_f16(vget_high_f16(v120_127)))); \ - } while (0) - -/** - * @brief hgemm 8x16 kernel sc = sa * sb - * - * @param M length of the row of matrix A - * @param N length of the col of matrix B - * @param K length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading-dimension of matrix C - */ -void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(M % 8 == 0 && N % 16 == 0 && K % 8 == 0); - - __fp16 *a = sa, *b = sb, *c = sc; - unsigned int i, j, l; - for (i = 0; i < M; i += 8) { - for (j = 0; j < N; j += 16) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - // 8x16 - float16x8_t v0_7, v8_15; - float16x8_t v16_23, v24_31; - float16x8_t v32_39, v40_47; - float16x8_t v48_55, v56_63; - float16x8_t v64_71, v72_79; - float16x8_t v80_87, v88_95; - float16x8_t v96_103, v104_111; - float16x8_t v112_119, v120_127; - float16x8_t vb1, vb2; - float16x8_t va0; - - INIT_KERNEL_8X16(); - l = 0; - for (; l < K;) { - KERNEL_8x16_ACC1(); - } - vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7)); - vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71)); - vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15)); - vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79)); - vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23)); - vst1q_f16(c + 2 * ldc + 8, vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87)); - vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31)); - vst1q_f16(c + 3 * ldc + 8, vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95)); - vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39)); - vst1q_f16(c + 4 * ldc + 8, - vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103)); - vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47)); - vst1q_f16(c + 5 * ldc + 8, - vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111)); - vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55)); - vst1q_f16(c + 6 * ldc + 8, - vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119)); - vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63)); - vst1q_f16(c + 7 * ldc + 8, - vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127)); - c += 16; - a -= 8 * K; - } - sc += ldc * 8; - c = sc; - a += 8 * K; - b = sb; - } -} - -/** - * @brief hgemm 8x16 kernel sc = sa * sb - * - * @param M length of the row of matrix A - * @param N length of the col of matrix B - * @param K length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading-dimension of matrix C - */ -void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(M % 8 == 0 && N % 16 == 0 && K % 4 == 0); - - __fp16 *a = sa, *b = sb; - float *c = sc; - unsigned int i, j, l; - unsigned int K4 = (K >> 2) << 2; - unsigned int K8 = (K >> 3) << 3; - unsigned int K16 = (K >> 4) << 4; - for (i = 0; i < M; i += 8) { - for (j = 0; j < N; j += 16) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - float16x8_t v0_7, v8_15; - float16x8_t v16_23, v24_31; - float16x8_t v32_39, v40_47; - float16x8_t v48_55, v56_63; - float16x8_t v64_71, v72_79; - float16x8_t v80_87, v88_95; - float16x8_t v96_103, v104_111; - float16x8_t v112_119, v120_127; - float16x8_t vb1, vb2; - float16x8_t va0; - l = 0; - for (; l < K16;) { - INIT_KERNEL_8X16(); - KERNEL_8x16_ACC16(); - SAVE_KERNEL_8X16_F16_F32(); - } - for (; l < K8;) { - INIT_KERNEL_8X16(); - KERNEL_8x16_ACC8(); - SAVE_KERNEL_8X16_F16_F32(); - } - for (; l < K4;) { - INIT_KERNEL_8X16(); - KERNEL_8x16_ACC4(); - SAVE_KERNEL_8X16_F16_F32(); - } - for (; l < K;) { - INIT_KERNEL_8X16(); - KERNEL_8x16_ACC1(); - SAVE_KERNEL_8X16_F16_F32(); - } - c += 16; - a -= 8 * K; - } - sc += ldc * 8; - c = sc; - a += 8 * K; - b = sb; - } -} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_8x8.h b/nntrainer/tensor/hgemm/hgemm_kernel_8x8.h deleted file mode 100644 index 2e3eb6a7..00000000 --- a/nntrainer/tensor/hgemm/hgemm_kernel_8x8.h +++ /dev/null @@ -1,511 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2024 Sungsik Kong - * - * @file hgemm_kernel_8x8.h - * @date 01 April 2024 - * @see https://github.com/nnstreamer/nntrainer - * @author Sungsik Kong - * @bug No known bugs except for NYI items - * @brief This is half-precision GEMM 8x8 kernel - * - */ - -#include -#include -#include - -#define INIT_KERNEL_8x8() \ - do { \ - v24 = vdupq_n_f16(0.F); \ - v25 = vdupq_n_f16(0.F); \ - v26 = vdupq_n_f16(0.F); \ - v27 = vdupq_n_f16(0.F); \ - v28 = vdupq_n_f16(0.F); \ - v29 = vdupq_n_f16(0.F); \ - v30 = vdupq_n_f16(0.F); \ - v31 = vdupq_n_f16(0.F); \ - } while (0) - -// 1. Partial sum 1024 digits -#define KERNEL_8x8_ACC16() \ - do { \ - va0 = vld1q_f16(a); \ - v16 = vld1q_f16(b); \ - v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \ - va0 = vld1q_f16(a + 8); \ - v17 = vld1q_f16(b + 8); \ - v24 = vfmaq_laneq_f16(v24, v17, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v17, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v17, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v17, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v17, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v17, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v17, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v17, va0, 7); \ - va0 = vld1q_f16(a + 8 * 2); \ - v18 = vld1q_f16(b + 8 * 2); \ - v24 = vfmaq_laneq_f16(v24, v18, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v18, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v18, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v18, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v18, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v18, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v18, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v18, va0, 7); \ - va0 = vld1q_f16(a + 8 * 3); \ - v19 = vld1q_f16(b + 8 * 3); \ - v24 = vfmaq_laneq_f16(v24, v19, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v19, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v19, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v19, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v19, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v19, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v19, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v19, va0, 7); \ - va0 = vld1q_f16(a + 8 * 4); \ - v20 = vld1q_f16(b + 8 * 4); \ - v24 = vfmaq_laneq_f16(v24, v20, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v20, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v20, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v20, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v20, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v20, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v20, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v20, va0, 7); \ - va0 = vld1q_f16(a + 8 * 5); \ - v21 = vld1q_f16(b + 8 * 5); \ - v24 = vfmaq_laneq_f16(v24, v21, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v21, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v21, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v21, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v21, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v21, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v21, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v21, va0, 7); \ - va0 = vld1q_f16(a + 8 * 6); \ - v22 = vld1q_f16(b + 8 * 6); \ - v24 = vfmaq_laneq_f16(v24, v22, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v22, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v22, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v22, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v22, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v22, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v22, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v22, va0, 7); \ - va0 = vld1q_f16(a + 8 * 7); \ - v23 = vld1q_f16(b + 8 * 7); \ - v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ - va0 = vld1q_f16(a + 8 * 8); \ - v23 = vld1q_f16(b + 8 * 8); \ - v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ - va0 = vld1q_f16(a + 8 * 9); \ - v23 = vld1q_f16(b + 8 * 9); \ - v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ - va0 = vld1q_f16(a + 8 * 10); \ - v23 = vld1q_f16(b + 8 * 10); \ - v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ - va0 = vld1q_f16(a + 8 * 11); \ - v23 = vld1q_f16(b + 8 * 11); \ - v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ - va0 = vld1q_f16(a + 8 * 12); \ - v23 = vld1q_f16(b + 8 * 12); \ - v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ - va0 = vld1q_f16(a + 8 * 13); \ - v23 = vld1q_f16(b + 8 * 13); \ - v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ - va0 = vld1q_f16(a + 8 * 14); \ - v23 = vld1q_f16(b + 8 * 14); \ - v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ - va0 = vld1q_f16(a + 8 * 15); \ - v23 = vld1q_f16(b + 8 * 15); \ - v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \ - __builtin_prefetch(b + 128, 0, 3); \ - __builtin_prefetch(a + 128, 0, 3); \ - l += 16; \ - b += 8 * 16; \ - a += 8 * 16; \ - } while (0) - -// 2. Partial sum 512 digits -#define KERNEL_8x8_ACC8() \ - do { \ - va0 = vld1q_f16(a); \ - v16 = vld1q_f16(b); \ - v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \ - va1 = vld1q_f16(a + 8); \ - v17 = vld1q_f16(b + 8); \ - v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \ - v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \ - v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \ - v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \ - v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \ - v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \ - v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \ - v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \ - va2 = vld1q_f16(a + 16); \ - v18 = vld1q_f16(b + 16); \ - v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \ - v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \ - v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \ - v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \ - v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \ - v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \ - v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \ - v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \ - va3 = vld1q_f16(a + 24); \ - v19 = vld1q_f16(b + 24); \ - v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \ - v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \ - v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \ - v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \ - v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \ - v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \ - v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \ - v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \ - va4 = vld1q_f16(a + 32); \ - v20 = vld1q_f16(b + 32); \ - v24 = vfmaq_laneq_f16(v24, v20, va4, 0); \ - v25 = vfmaq_laneq_f16(v25, v20, va4, 1); \ - v26 = vfmaq_laneq_f16(v26, v20, va4, 2); \ - v27 = vfmaq_laneq_f16(v27, v20, va4, 3); \ - v28 = vfmaq_laneq_f16(v28, v20, va4, 4); \ - v29 = vfmaq_laneq_f16(v29, v20, va4, 5); \ - v30 = vfmaq_laneq_f16(v30, v20, va4, 6); \ - v31 = vfmaq_laneq_f16(v31, v20, va4, 7); \ - va5 = vld1q_f16(a + 40); \ - v21 = vld1q_f16(b + 40); \ - v24 = vfmaq_laneq_f16(v24, v21, va5, 0); \ - v25 = vfmaq_laneq_f16(v25, v21, va5, 1); \ - v26 = vfmaq_laneq_f16(v26, v21, va5, 2); \ - v27 = vfmaq_laneq_f16(v27, v21, va5, 3); \ - v28 = vfmaq_laneq_f16(v28, v21, va5, 4); \ - v29 = vfmaq_laneq_f16(v29, v21, va5, 5); \ - v30 = vfmaq_laneq_f16(v30, v21, va5, 6); \ - v31 = vfmaq_laneq_f16(v31, v21, va5, 7); \ - va6 = vld1q_f16(a + 48); \ - v22 = vld1q_f16(b + 48); \ - v24 = vfmaq_laneq_f16(v24, v22, va6, 0); \ - v25 = vfmaq_laneq_f16(v25, v22, va6, 1); \ - v26 = vfmaq_laneq_f16(v26, v22, va6, 2); \ - v27 = vfmaq_laneq_f16(v27, v22, va6, 3); \ - v28 = vfmaq_laneq_f16(v28, v22, va6, 4); \ - v29 = vfmaq_laneq_f16(v29, v22, va6, 5); \ - v30 = vfmaq_laneq_f16(v30, v22, va6, 6); \ - v31 = vfmaq_laneq_f16(v31, v22, va6, 7); \ - va7 = vld1q_f16(a + 56); \ - v23 = vld1q_f16(b + 56); \ - v24 = vfmaq_laneq_f16(v24, v23, va7, 0); \ - v25 = vfmaq_laneq_f16(v25, v23, va7, 1); \ - v26 = vfmaq_laneq_f16(v26, v23, va7, 2); \ - v27 = vfmaq_laneq_f16(v27, v23, va7, 3); \ - v28 = vfmaq_laneq_f16(v28, v23, va7, 4); \ - v29 = vfmaq_laneq_f16(v29, v23, va7, 5); \ - v30 = vfmaq_laneq_f16(v30, v23, va7, 6); \ - v31 = vfmaq_laneq_f16(v31, v23, va7, 7); \ - __builtin_prefetch(b + 64, 0, 3); \ - __builtin_prefetch(a + 64, 0, 3); \ - l += 8; \ - b += 8 * 8; \ - a += 8 * 8; \ - } while (0) - -// 3. Partial sum 256 digits -#define KERNEL_8x8_ACC4() \ - do { \ - va0 = vld1q_f16(a); \ - v16 = vld1q_f16(b); \ - v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \ - va1 = vld1q_f16(a + 8); \ - v17 = vld1q_f16(b + 8); \ - v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \ - v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \ - v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \ - v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \ - v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \ - v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \ - v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \ - v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \ - va2 = vld1q_f16(a + 16); \ - v18 = vld1q_f16(b + 16); \ - v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \ - v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \ - v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \ - v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \ - v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \ - v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \ - v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \ - v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \ - va3 = vld1q_f16(a + 24); \ - v19 = vld1q_f16(b + 24); \ - v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \ - v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \ - v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \ - v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \ - v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \ - v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \ - v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \ - v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \ - __builtin_prefetch(b + 32, 0, 3); \ - __builtin_prefetch(a + 32, 0, 3); \ - l += 4; \ - b += 8 * 4; \ - a += 8 * 4; \ - } while (0) - -// 4. Partial sum 64 digits -#define KERNEL_8x8_ACC1() \ - do { \ - va0 = vld1q_f16(a); \ - v16 = vld1q_f16(b); \ - v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \ - v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \ - v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \ - v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \ - v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \ - v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \ - v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \ - v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \ - __builtin_prefetch(b + 8, 0, 3); \ - __builtin_prefetch(a + 8, 0, 3); \ - l += 1; \ - b += 8 * 1; \ - a += 8 * 1; \ - } while (0) - -#define SAVE_KERNEL_8X8_F16_f32() \ - do { \ - vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v24)))); \ - vst1q_f32(c + 4, \ - vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v24)))); \ - \ - vst1q_f32(c + ldc, \ - vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v25)))); \ - vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc), \ - vcvt_f32_f16(vget_high_f16(v25)))); \ - \ - vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \ - vcvt_f32_f16(vget_low_f16(v26)))); \ - vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc), \ - vcvt_f32_f16(vget_high_f16(v26)))); \ - \ - vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \ - vcvt_f32_f16(vget_low_f16(v27)))); \ - vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc), \ - vcvt_f32_f16(vget_high_f16(v27)))); \ - \ - vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc), \ - vcvt_f32_f16(vget_low_f16(v28)))); \ - vst1q_f32(c + 4 + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 + 4 * ldc), \ - vcvt_f32_f16(vget_high_f16(v28)))); \ - \ - vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc), \ - vcvt_f32_f16(vget_low_f16(v29)))); \ - vst1q_f32(c + 4 + 5 * ldc, vaddq_f32(vld1q_f32(c + 4 + 5 * ldc), \ - vcvt_f32_f16(vget_high_f16(v29)))); \ - \ - vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc), \ - vcvt_f32_f16(vget_low_f16(v30)))); \ - vst1q_f32(c + 4 + 6 * ldc, vaddq_f32(vld1q_f32(c + 4 + 6 * ldc), \ - vcvt_f32_f16(vget_high_f16(v30)))); \ - \ - vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc), \ - vcvt_f32_f16(vget_low_f16(v31)))); \ - vst1q_f32(c + 4 + 7 * ldc, vaddq_f32(vld1q_f32(c + 4 + 7 * ldc), \ - vcvt_f32_f16(vget_high_f16(v31)))); \ - } while (0) - -/** - * @brief hgemm 8x8 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading-dimension of matrix C - */ -void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(M % 8 == 0 && N % 8 == 0 && K % 4 == 0); - - __fp16 *a = sa, *b = sb, *c = sc; - unsigned int i, j, l; - for (i = 0; i < M; i += 8) { - for (j = 0; j < N; j += 8) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - - float16x8_t v16, v17, v18, v19, v20, v21, v22, v23; - float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; - float16x8_t va0, va1, va2, va3, va4, va5, va6, va7; - INIT_KERNEL_8x8(); - l = 0; - for (; l < K;) { - KERNEL_8x8_ACC1(); - } - vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24)); - vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25)); - vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26)); - vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27)); - vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28)); - vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29)); - vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30)); - vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31)); - c += 8; - a -= 8 * K; - } - sc += ldc * 8; - c = sc; - a += 8 * K; - b = sb; - } -} - -/** - * @brief hgemm 8x8 kernel sc = sa * sb - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param sa sub-matrix of input matrix A - * @param sb sub-matrix of input matrix B - * @param sc sub-matrix of output matrix C - * @param ldc leading-dimension of matrix C - */ -void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K, - __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) { - assert(M > 0 && N > 0 && K > 0); - assert(M % 8 == 0 && N % 8 == 0 && K % 8 == 0); - - __fp16 *a = sa, *b = sb; - float *c = sc; - unsigned int i, j, l; - unsigned int K4 = (K >> 2) << 2; - unsigned int K8 = (K >> 3) << 3; - unsigned int K16 = (K >> 4) << 4; - for (i = 0; i < M; i += 8) { - for (j = 0; j < N; j += 8) { - __builtin_prefetch(b, 0, 3); - __builtin_prefetch(a, 0, 3); - - float16x8_t v16, v17, v18, v19, v20, v21, v22, v23; - float16x8_t v24, v25, v26, v27, v28, v29, v30, v31; - float16x8_t va0, va1, va2, va3, va4, va5, va6, va7; - l = 0; - for (; l < K16;) { - INIT_KERNEL_8x8(); - KERNEL_8x8_ACC16(); - SAVE_KERNEL_8X8_F16_f32(); - } - for (; l < K8;) { - INIT_KERNEL_8x8(); - KERNEL_8x8_ACC8(); - SAVE_KERNEL_8X8_F16_f32(); - } - for (; l < K4;) { - INIT_KERNEL_8x8(); - KERNEL_8x8_ACC4(); - SAVE_KERNEL_8X8_F16_f32(); - } - for (; l < K;) { - INIT_KERNEL_8x8(); - KERNEL_8x8_ACC1(); - SAVE_KERNEL_8X8_F16_f32(); - } - - c += 8; - a -= 8 * K; - } - sc += ldc * 8; - c = sc; - a += 8 * K; - b = sb; - } -} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_pack.cpp b/nntrainer/tensor/hgemm/hgemm_kernel_pack.cpp deleted file mode 100644 index 649f6f33..00000000 --- a/nntrainer/tensor/hgemm/hgemm_kernel_pack.cpp +++ /dev/null @@ -1,449 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2024 Sungsik Kong - * - * @file hgemm_kernel_pack.cpp - * @date 02 July 2024 - * @see https://github.com/nnstreamer/nntrainer - * @author Sungsik Kong - * @bug No known bugs except for NYI items - * @brief This is a source file for half-precision packing for the matrix - * multiplication - */ - -#include -#include -#include -#include - -void packing_A1(unsigned int m, unsigned int k, const __fp16 *from, - unsigned int lda, const __fp16 *to) { - - assert(k != 0 && m != 0 && k % 4 == 0 && m % 4 == 0); - unsigned int i, j; - - __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4; - __fp16 *b_offset; - __fp16 ctemp1, ctemp2, ctemp3, ctemp4; - - a_offset = (__fp16 *)from; - b_offset = (__fp16 *)to; - - j = m; - do { - a_offset1 = a_offset; - a_offset += lda; - - i = (k >> 2); - do { - ctemp1 = *(a_offset1 + 0); - ctemp2 = *(a_offset1 + 1); - ctemp3 = *(a_offset1 + 2); - ctemp4 = *(a_offset1 + 3); - - *(b_offset + 0) = ctemp1; - *(b_offset + 1) = ctemp2; - *(b_offset + 2) = ctemp3; - *(b_offset + 3) = ctemp4; - - a_offset1 += 4; - - b_offset += 4; - i--; - } while (i > 0); - j--; - } while (j > 0); -} - -void packing_A4(unsigned int M, unsigned int K, const __fp16 *src, - unsigned int lda, const __fp16 *dst) { - - assert(K != 0 && M != 0 && K % 4 == 0 && M % 4 == 0); - unsigned int i, j; - - __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4; - __fp16 *b_off; - __fp16 c1, c2, c3, c4; - __fp16 c5, c6, c7, c8; - __fp16 c9, c10, c11, c12; - __fp16 c13, c14, c15, c16; - - a_off = (__fp16 *)src; - b_off = (__fp16 *)dst; - - j = (M >> 2); - do { - a_off1 = a_off; - a_off2 = a_off1 + lda; - a_off3 = a_off2 + lda; - a_off4 = a_off3 + lda; - a_off += 4 * lda; - - i = (K >> 2); - do { - c1 = *(a_off1 + 0); - c2 = *(a_off1 + 1); - c3 = *(a_off1 + 2); - c4 = *(a_off1 + 3); - - c5 = *(a_off2 + 0); - c6 = *(a_off2 + 1); - c7 = *(a_off2 + 2); - c8 = *(a_off2 + 3); - - c9 = *(a_off3 + 0); - c10 = *(a_off3 + 1); - c11 = *(a_off3 + 2); - c12 = *(a_off3 + 3); - - c13 = *(a_off4 + 0); - c14 = *(a_off4 + 1); - c15 = *(a_off4 + 2); - c16 = *(a_off4 + 3); - - *(b_off + 0) = c1; - *(b_off + 1) = c5; - *(b_off + 2) = c9; - *(b_off + 3) = c13; - - *(b_off + 4) = c2; - *(b_off + 5) = c6; - *(b_off + 6) = c10; - *(b_off + 7) = c14; - - *(b_off + 8) = c3; - *(b_off + 9) = c7; - *(b_off + 10) = c11; - *(b_off + 11) = c15; - - *(b_off + 12) = c4; - *(b_off + 13) = c8; - *(b_off + 14) = c12; - *(b_off + 15) = c16; - - a_off1 += 4; - a_off2 += 4; - a_off3 += 4; - a_off4 += 4; - - b_off += 16; - i--; - } while (i > 0); - j--; - } while (j > 0); -} - -void packing_A8(unsigned int M, unsigned int K, const __fp16 *src, - unsigned int lda, const __fp16 *dst) { - - assert(K != 0 && M != 0 && K % 8 == 0 && M % 8 == 0); - - uint16x4_t msk = {0xFFFF, 0xFFFF, 0x0000, 0x0000}; - uint16x4_t inv_msk = {0x0000, 0x0000, 0xFFFF, 0xFFFF}; - - const __fp16 *a_off = (__fp16 *)src; - __fp16 *b_off = (__fp16 *)dst; - - for (unsigned int i = 0; i < M; i += 8) { - const __fp16 *a_off1 = a_off; - const __fp16 *a_off2 = a_off1 + lda; - const __fp16 *a_off3 = a_off2 + lda; - const __fp16 *a_off4 = a_off3 + lda; - const __fp16 *a_off5 = a_off4 + lda; - const __fp16 *a_off6 = a_off5 + lda; - const __fp16 *a_off7 = a_off6 + lda; - const __fp16 *a_off8 = a_off7 + lda; - a_off += 8 * lda; - - for (unsigned int j = 0; j < K; j += 8) { - float16x8_t _v0 = vld1q_f16(a_off1); - float16x8_t _v1 = vld1q_f16(a_off2); - float16x8_t _v2 = vld1q_f16(a_off3); - float16x8_t _v3 = vld1q_f16(a_off4); - - float16x8_t _v4 = vld1q_f16(a_off5); - float16x8_t _v5 = vld1q_f16(a_off6); - float16x8_t _v6 = vld1q_f16(a_off7); - float16x8_t _v7 = vld1q_f16(a_off8); - - a_off1 += 8; - a_off2 += 8; - a_off3 += 8; - a_off4 += 8; - a_off5 += 8; - a_off6 += 8; - a_off7 += 8; - a_off8 += 8; - - float16x8x2_t _vv0 = vtrnq_f16(_v0, _v1); - float16x8x2_t _vv1 = vtrnq_f16(_v2, _v3); - float16x8x2_t _vv2 = vtrnq_f16(_v4, _v5); - float16x8x2_t _vv3 = vtrnq_f16(_v6, _v7); - - float16x8_t _v8 = - vcombine_f16(vget_low_f16(_vv0.val[0]), vget_low_f16(_vv1.val[0])); - float16x8_t _v9 = - vcombine_f16(vget_low_f16(_vv0.val[1]), vget_low_f16(_vv1.val[1])); - float16x8_t _v10 = - vcombine_f16(vget_high_f16(_vv0.val[0]), vget_high_f16(_vv1.val[0])); - float16x8_t _v11 = - vcombine_f16(vget_high_f16(_vv0.val[1]), vget_high_f16(_vv1.val[1])); - - float16x8_t _v12 = - vcombine_f16(vget_low_f16(_vv2.val[0]), vget_low_f16(_vv3.val[0])); - float16x8_t _v13 = - vcombine_f16(vget_low_f16(_vv2.val[1]), vget_low_f16(_vv3.val[1])); - float16x8_t _v14 = - vcombine_f16(vget_high_f16(_vv2.val[0]), vget_high_f16(_vv3.val[0])); - float16x8_t _v15 = - vcombine_f16(vget_high_f16(_vv2.val[1]), vget_high_f16(_vv3.val[1])); - - // pack-in-pack - float16x4_t tmp_low_v8 = vget_low_f16(_v8); - float16x4_t tmp_high_v8 = vget_high_f16(_v8); - float16x4_t mid_v8 = vext_f16(tmp_low_v8, tmp_high_v8, 2); - - float16x4_t tmp_low_v9 = vget_low_f16(_v9); - float16x4_t tmp_high_v9 = vget_high_f16(_v9); - float16x4_t mid_v9 = vext_f16(tmp_low_v9, tmp_high_v9, 2); - - float16x4_t tmp_low_v10 = vget_low_f16(_v10); - float16x4_t tmp_high_v10 = vget_high_f16(_v10); - float16x4_t mid_v10 = vext_f16(tmp_low_v10, tmp_high_v10, 2); - - float16x4_t tmp_low_v11 = vget_low_f16(_v11); - float16x4_t tmp_high_v11 = vget_high_f16(_v11); - float16x4_t mid_v11 = vext_f16(tmp_low_v11, tmp_high_v11, 2); - - float16x4_t tmp_low_v12 = vget_low_f16(_v12); - float16x4_t tmp_high_v12 = vget_high_f16(_v12); - float16x4_t mid_v12 = vext_f16(tmp_low_v12, tmp_high_v12, 2); - - float16x4_t tmp_low_v13 = vget_low_f16(_v13); - float16x4_t tmp_high_v13 = vget_high_f16(_v13); - float16x4_t mid_v13 = vext_f16(tmp_low_v13, tmp_high_v13, 2); - - float16x4_t tmp_low_v14 = vget_low_f16(_v14); - float16x4_t tmp_high_v14 = vget_high_f16(_v14); - float16x4_t mid_v14 = vext_f16(tmp_low_v14, tmp_high_v14, 2); - - float16x4_t tmp_low_v15 = vget_low_f16(_v15); - float16x4_t tmp_high_v15 = vget_high_f16(_v15); - float16x4_t mid_v15 = vext_f16(tmp_low_v15, tmp_high_v15, 2); - - _v8 = vcombine_f16(vbsl_f16(msk, tmp_low_v8, mid_v8), - vbsl_f16(msk, tmp_low_v12, mid_v12)); - _v12 = vcombine_f16(vbsl_f16(msk, tmp_low_v9, mid_v9), - vbsl_f16(msk, tmp_low_v13, mid_v13)); - _v9 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v8, mid_v8), - vbsl_f16(inv_msk, tmp_high_v12, mid_v12)); - _v13 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v9, mid_v9), - vbsl_f16(inv_msk, tmp_high_v13, mid_v13)); - _v10 = vcombine_f16(vbsl_f16(msk, tmp_low_v10, mid_v10), - vbsl_f16(msk, tmp_low_v14, mid_v14)); - _v14 = vcombine_f16(vbsl_f16(msk, tmp_low_v11, mid_v11), - vbsl_f16(msk, tmp_low_v15, mid_v15)); - _v11 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v10, mid_v10), - vbsl_f16(inv_msk, tmp_high_v14, mid_v14)); - _v15 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v11, mid_v11), - vbsl_f16(inv_msk, tmp_high_v15, mid_v15)); - - vst1q_f16(b_off + 0, _v8); - vst1q_f16(b_off + 8, _v12); - vst1q_f16(b_off + 16, _v9); - vst1q_f16(b_off + 24, _v13); - vst1q_f16(b_off + 32, _v10); - vst1q_f16(b_off + 40, _v14); - vst1q_f16(b_off + 48, _v11); - vst1q_f16(b_off + 56, _v15); - b_off += 64; - } - } -} - -void packing_B1(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst) { - assert(K != 0 && N != 0 && N % 8 == 0); - - for (int i = 0; i < K; i++) { - const __fp16 *a_off = src + i * ldb; - __fp16 *b_off = (__fp16 *)dst + i; - for (int j = 0; j < N; j++) { - float16_t v = *(a_off); - a_off++; - - *b_off = v; - b_off += K; - } - } -} - -void packing_B4(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst) { - assert(K != 0 && N != 0 && K % 4 == 0 && N % 4 == 0); - unsigned int i, j; - - __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4; - __fp16 *b_off, *b_off1; - __fp16 c1, c2, c3, c4; - __fp16 c5, c6, c7, c8; - __fp16 c9, c10, c11, c12; - __fp16 c13, c14, c15, c16; - a_off = (__fp16 *)src; - b_off = (__fp16 *)dst; - - j = (K >> 2); - do { - a_off1 = a_off; - a_off2 = a_off1 + ldb; - a_off3 = a_off2 + ldb; - a_off4 = a_off3 + ldb; - a_off += 4 * ldb; - - b_off1 = b_off; - b_off += 16; - - i = (N >> 2); - do { - c1 = *(a_off1 + 0); - c2 = *(a_off1 + 1); - c3 = *(a_off1 + 2); - c4 = *(a_off1 + 3); - - c5 = *(a_off2 + 0); - c6 = *(a_off2 + 1); - c7 = *(a_off2 + 2); - c8 = *(a_off2 + 3); - - c9 = *(a_off3 + 0); - c10 = *(a_off3 + 1); - c11 = *(a_off3 + 2); - c12 = *(a_off3 + 3); - - c13 = *(a_off4 + 0); - c14 = *(a_off4 + 1); - c15 = *(a_off4 + 2); - c16 = *(a_off4 + 3); - - a_off1 += 4; - a_off2 += 4; - a_off3 += 4; - a_off4 += 4; - - *(b_off1 + 0) = c1; - *(b_off1 + 1) = c2; - *(b_off1 + 2) = c3; - *(b_off1 + 3) = c4; - - *(b_off1 + 4) = c5; - *(b_off1 + 5) = c6; - *(b_off1 + 6) = c7; - *(b_off1 + 7) = c8; - - *(b_off1 + 8) = c9; - *(b_off1 + 9) = c10; - *(b_off1 + 10) = c11; - *(b_off1 + 11) = c12; - - *(b_off1 + 12) = c13; - *(b_off1 + 13) = c14; - *(b_off1 + 14) = c15; - *(b_off1 + 15) = c16; - - b_off1 += K * 4; - i--; - } while (i > 0); - j--; - } while (j > 0); -} - -void packing_B8(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst) { - assert(K != 0 && N != 0 && N % 8 == 0); - - for (int i = 0; i < K; i++) { - const __fp16 *a_off = src + i * ldb; - __fp16 *b_off = (__fp16 *)dst + i * 8; - for (int j = 0; j < N; j += 8) { - float16x8_t v = vld1q_f16(a_off); - a_off += 8; - - vst1q_f16(b_off, v); - b_off += 8 * K; - } - } -} - -void packing_B16(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst) { - assert(K != 0 && N != 0 && N % 16 == 0); - - for (int i = 0; i < K; i++) { - const __fp16 *a_off = src + i * ldb; - __fp16 *b_off = (__fp16 *)dst + i * 16; - for (int j = 0; j < N; j += 16) { - float16x8_t v0_7 = vld1q_f16(a_off); - float16x8_t v8_15 = vld1q_f16(a_off + 8); - a_off += 16; - - vst1q_f16(b_off, v0_7); - vst1q_f16(b_off + 8, v8_15); - b_off += 16 * K; - } - } -} - -void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst) { - /// @note ldb = K for here - assert(K != 0 && N != 0 && N % 16 == 0); - unsigned int K8 = (K >> 3) << 3; - - const __fp16 *src_off = (__fp16 *)src; - __fp16 *dst_off = (__fp16 *)dst; - - const unsigned int ld_tile_T = 16; - __fp16 *tile_T = new __fp16[8 * ld_tile_T]; - // __fp16 *tile_T = alignedMalloc(8 * ld_tile_T); - - // 1. Do something like 8x16 transpose kernel - // 2. Save linearized transposed output tile to dst - for (unsigned int n = 0; n < N; n += 16) { - const __fp16 *src_off1 = src_off; - __fp16 *dst_off1 = dst_off; - src_off += 16 * ldb; - dst_off += (K8 * 16 + (K - K8)); // ? - for (unsigned int k = 0; k < K8; k += 8) { - // 16x8 tile -> 8x16 - transpose_neon<__fp16>(16, 8, src_off1, ldb, tile_T, ld_tile_T); - - // Store with correct packing order linearly - vst1q_f16(&dst_off1[0], vld1q_f16(&tile_T[0 * ld_tile_T + 0])); - vst1q_f16(&dst_off1[8], vld1q_f16(&tile_T[0 * ld_tile_T + 8])); - vst1q_f16(&dst_off1[16], vld1q_f16(&tile_T[1 * ld_tile_T + 0])); - vst1q_f16(&dst_off1[24], vld1q_f16(&tile_T[1 * ld_tile_T + 8])); - vst1q_f16(&dst_off1[32], vld1q_f16(&tile_T[2 * ld_tile_T + 0])); - vst1q_f16(&dst_off1[40], vld1q_f16(&tile_T[2 * ld_tile_T + 8])); - vst1q_f16(&dst_off1[48], vld1q_f16(&tile_T[3 * ld_tile_T + 0])); - vst1q_f16(&dst_off1[56], vld1q_f16(&tile_T[3 * ld_tile_T + 8])); - vst1q_f16(&dst_off1[64], vld1q_f16(&tile_T[4 * ld_tile_T + 0])); - vst1q_f16(&dst_off1[72], vld1q_f16(&tile_T[4 * ld_tile_T + 8])); - vst1q_f16(&dst_off1[80], vld1q_f16(&tile_T[5 * ld_tile_T + 0])); - vst1q_f16(&dst_off1[88], vld1q_f16(&tile_T[5 * ld_tile_T + 8])); - vst1q_f16(&dst_off1[96], vld1q_f16(&tile_T[6 * ld_tile_T + 0])); - vst1q_f16(&dst_off1[104], vld1q_f16(&tile_T[6 * ld_tile_T + 8])); - vst1q_f16(&dst_off1[112], vld1q_f16(&tile_T[7 * ld_tile_T + 0])); - vst1q_f16(&dst_off1[120], vld1q_f16(&tile_T[7 * ld_tile_T + 8])); - - dst_off1 += 16 * 8; - src_off1 += 8; - } - - // Do the equivalent of one by one for the rest - for (unsigned int k = K8; k < K; ++k) { - for (unsigned int _n = 0; _n < 16; ++_n) { - dst_off1[_n] = src_off1[k]; - } - } - } -} diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_pack.h b/nntrainer/tensor/hgemm/hgemm_kernel_pack.h deleted file mode 100644 index fddc3511..00000000 --- a/nntrainer/tensor/hgemm/hgemm_kernel_pack.h +++ /dev/null @@ -1,102 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Copyright (C) 2024 Sungsik Kong - * - * @file hgemm_kernel_pack.h - * @date 01 April 2024 - * @see https://github.com/nnstreamer/nntrainer - * @author Sungsik Kong - * @author Debadri Samaddar - * @bug No known bugs except for NYI items - * @brief This is for half-precision packing for kernel-based GEMM - */ - - -/** - * @brief packing function of input matrix A - * - * @param M length of the row of the matrix - * @param K length of the col of the matrix - * @param src input of original source of the matrix - * @param lda leading dimension of the matrix - * @param dst output of packed data of the matrix - */ -void packing_A1(unsigned int m, unsigned int k, const __fp16 *from, - unsigned int lda, const __fp16 *to); -/** - * @brief packing function of input matrix A - * - * @param M length of the row of the matrix - * @param K length of the col of the matrix - * @param src input of original source of the matrix - * @param lda leading dimension of the matrix - * @param dst output of packed data of the matrix - */ -void packing_A4(unsigned int M, unsigned int K, const __fp16 *src, - unsigned int lda, const __fp16 *dst); -/** - * @brief packing function of input matrix A - * - * @param M length of the row of the matrix - * @param K length of the col of the matrix - * @param src input of original source of the matrix - * @param lda leading dimension of the matrix - * @param dst output of packed data of the matrix - */ -void packing_A8(unsigned int M, unsigned int K, const __fp16 *src, - unsigned int lda, const __fp16 *dst); -/** - * @brief packing function of input matrix B - * - * @param M length of the row of the matrix - * @param K length of the col of the matrix - * @param src input of original source of the matrix - * @param ldb leading dimension of the matrix - * @param dst output of packed data of the matrix - */ -void packing_B1(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst); -/** - * @brief packing function of input matrix B - * - * @param M length of the row of the matrix - * @param K length of the col of the matrix - * @param src input of original source of the matrix - * @param ldb leading dimension of the matrix - * @param dst output of packed data of the matrix - */ -void packing_B4(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst); -/** - * @brief packing function of input matrix B - * - * @param M length of the row of the matrix - * @param K length of the col of the matrix - * @param src input of original source of the matrix - * @param ldb leading dimension of the matrix - * @param dst output of packed data of the matrix - */ -void packing_B8(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst); -/** - * @brief packing function of input matrix B - * - * @param M length of the row of the matrix - * @param K length of the col of the matrix - * @param src input of original source of the matrix - * @param ldb leading dimension of the matrix - * @param dst output of packed data of the matrix - */ -void packing_B16(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst); -/** - * @brief - * - * @param K - * @param N - * @param src - * @param ldb - * @param dst - */ -void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src, - unsigned int ldb, const __fp16 *dst); diff --git a/nntrainer/tensor/hgemm/hgemm_noTrans.cpp b/nntrainer/tensor/hgemm/hgemm_noTrans.cpp index 64a32b38..bff0c308 100644 --- a/nntrainer/tensor/hgemm/hgemm_noTrans.cpp +++ b/nntrainer/tensor/hgemm/hgemm_noTrans.cpp @@ -11,25 +11,15 @@ * */ +#include #include - -#include +#include +#include #include +#include #include #include -// #include - #include -#include - -#include -#include -#include -#include -#include -#include - - void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M, unsigned int N, unsigned int K, float alpha, float beta) { diff --git a/nntrainer/tensor/hgemm/hgemm_pack.cpp b/nntrainer/tensor/hgemm/hgemm_pack.cpp new file mode 100644 index 00000000..0f4b1470 --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_pack.cpp @@ -0,0 +1,450 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Sungsik Kong + * + * @file hgemm_kernel_pack.cpp + * @date 02 July 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Sungsik Kong + * @bug No known bugs except for NYI items + * @brief This is a source file for half-precision packing for the matrix + * multiplication + */ + +#include +#include +#include +#include +#include + +void packing_A1(unsigned int m, unsigned int k, const __fp16 *from, + unsigned int lda, const __fp16 *to) { + + assert(k != 0 && m != 0 && k % 4 == 0 && m % 4 == 0); + unsigned int i, j; + + __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4; + __fp16 *b_offset; + __fp16 ctemp1, ctemp2, ctemp3, ctemp4; + + a_offset = (__fp16 *)from; + b_offset = (__fp16 *)to; + + j = m; + do { + a_offset1 = a_offset; + a_offset += lda; + + i = (k >> 2); + do { + ctemp1 = *(a_offset1 + 0); + ctemp2 = *(a_offset1 + 1); + ctemp3 = *(a_offset1 + 2); + ctemp4 = *(a_offset1 + 3); + + *(b_offset + 0) = ctemp1; + *(b_offset + 1) = ctemp2; + *(b_offset + 2) = ctemp3; + *(b_offset + 3) = ctemp4; + + a_offset1 += 4; + + b_offset += 4; + i--; + } while (i > 0); + j--; + } while (j > 0); +} + +void packing_A4(unsigned int M, unsigned int K, const __fp16 *src, + unsigned int lda, const __fp16 *dst) { + + assert(K != 0 && M != 0 && K % 4 == 0 && M % 4 == 0); + unsigned int i, j; + + __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4; + __fp16 *b_off; + __fp16 c1, c2, c3, c4; + __fp16 c5, c6, c7, c8; + __fp16 c9, c10, c11, c12; + __fp16 c13, c14, c15, c16; + + a_off = (__fp16 *)src; + b_off = (__fp16 *)dst; + + j = (M >> 2); + do { + a_off1 = a_off; + a_off2 = a_off1 + lda; + a_off3 = a_off2 + lda; + a_off4 = a_off3 + lda; + a_off += 4 * lda; + + i = (K >> 2); + do { + c1 = *(a_off1 + 0); + c2 = *(a_off1 + 1); + c3 = *(a_off1 + 2); + c4 = *(a_off1 + 3); + + c5 = *(a_off2 + 0); + c6 = *(a_off2 + 1); + c7 = *(a_off2 + 2); + c8 = *(a_off2 + 3); + + c9 = *(a_off3 + 0); + c10 = *(a_off3 + 1); + c11 = *(a_off3 + 2); + c12 = *(a_off3 + 3); + + c13 = *(a_off4 + 0); + c14 = *(a_off4 + 1); + c15 = *(a_off4 + 2); + c16 = *(a_off4 + 3); + + *(b_off + 0) = c1; + *(b_off + 1) = c5; + *(b_off + 2) = c9; + *(b_off + 3) = c13; + + *(b_off + 4) = c2; + *(b_off + 5) = c6; + *(b_off + 6) = c10; + *(b_off + 7) = c14; + + *(b_off + 8) = c3; + *(b_off + 9) = c7; + *(b_off + 10) = c11; + *(b_off + 11) = c15; + + *(b_off + 12) = c4; + *(b_off + 13) = c8; + *(b_off + 14) = c12; + *(b_off + 15) = c16; + + a_off1 += 4; + a_off2 += 4; + a_off3 += 4; + a_off4 += 4; + + b_off += 16; + i--; + } while (i > 0); + j--; + } while (j > 0); +} + +void packing_A8(unsigned int M, unsigned int K, const __fp16 *src, + unsigned int lda, const __fp16 *dst) { + + assert(K != 0 && M != 0 && K % 8 == 0 && M % 8 == 0); + + uint16x4_t msk = {0xFFFF, 0xFFFF, 0x0000, 0x0000}; + uint16x4_t inv_msk = {0x0000, 0x0000, 0xFFFF, 0xFFFF}; + + const __fp16 *a_off = (__fp16 *)src; + __fp16 *b_off = (__fp16 *)dst; + + for (unsigned int i = 0; i < M; i += 8) { + const __fp16 *a_off1 = a_off; + const __fp16 *a_off2 = a_off1 + lda; + const __fp16 *a_off3 = a_off2 + lda; + const __fp16 *a_off4 = a_off3 + lda; + const __fp16 *a_off5 = a_off4 + lda; + const __fp16 *a_off6 = a_off5 + lda; + const __fp16 *a_off7 = a_off6 + lda; + const __fp16 *a_off8 = a_off7 + lda; + a_off += 8 * lda; + + for (unsigned int j = 0; j < K; j += 8) { + float16x8_t _v0 = vld1q_f16(a_off1); + float16x8_t _v1 = vld1q_f16(a_off2); + float16x8_t _v2 = vld1q_f16(a_off3); + float16x8_t _v3 = vld1q_f16(a_off4); + + float16x8_t _v4 = vld1q_f16(a_off5); + float16x8_t _v5 = vld1q_f16(a_off6); + float16x8_t _v6 = vld1q_f16(a_off7); + float16x8_t _v7 = vld1q_f16(a_off8); + + a_off1 += 8; + a_off2 += 8; + a_off3 += 8; + a_off4 += 8; + a_off5 += 8; + a_off6 += 8; + a_off7 += 8; + a_off8 += 8; + + float16x8x2_t _vv0 = vtrnq_f16(_v0, _v1); + float16x8x2_t _vv1 = vtrnq_f16(_v2, _v3); + float16x8x2_t _vv2 = vtrnq_f16(_v4, _v5); + float16x8x2_t _vv3 = vtrnq_f16(_v6, _v7); + + float16x8_t _v8 = + vcombine_f16(vget_low_f16(_vv0.val[0]), vget_low_f16(_vv1.val[0])); + float16x8_t _v9 = + vcombine_f16(vget_low_f16(_vv0.val[1]), vget_low_f16(_vv1.val[1])); + float16x8_t _v10 = + vcombine_f16(vget_high_f16(_vv0.val[0]), vget_high_f16(_vv1.val[0])); + float16x8_t _v11 = + vcombine_f16(vget_high_f16(_vv0.val[1]), vget_high_f16(_vv1.val[1])); + + float16x8_t _v12 = + vcombine_f16(vget_low_f16(_vv2.val[0]), vget_low_f16(_vv3.val[0])); + float16x8_t _v13 = + vcombine_f16(vget_low_f16(_vv2.val[1]), vget_low_f16(_vv3.val[1])); + float16x8_t _v14 = + vcombine_f16(vget_high_f16(_vv2.val[0]), vget_high_f16(_vv3.val[0])); + float16x8_t _v15 = + vcombine_f16(vget_high_f16(_vv2.val[1]), vget_high_f16(_vv3.val[1])); + + // pack-in-pack + float16x4_t tmp_low_v8 = vget_low_f16(_v8); + float16x4_t tmp_high_v8 = vget_high_f16(_v8); + float16x4_t mid_v8 = vext_f16(tmp_low_v8, tmp_high_v8, 2); + + float16x4_t tmp_low_v9 = vget_low_f16(_v9); + float16x4_t tmp_high_v9 = vget_high_f16(_v9); + float16x4_t mid_v9 = vext_f16(tmp_low_v9, tmp_high_v9, 2); + + float16x4_t tmp_low_v10 = vget_low_f16(_v10); + float16x4_t tmp_high_v10 = vget_high_f16(_v10); + float16x4_t mid_v10 = vext_f16(tmp_low_v10, tmp_high_v10, 2); + + float16x4_t tmp_low_v11 = vget_low_f16(_v11); + float16x4_t tmp_high_v11 = vget_high_f16(_v11); + float16x4_t mid_v11 = vext_f16(tmp_low_v11, tmp_high_v11, 2); + + float16x4_t tmp_low_v12 = vget_low_f16(_v12); + float16x4_t tmp_high_v12 = vget_high_f16(_v12); + float16x4_t mid_v12 = vext_f16(tmp_low_v12, tmp_high_v12, 2); + + float16x4_t tmp_low_v13 = vget_low_f16(_v13); + float16x4_t tmp_high_v13 = vget_high_f16(_v13); + float16x4_t mid_v13 = vext_f16(tmp_low_v13, tmp_high_v13, 2); + + float16x4_t tmp_low_v14 = vget_low_f16(_v14); + float16x4_t tmp_high_v14 = vget_high_f16(_v14); + float16x4_t mid_v14 = vext_f16(tmp_low_v14, tmp_high_v14, 2); + + float16x4_t tmp_low_v15 = vget_low_f16(_v15); + float16x4_t tmp_high_v15 = vget_high_f16(_v15); + float16x4_t mid_v15 = vext_f16(tmp_low_v15, tmp_high_v15, 2); + + _v8 = vcombine_f16(vbsl_f16(msk, tmp_low_v8, mid_v8), + vbsl_f16(msk, tmp_low_v12, mid_v12)); + _v12 = vcombine_f16(vbsl_f16(msk, tmp_low_v9, mid_v9), + vbsl_f16(msk, tmp_low_v13, mid_v13)); + _v9 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v8, mid_v8), + vbsl_f16(inv_msk, tmp_high_v12, mid_v12)); + _v13 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v9, mid_v9), + vbsl_f16(inv_msk, tmp_high_v13, mid_v13)); + _v10 = vcombine_f16(vbsl_f16(msk, tmp_low_v10, mid_v10), + vbsl_f16(msk, tmp_low_v14, mid_v14)); + _v14 = vcombine_f16(vbsl_f16(msk, tmp_low_v11, mid_v11), + vbsl_f16(msk, tmp_low_v15, mid_v15)); + _v11 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v10, mid_v10), + vbsl_f16(inv_msk, tmp_high_v14, mid_v14)); + _v15 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v11, mid_v11), + vbsl_f16(inv_msk, tmp_high_v15, mid_v15)); + + vst1q_f16(b_off + 0, _v8); + vst1q_f16(b_off + 8, _v12); + vst1q_f16(b_off + 16, _v9); + vst1q_f16(b_off + 24, _v13); + vst1q_f16(b_off + 32, _v10); + vst1q_f16(b_off + 40, _v14); + vst1q_f16(b_off + 48, _v11); + vst1q_f16(b_off + 56, _v15); + b_off += 64; + } + } +} + +void packing_B1(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst) { + assert(K != 0 && N != 0 && N % 8 == 0); + + for (int i = 0; i < K; i++) { + const __fp16 *a_off = src + i * ldb; + __fp16 *b_off = (__fp16 *)dst + i; + for (int j = 0; j < N; j++) { + float16_t v = *(a_off); + a_off++; + + *b_off = v; + b_off += K; + } + } +} + +void packing_B4(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst) { + assert(K != 0 && N != 0 && K % 4 == 0 && N % 4 == 0); + unsigned int i, j; + + __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4; + __fp16 *b_off, *b_off1; + __fp16 c1, c2, c3, c4; + __fp16 c5, c6, c7, c8; + __fp16 c9, c10, c11, c12; + __fp16 c13, c14, c15, c16; + a_off = (__fp16 *)src; + b_off = (__fp16 *)dst; + + j = (K >> 2); + do { + a_off1 = a_off; + a_off2 = a_off1 + ldb; + a_off3 = a_off2 + ldb; + a_off4 = a_off3 + ldb; + a_off += 4 * ldb; + + b_off1 = b_off; + b_off += 16; + + i = (N >> 2); + do { + c1 = *(a_off1 + 0); + c2 = *(a_off1 + 1); + c3 = *(a_off1 + 2); + c4 = *(a_off1 + 3); + + c5 = *(a_off2 + 0); + c6 = *(a_off2 + 1); + c7 = *(a_off2 + 2); + c8 = *(a_off2 + 3); + + c9 = *(a_off3 + 0); + c10 = *(a_off3 + 1); + c11 = *(a_off3 + 2); + c12 = *(a_off3 + 3); + + c13 = *(a_off4 + 0); + c14 = *(a_off4 + 1); + c15 = *(a_off4 + 2); + c16 = *(a_off4 + 3); + + a_off1 += 4; + a_off2 += 4; + a_off3 += 4; + a_off4 += 4; + + *(b_off1 + 0) = c1; + *(b_off1 + 1) = c2; + *(b_off1 + 2) = c3; + *(b_off1 + 3) = c4; + + *(b_off1 + 4) = c5; + *(b_off1 + 5) = c6; + *(b_off1 + 6) = c7; + *(b_off1 + 7) = c8; + + *(b_off1 + 8) = c9; + *(b_off1 + 9) = c10; + *(b_off1 + 10) = c11; + *(b_off1 + 11) = c12; + + *(b_off1 + 12) = c13; + *(b_off1 + 13) = c14; + *(b_off1 + 14) = c15; + *(b_off1 + 15) = c16; + + b_off1 += K * 4; + i--; + } while (i > 0); + j--; + } while (j > 0); +} + +void packing_B8(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst) { + assert(K != 0 && N != 0 && N % 8 == 0); + + for (int i = 0; i < K; i++) { + const __fp16 *a_off = src + i * ldb; + __fp16 *b_off = (__fp16 *)dst + i * 8; + for (int j = 0; j < N; j += 8) { + float16x8_t v = vld1q_f16(a_off); + a_off += 8; + + vst1q_f16(b_off, v); + b_off += 8 * K; + } + } +} + +void packing_B16(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst) { + assert(K != 0 && N != 0 && N % 16 == 0); + + for (int i = 0; i < K; i++) { + const __fp16 *a_off = src + i * ldb; + __fp16 *b_off = (__fp16 *)dst + i * 16; + for (int j = 0; j < N; j += 16) { + float16x8_t v0_7 = vld1q_f16(a_off); + float16x8_t v8_15 = vld1q_f16(a_off + 8); + a_off += 16; + + vst1q_f16(b_off, v0_7); + vst1q_f16(b_off + 8, v8_15); + b_off += 16 * K; + } + } +} + +void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst) { + /// @note ldb = K for here + assert(K != 0 && N != 0 && N % 16 == 0); + unsigned int K8 = (K >> 3) << 3; + + const __fp16 *src_off = (__fp16 *)src; + __fp16 *dst_off = (__fp16 *)dst; + + const unsigned int ld_tile_T = 16; + __fp16 *tile_T = new __fp16[8 * ld_tile_T]; + // __fp16 *tile_T = alignedMalloc(8 * ld_tile_T); + + // 1. Do something like 8x16 transpose kernel + // 2. Save linearized transposed output tile to dst + for (unsigned int n = 0; n < N; n += 16) { + const __fp16 *src_off1 = src_off; + __fp16 *dst_off1 = dst_off; + src_off += 16 * ldb; + dst_off += (K8 * 16 + (K - K8)); // ? + for (unsigned int k = 0; k < K8; k += 8) { + // 16x8 tile -> 8x16 + transpose_neon<__fp16>(16, 8, src_off1, ldb, tile_T, ld_tile_T); + + // Store with correct packing order linearly + vst1q_f16(&dst_off1[0], vld1q_f16(&tile_T[0 * ld_tile_T + 0])); + vst1q_f16(&dst_off1[8], vld1q_f16(&tile_T[0 * ld_tile_T + 8])); + vst1q_f16(&dst_off1[16], vld1q_f16(&tile_T[1 * ld_tile_T + 0])); + vst1q_f16(&dst_off1[24], vld1q_f16(&tile_T[1 * ld_tile_T + 8])); + vst1q_f16(&dst_off1[32], vld1q_f16(&tile_T[2 * ld_tile_T + 0])); + vst1q_f16(&dst_off1[40], vld1q_f16(&tile_T[2 * ld_tile_T + 8])); + vst1q_f16(&dst_off1[48], vld1q_f16(&tile_T[3 * ld_tile_T + 0])); + vst1q_f16(&dst_off1[56], vld1q_f16(&tile_T[3 * ld_tile_T + 8])); + vst1q_f16(&dst_off1[64], vld1q_f16(&tile_T[4 * ld_tile_T + 0])); + vst1q_f16(&dst_off1[72], vld1q_f16(&tile_T[4 * ld_tile_T + 8])); + vst1q_f16(&dst_off1[80], vld1q_f16(&tile_T[5 * ld_tile_T + 0])); + vst1q_f16(&dst_off1[88], vld1q_f16(&tile_T[5 * ld_tile_T + 8])); + vst1q_f16(&dst_off1[96], vld1q_f16(&tile_T[6 * ld_tile_T + 0])); + vst1q_f16(&dst_off1[104], vld1q_f16(&tile_T[6 * ld_tile_T + 8])); + vst1q_f16(&dst_off1[112], vld1q_f16(&tile_T[7 * ld_tile_T + 0])); + vst1q_f16(&dst_off1[120], vld1q_f16(&tile_T[7 * ld_tile_T + 8])); + + dst_off1 += 16 * 8; + src_off1 += 8; + } + + // Do the equivalent of one by one for the rest + for (unsigned int k = K8; k < K; ++k) { + for (unsigned int _n = 0; _n < 16; ++_n) { + dst_off1[_n] = src_off1[k]; + } + } + } +} diff --git a/nntrainer/tensor/hgemm/hgemm_pack.h b/nntrainer/tensor/hgemm/hgemm_pack.h new file mode 100644 index 00000000..7a671a51 --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm_pack.h @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Sungsik Kong + * + * @file hgemm_kernel_pack.h + * @date 01 April 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Sungsik Kong + * @author Debadri Samaddar + * @bug No known bugs except for NYI items + * @brief This is for half-precision packing for kernel-based GEMM + */ + +/** + * @brief packing function of input matrix A + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix + * @param lda leading dimension of the matrix + * @param dst output of packed data of the matrix + */ +void packing_A1(unsigned int m, unsigned int k, const __fp16 *from, + unsigned int lda, const __fp16 *to); +/** + * @brief packing function of input matrix A + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix + * @param lda leading dimension of the matrix + * @param dst output of packed data of the matrix + */ +void packing_A4(unsigned int M, unsigned int K, const __fp16 *src, + unsigned int lda, const __fp16 *dst); +/** + * @brief packing function of input matrix A + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix + * @param lda leading dimension of the matrix + * @param dst output of packed data of the matrix + */ +void packing_A8(unsigned int M, unsigned int K, const __fp16 *src, + unsigned int lda, const __fp16 *dst); +/** + * @brief packing function of input matrix B + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix + * @param ldb leading dimension of the matrix + * @param dst output of packed data of the matrix + */ +void packing_B1(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst); +/** + * @brief packing function of input matrix B + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix + * @param ldb leading dimension of the matrix + * @param dst output of packed data of the matrix + */ +void packing_B4(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst); +/** + * @brief packing function of input matrix B + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix + * @param ldb leading dimension of the matrix + * @param dst output of packed data of the matrix + */ +void packing_B8(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst); +/** + * @brief packing function of input matrix B + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix + * @param ldb leading dimension of the matrix + * @param dst output of packed data of the matrix + */ +void packing_B16(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst); +/** + * @brief + * + * @param K + * @param N + * @param src + * @param ldb + * @param dst + */ +void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst); diff --git a/nntrainer/tensor/hgemm/hgemm_padding_b.cpp b/nntrainer/tensor/hgemm/hgemm_padding_b.cpp index 3eb4f188..f0ef1052 100644 --- a/nntrainer/tensor/hgemm/hgemm_padding_b.cpp +++ b/nntrainer/tensor/hgemm/hgemm_padding_b.cpp @@ -82,7 +82,6 @@ void hgemm_padding_B_noTrans_wrt_KN(const __fp16 *B, __fp16 *Bp, unsigned int K, std::cerr << "Error : hgemm_padding_B_noTrans_wrt_KN NYI!\n"; } - void hgemm_padding_B_Trans_wrt_N(const __fp16 *B, __fp16 *Bp, unsigned int K, unsigned int N, unsigned int K8, unsigned int N16) { diff --git a/nntrainer/tensor/hgemm/hgemm_transB.cpp b/nntrainer/tensor/hgemm/hgemm_transB.cpp index adc7907f..f224e5f7 100644 --- a/nntrainer/tensor/hgemm/hgemm_transB.cpp +++ b/nntrainer/tensor/hgemm/hgemm_transB.cpp @@ -12,20 +12,15 @@ */ #include -#include #include -// #include -#include +#include #include +#include #include #include #include #include -// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16 /// @todo change to macro kernel -// #if !defined(HGEMM_KERNEL_8x16) hgemm_kernel_8x16 -// #endif - void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K, const __fp16 *A, unsigned int lda, const __fp16 *B, unsigned int ldb, float *C, unsigned int ldc, @@ -87,8 +82,7 @@ void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K, n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1); } packing_transB16(k_min, n_min, B + ks + ldb * ns, ldb, sB); - hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, - ldc); + hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc); } } } diff --git a/nntrainer/tensor/hgemm/meson.build b/nntrainer/tensor/hgemm/meson.build index ec5af07b..90ef0a37 100644 --- a/nntrainer/tensor/hgemm/meson.build +++ b/nntrainer/tensor/hgemm/meson.build @@ -1,18 +1,20 @@ hgemm_headers = [ 'hgemm.h', 'hgemm_util.h', - 'hgemm_kernel_pack.h', - 'hgemm_kernel_4x4.h', - 'hgemm_kernel_4x8.h', - 'hgemm_kernel_8x8.h', - 'hgemm_kernel_8x16.h', + 'hgemm_pack.h', + 'hgemm_common.h', + 'hgemm_padding.h', ] +subdir('hgemm_kernel') +nntrainer_inc += include_directories('hgemm_kernel') +nntrainer_inc_abs += meson.current_source_dir() / 'hgemm_kernel' + hgemm_sources = [ 'hgemm.cpp', 'hgemm_padding_a.cpp', 'hgemm_padding_b.cpp', - 'hgemm_kernel_pack.cpp', + 'hgemm_pack.cpp', 'hgemm_noTrans.cpp', 'hgemm_transA.cpp', 'hgemm_transB.cpp',