- Kernel functions are used regardless of matrix transpose, does need to be included from separate file.
- For further optimal implemenation of matrix A / B / AB transpose blocking-kernel sequences, divide their file for convenience
- Function 'hgemm' itself is better to be reside in hgemm directory.
**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped
Signed-off-by: skykongkong8 <ss.kong@samsung.com>
* @param[in] alpha float number
* @param[in] beta float number
*/
-void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
- uint32_t K, float alpha, float beta, bool TransA, bool TransB);
-
+void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
+ uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
+ bool TransB);
/**
* @brief squared root transformation with neon : X = sqrt(X)
*
*/
+#include <arm_neon.h>
+#include <cmath>
#include <hgemm.h>
+#include <hgemm_common.h>
#include <hgemm_noTrans.h>
#include <hgemm_padding.h>
#include <hgemm_transA.h>
#include <hgemm_transAB.h>
#include <hgemm_transB.h>
#include <hgemm_util.h>
-#include <hgemm_common.h>
-#include <cmath>
-
-void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
- unsigned int K, float alpha, float beta, bool TransA, bool TransB) {
+void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
+ bool TransB) {
if (K == 1) {
return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
}
}
hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB);
-
- unsigned int L = M*N;
- unsigned int L8 = (L >> 3) <<3;
+
+ unsigned int L = M * N;
+ unsigned int L8 = (L >> 3) << 3;
for (unsigned int idx = 0; idx < L8; idx += 8) {
float32x4_t x1 = vld1q_f32(&C32[idx]);
}
void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
- bool TransB) {
+ unsigned int N, unsigned int K, float alpha, float beta,
+ bool TransA, bool TransB) {
unsigned int lda = (TransA) ? M : K;
unsigned int ldb = (TransB) ? K : N;
-
+
return hgemm_K1_noTrans(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
if (!TransA && TransB) {
* @param[in] alpha float number
* @param[in] beta float number
*/
-void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
- unsigned int K, float alpha, float beta, bool TransA, bool TransB);
+void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
+ bool TransB);
/**
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
* @param[in] alpha float number
* @param[in] beta float number
*/
-void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
- unsigned int N, unsigned int K, float alpha = 1.F, float beta = 0.F,
- bool TransA = false, bool TransB = false);
+void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha = 1.F, float beta = 0.F, bool TransA = false,
+ bool TransB = false);
/**
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
* where op(X) is one of X or X**T
* @param[in] beta float number
*/
void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
- bool TransB);
+ unsigned int N, unsigned int K, float alpha, float beta,
+ bool TransA, bool TransB);
* @brief This is common settings for hgemm
*
*/
-#include <arm_neon.h>
-#include <assert.h>
-
-#define A(i, j) a[(i) * lda + (j)]
-#define B(i, j) b[(i) * ldb + (j)]
-#define C(i, j) c[(i) * ldc + (j)]
+#define A(i, j) a[(i)*lda + (j)]
+#define B(i, j) b[(i)*ldb + (j)]
+#define C(i, j) c[(i)*ldc + (j)]
#define N_BLOCKING (768)
#define K_BLOCKING (256)
#define GEMM_UNROLLING_1 (1)
#define VL_FP16 (8)
#define VL_FP16_HALF (4)
-
-
-
-/**
- * @todo Add macro for instructions in other CPU architectures
- */
+++ /dev/null
-// #include <hgemm_kernel_1x4.h>
-// #include <hgemm_kernel_1x8.h>
-// #include <hgemm_kernel_4x4.h>
-// #include <hgemm_kernel_4x8.h>
-// #include <hgemm_kernel_8x16.h>
-// #include <hgemm_kernel_8x8.h>
-
-// #define HGEMM_KERNEL_1x4 hgemm_kernel_1x4
-// #define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
-// #define HGEMM_KERNEL_1x8 hgemm_kernel_1x8
-// #define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
-// #define HGEMM_KERNEL_8x8 hgemm_kernel_8x8
-// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16
\ No newline at end of file
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_kernel.h
+ * @date 10 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is a collection of all the KERNELs function for hgemm
+ *
+ */
+
+/**
+ * @brief hgemm_kernel_8x16 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_8x16 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_8x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_8x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_4x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_4x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_4x4 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_4x4 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_1x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_1x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_1x4 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_1x4 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
+ *
+ * @file hgemm_kernel_1x4.cpp
+ * @date 23 April 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Debadri Samaddar <s.debadri@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM 1x4 kernel
+ *
+ */
+
+#include <stdlib.h>
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+
+/**
+ * @brief hgemm 1x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(N % 4 == 0);
+
+ __fp16 *a = sa, *b = sb, *c = sc;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i++) {
+ for (j = 0; j < N; j += 4) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+
+ for (l = 0; l < K; l += 4) {
+ float16x4_t v24 = {0.F};
+ float16x4_t v0 = vld1_f16(b);
+ float16_t v16 = *a;
+
+ v24 = vfma_n_f16(v24, v0, v16);
+
+ float16x4_t v1 = vld1_f16(b + 4);
+ float16_t v17 = *(a + 1);
+
+ v24 = vfma_n_f16(v24, v1, v17);
+
+ float16x4_t v2 = vld1_f16(b + 8);
+ float16_t v18 = *(a + 2);
+
+ v24 = vfma_n_f16(v24, v2, v18);
+
+ float16x4_t v3 = vld1_f16(b + 12);
+ float16_t v19 = *(a + 3);
+
+ v24 = vfma_n_f16(v24, v3, v19);
+
+ __builtin_prefetch(b + 16, 0, 3);
+ __builtin_prefetch(a + 4, 0, 3);
+
+ b += 16;
+ a += 4;
+
+ v24 = vadd_f16(vld1_f16(c), v24);
+
+ vst1_f16(c, v24);
+ }
+ c += 4;
+ a -= K;
+ }
+ sc += ldc;
+ c = sc;
+ a += K;
+ b = sb;
+ }
+}
+
+/**
+ * @brief hgemm 1x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(N % 4 == 0);
+
+ __fp16 *a = sa, *b = sb;
+ float *c = sc;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i++) {
+ for (j = 0; j < N; j += 4) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+
+ for (l = 0; l < K; l += 4) {
+ float16x4_t v24 = {0.F};
+ float16x4_t v0 = vld1_f16(b);
+ float16_t v16 = *a;
+
+ v24 = vfma_n_f16(v24, v0, v16);
+
+ float16x4_t v1 = vld1_f16(b + 4);
+ float16_t v17 = *(a + 1);
+
+ v24 = vfma_n_f16(v24, v1, v17);
+
+ float16x4_t v2 = vld1_f16(b + 8);
+ float16_t v18 = *(a + 2);
+
+ v24 = vfma_n_f16(v24, v2, v18);
+
+ float16x4_t v3 = vld1_f16(b + 12);
+ float16_t v19 = *(a + 3);
+
+ v24 = vfma_n_f16(v24, v3, v19);
+
+ __builtin_prefetch(b + 16, 0, 3);
+ __builtin_prefetch(a + 4, 0, 3);
+
+ b += 16;
+ a += 4;
+
+ vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24)));
+ }
+ c += 4;
+ a -= K;
+ }
+ sc += ldc;
+ c = sc;
+ a += K;
+ b = sb;
+ }
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
+ *
+ * @file hgemm_kernel_1x8.cpp
+ * @date 05 April 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Debadri Samaddar <s.debadri@samsung.com>
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM 1x8 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+// 1. Partial sum 64 digits : worst accuracy, best latency
+#define KERNEL_1x8_ACC8() \
+ do { \
+ v0 = vdupq_n_f16(0.F); \
+ dv0 = *a; \
+ v24 = vld1q_f16(b); \
+ v0 = vfmaq_n_f16(v0, v24, dv0); \
+ dv1 = *(a + 1); \
+ v25 = vld1q_f16(b + 8); \
+ v0 = vfmaq_n_f16(v0, v25, dv1); \
+ dv2 = *(a + 2); \
+ v26 = vld1q_f16(b + 16); \
+ v0 = vfmaq_n_f16(v0, v26, dv2); \
+ dv3 = *(a + 3); \
+ v27 = vld1q_f16(b + 24); \
+ v0 = vfmaq_n_f16(v0, v27, dv3); \
+ dv4 = *(a + 4); \
+ v28 = vld1q_f16(b + 32); \
+ v0 = vfmaq_n_f16(v0, v28, dv4); \
+ dv5 = *(a + 5); \
+ v29 = vld1q_f16(b + 40); \
+ v0 = vfmaq_n_f16(v0, v29, dv5); \
+ dv6 = *(a + 6); \
+ v30 = vld1q_f16(b + 48); \
+ v0 = vfmaq_n_f16(v0, v30, dv6); \
+ dv7 = *(a + 7); \
+ v31 = vld1q_f16(b + 56); \
+ v0 = vfmaq_n_f16(v0, v31, dv7); \
+ l += 8; \
+ b += 8 * 8; \
+ a += 8; \
+ } while (0)
+
+// 2. Partial sum 32 digits : medium accuracy, medium latency
+#define KERNEL_1x8_ACC4() \
+ do { \
+ v0 = vdupq_n_f16(0.F); \
+ dv0 = *a; \
+ v24 = vld1q_f16(b); \
+ v0 = vfmaq_n_f16(v0, v24, dv0); \
+ dv1 = *(a + 1); \
+ v25 = vld1q_f16(b + 8); \
+ v0 = vfmaq_n_f16(v0, v25, dv1); \
+ dv2 = *(a + 2); \
+ v26 = vld1q_f16(b + 16); \
+ v0 = vfmaq_n_f16(v0, v26, dv2); \
+ dv3 = *(a + 3); \
+ v27 = vld1q_f16(b + 24); \
+ v0 = vfmaq_n_f16(v0, v27, dv3); \
+ l += 4; \
+ b += 8 * 4; \
+ a += 4; \
+ } while (0)
+
+// 3. Partial sum 8 digits : Best accuracy, worst latency
+#define KERNEL_1x8_ACC1() \
+ do { \
+ v0 = vdupq_n_f16(0.F); \
+ dv0 = *(a); \
+ v24 = vld1q_f16(b); \
+ v0 = vfmaq_n_f16(v0, v24, dv0); \
+ l += 1; \
+ b += 8 * 1; \
+ a++; \
+ } while (0)
+
+/**
+ * @brief hgemm 1x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(N % 8 == 0);
+
+ __fp16 *a = sa, *b = sb, *c = sc;
+ unsigned int k8 = (K >> 3) << 3;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i++) {
+ for (j = 0; j < N; j += 8) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+ float16x8_t v0;
+ float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+ float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+ l = 0;
+ for (; l < k8;) {
+ KERNEL_1x8_ACC8();
+
+ vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
+ }
+ for (; l < K;) {
+ KERNEL_1x8_ACC1();
+
+ vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
+ }
+ c += 8;
+ a -= K;
+ }
+ sc += ldc;
+ c = sc;
+ a += K;
+ b = sb;
+ }
+}
+
+/**
+ * @brief hgemm 1x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(N % 8 == 0);
+
+ __fp16 *a = sa, *b = sb;
+ float *c = sc;
+ unsigned int k8 = (K >> 3) << 3;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i++) {
+ for (j = 0; j < N; j += 8) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+ float16x8_t v0;
+ float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+ float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+ l = 0;
+ for (; l < k8;) {
+ KERNEL_1x8_ACC8();
+
+ vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));
+
+ vst1q_f32(c + 4,
+ vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));
+ }
+ for (; l < K;) {
+ KERNEL_1x8_ACC1();
+
+ vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));
+
+ vst1q_f32(c + 4,
+ vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));
+ }
+ c += 8;
+ a -= K;
+ }
+ sc += ldc;
+ c = sc;
+ a += K;
+ b = sb;
+ }
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_kernel_4x4.cpp
+ * @date 01 April 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM 4x4 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+#define INIT_KERNEL_4x4() \
+ do { \
+ v24 = vdup_n_f16(0.F); \
+ v25 = vdup_n_f16(0.F); \
+ v26 = vdup_n_f16(0.F); \
+ v27 = vdup_n_f16(0.F); \
+ } while (0)
+
+// 1. Partial sum 256 digits
+#define KERNEL_4x4_ACC16() \
+ do { \
+ dv0 = vld1_f16(a); \
+ vb0 = vld1_f16(b); \
+ v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
+ v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
+ v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
+ v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
+ dv1 = vld1_f16(a + 4); \
+ vb1 = vld1_f16(b + 4); \
+ v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
+ v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
+ v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
+ v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
+ dv2 = vld1_f16(a + 4 * 2); \
+ vb2 = vld1_f16(b + 4 * 2); \
+ v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
+ v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
+ v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
+ v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
+ dv3 = vld1_f16(a + 4 * 3); \
+ vb3 = vld1_f16(b + 4 * 3); \
+ v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
+ v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
+ v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
+ v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
+ dv4 = vld1_f16(a + 4 * 4); \
+ vb4 = vld1_f16(b + 4 * 4); \
+ v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
+ v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
+ v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
+ v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
+ dv5 = vld1_f16(a + 4 * 5); \
+ vb5 = vld1_f16(b + 4 * 5); \
+ v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
+ v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
+ v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
+ v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
+ dv6 = vld1_f16(a + 4 * 6); \
+ vb6 = vld1_f16(b + 4 * 6); \
+ v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
+ v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
+ v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
+ v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
+ dv7 = vld1_f16(a + 4 * 7); \
+ vb7 = vld1_f16(b + 4 * 7); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 8); \
+ vb7 = vld1_f16(b + 4 * 8); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 9); \
+ vb7 = vld1_f16(b + 4 * 9); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 10); \
+ vb7 = vld1_f16(b + 4 * 10); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 11); \
+ vb7 = vld1_f16(b + 4 * 11); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 12); \
+ vb7 = vld1_f16(b + 4 * 12); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 13); \
+ vb7 = vld1_f16(b + 4 * 13); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 14); \
+ vb7 = vld1_f16(b + 4 * 14); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 15); \
+ vb7 = vld1_f16(b + 4 * 15); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ l += 16; \
+ __builtin_prefetch(b + 64, 0, 3); \
+ __builtin_prefetch(a + 64, 0, 3); \
+ b += 4 * 16; \
+ a += 4 * 16; \
+ } while (0)
+
+// 2. Partial sum 128 digits
+#define KERNEL_4x4_ACC8() \
+ do { \
+ dv0 = vld1_f16(a); \
+ vb0 = vld1_f16(b); \
+ v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
+ v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
+ v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
+ v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
+ dv1 = vld1_f16(a + 4); \
+ vb1 = vld1_f16(b + 4); \
+ v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
+ v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
+ v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
+ v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
+ dv2 = vld1_f16(a + 8); \
+ vb2 = vld1_f16(b + 8); \
+ v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
+ v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
+ v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
+ v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
+ dv3 = vld1_f16(a + 12); \
+ vb3 = vld1_f16(b + 12); \
+ v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
+ v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
+ v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
+ v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
+ dv4 = vld1_f16(a + 16); \
+ vb4 = vld1_f16(b + 16); \
+ v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
+ v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
+ v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
+ v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
+ dv5 = vld1_f16(a + 20); \
+ vb5 = vld1_f16(b + 20); \
+ v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
+ v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
+ v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
+ v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
+ dv6 = vld1_f16(a + 24); \
+ vb6 = vld1_f16(b + 24); \
+ v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
+ v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
+ v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
+ v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
+ dv7 = vld1_f16(a + 28); \
+ vb7 = vld1_f16(b + 28); \
+ v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+ v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+ v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+ v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+ l += 8; \
+ __builtin_prefetch(b + 32, 0, 3); \
+ __builtin_prefetch(a + 32, 0, 3); \
+ b += 4 * 8; \
+ a += 4 * 8; \
+ } while (0)
+
+// 3. Partial sum 16 digits
+#define KERNEL_4x4_ACC1() \
+ do { \
+ dv0 = vld1_f16(a); \
+ vb0 = vld1_f16(b); \
+ v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
+ v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
+ v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
+ v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
+ l += 1; \
+ __builtin_prefetch(b + 4, 0, 3); \
+ __builtin_prefetch(a + 4, 0, 3); \
+ b += 4 * 1; \
+ a += 4 * 1; \
+ } while (0)
+
+#define SAVE_KERNEL_4X4_F16_F32() \
+ do { \
+ vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24))); \
+ vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(v25))); \
+ vst1q_f32(c + 2 * ldc, \
+ vaddq_f32(vld1q_f32(c + 2 * ldc), vcvt_f32_f16(v26))); \
+ vst1q_f32(c + 3 * ldc, \
+ vaddq_f32(vld1q_f32(c + 3 * ldc), vcvt_f32_f16(v27))); \
+ } while (0)
+
+/**
+ * @brief hgemm 4x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
+
+ __fp16 *a = sa, *b = sb, *c = sc;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i += 4) {
+ for (j = 0; j < N; j += 4) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+
+ float16x4_t v24;
+ float16x4_t v25;
+ float16x4_t v26;
+ float16x4_t v27;
+ INIT_KERNEL_4x4();
+
+ for (l = 0; l < K; l += 4) {
+ float16x4_t v0 = vld1_f16(b);
+ float16x4_t v16 = vld1_f16(a);
+
+ v24 = vfma_lane_f16(v24, v0, v16, 0);
+ v25 = vfma_lane_f16(v25, v0, v16, 1);
+ v26 = vfma_lane_f16(v26, v0, v16, 2);
+ v27 = vfma_lane_f16(v27, v0, v16, 3);
+
+ float16x4_t v1 = vld1_f16(b + 4);
+ float16x4_t v17 = vld1_f16(a + 4);
+
+ v24 = vfma_lane_f16(v24, v1, v17, 0);
+ v25 = vfma_lane_f16(v25, v1, v17, 1);
+ v26 = vfma_lane_f16(v26, v1, v17, 2);
+ v27 = vfma_lane_f16(v27, v1, v17, 3);
+
+ float16x4_t v2 = vld1_f16(b + 8);
+ float16x4_t v18 = vld1_f16(a + 8);
+
+ v24 = vfma_lane_f16(v24, v2, v18, 0);
+ v25 = vfma_lane_f16(v25, v2, v18, 1);
+ v26 = vfma_lane_f16(v26, v2, v18, 2);
+ v27 = vfma_lane_f16(v27, v2, v18, 3);
+
+ float16x4_t v3 = vld1_f16(b + 12);
+ float16x4_t v19 = vld1_f16(a + 12);
+
+ v24 = vfma_lane_f16(v24, v3, v19, 0);
+ v25 = vfma_lane_f16(v25, v3, v19, 1);
+ v26 = vfma_lane_f16(v26, v3, v19, 2);
+ v27 = vfma_lane_f16(v27, v3, v19, 3);
+
+ __builtin_prefetch(b + 16, 0, 3);
+ __builtin_prefetch(a + 16, 0, 3);
+
+ b += 16;
+ a += 16;
+ }
+
+ v24 = vadd_f16(vld1_f16(c), v24);
+ v25 = vadd_f16(vld1_f16(c + ldc), v25);
+ v26 = vadd_f16(vld1_f16(c + 2 * ldc), v26);
+ v27 = vadd_f16(vld1_f16(c + 3 * ldc), v27);
+
+ vst1_f16(c, v24);
+ vst1_f16(c + ldc, v25);
+ vst1_f16(c + 2 * ldc, v26);
+ vst1_f16(c + 3 * ldc, v27);
+
+ c += 4;
+ a -= 4 * K;
+ }
+ sc += ldc * 4;
+ c = sc;
+ a += 4 * K;
+ b = sb;
+ }
+}
+
+/**
+ * @brief hgemm 4x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
+
+ __fp16 *a = sa, *b = sb;
+ float *c = sc;
+ unsigned int i, j, l;
+ unsigned int K16 = (K >> 4) << 4;
+ unsigned int K8 = (K >> 3) << 3;
+ for (i = 0; i < M; i += 4) {
+ for (j = 0; j < N; j += 4) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+
+ float16x4_t v24, v25, v26, v27;
+ float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+ float16x4_t vb0, vb1, vb2, vb3, vb4, vb5, vb6, vb7;
+ l = 0;
+ for (; l < K16;) {
+ INIT_KERNEL_4x4();
+ KERNEL_4x4_ACC16();
+ SAVE_KERNEL_4X4_F16_F32();
+ }
+ for (; l < K8;) {
+ INIT_KERNEL_4x4();
+ KERNEL_4x4_ACC8();
+ SAVE_KERNEL_4X4_F16_F32();
+ }
+ for (; l < K;) {
+ INIT_KERNEL_4x4();
+ KERNEL_4x4_ACC1();
+ SAVE_KERNEL_4X4_F16_F32();
+ }
+
+ c += 4;
+ a -= 4 * K;
+ }
+ sc += ldc * 4;
+ c = sc;
+ a += 4 * K;
+ b = sb;
+ }
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_kernel_4x8.cpp
+ * @date 03 April 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM 8x8 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+#define INIT_KERNEL_4X8() \
+ do { \
+ v0 = vdupq_n_f16(0.F); \
+ v3 = vdupq_n_f16(0.F); \
+ v6 = vdupq_n_f16(0.F); \
+ v9 = vdupq_n_f16(0.F); \
+ } while (0)
+
+// 1. Partial sum 256 digits
+#define KERNEL_4x8_ACC16() \
+ do { \
+ dv0 = vld1_f16(a); \
+ v24 = vld1q_f16(b); \
+ v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
+ v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
+ v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
+ v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
+ dv1 = vld1_f16(a + 4); \
+ v25 = vld1q_f16(b + 8); \
+ v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
+ v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
+ v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
+ v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
+ dv2 = vld1_f16(a + 4 * 2); \
+ v26 = vld1q_f16(b + 8 * 2); \
+ v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
+ v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
+ v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
+ v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
+ dv3 = vld1_f16(a + 4 * 3); \
+ v27 = vld1q_f16(b + 8 * 3); \
+ v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
+ v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
+ v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
+ v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
+ dv4 = vld1_f16(a + 4 * 4); \
+ v28 = vld1q_f16(b + 8 * 4); \
+ v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
+ v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
+ v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
+ v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
+ dv5 = vld1_f16(a + 4 * 5); \
+ v29 = vld1q_f16(b + 8 * 5); \
+ v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
+ v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
+ v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
+ v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
+ dv6 = vld1_f16(a + 4 * 6); \
+ v30 = vld1q_f16(b + 8 * 6); \
+ v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
+ v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
+ v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
+ v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
+ dv7 = vld1_f16(a + 4 * 7); \
+ v31 = vld1q_f16(b + 8 * 7); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 8); \
+ v31 = vld1q_f16(b + 8 * 8); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 9); \
+ v31 = vld1q_f16(b + 8 * 9); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 10); \
+ v31 = vld1q_f16(b + 8 * 10); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 11); \
+ v31 = vld1q_f16(b + 8 * 11); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 12); \
+ v31 = vld1q_f16(b + 8 * 12); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 13); \
+ v31 = vld1q_f16(b + 8 * 13); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 14); \
+ v31 = vld1q_f16(b + 8 * 14); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ dv7 = vld1_f16(a + 4 * 15); \
+ v31 = vld1q_f16(b + 8 * 15); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ l += 16; \
+ __builtin_prefetch(b + 128, 0, 3); \
+ __builtin_prefetch(a + 64, 0, 3); \
+ b += 8 * 16; \
+ a += 4 * 16; \
+ } while (0)
+
+// 1. Partial sum 256 digits
+#define KERNEL_4x8_ACC8() \
+ do { \
+ dv0 = vld1_f16(a); \
+ v24 = vld1q_f16(b); \
+ v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
+ v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
+ v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
+ v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
+ dv1 = vld1_f16(a + 4); \
+ v25 = vld1q_f16(b + 8); \
+ v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
+ v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
+ v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
+ v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
+ dv2 = vld1_f16(a + 8); \
+ v26 = vld1q_f16(b + 16); \
+ v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
+ v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
+ v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
+ v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
+ dv3 = vld1_f16(a + 12); \
+ v27 = vld1q_f16(b + 24); \
+ v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
+ v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
+ v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
+ v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
+ dv4 = vld1_f16(a + 16); \
+ v28 = vld1q_f16(b + 32); \
+ v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
+ v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
+ v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
+ v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
+ dv5 = vld1_f16(a + 20); \
+ v29 = vld1q_f16(b + 40); \
+ v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
+ v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
+ v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
+ v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
+ dv6 = vld1_f16(a + 24); \
+ v30 = vld1q_f16(b + 48); \
+ v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
+ v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
+ v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
+ v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
+ dv7 = vld1_f16(a + 28); \
+ v31 = vld1q_f16(b + 56); \
+ v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+ v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+ v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+ v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+ l += 8; \
+ __builtin_prefetch(b + 64, 0, 3); \
+ __builtin_prefetch(a + 32, 0, 3); \
+ b += 8 * 8; \
+ a += 4 * 8; \
+ } while (0)
+
+// 2. Partial sum 128 digits
+#define KERNEL_4x8_ACC4() \
+ do { \
+ dv0 = vld1_f16(a); \
+ v24 = vld1q_f16(b); \
+ v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
+ v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
+ v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
+ v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
+ dv1 = vld1_f16(a + 4); \
+ v25 = vld1q_f16(b + 8); \
+ v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
+ v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
+ v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
+ v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
+ dv2 = vld1_f16(a + 8); \
+ v26 = vld1q_f16(b + 16); \
+ v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
+ v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
+ v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
+ v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
+ dv3 = vld1_f16(a + 12); \
+ v27 = vld1q_f16(b + 24); \
+ v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
+ v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
+ v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
+ v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
+ l += 4; \
+ __builtin_prefetch(b + 32, 0, 3); \
+ __builtin_prefetch(a + 16, 0, 3); \
+ b += 8 * 4; \
+ a += 4 * 4; \
+ } while (0)
+
+// 3. Partial sum 32 digits
+#define KERNEL_4x8_ACC1() \
+ do { \
+ dv0 = vld1_f16(a); \
+ v24 = vld1q_f16(b); \
+ v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
+ v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
+ v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
+ v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
+ l += 1; \
+ __builtin_prefetch(b + 8, 0, 3); \
+ __builtin_prefetch(a + 4, 0, 3); \
+ b += 8 * 1; \
+ a += 4 * 1; \
+ } while (0)
+
+#define SAVE_KERNEL_4X8_F16_F32() \
+ do { \
+ vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0)))); \
+ vst1q_f32(c + ldc, \
+ vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v3)))); \
+ vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v6)))); \
+ vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v9)))); \
+ \
+ vst1q_f32(c + 4, \
+ vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0)))); \
+ vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc), \
+ vcvt_f32_f16(vget_high_f16(v3)))); \
+ vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc), \
+ vcvt_f32_f16(vget_high_f16(v6)))); \
+ vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc), \
+ vcvt_f32_f16(vget_high_f16(v9)))); \
+ } while (0)
+
+/**
+ * @brief hgemm 4x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(M % 4 == 0 && N % 8 == 0);
+
+ __fp16 *a = sa, *b = sb, *c = sc;
+ unsigned int K8 = (K >> 3) << 3;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i += 4) {
+ for (j = 0; j < N; j += 8) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+ float16x8_t v0, v3, v6, v9;
+ float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+ float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+ INIT_KERNEL_4X8();
+ l = 0;
+ for (; l < K8;) {
+ KERNEL_4x8_ACC8();
+ }
+ for (; l < K;) {
+ KERNEL_4x8_ACC1();
+ }
+ vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
+ vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
+ vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
+ vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
+ c += 8;
+ a -= 4 * K;
+ }
+ sc += ldc * 4;
+ c = sc;
+ a += 4 * K;
+ b = sb;
+ }
+}
+
+/**
+ * @brief hgemm 4x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(M % 4 == 0 && N % 8 == 0);
+
+ __fp16 *a = sa, *b = sb;
+ float *c = sc;
+ unsigned int K16 = (K >> 4) << 4;
+ unsigned int K8 = (K >> 3) << 3;
+ unsigned int K4 = (K >> 2) << 2;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i += 4) {
+ for (j = 0; j < N; j += 8) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+ float16x8_t v0, v3, v6, v9;
+ float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+ float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+ l = 0;
+ for (; l < K16;) {
+ INIT_KERNEL_4X8();
+ KERNEL_4x8_ACC16();
+ SAVE_KERNEL_4X8_F16_F32();
+ }
+ for (; l < K8;) {
+ INIT_KERNEL_4X8();
+ KERNEL_4x8_ACC8();
+ SAVE_KERNEL_4X8_F16_F32();
+ }
+ for (; l < K4;) {
+ INIT_KERNEL_4X8();
+ KERNEL_4x8_ACC4();
+ SAVE_KERNEL_4X8_F16_F32();
+ }
+ for (; l < K;) {
+ INIT_KERNEL_4X8();
+ KERNEL_4x8_ACC1();
+ SAVE_KERNEL_4X8_F16_F32();
+ }
+ c += 8;
+ a -= 4 * K;
+ }
+ sc += ldc * 4;
+ c = sc;
+ a += 4 * K;
+ b = sb;
+ }
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_kernel_8x16.cpp
+ * @date 04 April 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM 8x16 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <iostream>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+#define INIT_KERNEL_8X16() \
+ do { \
+ v0_7 = vdupq_n_f16(0.F); \
+ v8_15 = vdupq_n_f16(0.F); \
+ v16_23 = vdupq_n_f16(0.F); \
+ v24_31 = vdupq_n_f16(0.F); \
+ v32_39 = vdupq_n_f16(0.F); \
+ v40_47 = vdupq_n_f16(0.F); \
+ v48_55 = vdupq_n_f16(0.F); \
+ v56_63 = vdupq_n_f16(0.F); \
+ v64_71 = vdupq_n_f16(0.F); \
+ v72_79 = vdupq_n_f16(0.F); \
+ v80_87 = vdupq_n_f16(0.F); \
+ v88_95 = vdupq_n_f16(0.F); \
+ v96_103 = vdupq_n_f16(0.F); \
+ v104_111 = vdupq_n_f16(0.F); \
+ v112_119 = vdupq_n_f16(0.F); \
+ v120_127 = vdupq_n_f16(0.F); \
+ } while (0)
+
+// 1. Partial sum 2048 digits
+#define KERNEL_8x16_ACC16() \
+ do { \
+ va0 = vld1q_f16(a + 8 * 0); \
+ vb1 = vld1q_f16(b + 8 * 0); \
+ vb2 = vld1q_f16(b + 8 * 1); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 1); \
+ vb1 = vld1q_f16(b + 8 * 2); \
+ vb2 = vld1q_f16(b + 8 * 3); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 2); \
+ vb1 = vld1q_f16(b + 8 * 4); \
+ vb2 = vld1q_f16(b + 8 * 5); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 3); \
+ vb1 = vld1q_f16(b + 8 * 6); \
+ vb2 = vld1q_f16(b + 8 * 7); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 4); \
+ vb1 = vld1q_f16(b + 8 * 8); \
+ vb2 = vld1q_f16(b + 8 * 9); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 5); \
+ vb1 = vld1q_f16(b + 8 * 10); \
+ vb2 = vld1q_f16(b + 8 * 11); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 6); \
+ vb1 = vld1q_f16(b + 8 * 12); \
+ vb2 = vld1q_f16(b + 8 * 13); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 7); \
+ vb1 = vld1q_f16(b + 8 * 14); \
+ vb2 = vld1q_f16(b + 8 * 15); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 8); \
+ vb1 = vld1q_f16(b + 8 * 16); \
+ vb2 = vld1q_f16(b + 8 * 17); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 9); \
+ vb1 = vld1q_f16(b + 8 * 18); \
+ vb2 = vld1q_f16(b + 8 * 19); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 10); \
+ vb1 = vld1q_f16(b + 8 * 20); \
+ vb2 = vld1q_f16(b + 8 * 21); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 11); \
+ vb1 = vld1q_f16(b + 8 * 22); \
+ vb2 = vld1q_f16(b + 8 * 23); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 12); \
+ vb1 = vld1q_f16(b + 8 * 24); \
+ vb2 = vld1q_f16(b + 8 * 25); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 13); \
+ vb1 = vld1q_f16(b + 8 * 26); \
+ vb2 = vld1q_f16(b + 8 * 27); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 14); \
+ vb1 = vld1q_f16(b + 8 * 28); \
+ vb2 = vld1q_f16(b + 8 * 29); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 15); \
+ vb1 = vld1q_f16(b + 8 * 30); \
+ vb2 = vld1q_f16(b + 8 * 31); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ __builtin_prefetch(b + 256, 0, 3); \
+ __builtin_prefetch(a + 128, 0, 3); \
+ l += 16; \
+ b += 16 * 16; \
+ a += 8 * 16; \
+ } while (0)
+
+// 2. Partial sum 1024 digits
+#define KERNEL_8x16_ACC8() \
+ do { \
+ va0 = vld1q_f16(a); \
+ vb1 = vld1q_f16(b); \
+ vb2 = vld1q_f16(b + 8); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8); \
+ vb1 = vld1q_f16(b + 16); \
+ vb2 = vld1q_f16(b + 24); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 16); \
+ vb1 = vld1q_f16(b + 32); \
+ vb2 = vld1q_f16(b + 40); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 24); \
+ vb1 = vld1q_f16(b + 48); \
+ vb2 = vld1q_f16(b + 56); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 32); \
+ vb1 = vld1q_f16(b + 64); \
+ vb2 = vld1q_f16(b + 72); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 40); \
+ vb1 = vld1q_f16(b + 80); \
+ vb2 = vld1q_f16(b + 88); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 48); \
+ vb1 = vld1q_f16(b + 96); \
+ vb2 = vld1q_f16(b + 104); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 56); \
+ vb1 = vld1q_f16(b + 112); \
+ vb2 = vld1q_f16(b + 120); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ l += 8; \
+ __builtin_prefetch(b + 128, 0, 3); \
+ __builtin_prefetch(a + 64, 0, 3); \
+ b += 16 * 8; \
+ a += 8 * 8; \
+ } while (0)
+
+// 3. Partial sum 512 digits
+#define KERNEL_8x16_ACC4() \
+ do { \
+ va0 = vld1q_f16(a); \
+ vb1 = vld1q_f16(b); \
+ vb2 = vld1q_f16(b + 8); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 8); \
+ vb1 = vld1q_f16(b + 16); \
+ vb2 = vld1q_f16(b + 24); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 16); \
+ vb1 = vld1q_f16(b + 32); \
+ vb2 = vld1q_f16(b + 40); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ va0 = vld1q_f16(a + 24); \
+ vb1 = vld1q_f16(b + 48); \
+ vb2 = vld1q_f16(b + 56); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ l += 4; \
+ __builtin_prefetch(b + 64, 0, 3); \
+ __builtin_prefetch(a + 32, 0, 3); \
+ b += 16 * 4; \
+ a += 8 * 4; \
+ } while (0)
+
+// 4. Partial sum 128 digits
+#define KERNEL_8x16_ACC1() \
+ do { \
+ va0 = vld1q_f16(a); \
+ vb1 = vld1q_f16(b); \
+ vb2 = vld1q_f16(b + 8); \
+ v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
+ v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
+ v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
+ v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
+ v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
+ v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
+ v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
+ v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
+ v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
+ v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
+ v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
+ v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
+ v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
+ v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+ v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+ v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+ l += 1; \
+ __builtin_prefetch(b + 16, 0, 3); \
+ __builtin_prefetch(a + 8, 0, 3); \
+ b += 16 * 1; \
+ a += 8 * 1; \
+ } while (0)
+
+#define SAVE_KERNEL_8X16_F16_F32() \
+ do { \
+ vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0_7)))); \
+ vst1q_f32(c + 4, \
+ vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0_7)))); \
+ \
+ vst1q_f32( \
+ c + 8, vaddq_f32(vld1q_f32(c + 8), vcvt_f32_f16(vget_low_f16(v64_71)))); \
+ vst1q_f32(c + 8 + 4, vaddq_f32(vld1q_f32(c + 8 + 4), \
+ vcvt_f32_f16(vget_high_f16(v64_71)))); \
+ \
+ vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), \
+ vcvt_f32_f16(vget_low_f16(v8_15)))); \
+ vst1q_f32(c + ldc + 4, vaddq_f32(vld1q_f32(c + ldc + 4), \
+ vcvt_f32_f16(vget_high_f16(v8_15)))); \
+ \
+ vst1q_f32(c + ldc + 8, vaddq_f32(vld1q_f32(c + ldc + 8), \
+ vcvt_f32_f16(vget_low_f16(v72_79)))); \
+ vst1q_f32(c + ldc + 8 + 4, \
+ vaddq_f32(vld1q_f32(c + ldc + 8 + 4), \
+ vcvt_f32_f16(vget_high_f16(v72_79)))); \
+ \
+ vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v16_23)))); \
+ vst1q_f32(c + 2 * ldc + 4, \
+ vaddq_f32(vld1q_f32(c + 2 * ldc + 4), \
+ vcvt_f32_f16(vget_high_f16(v16_23)))); \
+ \
+ vst1q_f32(c + 2 * ldc + 8, vaddq_f32(vld1q_f32(c + 2 * ldc + 8), \
+ vcvt_f32_f16(vget_low_f16(v80_87)))); \
+ vst1q_f32(c + 2 * ldc + 8 + 4, \
+ vaddq_f32(vld1q_f32(c + 2 * ldc + 8 + 4), \
+ vcvt_f32_f16(vget_high_f16(v80_87)))); \
+ \
+ vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v24_31)))); \
+ vst1q_f32(c + 3 * ldc + 4, \
+ vaddq_f32(vld1q_f32(c + 3 * ldc + 4), \
+ vcvt_f32_f16(vget_high_f16(v24_31)))); \
+ \
+ vst1q_f32(c + 3 * ldc + 8, vaddq_f32(vld1q_f32(c + 3 * ldc + 8), \
+ vcvt_f32_f16(vget_low_f16(v88_95)))); \
+ vst1q_f32(c + 3 * ldc + 8 + 4, \
+ vaddq_f32(vld1q_f32(c + 3 * ldc + 8 + 4), \
+ vcvt_f32_f16(vget_high_f16(v88_95)))); \
+ \
+ vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v32_39)))); \
+ vst1q_f32(c + 4 * ldc + 4, \
+ vaddq_f32(vld1q_f32(c + 4 * ldc + 4), \
+ vcvt_f32_f16(vget_high_f16(v32_39)))); \
+ \
+ vst1q_f32(c + 4 * ldc + 8, \
+ vaddq_f32(vld1q_f32(c + 4 * ldc + 8), \
+ vcvt_f32_f16(vget_low_f16(v96_103)))); \
+ vst1q_f32(c + 4 * ldc + 8 + 4, \
+ vaddq_f32(vld1q_f32(c + 4 * ldc + 8 + 4), \
+ vcvt_f32_f16(vget_high_f16(v96_103)))); \
+ \
+ vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v40_47)))); \
+ vst1q_f32(c + 5 * ldc + 4, \
+ vaddq_f32(vld1q_f32(c + 5 * ldc + 4), \
+ vcvt_f32_f16(vget_high_f16(v40_47)))); \
+ vst1q_f32(c + 5 * ldc + 8, \
+ vaddq_f32(vld1q_f32(c + 5 * ldc + 8), \
+ vcvt_f32_f16(vget_low_f16(v104_111)))); \
+ vst1q_f32(c + 5 * ldc + 8 + 4, \
+ vaddq_f32(vld1q_f32(c + 5 * ldc + 8 + 4), \
+ vcvt_f32_f16(vget_high_f16(v104_111)))); \
+ \
+ vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v48_55)))); \
+ vst1q_f32(c + 6 * ldc + 4, \
+ vaddq_f32(vld1q_f32(c + 6 * ldc + 4), \
+ vcvt_f32_f16(vget_high_f16(v48_55)))); \
+ \
+ vst1q_f32(c + 6 * ldc + 8, \
+ vaddq_f32(vld1q_f32(c + 6 * ldc + 8), \
+ vcvt_f32_f16(vget_low_f16(v112_119)))); \
+ vst1q_f32(c + 6 * ldc + 8 + 4, \
+ vaddq_f32(vld1q_f32(c + 6 * ldc + 8 + 4), \
+ vcvt_f32_f16(vget_high_f16(v112_119)))); \
+ \
+ vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v56_63)))); \
+ vst1q_f32(c + 7 * ldc + 4, \
+ vaddq_f32(vld1q_f32(c + 7 * ldc + 4), \
+ vcvt_f32_f16(vget_high_f16(v56_63)))); \
+ \
+ vst1q_f32(c + 7 * ldc + 8, \
+ vaddq_f32(vld1q_f32(c + 7 * ldc + 8), \
+ vcvt_f32_f16(vget_low_f16(v120_127)))); \
+ vst1q_f32(c + 7 * ldc + 8 + 4, \
+ vaddq_f32(vld1q_f32(c + 7 * ldc + 8 + 4), \
+ vcvt_f32_f16(vget_high_f16(v120_127)))); \
+ } while (0)
+
+/**
+ * @brief hgemm 8x16 kernel sc = sa * sb
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(M % 8 == 0 && N % 16 == 0 && K % 8 == 0);
+
+ __fp16 *a = sa, *b = sb, *c = sc;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i += 8) {
+ for (j = 0; j < N; j += 16) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+ // 8x16
+ float16x8_t v0_7, v8_15;
+ float16x8_t v16_23, v24_31;
+ float16x8_t v32_39, v40_47;
+ float16x8_t v48_55, v56_63;
+ float16x8_t v64_71, v72_79;
+ float16x8_t v80_87, v88_95;
+ float16x8_t v96_103, v104_111;
+ float16x8_t v112_119, v120_127;
+ float16x8_t vb1, vb2;
+ float16x8_t va0;
+
+ INIT_KERNEL_8X16();
+ l = 0;
+ for (; l < K;) {
+ KERNEL_8x16_ACC1();
+ }
+ vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
+ vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
+ vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
+ vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
+ vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23));
+ vst1q_f16(c + 2 * ldc + 8, vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87));
+ vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31));
+ vst1q_f16(c + 3 * ldc + 8, vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95));
+ vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39));
+ vst1q_f16(c + 4 * ldc + 8,
+ vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103));
+ vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47));
+ vst1q_f16(c + 5 * ldc + 8,
+ vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111));
+ vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55));
+ vst1q_f16(c + 6 * ldc + 8,
+ vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119));
+ vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
+ vst1q_f16(c + 7 * ldc + 8,
+ vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
+ c += 16;
+ a -= 8 * K;
+ }
+ sc += ldc * 8;
+ c = sc;
+ a += 8 * K;
+ b = sb;
+ }
+}
+
+/**
+ * @brief hgemm 8x16 kernel sc = sa * sb
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(M % 8 == 0 && N % 16 == 0 && K % 4 == 0);
+
+ __fp16 *a = sa, *b = sb;
+ float *c = sc;
+ unsigned int i, j, l;
+ unsigned int K4 = (K >> 2) << 2;
+ unsigned int K8 = (K >> 3) << 3;
+ unsigned int K16 = (K >> 4) << 4;
+ for (i = 0; i < M; i += 8) {
+ for (j = 0; j < N; j += 16) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+ float16x8_t v0_7, v8_15;
+ float16x8_t v16_23, v24_31;
+ float16x8_t v32_39, v40_47;
+ float16x8_t v48_55, v56_63;
+ float16x8_t v64_71, v72_79;
+ float16x8_t v80_87, v88_95;
+ float16x8_t v96_103, v104_111;
+ float16x8_t v112_119, v120_127;
+ float16x8_t vb1, vb2;
+ float16x8_t va0;
+ l = 0;
+ for (; l < K16;) {
+ INIT_KERNEL_8X16();
+ KERNEL_8x16_ACC16();
+ SAVE_KERNEL_8X16_F16_F32();
+ }
+ for (; l < K8;) {
+ INIT_KERNEL_8X16();
+ KERNEL_8x16_ACC8();
+ SAVE_KERNEL_8X16_F16_F32();
+ }
+ for (; l < K4;) {
+ INIT_KERNEL_8X16();
+ KERNEL_8x16_ACC4();
+ SAVE_KERNEL_8X16_F16_F32();
+ }
+ for (; l < K;) {
+ INIT_KERNEL_8X16();
+ KERNEL_8x16_ACC1();
+ SAVE_KERNEL_8X16_F16_F32();
+ }
+ c += 16;
+ a -= 8 * K;
+ }
+ sc += ldc * 8;
+ c = sc;
+ a += 8 * K;
+ b = sb;
+ }
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_kernel_8x8.cpp
+ * @date 01 April 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM 8x8 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+#define INIT_KERNEL_8x8() \
+ do { \
+ v24 = vdupq_n_f16(0.F); \
+ v25 = vdupq_n_f16(0.F); \
+ v26 = vdupq_n_f16(0.F); \
+ v27 = vdupq_n_f16(0.F); \
+ v28 = vdupq_n_f16(0.F); \
+ v29 = vdupq_n_f16(0.F); \
+ v30 = vdupq_n_f16(0.F); \
+ v31 = vdupq_n_f16(0.F); \
+ } while (0)
+
+// 1. Partial sum 1024 digits
+#define KERNEL_8x8_ACC16() \
+ do { \
+ va0 = vld1q_f16(a); \
+ v16 = vld1q_f16(b); \
+ v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
+ va0 = vld1q_f16(a + 8); \
+ v17 = vld1q_f16(b + 8); \
+ v24 = vfmaq_laneq_f16(v24, v17, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v17, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v17, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v17, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v17, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v17, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v17, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v17, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 2); \
+ v18 = vld1q_f16(b + 8 * 2); \
+ v24 = vfmaq_laneq_f16(v24, v18, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v18, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v18, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v18, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v18, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v18, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v18, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v18, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 3); \
+ v19 = vld1q_f16(b + 8 * 3); \
+ v24 = vfmaq_laneq_f16(v24, v19, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v19, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v19, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v19, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v19, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v19, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v19, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v19, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 4); \
+ v20 = vld1q_f16(b + 8 * 4); \
+ v24 = vfmaq_laneq_f16(v24, v20, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v20, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v20, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v20, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v20, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v20, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v20, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v20, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 5); \
+ v21 = vld1q_f16(b + 8 * 5); \
+ v24 = vfmaq_laneq_f16(v24, v21, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v21, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v21, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v21, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v21, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v21, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v21, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v21, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 6); \
+ v22 = vld1q_f16(b + 8 * 6); \
+ v24 = vfmaq_laneq_f16(v24, v22, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v22, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v22, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v22, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v22, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v22, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v22, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v22, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 7); \
+ v23 = vld1q_f16(b + 8 * 7); \
+ v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 8); \
+ v23 = vld1q_f16(b + 8 * 8); \
+ v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 9); \
+ v23 = vld1q_f16(b + 8 * 9); \
+ v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 10); \
+ v23 = vld1q_f16(b + 8 * 10); \
+ v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 11); \
+ v23 = vld1q_f16(b + 8 * 11); \
+ v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 12); \
+ v23 = vld1q_f16(b + 8 * 12); \
+ v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 13); \
+ v23 = vld1q_f16(b + 8 * 13); \
+ v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 14); \
+ v23 = vld1q_f16(b + 8 * 14); \
+ v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+ va0 = vld1q_f16(a + 8 * 15); \
+ v23 = vld1q_f16(b + 8 * 15); \
+ v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+ __builtin_prefetch(b + 128, 0, 3); \
+ __builtin_prefetch(a + 128, 0, 3); \
+ l += 16; \
+ b += 8 * 16; \
+ a += 8 * 16; \
+ } while (0)
+
+// 2. Partial sum 512 digits
+#define KERNEL_8x8_ACC8() \
+ do { \
+ va0 = vld1q_f16(a); \
+ v16 = vld1q_f16(b); \
+ v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
+ va1 = vld1q_f16(a + 8); \
+ v17 = vld1q_f16(b + 8); \
+ v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \
+ v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \
+ v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \
+ v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \
+ v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \
+ v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \
+ v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \
+ v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \
+ va2 = vld1q_f16(a + 16); \
+ v18 = vld1q_f16(b + 16); \
+ v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \
+ v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \
+ v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \
+ v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \
+ v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \
+ v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \
+ v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \
+ v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \
+ va3 = vld1q_f16(a + 24); \
+ v19 = vld1q_f16(b + 24); \
+ v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \
+ v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \
+ v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \
+ v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \
+ v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \
+ v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \
+ v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \
+ v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \
+ va4 = vld1q_f16(a + 32); \
+ v20 = vld1q_f16(b + 32); \
+ v24 = vfmaq_laneq_f16(v24, v20, va4, 0); \
+ v25 = vfmaq_laneq_f16(v25, v20, va4, 1); \
+ v26 = vfmaq_laneq_f16(v26, v20, va4, 2); \
+ v27 = vfmaq_laneq_f16(v27, v20, va4, 3); \
+ v28 = vfmaq_laneq_f16(v28, v20, va4, 4); \
+ v29 = vfmaq_laneq_f16(v29, v20, va4, 5); \
+ v30 = vfmaq_laneq_f16(v30, v20, va4, 6); \
+ v31 = vfmaq_laneq_f16(v31, v20, va4, 7); \
+ va5 = vld1q_f16(a + 40); \
+ v21 = vld1q_f16(b + 40); \
+ v24 = vfmaq_laneq_f16(v24, v21, va5, 0); \
+ v25 = vfmaq_laneq_f16(v25, v21, va5, 1); \
+ v26 = vfmaq_laneq_f16(v26, v21, va5, 2); \
+ v27 = vfmaq_laneq_f16(v27, v21, va5, 3); \
+ v28 = vfmaq_laneq_f16(v28, v21, va5, 4); \
+ v29 = vfmaq_laneq_f16(v29, v21, va5, 5); \
+ v30 = vfmaq_laneq_f16(v30, v21, va5, 6); \
+ v31 = vfmaq_laneq_f16(v31, v21, va5, 7); \
+ va6 = vld1q_f16(a + 48); \
+ v22 = vld1q_f16(b + 48); \
+ v24 = vfmaq_laneq_f16(v24, v22, va6, 0); \
+ v25 = vfmaq_laneq_f16(v25, v22, va6, 1); \
+ v26 = vfmaq_laneq_f16(v26, v22, va6, 2); \
+ v27 = vfmaq_laneq_f16(v27, v22, va6, 3); \
+ v28 = vfmaq_laneq_f16(v28, v22, va6, 4); \
+ v29 = vfmaq_laneq_f16(v29, v22, va6, 5); \
+ v30 = vfmaq_laneq_f16(v30, v22, va6, 6); \
+ v31 = vfmaq_laneq_f16(v31, v22, va6, 7); \
+ va7 = vld1q_f16(a + 56); \
+ v23 = vld1q_f16(b + 56); \
+ v24 = vfmaq_laneq_f16(v24, v23, va7, 0); \
+ v25 = vfmaq_laneq_f16(v25, v23, va7, 1); \
+ v26 = vfmaq_laneq_f16(v26, v23, va7, 2); \
+ v27 = vfmaq_laneq_f16(v27, v23, va7, 3); \
+ v28 = vfmaq_laneq_f16(v28, v23, va7, 4); \
+ v29 = vfmaq_laneq_f16(v29, v23, va7, 5); \
+ v30 = vfmaq_laneq_f16(v30, v23, va7, 6); \
+ v31 = vfmaq_laneq_f16(v31, v23, va7, 7); \
+ __builtin_prefetch(b + 64, 0, 3); \
+ __builtin_prefetch(a + 64, 0, 3); \
+ l += 8; \
+ b += 8 * 8; \
+ a += 8 * 8; \
+ } while (0)
+
+// 3. Partial sum 256 digits
+#define KERNEL_8x8_ACC4() \
+ do { \
+ va0 = vld1q_f16(a); \
+ v16 = vld1q_f16(b); \
+ v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
+ va1 = vld1q_f16(a + 8); \
+ v17 = vld1q_f16(b + 8); \
+ v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \
+ v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \
+ v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \
+ v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \
+ v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \
+ v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \
+ v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \
+ v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \
+ va2 = vld1q_f16(a + 16); \
+ v18 = vld1q_f16(b + 16); \
+ v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \
+ v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \
+ v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \
+ v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \
+ v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \
+ v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \
+ v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \
+ v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \
+ va3 = vld1q_f16(a + 24); \
+ v19 = vld1q_f16(b + 24); \
+ v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \
+ v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \
+ v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \
+ v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \
+ v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \
+ v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \
+ v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \
+ v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \
+ __builtin_prefetch(b + 32, 0, 3); \
+ __builtin_prefetch(a + 32, 0, 3); \
+ l += 4; \
+ b += 8 * 4; \
+ a += 8 * 4; \
+ } while (0)
+
+// 4. Partial sum 64 digits
+#define KERNEL_8x8_ACC1() \
+ do { \
+ va0 = vld1q_f16(a); \
+ v16 = vld1q_f16(b); \
+ v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
+ v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
+ v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
+ v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
+ v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
+ v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
+ v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
+ v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
+ __builtin_prefetch(b + 8, 0, 3); \
+ __builtin_prefetch(a + 8, 0, 3); \
+ l += 1; \
+ b += 8 * 1; \
+ a += 8 * 1; \
+ } while (0)
+
+#define SAVE_KERNEL_8X8_F16_f32() \
+ do { \
+ vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v24)))); \
+ vst1q_f32(c + 4, \
+ vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v24)))); \
+ \
+ vst1q_f32(c + ldc, \
+ vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v25)))); \
+ vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc), \
+ vcvt_f32_f16(vget_high_f16(v25)))); \
+ \
+ vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v26)))); \
+ vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc), \
+ vcvt_f32_f16(vget_high_f16(v26)))); \
+ \
+ vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v27)))); \
+ vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc), \
+ vcvt_f32_f16(vget_high_f16(v27)))); \
+ \
+ vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v28)))); \
+ vst1q_f32(c + 4 + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 + 4 * ldc), \
+ vcvt_f32_f16(vget_high_f16(v28)))); \
+ \
+ vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v29)))); \
+ vst1q_f32(c + 4 + 5 * ldc, vaddq_f32(vld1q_f32(c + 4 + 5 * ldc), \
+ vcvt_f32_f16(vget_high_f16(v29)))); \
+ \
+ vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v30)))); \
+ vst1q_f32(c + 4 + 6 * ldc, vaddq_f32(vld1q_f32(c + 4 + 6 * ldc), \
+ vcvt_f32_f16(vget_high_f16(v30)))); \
+ \
+ vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc), \
+ vcvt_f32_f16(vget_low_f16(v31)))); \
+ vst1q_f32(c + 4 + 7 * ldc, vaddq_f32(vld1q_f32(c + 4 + 7 * ldc), \
+ vcvt_f32_f16(vget_high_f16(v31)))); \
+ } while (0)
+
+/**
+ * @brief hgemm 8x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(M % 8 == 0 && N % 8 == 0 && K % 4 == 0);
+
+ __fp16 *a = sa, *b = sb, *c = sc;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i += 8) {
+ for (j = 0; j < N; j += 8) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+
+ float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
+ float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+ float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
+ INIT_KERNEL_8x8();
+ l = 0;
+ for (; l < K;) {
+ KERNEL_8x8_ACC1();
+ }
+ vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24));
+ vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25));
+ vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26));
+ vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27));
+ vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28));
+ vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29));
+ vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30));
+ vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31));
+ c += 8;
+ a -= 8 * K;
+ }
+ sc += ldc * 8;
+ c = sc;
+ a += 8 * K;
+ b = sb;
+ }
+}
+
+/**
+ * @brief hgemm 8x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(M % 8 == 0 && N % 8 == 0 && K % 8 == 0);
+
+ __fp16 *a = sa, *b = sb;
+ float *c = sc;
+ unsigned int i, j, l;
+ unsigned int K4 = (K >> 2) << 2;
+ unsigned int K8 = (K >> 3) << 3;
+ unsigned int K16 = (K >> 4) << 4;
+ for (i = 0; i < M; i += 8) {
+ for (j = 0; j < N; j += 8) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+
+ float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
+ float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+ float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
+ l = 0;
+ for (; l < K16;) {
+ INIT_KERNEL_8x8();
+ KERNEL_8x8_ACC16();
+ SAVE_KERNEL_8X8_F16_f32();
+ }
+ for (; l < K8;) {
+ INIT_KERNEL_8x8();
+ KERNEL_8x8_ACC8();
+ SAVE_KERNEL_8X8_F16_f32();
+ }
+ for (; l < K4;) {
+ INIT_KERNEL_8x8();
+ KERNEL_8x8_ACC4();
+ SAVE_KERNEL_8X8_F16_f32();
+ }
+ for (; l < K;) {
+ INIT_KERNEL_8x8();
+ KERNEL_8x8_ACC1();
+ SAVE_KERNEL_8X8_F16_f32();
+ }
+
+ c += 8;
+ a -= 8 * K;
+ }
+ sc += ldc * 8;
+ c = sc;
+ a += 8 * K;
+ b = sb;
+ }
+}
--- /dev/null
+hgemm_kernel_headers = [
+ 'hgemm_kernel.h',
+]
+
+
+hgemm_kernel_sources = [
+ 'hgemm_kernel_1x4.cpp',
+ 'hgemm_kernel_1x8.cpp',
+ 'hgemm_kernel_4x4.cpp',
+ 'hgemm_kernel_4x8.cpp',
+ 'hgemm_kernel_8x8.cpp',
+ 'hgemm_kernel_8x16.cpp',
+]
+
+foreach s : hgemm_kernel_sources
+ nntrainer_sources += meson.current_source_dir() / s
+endforeach
+
+foreach h : hgemm_kernel_headers
+ nntrainer_headers += meson.current_source_dir() / h
+endforeach
+
+++ /dev/null
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
- *
- * @file hgemm_kernel_1x4.h
- * @date 23 April 2024
- * @see https://github.com/nnstreamer/nntrainer
- * @author Debadri Samaddar <s.debadri@samsung.com>
- * @bug No known bugs except for NYI items
- * @brief This is half-precision GEMM 1x4 kernel
- *
- */
-
-#include <stdlib.h>
-#include <arm_neon.h>
-#include <assert.h>
-
-/**
- * @brief hgemm 1x4 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading dimension of matrix C
- */
-void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(N % 4 == 0);
-
- __fp16 *a = sa, *b = sb, *c = sc;
- unsigned int i, j, l;
- for (i = 0; i < M; i++) {
- for (j = 0; j < N; j += 4) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
-
- for (l = 0; l < K; l += 4) {
- float16x4_t v24 = {0.F};
- float16x4_t v0 = vld1_f16(b);
- float16_t v16 = *a;
-
- v24 = vfma_n_f16(v24, v0, v16);
-
- float16x4_t v1 = vld1_f16(b + 4);
- float16_t v17 = *(a + 1);
-
- v24 = vfma_n_f16(v24, v1, v17);
-
- float16x4_t v2 = vld1_f16(b + 8);
- float16_t v18 = *(a + 2);
-
- v24 = vfma_n_f16(v24, v2, v18);
-
- float16x4_t v3 = vld1_f16(b + 12);
- float16_t v19 = *(a + 3);
-
- v24 = vfma_n_f16(v24, v3, v19);
-
- __builtin_prefetch(b + 16, 0, 3);
- __builtin_prefetch(a + 4, 0, 3);
-
- b += 16;
- a += 4;
-
- v24 = vadd_f16(vld1_f16(c), v24);
-
- vst1_f16(c, v24);
- }
- c += 4;
- a -= K;
- }
- sc += ldc;
- c = sc;
- a += K;
- b = sb;
- }
-}
-
-/**
- * @brief hgemm 1x4 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading dimension of matrix C
- */
-void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(N % 4 == 0);
-
- __fp16 *a = sa, *b = sb;
- float *c = sc;
- unsigned int i, j, l;
- for (i = 0; i < M; i++) {
- for (j = 0; j < N; j += 4) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
-
- for (l = 0; l < K; l += 4) {
- float16x4_t v24 = {0.F};
- float16x4_t v0 = vld1_f16(b);
- float16_t v16 = *a;
-
- v24 = vfma_n_f16(v24, v0, v16);
-
- float16x4_t v1 = vld1_f16(b + 4);
- float16_t v17 = *(a + 1);
-
- v24 = vfma_n_f16(v24, v1, v17);
-
- float16x4_t v2 = vld1_f16(b + 8);
- float16_t v18 = *(a + 2);
-
- v24 = vfma_n_f16(v24, v2, v18);
-
- float16x4_t v3 = vld1_f16(b + 12);
- float16_t v19 = *(a + 3);
-
- v24 = vfma_n_f16(v24, v3, v19);
-
- __builtin_prefetch(b + 16, 0, 3);
- __builtin_prefetch(a + 4, 0, 3);
-
- b += 16;
- a += 4;
-
- vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24)));
- }
- c += 4;
- a -= K;
- }
- sc += ldc;
- c = sc;
- a += K;
- b = sb;
- }
-}
+++ /dev/null
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
- *
- * @file hgemm_kernel_1x8.h
- * @date 05 April 2024
- * @see https://github.com/nnstreamer/nntrainer
- * @author Debadri Samaddar <s.debadri@samsung.com>
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug No known bugs except for NYI items
- * @brief This is half-precision GEMM 1x8 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <stdlib.h>
-
-// 1. Partial sum 64 digits : worst accuracy, best latency
-#define KERNEL_1x8_ACC8() \
- do { \
- v0 = vdupq_n_f16(0.F); \
- dv0 = *a; \
- v24 = vld1q_f16(b); \
- v0 = vfmaq_n_f16(v0, v24, dv0); \
- dv1 = *(a + 1); \
- v25 = vld1q_f16(b + 8); \
- v0 = vfmaq_n_f16(v0, v25, dv1); \
- dv2 = *(a + 2); \
- v26 = vld1q_f16(b + 16); \
- v0 = vfmaq_n_f16(v0, v26, dv2); \
- dv3 = *(a + 3); \
- v27 = vld1q_f16(b + 24); \
- v0 = vfmaq_n_f16(v0, v27, dv3); \
- dv4 = *(a + 4); \
- v28 = vld1q_f16(b + 32); \
- v0 = vfmaq_n_f16(v0, v28, dv4); \
- dv5 = *(a + 5); \
- v29 = vld1q_f16(b + 40); \
- v0 = vfmaq_n_f16(v0, v29, dv5); \
- dv6 = *(a + 6); \
- v30 = vld1q_f16(b + 48); \
- v0 = vfmaq_n_f16(v0, v30, dv6); \
- dv7 = *(a + 7); \
- v31 = vld1q_f16(b + 56); \
- v0 = vfmaq_n_f16(v0, v31, dv7); \
- l += 8; \
- b += 8 * 8; \
- a += 8; \
- } while (0)
-
-// 2. Partial sum 32 digits : medium accuracy, medium latency
-#define KERNEL_1x8_ACC4() \
- do { \
- v0 = vdupq_n_f16(0.F); \
- dv0 = *a; \
- v24 = vld1q_f16(b); \
- v0 = vfmaq_n_f16(v0, v24, dv0); \
- dv1 = *(a + 1); \
- v25 = vld1q_f16(b + 8); \
- v0 = vfmaq_n_f16(v0, v25, dv1); \
- dv2 = *(a + 2); \
- v26 = vld1q_f16(b + 16); \
- v0 = vfmaq_n_f16(v0, v26, dv2); \
- dv3 = *(a + 3); \
- v27 = vld1q_f16(b + 24); \
- v0 = vfmaq_n_f16(v0, v27, dv3); \
- l += 4; \
- b += 8 * 4; \
- a += 4; \
- } while (0)
-
-// 3. Partial sum 8 digits : Best accuracy, worst latency
-#define KERNEL_1x8_ACC1() \
- do { \
- v0 = vdupq_n_f16(0.F); \
- dv0 = *(a); \
- v24 = vld1q_f16(b); \
- v0 = vfmaq_n_f16(v0, v24, dv0); \
- l += 1; \
- b += 8 * 1; \
- a++; \
- } while (0)
-
-/**
- * @brief hgemm 1x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(N % 8 == 0);
-
- __fp16 *a = sa, *b = sb, *c = sc;
- unsigned int k8 = (K >> 3) << 3;
- unsigned int i, j, l;
- for (i = 0; i < M; i++) {
- for (j = 0; j < N; j += 8) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
- float16x8_t v0;
- float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
- float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
- l = 0;
- for (; l < k8;) {
- KERNEL_1x8_ACC8();
-
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
- }
- for (; l < K;) {
- KERNEL_1x8_ACC1();
-
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
- }
- c += 8;
- a -= K;
- }
- sc += ldc;
- c = sc;
- a += K;
- b = sb;
- }
-}
-
-/**
- * @brief hgemm 1x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(N % 8 == 0);
-
- __fp16 *a = sa, *b = sb;
- float *c = sc;
- unsigned int k8 = (K >> 3) << 3;
- unsigned int i, j, l;
- for (i = 0; i < M; i++) {
- for (j = 0; j < N; j += 8) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
- float16x8_t v0;
- float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
- float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
- l = 0;
- for (; l < k8;) {
- KERNEL_1x8_ACC8();
-
- vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));
-
- vst1q_f32(c + 4,
- vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));
- }
- for (; l < K;) {
- KERNEL_1x8_ACC1();
-
- vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));
-
- vst1q_f32(c + 4,
- vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));
- }
- c += 8;
- a -= K;
- }
- sc += ldc;
- c = sc;
- a += K;
- b = sb;
- }
-}
+++ /dev/null
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file hgemm_kernel_4x4.h
- * @date 01 April 2024
- * @see https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug No known bugs except for NYI items
- * @brief This is half-precision GEMM 4x4 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <stdlib.h>
-
-#define INIT_KERNEL_4x4() \
- do { \
- v24 = vdup_n_f16(0.F); \
- v25 = vdup_n_f16(0.F); \
- v26 = vdup_n_f16(0.F); \
- v27 = vdup_n_f16(0.F); \
- } while (0)
-
-// 1. Partial sum 256 digits
-#define KERNEL_4x4_ACC16() \
- do { \
- dv0 = vld1_f16(a); \
- vb0 = vld1_f16(b); \
- v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
- v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
- v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
- v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
- dv1 = vld1_f16(a + 4); \
- vb1 = vld1_f16(b + 4); \
- v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
- v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
- v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
- v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
- dv2 = vld1_f16(a + 4 * 2); \
- vb2 = vld1_f16(b + 4 * 2); \
- v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
- v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
- v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
- v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
- dv3 = vld1_f16(a + 4 * 3); \
- vb3 = vld1_f16(b + 4 * 3); \
- v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
- v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
- v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
- v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
- dv4 = vld1_f16(a + 4 * 4); \
- vb4 = vld1_f16(b + 4 * 4); \
- v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
- v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
- v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
- v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
- dv5 = vld1_f16(a + 4 * 5); \
- vb5 = vld1_f16(b + 4 * 5); \
- v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
- v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
- v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
- v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
- dv6 = vld1_f16(a + 4 * 6); \
- vb6 = vld1_f16(b + 4 * 6); \
- v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
- v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
- v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
- v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
- dv7 = vld1_f16(a + 4 * 7); \
- vb7 = vld1_f16(b + 4 * 7); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 8); \
- vb7 = vld1_f16(b + 4 * 8); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 9); \
- vb7 = vld1_f16(b + 4 * 9); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 10); \
- vb7 = vld1_f16(b + 4 * 10); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 11); \
- vb7 = vld1_f16(b + 4 * 11); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 12); \
- vb7 = vld1_f16(b + 4 * 12); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 13); \
- vb7 = vld1_f16(b + 4 * 13); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 14); \
- vb7 = vld1_f16(b + 4 * 14); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 15); \
- vb7 = vld1_f16(b + 4 * 15); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- l += 16; \
- __builtin_prefetch(b + 64, 0, 3); \
- __builtin_prefetch(a + 64, 0, 3); \
- b += 4 * 16; \
- a += 4 * 16; \
- } while (0)
-
-// 2. Partial sum 128 digits
-#define KERNEL_4x4_ACC8() \
- do { \
- dv0 = vld1_f16(a); \
- vb0 = vld1_f16(b); \
- v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
- v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
- v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
- v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
- dv1 = vld1_f16(a + 4); \
- vb1 = vld1_f16(b + 4); \
- v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
- v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
- v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
- v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
- dv2 = vld1_f16(a + 8); \
- vb2 = vld1_f16(b + 8); \
- v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
- v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
- v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
- v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
- dv3 = vld1_f16(a + 12); \
- vb3 = vld1_f16(b + 12); \
- v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
- v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
- v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
- v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
- dv4 = vld1_f16(a + 16); \
- vb4 = vld1_f16(b + 16); \
- v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
- v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
- v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
- v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
- dv5 = vld1_f16(a + 20); \
- vb5 = vld1_f16(b + 20); \
- v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
- v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
- v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
- v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
- dv6 = vld1_f16(a + 24); \
- vb6 = vld1_f16(b + 24); \
- v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
- v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
- v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
- v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
- dv7 = vld1_f16(a + 28); \
- vb7 = vld1_f16(b + 28); \
- v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
- v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
- v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
- v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- l += 8; \
- __builtin_prefetch(b + 32, 0, 3); \
- __builtin_prefetch(a + 32, 0, 3); \
- b += 4 * 8; \
- a += 4 * 8; \
- } while (0)
-
-// 3. Partial sum 16 digits
-#define KERNEL_4x4_ACC1() \
- do { \
- dv0 = vld1_f16(a); \
- vb0 = vld1_f16(b); \
- v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
- v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
- v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
- v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
- l += 1; \
- __builtin_prefetch(b + 4, 0, 3); \
- __builtin_prefetch(a + 4, 0, 3); \
- b += 4 * 1; \
- a += 4 * 1; \
- } while (0)
-
-#define SAVE_KERNEL_4X4_F16_F32() \
- do { \
- vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24))); \
- vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(v25))); \
- vst1q_f32(c + 2 * ldc, \
- vaddq_f32(vld1q_f32(c + 2 * ldc), vcvt_f32_f16(v26))); \
- vst1q_f32(c + 3 * ldc, \
- vaddq_f32(vld1q_f32(c + 3 * ldc), vcvt_f32_f16(v27))); \
- } while (0)
-
-/**
- * @brief hgemm 4x4 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading dimension of matrix C
- */
-void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
-
- __fp16 *a = sa, *b = sb, *c = sc;
- unsigned int i, j, l;
- for (i = 0; i < M; i += 4) {
- for (j = 0; j < N; j += 4) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
-
- float16x4_t v24;
- float16x4_t v25;
- float16x4_t v26;
- float16x4_t v27;
- INIT_KERNEL_4x4();
-
- for (l = 0; l < K; l += 4) {
- float16x4_t v0 = vld1_f16(b);
- float16x4_t v16 = vld1_f16(a);
-
- v24 = vfma_lane_f16(v24, v0, v16, 0);
- v25 = vfma_lane_f16(v25, v0, v16, 1);
- v26 = vfma_lane_f16(v26, v0, v16, 2);
- v27 = vfma_lane_f16(v27, v0, v16, 3);
-
- float16x4_t v1 = vld1_f16(b + 4);
- float16x4_t v17 = vld1_f16(a + 4);
-
- v24 = vfma_lane_f16(v24, v1, v17, 0);
- v25 = vfma_lane_f16(v25, v1, v17, 1);
- v26 = vfma_lane_f16(v26, v1, v17, 2);
- v27 = vfma_lane_f16(v27, v1, v17, 3);
-
- float16x4_t v2 = vld1_f16(b + 8);
- float16x4_t v18 = vld1_f16(a + 8);
-
- v24 = vfma_lane_f16(v24, v2, v18, 0);
- v25 = vfma_lane_f16(v25, v2, v18, 1);
- v26 = vfma_lane_f16(v26, v2, v18, 2);
- v27 = vfma_lane_f16(v27, v2, v18, 3);
-
- float16x4_t v3 = vld1_f16(b + 12);
- float16x4_t v19 = vld1_f16(a + 12);
-
- v24 = vfma_lane_f16(v24, v3, v19, 0);
- v25 = vfma_lane_f16(v25, v3, v19, 1);
- v26 = vfma_lane_f16(v26, v3, v19, 2);
- v27 = vfma_lane_f16(v27, v3, v19, 3);
-
- __builtin_prefetch(b + 16, 0, 3);
- __builtin_prefetch(a + 16, 0, 3);
-
- b += 16;
- a += 16;
- }
-
- v24 = vadd_f16(vld1_f16(c), v24);
- v25 = vadd_f16(vld1_f16(c + ldc), v25);
- v26 = vadd_f16(vld1_f16(c + 2 * ldc), v26);
- v27 = vadd_f16(vld1_f16(c + 3 * ldc), v27);
-
- vst1_f16(c, v24);
- vst1_f16(c + ldc, v25);
- vst1_f16(c + 2 * ldc, v26);
- vst1_f16(c + 3 * ldc, v27);
-
- c += 4;
- a -= 4 * K;
- }
- sc += ldc * 4;
- c = sc;
- a += 4 * K;
- b = sb;
- }
-}
-
-/**
- * @brief hgemm 4x4 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading dimension of matrix C
- */
-void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
-
- __fp16 *a = sa, *b = sb;
- float *c = sc;
- unsigned int i, j, l;
- unsigned int K16 = (K >> 4) << 4;
- unsigned int K8 = (K >> 3) << 3;
- for (i = 0; i < M; i += 4) {
- for (j = 0; j < N; j += 4) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
-
- float16x4_t v24, v25, v26, v27;
- float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
- float16x4_t vb0, vb1, vb2, vb3, vb4, vb5, vb6, vb7;
- l = 0;
- for (; l < K16;) {
- INIT_KERNEL_4x4();
- KERNEL_4x4_ACC16();
- SAVE_KERNEL_4X4_F16_F32();
- }
- for (; l < K8;) {
- INIT_KERNEL_4x4();
- KERNEL_4x4_ACC8();
- SAVE_KERNEL_4X4_F16_F32();
- }
- for (; l < K;) {
- INIT_KERNEL_4x4();
- KERNEL_4x4_ACC1();
- SAVE_KERNEL_4X4_F16_F32();
- }
-
- c += 4;
- a -= 4 * K;
- }
- sc += ldc * 4;
- c = sc;
- a += 4 * K;
- b = sb;
- }
-}
+++ /dev/null
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file hgemm_kernel_4x8.h
- * @date 03 April 2024
- * @see https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug No known bugs except for NYI items
- * @brief This is half-precision GEMM 8x8 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <stdlib.h>
-
-#define INIT_KERNEL_4X8() \
- do { \
- v0 = vdupq_n_f16(0.F); \
- v3 = vdupq_n_f16(0.F); \
- v6 = vdupq_n_f16(0.F); \
- v9 = vdupq_n_f16(0.F); \
- } while (0)
-
-// 1. Partial sum 256 digits
-#define KERNEL_4x8_ACC16() \
- do { \
- dv0 = vld1_f16(a); \
- v24 = vld1q_f16(b); \
- v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
- v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
- v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
- v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
- dv1 = vld1_f16(a + 4); \
- v25 = vld1q_f16(b + 8); \
- v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
- v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
- v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
- v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
- dv2 = vld1_f16(a + 4 * 2); \
- v26 = vld1q_f16(b + 8 * 2); \
- v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
- v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
- v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
- v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
- dv3 = vld1_f16(a + 4 * 3); \
- v27 = vld1q_f16(b + 8 * 3); \
- v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
- v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
- v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
- v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
- dv4 = vld1_f16(a + 4 * 4); \
- v28 = vld1q_f16(b + 8 * 4); \
- v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
- v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
- v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
- v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
- dv5 = vld1_f16(a + 4 * 5); \
- v29 = vld1q_f16(b + 8 * 5); \
- v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
- v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
- v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
- v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
- dv6 = vld1_f16(a + 4 * 6); \
- v30 = vld1q_f16(b + 8 * 6); \
- v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
- v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
- v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
- v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
- dv7 = vld1_f16(a + 4 * 7); \
- v31 = vld1q_f16(b + 8 * 7); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 8); \
- v31 = vld1q_f16(b + 8 * 8); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 9); \
- v31 = vld1q_f16(b + 8 * 9); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 10); \
- v31 = vld1q_f16(b + 8 * 10); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 11); \
- v31 = vld1q_f16(b + 8 * 11); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 12); \
- v31 = vld1q_f16(b + 8 * 12); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 13); \
- v31 = vld1q_f16(b + 8 * 13); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 14); \
- v31 = vld1q_f16(b + 8 * 14); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- dv7 = vld1_f16(a + 4 * 15); \
- v31 = vld1q_f16(b + 8 * 15); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- l += 16; \
- __builtin_prefetch(b + 128, 0, 3); \
- __builtin_prefetch(a + 64, 0, 3); \
- b += 8 * 16; \
- a += 4 * 16; \
- } while (0)
-
-// 1. Partial sum 256 digits
-#define KERNEL_4x8_ACC8() \
- do { \
- dv0 = vld1_f16(a); \
- v24 = vld1q_f16(b); \
- v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
- v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
- v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
- v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
- dv1 = vld1_f16(a + 4); \
- v25 = vld1q_f16(b + 8); \
- v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
- v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
- v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
- v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
- dv2 = vld1_f16(a + 8); \
- v26 = vld1q_f16(b + 16); \
- v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
- v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
- v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
- v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
- dv3 = vld1_f16(a + 12); \
- v27 = vld1q_f16(b + 24); \
- v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
- v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
- v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
- v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
- dv4 = vld1_f16(a + 16); \
- v28 = vld1q_f16(b + 32); \
- v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
- v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
- v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
- v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
- dv5 = vld1_f16(a + 20); \
- v29 = vld1q_f16(b + 40); \
- v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
- v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
- v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
- v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
- dv6 = vld1_f16(a + 24); \
- v30 = vld1q_f16(b + 48); \
- v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
- v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
- v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
- v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
- dv7 = vld1_f16(a + 28); \
- v31 = vld1q_f16(b + 56); \
- v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
- v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
- v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
- v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
- l += 8; \
- __builtin_prefetch(b + 64, 0, 3); \
- __builtin_prefetch(a + 32, 0, 3); \
- b += 8 * 8; \
- a += 4 * 8; \
- } while (0)
-
-// 2. Partial sum 128 digits
-#define KERNEL_4x8_ACC4() \
- do { \
- dv0 = vld1_f16(a); \
- v24 = vld1q_f16(b); \
- v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
- v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
- v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
- v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
- dv1 = vld1_f16(a + 4); \
- v25 = vld1q_f16(b + 8); \
- v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
- v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
- v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
- v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
- dv2 = vld1_f16(a + 8); \
- v26 = vld1q_f16(b + 16); \
- v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
- v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
- v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
- v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
- dv3 = vld1_f16(a + 12); \
- v27 = vld1q_f16(b + 24); \
- v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
- v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
- v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
- v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
- l += 4; \
- __builtin_prefetch(b + 32, 0, 3); \
- __builtin_prefetch(a + 16, 0, 3); \
- b += 8 * 4; \
- a += 4 * 4; \
- } while (0)
-
-// 3. Partial sum 32 digits
-#define KERNEL_4x8_ACC1() \
- do { \
- dv0 = vld1_f16(a); \
- v24 = vld1q_f16(b); \
- v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
- v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
- v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
- v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
- l += 1; \
- __builtin_prefetch(b + 8, 0, 3); \
- __builtin_prefetch(a + 4, 0, 3); \
- b += 8 * 1; \
- a += 4 * 1; \
- } while (0)
-
-#define SAVE_KERNEL_4X8_F16_F32() \
- do { \
- vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0)))); \
- vst1q_f32(c + ldc, \
- vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v3)))); \
- vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \
- vcvt_f32_f16(vget_low_f16(v6)))); \
- vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \
- vcvt_f32_f16(vget_low_f16(v9)))); \
- \
- vst1q_f32(c + 4, \
- vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0)))); \
- vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc), \
- vcvt_f32_f16(vget_high_f16(v3)))); \
- vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc), \
- vcvt_f32_f16(vget_high_f16(v6)))); \
- vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc), \
- vcvt_f32_f16(vget_high_f16(v9)))); \
- } while (0)
-
-/**
- * @brief hgemm 4x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(M % 4 == 0 && N % 8 == 0);
-
- __fp16 *a = sa, *b = sb, *c = sc;
- unsigned int K8 = (K >> 3) << 3;
- unsigned int i, j, l;
- for (i = 0; i < M; i += 4) {
- for (j = 0; j < N; j += 8) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
- float16x8_t v0, v3, v6, v9;
- float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
- float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
- INIT_KERNEL_4X8();
- l = 0;
- for (; l < K8;) {
- KERNEL_4x8_ACC8();
- }
- for (; l < K;) {
- KERNEL_4x8_ACC1();
- }
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
- vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
- vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
- vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
- c += 8;
- a -= 4 * K;
- }
- sc += ldc * 4;
- c = sc;
- a += 4 * K;
- b = sb;
- }
-}
-
-/**
- * @brief hgemm 4x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(M % 4 == 0 && N % 8 == 0);
-
- __fp16 *a = sa, *b = sb;
- float *c = sc;
- unsigned int K16 = (K >> 4) << 4;
- unsigned int K8 = (K >> 3) << 3;
- unsigned int K4 = (K >> 2) << 2;
- unsigned int i, j, l;
- for (i = 0; i < M; i += 4) {
- for (j = 0; j < N; j += 8) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
- float16x8_t v0, v3, v6, v9;
- float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
- float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
- l = 0;
- for (; l < K16;) {
- INIT_KERNEL_4X8();
- KERNEL_4x8_ACC16();
- SAVE_KERNEL_4X8_F16_F32();
- }
- for (; l < K8;) {
- INIT_KERNEL_4X8();
- KERNEL_4x8_ACC8();
- SAVE_KERNEL_4X8_F16_F32();
- }
- for (; l < K4;) {
- INIT_KERNEL_4X8();
- KERNEL_4x8_ACC4();
- SAVE_KERNEL_4X8_F16_F32();
- }
- for (; l < K;) {
- INIT_KERNEL_4X8();
- KERNEL_4x8_ACC1();
- SAVE_KERNEL_4X8_F16_F32();
- }
- c += 8;
- a -= 4 * K;
- }
- sc += ldc * 4;
- c = sc;
- a += 4 * K;
- b = sb;
- }
-}
+++ /dev/null
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file hgemm_kernel_8x16.h
- * @date 04 April 2024
- * @see https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug No known bugs except for NYI items
- * @brief This is half-precision GEMM 8x16 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <iostream>
-#include <stdlib.h>
-
-#define INIT_KERNEL_8X16() \
- do { \
- v0_7 = vdupq_n_f16(0.F); \
- v8_15 = vdupq_n_f16(0.F); \
- v16_23 = vdupq_n_f16(0.F); \
- v24_31 = vdupq_n_f16(0.F); \
- v32_39 = vdupq_n_f16(0.F); \
- v40_47 = vdupq_n_f16(0.F); \
- v48_55 = vdupq_n_f16(0.F); \
- v56_63 = vdupq_n_f16(0.F); \
- v64_71 = vdupq_n_f16(0.F); \
- v72_79 = vdupq_n_f16(0.F); \
- v80_87 = vdupq_n_f16(0.F); \
- v88_95 = vdupq_n_f16(0.F); \
- v96_103 = vdupq_n_f16(0.F); \
- v104_111 = vdupq_n_f16(0.F); \
- v112_119 = vdupq_n_f16(0.F); \
- v120_127 = vdupq_n_f16(0.F); \
- } while (0)
-
-// 1. Partial sum 2048 digits
-#define KERNEL_8x16_ACC16() \
- do { \
- va0 = vld1q_f16(a + 8 * 0); \
- vb1 = vld1q_f16(b + 8 * 0); \
- vb2 = vld1q_f16(b + 8 * 1); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 1); \
- vb1 = vld1q_f16(b + 8 * 2); \
- vb2 = vld1q_f16(b + 8 * 3); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 2); \
- vb1 = vld1q_f16(b + 8 * 4); \
- vb2 = vld1q_f16(b + 8 * 5); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 3); \
- vb1 = vld1q_f16(b + 8 * 6); \
- vb2 = vld1q_f16(b + 8 * 7); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 4); \
- vb1 = vld1q_f16(b + 8 * 8); \
- vb2 = vld1q_f16(b + 8 * 9); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 5); \
- vb1 = vld1q_f16(b + 8 * 10); \
- vb2 = vld1q_f16(b + 8 * 11); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 6); \
- vb1 = vld1q_f16(b + 8 * 12); \
- vb2 = vld1q_f16(b + 8 * 13); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 7); \
- vb1 = vld1q_f16(b + 8 * 14); \
- vb2 = vld1q_f16(b + 8 * 15); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 8); \
- vb1 = vld1q_f16(b + 8 * 16); \
- vb2 = vld1q_f16(b + 8 * 17); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 9); \
- vb1 = vld1q_f16(b + 8 * 18); \
- vb2 = vld1q_f16(b + 8 * 19); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 10); \
- vb1 = vld1q_f16(b + 8 * 20); \
- vb2 = vld1q_f16(b + 8 * 21); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 11); \
- vb1 = vld1q_f16(b + 8 * 22); \
- vb2 = vld1q_f16(b + 8 * 23); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 12); \
- vb1 = vld1q_f16(b + 8 * 24); \
- vb2 = vld1q_f16(b + 8 * 25); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 13); \
- vb1 = vld1q_f16(b + 8 * 26); \
- vb2 = vld1q_f16(b + 8 * 27); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 14); \
- vb1 = vld1q_f16(b + 8 * 28); \
- vb2 = vld1q_f16(b + 8 * 29); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8 * 15); \
- vb1 = vld1q_f16(b + 8 * 30); \
- vb2 = vld1q_f16(b + 8 * 31); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- __builtin_prefetch(b + 256, 0, 3); \
- __builtin_prefetch(a + 128, 0, 3); \
- l += 16; \
- b += 16 * 16; \
- a += 8 * 16; \
- } while (0)
-
-// 2. Partial sum 1024 digits
-#define KERNEL_8x16_ACC8() \
- do { \
- va0 = vld1q_f16(a); \
- vb1 = vld1q_f16(b); \
- vb2 = vld1q_f16(b + 8); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8); \
- vb1 = vld1q_f16(b + 16); \
- vb2 = vld1q_f16(b + 24); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 16); \
- vb1 = vld1q_f16(b + 32); \
- vb2 = vld1q_f16(b + 40); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 24); \
- vb1 = vld1q_f16(b + 48); \
- vb2 = vld1q_f16(b + 56); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 32); \
- vb1 = vld1q_f16(b + 64); \
- vb2 = vld1q_f16(b + 72); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 40); \
- vb1 = vld1q_f16(b + 80); \
- vb2 = vld1q_f16(b + 88); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 48); \
- vb1 = vld1q_f16(b + 96); \
- vb2 = vld1q_f16(b + 104); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 56); \
- vb1 = vld1q_f16(b + 112); \
- vb2 = vld1q_f16(b + 120); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- l += 8; \
- __builtin_prefetch(b + 128, 0, 3); \
- __builtin_prefetch(a + 64, 0, 3); \
- b += 16 * 8; \
- a += 8 * 8; \
- } while (0)
-
-// 3. Partial sum 512 digits
-#define KERNEL_8x16_ACC4() \
- do { \
- va0 = vld1q_f16(a); \
- vb1 = vld1q_f16(b); \
- vb2 = vld1q_f16(b + 8); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 8); \
- vb1 = vld1q_f16(b + 16); \
- vb2 = vld1q_f16(b + 24); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 16); \
- vb1 = vld1q_f16(b + 32); \
- vb2 = vld1q_f16(b + 40); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- va0 = vld1q_f16(a + 24); \
- vb1 = vld1q_f16(b + 48); \
- vb2 = vld1q_f16(b + 56); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- l += 4; \
- __builtin_prefetch(b + 64, 0, 3); \
- __builtin_prefetch(a + 32, 0, 3); \
- b += 16 * 4; \
- a += 8 * 4; \
- } while (0)
-
-// 4. Partial sum 128 digits
-#define KERNEL_8x16_ACC1() \
- do { \
- va0 = vld1q_f16(a); \
- vb1 = vld1q_f16(b); \
- vb2 = vld1q_f16(b + 8); \
- v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0); \
- v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1); \
- v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2); \
- v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3); \
- v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4); \
- v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5); \
- v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6); \
- v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7); \
- v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0); \
- v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1); \
- v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2); \
- v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3); \
- v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4); \
- v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
- v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
- v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
- l += 1; \
- __builtin_prefetch(b + 16, 0, 3); \
- __builtin_prefetch(a + 8, 0, 3); \
- b += 16 * 1; \
- a += 8 * 1; \
- } while (0)
-
-#define SAVE_KERNEL_8X16_F16_F32() \
- do { \
- vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0_7)))); \
- vst1q_f32(c + 4, \
- vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0_7)))); \
- \
- vst1q_f32( \
- c + 8, vaddq_f32(vld1q_f32(c + 8), vcvt_f32_f16(vget_low_f16(v64_71)))); \
- vst1q_f32(c + 8 + 4, vaddq_f32(vld1q_f32(c + 8 + 4), \
- vcvt_f32_f16(vget_high_f16(v64_71)))); \
- \
- vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), \
- vcvt_f32_f16(vget_low_f16(v8_15)))); \
- vst1q_f32(c + ldc + 4, vaddq_f32(vld1q_f32(c + ldc + 4), \
- vcvt_f32_f16(vget_high_f16(v8_15)))); \
- \
- vst1q_f32(c + ldc + 8, vaddq_f32(vld1q_f32(c + ldc + 8), \
- vcvt_f32_f16(vget_low_f16(v72_79)))); \
- vst1q_f32(c + ldc + 8 + 4, \
- vaddq_f32(vld1q_f32(c + ldc + 8 + 4), \
- vcvt_f32_f16(vget_high_f16(v72_79)))); \
- \
- vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \
- vcvt_f32_f16(vget_low_f16(v16_23)))); \
- vst1q_f32(c + 2 * ldc + 4, \
- vaddq_f32(vld1q_f32(c + 2 * ldc + 4), \
- vcvt_f32_f16(vget_high_f16(v16_23)))); \
- \
- vst1q_f32(c + 2 * ldc + 8, vaddq_f32(vld1q_f32(c + 2 * ldc + 8), \
- vcvt_f32_f16(vget_low_f16(v80_87)))); \
- vst1q_f32(c + 2 * ldc + 8 + 4, \
- vaddq_f32(vld1q_f32(c + 2 * ldc + 8 + 4), \
- vcvt_f32_f16(vget_high_f16(v80_87)))); \
- \
- vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \
- vcvt_f32_f16(vget_low_f16(v24_31)))); \
- vst1q_f32(c + 3 * ldc + 4, \
- vaddq_f32(vld1q_f32(c + 3 * ldc + 4), \
- vcvt_f32_f16(vget_high_f16(v24_31)))); \
- \
- vst1q_f32(c + 3 * ldc + 8, vaddq_f32(vld1q_f32(c + 3 * ldc + 8), \
- vcvt_f32_f16(vget_low_f16(v88_95)))); \
- vst1q_f32(c + 3 * ldc + 8 + 4, \
- vaddq_f32(vld1q_f32(c + 3 * ldc + 8 + 4), \
- vcvt_f32_f16(vget_high_f16(v88_95)))); \
- \
- vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc), \
- vcvt_f32_f16(vget_low_f16(v32_39)))); \
- vst1q_f32(c + 4 * ldc + 4, \
- vaddq_f32(vld1q_f32(c + 4 * ldc + 4), \
- vcvt_f32_f16(vget_high_f16(v32_39)))); \
- \
- vst1q_f32(c + 4 * ldc + 8, \
- vaddq_f32(vld1q_f32(c + 4 * ldc + 8), \
- vcvt_f32_f16(vget_low_f16(v96_103)))); \
- vst1q_f32(c + 4 * ldc + 8 + 4, \
- vaddq_f32(vld1q_f32(c + 4 * ldc + 8 + 4), \
- vcvt_f32_f16(vget_high_f16(v96_103)))); \
- \
- vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc), \
- vcvt_f32_f16(vget_low_f16(v40_47)))); \
- vst1q_f32(c + 5 * ldc + 4, \
- vaddq_f32(vld1q_f32(c + 5 * ldc + 4), \
- vcvt_f32_f16(vget_high_f16(v40_47)))); \
- vst1q_f32(c + 5 * ldc + 8, \
- vaddq_f32(vld1q_f32(c + 5 * ldc + 8), \
- vcvt_f32_f16(vget_low_f16(v104_111)))); \
- vst1q_f32(c + 5 * ldc + 8 + 4, \
- vaddq_f32(vld1q_f32(c + 5 * ldc + 8 + 4), \
- vcvt_f32_f16(vget_high_f16(v104_111)))); \
- \
- vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc), \
- vcvt_f32_f16(vget_low_f16(v48_55)))); \
- vst1q_f32(c + 6 * ldc + 4, \
- vaddq_f32(vld1q_f32(c + 6 * ldc + 4), \
- vcvt_f32_f16(vget_high_f16(v48_55)))); \
- \
- vst1q_f32(c + 6 * ldc + 8, \
- vaddq_f32(vld1q_f32(c + 6 * ldc + 8), \
- vcvt_f32_f16(vget_low_f16(v112_119)))); \
- vst1q_f32(c + 6 * ldc + 8 + 4, \
- vaddq_f32(vld1q_f32(c + 6 * ldc + 8 + 4), \
- vcvt_f32_f16(vget_high_f16(v112_119)))); \
- \
- vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc), \
- vcvt_f32_f16(vget_low_f16(v56_63)))); \
- vst1q_f32(c + 7 * ldc + 4, \
- vaddq_f32(vld1q_f32(c + 7 * ldc + 4), \
- vcvt_f32_f16(vget_high_f16(v56_63)))); \
- \
- vst1q_f32(c + 7 * ldc + 8, \
- vaddq_f32(vld1q_f32(c + 7 * ldc + 8), \
- vcvt_f32_f16(vget_low_f16(v120_127)))); \
- vst1q_f32(c + 7 * ldc + 8 + 4, \
- vaddq_f32(vld1q_f32(c + 7 * ldc + 8 + 4), \
- vcvt_f32_f16(vget_high_f16(v120_127)))); \
- } while (0)
-
-/**
- * @brief hgemm 8x16 kernel sc = sa * sb
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(M % 8 == 0 && N % 16 == 0 && K % 8 == 0);
-
- __fp16 *a = sa, *b = sb, *c = sc;
- unsigned int i, j, l;
- for (i = 0; i < M; i += 8) {
- for (j = 0; j < N; j += 16) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
- // 8x16
- float16x8_t v0_7, v8_15;
- float16x8_t v16_23, v24_31;
- float16x8_t v32_39, v40_47;
- float16x8_t v48_55, v56_63;
- float16x8_t v64_71, v72_79;
- float16x8_t v80_87, v88_95;
- float16x8_t v96_103, v104_111;
- float16x8_t v112_119, v120_127;
- float16x8_t vb1, vb2;
- float16x8_t va0;
-
- INIT_KERNEL_8X16();
- l = 0;
- for (; l < K;) {
- KERNEL_8x16_ACC1();
- }
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
- vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
- vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
- vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
- vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23));
- vst1q_f16(c + 2 * ldc + 8, vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87));
- vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31));
- vst1q_f16(c + 3 * ldc + 8, vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95));
- vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39));
- vst1q_f16(c + 4 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103));
- vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47));
- vst1q_f16(c + 5 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111));
- vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55));
- vst1q_f16(c + 6 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119));
- vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
- vst1q_f16(c + 7 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
- c += 16;
- a -= 8 * K;
- }
- sc += ldc * 8;
- c = sc;
- a += 8 * K;
- b = sb;
- }
-}
-
-/**
- * @brief hgemm 8x16 kernel sc = sa * sb
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(M % 8 == 0 && N % 16 == 0 && K % 4 == 0);
-
- __fp16 *a = sa, *b = sb;
- float *c = sc;
- unsigned int i, j, l;
- unsigned int K4 = (K >> 2) << 2;
- unsigned int K8 = (K >> 3) << 3;
- unsigned int K16 = (K >> 4) << 4;
- for (i = 0; i < M; i += 8) {
- for (j = 0; j < N; j += 16) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
- float16x8_t v0_7, v8_15;
- float16x8_t v16_23, v24_31;
- float16x8_t v32_39, v40_47;
- float16x8_t v48_55, v56_63;
- float16x8_t v64_71, v72_79;
- float16x8_t v80_87, v88_95;
- float16x8_t v96_103, v104_111;
- float16x8_t v112_119, v120_127;
- float16x8_t vb1, vb2;
- float16x8_t va0;
- l = 0;
- for (; l < K16;) {
- INIT_KERNEL_8X16();
- KERNEL_8x16_ACC16();
- SAVE_KERNEL_8X16_F16_F32();
- }
- for (; l < K8;) {
- INIT_KERNEL_8X16();
- KERNEL_8x16_ACC8();
- SAVE_KERNEL_8X16_F16_F32();
- }
- for (; l < K4;) {
- INIT_KERNEL_8X16();
- KERNEL_8x16_ACC4();
- SAVE_KERNEL_8X16_F16_F32();
- }
- for (; l < K;) {
- INIT_KERNEL_8X16();
- KERNEL_8x16_ACC1();
- SAVE_KERNEL_8X16_F16_F32();
- }
- c += 16;
- a -= 8 * K;
- }
- sc += ldc * 8;
- c = sc;
- a += 8 * K;
- b = sb;
- }
-}
+++ /dev/null
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file hgemm_kernel_8x8.h
- * @date 01 April 2024
- * @see https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug No known bugs except for NYI items
- * @brief This is half-precision GEMM 8x8 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <stdlib.h>
-
-#define INIT_KERNEL_8x8() \
- do { \
- v24 = vdupq_n_f16(0.F); \
- v25 = vdupq_n_f16(0.F); \
- v26 = vdupq_n_f16(0.F); \
- v27 = vdupq_n_f16(0.F); \
- v28 = vdupq_n_f16(0.F); \
- v29 = vdupq_n_f16(0.F); \
- v30 = vdupq_n_f16(0.F); \
- v31 = vdupq_n_f16(0.F); \
- } while (0)
-
-// 1. Partial sum 1024 digits
-#define KERNEL_8x8_ACC16() \
- do { \
- va0 = vld1q_f16(a); \
- v16 = vld1q_f16(b); \
- v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
- va0 = vld1q_f16(a + 8); \
- v17 = vld1q_f16(b + 8); \
- v24 = vfmaq_laneq_f16(v24, v17, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v17, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v17, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v17, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v17, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v17, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v17, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v17, va0, 7); \
- va0 = vld1q_f16(a + 8 * 2); \
- v18 = vld1q_f16(b + 8 * 2); \
- v24 = vfmaq_laneq_f16(v24, v18, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v18, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v18, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v18, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v18, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v18, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v18, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v18, va0, 7); \
- va0 = vld1q_f16(a + 8 * 3); \
- v19 = vld1q_f16(b + 8 * 3); \
- v24 = vfmaq_laneq_f16(v24, v19, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v19, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v19, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v19, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v19, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v19, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v19, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v19, va0, 7); \
- va0 = vld1q_f16(a + 8 * 4); \
- v20 = vld1q_f16(b + 8 * 4); \
- v24 = vfmaq_laneq_f16(v24, v20, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v20, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v20, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v20, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v20, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v20, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v20, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v20, va0, 7); \
- va0 = vld1q_f16(a + 8 * 5); \
- v21 = vld1q_f16(b + 8 * 5); \
- v24 = vfmaq_laneq_f16(v24, v21, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v21, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v21, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v21, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v21, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v21, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v21, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v21, va0, 7); \
- va0 = vld1q_f16(a + 8 * 6); \
- v22 = vld1q_f16(b + 8 * 6); \
- v24 = vfmaq_laneq_f16(v24, v22, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v22, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v22, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v22, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v22, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v22, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v22, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v22, va0, 7); \
- va0 = vld1q_f16(a + 8 * 7); \
- v23 = vld1q_f16(b + 8 * 7); \
- v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
- va0 = vld1q_f16(a + 8 * 8); \
- v23 = vld1q_f16(b + 8 * 8); \
- v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
- va0 = vld1q_f16(a + 8 * 9); \
- v23 = vld1q_f16(b + 8 * 9); \
- v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
- va0 = vld1q_f16(a + 8 * 10); \
- v23 = vld1q_f16(b + 8 * 10); \
- v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
- va0 = vld1q_f16(a + 8 * 11); \
- v23 = vld1q_f16(b + 8 * 11); \
- v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
- va0 = vld1q_f16(a + 8 * 12); \
- v23 = vld1q_f16(b + 8 * 12); \
- v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
- va0 = vld1q_f16(a + 8 * 13); \
- v23 = vld1q_f16(b + 8 * 13); \
- v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
- va0 = vld1q_f16(a + 8 * 14); \
- v23 = vld1q_f16(b + 8 * 14); \
- v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
- va0 = vld1q_f16(a + 8 * 15); \
- v23 = vld1q_f16(b + 8 * 15); \
- v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
- __builtin_prefetch(b + 128, 0, 3); \
- __builtin_prefetch(a + 128, 0, 3); \
- l += 16; \
- b += 8 * 16; \
- a += 8 * 16; \
- } while (0)
-
-// 2. Partial sum 512 digits
-#define KERNEL_8x8_ACC8() \
- do { \
- va0 = vld1q_f16(a); \
- v16 = vld1q_f16(b); \
- v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
- va1 = vld1q_f16(a + 8); \
- v17 = vld1q_f16(b + 8); \
- v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \
- v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \
- v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \
- v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \
- v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \
- v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \
- v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \
- v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \
- va2 = vld1q_f16(a + 16); \
- v18 = vld1q_f16(b + 16); \
- v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \
- v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \
- v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \
- v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \
- v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \
- v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \
- v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \
- v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \
- va3 = vld1q_f16(a + 24); \
- v19 = vld1q_f16(b + 24); \
- v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \
- v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \
- v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \
- v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \
- v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \
- v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \
- v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \
- v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \
- va4 = vld1q_f16(a + 32); \
- v20 = vld1q_f16(b + 32); \
- v24 = vfmaq_laneq_f16(v24, v20, va4, 0); \
- v25 = vfmaq_laneq_f16(v25, v20, va4, 1); \
- v26 = vfmaq_laneq_f16(v26, v20, va4, 2); \
- v27 = vfmaq_laneq_f16(v27, v20, va4, 3); \
- v28 = vfmaq_laneq_f16(v28, v20, va4, 4); \
- v29 = vfmaq_laneq_f16(v29, v20, va4, 5); \
- v30 = vfmaq_laneq_f16(v30, v20, va4, 6); \
- v31 = vfmaq_laneq_f16(v31, v20, va4, 7); \
- va5 = vld1q_f16(a + 40); \
- v21 = vld1q_f16(b + 40); \
- v24 = vfmaq_laneq_f16(v24, v21, va5, 0); \
- v25 = vfmaq_laneq_f16(v25, v21, va5, 1); \
- v26 = vfmaq_laneq_f16(v26, v21, va5, 2); \
- v27 = vfmaq_laneq_f16(v27, v21, va5, 3); \
- v28 = vfmaq_laneq_f16(v28, v21, va5, 4); \
- v29 = vfmaq_laneq_f16(v29, v21, va5, 5); \
- v30 = vfmaq_laneq_f16(v30, v21, va5, 6); \
- v31 = vfmaq_laneq_f16(v31, v21, va5, 7); \
- va6 = vld1q_f16(a + 48); \
- v22 = vld1q_f16(b + 48); \
- v24 = vfmaq_laneq_f16(v24, v22, va6, 0); \
- v25 = vfmaq_laneq_f16(v25, v22, va6, 1); \
- v26 = vfmaq_laneq_f16(v26, v22, va6, 2); \
- v27 = vfmaq_laneq_f16(v27, v22, va6, 3); \
- v28 = vfmaq_laneq_f16(v28, v22, va6, 4); \
- v29 = vfmaq_laneq_f16(v29, v22, va6, 5); \
- v30 = vfmaq_laneq_f16(v30, v22, va6, 6); \
- v31 = vfmaq_laneq_f16(v31, v22, va6, 7); \
- va7 = vld1q_f16(a + 56); \
- v23 = vld1q_f16(b + 56); \
- v24 = vfmaq_laneq_f16(v24, v23, va7, 0); \
- v25 = vfmaq_laneq_f16(v25, v23, va7, 1); \
- v26 = vfmaq_laneq_f16(v26, v23, va7, 2); \
- v27 = vfmaq_laneq_f16(v27, v23, va7, 3); \
- v28 = vfmaq_laneq_f16(v28, v23, va7, 4); \
- v29 = vfmaq_laneq_f16(v29, v23, va7, 5); \
- v30 = vfmaq_laneq_f16(v30, v23, va7, 6); \
- v31 = vfmaq_laneq_f16(v31, v23, va7, 7); \
- __builtin_prefetch(b + 64, 0, 3); \
- __builtin_prefetch(a + 64, 0, 3); \
- l += 8; \
- b += 8 * 8; \
- a += 8 * 8; \
- } while (0)
-
-// 3. Partial sum 256 digits
-#define KERNEL_8x8_ACC4() \
- do { \
- va0 = vld1q_f16(a); \
- v16 = vld1q_f16(b); \
- v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
- va1 = vld1q_f16(a + 8); \
- v17 = vld1q_f16(b + 8); \
- v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \
- v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \
- v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \
- v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \
- v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \
- v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \
- v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \
- v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \
- va2 = vld1q_f16(a + 16); \
- v18 = vld1q_f16(b + 16); \
- v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \
- v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \
- v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \
- v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \
- v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \
- v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \
- v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \
- v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \
- va3 = vld1q_f16(a + 24); \
- v19 = vld1q_f16(b + 24); \
- v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \
- v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \
- v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \
- v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \
- v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \
- v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \
- v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \
- v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \
- __builtin_prefetch(b + 32, 0, 3); \
- __builtin_prefetch(a + 32, 0, 3); \
- l += 4; \
- b += 8 * 4; \
- a += 8 * 4; \
- } while (0)
-
-// 4. Partial sum 64 digits
-#define KERNEL_8x8_ACC1() \
- do { \
- va0 = vld1q_f16(a); \
- v16 = vld1q_f16(b); \
- v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
- v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
- v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
- v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
- v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
- v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
- v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
- v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
- __builtin_prefetch(b + 8, 0, 3); \
- __builtin_prefetch(a + 8, 0, 3); \
- l += 1; \
- b += 8 * 1; \
- a += 8 * 1; \
- } while (0)
-
-#define SAVE_KERNEL_8X8_F16_f32() \
- do { \
- vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v24)))); \
- vst1q_f32(c + 4, \
- vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v24)))); \
- \
- vst1q_f32(c + ldc, \
- vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v25)))); \
- vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc), \
- vcvt_f32_f16(vget_high_f16(v25)))); \
- \
- vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc), \
- vcvt_f32_f16(vget_low_f16(v26)))); \
- vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc), \
- vcvt_f32_f16(vget_high_f16(v26)))); \
- \
- vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), \
- vcvt_f32_f16(vget_low_f16(v27)))); \
- vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc), \
- vcvt_f32_f16(vget_high_f16(v27)))); \
- \
- vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc), \
- vcvt_f32_f16(vget_low_f16(v28)))); \
- vst1q_f32(c + 4 + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 + 4 * ldc), \
- vcvt_f32_f16(vget_high_f16(v28)))); \
- \
- vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc), \
- vcvt_f32_f16(vget_low_f16(v29)))); \
- vst1q_f32(c + 4 + 5 * ldc, vaddq_f32(vld1q_f32(c + 4 + 5 * ldc), \
- vcvt_f32_f16(vget_high_f16(v29)))); \
- \
- vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc), \
- vcvt_f32_f16(vget_low_f16(v30)))); \
- vst1q_f32(c + 4 + 6 * ldc, vaddq_f32(vld1q_f32(c + 4 + 6 * ldc), \
- vcvt_f32_f16(vget_high_f16(v30)))); \
- \
- vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc), \
- vcvt_f32_f16(vget_low_f16(v31)))); \
- vst1q_f32(c + 4 + 7 * ldc, vaddq_f32(vld1q_f32(c + 4 + 7 * ldc), \
- vcvt_f32_f16(vget_high_f16(v31)))); \
- } while (0)
-
-/**
- * @brief hgemm 8x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(M % 8 == 0 && N % 8 == 0 && K % 4 == 0);
-
- __fp16 *a = sa, *b = sb, *c = sc;
- unsigned int i, j, l;
- for (i = 0; i < M; i += 8) {
- for (j = 0; j < N; j += 8) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
-
- float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
- float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
- float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
- INIT_KERNEL_8x8();
- l = 0;
- for (; l < K;) {
- KERNEL_8x8_ACC1();
- }
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24));
- vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25));
- vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26));
- vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27));
- vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28));
- vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29));
- vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30));
- vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31));
- c += 8;
- a -= 8 * K;
- }
- sc += ldc * 8;
- c = sc;
- a += 8 * K;
- b = sb;
- }
-}
-
-/**
- * @brief hgemm 8x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
- __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
- assert(M > 0 && N > 0 && K > 0);
- assert(M % 8 == 0 && N % 8 == 0 && K % 8 == 0);
-
- __fp16 *a = sa, *b = sb;
- float *c = sc;
- unsigned int i, j, l;
- unsigned int K4 = (K >> 2) << 2;
- unsigned int K8 = (K >> 3) << 3;
- unsigned int K16 = (K >> 4) << 4;
- for (i = 0; i < M; i += 8) {
- for (j = 0; j < N; j += 8) {
- __builtin_prefetch(b, 0, 3);
- __builtin_prefetch(a, 0, 3);
-
- float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
- float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
- float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
- l = 0;
- for (; l < K16;) {
- INIT_KERNEL_8x8();
- KERNEL_8x8_ACC16();
- SAVE_KERNEL_8X8_F16_f32();
- }
- for (; l < K8;) {
- INIT_KERNEL_8x8();
- KERNEL_8x8_ACC8();
- SAVE_KERNEL_8X8_F16_f32();
- }
- for (; l < K4;) {
- INIT_KERNEL_8x8();
- KERNEL_8x8_ACC4();
- SAVE_KERNEL_8X8_F16_f32();
- }
- for (; l < K;) {
- INIT_KERNEL_8x8();
- KERNEL_8x8_ACC1();
- SAVE_KERNEL_8X8_F16_f32();
- }
-
- c += 8;
- a -= 8 * K;
- }
- sc += ldc * 8;
- c = sc;
- a += 8 * K;
- b = sb;
- }
-}
+++ /dev/null
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file hgemm_kernel_pack.cpp
- * @date 02 July 2024
- * @see https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug No known bugs except for NYI items
- * @brief This is a source file for half-precision packing for the matrix
- * multiplication
- */
-
-#include <assert.h>
-#include <hgemm_common.h>
-#include <hgemm_kernel_pack.h>
-#include <matrix_transpose_neon.h>
-
-void packing_A1(unsigned int m, unsigned int k, const __fp16 *from,
- unsigned int lda, const __fp16 *to) {
-
- assert(k != 0 && m != 0 && k % 4 == 0 && m % 4 == 0);
- unsigned int i, j;
-
- __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4;
- __fp16 *b_offset;
- __fp16 ctemp1, ctemp2, ctemp3, ctemp4;
-
- a_offset = (__fp16 *)from;
- b_offset = (__fp16 *)to;
-
- j = m;
- do {
- a_offset1 = a_offset;
- a_offset += lda;
-
- i = (k >> 2);
- do {
- ctemp1 = *(a_offset1 + 0);
- ctemp2 = *(a_offset1 + 1);
- ctemp3 = *(a_offset1 + 2);
- ctemp4 = *(a_offset1 + 3);
-
- *(b_offset + 0) = ctemp1;
- *(b_offset + 1) = ctemp2;
- *(b_offset + 2) = ctemp3;
- *(b_offset + 3) = ctemp4;
-
- a_offset1 += 4;
-
- b_offset += 4;
- i--;
- } while (i > 0);
- j--;
- } while (j > 0);
-}
-
-void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
- unsigned int lda, const __fp16 *dst) {
-
- assert(K != 0 && M != 0 && K % 4 == 0 && M % 4 == 0);
- unsigned int i, j;
-
- __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
- __fp16 *b_off;
- __fp16 c1, c2, c3, c4;
- __fp16 c5, c6, c7, c8;
- __fp16 c9, c10, c11, c12;
- __fp16 c13, c14, c15, c16;
-
- a_off = (__fp16 *)src;
- b_off = (__fp16 *)dst;
-
- j = (M >> 2);
- do {
- a_off1 = a_off;
- a_off2 = a_off1 + lda;
- a_off3 = a_off2 + lda;
- a_off4 = a_off3 + lda;
- a_off += 4 * lda;
-
- i = (K >> 2);
- do {
- c1 = *(a_off1 + 0);
- c2 = *(a_off1 + 1);
- c3 = *(a_off1 + 2);
- c4 = *(a_off1 + 3);
-
- c5 = *(a_off2 + 0);
- c6 = *(a_off2 + 1);
- c7 = *(a_off2 + 2);
- c8 = *(a_off2 + 3);
-
- c9 = *(a_off3 + 0);
- c10 = *(a_off3 + 1);
- c11 = *(a_off3 + 2);
- c12 = *(a_off3 + 3);
-
- c13 = *(a_off4 + 0);
- c14 = *(a_off4 + 1);
- c15 = *(a_off4 + 2);
- c16 = *(a_off4 + 3);
-
- *(b_off + 0) = c1;
- *(b_off + 1) = c5;
- *(b_off + 2) = c9;
- *(b_off + 3) = c13;
-
- *(b_off + 4) = c2;
- *(b_off + 5) = c6;
- *(b_off + 6) = c10;
- *(b_off + 7) = c14;
-
- *(b_off + 8) = c3;
- *(b_off + 9) = c7;
- *(b_off + 10) = c11;
- *(b_off + 11) = c15;
-
- *(b_off + 12) = c4;
- *(b_off + 13) = c8;
- *(b_off + 14) = c12;
- *(b_off + 15) = c16;
-
- a_off1 += 4;
- a_off2 += 4;
- a_off3 += 4;
- a_off4 += 4;
-
- b_off += 16;
- i--;
- } while (i > 0);
- j--;
- } while (j > 0);
-}
-
-void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
- unsigned int lda, const __fp16 *dst) {
-
- assert(K != 0 && M != 0 && K % 8 == 0 && M % 8 == 0);
-
- uint16x4_t msk = {0xFFFF, 0xFFFF, 0x0000, 0x0000};
- uint16x4_t inv_msk = {0x0000, 0x0000, 0xFFFF, 0xFFFF};
-
- const __fp16 *a_off = (__fp16 *)src;
- __fp16 *b_off = (__fp16 *)dst;
-
- for (unsigned int i = 0; i < M; i += 8) {
- const __fp16 *a_off1 = a_off;
- const __fp16 *a_off2 = a_off1 + lda;
- const __fp16 *a_off3 = a_off2 + lda;
- const __fp16 *a_off4 = a_off3 + lda;
- const __fp16 *a_off5 = a_off4 + lda;
- const __fp16 *a_off6 = a_off5 + lda;
- const __fp16 *a_off7 = a_off6 + lda;
- const __fp16 *a_off8 = a_off7 + lda;
- a_off += 8 * lda;
-
- for (unsigned int j = 0; j < K; j += 8) {
- float16x8_t _v0 = vld1q_f16(a_off1);
- float16x8_t _v1 = vld1q_f16(a_off2);
- float16x8_t _v2 = vld1q_f16(a_off3);
- float16x8_t _v3 = vld1q_f16(a_off4);
-
- float16x8_t _v4 = vld1q_f16(a_off5);
- float16x8_t _v5 = vld1q_f16(a_off6);
- float16x8_t _v6 = vld1q_f16(a_off7);
- float16x8_t _v7 = vld1q_f16(a_off8);
-
- a_off1 += 8;
- a_off2 += 8;
- a_off3 += 8;
- a_off4 += 8;
- a_off5 += 8;
- a_off6 += 8;
- a_off7 += 8;
- a_off8 += 8;
-
- float16x8x2_t _vv0 = vtrnq_f16(_v0, _v1);
- float16x8x2_t _vv1 = vtrnq_f16(_v2, _v3);
- float16x8x2_t _vv2 = vtrnq_f16(_v4, _v5);
- float16x8x2_t _vv3 = vtrnq_f16(_v6, _v7);
-
- float16x8_t _v8 =
- vcombine_f16(vget_low_f16(_vv0.val[0]), vget_low_f16(_vv1.val[0]));
- float16x8_t _v9 =
- vcombine_f16(vget_low_f16(_vv0.val[1]), vget_low_f16(_vv1.val[1]));
- float16x8_t _v10 =
- vcombine_f16(vget_high_f16(_vv0.val[0]), vget_high_f16(_vv1.val[0]));
- float16x8_t _v11 =
- vcombine_f16(vget_high_f16(_vv0.val[1]), vget_high_f16(_vv1.val[1]));
-
- float16x8_t _v12 =
- vcombine_f16(vget_low_f16(_vv2.val[0]), vget_low_f16(_vv3.val[0]));
- float16x8_t _v13 =
- vcombine_f16(vget_low_f16(_vv2.val[1]), vget_low_f16(_vv3.val[1]));
- float16x8_t _v14 =
- vcombine_f16(vget_high_f16(_vv2.val[0]), vget_high_f16(_vv3.val[0]));
- float16x8_t _v15 =
- vcombine_f16(vget_high_f16(_vv2.val[1]), vget_high_f16(_vv3.val[1]));
-
- // pack-in-pack
- float16x4_t tmp_low_v8 = vget_low_f16(_v8);
- float16x4_t tmp_high_v8 = vget_high_f16(_v8);
- float16x4_t mid_v8 = vext_f16(tmp_low_v8, tmp_high_v8, 2);
-
- float16x4_t tmp_low_v9 = vget_low_f16(_v9);
- float16x4_t tmp_high_v9 = vget_high_f16(_v9);
- float16x4_t mid_v9 = vext_f16(tmp_low_v9, tmp_high_v9, 2);
-
- float16x4_t tmp_low_v10 = vget_low_f16(_v10);
- float16x4_t tmp_high_v10 = vget_high_f16(_v10);
- float16x4_t mid_v10 = vext_f16(tmp_low_v10, tmp_high_v10, 2);
-
- float16x4_t tmp_low_v11 = vget_low_f16(_v11);
- float16x4_t tmp_high_v11 = vget_high_f16(_v11);
- float16x4_t mid_v11 = vext_f16(tmp_low_v11, tmp_high_v11, 2);
-
- float16x4_t tmp_low_v12 = vget_low_f16(_v12);
- float16x4_t tmp_high_v12 = vget_high_f16(_v12);
- float16x4_t mid_v12 = vext_f16(tmp_low_v12, tmp_high_v12, 2);
-
- float16x4_t tmp_low_v13 = vget_low_f16(_v13);
- float16x4_t tmp_high_v13 = vget_high_f16(_v13);
- float16x4_t mid_v13 = vext_f16(tmp_low_v13, tmp_high_v13, 2);
-
- float16x4_t tmp_low_v14 = vget_low_f16(_v14);
- float16x4_t tmp_high_v14 = vget_high_f16(_v14);
- float16x4_t mid_v14 = vext_f16(tmp_low_v14, tmp_high_v14, 2);
-
- float16x4_t tmp_low_v15 = vget_low_f16(_v15);
- float16x4_t tmp_high_v15 = vget_high_f16(_v15);
- float16x4_t mid_v15 = vext_f16(tmp_low_v15, tmp_high_v15, 2);
-
- _v8 = vcombine_f16(vbsl_f16(msk, tmp_low_v8, mid_v8),
- vbsl_f16(msk, tmp_low_v12, mid_v12));
- _v12 = vcombine_f16(vbsl_f16(msk, tmp_low_v9, mid_v9),
- vbsl_f16(msk, tmp_low_v13, mid_v13));
- _v9 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v8, mid_v8),
- vbsl_f16(inv_msk, tmp_high_v12, mid_v12));
- _v13 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v9, mid_v9),
- vbsl_f16(inv_msk, tmp_high_v13, mid_v13));
- _v10 = vcombine_f16(vbsl_f16(msk, tmp_low_v10, mid_v10),
- vbsl_f16(msk, tmp_low_v14, mid_v14));
- _v14 = vcombine_f16(vbsl_f16(msk, tmp_low_v11, mid_v11),
- vbsl_f16(msk, tmp_low_v15, mid_v15));
- _v11 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v10, mid_v10),
- vbsl_f16(inv_msk, tmp_high_v14, mid_v14));
- _v15 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v11, mid_v11),
- vbsl_f16(inv_msk, tmp_high_v15, mid_v15));
-
- vst1q_f16(b_off + 0, _v8);
- vst1q_f16(b_off + 8, _v12);
- vst1q_f16(b_off + 16, _v9);
- vst1q_f16(b_off + 24, _v13);
- vst1q_f16(b_off + 32, _v10);
- vst1q_f16(b_off + 40, _v14);
- vst1q_f16(b_off + 48, _v11);
- vst1q_f16(b_off + 56, _v15);
- b_off += 64;
- }
- }
-}
-
-void packing_B1(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst) {
- assert(K != 0 && N != 0 && N % 8 == 0);
-
- for (int i = 0; i < K; i++) {
- const __fp16 *a_off = src + i * ldb;
- __fp16 *b_off = (__fp16 *)dst + i;
- for (int j = 0; j < N; j++) {
- float16_t v = *(a_off);
- a_off++;
-
- *b_off = v;
- b_off += K;
- }
- }
-}
-
-void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst) {
- assert(K != 0 && N != 0 && K % 4 == 0 && N % 4 == 0);
- unsigned int i, j;
-
- __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
- __fp16 *b_off, *b_off1;
- __fp16 c1, c2, c3, c4;
- __fp16 c5, c6, c7, c8;
- __fp16 c9, c10, c11, c12;
- __fp16 c13, c14, c15, c16;
- a_off = (__fp16 *)src;
- b_off = (__fp16 *)dst;
-
- j = (K >> 2);
- do {
- a_off1 = a_off;
- a_off2 = a_off1 + ldb;
- a_off3 = a_off2 + ldb;
- a_off4 = a_off3 + ldb;
- a_off += 4 * ldb;
-
- b_off1 = b_off;
- b_off += 16;
-
- i = (N >> 2);
- do {
- c1 = *(a_off1 + 0);
- c2 = *(a_off1 + 1);
- c3 = *(a_off1 + 2);
- c4 = *(a_off1 + 3);
-
- c5 = *(a_off2 + 0);
- c6 = *(a_off2 + 1);
- c7 = *(a_off2 + 2);
- c8 = *(a_off2 + 3);
-
- c9 = *(a_off3 + 0);
- c10 = *(a_off3 + 1);
- c11 = *(a_off3 + 2);
- c12 = *(a_off3 + 3);
-
- c13 = *(a_off4 + 0);
- c14 = *(a_off4 + 1);
- c15 = *(a_off4 + 2);
- c16 = *(a_off4 + 3);
-
- a_off1 += 4;
- a_off2 += 4;
- a_off3 += 4;
- a_off4 += 4;
-
- *(b_off1 + 0) = c1;
- *(b_off1 + 1) = c2;
- *(b_off1 + 2) = c3;
- *(b_off1 + 3) = c4;
-
- *(b_off1 + 4) = c5;
- *(b_off1 + 5) = c6;
- *(b_off1 + 6) = c7;
- *(b_off1 + 7) = c8;
-
- *(b_off1 + 8) = c9;
- *(b_off1 + 9) = c10;
- *(b_off1 + 10) = c11;
- *(b_off1 + 11) = c12;
-
- *(b_off1 + 12) = c13;
- *(b_off1 + 13) = c14;
- *(b_off1 + 14) = c15;
- *(b_off1 + 15) = c16;
-
- b_off1 += K * 4;
- i--;
- } while (i > 0);
- j--;
- } while (j > 0);
-}
-
-void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst) {
- assert(K != 0 && N != 0 && N % 8 == 0);
-
- for (int i = 0; i < K; i++) {
- const __fp16 *a_off = src + i * ldb;
- __fp16 *b_off = (__fp16 *)dst + i * 8;
- for (int j = 0; j < N; j += 8) {
- float16x8_t v = vld1q_f16(a_off);
- a_off += 8;
-
- vst1q_f16(b_off, v);
- b_off += 8 * K;
- }
- }
-}
-
-void packing_B16(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst) {
- assert(K != 0 && N != 0 && N % 16 == 0);
-
- for (int i = 0; i < K; i++) {
- const __fp16 *a_off = src + i * ldb;
- __fp16 *b_off = (__fp16 *)dst + i * 16;
- for (int j = 0; j < N; j += 16) {
- float16x8_t v0_7 = vld1q_f16(a_off);
- float16x8_t v8_15 = vld1q_f16(a_off + 8);
- a_off += 16;
-
- vst1q_f16(b_off, v0_7);
- vst1q_f16(b_off + 8, v8_15);
- b_off += 16 * K;
- }
- }
-}
-
-void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst) {
- /// @note ldb = K for here
- assert(K != 0 && N != 0 && N % 16 == 0);
- unsigned int K8 = (K >> 3) << 3;
-
- const __fp16 *src_off = (__fp16 *)src;
- __fp16 *dst_off = (__fp16 *)dst;
-
- const unsigned int ld_tile_T = 16;
- __fp16 *tile_T = new __fp16[8 * ld_tile_T];
- // __fp16 *tile_T = alignedMalloc(8 * ld_tile_T);
-
- // 1. Do something like 8x16 transpose kernel
- // 2. Save linearized transposed output tile to dst
- for (unsigned int n = 0; n < N; n += 16) {
- const __fp16 *src_off1 = src_off;
- __fp16 *dst_off1 = dst_off;
- src_off += 16 * ldb;
- dst_off += (K8 * 16 + (K - K8)); // ?
- for (unsigned int k = 0; k < K8; k += 8) {
- // 16x8 tile -> 8x16
- transpose_neon<__fp16>(16, 8, src_off1, ldb, tile_T, ld_tile_T);
-
- // Store with correct packing order linearly
- vst1q_f16(&dst_off1[0], vld1q_f16(&tile_T[0 * ld_tile_T + 0]));
- vst1q_f16(&dst_off1[8], vld1q_f16(&tile_T[0 * ld_tile_T + 8]));
- vst1q_f16(&dst_off1[16], vld1q_f16(&tile_T[1 * ld_tile_T + 0]));
- vst1q_f16(&dst_off1[24], vld1q_f16(&tile_T[1 * ld_tile_T + 8]));
- vst1q_f16(&dst_off1[32], vld1q_f16(&tile_T[2 * ld_tile_T + 0]));
- vst1q_f16(&dst_off1[40], vld1q_f16(&tile_T[2 * ld_tile_T + 8]));
- vst1q_f16(&dst_off1[48], vld1q_f16(&tile_T[3 * ld_tile_T + 0]));
- vst1q_f16(&dst_off1[56], vld1q_f16(&tile_T[3 * ld_tile_T + 8]));
- vst1q_f16(&dst_off1[64], vld1q_f16(&tile_T[4 * ld_tile_T + 0]));
- vst1q_f16(&dst_off1[72], vld1q_f16(&tile_T[4 * ld_tile_T + 8]));
- vst1q_f16(&dst_off1[80], vld1q_f16(&tile_T[5 * ld_tile_T + 0]));
- vst1q_f16(&dst_off1[88], vld1q_f16(&tile_T[5 * ld_tile_T + 8]));
- vst1q_f16(&dst_off1[96], vld1q_f16(&tile_T[6 * ld_tile_T + 0]));
- vst1q_f16(&dst_off1[104], vld1q_f16(&tile_T[6 * ld_tile_T + 8]));
- vst1q_f16(&dst_off1[112], vld1q_f16(&tile_T[7 * ld_tile_T + 0]));
- vst1q_f16(&dst_off1[120], vld1q_f16(&tile_T[7 * ld_tile_T + 8]));
-
- dst_off1 += 16 * 8;
- src_off1 += 8;
- }
-
- // Do the equivalent of one by one for the rest
- for (unsigned int k = K8; k < K; ++k) {
- for (unsigned int _n = 0; _n < 16; ++_n) {
- dst_off1[_n] = src_off1[k];
- }
- }
- }
-}
+++ /dev/null
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file hgemm_kernel_pack.h
- * @date 01 April 2024
- * @see https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @author Debadri Samaddar <s.debadri@samsung.com>
- * @bug No known bugs except for NYI items
- * @brief This is for half-precision packing for kernel-based GEMM
- */
-
-
-/**
- * @brief packing function of input matrix A
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param lda leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_A1(unsigned int m, unsigned int k, const __fp16 *from,
- unsigned int lda, const __fp16 *to);
-/**
- * @brief packing function of input matrix A
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param lda leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
- unsigned int lda, const __fp16 *dst);
-/**
- * @brief packing function of input matrix A
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param lda leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
- unsigned int lda, const __fp16 *dst);
-/**
- * @brief packing function of input matrix B
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param ldb leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_B1(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst);
-/**
- * @brief packing function of input matrix B
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param ldb leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst);
-/**
- * @brief packing function of input matrix B
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param ldb leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst);
-/**
- * @brief packing function of input matrix B
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param ldb leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_B16(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst);
-/**
- * @brief
- *
- * @param K
- * @param N
- * @param src
- * @param ldb
- * @param dst
- */
-void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src,
- unsigned int ldb, const __fp16 *dst);
*
*/
+#include <arm_neon.h>
#include <cmath>
-
-#include <hgemm_kernel_pack.h>
+#include <hgemm_common.h>
+#include <hgemm_kernel.h>
#include <hgemm_noTrans.h>
+#include <hgemm_pack.h>
#include <hgemm_util.h>
#include <limits>
-// #include <hgemm_kernel.h>
-
#include <matrix_transpose_neon.h>
-#include <hgemm_common.h>
-
-#include <hgemm_kernel_1x4.h>
-#include <hgemm_kernel_1x8.h>
-#include <hgemm_kernel_4x4.h>
-#include <hgemm_kernel_4x8.h>
-#include <hgemm_kernel_8x16.h>
-#include <hgemm_kernel_8x8.h>
-
-
void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
unsigned int N, unsigned int K, float alpha, float beta) {
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_kernel_pack.cpp
+ * @date 02 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is a source file for half-precision packing for the matrix
+ * multiplication
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_common.h>
+#include <hgemm_pack.h>
+#include <matrix_transpose_neon.h>
+
+void packing_A1(unsigned int m, unsigned int k, const __fp16 *from,
+ unsigned int lda, const __fp16 *to) {
+
+ assert(k != 0 && m != 0 && k % 4 == 0 && m % 4 == 0);
+ unsigned int i, j;
+
+ __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4;
+ __fp16 *b_offset;
+ __fp16 ctemp1, ctemp2, ctemp3, ctemp4;
+
+ a_offset = (__fp16 *)from;
+ b_offset = (__fp16 *)to;
+
+ j = m;
+ do {
+ a_offset1 = a_offset;
+ a_offset += lda;
+
+ i = (k >> 2);
+ do {
+ ctemp1 = *(a_offset1 + 0);
+ ctemp2 = *(a_offset1 + 1);
+ ctemp3 = *(a_offset1 + 2);
+ ctemp4 = *(a_offset1 + 3);
+
+ *(b_offset + 0) = ctemp1;
+ *(b_offset + 1) = ctemp2;
+ *(b_offset + 2) = ctemp3;
+ *(b_offset + 3) = ctemp4;
+
+ a_offset1 += 4;
+
+ b_offset += 4;
+ i--;
+ } while (i > 0);
+ j--;
+ } while (j > 0);
+}
+
+void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
+ unsigned int lda, const __fp16 *dst) {
+
+ assert(K != 0 && M != 0 && K % 4 == 0 && M % 4 == 0);
+ unsigned int i, j;
+
+ __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
+ __fp16 *b_off;
+ __fp16 c1, c2, c3, c4;
+ __fp16 c5, c6, c7, c8;
+ __fp16 c9, c10, c11, c12;
+ __fp16 c13, c14, c15, c16;
+
+ a_off = (__fp16 *)src;
+ b_off = (__fp16 *)dst;
+
+ j = (M >> 2);
+ do {
+ a_off1 = a_off;
+ a_off2 = a_off1 + lda;
+ a_off3 = a_off2 + lda;
+ a_off4 = a_off3 + lda;
+ a_off += 4 * lda;
+
+ i = (K >> 2);
+ do {
+ c1 = *(a_off1 + 0);
+ c2 = *(a_off1 + 1);
+ c3 = *(a_off1 + 2);
+ c4 = *(a_off1 + 3);
+
+ c5 = *(a_off2 + 0);
+ c6 = *(a_off2 + 1);
+ c7 = *(a_off2 + 2);
+ c8 = *(a_off2 + 3);
+
+ c9 = *(a_off3 + 0);
+ c10 = *(a_off3 + 1);
+ c11 = *(a_off3 + 2);
+ c12 = *(a_off3 + 3);
+
+ c13 = *(a_off4 + 0);
+ c14 = *(a_off4 + 1);
+ c15 = *(a_off4 + 2);
+ c16 = *(a_off4 + 3);
+
+ *(b_off + 0) = c1;
+ *(b_off + 1) = c5;
+ *(b_off + 2) = c9;
+ *(b_off + 3) = c13;
+
+ *(b_off + 4) = c2;
+ *(b_off + 5) = c6;
+ *(b_off + 6) = c10;
+ *(b_off + 7) = c14;
+
+ *(b_off + 8) = c3;
+ *(b_off + 9) = c7;
+ *(b_off + 10) = c11;
+ *(b_off + 11) = c15;
+
+ *(b_off + 12) = c4;
+ *(b_off + 13) = c8;
+ *(b_off + 14) = c12;
+ *(b_off + 15) = c16;
+
+ a_off1 += 4;
+ a_off2 += 4;
+ a_off3 += 4;
+ a_off4 += 4;
+
+ b_off += 16;
+ i--;
+ } while (i > 0);
+ j--;
+ } while (j > 0);
+}
+
+void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
+ unsigned int lda, const __fp16 *dst) {
+
+ assert(K != 0 && M != 0 && K % 8 == 0 && M % 8 == 0);
+
+ uint16x4_t msk = {0xFFFF, 0xFFFF, 0x0000, 0x0000};
+ uint16x4_t inv_msk = {0x0000, 0x0000, 0xFFFF, 0xFFFF};
+
+ const __fp16 *a_off = (__fp16 *)src;
+ __fp16 *b_off = (__fp16 *)dst;
+
+ for (unsigned int i = 0; i < M; i += 8) {
+ const __fp16 *a_off1 = a_off;
+ const __fp16 *a_off2 = a_off1 + lda;
+ const __fp16 *a_off3 = a_off2 + lda;
+ const __fp16 *a_off4 = a_off3 + lda;
+ const __fp16 *a_off5 = a_off4 + lda;
+ const __fp16 *a_off6 = a_off5 + lda;
+ const __fp16 *a_off7 = a_off6 + lda;
+ const __fp16 *a_off8 = a_off7 + lda;
+ a_off += 8 * lda;
+
+ for (unsigned int j = 0; j < K; j += 8) {
+ float16x8_t _v0 = vld1q_f16(a_off1);
+ float16x8_t _v1 = vld1q_f16(a_off2);
+ float16x8_t _v2 = vld1q_f16(a_off3);
+ float16x8_t _v3 = vld1q_f16(a_off4);
+
+ float16x8_t _v4 = vld1q_f16(a_off5);
+ float16x8_t _v5 = vld1q_f16(a_off6);
+ float16x8_t _v6 = vld1q_f16(a_off7);
+ float16x8_t _v7 = vld1q_f16(a_off8);
+
+ a_off1 += 8;
+ a_off2 += 8;
+ a_off3 += 8;
+ a_off4 += 8;
+ a_off5 += 8;
+ a_off6 += 8;
+ a_off7 += 8;
+ a_off8 += 8;
+
+ float16x8x2_t _vv0 = vtrnq_f16(_v0, _v1);
+ float16x8x2_t _vv1 = vtrnq_f16(_v2, _v3);
+ float16x8x2_t _vv2 = vtrnq_f16(_v4, _v5);
+ float16x8x2_t _vv3 = vtrnq_f16(_v6, _v7);
+
+ float16x8_t _v8 =
+ vcombine_f16(vget_low_f16(_vv0.val[0]), vget_low_f16(_vv1.val[0]));
+ float16x8_t _v9 =
+ vcombine_f16(vget_low_f16(_vv0.val[1]), vget_low_f16(_vv1.val[1]));
+ float16x8_t _v10 =
+ vcombine_f16(vget_high_f16(_vv0.val[0]), vget_high_f16(_vv1.val[0]));
+ float16x8_t _v11 =
+ vcombine_f16(vget_high_f16(_vv0.val[1]), vget_high_f16(_vv1.val[1]));
+
+ float16x8_t _v12 =
+ vcombine_f16(vget_low_f16(_vv2.val[0]), vget_low_f16(_vv3.val[0]));
+ float16x8_t _v13 =
+ vcombine_f16(vget_low_f16(_vv2.val[1]), vget_low_f16(_vv3.val[1]));
+ float16x8_t _v14 =
+ vcombine_f16(vget_high_f16(_vv2.val[0]), vget_high_f16(_vv3.val[0]));
+ float16x8_t _v15 =
+ vcombine_f16(vget_high_f16(_vv2.val[1]), vget_high_f16(_vv3.val[1]));
+
+ // pack-in-pack
+ float16x4_t tmp_low_v8 = vget_low_f16(_v8);
+ float16x4_t tmp_high_v8 = vget_high_f16(_v8);
+ float16x4_t mid_v8 = vext_f16(tmp_low_v8, tmp_high_v8, 2);
+
+ float16x4_t tmp_low_v9 = vget_low_f16(_v9);
+ float16x4_t tmp_high_v9 = vget_high_f16(_v9);
+ float16x4_t mid_v9 = vext_f16(tmp_low_v9, tmp_high_v9, 2);
+
+ float16x4_t tmp_low_v10 = vget_low_f16(_v10);
+ float16x4_t tmp_high_v10 = vget_high_f16(_v10);
+ float16x4_t mid_v10 = vext_f16(tmp_low_v10, tmp_high_v10, 2);
+
+ float16x4_t tmp_low_v11 = vget_low_f16(_v11);
+ float16x4_t tmp_high_v11 = vget_high_f16(_v11);
+ float16x4_t mid_v11 = vext_f16(tmp_low_v11, tmp_high_v11, 2);
+
+ float16x4_t tmp_low_v12 = vget_low_f16(_v12);
+ float16x4_t tmp_high_v12 = vget_high_f16(_v12);
+ float16x4_t mid_v12 = vext_f16(tmp_low_v12, tmp_high_v12, 2);
+
+ float16x4_t tmp_low_v13 = vget_low_f16(_v13);
+ float16x4_t tmp_high_v13 = vget_high_f16(_v13);
+ float16x4_t mid_v13 = vext_f16(tmp_low_v13, tmp_high_v13, 2);
+
+ float16x4_t tmp_low_v14 = vget_low_f16(_v14);
+ float16x4_t tmp_high_v14 = vget_high_f16(_v14);
+ float16x4_t mid_v14 = vext_f16(tmp_low_v14, tmp_high_v14, 2);
+
+ float16x4_t tmp_low_v15 = vget_low_f16(_v15);
+ float16x4_t tmp_high_v15 = vget_high_f16(_v15);
+ float16x4_t mid_v15 = vext_f16(tmp_low_v15, tmp_high_v15, 2);
+
+ _v8 = vcombine_f16(vbsl_f16(msk, tmp_low_v8, mid_v8),
+ vbsl_f16(msk, tmp_low_v12, mid_v12));
+ _v12 = vcombine_f16(vbsl_f16(msk, tmp_low_v9, mid_v9),
+ vbsl_f16(msk, tmp_low_v13, mid_v13));
+ _v9 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v8, mid_v8),
+ vbsl_f16(inv_msk, tmp_high_v12, mid_v12));
+ _v13 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v9, mid_v9),
+ vbsl_f16(inv_msk, tmp_high_v13, mid_v13));
+ _v10 = vcombine_f16(vbsl_f16(msk, tmp_low_v10, mid_v10),
+ vbsl_f16(msk, tmp_low_v14, mid_v14));
+ _v14 = vcombine_f16(vbsl_f16(msk, tmp_low_v11, mid_v11),
+ vbsl_f16(msk, tmp_low_v15, mid_v15));
+ _v11 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v10, mid_v10),
+ vbsl_f16(inv_msk, tmp_high_v14, mid_v14));
+ _v15 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v11, mid_v11),
+ vbsl_f16(inv_msk, tmp_high_v15, mid_v15));
+
+ vst1q_f16(b_off + 0, _v8);
+ vst1q_f16(b_off + 8, _v12);
+ vst1q_f16(b_off + 16, _v9);
+ vst1q_f16(b_off + 24, _v13);
+ vst1q_f16(b_off + 32, _v10);
+ vst1q_f16(b_off + 40, _v14);
+ vst1q_f16(b_off + 48, _v11);
+ vst1q_f16(b_off + 56, _v15);
+ b_off += 64;
+ }
+ }
+}
+
+void packing_B1(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst) {
+ assert(K != 0 && N != 0 && N % 8 == 0);
+
+ for (int i = 0; i < K; i++) {
+ const __fp16 *a_off = src + i * ldb;
+ __fp16 *b_off = (__fp16 *)dst + i;
+ for (int j = 0; j < N; j++) {
+ float16_t v = *(a_off);
+ a_off++;
+
+ *b_off = v;
+ b_off += K;
+ }
+ }
+}
+
+void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst) {
+ assert(K != 0 && N != 0 && K % 4 == 0 && N % 4 == 0);
+ unsigned int i, j;
+
+ __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
+ __fp16 *b_off, *b_off1;
+ __fp16 c1, c2, c3, c4;
+ __fp16 c5, c6, c7, c8;
+ __fp16 c9, c10, c11, c12;
+ __fp16 c13, c14, c15, c16;
+ a_off = (__fp16 *)src;
+ b_off = (__fp16 *)dst;
+
+ j = (K >> 2);
+ do {
+ a_off1 = a_off;
+ a_off2 = a_off1 + ldb;
+ a_off3 = a_off2 + ldb;
+ a_off4 = a_off3 + ldb;
+ a_off += 4 * ldb;
+
+ b_off1 = b_off;
+ b_off += 16;
+
+ i = (N >> 2);
+ do {
+ c1 = *(a_off1 + 0);
+ c2 = *(a_off1 + 1);
+ c3 = *(a_off1 + 2);
+ c4 = *(a_off1 + 3);
+
+ c5 = *(a_off2 + 0);
+ c6 = *(a_off2 + 1);
+ c7 = *(a_off2 + 2);
+ c8 = *(a_off2 + 3);
+
+ c9 = *(a_off3 + 0);
+ c10 = *(a_off3 + 1);
+ c11 = *(a_off3 + 2);
+ c12 = *(a_off3 + 3);
+
+ c13 = *(a_off4 + 0);
+ c14 = *(a_off4 + 1);
+ c15 = *(a_off4 + 2);
+ c16 = *(a_off4 + 3);
+
+ a_off1 += 4;
+ a_off2 += 4;
+ a_off3 += 4;
+ a_off4 += 4;
+
+ *(b_off1 + 0) = c1;
+ *(b_off1 + 1) = c2;
+ *(b_off1 + 2) = c3;
+ *(b_off1 + 3) = c4;
+
+ *(b_off1 + 4) = c5;
+ *(b_off1 + 5) = c6;
+ *(b_off1 + 6) = c7;
+ *(b_off1 + 7) = c8;
+
+ *(b_off1 + 8) = c9;
+ *(b_off1 + 9) = c10;
+ *(b_off1 + 10) = c11;
+ *(b_off1 + 11) = c12;
+
+ *(b_off1 + 12) = c13;
+ *(b_off1 + 13) = c14;
+ *(b_off1 + 14) = c15;
+ *(b_off1 + 15) = c16;
+
+ b_off1 += K * 4;
+ i--;
+ } while (i > 0);
+ j--;
+ } while (j > 0);
+}
+
+void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst) {
+ assert(K != 0 && N != 0 && N % 8 == 0);
+
+ for (int i = 0; i < K; i++) {
+ const __fp16 *a_off = src + i * ldb;
+ __fp16 *b_off = (__fp16 *)dst + i * 8;
+ for (int j = 0; j < N; j += 8) {
+ float16x8_t v = vld1q_f16(a_off);
+ a_off += 8;
+
+ vst1q_f16(b_off, v);
+ b_off += 8 * K;
+ }
+ }
+}
+
+void packing_B16(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst) {
+ assert(K != 0 && N != 0 && N % 16 == 0);
+
+ for (int i = 0; i < K; i++) {
+ const __fp16 *a_off = src + i * ldb;
+ __fp16 *b_off = (__fp16 *)dst + i * 16;
+ for (int j = 0; j < N; j += 16) {
+ float16x8_t v0_7 = vld1q_f16(a_off);
+ float16x8_t v8_15 = vld1q_f16(a_off + 8);
+ a_off += 16;
+
+ vst1q_f16(b_off, v0_7);
+ vst1q_f16(b_off + 8, v8_15);
+ b_off += 16 * K;
+ }
+ }
+}
+
+void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst) {
+ /// @note ldb = K for here
+ assert(K != 0 && N != 0 && N % 16 == 0);
+ unsigned int K8 = (K >> 3) << 3;
+
+ const __fp16 *src_off = (__fp16 *)src;
+ __fp16 *dst_off = (__fp16 *)dst;
+
+ const unsigned int ld_tile_T = 16;
+ __fp16 *tile_T = new __fp16[8 * ld_tile_T];
+ // __fp16 *tile_T = alignedMalloc(8 * ld_tile_T);
+
+ // 1. Do something like 8x16 transpose kernel
+ // 2. Save linearized transposed output tile to dst
+ for (unsigned int n = 0; n < N; n += 16) {
+ const __fp16 *src_off1 = src_off;
+ __fp16 *dst_off1 = dst_off;
+ src_off += 16 * ldb;
+ dst_off += (K8 * 16 + (K - K8)); // ?
+ for (unsigned int k = 0; k < K8; k += 8) {
+ // 16x8 tile -> 8x16
+ transpose_neon<__fp16>(16, 8, src_off1, ldb, tile_T, ld_tile_T);
+
+ // Store with correct packing order linearly
+ vst1q_f16(&dst_off1[0], vld1q_f16(&tile_T[0 * ld_tile_T + 0]));
+ vst1q_f16(&dst_off1[8], vld1q_f16(&tile_T[0 * ld_tile_T + 8]));
+ vst1q_f16(&dst_off1[16], vld1q_f16(&tile_T[1 * ld_tile_T + 0]));
+ vst1q_f16(&dst_off1[24], vld1q_f16(&tile_T[1 * ld_tile_T + 8]));
+ vst1q_f16(&dst_off1[32], vld1q_f16(&tile_T[2 * ld_tile_T + 0]));
+ vst1q_f16(&dst_off1[40], vld1q_f16(&tile_T[2 * ld_tile_T + 8]));
+ vst1q_f16(&dst_off1[48], vld1q_f16(&tile_T[3 * ld_tile_T + 0]));
+ vst1q_f16(&dst_off1[56], vld1q_f16(&tile_T[3 * ld_tile_T + 8]));
+ vst1q_f16(&dst_off1[64], vld1q_f16(&tile_T[4 * ld_tile_T + 0]));
+ vst1q_f16(&dst_off1[72], vld1q_f16(&tile_T[4 * ld_tile_T + 8]));
+ vst1q_f16(&dst_off1[80], vld1q_f16(&tile_T[5 * ld_tile_T + 0]));
+ vst1q_f16(&dst_off1[88], vld1q_f16(&tile_T[5 * ld_tile_T + 8]));
+ vst1q_f16(&dst_off1[96], vld1q_f16(&tile_T[6 * ld_tile_T + 0]));
+ vst1q_f16(&dst_off1[104], vld1q_f16(&tile_T[6 * ld_tile_T + 8]));
+ vst1q_f16(&dst_off1[112], vld1q_f16(&tile_T[7 * ld_tile_T + 0]));
+ vst1q_f16(&dst_off1[120], vld1q_f16(&tile_T[7 * ld_tile_T + 8]));
+
+ dst_off1 += 16 * 8;
+ src_off1 += 8;
+ }
+
+ // Do the equivalent of one by one for the rest
+ for (unsigned int k = K8; k < K; ++k) {
+ for (unsigned int _n = 0; _n < 16; ++_n) {
+ dst_off1[_n] = src_off1[k];
+ }
+ }
+ }
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_kernel_pack.h
+ * @date 01 April 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @author Debadri Samaddar <s.debadri@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is for half-precision packing for kernel-based GEMM
+ */
+
+/**
+ * @brief packing function of input matrix A
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param lda leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_A1(unsigned int m, unsigned int k, const __fp16 *from,
+ unsigned int lda, const __fp16 *to);
+/**
+ * @brief packing function of input matrix A
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param lda leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
+ unsigned int lda, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix A
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param lda leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
+ unsigned int lda, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix B
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param ldb leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_B1(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix B
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param ldb leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix B
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param ldb leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix B
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param ldb leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_B16(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst);
+/**
+ * @brief
+ *
+ * @param K
+ * @param N
+ * @param src
+ * @param ldb
+ * @param dst
+ */
+void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst);
std::cerr << "Error : hgemm_padding_B_noTrans_wrt_KN NYI!\n";
}
-
void hgemm_padding_B_Trans_wrt_N(const __fp16 *B, __fp16 *Bp, unsigned int K,
unsigned int N, unsigned int K8,
unsigned int N16) {
*/
#include <cmath>
-#include <hgemm_kernel_8x16.h>
#include <hgemm_common.h>
-// #include <hgemm_kernel.h>
-#include <hgemm_kernel_pack.h>
+#include <hgemm_kernel.h>
#include <hgemm_noTrans.h>
+#include <hgemm_pack.h>
#include <hgemm_transB.h>
#include <hgemm_util.h>
#include <limits>
#include <matrix_transpose_neon.h>
-// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16 /// @todo change to macro kernel
-// #if !defined(HGEMM_KERNEL_8x16) hgemm_kernel_8x16
-// #endif
-
void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, float *C, unsigned int ldc,
n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
}
packing_transB16(k_min, n_min, B + ks + ldb * ns, ldb, sB);
- hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns,
- ldc);
+ hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
}
}
}
hgemm_headers = [
'hgemm.h',
'hgemm_util.h',
- 'hgemm_kernel_pack.h',
- 'hgemm_kernel_4x4.h',
- 'hgemm_kernel_4x8.h',
- 'hgemm_kernel_8x8.h',
- 'hgemm_kernel_8x16.h',
+ 'hgemm_pack.h',
+ 'hgemm_common.h',
+ 'hgemm_padding.h',
]
+subdir('hgemm_kernel')
+nntrainer_inc += include_directories('hgemm_kernel')
+nntrainer_inc_abs += meson.current_source_dir() / 'hgemm_kernel'
+
hgemm_sources = [
'hgemm.cpp',
'hgemm_padding_a.cpp',
'hgemm_padding_b.cpp',
- 'hgemm_kernel_pack.cpp',
+ 'hgemm_pack.cpp',
'hgemm_noTrans.cpp',
'hgemm_transA.cpp',
'hgemm_transB.cpp',