[ hgemm ] Refactor kernel init process
authorskykongkong8 <ss.kong@samsung.com>
Mon, 15 Apr 2024 04:01:04 +0000 (13:01 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 22 May 2024 23:13:42 +0000 (08:13 +0900)
- I found there was a repeated usage of matrix initialization before mul-add fused operations.
- With separate initialization code, we can enjoy:
1. Cleaner code that is reusable for both f16 & f16-f32 kernel
2. Redundant init process is minimized for f16 kernel. Better latency with the SAME accuracy.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/hgemm/hgemm_kernel_4x4.h
nntrainer/tensor/hgemm/hgemm_kernel_4x8.h
nntrainer/tensor/hgemm/hgemm_kernel_8x16.h
nntrainer/tensor/hgemm/hgemm_kernel_8x8.h

index db560db04dccb3a4546ef3553be8d7a58782a6cd..ab49faaca7d6b16157eca62b54263a870ee7443b 100644 (file)
 #include <hgemm_common.h>
 #include <stdlib.h>
 
+#define INIT_KERNEL_4x4() \
+  v24 = vdup_n_f16(0.F);  \
+  v25 = vdup_n_f16(0.F);  \
+  v26 = vdup_n_f16(0.F);  \
+  v27 = vdup_n_f16(0.F);
+
 // 1. Partial sum 256 digits : Medium accuracy, medium latency
 #define KERNEL_4x4_ACC16()               \
-  v24 = vdup_n_f16(0.F);                 \
-  v25 = vdup_n_f16(0.F);                 \
-  v26 = vdup_n_f16(0.F);                 \
-  v27 = vdup_n_f16(0.F);                 \
   dv0 = vld1_f16(a);                     \
   vb0 = vld1_f16(b);                     \
   v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
   v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
   v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
   v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
-  l += 16;                                \
+  l += 16;                               \
   __builtin_prefetch(b + 64, 0, 3);      \
   __builtin_prefetch(a + 64, 0, 3);      \
-  b += 4 * 16;                            \
+  b += 4 * 16;                           \
   a += 4 * 16;
 
 // 2. Partial sum 128 digits : Medium accuracy, medium latency
 #define KERNEL_4x4_ACC8()                \
-  v24 = vdup_n_f16(0.F);                 \
-  v25 = vdup_n_f16(0.F);                 \
-  v26 = vdup_n_f16(0.F);                 \
-  v27 = vdup_n_f16(0.F);                 \
   dv0 = vld1_f16(a);                     \
   vb0 = vld1_f16(b);                     \
   v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
 
 // 2. Partial sum 16 digits : Best accuracy, worst latency
 #define KERNEL_4x4_ACC1()                \
-  v24 = vdup_n_f16(0.F);                 \
-  v25 = vdup_n_f16(0.F);                 \
-  v26 = vdup_n_f16(0.F);                 \
-  v27 = vdup_n_f16(0.F);                 \
   dv0 = vld1_f16(a);                     \
   vb0 = vld1_f16(b);                     \
   v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
@@ -230,10 +224,11 @@ void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
       __builtin_prefetch(b, 0, 3);
       __builtin_prefetch(a, 0, 3);
 
-      float16x4_t v24 = {0};
-      float16x4_t v25 = {0};
-      float16x4_t v26 = {0};
-      float16x4_t v27 = {0};
+      float16x4_t v24;
+      float16x4_t v25;
+      float16x4_t v26;
+      float16x4_t v27;
+      INIT_KERNEL_4x4();
 
       for (l = 0; l < K; l += VL_FP16_HALF) {
         float16x4_t v0 = vld1_f16(b);
@@ -326,14 +321,17 @@ void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
       float16x4_t vb0, vb1, vb2, vb3, vb4, vb5, vb6, vb7;
       l = 0;
       for (; l < K16;) {
+        INIT_KERNEL_4x4();
         KERNEL_4x4_ACC16();
         SAVE_KERNEL_4X4_F16_F32();
       }
       for (; l < K8;) {
+        INIT_KERNEL_4x4();
         KERNEL_4x4_ACC8();
         SAVE_KERNEL_4X4_F16_F32();
       }
       for (; l < K;) {
+        INIT_KERNEL_4x4();
         KERNEL_4x4_ACC1();
         SAVE_KERNEL_4X4_F16_F32();
       }
index 6ee1b100598c269112129098ac5d558e0f3de649..064f0a7b73a66bacd0e6a43b4726cc60022ed0bb 100644 (file)
 #include <hgemm_common.h>
 #include <stdlib.h>
 
-/// @note Following KERNELs are the combinations of accuracy-latency
-/// tradeoff. User can select which kernel to use by replacing them.
+#define INIT_KERNEL_4X8() \
+  v0 = vdupq_n_f16(0.F);  \
+  v3 = vdupq_n_f16(0.F);  \
+  v6 = vdupq_n_f16(0.F);  \
+  v9 = vdupq_n_f16(0.F);
 
 // 1. Partial sum 256 digits : worst accuracy, best latency
 #define KERNEL_4x8_ACC16()              \
-  v0 = vdupq_n_f16(0.F);                \
-  v3 = vdupq_n_f16(0.F);                \
-  v6 = vdupq_n_f16(0.F);                \
-  v9 = vdupq_n_f16(0.F);                \
   dv0 = vld1_f16(a);                    \
   v24 = vld1q_f16(b);                   \
   v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
 
 // 1. Partial sum 256 digits : worst accuracy, best latency
 #define KERNEL_4x8_ACC8()               \
-  v0 = vdupq_n_f16(0.F);                \
-  v3 = vdupq_n_f16(0.F);                \
-  v6 = vdupq_n_f16(0.F);                \
-  v9 = vdupq_n_f16(0.F);                \
   dv0 = vld1_f16(a);                    \
   v24 = vld1q_f16(b);                   \
   v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
 
 // 2. Partial sum 128 digits : medium accuracy, medium latency
 #define KERNEL_4x8_ACC4()               \
-  v0 = vdupq_n_f16(0.F);                \
-  v3 = vdupq_n_f16(0.F);                \
-  v6 = vdupq_n_f16(0.F);                \
-  v9 = vdupq_n_f16(0.F);                \
   dv0 = vld1_f16(a);                    \
   v24 = vld1q_f16(b);                   \
   v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
 
 // 3. Partial sum 32 digits : Best accuracy, worst latency
 #define KERNEL_4x8_ACC1()               \
-  v0 = vdupq_n_f16(0.F);                \
-  v3 = vdupq_n_f16(0.F);                \
-  v6 = vdupq_n_f16(0.F);                \
-  v9 = vdupq_n_f16(0.F);                \
   dv0 = vld1_f16(a);                    \
   v24 = vld1q_f16(b);                   \
   v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
@@ -274,7 +261,7 @@ void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
   assert(M % 4 == 0 && N % 8 == 0);
 
   __fp16 *a = sa, *b = sb, *c = sc;
-  unsigned int k8 = (K >> 3) << 3;
+  unsigned int K8 = (K >> 3) << 3;
   unsigned int i, j, l;
   for (i = 0; i < M; i += 4) {
     for (j = 0; j < N; j += 8) {
@@ -283,23 +270,18 @@ void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
       float16x8_t v0, v3, v6, v9;
       float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
       float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+      INIT_KERNEL_4X8();
       l = 0;
-      for (; l < k8;) {
+      for (; l < K8;) {
         KERNEL_4x8_ACC8();
-
-        vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
-        vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
-        vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
-        vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
       }
       for (; l < K;) {
         KERNEL_4x8_ACC1();
-
-        vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
-        vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
-        vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
-        vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
       }
+      vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
+      vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
+      vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
+      vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
       c += 8;
       a -= 4 * K;
     }
@@ -341,18 +323,22 @@ void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
       float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
       l = 0;
       for (; l < K16;) {
+        INIT_KERNEL_4X8();
         KERNEL_4x8_ACC16();
         SAVE_KERNEL_4X8_F16_F32();
       }
       for (; l < K8;) {
+        INIT_KERNEL_4X8();
         KERNEL_4x8_ACC8();
         SAVE_KERNEL_4X8_F16_F32();
       }
       for (; l < K4;) {
+        INIT_KERNEL_4X8();
         KERNEL_4x8_ACC4();
         SAVE_KERNEL_4X8_F16_F32();
       }
       for (; l < K;) {
+        INIT_KERNEL_4X8();
         KERNEL_4x8_ACC1();
         SAVE_KERNEL_4X8_F16_F32();
       }
index 188f106655dd5aeda1d9a34fc9aab4db18068c7d..38778ea8f39324ce4c5e80b15c35d2b51d7eb19d 100644 (file)
 #include <hgemm_common.h>
 #include <stdlib.h>
 
-/// @note Following KERNELs are the combinations of accuracy-latency
-/// tradeoff. User can select which kernel to use by replacing them.
+#define INIT_KERNEL_8X16()     \
+  v0_7 = vdupq_n_f16(0.F);     \
+  v8_15 = vdupq_n_f16(0.F);    \
+  v16_23 = vdupq_n_f16(0.F);   \
+  v24_31 = vdupq_n_f16(0.F);   \
+  v32_39 = vdupq_n_f16(0.F);   \
+  v40_47 = vdupq_n_f16(0.F);   \
+  v48_55 = vdupq_n_f16(0.F);   \
+  v56_63 = vdupq_n_f16(0.F);   \
+  v64_71 = vdupq_n_f16(0.F);   \
+  v72_79 = vdupq_n_f16(0.F);   \
+  v80_87 = vdupq_n_f16(0.F);   \
+  v88_95 = vdupq_n_f16(0.F);   \
+  v96_103 = vdupq_n_f16(0.F);  \
+  v104_111 = vdupq_n_f16(0.F); \
+  v112_119 = vdupq_n_f16(0.F); \
+  v120_127 = vdupq_n_f16(0.F);
 
 // 0. Partial sum 2048 digits : Best latency, worst accuracy.
 #define KERNEL_8x16_ACC16()                          \
-  v0_7 = vdupq_n_f16(0.F);                           \
-  v8_15 = vdupq_n_f16(0.F);                          \
-  v16_23 = vdupq_n_f16(0.F);                         \
-  v24_31 = vdupq_n_f16(0.F);                         \
-  v32_39 = vdupq_n_f16(0.F);                         \
-  v40_47 = vdupq_n_f16(0.F);                         \
-  v48_55 = vdupq_n_f16(0.F);                         \
-  v56_63 = vdupq_n_f16(0.F);                         \
-  v64_71 = vdupq_n_f16(0.F);                         \
-  v72_79 = vdupq_n_f16(0.F);                         \
-  v80_87 = vdupq_n_f16(0.F);                         \
-  v88_95 = vdupq_n_f16(0.F);                         \
-  v96_103 = vdupq_n_f16(0.F);                        \
-  v104_111 = vdupq_n_f16(0.F);                       \
-  v112_119 = vdupq_n_f16(0.F);                       \
-  v120_127 = vdupq_n_f16(0.F);                       \
   va0 = vld1q_f16(a);                                \
   v24 = vld1q_f16(b);                                \
   v25 = vld1q_f16(b + 8);                            \
 
 // 1. Partial sum 1024 digits : Medium-high accuracy, medium latency
 #define KERNEL_8x16_ACC8()                           \
-  v0_7 = vdupq_n_f16(0.F);                           \
-  v8_15 = vdupq_n_f16(0.F);                          \
-  v16_23 = vdupq_n_f16(0.F);                         \
-  v24_31 = vdupq_n_f16(0.F);                         \
-  v32_39 = vdupq_n_f16(0.F);                         \
-  v40_47 = vdupq_n_f16(0.F);                         \
-  v48_55 = vdupq_n_f16(0.F);                         \
-  v56_63 = vdupq_n_f16(0.F);                         \
-  v64_71 = vdupq_n_f16(0.F);                         \
-  v72_79 = vdupq_n_f16(0.F);                         \
-  v80_87 = vdupq_n_f16(0.F);                         \
-  v88_95 = vdupq_n_f16(0.F);                         \
-  v96_103 = vdupq_n_f16(0.F);                        \
-  v104_111 = vdupq_n_f16(0.F);                       \
-  v112_119 = vdupq_n_f16(0.F);                       \
-  v120_127 = vdupq_n_f16(0.F);                       \
   va0 = vld1q_f16(a);                                \
   v24 = vld1q_f16(b);                                \
   v25 = vld1q_f16(b + 8);                            \
 
 // 2. Partial sum 512 digits : Medium accuracy, medium latency
 #define KERNEL_8x16_ACC4()                           \
-  v0_7 = vdupq_n_f16(0.F);                           \
-  v8_15 = vdupq_n_f16(0.F);                          \
-  v16_23 = vdupq_n_f16(0.F);                         \
-  v24_31 = vdupq_n_f16(0.F);                         \
-  v32_39 = vdupq_n_f16(0.F);                         \
-  v40_47 = vdupq_n_f16(0.F);                         \
-  v48_55 = vdupq_n_f16(0.F);                         \
-  v56_63 = vdupq_n_f16(0.F);                         \
-  v64_71 = vdupq_n_f16(0.F);                         \
-  v72_79 = vdupq_n_f16(0.F);                         \
-  v80_87 = vdupq_n_f16(0.F);                         \
-  v88_95 = vdupq_n_f16(0.F);                         \
-  v96_103 = vdupq_n_f16(0.F);                        \
-  v104_111 = vdupq_n_f16(0.F);                       \
-  v112_119 = vdupq_n_f16(0.F);                       \
-  v120_127 = vdupq_n_f16(0.F);                       \
   va0 = vld1q_f16(a);                                \
   v24 = vld1q_f16(b);                                \
   v25 = vld1q_f16(b + 8);                            \
 
 // 3. Partial sum 128 digits : Best accuracy, worst latency
 #define KERNEL_8x16_ACC1()                           \
-  v0_7 = vdupq_n_f16(0.F);                           \
-  v8_15 = vdupq_n_f16(0.F);                          \
-  v16_23 = vdupq_n_f16(0.F);                         \
-  v24_31 = vdupq_n_f16(0.F);                         \
-  v32_39 = vdupq_n_f16(0.F);                         \
-  v40_47 = vdupq_n_f16(0.F);                         \
-  v48_55 = vdupq_n_f16(0.F);                         \
-  v56_63 = vdupq_n_f16(0.F);                         \
-  v64_71 = vdupq_n_f16(0.F);                         \
-  v72_79 = vdupq_n_f16(0.F);                         \
-  v80_87 = vdupq_n_f16(0.F);                         \
-  v88_95 = vdupq_n_f16(0.F);                         \
-  v96_103 = vdupq_n_f16(0.F);                        \
-  v104_111 = vdupq_n_f16(0.F);                       \
-  v112_119 = vdupq_n_f16(0.F);                       \
-  v120_127 = vdupq_n_f16(0.F);                       \
   va0 = vld1q_f16(a);                                \
   v24 = vld1q_f16(b);                                \
   v25 = vld1q_f16(b + 8);                            \
@@ -783,10 +734,13 @@ void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
 
       float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
       float16x8_t va0, va1, va2, va3;
+
+      INIT_KERNEL_8X16();
       l = 0;
       for (; l < K;) {
-        KERNEL_8x16_ACC4();
-        vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
+        KERNEL_8x16_ACC1();
+      }
+       vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
         vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
         vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
         vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
@@ -808,7 +762,6 @@ void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
         vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
         vst1q_f16(c + 7 * ldc + 8,
                   vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
-      }
       c += 16;
       a -= 8 * K;
     }
@@ -857,18 +810,22 @@ void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
       float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
       l = 0;
       for (; l < K16;) {
+        INIT_KERNEL_8X16();
         KERNEL_8x16_ACC16();
         SAVE_KERNEL_8X16_F16_F32();
       }
       for (; l < K8;) {
+        INIT_KERNEL_8X16();
         KERNEL_8x16_ACC8();
         SAVE_KERNEL_8X16_F16_F32();
       }
       for (; l < K4;) {
+        INIT_KERNEL_8X16();
         KERNEL_8x16_ACC4();
         SAVE_KERNEL_8X16_F16_F32();
       }
       for (; l < K;) {
+        INIT_KERNEL_8X16();
         KERNEL_8x16_ACC1();
         SAVE_KERNEL_8X16_F16_F32();
       }
index 90d640d79258a09df4946a6fb665e78d17b35240..c913bdd040c93668431b8e385dcd275e21365ff6 100644 (file)
 #include <hgemm_common.h>
 #include <stdlib.h>
 
-/// @note Following KERNELs are the combinations of accuracy-latency
-/// tradeoff. User can select which kernel to use by replacing them.
+#define INIT_KERNEL_8x8() \
+  v24 = vdupq_n_f16(0.F); \
+  v25 = vdupq_n_f16(0.F); \
+  v26 = vdupq_n_f16(0.F); \
+  v27 = vdupq_n_f16(0.F); \
+  v28 = vdupq_n_f16(0.F); \
+  v29 = vdupq_n_f16(0.F); \
+  v30 = vdupq_n_f16(0.F); \
+  v31 = vdupq_n_f16(0.F);
 
 // 1. Partial sum 1024 digits : Worst accuracy, best latency
 #define KERNEL_8x8_ACC16()                 \
-  v24 = vdupq_n_f16(0.F);                  \
-  v25 = vdupq_n_f16(0.F);                  \
-  v26 = vdupq_n_f16(0.F);                  \
-  v27 = vdupq_n_f16(0.F);                  \
-  v28 = vdupq_n_f16(0.F);                  \
-  v29 = vdupq_n_f16(0.F);                  \
-  v30 = vdupq_n_f16(0.F);                  \
-  v31 = vdupq_n_f16(0.F);                  \
   va0 = vld1q_f16(a);                      \
   v16 = vld1q_f16(b);                      \
   v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
 
 // 2. Partial sum 512 digits : Medium accuracy, medium latency
 #define KERNEL_8x8_ACC8()                  \
-  v24 = vdupq_n_f16(0.F);                  \
-  v25 = vdupq_n_f16(0.F);                  \
-  v26 = vdupq_n_f16(0.F);                  \
-  v27 = vdupq_n_f16(0.F);                  \
-  v28 = vdupq_n_f16(0.F);                  \
-  v29 = vdupq_n_f16(0.F);                  \
-  v30 = vdupq_n_f16(0.F);                  \
-  v31 = vdupq_n_f16(0.F);                  \
   va0 = vld1q_f16(a);                      \
   v16 = vld1q_f16(b);                      \
   v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
 
 // 3. Partial sum 256 digits : Medium accuracy, medium latency
 #define KERNEL_8x8_ACC4()                  \
-  v24 = vdupq_n_f16(0.F);                  \
-  v25 = vdupq_n_f16(0.F);                  \
-  v26 = vdupq_n_f16(0.F);                  \
-  v27 = vdupq_n_f16(0.F);                  \
-  v28 = vdupq_n_f16(0.F);                  \
-  v29 = vdupq_n_f16(0.F);                  \
-  v30 = vdupq_n_f16(0.F);                  \
-  v31 = vdupq_n_f16(0.F);                  \
   va0 = vld1q_f16(a);                      \
   v16 = vld1q_f16(b);                      \
   v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
 
 // 4. Partial sum 64 digits : Best accuracy, worst latency
 #define KERNEL_8x8_ACC1()                  \
-  v24 = vdupq_n_f16(0.F);                  \
-  v25 = vdupq_n_f16(0.F);                  \
-  v26 = vdupq_n_f16(0.F);                  \
-  v27 = vdupq_n_f16(0.F);                  \
-  v28 = vdupq_n_f16(0.F);                  \
-  v29 = vdupq_n_f16(0.F);                  \
-  v30 = vdupq_n_f16(0.F);                  \
-  v31 = vdupq_n_f16(0.F);                  \
   va0 = vld1q_f16(a);                      \
   v16 = vld1q_f16(b);                      \
   v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
@@ -437,19 +412,19 @@ void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
       float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
       float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
       float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
+      INIT_KERNEL_8x8();
       l = 0;
       for (; l < K;) {
-        KERNEL_8x8_ACC8();
-
-        vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24));
-        vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25));
-        vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26));
-        vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27));
-        vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28));
-        vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29));
-        vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30));
-        vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31));
+        KERNEL_8x8_ACC1();
       }
+      vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24));
+      vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25));
+      vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26));
+      vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27));
+      vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28));
+      vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29));
+      vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30));
+      vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31));
       c += 8;
       a -= 8 * K;
     }
@@ -492,18 +467,22 @@ void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
       float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
       l = 0;
       for (; l < K16;) {
+        INIT_KERNEL_8x8();
         KERNEL_8x8_ACC16();
         SAVE_KERNEL_8X8_F16_f32();
       }
       for (; l < K8;) {
+        INIT_KERNEL_8x8();
         KERNEL_8x8_ACC8();
         SAVE_KERNEL_8X8_F16_f32();
       }
       for (; l < K4;) {
+        INIT_KERNEL_8x8();
         KERNEL_8x8_ACC4();
         SAVE_KERNEL_8X8_F16_f32();
       }
       for (; l < K;) {
+        INIT_KERNEL_8x8();
         KERNEL_8x8_ACC1();
         SAVE_KERNEL_8X8_F16_f32();
       }