[ hgemm/refactor ] Refactor hgemm file structure
authorskykongkong8 <ss.kong@samsung.com>
Wed, 10 Jul 2024 08:34:43 +0000 (17:34 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 30 Jul 2024 22:45:30 +0000 (07:45 +0900)
- Kernel functions are used regardless of matrix transpose, does need to be included from separate file.
- For further optimal implemenation of matrix A / B / AB transpose blocking-kernel sequences, divide their file for convenience
- Function 'hgemm' itself is better to be reside in hgemm directory.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
27 files changed:
nntrainer/tensor/blas_neon.h
nntrainer/tensor/hgemm/hgemm.cpp
nntrainer/tensor/hgemm/hgemm.h
nntrainer/tensor/hgemm/hgemm_common.h
nntrainer/tensor/hgemm/hgemm_kernel.h [deleted file]
nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel.h [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x4.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x8.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x4.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x8.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x16.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x8.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_kernel/meson.build [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_kernel_1x4.h [deleted file]
nntrainer/tensor/hgemm/hgemm_kernel_1x8.h [deleted file]
nntrainer/tensor/hgemm/hgemm_kernel_4x4.h [deleted file]
nntrainer/tensor/hgemm/hgemm_kernel_4x8.h [deleted file]
nntrainer/tensor/hgemm/hgemm_kernel_8x16.h [deleted file]
nntrainer/tensor/hgemm/hgemm_kernel_8x8.h [deleted file]
nntrainer/tensor/hgemm/hgemm_kernel_pack.cpp [deleted file]
nntrainer/tensor/hgemm/hgemm_kernel_pack.h [deleted file]
nntrainer/tensor/hgemm/hgemm_noTrans.cpp
nntrainer/tensor/hgemm/hgemm_pack.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_pack.h [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_padding_b.cpp
nntrainer/tensor/hgemm/hgemm_transB.cpp
nntrainer/tensor/hgemm/meson.build

index b7be55d48e0fb097473dbc50c5661a26f5ef00b0..81f8c060ed410c688759c883c61809b1d5100bc5 100644 (file)
@@ -327,9 +327,9 @@ unsigned int isamax(const unsigned int N, const __fp16 *X);
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
-           uint32_t K, float alpha, float beta, bool TransA, bool TransB);
-
+void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
+                  uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
+                  bool TransB);
 
 /**
  * @brief squared root transformation with neon : X = sqrt(X)
index 81d22b22c4bbf4f16577195fb7215668c2618f04..cadb28bbe9ddbb09f56c143a0739698293cd7905 100644 (file)
  *
  */
 
+#include <arm_neon.h>
+#include <cmath>
 #include <hgemm.h>
+#include <hgemm_common.h>
 #include <hgemm_noTrans.h>
 #include <hgemm_padding.h>
 #include <hgemm_transA.h>
 #include <hgemm_transAB.h>
 #include <hgemm_transB.h>
 #include <hgemm_util.h>
-#include <hgemm_common.h>
-#include <cmath>
 
-
-void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
-           unsigned int K, float alpha, float beta, bool TransA, bool TransB) {
+void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+           unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
+           bool TransB) {
   if (K == 1) {
     return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
   }
@@ -67,9 +68,9 @@ void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned
   }
 
   hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB);
-  
-  unsigned int L = M*N;
-  unsigned int L8 = (L >> 3) <<3;
+
+  unsigned int L = M * N;
+  unsigned int L8 = (L >> 3) << 3;
 
   for (unsigned int idx = 0; idx < L8; idx += 8) {
     float32x4_t x1 = vld1q_f32(&C32[idx]);
@@ -151,11 +152,11 @@ void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32,
 }
 
 void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
-              unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
-              bool TransB) {
+              unsigned int N, unsigned int K, float alpha, float beta,
+              bool TransA, bool TransB) {
   unsigned int lda = (TransA) ? M : K;
   unsigned int ldb = (TransB) ? K : N;
-  
+
   return hgemm_K1_noTrans(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
 
   if (!TransA && TransB) {
index d2dd28941c25a8b9a2f8e5bada087cc379bf70be..a0c7b6f9d241910ebe37a8d89ae7a306a3d3cc88 100644 (file)
@@ -24,8 +24,9 @@
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
-           unsigned int K, float alpha, float beta, bool TransA, bool TransB);
+void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+           unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
+           bool TransB);
 
 /**
  * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
@@ -54,9 +55,10 @@ void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
-                    unsigned int N, unsigned int K, float alpha = 1.F, float beta = 0.F,
-                    bool TransA = false, bool TransB = false);
+void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32,
+                    unsigned int M, unsigned int N, unsigned int K,
+                    float alpha = 1.F, float beta = 0.F, bool TransA = false,
+                    bool TransB = false);
 /**
  * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
  * where op(X) is one of X or X**T
@@ -70,5 +72,5 @@ void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M
  * @param[in] beta float number
  */
 void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
-              unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
-              bool TransB);
+              unsigned int N, unsigned int K, float alpha, float beta,
+              bool TransA, bool TransB);
index bdf9bcbcace1f883a41a8658c84f400b79649dd9..a041e4319df654d458c9958d940264d4be737c02 100644 (file)
  * @brief  This is common settings for hgemm
  *
  */
-#include <arm_neon.h>
-#include <assert.h>
 
-
-#define A(i, j) a[(i) * lda + (j)]
-#define B(i, j) b[(i) * ldb + (j)]
-#define C(i, j) c[(i) * ldc + (j)]
+#define A(i, j) a[(i)*lda + (j)]
+#define B(i, j) b[(i)*ldb + (j)]
+#define C(i, j) c[(i)*ldc + (j)]
 
 #define N_BLOCKING (768)
 #define K_BLOCKING (256)
@@ -27,9 +24,3 @@
 #define GEMM_UNROLLING_1 (1)
 #define VL_FP16 (8)
 #define VL_FP16_HALF (4)
-
-
-
-/**
- * @todo Add macro for instructions in other CPU architectures
- */
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel.h b/nntrainer/tensor/hgemm/hgemm_kernel.h
deleted file mode 100644 (file)
index 4bcea0f..0000000
+++ /dev/null
@@ -1,13 +0,0 @@
-// #include <hgemm_kernel_1x4.h>
-// #include <hgemm_kernel_1x8.h>
-// #include <hgemm_kernel_4x4.h>
-// #include <hgemm_kernel_4x8.h>
-// #include <hgemm_kernel_8x16.h>
-// #include <hgemm_kernel_8x8.h>
-
-// #define HGEMM_KERNEL_1x4 hgemm_kernel_1x4
-// #define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
-// #define HGEMM_KERNEL_1x8 hgemm_kernel_1x8
-// #define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
-// #define HGEMM_KERNEL_8x8 hgemm_kernel_8x8
-// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16
\ No newline at end of file
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel.h b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel.h
new file mode 100644 (file)
index 0000000..2ebc8b4
--- /dev/null
@@ -0,0 +1,169 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_kernel.h
+ * @date   10 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is a collection of all the KERNELs function for hgemm
+ *
+ */
+
+/**
+ * @brief hgemm_kernel_8x16 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
+                       __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_8x16 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
+                       __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_8x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_8x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_4x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_4x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_4x4 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_4x4 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_1x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_1x8 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_1x4 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc);
+/**
+ * @brief hgemm_kernel_1x4 KERNEL function
+ *
+ * @param M Length of blocked M
+ * @param N Length of blocked N
+ * @param K Length of blocked K
+ * @param sa Starting address of blocked A
+ * @param sb Starting address of blocked B
+ * @param sc Starting address of blocked C
+ * @param ldc Leading dimension of original matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc);
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x4.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x4.cpp
new file mode 100644 (file)
index 0000000..2c301e5
--- /dev/null
@@ -0,0 +1,146 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
+ *
+ * @file   hgemm_kernel_1x4.cpp
+ * @date   23 April 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Debadri Samaddar <s.debadri@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM 1x4 kernel
+ *
+ */
+
+#include <stdlib.h>
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+
+/**
+ * @brief hgemm 1x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(N % 4 == 0);
+
+  __fp16 *a = sa, *b = sb, *c = sc;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i++) {
+    for (j = 0; j < N; j += 4) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+
+      for (l = 0; l < K; l += 4) {
+        float16x4_t v24 = {0.F};
+        float16x4_t v0 = vld1_f16(b);
+        float16_t v16 = *a;
+
+        v24 = vfma_n_f16(v24, v0, v16);
+
+        float16x4_t v1 = vld1_f16(b + 4);
+        float16_t v17 = *(a + 1);
+
+        v24 = vfma_n_f16(v24, v1, v17);
+
+        float16x4_t v2 = vld1_f16(b + 8);
+        float16_t v18 = *(a + 2);
+
+        v24 = vfma_n_f16(v24, v2, v18);
+
+        float16x4_t v3 = vld1_f16(b + 12);
+        float16_t v19 = *(a + 3);
+
+        v24 = vfma_n_f16(v24, v3, v19);
+
+        __builtin_prefetch(b + 16, 0, 3);
+        __builtin_prefetch(a + 4, 0, 3);
+
+        b += 16;
+        a += 4;
+
+        v24 = vadd_f16(vld1_f16(c), v24);
+
+        vst1_f16(c, v24);
+      }
+      c += 4;
+      a -= K;
+    }
+    sc += ldc;
+    c = sc;
+    a += K;
+    b = sb;
+  }
+}
+
+/**
+ * @brief hgemm 1x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(N % 4 == 0);
+
+  __fp16 *a = sa, *b = sb;
+  float *c = sc;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i++) {
+    for (j = 0; j < N; j += 4) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+
+      for (l = 0; l < K; l += 4) {
+        float16x4_t v24 = {0.F};
+        float16x4_t v0 = vld1_f16(b);
+        float16_t v16 = *a;
+
+        v24 = vfma_n_f16(v24, v0, v16);
+
+        float16x4_t v1 = vld1_f16(b + 4);
+        float16_t v17 = *(a + 1);
+
+        v24 = vfma_n_f16(v24, v1, v17);
+
+        float16x4_t v2 = vld1_f16(b + 8);
+        float16_t v18 = *(a + 2);
+
+        v24 = vfma_n_f16(v24, v2, v18);
+
+        float16x4_t v3 = vld1_f16(b + 12);
+        float16_t v19 = *(a + 3);
+
+        v24 = vfma_n_f16(v24, v3, v19);
+
+        __builtin_prefetch(b + 16, 0, 3);
+        __builtin_prefetch(a + 4, 0, 3);
+
+        b += 16;
+        a += 4;
+
+        vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24)));
+      }
+      c += 4;
+      a -= K;
+    }
+    sc += ldc;
+    c = sc;
+    a += K;
+    b = sb;
+  }
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x8.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_1x8.cpp
new file mode 100644 (file)
index 0000000..35927e5
--- /dev/null
@@ -0,0 +1,185 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
+ *
+ * @file   hgemm_kernel_1x8.cpp
+ * @date   05 April 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Debadri Samaddar <s.debadri@samsung.com>
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM 1x8 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+// 1. Partial sum 64 digits : worst accuracy, best latency
+#define KERNEL_1x8_ACC8()           \
+  do {                              \
+    v0 = vdupq_n_f16(0.F);          \
+    dv0 = *a;                       \
+    v24 = vld1q_f16(b);             \
+    v0 = vfmaq_n_f16(v0, v24, dv0); \
+    dv1 = *(a + 1);                 \
+    v25 = vld1q_f16(b + 8);         \
+    v0 = vfmaq_n_f16(v0, v25, dv1); \
+    dv2 = *(a + 2);                 \
+    v26 = vld1q_f16(b + 16);        \
+    v0 = vfmaq_n_f16(v0, v26, dv2); \
+    dv3 = *(a + 3);                 \
+    v27 = vld1q_f16(b + 24);        \
+    v0 = vfmaq_n_f16(v0, v27, dv3); \
+    dv4 = *(a + 4);                 \
+    v28 = vld1q_f16(b + 32);        \
+    v0 = vfmaq_n_f16(v0, v28, dv4); \
+    dv5 = *(a + 5);                 \
+    v29 = vld1q_f16(b + 40);        \
+    v0 = vfmaq_n_f16(v0, v29, dv5); \
+    dv6 = *(a + 6);                 \
+    v30 = vld1q_f16(b + 48);        \
+    v0 = vfmaq_n_f16(v0, v30, dv6); \
+    dv7 = *(a + 7);                 \
+    v31 = vld1q_f16(b + 56);        \
+    v0 = vfmaq_n_f16(v0, v31, dv7); \
+    l += 8;                         \
+    b += 8 * 8;                     \
+    a += 8;                         \
+  } while (0)
+
+// 2. Partial sum 32 digits : medium accuracy, medium latency
+#define KERNEL_1x8_ACC4()           \
+  do {                              \
+    v0 = vdupq_n_f16(0.F);          \
+    dv0 = *a;                       \
+    v24 = vld1q_f16(b);             \
+    v0 = vfmaq_n_f16(v0, v24, dv0); \
+    dv1 = *(a + 1);                 \
+    v25 = vld1q_f16(b + 8);         \
+    v0 = vfmaq_n_f16(v0, v25, dv1); \
+    dv2 = *(a + 2);                 \
+    v26 = vld1q_f16(b + 16);        \
+    v0 = vfmaq_n_f16(v0, v26, dv2); \
+    dv3 = *(a + 3);                 \
+    v27 = vld1q_f16(b + 24);        \
+    v0 = vfmaq_n_f16(v0, v27, dv3); \
+    l += 4;                         \
+    b += 8 * 4;                     \
+    a += 4;                         \
+  } while (0)
+
+// 3. Partial sum 8 digits : Best accuracy, worst latency
+#define KERNEL_1x8_ACC1()           \
+  do {                              \
+    v0 = vdupq_n_f16(0.F);          \
+    dv0 = *(a);                     \
+    v24 = vld1q_f16(b);             \
+    v0 = vfmaq_n_f16(v0, v24, dv0); \
+    l += 1;                         \
+    b += 8 * 1;                     \
+    a++;                            \
+  } while (0)
+
+/**
+ * @brief hgemm 1x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(N % 8 == 0);
+
+  __fp16 *a = sa, *b = sb, *c = sc;
+  unsigned int k8 = (K >> 3) << 3;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i++) {
+    for (j = 0; j < N; j += 8) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+      float16x8_t v0;
+      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+      float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+      l = 0;
+      for (; l < k8;) {
+        KERNEL_1x8_ACC8();
+
+        vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
+      }
+      for (; l < K;) {
+        KERNEL_1x8_ACC1();
+
+        vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
+      }
+      c += 8;
+      a -= K;
+    }
+    sc += ldc;
+    c = sc;
+    a += K;
+    b = sb;
+  }
+}
+
+/**
+ * @brief hgemm 1x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(N % 8 == 0);
+
+  __fp16 *a = sa, *b = sb;
+  float *c = sc;
+  unsigned int k8 = (K >> 3) << 3;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i++) {
+    for (j = 0; j < N; j += 8) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+      float16x8_t v0;
+      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+      float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+      l = 0;
+      for (; l < k8;) {
+        KERNEL_1x8_ACC8();
+
+        vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));
+
+        vst1q_f32(c + 4,
+                  vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));
+      }
+      for (; l < K;) {
+        KERNEL_1x8_ACC1();
+
+        vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));
+
+        vst1q_f32(c + 4,
+                  vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));
+      }
+      c += 8;
+      a -= K;
+    }
+    sc += ldc;
+    c = sc;
+    a += K;
+    b = sb;
+  }
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x4.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x4.cpp
new file mode 100644 (file)
index 0000000..40ab4ea
--- /dev/null
@@ -0,0 +1,360 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_kernel_4x4.cpp
+ * @date   01 April 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM 4x4 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+#define INIT_KERNEL_4x4()  \
+  do {                     \
+    v24 = vdup_n_f16(0.F); \
+    v25 = vdup_n_f16(0.F); \
+    v26 = vdup_n_f16(0.F); \
+    v27 = vdup_n_f16(0.F); \
+  } while (0)
+
+// 1. Partial sum 256 digits
+#define KERNEL_4x4_ACC16()                 \
+  do {                                     \
+    dv0 = vld1_f16(a);                     \
+    vb0 = vld1_f16(b);                     \
+    v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
+    v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
+    v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
+    v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
+    dv1 = vld1_f16(a + 4);                 \
+    vb1 = vld1_f16(b + 4);                 \
+    v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
+    v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
+    v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
+    v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
+    dv2 = vld1_f16(a + 4 * 2);             \
+    vb2 = vld1_f16(b + 4 * 2);             \
+    v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
+    v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
+    v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
+    v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
+    dv3 = vld1_f16(a + 4 * 3);             \
+    vb3 = vld1_f16(b + 4 * 3);             \
+    v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
+    v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
+    v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
+    v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
+    dv4 = vld1_f16(a + 4 * 4);             \
+    vb4 = vld1_f16(b + 4 * 4);             \
+    v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
+    v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
+    v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
+    v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
+    dv5 = vld1_f16(a + 4 * 5);             \
+    vb5 = vld1_f16(b + 4 * 5);             \
+    v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
+    v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
+    v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
+    v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
+    dv6 = vld1_f16(a + 4 * 6);             \
+    vb6 = vld1_f16(b + 4 * 6);             \
+    v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
+    v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
+    v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
+    v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
+    dv7 = vld1_f16(a + 4 * 7);             \
+    vb7 = vld1_f16(b + 4 * 7);             \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 8);             \
+    vb7 = vld1_f16(b + 4 * 8);             \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 9);             \
+    vb7 = vld1_f16(b + 4 * 9);             \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 10);            \
+    vb7 = vld1_f16(b + 4 * 10);            \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 11);            \
+    vb7 = vld1_f16(b + 4 * 11);            \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 12);            \
+    vb7 = vld1_f16(b + 4 * 12);            \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 13);            \
+    vb7 = vld1_f16(b + 4 * 13);            \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 14);            \
+    vb7 = vld1_f16(b + 4 * 14);            \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 15);            \
+    vb7 = vld1_f16(b + 4 * 15);            \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    l += 16;                               \
+    __builtin_prefetch(b + 64, 0, 3);      \
+    __builtin_prefetch(a + 64, 0, 3);      \
+    b += 4 * 16;                           \
+    a += 4 * 16;                           \
+  } while (0)
+
+// 2. Partial sum 128 digits
+#define KERNEL_4x4_ACC8()                  \
+  do {                                     \
+    dv0 = vld1_f16(a);                     \
+    vb0 = vld1_f16(b);                     \
+    v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
+    v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
+    v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
+    v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
+    dv1 = vld1_f16(a + 4);                 \
+    vb1 = vld1_f16(b + 4);                 \
+    v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
+    v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
+    v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
+    v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
+    dv2 = vld1_f16(a + 8);                 \
+    vb2 = vld1_f16(b + 8);                 \
+    v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
+    v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
+    v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
+    v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
+    dv3 = vld1_f16(a + 12);                \
+    vb3 = vld1_f16(b + 12);                \
+    v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
+    v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
+    v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
+    v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
+    dv4 = vld1_f16(a + 16);                \
+    vb4 = vld1_f16(b + 16);                \
+    v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
+    v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
+    v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
+    v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
+    dv5 = vld1_f16(a + 20);                \
+    vb5 = vld1_f16(b + 20);                \
+    v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
+    v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
+    v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
+    v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
+    dv6 = vld1_f16(a + 24);                \
+    vb6 = vld1_f16(b + 24);                \
+    v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
+    v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
+    v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
+    v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
+    dv7 = vld1_f16(a + 28);                \
+    vb7 = vld1_f16(b + 28);                \
+    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
+    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
+    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
+    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
+    l += 8;                                \
+    __builtin_prefetch(b + 32, 0, 3);      \
+    __builtin_prefetch(a + 32, 0, 3);      \
+    b += 4 * 8;                            \
+    a += 4 * 8;                            \
+  } while (0)
+
+// 3. Partial sum 16 digits
+#define KERNEL_4x4_ACC1()                  \
+  do {                                     \
+    dv0 = vld1_f16(a);                     \
+    vb0 = vld1_f16(b);                     \
+    v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
+    v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
+    v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
+    v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
+    l += 1;                                \
+    __builtin_prefetch(b + 4, 0, 3);       \
+    __builtin_prefetch(a + 4, 0, 3);       \
+    b += 4 * 1;                            \
+    a += 4 * 1;                            \
+  } while (0)
+
+#define SAVE_KERNEL_4X4_F16_F32()                                         \
+  do {                                                                    \
+    vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24)));             \
+    vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(v25))); \
+    vst1q_f32(c + 2 * ldc,                                                \
+              vaddq_f32(vld1q_f32(c + 2 * ldc), vcvt_f32_f16(v26)));      \
+    vst1q_f32(c + 3 * ldc,                                                \
+              vaddq_f32(vld1q_f32(c + 3 * ldc), vcvt_f32_f16(v27)));      \
+  } while (0)
+
+/**
+ * @brief hgemm 4x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
+
+  __fp16 *a = sa, *b = sb, *c = sc;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i += 4) {
+    for (j = 0; j < N; j += 4) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+
+      float16x4_t v24;
+      float16x4_t v25;
+      float16x4_t v26;
+      float16x4_t v27;
+      INIT_KERNEL_4x4();
+
+      for (l = 0; l < K; l += 4) {
+        float16x4_t v0 = vld1_f16(b);
+        float16x4_t v16 = vld1_f16(a);
+
+        v24 = vfma_lane_f16(v24, v0, v16, 0);
+        v25 = vfma_lane_f16(v25, v0, v16, 1);
+        v26 = vfma_lane_f16(v26, v0, v16, 2);
+        v27 = vfma_lane_f16(v27, v0, v16, 3);
+
+        float16x4_t v1 = vld1_f16(b + 4);
+        float16x4_t v17 = vld1_f16(a + 4);
+
+        v24 = vfma_lane_f16(v24, v1, v17, 0);
+        v25 = vfma_lane_f16(v25, v1, v17, 1);
+        v26 = vfma_lane_f16(v26, v1, v17, 2);
+        v27 = vfma_lane_f16(v27, v1, v17, 3);
+
+        float16x4_t v2 = vld1_f16(b + 8);
+        float16x4_t v18 = vld1_f16(a + 8);
+
+        v24 = vfma_lane_f16(v24, v2, v18, 0);
+        v25 = vfma_lane_f16(v25, v2, v18, 1);
+        v26 = vfma_lane_f16(v26, v2, v18, 2);
+        v27 = vfma_lane_f16(v27, v2, v18, 3);
+
+        float16x4_t v3 = vld1_f16(b + 12);
+        float16x4_t v19 = vld1_f16(a + 12);
+
+        v24 = vfma_lane_f16(v24, v3, v19, 0);
+        v25 = vfma_lane_f16(v25, v3, v19, 1);
+        v26 = vfma_lane_f16(v26, v3, v19, 2);
+        v27 = vfma_lane_f16(v27, v3, v19, 3);
+
+        __builtin_prefetch(b + 16, 0, 3);
+        __builtin_prefetch(a + 16, 0, 3);
+
+        b += 16;
+        a += 16;
+      }
+
+      v24 = vadd_f16(vld1_f16(c), v24);
+      v25 = vadd_f16(vld1_f16(c + ldc), v25);
+      v26 = vadd_f16(vld1_f16(c + 2 * ldc), v26);
+      v27 = vadd_f16(vld1_f16(c + 3 * ldc), v27);
+
+      vst1_f16(c, v24);
+      vst1_f16(c + ldc, v25);
+      vst1_f16(c + 2 * ldc, v26);
+      vst1_f16(c + 3 * ldc, v27);
+
+      c += 4;
+      a -= 4 * K;
+    }
+    sc += ldc * 4;
+    c = sc;
+    a += 4 * K;
+    b = sb;
+  }
+}
+
+/**
+ * @brief hgemm 4x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
+
+  __fp16 *a = sa, *b = sb;
+  float *c = sc;
+  unsigned int i, j, l;
+  unsigned int K16 = (K >> 4) << 4;
+  unsigned int K8 = (K >> 3) << 3;
+  for (i = 0; i < M; i += 4) {
+    for (j = 0; j < N; j += 4) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+
+      float16x4_t v24, v25, v26, v27;
+      float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+      float16x4_t vb0, vb1, vb2, vb3, vb4, vb5, vb6, vb7;
+      l = 0;
+      for (; l < K16;) {
+        INIT_KERNEL_4x4();
+        KERNEL_4x4_ACC16();
+        SAVE_KERNEL_4X4_F16_F32();
+      }
+      for (; l < K8;) {
+        INIT_KERNEL_4x4();
+        KERNEL_4x4_ACC8();
+        SAVE_KERNEL_4X4_F16_F32();
+      }
+      for (; l < K;) {
+        INIT_KERNEL_4x4();
+        KERNEL_4x4_ACC1();
+        SAVE_KERNEL_4X4_F16_F32();
+      }
+
+      c += 4;
+      a -= 4 * K;
+    }
+    sc += ldc * 4;
+    c = sc;
+    a += 4 * K;
+    b = sb;
+  }
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x8.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x8.cpp
new file mode 100644 (file)
index 0000000..3cebee4
--- /dev/null
@@ -0,0 +1,367 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_kernel_4x8.cpp
+ * @date   03 April 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM 8x8 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+#define INIT_KERNEL_4X8()  \
+  do {                     \
+    v0 = vdupq_n_f16(0.F); \
+    v3 = vdupq_n_f16(0.F); \
+    v6 = vdupq_n_f16(0.F); \
+    v9 = vdupq_n_f16(0.F); \
+  } while (0)
+
+// 1. Partial sum 256 digits
+#define KERNEL_4x8_ACC16()                \
+  do {                                    \
+    dv0 = vld1_f16(a);                    \
+    v24 = vld1q_f16(b);                   \
+    v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
+    v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
+    v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
+    v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
+    dv1 = vld1_f16(a + 4);                \
+    v25 = vld1q_f16(b + 8);               \
+    v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
+    v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
+    v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
+    v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
+    dv2 = vld1_f16(a + 4 * 2);            \
+    v26 = vld1q_f16(b + 8 * 2);           \
+    v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
+    v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
+    v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
+    v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
+    dv3 = vld1_f16(a + 4 * 3);            \
+    v27 = vld1q_f16(b + 8 * 3);           \
+    v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
+    v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
+    v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
+    v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
+    dv4 = vld1_f16(a + 4 * 4);            \
+    v28 = vld1q_f16(b + 8 * 4);           \
+    v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
+    v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
+    v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
+    v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
+    dv5 = vld1_f16(a + 4 * 5);            \
+    v29 = vld1q_f16(b + 8 * 5);           \
+    v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
+    v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
+    v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
+    v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
+    dv6 = vld1_f16(a + 4 * 6);            \
+    v30 = vld1q_f16(b + 8 * 6);           \
+    v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
+    v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
+    v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
+    v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
+    dv7 = vld1_f16(a + 4 * 7);            \
+    v31 = vld1q_f16(b + 8 * 7);           \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 8);            \
+    v31 = vld1q_f16(b + 8 * 8);           \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 9);            \
+    v31 = vld1q_f16(b + 8 * 9);           \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 10);           \
+    v31 = vld1q_f16(b + 8 * 10);          \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 11);           \
+    v31 = vld1q_f16(b + 8 * 11);          \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 12);           \
+    v31 = vld1q_f16(b + 8 * 12);          \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 13);           \
+    v31 = vld1q_f16(b + 8 * 13);          \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 14);           \
+    v31 = vld1q_f16(b + 8 * 14);          \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    dv7 = vld1_f16(a + 4 * 15);           \
+    v31 = vld1q_f16(b + 8 * 15);          \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    l += 16;                              \
+    __builtin_prefetch(b + 128, 0, 3);    \
+    __builtin_prefetch(a + 64, 0, 3);     \
+    b += 8 * 16;                          \
+    a += 4 * 16;                          \
+  } while (0)
+
+// 1. Partial sum 256 digits
+#define KERNEL_4x8_ACC8()                 \
+  do {                                    \
+    dv0 = vld1_f16(a);                    \
+    v24 = vld1q_f16(b);                   \
+    v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
+    v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
+    v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
+    v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
+    dv1 = vld1_f16(a + 4);                \
+    v25 = vld1q_f16(b + 8);               \
+    v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
+    v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
+    v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
+    v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
+    dv2 = vld1_f16(a + 8);                \
+    v26 = vld1q_f16(b + 16);              \
+    v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
+    v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
+    v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
+    v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
+    dv3 = vld1_f16(a + 12);               \
+    v27 = vld1q_f16(b + 24);              \
+    v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
+    v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
+    v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
+    v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
+    dv4 = vld1_f16(a + 16);               \
+    v28 = vld1q_f16(b + 32);              \
+    v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
+    v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
+    v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
+    v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
+    dv5 = vld1_f16(a + 20);               \
+    v29 = vld1q_f16(b + 40);              \
+    v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
+    v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
+    v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
+    v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
+    dv6 = vld1_f16(a + 24);               \
+    v30 = vld1q_f16(b + 48);              \
+    v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
+    v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
+    v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
+    v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
+    dv7 = vld1_f16(a + 28);               \
+    v31 = vld1q_f16(b + 56);              \
+    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
+    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
+    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
+    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
+    l += 8;                               \
+    __builtin_prefetch(b + 64, 0, 3);     \
+    __builtin_prefetch(a + 32, 0, 3);     \
+    b += 8 * 8;                           \
+    a += 4 * 8;                           \
+  } while (0)
+
+// 2. Partial sum 128 digits
+#define KERNEL_4x8_ACC4()                 \
+  do {                                    \
+    dv0 = vld1_f16(a);                    \
+    v24 = vld1q_f16(b);                   \
+    v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
+    v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
+    v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
+    v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
+    dv1 = vld1_f16(a + 4);                \
+    v25 = vld1q_f16(b + 8);               \
+    v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
+    v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
+    v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
+    v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
+    dv2 = vld1_f16(a + 8);                \
+    v26 = vld1q_f16(b + 16);              \
+    v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
+    v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
+    v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
+    v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
+    dv3 = vld1_f16(a + 12);               \
+    v27 = vld1q_f16(b + 24);              \
+    v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
+    v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
+    v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
+    v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
+    l += 4;                               \
+    __builtin_prefetch(b + 32, 0, 3);     \
+    __builtin_prefetch(a + 16, 0, 3);     \
+    b += 8 * 4;                           \
+    a += 4 * 4;                           \
+  } while (0)
+
+// 3. Partial sum 32 digits
+#define KERNEL_4x8_ACC1()                 \
+  do {                                    \
+    dv0 = vld1_f16(a);                    \
+    v24 = vld1q_f16(b);                   \
+    v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
+    v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
+    v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
+    v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
+    l += 1;                               \
+    __builtin_prefetch(b + 8, 0, 3);      \
+    __builtin_prefetch(a + 4, 0, 3);      \
+    b += 8 * 1;                           \
+    a += 4 * 1;                           \
+  } while (0)
+
+#define SAVE_KERNEL_4X8_F16_F32()                                             \
+  do {                                                                        \
+    vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));    \
+    vst1q_f32(c + ldc,                                                        \
+              vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v3)))); \
+    vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc),                  \
+                                     vcvt_f32_f16(vget_low_f16(v6))));        \
+    vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc),                  \
+                                     vcvt_f32_f16(vget_low_f16(v9))));        \
+                                                                              \
+    vst1q_f32(c + 4,                                                          \
+              vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));  \
+    vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc),                  \
+                                     vcvt_f32_f16(vget_high_f16(v3))));       \
+    vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc),          \
+                                         vcvt_f32_f16(vget_high_f16(v6))));   \
+    vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc),          \
+                                         vcvt_f32_f16(vget_high_f16(v9))));   \
+  } while (0)
+
+/**
+ * @brief hgemm 4x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 4 == 0 && N % 8 == 0);
+
+  __fp16 *a = sa, *b = sb, *c = sc;
+  unsigned int K8 = (K >> 3) << 3;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i += 4) {
+    for (j = 0; j < N; j += 8) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+      float16x8_t v0, v3, v6, v9;
+      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+      float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+      INIT_KERNEL_4X8();
+      l = 0;
+      for (; l < K8;) {
+        KERNEL_4x8_ACC8();
+      }
+      for (; l < K;) {
+        KERNEL_4x8_ACC1();
+      }
+      vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
+      vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
+      vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
+      vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
+      c += 8;
+      a -= 4 * K;
+    }
+    sc += ldc * 4;
+    c = sc;
+    a += 4 * K;
+    b = sb;
+  }
+}
+
+/**
+ * @brief hgemm 4x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 4 == 0 && N % 8 == 0);
+
+  __fp16 *a = sa, *b = sb;
+  float *c = sc;
+  unsigned int K16 = (K >> 4) << 4;
+  unsigned int K8 = (K >> 3) << 3;
+  unsigned int K4 = (K >> 2) << 2;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i += 4) {
+    for (j = 0; j < N; j += 8) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+      float16x8_t v0, v3, v6, v9;
+      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+      float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+      l = 0;
+      for (; l < K16;) {
+        INIT_KERNEL_4X8();
+        KERNEL_4x8_ACC16();
+        SAVE_KERNEL_4X8_F16_F32();
+      }
+      for (; l < K8;) {
+        INIT_KERNEL_4X8();
+        KERNEL_4x8_ACC8();
+        SAVE_KERNEL_4X8_F16_F32();
+      }
+      for (; l < K4;) {
+        INIT_KERNEL_4X8();
+        KERNEL_4x8_ACC4();
+        SAVE_KERNEL_4X8_F16_F32();
+      }
+      for (; l < K;) {
+        INIT_KERNEL_4X8();
+        KERNEL_4x8_ACC1();
+        SAVE_KERNEL_4X8_F16_F32();
+      }
+      c += 8;
+      a -= 4 * K;
+    }
+    sc += ldc * 4;
+    c = sc;
+    a += 4 * K;
+    b = sb;
+  }
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x16.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x16.cpp
new file mode 100644 (file)
index 0000000..f8d6b56
--- /dev/null
@@ -0,0 +1,863 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_kernel_8x16.cpp
+ * @date   04 April 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM 8x16 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <iostream>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+#define INIT_KERNEL_8X16()       \
+  do {                           \
+    v0_7 = vdupq_n_f16(0.F);     \
+    v8_15 = vdupq_n_f16(0.F);    \
+    v16_23 = vdupq_n_f16(0.F);   \
+    v24_31 = vdupq_n_f16(0.F);   \
+    v32_39 = vdupq_n_f16(0.F);   \
+    v40_47 = vdupq_n_f16(0.F);   \
+    v48_55 = vdupq_n_f16(0.F);   \
+    v56_63 = vdupq_n_f16(0.F);   \
+    v64_71 = vdupq_n_f16(0.F);   \
+    v72_79 = vdupq_n_f16(0.F);   \
+    v80_87 = vdupq_n_f16(0.F);   \
+    v88_95 = vdupq_n_f16(0.F);   \
+    v96_103 = vdupq_n_f16(0.F);  \
+    v104_111 = vdupq_n_f16(0.F); \
+    v112_119 = vdupq_n_f16(0.F); \
+    v120_127 = vdupq_n_f16(0.F); \
+  } while (0)
+
+// 1. Partial sum 2048 digits
+#define KERNEL_8x16_ACC16()                            \
+  do {                                                 \
+    va0 = vld1q_f16(a + 8 * 0);                        \
+    vb1 = vld1q_f16(b + 8 * 0);                        \
+    vb2 = vld1q_f16(b + 8 * 1);                        \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 1);                        \
+    vb1 = vld1q_f16(b + 8 * 2);                        \
+    vb2 = vld1q_f16(b + 8 * 3);                        \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 2);                        \
+    vb1 = vld1q_f16(b + 8 * 4);                        \
+    vb2 = vld1q_f16(b + 8 * 5);                        \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 3);                        \
+    vb1 = vld1q_f16(b + 8 * 6);                        \
+    vb2 = vld1q_f16(b + 8 * 7);                        \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 4);                        \
+    vb1 = vld1q_f16(b + 8 * 8);                        \
+    vb2 = vld1q_f16(b + 8 * 9);                        \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 5);                        \
+    vb1 = vld1q_f16(b + 8 * 10);                       \
+    vb2 = vld1q_f16(b + 8 * 11);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 6);                        \
+    vb1 = vld1q_f16(b + 8 * 12);                       \
+    vb2 = vld1q_f16(b + 8 * 13);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 7);                        \
+    vb1 = vld1q_f16(b + 8 * 14);                       \
+    vb2 = vld1q_f16(b + 8 * 15);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 8);                        \
+    vb1 = vld1q_f16(b + 8 * 16);                       \
+    vb2 = vld1q_f16(b + 8 * 17);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 9);                        \
+    vb1 = vld1q_f16(b + 8 * 18);                       \
+    vb2 = vld1q_f16(b + 8 * 19);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 10);                       \
+    vb1 = vld1q_f16(b + 8 * 20);                       \
+    vb2 = vld1q_f16(b + 8 * 21);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 11);                       \
+    vb1 = vld1q_f16(b + 8 * 22);                       \
+    vb2 = vld1q_f16(b + 8 * 23);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 12);                       \
+    vb1 = vld1q_f16(b + 8 * 24);                       \
+    vb2 = vld1q_f16(b + 8 * 25);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 13);                       \
+    vb1 = vld1q_f16(b + 8 * 26);                       \
+    vb2 = vld1q_f16(b + 8 * 27);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 14);                       \
+    vb1 = vld1q_f16(b + 8 * 28);                       \
+    vb2 = vld1q_f16(b + 8 * 29);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 15);                       \
+    vb1 = vld1q_f16(b + 8 * 30);                       \
+    vb2 = vld1q_f16(b + 8 * 31);                       \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    __builtin_prefetch(b + 256, 0, 3);                 \
+    __builtin_prefetch(a + 128, 0, 3);                 \
+    l += 16;                                           \
+    b += 16 * 16;                                      \
+    a += 8 * 16;                                       \
+  } while (0)
+
+// 2. Partial sum 1024 digits
+#define KERNEL_8x16_ACC8()                             \
+  do {                                                 \
+    va0 = vld1q_f16(a);                                \
+    vb1 = vld1q_f16(b);                                \
+    vb2 = vld1q_f16(b + 8);                            \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8);                            \
+    vb1 = vld1q_f16(b + 16);                           \
+    vb2 = vld1q_f16(b + 24);                           \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 16);                           \
+    vb1 = vld1q_f16(b + 32);                           \
+    vb2 = vld1q_f16(b + 40);                           \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 24);                           \
+    vb1 = vld1q_f16(b + 48);                           \
+    vb2 = vld1q_f16(b + 56);                           \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 32);                           \
+    vb1 = vld1q_f16(b + 64);                           \
+    vb2 = vld1q_f16(b + 72);                           \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 40);                           \
+    vb1 = vld1q_f16(b + 80);                           \
+    vb2 = vld1q_f16(b + 88);                           \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 48);                           \
+    vb1 = vld1q_f16(b + 96);                           \
+    vb2 = vld1q_f16(b + 104);                          \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 56);                           \
+    vb1 = vld1q_f16(b + 112);                          \
+    vb2 = vld1q_f16(b + 120);                          \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    l += 8;                                            \
+    __builtin_prefetch(b + 128, 0, 3);                 \
+    __builtin_prefetch(a + 64, 0, 3);                  \
+    b += 16 * 8;                                       \
+    a += 8 * 8;                                        \
+  } while (0)
+
+// 3. Partial sum 512 digits
+#define KERNEL_8x16_ACC4()                             \
+  do {                                                 \
+    va0 = vld1q_f16(a);                                \
+    vb1 = vld1q_f16(b);                                \
+    vb2 = vld1q_f16(b + 8);                            \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 8);                            \
+    vb1 = vld1q_f16(b + 16);                           \
+    vb2 = vld1q_f16(b + 24);                           \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 16);                           \
+    vb1 = vld1q_f16(b + 32);                           \
+    vb2 = vld1q_f16(b + 40);                           \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    va0 = vld1q_f16(a + 24);                           \
+    vb1 = vld1q_f16(b + 48);                           \
+    vb2 = vld1q_f16(b + 56);                           \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    l += 4;                                            \
+    __builtin_prefetch(b + 64, 0, 3);                  \
+    __builtin_prefetch(a + 32, 0, 3);                  \
+    b += 16 * 4;                                       \
+    a += 8 * 4;                                        \
+  } while (0)
+
+// 4. Partial sum 128 digits
+#define KERNEL_8x16_ACC1()                             \
+  do {                                                 \
+    va0 = vld1q_f16(a);                                \
+    vb1 = vld1q_f16(b);                                \
+    vb2 = vld1q_f16(b + 8);                            \
+    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
+    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
+    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
+    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
+    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
+    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
+    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
+    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
+    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
+    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
+    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
+    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
+    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
+    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
+    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
+    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
+    l += 1;                                            \
+    __builtin_prefetch(b + 16, 0, 3);                  \
+    __builtin_prefetch(a + 8, 0, 3);                   \
+    b += 16 * 1;                                       \
+    a += 8 * 1;                                        \
+  } while (0)
+
+#define SAVE_KERNEL_8X16_F16_F32()                                             \
+  do {                                                                         \
+    vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0_7))));   \
+    vst1q_f32(c + 4,                                                           \
+              vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0_7)))); \
+                                                                               \
+    vst1q_f32(                                                                 \
+      c + 8, vaddq_f32(vld1q_f32(c + 8), vcvt_f32_f16(vget_low_f16(v64_71)))); \
+    vst1q_f32(c + 8 + 4, vaddq_f32(vld1q_f32(c + 8 + 4),                       \
+                                   vcvt_f32_f16(vget_high_f16(v64_71))));      \
+                                                                               \
+    vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc),                           \
+                                 vcvt_f32_f16(vget_low_f16(v8_15))));          \
+    vst1q_f32(c + ldc + 4, vaddq_f32(vld1q_f32(c + ldc + 4),                   \
+                                     vcvt_f32_f16(vget_high_f16(v8_15))));     \
+                                                                               \
+    vst1q_f32(c + ldc + 8, vaddq_f32(vld1q_f32(c + ldc + 8),                   \
+                                     vcvt_f32_f16(vget_low_f16(v72_79))));     \
+    vst1q_f32(c + ldc + 8 + 4,                                                 \
+              vaddq_f32(vld1q_f32(c + ldc + 8 + 4),                            \
+                        vcvt_f32_f16(vget_high_f16(v72_79))));                 \
+                                                                               \
+    vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v16_23))));     \
+    vst1q_f32(c + 2 * ldc + 4,                                                 \
+              vaddq_f32(vld1q_f32(c + 2 * ldc + 4),                            \
+                        vcvt_f32_f16(vget_high_f16(v16_23))));                 \
+                                                                               \
+    vst1q_f32(c + 2 * ldc + 8, vaddq_f32(vld1q_f32(c + 2 * ldc + 8),           \
+                                         vcvt_f32_f16(vget_low_f16(v80_87)))); \
+    vst1q_f32(c + 2 * ldc + 8 + 4,                                             \
+              vaddq_f32(vld1q_f32(c + 2 * ldc + 8 + 4),                        \
+                        vcvt_f32_f16(vget_high_f16(v80_87))));                 \
+                                                                               \
+    vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v24_31))));     \
+    vst1q_f32(c + 3 * ldc + 4,                                                 \
+              vaddq_f32(vld1q_f32(c + 3 * ldc + 4),                            \
+                        vcvt_f32_f16(vget_high_f16(v24_31))));                 \
+                                                                               \
+    vst1q_f32(c + 3 * ldc + 8, vaddq_f32(vld1q_f32(c + 3 * ldc + 8),           \
+                                         vcvt_f32_f16(vget_low_f16(v88_95)))); \
+    vst1q_f32(c + 3 * ldc + 8 + 4,                                             \
+              vaddq_f32(vld1q_f32(c + 3 * ldc + 8 + 4),                        \
+                        vcvt_f32_f16(vget_high_f16(v88_95))));                 \
+                                                                               \
+    vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v32_39))));     \
+    vst1q_f32(c + 4 * ldc + 4,                                                 \
+              vaddq_f32(vld1q_f32(c + 4 * ldc + 4),                            \
+                        vcvt_f32_f16(vget_high_f16(v32_39))));                 \
+                                                                               \
+    vst1q_f32(c + 4 * ldc + 8,                                                 \
+              vaddq_f32(vld1q_f32(c + 4 * ldc + 8),                            \
+                        vcvt_f32_f16(vget_low_f16(v96_103))));                 \
+    vst1q_f32(c + 4 * ldc + 8 + 4,                                             \
+              vaddq_f32(vld1q_f32(c + 4 * ldc + 8 + 4),                        \
+                        vcvt_f32_f16(vget_high_f16(v96_103))));                \
+                                                                               \
+    vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v40_47))));     \
+    vst1q_f32(c + 5 * ldc + 4,                                                 \
+              vaddq_f32(vld1q_f32(c + 5 * ldc + 4),                            \
+                        vcvt_f32_f16(vget_high_f16(v40_47))));                 \
+    vst1q_f32(c + 5 * ldc + 8,                                                 \
+              vaddq_f32(vld1q_f32(c + 5 * ldc + 8),                            \
+                        vcvt_f32_f16(vget_low_f16(v104_111))));                \
+    vst1q_f32(c + 5 * ldc + 8 + 4,                                             \
+              vaddq_f32(vld1q_f32(c + 5 * ldc + 8 + 4),                        \
+                        vcvt_f32_f16(vget_high_f16(v104_111))));               \
+                                                                               \
+    vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v48_55))));     \
+    vst1q_f32(c + 6 * ldc + 4,                                                 \
+              vaddq_f32(vld1q_f32(c + 6 * ldc + 4),                            \
+                        vcvt_f32_f16(vget_high_f16(v48_55))));                 \
+                                                                               \
+    vst1q_f32(c + 6 * ldc + 8,                                                 \
+              vaddq_f32(vld1q_f32(c + 6 * ldc + 8),                            \
+                        vcvt_f32_f16(vget_low_f16(v112_119))));                \
+    vst1q_f32(c + 6 * ldc + 8 + 4,                                             \
+              vaddq_f32(vld1q_f32(c + 6 * ldc + 8 + 4),                        \
+                        vcvt_f32_f16(vget_high_f16(v112_119))));               \
+                                                                               \
+    vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v56_63))));     \
+    vst1q_f32(c + 7 * ldc + 4,                                                 \
+              vaddq_f32(vld1q_f32(c + 7 * ldc + 4),                            \
+                        vcvt_f32_f16(vget_high_f16(v56_63))));                 \
+                                                                               \
+    vst1q_f32(c + 7 * ldc + 8,                                                 \
+              vaddq_f32(vld1q_f32(c + 7 * ldc + 8),                            \
+                        vcvt_f32_f16(vget_low_f16(v120_127))));                \
+    vst1q_f32(c + 7 * ldc + 8 + 4,                                             \
+              vaddq_f32(vld1q_f32(c + 7 * ldc + 8 + 4),                        \
+                        vcvt_f32_f16(vget_high_f16(v120_127))));               \
+  } while (0)
+
+/**
+ * @brief hgemm 8x16 kernel sc = sa * sb
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
+                       __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 8 == 0 && N % 16 == 0 && K % 8 == 0);
+
+  __fp16 *a = sa, *b = sb, *c = sc;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i += 8) {
+    for (j = 0; j < N; j += 16) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+      // 8x16
+      float16x8_t v0_7, v8_15;
+      float16x8_t v16_23, v24_31;
+      float16x8_t v32_39, v40_47;
+      float16x8_t v48_55, v56_63;
+      float16x8_t v64_71, v72_79;
+      float16x8_t v80_87, v88_95;
+      float16x8_t v96_103, v104_111;
+      float16x8_t v112_119, v120_127;
+      float16x8_t vb1, vb2;
+      float16x8_t va0;
+
+      INIT_KERNEL_8X16();
+      l = 0;
+      for (; l < K;) {
+        KERNEL_8x16_ACC1();
+      }
+      vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
+      vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
+      vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
+      vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
+      vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23));
+      vst1q_f16(c + 2 * ldc + 8, vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87));
+      vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31));
+      vst1q_f16(c + 3 * ldc + 8, vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95));
+      vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39));
+      vst1q_f16(c + 4 * ldc + 8,
+                vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103));
+      vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47));
+      vst1q_f16(c + 5 * ldc + 8,
+                vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111));
+      vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55));
+      vst1q_f16(c + 6 * ldc + 8,
+                vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119));
+      vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
+      vst1q_f16(c + 7 * ldc + 8,
+                vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
+      c += 16;
+      a -= 8 * K;
+    }
+    sc += ldc * 8;
+    c = sc;
+    a += 8 * K;
+    b = sb;
+  }
+}
+
+/**
+ * @brief hgemm 8x16 kernel sc = sa * sb
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
+                       __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 8 == 0 && N % 16 == 0 && K % 4 == 0);
+
+  __fp16 *a = sa, *b = sb;
+  float *c = sc;
+  unsigned int i, j, l;
+  unsigned int K4 = (K >> 2) << 2;
+  unsigned int K8 = (K >> 3) << 3;
+  unsigned int K16 = (K >> 4) << 4;
+  for (i = 0; i < M; i += 8) {
+    for (j = 0; j < N; j += 16) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+      float16x8_t v0_7, v8_15;
+      float16x8_t v16_23, v24_31;
+      float16x8_t v32_39, v40_47;
+      float16x8_t v48_55, v56_63;
+      float16x8_t v64_71, v72_79;
+      float16x8_t v80_87, v88_95;
+      float16x8_t v96_103, v104_111;
+      float16x8_t v112_119, v120_127;
+      float16x8_t vb1, vb2;
+      float16x8_t va0;
+      l = 0;
+      for (; l < K16;) {
+        INIT_KERNEL_8X16();
+        KERNEL_8x16_ACC16();
+        SAVE_KERNEL_8X16_F16_F32();
+      }
+      for (; l < K8;) {
+        INIT_KERNEL_8X16();
+        KERNEL_8x16_ACC8();
+        SAVE_KERNEL_8X16_F16_F32();
+      }
+      for (; l < K4;) {
+        INIT_KERNEL_8X16();
+        KERNEL_8x16_ACC4();
+        SAVE_KERNEL_8X16_F16_F32();
+      }
+      for (; l < K;) {
+        INIT_KERNEL_8X16();
+        KERNEL_8x16_ACC1();
+        SAVE_KERNEL_8X16_F16_F32();
+      }
+      c += 16;
+      a -= 8 * K;
+    }
+    sc += ldc * 8;
+    c = sc;
+    a += 8 * K;
+    b = sb;
+  }
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x8.cpp b/nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_8x8.cpp
new file mode 100644 (file)
index 0000000..f799e52
--- /dev/null
@@ -0,0 +1,512 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_kernel_8x8.cpp
+ * @date   01 April 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM 8x8 kernel
+ *
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_kernel.h>
+#include <stdlib.h>
+
+#define INIT_KERNEL_8x8()   \
+  do {                      \
+    v24 = vdupq_n_f16(0.F); \
+    v25 = vdupq_n_f16(0.F); \
+    v26 = vdupq_n_f16(0.F); \
+    v27 = vdupq_n_f16(0.F); \
+    v28 = vdupq_n_f16(0.F); \
+    v29 = vdupq_n_f16(0.F); \
+    v30 = vdupq_n_f16(0.F); \
+    v31 = vdupq_n_f16(0.F); \
+  } while (0)
+
+// 1. Partial sum 1024 digits
+#define KERNEL_8x8_ACC16()                   \
+  do {                                       \
+    va0 = vld1q_f16(a);                      \
+    v16 = vld1q_f16(b);                      \
+    v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
+    va0 = vld1q_f16(a + 8);                  \
+    v17 = vld1q_f16(b + 8);                  \
+    v24 = vfmaq_laneq_f16(v24, v17, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v17, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v17, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v17, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v17, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v17, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v17, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v17, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 2);              \
+    v18 = vld1q_f16(b + 8 * 2);              \
+    v24 = vfmaq_laneq_f16(v24, v18, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v18, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v18, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v18, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v18, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v18, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v18, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v18, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 3);              \
+    v19 = vld1q_f16(b + 8 * 3);              \
+    v24 = vfmaq_laneq_f16(v24, v19, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v19, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v19, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v19, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v19, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v19, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v19, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v19, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 4);              \
+    v20 = vld1q_f16(b + 8 * 4);              \
+    v24 = vfmaq_laneq_f16(v24, v20, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v20, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v20, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v20, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v20, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v20, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v20, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v20, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 5);              \
+    v21 = vld1q_f16(b + 8 * 5);              \
+    v24 = vfmaq_laneq_f16(v24, v21, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v21, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v21, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v21, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v21, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v21, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v21, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v21, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 6);              \
+    v22 = vld1q_f16(b + 8 * 6);              \
+    v24 = vfmaq_laneq_f16(v24, v22, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v22, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v22, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v22, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v22, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v22, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v22, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v22, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 7);              \
+    v23 = vld1q_f16(b + 8 * 7);              \
+    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 8);              \
+    v23 = vld1q_f16(b + 8 * 8);              \
+    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 9);              \
+    v23 = vld1q_f16(b + 8 * 9);              \
+    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 10);             \
+    v23 = vld1q_f16(b + 8 * 10);             \
+    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 11);             \
+    v23 = vld1q_f16(b + 8 * 11);             \
+    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 12);             \
+    v23 = vld1q_f16(b + 8 * 12);             \
+    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 13);             \
+    v23 = vld1q_f16(b + 8 * 13);             \
+    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 14);             \
+    v23 = vld1q_f16(b + 8 * 14);             \
+    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+    va0 = vld1q_f16(a + 8 * 15);             \
+    v23 = vld1q_f16(b + 8 * 15);             \
+    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
+    __builtin_prefetch(b + 128, 0, 3);       \
+    __builtin_prefetch(a + 128, 0, 3);       \
+    l += 16;                                 \
+    b += 8 * 16;                             \
+    a += 8 * 16;                             \
+  } while (0)
+
+// 2. Partial sum 512 digits
+#define KERNEL_8x8_ACC8()                    \
+  do {                                       \
+    va0 = vld1q_f16(a);                      \
+    v16 = vld1q_f16(b);                      \
+    v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
+    va1 = vld1q_f16(a + 8);                  \
+    v17 = vld1q_f16(b + 8);                  \
+    v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \
+    v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \
+    v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \
+    v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \
+    v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \
+    v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \
+    v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \
+    v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \
+    va2 = vld1q_f16(a + 16);                 \
+    v18 = vld1q_f16(b + 16);                 \
+    v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \
+    v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \
+    v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \
+    v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \
+    v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \
+    v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \
+    v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \
+    v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \
+    va3 = vld1q_f16(a + 24);                 \
+    v19 = vld1q_f16(b + 24);                 \
+    v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \
+    v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \
+    v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \
+    v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \
+    v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \
+    v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \
+    v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \
+    v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \
+    va4 = vld1q_f16(a + 32);                 \
+    v20 = vld1q_f16(b + 32);                 \
+    v24 = vfmaq_laneq_f16(v24, v20, va4, 0); \
+    v25 = vfmaq_laneq_f16(v25, v20, va4, 1); \
+    v26 = vfmaq_laneq_f16(v26, v20, va4, 2); \
+    v27 = vfmaq_laneq_f16(v27, v20, va4, 3); \
+    v28 = vfmaq_laneq_f16(v28, v20, va4, 4); \
+    v29 = vfmaq_laneq_f16(v29, v20, va4, 5); \
+    v30 = vfmaq_laneq_f16(v30, v20, va4, 6); \
+    v31 = vfmaq_laneq_f16(v31, v20, va4, 7); \
+    va5 = vld1q_f16(a + 40);                 \
+    v21 = vld1q_f16(b + 40);                 \
+    v24 = vfmaq_laneq_f16(v24, v21, va5, 0); \
+    v25 = vfmaq_laneq_f16(v25, v21, va5, 1); \
+    v26 = vfmaq_laneq_f16(v26, v21, va5, 2); \
+    v27 = vfmaq_laneq_f16(v27, v21, va5, 3); \
+    v28 = vfmaq_laneq_f16(v28, v21, va5, 4); \
+    v29 = vfmaq_laneq_f16(v29, v21, va5, 5); \
+    v30 = vfmaq_laneq_f16(v30, v21, va5, 6); \
+    v31 = vfmaq_laneq_f16(v31, v21, va5, 7); \
+    va6 = vld1q_f16(a + 48);                 \
+    v22 = vld1q_f16(b + 48);                 \
+    v24 = vfmaq_laneq_f16(v24, v22, va6, 0); \
+    v25 = vfmaq_laneq_f16(v25, v22, va6, 1); \
+    v26 = vfmaq_laneq_f16(v26, v22, va6, 2); \
+    v27 = vfmaq_laneq_f16(v27, v22, va6, 3); \
+    v28 = vfmaq_laneq_f16(v28, v22, va6, 4); \
+    v29 = vfmaq_laneq_f16(v29, v22, va6, 5); \
+    v30 = vfmaq_laneq_f16(v30, v22, va6, 6); \
+    v31 = vfmaq_laneq_f16(v31, v22, va6, 7); \
+    va7 = vld1q_f16(a + 56);                 \
+    v23 = vld1q_f16(b + 56);                 \
+    v24 = vfmaq_laneq_f16(v24, v23, va7, 0); \
+    v25 = vfmaq_laneq_f16(v25, v23, va7, 1); \
+    v26 = vfmaq_laneq_f16(v26, v23, va7, 2); \
+    v27 = vfmaq_laneq_f16(v27, v23, va7, 3); \
+    v28 = vfmaq_laneq_f16(v28, v23, va7, 4); \
+    v29 = vfmaq_laneq_f16(v29, v23, va7, 5); \
+    v30 = vfmaq_laneq_f16(v30, v23, va7, 6); \
+    v31 = vfmaq_laneq_f16(v31, v23, va7, 7); \
+    __builtin_prefetch(b + 64, 0, 3);        \
+    __builtin_prefetch(a + 64, 0, 3);        \
+    l += 8;                                  \
+    b += 8 * 8;                              \
+    a += 8 * 8;                              \
+  } while (0)
+
+// 3. Partial sum 256 digits
+#define KERNEL_8x8_ACC4()                    \
+  do {                                       \
+    va0 = vld1q_f16(a);                      \
+    v16 = vld1q_f16(b);                      \
+    v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
+    va1 = vld1q_f16(a + 8);                  \
+    v17 = vld1q_f16(b + 8);                  \
+    v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \
+    v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \
+    v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \
+    v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \
+    v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \
+    v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \
+    v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \
+    v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \
+    va2 = vld1q_f16(a + 16);                 \
+    v18 = vld1q_f16(b + 16);                 \
+    v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \
+    v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \
+    v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \
+    v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \
+    v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \
+    v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \
+    v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \
+    v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \
+    va3 = vld1q_f16(a + 24);                 \
+    v19 = vld1q_f16(b + 24);                 \
+    v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \
+    v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \
+    v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \
+    v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \
+    v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \
+    v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \
+    v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \
+    v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \
+    __builtin_prefetch(b + 32, 0, 3);        \
+    __builtin_prefetch(a + 32, 0, 3);        \
+    l += 4;                                  \
+    b += 8 * 4;                              \
+    a += 8 * 4;                              \
+  } while (0)
+
+// 4. Partial sum 64 digits
+#define KERNEL_8x8_ACC1()                    \
+  do {                                       \
+    va0 = vld1q_f16(a);                      \
+    v16 = vld1q_f16(b);                      \
+    v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
+    v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
+    v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
+    v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
+    v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
+    v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
+    v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
+    v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
+    __builtin_prefetch(b + 8, 0, 3);         \
+    __builtin_prefetch(a + 8, 0, 3);         \
+    l += 1;                                  \
+    b += 8 * 1;                              \
+    a += 8 * 1;                              \
+  } while (0)
+
+#define SAVE_KERNEL_8X8_F16_f32()                                              \
+  do {                                                                         \
+    vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v24))));    \
+    vst1q_f32(c + 4,                                                           \
+              vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v24))));  \
+                                                                               \
+    vst1q_f32(c + ldc,                                                         \
+              vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v25)))); \
+    vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc),                   \
+                                     vcvt_f32_f16(vget_high_f16(v25))));       \
+                                                                               \
+    vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v26))));        \
+    vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc),           \
+                                         vcvt_f32_f16(vget_high_f16(v26))));   \
+                                                                               \
+    vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v27))));        \
+    vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc),           \
+                                         vcvt_f32_f16(vget_high_f16(v27))));   \
+                                                                               \
+    vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v28))));        \
+    vst1q_f32(c + 4 + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 + 4 * ldc),           \
+                                         vcvt_f32_f16(vget_high_f16(v28))));   \
+                                                                               \
+    vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v29))));        \
+    vst1q_f32(c + 4 + 5 * ldc, vaddq_f32(vld1q_f32(c + 4 + 5 * ldc),           \
+                                         vcvt_f32_f16(vget_high_f16(v29))));   \
+                                                                               \
+    vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v30))));        \
+    vst1q_f32(c + 4 + 6 * ldc, vaddq_f32(vld1q_f32(c + 4 + 6 * ldc),           \
+                                         vcvt_f32_f16(vget_high_f16(v30))));   \
+                                                                               \
+    vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc),                   \
+                                     vcvt_f32_f16(vget_low_f16(v31))));        \
+    vst1q_f32(c + 4 + 7 * ldc, vaddq_f32(vld1q_f32(c + 4 + 7 * ldc),           \
+                                         vcvt_f32_f16(vget_high_f16(v31))));   \
+  } while (0)
+
+/**
+ * @brief hgemm 8x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 8 == 0 && N % 8 == 0 && K % 4 == 0);
+
+  __fp16 *a = sa, *b = sb, *c = sc;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i += 8) {
+    for (j = 0; j < N; j += 8) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+
+      float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
+      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+      float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
+      INIT_KERNEL_8x8();
+      l = 0;
+      for (; l < K;) {
+        KERNEL_8x8_ACC1();
+      }
+      vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24));
+      vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25));
+      vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26));
+      vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27));
+      vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28));
+      vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29));
+      vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30));
+      vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31));
+      c += 8;
+      a -= 8 * K;
+    }
+    sc += ldc * 8;
+    c = sc;
+    a += 8 * K;
+    b = sb;
+  }
+}
+
+/**
+ * @brief hgemm 8x8 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading-dimension of matrix C
+ */
+void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 8 == 0 && N % 8 == 0 && K % 8 == 0);
+
+  __fp16 *a = sa, *b = sb;
+  float *c = sc;
+  unsigned int i, j, l;
+  unsigned int K4 = (K >> 2) << 2;
+  unsigned int K8 = (K >> 3) << 3;
+  unsigned int K16 = (K >> 4) << 4;
+  for (i = 0; i < M; i += 8) {
+    for (j = 0; j < N; j += 8) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+
+      float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
+      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
+      float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
+      l = 0;
+      for (; l < K16;) {
+        INIT_KERNEL_8x8();
+        KERNEL_8x8_ACC16();
+        SAVE_KERNEL_8X8_F16_f32();
+      }
+      for (; l < K8;) {
+        INIT_KERNEL_8x8();
+        KERNEL_8x8_ACC8();
+        SAVE_KERNEL_8X8_F16_f32();
+      }
+      for (; l < K4;) {
+        INIT_KERNEL_8x8();
+        KERNEL_8x8_ACC4();
+        SAVE_KERNEL_8X8_F16_f32();
+      }
+      for (; l < K;) {
+        INIT_KERNEL_8x8();
+        KERNEL_8x8_ACC1();
+        SAVE_KERNEL_8X8_F16_f32();
+      }
+
+      c += 8;
+      a -= 8 * K;
+    }
+    sc += ldc * 8;
+    c = sc;
+    a += 8 * K;
+    b = sb;
+  }
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel/meson.build b/nntrainer/tensor/hgemm/hgemm_kernel/meson.build
new file mode 100644 (file)
index 0000000..1b6cc50
--- /dev/null
@@ -0,0 +1,22 @@
+hgemm_kernel_headers = [
+  'hgemm_kernel.h',
+]
+
+
+hgemm_kernel_sources = [
+    'hgemm_kernel_1x4.cpp',
+    'hgemm_kernel_1x8.cpp',
+    'hgemm_kernel_4x4.cpp',
+    'hgemm_kernel_4x8.cpp',
+    'hgemm_kernel_8x8.cpp',
+    'hgemm_kernel_8x16.cpp',
+]
+
+foreach s : hgemm_kernel_sources
+  nntrainer_sources += meson.current_source_dir() / s
+endforeach
+
+foreach h : hgemm_kernel_headers
+  nntrainer_headers += meson.current_source_dir() / h
+endforeach
+
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h b/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h
deleted file mode 100644 (file)
index d001887..0000000
+++ /dev/null
@@ -1,145 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
- *
- * @file   hgemm_kernel_1x4.h
- * @date   23 April 2024
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Debadri Samaddar <s.debadri@samsung.com>
- * @bug    No known bugs except for NYI items
- * @brief  This is half-precision GEMM 1x4 kernel
- *
- */
-
-#include <stdlib.h>
-#include <arm_neon.h>
-#include <assert.h>
-
-/**
- * @brief hgemm 1x4 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading dimension of matrix C
- */
-void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(N % 4 == 0);
-
-  __fp16 *a = sa, *b = sb, *c = sc;
-  unsigned int i, j, l;
-  for (i = 0; i < M; i++) {
-    for (j = 0; j < N; j += 4) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-
-      for (l = 0; l < K; l += 4) {
-        float16x4_t v24 = {0.F};
-        float16x4_t v0 = vld1_f16(b);
-        float16_t v16 = *a;
-
-        v24 = vfma_n_f16(v24, v0, v16);
-
-        float16x4_t v1 = vld1_f16(b + 4);
-        float16_t v17 = *(a + 1);
-
-        v24 = vfma_n_f16(v24, v1, v17);
-
-        float16x4_t v2 = vld1_f16(b + 8);
-        float16_t v18 = *(a + 2);
-
-        v24 = vfma_n_f16(v24, v2, v18);
-
-        float16x4_t v3 = vld1_f16(b + 12);
-        float16_t v19 = *(a + 3);
-
-        v24 = vfma_n_f16(v24, v3, v19);
-
-        __builtin_prefetch(b + 16, 0, 3);
-        __builtin_prefetch(a + 4, 0, 3);
-
-        b += 16;
-        a += 4;
-
-        v24 = vadd_f16(vld1_f16(c), v24);
-
-        vst1_f16(c, v24);
-      }
-      c += 4;
-      a -= K;
-    }
-    sc += ldc;
-    c = sc;
-    a += K;
-    b = sb;
-  }
-}
-
-/**
- * @brief hgemm 1x4 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading dimension of matrix C
- */
-void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(N % 4 == 0);
-
-  __fp16 *a = sa, *b = sb;
-  float *c = sc;
-  unsigned int i, j, l;
-  for (i = 0; i < M; i++) {
-    for (j = 0; j < N; j += 4) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-
-      for (l = 0; l < K; l += 4) {
-        float16x4_t v24 = {0.F};
-        float16x4_t v0 = vld1_f16(b);
-        float16_t v16 = *a;
-
-        v24 = vfma_n_f16(v24, v0, v16);
-
-        float16x4_t v1 = vld1_f16(b + 4);
-        float16_t v17 = *(a + 1);
-
-        v24 = vfma_n_f16(v24, v1, v17);
-
-        float16x4_t v2 = vld1_f16(b + 8);
-        float16_t v18 = *(a + 2);
-
-        v24 = vfma_n_f16(v24, v2, v18);
-
-        float16x4_t v3 = vld1_f16(b + 12);
-        float16_t v19 = *(a + 3);
-
-        v24 = vfma_n_f16(v24, v3, v19);
-
-        __builtin_prefetch(b + 16, 0, 3);
-        __builtin_prefetch(a + 4, 0, 3);
-
-        b += 16;
-        a += 4;
-
-        vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24)));
-      }
-      c += 4;
-      a -= K;
-    }
-    sc += ldc;
-    c = sc;
-    a += K;
-    b = sb;
-  }
-}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_1x8.h b/nntrainer/tensor/hgemm/hgemm_kernel_1x8.h
deleted file mode 100644 (file)
index 3114ca3..0000000
+++ /dev/null
@@ -1,184 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
- *
- * @file   hgemm_kernel_1x8.h
- * @date   05 April 2024
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Debadri Samaddar <s.debadri@samsung.com>
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug    No known bugs except for NYI items
- * @brief  This is half-precision GEMM 1x8 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <stdlib.h>
-
-// 1. Partial sum 64 digits : worst accuracy, best latency
-#define KERNEL_1x8_ACC8()           \
-  do {                              \
-    v0 = vdupq_n_f16(0.F);          \
-    dv0 = *a;                       \
-    v24 = vld1q_f16(b);             \
-    v0 = vfmaq_n_f16(v0, v24, dv0); \
-    dv1 = *(a + 1);                 \
-    v25 = vld1q_f16(b + 8);         \
-    v0 = vfmaq_n_f16(v0, v25, dv1); \
-    dv2 = *(a + 2);                 \
-    v26 = vld1q_f16(b + 16);        \
-    v0 = vfmaq_n_f16(v0, v26, dv2); \
-    dv3 = *(a + 3);                 \
-    v27 = vld1q_f16(b + 24);        \
-    v0 = vfmaq_n_f16(v0, v27, dv3); \
-    dv4 = *(a + 4);                 \
-    v28 = vld1q_f16(b + 32);        \
-    v0 = vfmaq_n_f16(v0, v28, dv4); \
-    dv5 = *(a + 5);                 \
-    v29 = vld1q_f16(b + 40);        \
-    v0 = vfmaq_n_f16(v0, v29, dv5); \
-    dv6 = *(a + 6);                 \
-    v30 = vld1q_f16(b + 48);        \
-    v0 = vfmaq_n_f16(v0, v30, dv6); \
-    dv7 = *(a + 7);                 \
-    v31 = vld1q_f16(b + 56);        \
-    v0 = vfmaq_n_f16(v0, v31, dv7); \
-    l += 8;                         \
-    b += 8 * 8;                     \
-    a += 8;                         \
-  } while (0)
-
-// 2. Partial sum 32 digits : medium accuracy, medium latency
-#define KERNEL_1x8_ACC4()           \
-  do {                              \
-    v0 = vdupq_n_f16(0.F);          \
-    dv0 = *a;                       \
-    v24 = vld1q_f16(b);             \
-    v0 = vfmaq_n_f16(v0, v24, dv0); \
-    dv1 = *(a + 1);                 \
-    v25 = vld1q_f16(b + 8);         \
-    v0 = vfmaq_n_f16(v0, v25, dv1); \
-    dv2 = *(a + 2);                 \
-    v26 = vld1q_f16(b + 16);        \
-    v0 = vfmaq_n_f16(v0, v26, dv2); \
-    dv3 = *(a + 3);                 \
-    v27 = vld1q_f16(b + 24);        \
-    v0 = vfmaq_n_f16(v0, v27, dv3); \
-    l += 4;                         \
-    b += 8 * 4;                     \
-    a += 4;                         \
-  } while (0)
-
-// 3. Partial sum 8 digits : Best accuracy, worst latency
-#define KERNEL_1x8_ACC1()           \
-  do {                              \
-    v0 = vdupq_n_f16(0.F);          \
-    dv0 = *(a);                     \
-    v24 = vld1q_f16(b);             \
-    v0 = vfmaq_n_f16(v0, v24, dv0); \
-    l += 1;                         \
-    b += 8 * 1;                     \
-    a++;                            \
-  } while (0)
-
-/**
- * @brief hgemm 1x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(N % 8 == 0);
-
-  __fp16 *a = sa, *b = sb, *c = sc;
-  unsigned int k8 = (K >> 3) << 3;
-  unsigned int i, j, l;
-  for (i = 0; i < M; i++) {
-    for (j = 0; j < N; j += 8) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-      float16x8_t v0;
-      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
-      float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
-      l = 0;
-      for (; l < k8;) {
-        KERNEL_1x8_ACC8();
-
-        vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
-      }
-      for (; l < K;) {
-        KERNEL_1x8_ACC1();
-
-        vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
-      }
-      c += 8;
-      a -= K;
-    }
-    sc += ldc;
-    c = sc;
-    a += K;
-    b = sb;
-  }
-}
-
-/**
- * @brief hgemm 1x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_1x8(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(N % 8 == 0);
-
-  __fp16 *a = sa, *b = sb;
-  float *c = sc;
-  unsigned int k8 = (K >> 3) << 3;
-  unsigned int i, j, l;
-  for (i = 0; i < M; i++) {
-    for (j = 0; j < N; j += 8) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-      float16x8_t v0;
-      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
-      float16_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
-      l = 0;
-      for (; l < k8;) {
-        KERNEL_1x8_ACC8();
-
-        vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));
-
-        vst1q_f32(c + 4,
-                  vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));
-      }
-      for (; l < K;) {
-        KERNEL_1x8_ACC1();
-
-        vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));
-
-        vst1q_f32(c + 4,
-                  vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));
-      }
-      c += 8;
-      a -= K;
-    }
-    sc += ldc;
-    c = sc;
-    a += K;
-    b = sb;
-  }
-}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h b/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h
deleted file mode 100644 (file)
index 18e86cc..0000000
+++ /dev/null
@@ -1,359 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file   hgemm_kernel_4x4.h
- * @date   01 April 2024
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug    No known bugs except for NYI items
- * @brief  This is half-precision GEMM 4x4 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <stdlib.h>
-
-#define INIT_KERNEL_4x4()  \
-  do {                     \
-    v24 = vdup_n_f16(0.F); \
-    v25 = vdup_n_f16(0.F); \
-    v26 = vdup_n_f16(0.F); \
-    v27 = vdup_n_f16(0.F); \
-  } while (0)
-
-// 1. Partial sum 256 digits
-#define KERNEL_4x4_ACC16()                 \
-  do {                                     \
-    dv0 = vld1_f16(a);                     \
-    vb0 = vld1_f16(b);                     \
-    v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
-    v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
-    v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
-    v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
-    dv1 = vld1_f16(a + 4);                 \
-    vb1 = vld1_f16(b + 4);                 \
-    v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
-    v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
-    v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
-    v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
-    dv2 = vld1_f16(a + 4 * 2);             \
-    vb2 = vld1_f16(b + 4 * 2);             \
-    v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
-    v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
-    v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
-    v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
-    dv3 = vld1_f16(a + 4 * 3);             \
-    vb3 = vld1_f16(b + 4 * 3);             \
-    v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
-    v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
-    v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
-    v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
-    dv4 = vld1_f16(a + 4 * 4);             \
-    vb4 = vld1_f16(b + 4 * 4);             \
-    v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
-    v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
-    v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
-    v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
-    dv5 = vld1_f16(a + 4 * 5);             \
-    vb5 = vld1_f16(b + 4 * 5);             \
-    v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
-    v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
-    v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
-    v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
-    dv6 = vld1_f16(a + 4 * 6);             \
-    vb6 = vld1_f16(b + 4 * 6);             \
-    v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
-    v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
-    v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
-    v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
-    dv7 = vld1_f16(a + 4 * 7);             \
-    vb7 = vld1_f16(b + 4 * 7);             \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 8);             \
-    vb7 = vld1_f16(b + 4 * 8);             \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 9);             \
-    vb7 = vld1_f16(b + 4 * 9);             \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 10);            \
-    vb7 = vld1_f16(b + 4 * 10);            \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 11);            \
-    vb7 = vld1_f16(b + 4 * 11);            \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 12);            \
-    vb7 = vld1_f16(b + 4 * 12);            \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 13);            \
-    vb7 = vld1_f16(b + 4 * 13);            \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 14);            \
-    vb7 = vld1_f16(b + 4 * 14);            \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 15);            \
-    vb7 = vld1_f16(b + 4 * 15);            \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    l += 16;                               \
-    __builtin_prefetch(b + 64, 0, 3);      \
-    __builtin_prefetch(a + 64, 0, 3);      \
-    b += 4 * 16;                           \
-    a += 4 * 16;                           \
-  } while (0)
-
-// 2. Partial sum 128 digits
-#define KERNEL_4x4_ACC8()                  \
-  do {                                     \
-    dv0 = vld1_f16(a);                     \
-    vb0 = vld1_f16(b);                     \
-    v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
-    v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
-    v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
-    v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
-    dv1 = vld1_f16(a + 4);                 \
-    vb1 = vld1_f16(b + 4);                 \
-    v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
-    v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
-    v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
-    v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
-    dv2 = vld1_f16(a + 8);                 \
-    vb2 = vld1_f16(b + 8);                 \
-    v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
-    v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
-    v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
-    v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
-    dv3 = vld1_f16(a + 12);                \
-    vb3 = vld1_f16(b + 12);                \
-    v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
-    v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
-    v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
-    v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
-    dv4 = vld1_f16(a + 16);                \
-    vb4 = vld1_f16(b + 16);                \
-    v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
-    v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
-    v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
-    v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
-    dv5 = vld1_f16(a + 20);                \
-    vb5 = vld1_f16(b + 20);                \
-    v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
-    v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
-    v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
-    v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
-    dv6 = vld1_f16(a + 24);                \
-    vb6 = vld1_f16(b + 24);                \
-    v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
-    v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
-    v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
-    v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
-    dv7 = vld1_f16(a + 28);                \
-    vb7 = vld1_f16(b + 28);                \
-    v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
-    v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
-    v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
-    v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-    l += 8;                                \
-    __builtin_prefetch(b + 32, 0, 3);      \
-    __builtin_prefetch(a + 32, 0, 3);      \
-    b += 4 * 8;                            \
-    a += 4 * 8;                            \
-  } while (0)
-
-// 3. Partial sum 16 digits
-#define KERNEL_4x4_ACC1()                  \
-  do {                                     \
-    dv0 = vld1_f16(a);                     \
-    vb0 = vld1_f16(b);                     \
-    v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
-    v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
-    v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
-    v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
-    l += 1;                                \
-    __builtin_prefetch(b + 4, 0, 3);       \
-    __builtin_prefetch(a + 4, 0, 3);       \
-    b += 4 * 1;                            \
-    a += 4 * 1;                            \
-  } while (0)
-
-#define SAVE_KERNEL_4X4_F16_F32()                                         \
-  do {                                                                    \
-    vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24)));             \
-    vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(v25))); \
-    vst1q_f32(c + 2 * ldc,                                                \
-              vaddq_f32(vld1q_f32(c + 2 * ldc), vcvt_f32_f16(v26)));      \
-    vst1q_f32(c + 3 * ldc,                                                \
-              vaddq_f32(vld1q_f32(c + 3 * ldc), vcvt_f32_f16(v27)));      \
-  } while (0)
-
-/**
- * @brief hgemm 4x4 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading dimension of matrix C
- */
-void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
-
-  __fp16 *a = sa, *b = sb, *c = sc;
-  unsigned int i, j, l;
-  for (i = 0; i < M; i += 4) {
-    for (j = 0; j < N; j += 4) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-
-      float16x4_t v24;
-      float16x4_t v25;
-      float16x4_t v26;
-      float16x4_t v27;
-      INIT_KERNEL_4x4();
-
-      for (l = 0; l < K; l += 4) {
-        float16x4_t v0 = vld1_f16(b);
-        float16x4_t v16 = vld1_f16(a);
-
-        v24 = vfma_lane_f16(v24, v0, v16, 0);
-        v25 = vfma_lane_f16(v25, v0, v16, 1);
-        v26 = vfma_lane_f16(v26, v0, v16, 2);
-        v27 = vfma_lane_f16(v27, v0, v16, 3);
-
-        float16x4_t v1 = vld1_f16(b + 4);
-        float16x4_t v17 = vld1_f16(a + 4);
-
-        v24 = vfma_lane_f16(v24, v1, v17, 0);
-        v25 = vfma_lane_f16(v25, v1, v17, 1);
-        v26 = vfma_lane_f16(v26, v1, v17, 2);
-        v27 = vfma_lane_f16(v27, v1, v17, 3);
-
-        float16x4_t v2 = vld1_f16(b + 8);
-        float16x4_t v18 = vld1_f16(a + 8);
-
-        v24 = vfma_lane_f16(v24, v2, v18, 0);
-        v25 = vfma_lane_f16(v25, v2, v18, 1);
-        v26 = vfma_lane_f16(v26, v2, v18, 2);
-        v27 = vfma_lane_f16(v27, v2, v18, 3);
-
-        float16x4_t v3 = vld1_f16(b + 12);
-        float16x4_t v19 = vld1_f16(a + 12);
-
-        v24 = vfma_lane_f16(v24, v3, v19, 0);
-        v25 = vfma_lane_f16(v25, v3, v19, 1);
-        v26 = vfma_lane_f16(v26, v3, v19, 2);
-        v27 = vfma_lane_f16(v27, v3, v19, 3);
-
-        __builtin_prefetch(b + 16, 0, 3);
-        __builtin_prefetch(a + 16, 0, 3);
-
-        b += 16;
-        a += 16;
-      }
-
-      v24 = vadd_f16(vld1_f16(c), v24);
-      v25 = vadd_f16(vld1_f16(c + ldc), v25);
-      v26 = vadd_f16(vld1_f16(c + 2 * ldc), v26);
-      v27 = vadd_f16(vld1_f16(c + 3 * ldc), v27);
-
-      vst1_f16(c, v24);
-      vst1_f16(c + ldc, v25);
-      vst1_f16(c + 2 * ldc, v26);
-      vst1_f16(c + 3 * ldc, v27);
-
-      c += 4;
-      a -= 4 * K;
-    }
-    sc += ldc * 4;
-    c = sc;
-    a += 4 * K;
-    b = sb;
-  }
-}
-
-/**
- * @brief hgemm 4x4 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading dimension of matrix C
- */
-void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
-
-  __fp16 *a = sa, *b = sb;
-  float *c = sc;
-  unsigned int i, j, l;
-  unsigned int K16 = (K >> 4) << 4;
-  unsigned int K8 = (K >> 3) << 3;
-  for (i = 0; i < M; i += 4) {
-    for (j = 0; j < N; j += 4) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-
-      float16x4_t v24, v25, v26, v27;
-      float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
-      float16x4_t vb0, vb1, vb2, vb3, vb4, vb5, vb6, vb7;
-      l = 0;
-      for (; l < K16;) {
-        INIT_KERNEL_4x4();
-        KERNEL_4x4_ACC16();
-        SAVE_KERNEL_4X4_F16_F32();
-      }
-      for (; l < K8;) {
-        INIT_KERNEL_4x4();
-        KERNEL_4x4_ACC8();
-        SAVE_KERNEL_4X4_F16_F32();
-      }
-      for (; l < K;) {
-        INIT_KERNEL_4x4();
-        KERNEL_4x4_ACC1();
-        SAVE_KERNEL_4X4_F16_F32();
-      }
-
-      c += 4;
-      a -= 4 * K;
-    }
-    sc += ldc * 4;
-    c = sc;
-    a += 4 * K;
-    b = sb;
-  }
-}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_4x8.h b/nntrainer/tensor/hgemm/hgemm_kernel_4x8.h
deleted file mode 100644 (file)
index b1757bb..0000000
+++ /dev/null
@@ -1,366 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file   hgemm_kernel_4x8.h
- * @date   03 April 2024
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug    No known bugs except for NYI items
- * @brief  This is half-precision GEMM 8x8 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <stdlib.h>
-
-#define INIT_KERNEL_4X8()  \
-  do {                     \
-    v0 = vdupq_n_f16(0.F); \
-    v3 = vdupq_n_f16(0.F); \
-    v6 = vdupq_n_f16(0.F); \
-    v9 = vdupq_n_f16(0.F); \
-  } while (0)
-
-// 1. Partial sum 256 digits
-#define KERNEL_4x8_ACC16()                \
-  do {                                    \
-    dv0 = vld1_f16(a);                    \
-    v24 = vld1q_f16(b);                   \
-    v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
-    v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
-    v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
-    v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
-    dv1 = vld1_f16(a + 4);                \
-    v25 = vld1q_f16(b + 8);               \
-    v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
-    v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
-    v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
-    v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
-    dv2 = vld1_f16(a + 4 * 2);            \
-    v26 = vld1q_f16(b + 8 * 2);           \
-    v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
-    v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
-    v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
-    v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
-    dv3 = vld1_f16(a + 4 * 3);            \
-    v27 = vld1q_f16(b + 8 * 3);           \
-    v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
-    v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
-    v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
-    v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
-    dv4 = vld1_f16(a + 4 * 4);            \
-    v28 = vld1q_f16(b + 8 * 4);           \
-    v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
-    v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
-    v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
-    v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
-    dv5 = vld1_f16(a + 4 * 5);            \
-    v29 = vld1q_f16(b + 8 * 5);           \
-    v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
-    v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
-    v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
-    v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
-    dv6 = vld1_f16(a + 4 * 6);            \
-    v30 = vld1q_f16(b + 8 * 6);           \
-    v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
-    v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
-    v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
-    v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
-    dv7 = vld1_f16(a + 4 * 7);            \
-    v31 = vld1q_f16(b + 8 * 7);           \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 8);            \
-    v31 = vld1q_f16(b + 8 * 8);           \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 9);            \
-    v31 = vld1q_f16(b + 8 * 9);           \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 10);           \
-    v31 = vld1q_f16(b + 8 * 10);          \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 11);           \
-    v31 = vld1q_f16(b + 8 * 11);          \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 12);           \
-    v31 = vld1q_f16(b + 8 * 12);          \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 13);           \
-    v31 = vld1q_f16(b + 8 * 13);          \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 14);           \
-    v31 = vld1q_f16(b + 8 * 14);          \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    dv7 = vld1_f16(a + 4 * 15);           \
-    v31 = vld1q_f16(b + 8 * 15);          \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    l += 16;                              \
-    __builtin_prefetch(b + 128, 0, 3);    \
-    __builtin_prefetch(a + 64, 0, 3);     \
-    b += 8 * 16;                          \
-    a += 4 * 16;                          \
-  } while (0)
-
-// 1. Partial sum 256 digits
-#define KERNEL_4x8_ACC8()                 \
-  do {                                    \
-    dv0 = vld1_f16(a);                    \
-    v24 = vld1q_f16(b);                   \
-    v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
-    v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
-    v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
-    v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
-    dv1 = vld1_f16(a + 4);                \
-    v25 = vld1q_f16(b + 8);               \
-    v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
-    v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
-    v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
-    v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
-    dv2 = vld1_f16(a + 8);                \
-    v26 = vld1q_f16(b + 16);              \
-    v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
-    v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
-    v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
-    v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
-    dv3 = vld1_f16(a + 12);               \
-    v27 = vld1q_f16(b + 24);              \
-    v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
-    v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
-    v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
-    v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
-    dv4 = vld1_f16(a + 16);               \
-    v28 = vld1q_f16(b + 32);              \
-    v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
-    v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
-    v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
-    v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
-    dv5 = vld1_f16(a + 20);               \
-    v29 = vld1q_f16(b + 40);              \
-    v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
-    v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
-    v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
-    v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
-    dv6 = vld1_f16(a + 24);               \
-    v30 = vld1q_f16(b + 48);              \
-    v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
-    v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
-    v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
-    v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
-    dv7 = vld1_f16(a + 28);               \
-    v31 = vld1q_f16(b + 56);              \
-    v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
-    v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
-    v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
-    v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
-    l += 8;                               \
-    __builtin_prefetch(b + 64, 0, 3);     \
-    __builtin_prefetch(a + 32, 0, 3);     \
-    b += 8 * 8;                           \
-    a += 4 * 8;                           \
-  } while (0)
-
-// 2. Partial sum 128 digits
-#define KERNEL_4x8_ACC4()                 \
-  do {                                    \
-    dv0 = vld1_f16(a);                    \
-    v24 = vld1q_f16(b);                   \
-    v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
-    v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
-    v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
-    v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
-    dv1 = vld1_f16(a + 4);                \
-    v25 = vld1q_f16(b + 8);               \
-    v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
-    v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
-    v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
-    v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
-    dv2 = vld1_f16(a + 8);                \
-    v26 = vld1q_f16(b + 16);              \
-    v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
-    v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
-    v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
-    v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
-    dv3 = vld1_f16(a + 12);               \
-    v27 = vld1q_f16(b + 24);              \
-    v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
-    v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
-    v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
-    v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
-    l += 4;                               \
-    __builtin_prefetch(b + 32, 0, 3);     \
-    __builtin_prefetch(a + 16, 0, 3);     \
-    b += 8 * 4;                           \
-    a += 4 * 4;                           \
-  } while (0)
-
-// 3. Partial sum 32 digits
-#define KERNEL_4x8_ACC1()                 \
-  do {                                    \
-    dv0 = vld1_f16(a);                    \
-    v24 = vld1q_f16(b);                   \
-    v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
-    v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
-    v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
-    v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
-    l += 1;                               \
-    __builtin_prefetch(b + 8, 0, 3);      \
-    __builtin_prefetch(a + 4, 0, 3);      \
-    b += 8 * 1;                           \
-    a += 4 * 1;                           \
-  } while (0)
-
-#define SAVE_KERNEL_4X8_F16_F32()                                             \
-  do {                                                                        \
-    vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0))));    \
-    vst1q_f32(c + ldc,                                                        \
-              vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v3)))); \
-    vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc),                  \
-                                     vcvt_f32_f16(vget_low_f16(v6))));        \
-    vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc),                  \
-                                     vcvt_f32_f16(vget_low_f16(v9))));        \
-                                                                              \
-    vst1q_f32(c + 4,                                                          \
-              vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0))));  \
-    vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc),                  \
-                                     vcvt_f32_f16(vget_high_f16(v3))));       \
-    vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc),          \
-                                         vcvt_f32_f16(vget_high_f16(v6))));   \
-    vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc),          \
-                                         vcvt_f32_f16(vget_high_f16(v9))));   \
-  } while (0)
-
-/**
- * @brief hgemm 4x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(M % 4 == 0 && N % 8 == 0);
-
-  __fp16 *a = sa, *b = sb, *c = sc;
-  unsigned int K8 = (K >> 3) << 3;
-  unsigned int i, j, l;
-  for (i = 0; i < M; i += 4) {
-    for (j = 0; j < N; j += 8) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-      float16x8_t v0, v3, v6, v9;
-      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
-      float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
-      INIT_KERNEL_4X8();
-      l = 0;
-      for (; l < K8;) {
-        KERNEL_4x8_ACC8();
-      }
-      for (; l < K;) {
-        KERNEL_4x8_ACC1();
-      }
-      vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
-      vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
-      vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
-      vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
-      c += 8;
-      a -= 4 * K;
-    }
-    sc += ldc * 4;
-    c = sc;
-    a += 4 * K;
-    b = sb;
-  }
-}
-
-/**
- * @brief hgemm 4x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(M % 4 == 0 && N % 8 == 0);
-
-  __fp16 *a = sa, *b = sb;
-  float *c = sc;
-  unsigned int K16 = (K >> 4) << 4;
-  unsigned int K8 = (K >> 3) << 3;
-  unsigned int K4 = (K >> 2) << 2;
-  unsigned int i, j, l;
-  for (i = 0; i < M; i += 4) {
-    for (j = 0; j < N; j += 8) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-      float16x8_t v0, v3, v6, v9;
-      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
-      float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
-      l = 0;
-      for (; l < K16;) {
-        INIT_KERNEL_4X8();
-        KERNEL_4x8_ACC16();
-        SAVE_KERNEL_4X8_F16_F32();
-      }
-      for (; l < K8;) {
-        INIT_KERNEL_4X8();
-        KERNEL_4x8_ACC8();
-        SAVE_KERNEL_4X8_F16_F32();
-      }
-      for (; l < K4;) {
-        INIT_KERNEL_4X8();
-        KERNEL_4x8_ACC4();
-        SAVE_KERNEL_4X8_F16_F32();
-      }
-      for (; l < K;) {
-        INIT_KERNEL_4X8();
-        KERNEL_4x8_ACC1();
-        SAVE_KERNEL_4X8_F16_F32();
-      }
-      c += 8;
-      a -= 4 * K;
-    }
-    sc += ldc * 4;
-    c = sc;
-    a += 4 * K;
-    b = sb;
-  }
-}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_8x16.h b/nntrainer/tensor/hgemm/hgemm_kernel_8x16.h
deleted file mode 100644 (file)
index d29cbfc..0000000
+++ /dev/null
@@ -1,862 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file   hgemm_kernel_8x16.h
- * @date   04 April 2024
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug    No known bugs except for NYI items
- * @brief  This is half-precision GEMM 8x16 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <iostream>
-#include <stdlib.h>
-
-#define INIT_KERNEL_8X16()       \
-  do {                           \
-    v0_7 = vdupq_n_f16(0.F);     \
-    v8_15 = vdupq_n_f16(0.F);    \
-    v16_23 = vdupq_n_f16(0.F);   \
-    v24_31 = vdupq_n_f16(0.F);   \
-    v32_39 = vdupq_n_f16(0.F);   \
-    v40_47 = vdupq_n_f16(0.F);   \
-    v48_55 = vdupq_n_f16(0.F);   \
-    v56_63 = vdupq_n_f16(0.F);   \
-    v64_71 = vdupq_n_f16(0.F);   \
-    v72_79 = vdupq_n_f16(0.F);   \
-    v80_87 = vdupq_n_f16(0.F);   \
-    v88_95 = vdupq_n_f16(0.F);   \
-    v96_103 = vdupq_n_f16(0.F);  \
-    v104_111 = vdupq_n_f16(0.F); \
-    v112_119 = vdupq_n_f16(0.F); \
-    v120_127 = vdupq_n_f16(0.F); \
-  } while (0)
-
-// 1. Partial sum 2048 digits
-#define KERNEL_8x16_ACC16()                            \
-  do {                                                 \
-    va0 = vld1q_f16(a + 8 * 0);                        \
-    vb1 = vld1q_f16(b + 8 * 0);                        \
-    vb2 = vld1q_f16(b + 8 * 1);                        \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 1);                        \
-    vb1 = vld1q_f16(b + 8 * 2);                        \
-    vb2 = vld1q_f16(b + 8 * 3);                        \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 2);                        \
-    vb1 = vld1q_f16(b + 8 * 4);                        \
-    vb2 = vld1q_f16(b + 8 * 5);                        \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 3);                        \
-    vb1 = vld1q_f16(b + 8 * 6);                        \
-    vb2 = vld1q_f16(b + 8 * 7);                        \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 4);                        \
-    vb1 = vld1q_f16(b + 8 * 8);                        \
-    vb2 = vld1q_f16(b + 8 * 9);                        \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 5);                        \
-    vb1 = vld1q_f16(b + 8 * 10);                       \
-    vb2 = vld1q_f16(b + 8 * 11);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 6);                        \
-    vb1 = vld1q_f16(b + 8 * 12);                       \
-    vb2 = vld1q_f16(b + 8 * 13);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 7);                        \
-    vb1 = vld1q_f16(b + 8 * 14);                       \
-    vb2 = vld1q_f16(b + 8 * 15);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 8);                        \
-    vb1 = vld1q_f16(b + 8 * 16);                       \
-    vb2 = vld1q_f16(b + 8 * 17);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 9);                        \
-    vb1 = vld1q_f16(b + 8 * 18);                       \
-    vb2 = vld1q_f16(b + 8 * 19);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 10);                       \
-    vb1 = vld1q_f16(b + 8 * 20);                       \
-    vb2 = vld1q_f16(b + 8 * 21);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 11);                       \
-    vb1 = vld1q_f16(b + 8 * 22);                       \
-    vb2 = vld1q_f16(b + 8 * 23);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 12);                       \
-    vb1 = vld1q_f16(b + 8 * 24);                       \
-    vb2 = vld1q_f16(b + 8 * 25);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 13);                       \
-    vb1 = vld1q_f16(b + 8 * 26);                       \
-    vb2 = vld1q_f16(b + 8 * 27);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 14);                       \
-    vb1 = vld1q_f16(b + 8 * 28);                       \
-    vb2 = vld1q_f16(b + 8 * 29);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 15);                       \
-    vb1 = vld1q_f16(b + 8 * 30);                       \
-    vb2 = vld1q_f16(b + 8 * 31);                       \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    __builtin_prefetch(b + 256, 0, 3);                 \
-    __builtin_prefetch(a + 128, 0, 3);                 \
-    l += 16;                                           \
-    b += 16 * 16;                                      \
-    a += 8 * 16;                                       \
-  } while (0)
-
-// 2. Partial sum 1024 digits
-#define KERNEL_8x16_ACC8()                             \
-  do {                                                 \
-    va0 = vld1q_f16(a);                                \
-    vb1 = vld1q_f16(b);                                \
-    vb2 = vld1q_f16(b + 8);                            \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8);                            \
-    vb1 = vld1q_f16(b + 16);                           \
-    vb2 = vld1q_f16(b + 24);                           \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 16);                           \
-    vb1 = vld1q_f16(b + 32);                           \
-    vb2 = vld1q_f16(b + 40);                           \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 24);                           \
-    vb1 = vld1q_f16(b + 48);                           \
-    vb2 = vld1q_f16(b + 56);                           \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 32);                           \
-    vb1 = vld1q_f16(b + 64);                           \
-    vb2 = vld1q_f16(b + 72);                           \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 40);                           \
-    vb1 = vld1q_f16(b + 80);                           \
-    vb2 = vld1q_f16(b + 88);                           \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 48);                           \
-    vb1 = vld1q_f16(b + 96);                           \
-    vb2 = vld1q_f16(b + 104);                          \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 56);                           \
-    vb1 = vld1q_f16(b + 112);                          \
-    vb2 = vld1q_f16(b + 120);                          \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    l += 8;                                            \
-    __builtin_prefetch(b + 128, 0, 3);                 \
-    __builtin_prefetch(a + 64, 0, 3);                  \
-    b += 16 * 8;                                       \
-    a += 8 * 8;                                        \
-  } while (0)
-
-// 3. Partial sum 512 digits
-#define KERNEL_8x16_ACC4()                             \
-  do {                                                 \
-    va0 = vld1q_f16(a);                                \
-    vb1 = vld1q_f16(b);                                \
-    vb2 = vld1q_f16(b + 8);                            \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 8);                            \
-    vb1 = vld1q_f16(b + 16);                           \
-    vb2 = vld1q_f16(b + 24);                           \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 16);                           \
-    vb1 = vld1q_f16(b + 32);                           \
-    vb2 = vld1q_f16(b + 40);                           \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    va0 = vld1q_f16(a + 24);                           \
-    vb1 = vld1q_f16(b + 48);                           \
-    vb2 = vld1q_f16(b + 56);                           \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    l += 4;                                            \
-    __builtin_prefetch(b + 64, 0, 3);                  \
-    __builtin_prefetch(a + 32, 0, 3);                  \
-    b += 16 * 4;                                       \
-    a += 8 * 4;                                        \
-  } while (0)
-
-// 4. Partial sum 128 digits
-#define KERNEL_8x16_ACC1()                             \
-  do {                                                 \
-    va0 = vld1q_f16(a);                                \
-    vb1 = vld1q_f16(b);                                \
-    vb2 = vld1q_f16(b + 8);                            \
-    v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
-    v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
-    v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
-    v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
-    v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
-    v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
-    v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
-    v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
-    v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
-    v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
-    v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
-    v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
-    v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
-    v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
-    v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
-    v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
-    l += 1;                                            \
-    __builtin_prefetch(b + 16, 0, 3);                  \
-    __builtin_prefetch(a + 8, 0, 3);                   \
-    b += 16 * 1;                                       \
-    a += 8 * 1;                                        \
-  } while (0)
-
-#define SAVE_KERNEL_8X16_F16_F32()                                             \
-  do {                                                                         \
-    vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0_7))));   \
-    vst1q_f32(c + 4,                                                           \
-              vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0_7)))); \
-                                                                               \
-    vst1q_f32(                                                                 \
-      c + 8, vaddq_f32(vld1q_f32(c + 8), vcvt_f32_f16(vget_low_f16(v64_71)))); \
-    vst1q_f32(c + 8 + 4, vaddq_f32(vld1q_f32(c + 8 + 4),                       \
-                                   vcvt_f32_f16(vget_high_f16(v64_71))));      \
-                                                                               \
-    vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc),                           \
-                                 vcvt_f32_f16(vget_low_f16(v8_15))));          \
-    vst1q_f32(c + ldc + 4, vaddq_f32(vld1q_f32(c + ldc + 4),                   \
-                                     vcvt_f32_f16(vget_high_f16(v8_15))));     \
-                                                                               \
-    vst1q_f32(c + ldc + 8, vaddq_f32(vld1q_f32(c + ldc + 8),                   \
-                                     vcvt_f32_f16(vget_low_f16(v72_79))));     \
-    vst1q_f32(c + ldc + 8 + 4,                                                 \
-              vaddq_f32(vld1q_f32(c + ldc + 8 + 4),                            \
-                        vcvt_f32_f16(vget_high_f16(v72_79))));                 \
-                                                                               \
-    vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v16_23))));     \
-    vst1q_f32(c + 2 * ldc + 4,                                                 \
-              vaddq_f32(vld1q_f32(c + 2 * ldc + 4),                            \
-                        vcvt_f32_f16(vget_high_f16(v16_23))));                 \
-                                                                               \
-    vst1q_f32(c + 2 * ldc + 8, vaddq_f32(vld1q_f32(c + 2 * ldc + 8),           \
-                                         vcvt_f32_f16(vget_low_f16(v80_87)))); \
-    vst1q_f32(c + 2 * ldc + 8 + 4,                                             \
-              vaddq_f32(vld1q_f32(c + 2 * ldc + 8 + 4),                        \
-                        vcvt_f32_f16(vget_high_f16(v80_87))));                 \
-                                                                               \
-    vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v24_31))));     \
-    vst1q_f32(c + 3 * ldc + 4,                                                 \
-              vaddq_f32(vld1q_f32(c + 3 * ldc + 4),                            \
-                        vcvt_f32_f16(vget_high_f16(v24_31))));                 \
-                                                                               \
-    vst1q_f32(c + 3 * ldc + 8, vaddq_f32(vld1q_f32(c + 3 * ldc + 8),           \
-                                         vcvt_f32_f16(vget_low_f16(v88_95)))); \
-    vst1q_f32(c + 3 * ldc + 8 + 4,                                             \
-              vaddq_f32(vld1q_f32(c + 3 * ldc + 8 + 4),                        \
-                        vcvt_f32_f16(vget_high_f16(v88_95))));                 \
-                                                                               \
-    vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v32_39))));     \
-    vst1q_f32(c + 4 * ldc + 4,                                                 \
-              vaddq_f32(vld1q_f32(c + 4 * ldc + 4),                            \
-                        vcvt_f32_f16(vget_high_f16(v32_39))));                 \
-                                                                               \
-    vst1q_f32(c + 4 * ldc + 8,                                                 \
-              vaddq_f32(vld1q_f32(c + 4 * ldc + 8),                            \
-                        vcvt_f32_f16(vget_low_f16(v96_103))));                 \
-    vst1q_f32(c + 4 * ldc + 8 + 4,                                             \
-              vaddq_f32(vld1q_f32(c + 4 * ldc + 8 + 4),                        \
-                        vcvt_f32_f16(vget_high_f16(v96_103))));                \
-                                                                               \
-    vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v40_47))));     \
-    vst1q_f32(c + 5 * ldc + 4,                                                 \
-              vaddq_f32(vld1q_f32(c + 5 * ldc + 4),                            \
-                        vcvt_f32_f16(vget_high_f16(v40_47))));                 \
-    vst1q_f32(c + 5 * ldc + 8,                                                 \
-              vaddq_f32(vld1q_f32(c + 5 * ldc + 8),                            \
-                        vcvt_f32_f16(vget_low_f16(v104_111))));                \
-    vst1q_f32(c + 5 * ldc + 8 + 4,                                             \
-              vaddq_f32(vld1q_f32(c + 5 * ldc + 8 + 4),                        \
-                        vcvt_f32_f16(vget_high_f16(v104_111))));               \
-                                                                               \
-    vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v48_55))));     \
-    vst1q_f32(c + 6 * ldc + 4,                                                 \
-              vaddq_f32(vld1q_f32(c + 6 * ldc + 4),                            \
-                        vcvt_f32_f16(vget_high_f16(v48_55))));                 \
-                                                                               \
-    vst1q_f32(c + 6 * ldc + 8,                                                 \
-              vaddq_f32(vld1q_f32(c + 6 * ldc + 8),                            \
-                        vcvt_f32_f16(vget_low_f16(v112_119))));                \
-    vst1q_f32(c + 6 * ldc + 8 + 4,                                             \
-              vaddq_f32(vld1q_f32(c + 6 * ldc + 8 + 4),                        \
-                        vcvt_f32_f16(vget_high_f16(v112_119))));               \
-                                                                               \
-    vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v56_63))));     \
-    vst1q_f32(c + 7 * ldc + 4,                                                 \
-              vaddq_f32(vld1q_f32(c + 7 * ldc + 4),                            \
-                        vcvt_f32_f16(vget_high_f16(v56_63))));                 \
-                                                                               \
-    vst1q_f32(c + 7 * ldc + 8,                                                 \
-              vaddq_f32(vld1q_f32(c + 7 * ldc + 8),                            \
-                        vcvt_f32_f16(vget_low_f16(v120_127))));                \
-    vst1q_f32(c + 7 * ldc + 8 + 4,                                             \
-              vaddq_f32(vld1q_f32(c + 7 * ldc + 8 + 4),                        \
-                        vcvt_f32_f16(vget_high_f16(v120_127))));               \
-  } while (0)
-
-/**
- * @brief hgemm 8x16 kernel sc = sa * sb
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
-                       __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(M % 8 == 0 && N % 16 == 0 && K % 8 == 0);
-
-  __fp16 *a = sa, *b = sb, *c = sc;
-  unsigned int i, j, l;
-  for (i = 0; i < M; i += 8) {
-    for (j = 0; j < N; j += 16) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-      // 8x16
-      float16x8_t v0_7, v8_15;
-      float16x8_t v16_23, v24_31;
-      float16x8_t v32_39, v40_47;
-      float16x8_t v48_55, v56_63;
-      float16x8_t v64_71, v72_79;
-      float16x8_t v80_87, v88_95;
-      float16x8_t v96_103, v104_111;
-      float16x8_t v112_119, v120_127;
-      float16x8_t vb1, vb2;
-      float16x8_t va0;
-
-      INIT_KERNEL_8X16();
-      l = 0;
-      for (; l < K;) {
-        KERNEL_8x16_ACC1();
-      }
-      vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
-      vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
-      vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
-      vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
-      vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23));
-      vst1q_f16(c + 2 * ldc + 8, vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87));
-      vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31));
-      vst1q_f16(c + 3 * ldc + 8, vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95));
-      vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39));
-      vst1q_f16(c + 4 * ldc + 8,
-                vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103));
-      vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47));
-      vst1q_f16(c + 5 * ldc + 8,
-                vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111));
-      vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55));
-      vst1q_f16(c + 6 * ldc + 8,
-                vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119));
-      vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
-      vst1q_f16(c + 7 * ldc + 8,
-                vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
-      c += 16;
-      a -= 8 * K;
-    }
-    sc += ldc * 8;
-    c = sc;
-    a += 8 * K;
-    b = sb;
-  }
-}
-
-/**
- * @brief hgemm 8x16 kernel sc = sa * sb
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
-                       __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(M % 8 == 0 && N % 16 == 0 && K % 4 == 0);
-
-  __fp16 *a = sa, *b = sb;
-  float *c = sc;
-  unsigned int i, j, l;
-  unsigned int K4 = (K >> 2) << 2;
-  unsigned int K8 = (K >> 3) << 3;
-  unsigned int K16 = (K >> 4) << 4;
-  for (i = 0; i < M; i += 8) {
-    for (j = 0; j < N; j += 16) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-      float16x8_t v0_7, v8_15;
-      float16x8_t v16_23, v24_31;
-      float16x8_t v32_39, v40_47;
-      float16x8_t v48_55, v56_63;
-      float16x8_t v64_71, v72_79;
-      float16x8_t v80_87, v88_95;
-      float16x8_t v96_103, v104_111;
-      float16x8_t v112_119, v120_127;
-      float16x8_t vb1, vb2;
-      float16x8_t va0;
-      l = 0;
-      for (; l < K16;) {
-        INIT_KERNEL_8X16();
-        KERNEL_8x16_ACC16();
-        SAVE_KERNEL_8X16_F16_F32();
-      }
-      for (; l < K8;) {
-        INIT_KERNEL_8X16();
-        KERNEL_8x16_ACC8();
-        SAVE_KERNEL_8X16_F16_F32();
-      }
-      for (; l < K4;) {
-        INIT_KERNEL_8X16();
-        KERNEL_8x16_ACC4();
-        SAVE_KERNEL_8X16_F16_F32();
-      }
-      for (; l < K;) {
-        INIT_KERNEL_8X16();
-        KERNEL_8x16_ACC1();
-        SAVE_KERNEL_8X16_F16_F32();
-      }
-      c += 16;
-      a -= 8 * K;
-    }
-    sc += ldc * 8;
-    c = sc;
-    a += 8 * K;
-    b = sb;
-  }
-}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_8x8.h b/nntrainer/tensor/hgemm/hgemm_kernel_8x8.h
deleted file mode 100644 (file)
index 2e3eb6a..0000000
+++ /dev/null
@@ -1,511 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file   hgemm_kernel_8x8.h
- * @date   01 April 2024
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug    No known bugs except for NYI items
- * @brief  This is half-precision GEMM 8x8 kernel
- *
- */
-
-#include <arm_neon.h>
-#include <assert.h>
-#include <stdlib.h>
-
-#define INIT_KERNEL_8x8()   \
-  do {                      \
-    v24 = vdupq_n_f16(0.F); \
-    v25 = vdupq_n_f16(0.F); \
-    v26 = vdupq_n_f16(0.F); \
-    v27 = vdupq_n_f16(0.F); \
-    v28 = vdupq_n_f16(0.F); \
-    v29 = vdupq_n_f16(0.F); \
-    v30 = vdupq_n_f16(0.F); \
-    v31 = vdupq_n_f16(0.F); \
-  } while (0)
-
-// 1. Partial sum 1024 digits
-#define KERNEL_8x8_ACC16()                   \
-  do {                                       \
-    va0 = vld1q_f16(a);                      \
-    v16 = vld1q_f16(b);                      \
-    v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
-    va0 = vld1q_f16(a + 8);                  \
-    v17 = vld1q_f16(b + 8);                  \
-    v24 = vfmaq_laneq_f16(v24, v17, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v17, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v17, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v17, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v17, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v17, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v17, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v17, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 2);              \
-    v18 = vld1q_f16(b + 8 * 2);              \
-    v24 = vfmaq_laneq_f16(v24, v18, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v18, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v18, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v18, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v18, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v18, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v18, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v18, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 3);              \
-    v19 = vld1q_f16(b + 8 * 3);              \
-    v24 = vfmaq_laneq_f16(v24, v19, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v19, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v19, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v19, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v19, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v19, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v19, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v19, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 4);              \
-    v20 = vld1q_f16(b + 8 * 4);              \
-    v24 = vfmaq_laneq_f16(v24, v20, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v20, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v20, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v20, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v20, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v20, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v20, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v20, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 5);              \
-    v21 = vld1q_f16(b + 8 * 5);              \
-    v24 = vfmaq_laneq_f16(v24, v21, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v21, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v21, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v21, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v21, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v21, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v21, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v21, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 6);              \
-    v22 = vld1q_f16(b + 8 * 6);              \
-    v24 = vfmaq_laneq_f16(v24, v22, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v22, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v22, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v22, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v22, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v22, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v22, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v22, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 7);              \
-    v23 = vld1q_f16(b + 8 * 7);              \
-    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 8);              \
-    v23 = vld1q_f16(b + 8 * 8);              \
-    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 9);              \
-    v23 = vld1q_f16(b + 8 * 9);              \
-    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 10);             \
-    v23 = vld1q_f16(b + 8 * 10);             \
-    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 11);             \
-    v23 = vld1q_f16(b + 8 * 11);             \
-    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 12);             \
-    v23 = vld1q_f16(b + 8 * 12);             \
-    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 13);             \
-    v23 = vld1q_f16(b + 8 * 13);             \
-    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 14);             \
-    v23 = vld1q_f16(b + 8 * 14);             \
-    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
-    va0 = vld1q_f16(a + 8 * 15);             \
-    v23 = vld1q_f16(b + 8 * 15);             \
-    v24 = vfmaq_laneq_f16(v24, v23, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va0, 7); \
-    __builtin_prefetch(b + 128, 0, 3);       \
-    __builtin_prefetch(a + 128, 0, 3);       \
-    l += 16;                                 \
-    b += 8 * 16;                             \
-    a += 8 * 16;                             \
-  } while (0)
-
-// 2. Partial sum 512 digits
-#define KERNEL_8x8_ACC8()                    \
-  do {                                       \
-    va0 = vld1q_f16(a);                      \
-    v16 = vld1q_f16(b);                      \
-    v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
-    va1 = vld1q_f16(a + 8);                  \
-    v17 = vld1q_f16(b + 8);                  \
-    v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \
-    v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \
-    v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \
-    v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \
-    v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \
-    v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \
-    v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \
-    v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \
-    va2 = vld1q_f16(a + 16);                 \
-    v18 = vld1q_f16(b + 16);                 \
-    v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \
-    v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \
-    v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \
-    v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \
-    v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \
-    v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \
-    v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \
-    v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \
-    va3 = vld1q_f16(a + 24);                 \
-    v19 = vld1q_f16(b + 24);                 \
-    v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \
-    v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \
-    v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \
-    v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \
-    v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \
-    v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \
-    v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \
-    v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \
-    va4 = vld1q_f16(a + 32);                 \
-    v20 = vld1q_f16(b + 32);                 \
-    v24 = vfmaq_laneq_f16(v24, v20, va4, 0); \
-    v25 = vfmaq_laneq_f16(v25, v20, va4, 1); \
-    v26 = vfmaq_laneq_f16(v26, v20, va4, 2); \
-    v27 = vfmaq_laneq_f16(v27, v20, va4, 3); \
-    v28 = vfmaq_laneq_f16(v28, v20, va4, 4); \
-    v29 = vfmaq_laneq_f16(v29, v20, va4, 5); \
-    v30 = vfmaq_laneq_f16(v30, v20, va4, 6); \
-    v31 = vfmaq_laneq_f16(v31, v20, va4, 7); \
-    va5 = vld1q_f16(a + 40);                 \
-    v21 = vld1q_f16(b + 40);                 \
-    v24 = vfmaq_laneq_f16(v24, v21, va5, 0); \
-    v25 = vfmaq_laneq_f16(v25, v21, va5, 1); \
-    v26 = vfmaq_laneq_f16(v26, v21, va5, 2); \
-    v27 = vfmaq_laneq_f16(v27, v21, va5, 3); \
-    v28 = vfmaq_laneq_f16(v28, v21, va5, 4); \
-    v29 = vfmaq_laneq_f16(v29, v21, va5, 5); \
-    v30 = vfmaq_laneq_f16(v30, v21, va5, 6); \
-    v31 = vfmaq_laneq_f16(v31, v21, va5, 7); \
-    va6 = vld1q_f16(a + 48);                 \
-    v22 = vld1q_f16(b + 48);                 \
-    v24 = vfmaq_laneq_f16(v24, v22, va6, 0); \
-    v25 = vfmaq_laneq_f16(v25, v22, va6, 1); \
-    v26 = vfmaq_laneq_f16(v26, v22, va6, 2); \
-    v27 = vfmaq_laneq_f16(v27, v22, va6, 3); \
-    v28 = vfmaq_laneq_f16(v28, v22, va6, 4); \
-    v29 = vfmaq_laneq_f16(v29, v22, va6, 5); \
-    v30 = vfmaq_laneq_f16(v30, v22, va6, 6); \
-    v31 = vfmaq_laneq_f16(v31, v22, va6, 7); \
-    va7 = vld1q_f16(a + 56);                 \
-    v23 = vld1q_f16(b + 56);                 \
-    v24 = vfmaq_laneq_f16(v24, v23, va7, 0); \
-    v25 = vfmaq_laneq_f16(v25, v23, va7, 1); \
-    v26 = vfmaq_laneq_f16(v26, v23, va7, 2); \
-    v27 = vfmaq_laneq_f16(v27, v23, va7, 3); \
-    v28 = vfmaq_laneq_f16(v28, v23, va7, 4); \
-    v29 = vfmaq_laneq_f16(v29, v23, va7, 5); \
-    v30 = vfmaq_laneq_f16(v30, v23, va7, 6); \
-    v31 = vfmaq_laneq_f16(v31, v23, va7, 7); \
-    __builtin_prefetch(b + 64, 0, 3);        \
-    __builtin_prefetch(a + 64, 0, 3);        \
-    l += 8;                                  \
-    b += 8 * 8;                              \
-    a += 8 * 8;                              \
-  } while (0)
-
-// 3. Partial sum 256 digits
-#define KERNEL_8x8_ACC4()                    \
-  do {                                       \
-    va0 = vld1q_f16(a);                      \
-    v16 = vld1q_f16(b);                      \
-    v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
-    va1 = vld1q_f16(a + 8);                  \
-    v17 = vld1q_f16(b + 8);                  \
-    v24 = vfmaq_laneq_f16(v24, v17, va1, 0); \
-    v25 = vfmaq_laneq_f16(v25, v17, va1, 1); \
-    v26 = vfmaq_laneq_f16(v26, v17, va1, 2); \
-    v27 = vfmaq_laneq_f16(v27, v17, va1, 3); \
-    v28 = vfmaq_laneq_f16(v28, v17, va1, 4); \
-    v29 = vfmaq_laneq_f16(v29, v17, va1, 5); \
-    v30 = vfmaq_laneq_f16(v30, v17, va1, 6); \
-    v31 = vfmaq_laneq_f16(v31, v17, va1, 7); \
-    va2 = vld1q_f16(a + 16);                 \
-    v18 = vld1q_f16(b + 16);                 \
-    v24 = vfmaq_laneq_f16(v24, v18, va2, 0); \
-    v25 = vfmaq_laneq_f16(v25, v18, va2, 1); \
-    v26 = vfmaq_laneq_f16(v26, v18, va2, 2); \
-    v27 = vfmaq_laneq_f16(v27, v18, va2, 3); \
-    v28 = vfmaq_laneq_f16(v28, v18, va2, 4); \
-    v29 = vfmaq_laneq_f16(v29, v18, va2, 5); \
-    v30 = vfmaq_laneq_f16(v30, v18, va2, 6); \
-    v31 = vfmaq_laneq_f16(v31, v18, va2, 7); \
-    va3 = vld1q_f16(a + 24);                 \
-    v19 = vld1q_f16(b + 24);                 \
-    v24 = vfmaq_laneq_f16(v24, v19, va3, 0); \
-    v25 = vfmaq_laneq_f16(v25, v19, va3, 1); \
-    v26 = vfmaq_laneq_f16(v26, v19, va3, 2); \
-    v27 = vfmaq_laneq_f16(v27, v19, va3, 3); \
-    v28 = vfmaq_laneq_f16(v28, v19, va3, 4); \
-    v29 = vfmaq_laneq_f16(v29, v19, va3, 5); \
-    v30 = vfmaq_laneq_f16(v30, v19, va3, 6); \
-    v31 = vfmaq_laneq_f16(v31, v19, va3, 7); \
-    __builtin_prefetch(b + 32, 0, 3);        \
-    __builtin_prefetch(a + 32, 0, 3);        \
-    l += 4;                                  \
-    b += 8 * 4;                              \
-    a += 8 * 4;                              \
-  } while (0)
-
-// 4. Partial sum 64 digits
-#define KERNEL_8x8_ACC1()                    \
-  do {                                       \
-    va0 = vld1q_f16(a);                      \
-    v16 = vld1q_f16(b);                      \
-    v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
-    v25 = vfmaq_laneq_f16(v25, v16, va0, 1); \
-    v26 = vfmaq_laneq_f16(v26, v16, va0, 2); \
-    v27 = vfmaq_laneq_f16(v27, v16, va0, 3); \
-    v28 = vfmaq_laneq_f16(v28, v16, va0, 4); \
-    v29 = vfmaq_laneq_f16(v29, v16, va0, 5); \
-    v30 = vfmaq_laneq_f16(v30, v16, va0, 6); \
-    v31 = vfmaq_laneq_f16(v31, v16, va0, 7); \
-    __builtin_prefetch(b + 8, 0, 3);         \
-    __builtin_prefetch(a + 8, 0, 3);         \
-    l += 1;                                  \
-    b += 8 * 1;                              \
-    a += 8 * 1;                              \
-  } while (0)
-
-#define SAVE_KERNEL_8X8_F16_f32()                                              \
-  do {                                                                         \
-    vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v24))));    \
-    vst1q_f32(c + 4,                                                           \
-              vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v24))));  \
-                                                                               \
-    vst1q_f32(c + ldc,                                                         \
-              vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v25)))); \
-    vst1q_f32(c + 4 + ldc, vaddq_f32(vld1q_f32(c + 4 + ldc),                   \
-                                     vcvt_f32_f16(vget_high_f16(v25))));       \
-                                                                               \
-    vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v26))));        \
-    vst1q_f32(c + 4 + 2 * ldc, vaddq_f32(vld1q_f32(c + 4 + 2 * ldc),           \
-                                         vcvt_f32_f16(vget_high_f16(v26))));   \
-                                                                               \
-    vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v27))));        \
-    vst1q_f32(c + 4 + 3 * ldc, vaddq_f32(vld1q_f32(c + 4 + 3 * ldc),           \
-                                         vcvt_f32_f16(vget_high_f16(v27))));   \
-                                                                               \
-    vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v28))));        \
-    vst1q_f32(c + 4 + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 + 4 * ldc),           \
-                                         vcvt_f32_f16(vget_high_f16(v28))));   \
-                                                                               \
-    vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v29))));        \
-    vst1q_f32(c + 4 + 5 * ldc, vaddq_f32(vld1q_f32(c + 4 + 5 * ldc),           \
-                                         vcvt_f32_f16(vget_high_f16(v29))));   \
-                                                                               \
-    vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v30))));        \
-    vst1q_f32(c + 4 + 6 * ldc, vaddq_f32(vld1q_f32(c + 4 + 6 * ldc),           \
-                                         vcvt_f32_f16(vget_high_f16(v30))));   \
-                                                                               \
-    vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc),                   \
-                                     vcvt_f32_f16(vget_low_f16(v31))));        \
-    vst1q_f32(c + 4 + 7 * ldc, vaddq_f32(vld1q_f32(c + 4 + 7 * ldc),           \
-                                         vcvt_f32_f16(vget_high_f16(v31))));   \
-  } while (0)
-
-/**
- * @brief hgemm 8x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(M % 8 == 0 && N % 8 == 0 && K % 4 == 0);
-
-  __fp16 *a = sa, *b = sb, *c = sc;
-  unsigned int i, j, l;
-  for (i = 0; i < M; i += 8) {
-    for (j = 0; j < N; j += 8) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-
-      float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
-      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
-      float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
-      INIT_KERNEL_8x8();
-      l = 0;
-      for (; l < K;) {
-        KERNEL_8x8_ACC1();
-      }
-      vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24));
-      vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25));
-      vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26));
-      vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27));
-      vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28));
-      vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29));
-      vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30));
-      vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31));
-      c += 8;
-      a -= 8 * K;
-    }
-    sc += ldc * 8;
-    c = sc;
-    a += 8 * K;
-    b = sb;
-  }
-}
-
-/**
- * @brief hgemm 8x8 kernel sc = sa * sb
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param sa sub-matrix of input matrix A
- * @param sb sub-matrix of input matrix B
- * @param sc sub-matrix of output matrix C
- * @param ldc leading-dimension of matrix C
- */
-void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
-                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
-  assert(M > 0 && N > 0 && K > 0);
-  assert(M % 8 == 0 && N % 8 == 0 && K % 8 == 0);
-
-  __fp16 *a = sa, *b = sb;
-  float *c = sc;
-  unsigned int i, j, l;
-  unsigned int K4 = (K >> 2) << 2;
-  unsigned int K8 = (K >> 3) << 3;
-  unsigned int K16 = (K >> 4) << 4;
-  for (i = 0; i < M; i += 8) {
-    for (j = 0; j < N; j += 8) {
-      __builtin_prefetch(b, 0, 3);
-      __builtin_prefetch(a, 0, 3);
-
-      float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
-      float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
-      float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
-      l = 0;
-      for (; l < K16;) {
-        INIT_KERNEL_8x8();
-        KERNEL_8x8_ACC16();
-        SAVE_KERNEL_8X8_F16_f32();
-      }
-      for (; l < K8;) {
-        INIT_KERNEL_8x8();
-        KERNEL_8x8_ACC8();
-        SAVE_KERNEL_8X8_F16_f32();
-      }
-      for (; l < K4;) {
-        INIT_KERNEL_8x8();
-        KERNEL_8x8_ACC4();
-        SAVE_KERNEL_8X8_F16_f32();
-      }
-      for (; l < K;) {
-        INIT_KERNEL_8x8();
-        KERNEL_8x8_ACC1();
-        SAVE_KERNEL_8X8_F16_f32();
-      }
-
-      c += 8;
-      a -= 8 * K;
-    }
-    sc += ldc * 8;
-    c = sc;
-    a += 8 * K;
-    b = sb;
-  }
-}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_pack.cpp b/nntrainer/tensor/hgemm/hgemm_kernel_pack.cpp
deleted file mode 100644 (file)
index 649f6f3..0000000
+++ /dev/null
@@ -1,449 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file   hgemm_kernel_pack.cpp
- * @date   02 July 2024
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @bug    No known bugs except for NYI items
- * @brief  This is a source file for half-precision packing for the matrix
- * multiplication
- */
-
-#include <assert.h>
-#include <hgemm_common.h>
-#include <hgemm_kernel_pack.h>
-#include <matrix_transpose_neon.h>
-
-void packing_A1(unsigned int m, unsigned int k, const __fp16 *from,
-                unsigned int lda, const __fp16 *to) {
-
-  assert(k != 0 && m != 0 && k % 4 == 0 && m % 4 == 0);
-  unsigned int i, j;
-
-  __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4;
-  __fp16 *b_offset;
-  __fp16 ctemp1, ctemp2, ctemp3, ctemp4;
-
-  a_offset = (__fp16 *)from;
-  b_offset = (__fp16 *)to;
-
-  j = m;
-  do {
-    a_offset1 = a_offset;
-    a_offset += lda;
-
-    i = (k >> 2);
-    do {
-      ctemp1 = *(a_offset1 + 0);
-      ctemp2 = *(a_offset1 + 1);
-      ctemp3 = *(a_offset1 + 2);
-      ctemp4 = *(a_offset1 + 3);
-
-      *(b_offset + 0) = ctemp1;
-      *(b_offset + 1) = ctemp2;
-      *(b_offset + 2) = ctemp3;
-      *(b_offset + 3) = ctemp4;
-
-      a_offset1 += 4;
-
-      b_offset += 4;
-      i--;
-    } while (i > 0);
-    j--;
-  } while (j > 0);
-}
-
-void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
-                unsigned int lda, const __fp16 *dst) {
-
-  assert(K != 0 && M != 0 && K % 4 == 0 && M % 4 == 0);
-  unsigned int i, j;
-
-  __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
-  __fp16 *b_off;
-  __fp16 c1, c2, c3, c4;
-  __fp16 c5, c6, c7, c8;
-  __fp16 c9, c10, c11, c12;
-  __fp16 c13, c14, c15, c16;
-
-  a_off = (__fp16 *)src;
-  b_off = (__fp16 *)dst;
-
-  j = (M >> 2);
-  do {
-    a_off1 = a_off;
-    a_off2 = a_off1 + lda;
-    a_off3 = a_off2 + lda;
-    a_off4 = a_off3 + lda;
-    a_off += 4 * lda;
-
-    i = (K >> 2);
-    do {
-      c1 = *(a_off1 + 0);
-      c2 = *(a_off1 + 1);
-      c3 = *(a_off1 + 2);
-      c4 = *(a_off1 + 3);
-
-      c5 = *(a_off2 + 0);
-      c6 = *(a_off2 + 1);
-      c7 = *(a_off2 + 2);
-      c8 = *(a_off2 + 3);
-
-      c9 = *(a_off3 + 0);
-      c10 = *(a_off3 + 1);
-      c11 = *(a_off3 + 2);
-      c12 = *(a_off3 + 3);
-
-      c13 = *(a_off4 + 0);
-      c14 = *(a_off4 + 1);
-      c15 = *(a_off4 + 2);
-      c16 = *(a_off4 + 3);
-
-      *(b_off + 0) = c1;
-      *(b_off + 1) = c5;
-      *(b_off + 2) = c9;
-      *(b_off + 3) = c13;
-
-      *(b_off + 4) = c2;
-      *(b_off + 5) = c6;
-      *(b_off + 6) = c10;
-      *(b_off + 7) = c14;
-
-      *(b_off + 8) = c3;
-      *(b_off + 9) = c7;
-      *(b_off + 10) = c11;
-      *(b_off + 11) = c15;
-
-      *(b_off + 12) = c4;
-      *(b_off + 13) = c8;
-      *(b_off + 14) = c12;
-      *(b_off + 15) = c16;
-
-      a_off1 += 4;
-      a_off2 += 4;
-      a_off3 += 4;
-      a_off4 += 4;
-
-      b_off += 16;
-      i--;
-    } while (i > 0);
-    j--;
-  } while (j > 0);
-}
-
-void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
-                unsigned int lda, const __fp16 *dst) {
-
-  assert(K != 0 && M != 0 && K % 8 == 0 && M % 8 == 0);
-
-  uint16x4_t msk = {0xFFFF, 0xFFFF, 0x0000, 0x0000};
-  uint16x4_t inv_msk = {0x0000, 0x0000, 0xFFFF, 0xFFFF};
-
-  const __fp16 *a_off = (__fp16 *)src;
-  __fp16 *b_off = (__fp16 *)dst;
-
-  for (unsigned int i = 0; i < M; i += 8) {
-    const __fp16 *a_off1 = a_off;
-    const __fp16 *a_off2 = a_off1 + lda;
-    const __fp16 *a_off3 = a_off2 + lda;
-    const __fp16 *a_off4 = a_off3 + lda;
-    const __fp16 *a_off5 = a_off4 + lda;
-    const __fp16 *a_off6 = a_off5 + lda;
-    const __fp16 *a_off7 = a_off6 + lda;
-    const __fp16 *a_off8 = a_off7 + lda;
-    a_off += 8 * lda;
-
-    for (unsigned int j = 0; j < K; j += 8) {
-      float16x8_t _v0 = vld1q_f16(a_off1);
-      float16x8_t _v1 = vld1q_f16(a_off2);
-      float16x8_t _v2 = vld1q_f16(a_off3);
-      float16x8_t _v3 = vld1q_f16(a_off4);
-
-      float16x8_t _v4 = vld1q_f16(a_off5);
-      float16x8_t _v5 = vld1q_f16(a_off6);
-      float16x8_t _v6 = vld1q_f16(a_off7);
-      float16x8_t _v7 = vld1q_f16(a_off8);
-
-      a_off1 += 8;
-      a_off2 += 8;
-      a_off3 += 8;
-      a_off4 += 8;
-      a_off5 += 8;
-      a_off6 += 8;
-      a_off7 += 8;
-      a_off8 += 8;
-
-      float16x8x2_t _vv0 = vtrnq_f16(_v0, _v1);
-      float16x8x2_t _vv1 = vtrnq_f16(_v2, _v3);
-      float16x8x2_t _vv2 = vtrnq_f16(_v4, _v5);
-      float16x8x2_t _vv3 = vtrnq_f16(_v6, _v7);
-
-      float16x8_t _v8 =
-        vcombine_f16(vget_low_f16(_vv0.val[0]), vget_low_f16(_vv1.val[0]));
-      float16x8_t _v9 =
-        vcombine_f16(vget_low_f16(_vv0.val[1]), vget_low_f16(_vv1.val[1]));
-      float16x8_t _v10 =
-        vcombine_f16(vget_high_f16(_vv0.val[0]), vget_high_f16(_vv1.val[0]));
-      float16x8_t _v11 =
-        vcombine_f16(vget_high_f16(_vv0.val[1]), vget_high_f16(_vv1.val[1]));
-
-      float16x8_t _v12 =
-        vcombine_f16(vget_low_f16(_vv2.val[0]), vget_low_f16(_vv3.val[0]));
-      float16x8_t _v13 =
-        vcombine_f16(vget_low_f16(_vv2.val[1]), vget_low_f16(_vv3.val[1]));
-      float16x8_t _v14 =
-        vcombine_f16(vget_high_f16(_vv2.val[0]), vget_high_f16(_vv3.val[0]));
-      float16x8_t _v15 =
-        vcombine_f16(vget_high_f16(_vv2.val[1]), vget_high_f16(_vv3.val[1]));
-
-      // pack-in-pack
-      float16x4_t tmp_low_v8 = vget_low_f16(_v8);
-      float16x4_t tmp_high_v8 = vget_high_f16(_v8);
-      float16x4_t mid_v8 = vext_f16(tmp_low_v8, tmp_high_v8, 2);
-
-      float16x4_t tmp_low_v9 = vget_low_f16(_v9);
-      float16x4_t tmp_high_v9 = vget_high_f16(_v9);
-      float16x4_t mid_v9 = vext_f16(tmp_low_v9, tmp_high_v9, 2);
-
-      float16x4_t tmp_low_v10 = vget_low_f16(_v10);
-      float16x4_t tmp_high_v10 = vget_high_f16(_v10);
-      float16x4_t mid_v10 = vext_f16(tmp_low_v10, tmp_high_v10, 2);
-
-      float16x4_t tmp_low_v11 = vget_low_f16(_v11);
-      float16x4_t tmp_high_v11 = vget_high_f16(_v11);
-      float16x4_t mid_v11 = vext_f16(tmp_low_v11, tmp_high_v11, 2);
-
-      float16x4_t tmp_low_v12 = vget_low_f16(_v12);
-      float16x4_t tmp_high_v12 = vget_high_f16(_v12);
-      float16x4_t mid_v12 = vext_f16(tmp_low_v12, tmp_high_v12, 2);
-
-      float16x4_t tmp_low_v13 = vget_low_f16(_v13);
-      float16x4_t tmp_high_v13 = vget_high_f16(_v13);
-      float16x4_t mid_v13 = vext_f16(tmp_low_v13, tmp_high_v13, 2);
-
-      float16x4_t tmp_low_v14 = vget_low_f16(_v14);
-      float16x4_t tmp_high_v14 = vget_high_f16(_v14);
-      float16x4_t mid_v14 = vext_f16(tmp_low_v14, tmp_high_v14, 2);
-
-      float16x4_t tmp_low_v15 = vget_low_f16(_v15);
-      float16x4_t tmp_high_v15 = vget_high_f16(_v15);
-      float16x4_t mid_v15 = vext_f16(tmp_low_v15, tmp_high_v15, 2);
-
-      _v8 = vcombine_f16(vbsl_f16(msk, tmp_low_v8, mid_v8),
-                         vbsl_f16(msk, tmp_low_v12, mid_v12));
-      _v12 = vcombine_f16(vbsl_f16(msk, tmp_low_v9, mid_v9),
-                          vbsl_f16(msk, tmp_low_v13, mid_v13));
-      _v9 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v8, mid_v8),
-                         vbsl_f16(inv_msk, tmp_high_v12, mid_v12));
-      _v13 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v9, mid_v9),
-                          vbsl_f16(inv_msk, tmp_high_v13, mid_v13));
-      _v10 = vcombine_f16(vbsl_f16(msk, tmp_low_v10, mid_v10),
-                          vbsl_f16(msk, tmp_low_v14, mid_v14));
-      _v14 = vcombine_f16(vbsl_f16(msk, tmp_low_v11, mid_v11),
-                          vbsl_f16(msk, tmp_low_v15, mid_v15));
-      _v11 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v10, mid_v10),
-                          vbsl_f16(inv_msk, tmp_high_v14, mid_v14));
-      _v15 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v11, mid_v11),
-                          vbsl_f16(inv_msk, tmp_high_v15, mid_v15));
-
-      vst1q_f16(b_off + 0, _v8);
-      vst1q_f16(b_off + 8, _v12);
-      vst1q_f16(b_off + 16, _v9);
-      vst1q_f16(b_off + 24, _v13);
-      vst1q_f16(b_off + 32, _v10);
-      vst1q_f16(b_off + 40, _v14);
-      vst1q_f16(b_off + 48, _v11);
-      vst1q_f16(b_off + 56, _v15);
-      b_off += 64;
-    }
-  }
-}
-
-void packing_B1(unsigned int K, unsigned int N, const __fp16 *src,
-                unsigned int ldb, const __fp16 *dst) {
-  assert(K != 0 && N != 0 && N % 8 == 0);
-
-  for (int i = 0; i < K; i++) {
-    const __fp16 *a_off = src + i * ldb;
-    __fp16 *b_off = (__fp16 *)dst + i;
-    for (int j = 0; j < N; j++) {
-      float16_t v = *(a_off);
-      a_off++;
-
-      *b_off = v;
-      b_off += K;
-    }
-  }
-}
-
-void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
-                unsigned int ldb, const __fp16 *dst) {
-  assert(K != 0 && N != 0 && K % 4 == 0 && N % 4 == 0);
-  unsigned int i, j;
-
-  __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
-  __fp16 *b_off, *b_off1;
-  __fp16 c1, c2, c3, c4;
-  __fp16 c5, c6, c7, c8;
-  __fp16 c9, c10, c11, c12;
-  __fp16 c13, c14, c15, c16;
-  a_off = (__fp16 *)src;
-  b_off = (__fp16 *)dst;
-
-  j = (K >> 2);
-  do {
-    a_off1 = a_off;
-    a_off2 = a_off1 + ldb;
-    a_off3 = a_off2 + ldb;
-    a_off4 = a_off3 + ldb;
-    a_off += 4 * ldb;
-
-    b_off1 = b_off;
-    b_off += 16;
-
-    i = (N >> 2);
-    do {
-      c1 = *(a_off1 + 0);
-      c2 = *(a_off1 + 1);
-      c3 = *(a_off1 + 2);
-      c4 = *(a_off1 + 3);
-
-      c5 = *(a_off2 + 0);
-      c6 = *(a_off2 + 1);
-      c7 = *(a_off2 + 2);
-      c8 = *(a_off2 + 3);
-
-      c9 = *(a_off3 + 0);
-      c10 = *(a_off3 + 1);
-      c11 = *(a_off3 + 2);
-      c12 = *(a_off3 + 3);
-
-      c13 = *(a_off4 + 0);
-      c14 = *(a_off4 + 1);
-      c15 = *(a_off4 + 2);
-      c16 = *(a_off4 + 3);
-
-      a_off1 += 4;
-      a_off2 += 4;
-      a_off3 += 4;
-      a_off4 += 4;
-
-      *(b_off1 + 0) = c1;
-      *(b_off1 + 1) = c2;
-      *(b_off1 + 2) = c3;
-      *(b_off1 + 3) = c4;
-
-      *(b_off1 + 4) = c5;
-      *(b_off1 + 5) = c6;
-      *(b_off1 + 6) = c7;
-      *(b_off1 + 7) = c8;
-
-      *(b_off1 + 8) = c9;
-      *(b_off1 + 9) = c10;
-      *(b_off1 + 10) = c11;
-      *(b_off1 + 11) = c12;
-
-      *(b_off1 + 12) = c13;
-      *(b_off1 + 13) = c14;
-      *(b_off1 + 14) = c15;
-      *(b_off1 + 15) = c16;
-
-      b_off1 += K * 4;
-      i--;
-    } while (i > 0);
-    j--;
-  } while (j > 0);
-}
-
-void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
-                unsigned int ldb, const __fp16 *dst) {
-  assert(K != 0 && N != 0 && N % 8 == 0);
-
-  for (int i = 0; i < K; i++) {
-    const __fp16 *a_off = src + i * ldb;
-    __fp16 *b_off = (__fp16 *)dst + i * 8;
-    for (int j = 0; j < N; j += 8) {
-      float16x8_t v = vld1q_f16(a_off);
-      a_off += 8;
-
-      vst1q_f16(b_off, v);
-      b_off += 8 * K;
-    }
-  }
-}
-
-void packing_B16(unsigned int K, unsigned int N, const __fp16 *src,
-                 unsigned int ldb, const __fp16 *dst) {
-  assert(K != 0 && N != 0 && N % 16 == 0);
-
-  for (int i = 0; i < K; i++) {
-    const __fp16 *a_off = src + i * ldb;
-    __fp16 *b_off = (__fp16 *)dst + i * 16;
-    for (int j = 0; j < N; j += 16) {
-      float16x8_t v0_7 = vld1q_f16(a_off);
-      float16x8_t v8_15 = vld1q_f16(a_off + 8);
-      a_off += 16;
-
-      vst1q_f16(b_off, v0_7);
-      vst1q_f16(b_off + 8, v8_15);
-      b_off += 16 * K;
-    }
-  }
-}
-
-void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src,
-                      unsigned int ldb, const __fp16 *dst) {
-  /// @note ldb = K for here
-  assert(K != 0 && N != 0 && N % 16 == 0);
-  unsigned int K8 = (K >> 3) << 3;
-
-  const __fp16 *src_off = (__fp16 *)src;
-  __fp16 *dst_off = (__fp16 *)dst;
-
-  const unsigned int ld_tile_T = 16;
-  __fp16 *tile_T = new __fp16[8 * ld_tile_T];
-  // __fp16 *tile_T = alignedMalloc(8 * ld_tile_T);
-
-  // 1. Do something like 8x16 transpose kernel
-  // 2. Save linearized transposed output tile to dst
-  for (unsigned int n = 0; n < N; n += 16) {
-    const __fp16 *src_off1 = src_off;
-    __fp16 *dst_off1 = dst_off;
-    src_off += 16 * ldb;
-    dst_off += (K8 * 16 + (K - K8)); // ?
-    for (unsigned int k = 0; k < K8; k += 8) {
-      // 16x8 tile -> 8x16
-      transpose_neon<__fp16>(16, 8, src_off1, ldb, tile_T, ld_tile_T);
-
-      // Store with correct packing order linearly
-      vst1q_f16(&dst_off1[0], vld1q_f16(&tile_T[0 * ld_tile_T + 0]));
-      vst1q_f16(&dst_off1[8], vld1q_f16(&tile_T[0 * ld_tile_T + 8]));
-      vst1q_f16(&dst_off1[16], vld1q_f16(&tile_T[1 * ld_tile_T + 0]));
-      vst1q_f16(&dst_off1[24], vld1q_f16(&tile_T[1 * ld_tile_T + 8]));
-      vst1q_f16(&dst_off1[32], vld1q_f16(&tile_T[2 * ld_tile_T + 0]));
-      vst1q_f16(&dst_off1[40], vld1q_f16(&tile_T[2 * ld_tile_T + 8]));
-      vst1q_f16(&dst_off1[48], vld1q_f16(&tile_T[3 * ld_tile_T + 0]));
-      vst1q_f16(&dst_off1[56], vld1q_f16(&tile_T[3 * ld_tile_T + 8]));
-      vst1q_f16(&dst_off1[64], vld1q_f16(&tile_T[4 * ld_tile_T + 0]));
-      vst1q_f16(&dst_off1[72], vld1q_f16(&tile_T[4 * ld_tile_T + 8]));
-      vst1q_f16(&dst_off1[80], vld1q_f16(&tile_T[5 * ld_tile_T + 0]));
-      vst1q_f16(&dst_off1[88], vld1q_f16(&tile_T[5 * ld_tile_T + 8]));
-      vst1q_f16(&dst_off1[96], vld1q_f16(&tile_T[6 * ld_tile_T + 0]));
-      vst1q_f16(&dst_off1[104], vld1q_f16(&tile_T[6 * ld_tile_T + 8]));
-      vst1q_f16(&dst_off1[112], vld1q_f16(&tile_T[7 * ld_tile_T + 0]));
-      vst1q_f16(&dst_off1[120], vld1q_f16(&tile_T[7 * ld_tile_T + 8]));
-
-      dst_off1 += 16 * 8;
-      src_off1 += 8;
-    }
-
-    // Do the equivalent of one by one for the rest
-    for (unsigned int k = K8; k < K; ++k) {
-      for (unsigned int _n = 0; _n < 16; ++_n) {
-        dst_off1[_n] = src_off1[k];
-      }
-    }
-  }
-}
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_pack.h b/nntrainer/tensor/hgemm/hgemm_kernel_pack.h
deleted file mode 100644 (file)
index fddc351..0000000
+++ /dev/null
@@ -1,102 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
- *
- * @file   hgemm_kernel_pack.h
- * @date   01 April 2024
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Sungsik Kong <ss.kong@samsung.com>
- * @author Debadri Samaddar <s.debadri@samsung.com>
- * @bug    No known bugs except for NYI items
- * @brief  This is for half-precision packing for kernel-based GEMM
- */
-
-
-/**
- * @brief packing function of input matrix A
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param lda leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_A1(unsigned int m, unsigned int k, const __fp16 *from,
-                unsigned int lda, const __fp16 *to);
-/**
- * @brief packing function of input matrix A
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param lda leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
-                unsigned int lda, const __fp16 *dst);
-/**
- * @brief packing function of input matrix A
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param lda leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
-                unsigned int lda, const __fp16 *dst);
-/**
- * @brief packing function of input matrix B
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param ldb leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_B1(unsigned int K, unsigned int N, const __fp16 *src,
-                unsigned int ldb, const __fp16 *dst);
-/**
- * @brief packing function of input matrix B
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param ldb leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
-                unsigned int ldb, const __fp16 *dst);
-/**
- * @brief packing function of input matrix B
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param ldb leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
-                unsigned int ldb, const __fp16 *dst);
-/**
- * @brief packing function of input matrix B
- *
- * @param M length of the row of the matrix
- * @param K length of the col of the matrix
- * @param src input of original source of the matrix
- * @param ldb leading dimension of the matrix
- * @param dst output of packed data of the matrix
- */
-void packing_B16(unsigned int K, unsigned int N, const __fp16 *src,
-                 unsigned int ldb, const __fp16 *dst);
-/**
- * @brief 
- * 
- * @param K 
- * @param N 
- * @param src 
- * @param ldb 
- * @param dst 
- */
-void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src,
-                      unsigned int ldb, const __fp16 *dst);
index 64a32b385ee7957eca59e2d5f2a32d114cf4322f..bff0c308a282daed3b3dd6781686a18e196fe094 100644 (file)
  *
  */
 
+#include <arm_neon.h>
 #include <cmath>
-
-#include <hgemm_kernel_pack.h>
+#include <hgemm_common.h>
+#include <hgemm_kernel.h>
 #include <hgemm_noTrans.h>
+#include <hgemm_pack.h>
 #include <hgemm_util.h>
 #include <limits>
-// #include <hgemm_kernel.h>
-
 #include <matrix_transpose_neon.h>
-#include <hgemm_common.h>
-
-#include <hgemm_kernel_1x4.h>
-#include <hgemm_kernel_1x8.h>
-#include <hgemm_kernel_4x4.h>
-#include <hgemm_kernel_4x8.h>
-#include <hgemm_kernel_8x16.h>
-#include <hgemm_kernel_8x8.h>
-
-
 
 void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
                    unsigned int N, unsigned int K, float alpha, float beta) {
diff --git a/nntrainer/tensor/hgemm/hgemm_pack.cpp b/nntrainer/tensor/hgemm/hgemm_pack.cpp
new file mode 100644 (file)
index 0000000..0f4b147
--- /dev/null
@@ -0,0 +1,450 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_kernel_pack.cpp
+ * @date   02 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is a source file for half-precision packing for the matrix
+ * multiplication
+ */
+
+#include <arm_neon.h>
+#include <assert.h>
+#include <hgemm_common.h>
+#include <hgemm_pack.h>
+#include <matrix_transpose_neon.h>
+
+void packing_A1(unsigned int m, unsigned int k, const __fp16 *from,
+                unsigned int lda, const __fp16 *to) {
+
+  assert(k != 0 && m != 0 && k % 4 == 0 && m % 4 == 0);
+  unsigned int i, j;
+
+  __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4;
+  __fp16 *b_offset;
+  __fp16 ctemp1, ctemp2, ctemp3, ctemp4;
+
+  a_offset = (__fp16 *)from;
+  b_offset = (__fp16 *)to;
+
+  j = m;
+  do {
+    a_offset1 = a_offset;
+    a_offset += lda;
+
+    i = (k >> 2);
+    do {
+      ctemp1 = *(a_offset1 + 0);
+      ctemp2 = *(a_offset1 + 1);
+      ctemp3 = *(a_offset1 + 2);
+      ctemp4 = *(a_offset1 + 3);
+
+      *(b_offset + 0) = ctemp1;
+      *(b_offset + 1) = ctemp2;
+      *(b_offset + 2) = ctemp3;
+      *(b_offset + 3) = ctemp4;
+
+      a_offset1 += 4;
+
+      b_offset += 4;
+      i--;
+    } while (i > 0);
+    j--;
+  } while (j > 0);
+}
+
+void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
+                unsigned int lda, const __fp16 *dst) {
+
+  assert(K != 0 && M != 0 && K % 4 == 0 && M % 4 == 0);
+  unsigned int i, j;
+
+  __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
+  __fp16 *b_off;
+  __fp16 c1, c2, c3, c4;
+  __fp16 c5, c6, c7, c8;
+  __fp16 c9, c10, c11, c12;
+  __fp16 c13, c14, c15, c16;
+
+  a_off = (__fp16 *)src;
+  b_off = (__fp16 *)dst;
+
+  j = (M >> 2);
+  do {
+    a_off1 = a_off;
+    a_off2 = a_off1 + lda;
+    a_off3 = a_off2 + lda;
+    a_off4 = a_off3 + lda;
+    a_off += 4 * lda;
+
+    i = (K >> 2);
+    do {
+      c1 = *(a_off1 + 0);
+      c2 = *(a_off1 + 1);
+      c3 = *(a_off1 + 2);
+      c4 = *(a_off1 + 3);
+
+      c5 = *(a_off2 + 0);
+      c6 = *(a_off2 + 1);
+      c7 = *(a_off2 + 2);
+      c8 = *(a_off2 + 3);
+
+      c9 = *(a_off3 + 0);
+      c10 = *(a_off3 + 1);
+      c11 = *(a_off3 + 2);
+      c12 = *(a_off3 + 3);
+
+      c13 = *(a_off4 + 0);
+      c14 = *(a_off4 + 1);
+      c15 = *(a_off4 + 2);
+      c16 = *(a_off4 + 3);
+
+      *(b_off + 0) = c1;
+      *(b_off + 1) = c5;
+      *(b_off + 2) = c9;
+      *(b_off + 3) = c13;
+
+      *(b_off + 4) = c2;
+      *(b_off + 5) = c6;
+      *(b_off + 6) = c10;
+      *(b_off + 7) = c14;
+
+      *(b_off + 8) = c3;
+      *(b_off + 9) = c7;
+      *(b_off + 10) = c11;
+      *(b_off + 11) = c15;
+
+      *(b_off + 12) = c4;
+      *(b_off + 13) = c8;
+      *(b_off + 14) = c12;
+      *(b_off + 15) = c16;
+
+      a_off1 += 4;
+      a_off2 += 4;
+      a_off3 += 4;
+      a_off4 += 4;
+
+      b_off += 16;
+      i--;
+    } while (i > 0);
+    j--;
+  } while (j > 0);
+}
+
+void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
+                unsigned int lda, const __fp16 *dst) {
+
+  assert(K != 0 && M != 0 && K % 8 == 0 && M % 8 == 0);
+
+  uint16x4_t msk = {0xFFFF, 0xFFFF, 0x0000, 0x0000};
+  uint16x4_t inv_msk = {0x0000, 0x0000, 0xFFFF, 0xFFFF};
+
+  const __fp16 *a_off = (__fp16 *)src;
+  __fp16 *b_off = (__fp16 *)dst;
+
+  for (unsigned int i = 0; i < M; i += 8) {
+    const __fp16 *a_off1 = a_off;
+    const __fp16 *a_off2 = a_off1 + lda;
+    const __fp16 *a_off3 = a_off2 + lda;
+    const __fp16 *a_off4 = a_off3 + lda;
+    const __fp16 *a_off5 = a_off4 + lda;
+    const __fp16 *a_off6 = a_off5 + lda;
+    const __fp16 *a_off7 = a_off6 + lda;
+    const __fp16 *a_off8 = a_off7 + lda;
+    a_off += 8 * lda;
+
+    for (unsigned int j = 0; j < K; j += 8) {
+      float16x8_t _v0 = vld1q_f16(a_off1);
+      float16x8_t _v1 = vld1q_f16(a_off2);
+      float16x8_t _v2 = vld1q_f16(a_off3);
+      float16x8_t _v3 = vld1q_f16(a_off4);
+
+      float16x8_t _v4 = vld1q_f16(a_off5);
+      float16x8_t _v5 = vld1q_f16(a_off6);
+      float16x8_t _v6 = vld1q_f16(a_off7);
+      float16x8_t _v7 = vld1q_f16(a_off8);
+
+      a_off1 += 8;
+      a_off2 += 8;
+      a_off3 += 8;
+      a_off4 += 8;
+      a_off5 += 8;
+      a_off6 += 8;
+      a_off7 += 8;
+      a_off8 += 8;
+
+      float16x8x2_t _vv0 = vtrnq_f16(_v0, _v1);
+      float16x8x2_t _vv1 = vtrnq_f16(_v2, _v3);
+      float16x8x2_t _vv2 = vtrnq_f16(_v4, _v5);
+      float16x8x2_t _vv3 = vtrnq_f16(_v6, _v7);
+
+      float16x8_t _v8 =
+        vcombine_f16(vget_low_f16(_vv0.val[0]), vget_low_f16(_vv1.val[0]));
+      float16x8_t _v9 =
+        vcombine_f16(vget_low_f16(_vv0.val[1]), vget_low_f16(_vv1.val[1]));
+      float16x8_t _v10 =
+        vcombine_f16(vget_high_f16(_vv0.val[0]), vget_high_f16(_vv1.val[0]));
+      float16x8_t _v11 =
+        vcombine_f16(vget_high_f16(_vv0.val[1]), vget_high_f16(_vv1.val[1]));
+
+      float16x8_t _v12 =
+        vcombine_f16(vget_low_f16(_vv2.val[0]), vget_low_f16(_vv3.val[0]));
+      float16x8_t _v13 =
+        vcombine_f16(vget_low_f16(_vv2.val[1]), vget_low_f16(_vv3.val[1]));
+      float16x8_t _v14 =
+        vcombine_f16(vget_high_f16(_vv2.val[0]), vget_high_f16(_vv3.val[0]));
+      float16x8_t _v15 =
+        vcombine_f16(vget_high_f16(_vv2.val[1]), vget_high_f16(_vv3.val[1]));
+
+      // pack-in-pack
+      float16x4_t tmp_low_v8 = vget_low_f16(_v8);
+      float16x4_t tmp_high_v8 = vget_high_f16(_v8);
+      float16x4_t mid_v8 = vext_f16(tmp_low_v8, tmp_high_v8, 2);
+
+      float16x4_t tmp_low_v9 = vget_low_f16(_v9);
+      float16x4_t tmp_high_v9 = vget_high_f16(_v9);
+      float16x4_t mid_v9 = vext_f16(tmp_low_v9, tmp_high_v9, 2);
+
+      float16x4_t tmp_low_v10 = vget_low_f16(_v10);
+      float16x4_t tmp_high_v10 = vget_high_f16(_v10);
+      float16x4_t mid_v10 = vext_f16(tmp_low_v10, tmp_high_v10, 2);
+
+      float16x4_t tmp_low_v11 = vget_low_f16(_v11);
+      float16x4_t tmp_high_v11 = vget_high_f16(_v11);
+      float16x4_t mid_v11 = vext_f16(tmp_low_v11, tmp_high_v11, 2);
+
+      float16x4_t tmp_low_v12 = vget_low_f16(_v12);
+      float16x4_t tmp_high_v12 = vget_high_f16(_v12);
+      float16x4_t mid_v12 = vext_f16(tmp_low_v12, tmp_high_v12, 2);
+
+      float16x4_t tmp_low_v13 = vget_low_f16(_v13);
+      float16x4_t tmp_high_v13 = vget_high_f16(_v13);
+      float16x4_t mid_v13 = vext_f16(tmp_low_v13, tmp_high_v13, 2);
+
+      float16x4_t tmp_low_v14 = vget_low_f16(_v14);
+      float16x4_t tmp_high_v14 = vget_high_f16(_v14);
+      float16x4_t mid_v14 = vext_f16(tmp_low_v14, tmp_high_v14, 2);
+
+      float16x4_t tmp_low_v15 = vget_low_f16(_v15);
+      float16x4_t tmp_high_v15 = vget_high_f16(_v15);
+      float16x4_t mid_v15 = vext_f16(tmp_low_v15, tmp_high_v15, 2);
+
+      _v8 = vcombine_f16(vbsl_f16(msk, tmp_low_v8, mid_v8),
+                         vbsl_f16(msk, tmp_low_v12, mid_v12));
+      _v12 = vcombine_f16(vbsl_f16(msk, tmp_low_v9, mid_v9),
+                          vbsl_f16(msk, tmp_low_v13, mid_v13));
+      _v9 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v8, mid_v8),
+                         vbsl_f16(inv_msk, tmp_high_v12, mid_v12));
+      _v13 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v9, mid_v9),
+                          vbsl_f16(inv_msk, tmp_high_v13, mid_v13));
+      _v10 = vcombine_f16(vbsl_f16(msk, tmp_low_v10, mid_v10),
+                          vbsl_f16(msk, tmp_low_v14, mid_v14));
+      _v14 = vcombine_f16(vbsl_f16(msk, tmp_low_v11, mid_v11),
+                          vbsl_f16(msk, tmp_low_v15, mid_v15));
+      _v11 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v10, mid_v10),
+                          vbsl_f16(inv_msk, tmp_high_v14, mid_v14));
+      _v15 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v11, mid_v11),
+                          vbsl_f16(inv_msk, tmp_high_v15, mid_v15));
+
+      vst1q_f16(b_off + 0, _v8);
+      vst1q_f16(b_off + 8, _v12);
+      vst1q_f16(b_off + 16, _v9);
+      vst1q_f16(b_off + 24, _v13);
+      vst1q_f16(b_off + 32, _v10);
+      vst1q_f16(b_off + 40, _v14);
+      vst1q_f16(b_off + 48, _v11);
+      vst1q_f16(b_off + 56, _v15);
+      b_off += 64;
+    }
+  }
+}
+
+void packing_B1(unsigned int K, unsigned int N, const __fp16 *src,
+                unsigned int ldb, const __fp16 *dst) {
+  assert(K != 0 && N != 0 && N % 8 == 0);
+
+  for (int i = 0; i < K; i++) {
+    const __fp16 *a_off = src + i * ldb;
+    __fp16 *b_off = (__fp16 *)dst + i;
+    for (int j = 0; j < N; j++) {
+      float16_t v = *(a_off);
+      a_off++;
+
+      *b_off = v;
+      b_off += K;
+    }
+  }
+}
+
+void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
+                unsigned int ldb, const __fp16 *dst) {
+  assert(K != 0 && N != 0 && K % 4 == 0 && N % 4 == 0);
+  unsigned int i, j;
+
+  __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
+  __fp16 *b_off, *b_off1;
+  __fp16 c1, c2, c3, c4;
+  __fp16 c5, c6, c7, c8;
+  __fp16 c9, c10, c11, c12;
+  __fp16 c13, c14, c15, c16;
+  a_off = (__fp16 *)src;
+  b_off = (__fp16 *)dst;
+
+  j = (K >> 2);
+  do {
+    a_off1 = a_off;
+    a_off2 = a_off1 + ldb;
+    a_off3 = a_off2 + ldb;
+    a_off4 = a_off3 + ldb;
+    a_off += 4 * ldb;
+
+    b_off1 = b_off;
+    b_off += 16;
+
+    i = (N >> 2);
+    do {
+      c1 = *(a_off1 + 0);
+      c2 = *(a_off1 + 1);
+      c3 = *(a_off1 + 2);
+      c4 = *(a_off1 + 3);
+
+      c5 = *(a_off2 + 0);
+      c6 = *(a_off2 + 1);
+      c7 = *(a_off2 + 2);
+      c8 = *(a_off2 + 3);
+
+      c9 = *(a_off3 + 0);
+      c10 = *(a_off3 + 1);
+      c11 = *(a_off3 + 2);
+      c12 = *(a_off3 + 3);
+
+      c13 = *(a_off4 + 0);
+      c14 = *(a_off4 + 1);
+      c15 = *(a_off4 + 2);
+      c16 = *(a_off4 + 3);
+
+      a_off1 += 4;
+      a_off2 += 4;
+      a_off3 += 4;
+      a_off4 += 4;
+
+      *(b_off1 + 0) = c1;
+      *(b_off1 + 1) = c2;
+      *(b_off1 + 2) = c3;
+      *(b_off1 + 3) = c4;
+
+      *(b_off1 + 4) = c5;
+      *(b_off1 + 5) = c6;
+      *(b_off1 + 6) = c7;
+      *(b_off1 + 7) = c8;
+
+      *(b_off1 + 8) = c9;
+      *(b_off1 + 9) = c10;
+      *(b_off1 + 10) = c11;
+      *(b_off1 + 11) = c12;
+
+      *(b_off1 + 12) = c13;
+      *(b_off1 + 13) = c14;
+      *(b_off1 + 14) = c15;
+      *(b_off1 + 15) = c16;
+
+      b_off1 += K * 4;
+      i--;
+    } while (i > 0);
+    j--;
+  } while (j > 0);
+}
+
+void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
+                unsigned int ldb, const __fp16 *dst) {
+  assert(K != 0 && N != 0 && N % 8 == 0);
+
+  for (int i = 0; i < K; i++) {
+    const __fp16 *a_off = src + i * ldb;
+    __fp16 *b_off = (__fp16 *)dst + i * 8;
+    for (int j = 0; j < N; j += 8) {
+      float16x8_t v = vld1q_f16(a_off);
+      a_off += 8;
+
+      vst1q_f16(b_off, v);
+      b_off += 8 * K;
+    }
+  }
+}
+
+void packing_B16(unsigned int K, unsigned int N, const __fp16 *src,
+                 unsigned int ldb, const __fp16 *dst) {
+  assert(K != 0 && N != 0 && N % 16 == 0);
+
+  for (int i = 0; i < K; i++) {
+    const __fp16 *a_off = src + i * ldb;
+    __fp16 *b_off = (__fp16 *)dst + i * 16;
+    for (int j = 0; j < N; j += 16) {
+      float16x8_t v0_7 = vld1q_f16(a_off);
+      float16x8_t v8_15 = vld1q_f16(a_off + 8);
+      a_off += 16;
+
+      vst1q_f16(b_off, v0_7);
+      vst1q_f16(b_off + 8, v8_15);
+      b_off += 16 * K;
+    }
+  }
+}
+
+void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src,
+                      unsigned int ldb, const __fp16 *dst) {
+  /// @note ldb = K for here
+  assert(K != 0 && N != 0 && N % 16 == 0);
+  unsigned int K8 = (K >> 3) << 3;
+
+  const __fp16 *src_off = (__fp16 *)src;
+  __fp16 *dst_off = (__fp16 *)dst;
+
+  const unsigned int ld_tile_T = 16;
+  __fp16 *tile_T = new __fp16[8 * ld_tile_T];
+  // __fp16 *tile_T = alignedMalloc(8 * ld_tile_T);
+
+  // 1. Do something like 8x16 transpose kernel
+  // 2. Save linearized transposed output tile to dst
+  for (unsigned int n = 0; n < N; n += 16) {
+    const __fp16 *src_off1 = src_off;
+    __fp16 *dst_off1 = dst_off;
+    src_off += 16 * ldb;
+    dst_off += (K8 * 16 + (K - K8)); // ?
+    for (unsigned int k = 0; k < K8; k += 8) {
+      // 16x8 tile -> 8x16
+      transpose_neon<__fp16>(16, 8, src_off1, ldb, tile_T, ld_tile_T);
+
+      // Store with correct packing order linearly
+      vst1q_f16(&dst_off1[0], vld1q_f16(&tile_T[0 * ld_tile_T + 0]));
+      vst1q_f16(&dst_off1[8], vld1q_f16(&tile_T[0 * ld_tile_T + 8]));
+      vst1q_f16(&dst_off1[16], vld1q_f16(&tile_T[1 * ld_tile_T + 0]));
+      vst1q_f16(&dst_off1[24], vld1q_f16(&tile_T[1 * ld_tile_T + 8]));
+      vst1q_f16(&dst_off1[32], vld1q_f16(&tile_T[2 * ld_tile_T + 0]));
+      vst1q_f16(&dst_off1[40], vld1q_f16(&tile_T[2 * ld_tile_T + 8]));
+      vst1q_f16(&dst_off1[48], vld1q_f16(&tile_T[3 * ld_tile_T + 0]));
+      vst1q_f16(&dst_off1[56], vld1q_f16(&tile_T[3 * ld_tile_T + 8]));
+      vst1q_f16(&dst_off1[64], vld1q_f16(&tile_T[4 * ld_tile_T + 0]));
+      vst1q_f16(&dst_off1[72], vld1q_f16(&tile_T[4 * ld_tile_T + 8]));
+      vst1q_f16(&dst_off1[80], vld1q_f16(&tile_T[5 * ld_tile_T + 0]));
+      vst1q_f16(&dst_off1[88], vld1q_f16(&tile_T[5 * ld_tile_T + 8]));
+      vst1q_f16(&dst_off1[96], vld1q_f16(&tile_T[6 * ld_tile_T + 0]));
+      vst1q_f16(&dst_off1[104], vld1q_f16(&tile_T[6 * ld_tile_T + 8]));
+      vst1q_f16(&dst_off1[112], vld1q_f16(&tile_T[7 * ld_tile_T + 0]));
+      vst1q_f16(&dst_off1[120], vld1q_f16(&tile_T[7 * ld_tile_T + 8]));
+
+      dst_off1 += 16 * 8;
+      src_off1 += 8;
+    }
+
+    // Do the equivalent of one by one for the rest
+    for (unsigned int k = K8; k < K; ++k) {
+      for (unsigned int _n = 0; _n < 16; ++_n) {
+        dst_off1[_n] = src_off1[k];
+      }
+    }
+  }
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_pack.h b/nntrainer/tensor/hgemm/hgemm_pack.h
new file mode 100644 (file)
index 0000000..7a671a5
--- /dev/null
@@ -0,0 +1,101 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_kernel_pack.h
+ * @date   01 April 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @author Debadri Samaddar <s.debadri@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is for half-precision packing for kernel-based GEMM
+ */
+
+/**
+ * @brief packing function of input matrix A
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param lda leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_A1(unsigned int m, unsigned int k, const __fp16 *from,
+                unsigned int lda, const __fp16 *to);
+/**
+ * @brief packing function of input matrix A
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param lda leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
+                unsigned int lda, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix A
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param lda leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
+                unsigned int lda, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix B
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param ldb leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_B1(unsigned int K, unsigned int N, const __fp16 *src,
+                unsigned int ldb, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix B
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param ldb leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
+                unsigned int ldb, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix B
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param ldb leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
+                unsigned int ldb, const __fp16 *dst);
+/**
+ * @brief packing function of input matrix B
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
+ * @param ldb leading dimension of the matrix
+ * @param dst output of packed data of the matrix
+ */
+void packing_B16(unsigned int K, unsigned int N, const __fp16 *src,
+                 unsigned int ldb, const __fp16 *dst);
+/**
+ * @brief
+ *
+ * @param K
+ * @param N
+ * @param src
+ * @param ldb
+ * @param dst
+ */
+void packing_transB16(unsigned int K, unsigned int N, const __fp16 *src,
+                      unsigned int ldb, const __fp16 *dst);
index 3eb4f188535a6e74bf2fabdb6be9960394198686..f0ef1052b2ef6128ebf796b308dd6ee72e44b6f6 100644 (file)
@@ -82,7 +82,6 @@ void hgemm_padding_B_noTrans_wrt_KN(const __fp16 *B, __fp16 *Bp, unsigned int K,
   std::cerr << "Error : hgemm_padding_B_noTrans_wrt_KN NYI!\n";
 }
 
-
 void hgemm_padding_B_Trans_wrt_N(const __fp16 *B, __fp16 *Bp, unsigned int K,
                                  unsigned int N, unsigned int K8,
                                  unsigned int N16) {
index adc7907fd3501aa106cb3334c300734febb5758b..f224e5f7b4cc55278e41af44c906b4b3535537c4 100644 (file)
  */
 
 #include <cmath>
-#include <hgemm_kernel_8x16.h>
 #include <hgemm_common.h>
-// #include <hgemm_kernel.h>
-#include <hgemm_kernel_pack.h>
+#include <hgemm_kernel.h>
 #include <hgemm_noTrans.h>
+#include <hgemm_pack.h>
 #include <hgemm_transB.h>
 #include <hgemm_util.h>
 #include <limits>
 #include <matrix_transpose_neon.h>
 
-// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16 /// @todo change to macro kernel
-// #if !defined(HGEMM_KERNEL_8x16) hgemm_kernel_8x16
-// #endif
-
 void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
                        const __fp16 *A, unsigned int lda, const __fp16 *B,
                        unsigned int ldb, float *C, unsigned int ldc,
@@ -87,8 +82,7 @@ void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
           n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
         }
         packing_transB16(k_min, n_min, B + ks + ldb * ns, ldb, sB);
-        hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns,
-        ldc);
+        hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
       }
     }
   }
index ec5af07bba8dda44ed3ad0591e49be2af532668b..90ef0a3774a7716c5c7e8865c370f13bc508875a 100644 (file)
@@ -1,18 +1,20 @@
 hgemm_headers = [
   'hgemm.h',
   'hgemm_util.h',
-  'hgemm_kernel_pack.h',
-  'hgemm_kernel_4x4.h',
-  'hgemm_kernel_4x8.h',
-  'hgemm_kernel_8x8.h',
-  'hgemm_kernel_8x16.h',
+  'hgemm_pack.h',
+  'hgemm_common.h',
+  'hgemm_padding.h',
 ]
 
+subdir('hgemm_kernel')
+nntrainer_inc += include_directories('hgemm_kernel')
+nntrainer_inc_abs += meson.current_source_dir() / 'hgemm_kernel'
+
 hgemm_sources = [
     'hgemm.cpp',
     'hgemm_padding_a.cpp',
     'hgemm_padding_b.cpp',
-    'hgemm_kernel_pack.cpp',
+    'hgemm_pack.cpp',
     'hgemm_noTrans.cpp',
     'hgemm_transA.cpp',
     'hgemm_transB.cpp',