[ hgemm ] Apply ACC16 partial sum strategy & adaptive macro use in 8x16 kernel
authorskykongkong8 <ss.kong@samsung.com>
Mon, 15 Apr 2024 01:19:24 +0000 (10:19 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 22 May 2024 23:13:42 +0000 (08:13 +0900)
- With more digits computed with fp16 (in this case 1024 -> 2048) I could observe latency improvement with the cost of accuracy loss. However, according to current accuracy measurement criteria, it is still acceptable. Note that it is highly desired to be proven with model output once more.
- With variety of partial sum kernels, we can adaptively apply internal macro kernels without being constrained to K-divisibilty w.r.t. 4, 8, 16.Commit title (Until 50 colums per line)

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/hgemm/hgemm_kernel_8x16.h

index 7cac545809c2eddae65606b2c3515b92a684e797..188f106655dd5aeda1d9a34fc9aab4db18068c7d 100644 (file)
 /// @note Following KERNELs are the combinations of accuracy-latency
 /// tradeoff. User can select which kernel to use by replacing them.
 
-// 1. Partial sum 1024 digits : Worst accuracy, best latency
+// 0. Partial sum 2048 digits : Best latency, worst accuracy.
+#define KERNEL_8x16_ACC16()                          \
+  v0_7 = vdupq_n_f16(0.F);                           \
+  v8_15 = vdupq_n_f16(0.F);                          \
+  v16_23 = vdupq_n_f16(0.F);                         \
+  v24_31 = vdupq_n_f16(0.F);                         \
+  v32_39 = vdupq_n_f16(0.F);                         \
+  v40_47 = vdupq_n_f16(0.F);                         \
+  v48_55 = vdupq_n_f16(0.F);                         \
+  v56_63 = vdupq_n_f16(0.F);                         \
+  v64_71 = vdupq_n_f16(0.F);                         \
+  v72_79 = vdupq_n_f16(0.F);                         \
+  v80_87 = vdupq_n_f16(0.F);                         \
+  v88_95 = vdupq_n_f16(0.F);                         \
+  v96_103 = vdupq_n_f16(0.F);                        \
+  v104_111 = vdupq_n_f16(0.F);                       \
+  v112_119 = vdupq_n_f16(0.F);                       \
+  v120_127 = vdupq_n_f16(0.F);                       \
+  va0 = vld1q_f16(a);                                \
+  v24 = vld1q_f16(b);                                \
+  v25 = vld1q_f16(b + 8);                            \
+  v0_7 = vfmaq_laneq_f16(v0_7, v24, va0, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v24, va0, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v24, va0, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v24, va0, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v24, va0, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v24, va0, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v24, va0, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v24, va0, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v25, va0, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v25, va0, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v25, va0, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v25, va0, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v25, va0, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v25, va0, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v25, va0, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v25, va0, 7); \
+  va1 = vld1q_f16(a + 8);                            \
+  v26 = vld1q_f16(b + 8 * 2);                        \
+  v27 = vld1q_f16(b + 8 * 3);                        \
+  v0_7 = vfmaq_laneq_f16(v0_7, v26, va1, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v26, va1, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v26, va1, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v26, va1, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v26, va1, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v26, va1, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v26, va1, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v26, va1, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v27, va1, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v27, va1, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v27, va1, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v27, va1, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v27, va1, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v27, va1, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v27, va1, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v27, va1, 7); \
+  va2 = vld1q_f16(a + 8 * 2);                        \
+  v28 = vld1q_f16(b + 8 * 4);                        \
+  v29 = vld1q_f16(b + 8 * 5);                        \
+  v0_7 = vfmaq_laneq_f16(v0_7, v28, va2, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v28, va2, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v28, va2, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v28, va2, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v28, va2, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v28, va2, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v28, va2, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v28, va2, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v29, va2, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v29, va2, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v29, va2, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v29, va2, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v29, va2, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v29, va2, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v29, va2, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v29, va2, 7); \
+  va3 = vld1q_f16(a + 8 * 3);                        \
+  v30 = vld1q_f16(b + 8 * 6);                        \
+  v31 = vld1q_f16(b + 8 * 7);                        \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va3, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va3, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va3, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va3, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va3, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va3, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va3, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va3, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va3, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va3, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va3, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va3, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va3, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va3, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va3, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va3, 7); \
+  va4 = vld1q_f16(a + 8 * 4);                        \
+  v24 = vld1q_f16(b + 8 * 8);                        \
+  v25 = vld1q_f16(b + 8 * 9);                        \
+  v0_7 = vfmaq_laneq_f16(v0_7, v24, va4, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v24, va4, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v24, va4, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v24, va4, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v24, va4, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v24, va4, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v24, va4, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v24, va4, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v25, va4, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v25, va4, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v25, va4, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v25, va4, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v25, va4, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v25, va4, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v25, va4, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v25, va4, 7); \
+  va5 = vld1q_f16(a + 8 * 5);                        \
+  v26 = vld1q_f16(b + 8 * 10);                       \
+  v27 = vld1q_f16(b + 8 * 11);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v26, va5, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v26, va5, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v26, va5, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v26, va5, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v26, va5, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v26, va5, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v26, va5, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v26, va5, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v27, va5, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v27, va5, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v27, va5, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v27, va5, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v27, va5, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v27, va5, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v27, va5, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v27, va5, 7); \
+  va6 = vld1q_f16(a + 8 * 6);                        \
+  v28 = vld1q_f16(b + 8 * 12);                       \
+  v29 = vld1q_f16(b + 8 * 13);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v28, va6, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v28, va6, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v28, va6, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v28, va6, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v28, va6, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v28, va6, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v28, va6, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v28, va6, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v29, va6, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v29, va6, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v29, va6, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v29, va6, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v29, va6, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v29, va6, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v29, va6, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v29, va6, 7); \
+  va7 = vld1q_f16(a + 8 * 7);                        \
+  v30 = vld1q_f16(b + 8 * 14);                       \
+  v31 = vld1q_f16(b + 8 * 15);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va7, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va7, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va7, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va7, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va7, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va7, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va7, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va7, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va7, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va7, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va7, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va7, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va7, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va7, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va7, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va7, 7); \
+  va7 = vld1q_f16(a + 8 * 8);                        \
+  v30 = vld1q_f16(b + 8 * 16);                       \
+  v31 = vld1q_f16(b + 8 * 17);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va7, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va7, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va7, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va7, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va7, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va7, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va7, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va7, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va7, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va7, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va7, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va7, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va7, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va7, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va7, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va7, 7); \
+  va7 = vld1q_f16(a + 8 * 9);                        \
+  v30 = vld1q_f16(b + 8 * 18);                       \
+  v31 = vld1q_f16(b + 8 * 19);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va7, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va7, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va7, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va7, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va7, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va7, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va7, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va7, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va7, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va7, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va7, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va7, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va7, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va7, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va7, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va7, 7); \
+  va7 = vld1q_f16(a + 8 * 10);                       \
+  v30 = vld1q_f16(b + 8 * 20);                       \
+  v31 = vld1q_f16(b + 8 * 21);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va7, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va7, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va7, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va7, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va7, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va7, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va7, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va7, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va7, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va7, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va7, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va7, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va7, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va7, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va7, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va7, 7); \
+  va7 = vld1q_f16(a + 8 * 11);                       \
+  v30 = vld1q_f16(b + 8 * 22);                       \
+  v31 = vld1q_f16(b + 8 * 23);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va7, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va7, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va7, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va7, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va7, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va7, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va7, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va7, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va7, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va7, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va7, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va7, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va7, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va7, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va7, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va7, 7); \
+  va7 = vld1q_f16(a + 8 * 12);                       \
+  v30 = vld1q_f16(b + 8 * 24);                       \
+  v31 = vld1q_f16(b + 8 * 25);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va7, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va7, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va7, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va7, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va7, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va7, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va7, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va7, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va7, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va7, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va7, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va7, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va7, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va7, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va7, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va7, 7); \
+  va7 = vld1q_f16(a + 8 * 13);                       \
+  v30 = vld1q_f16(b + 8 * 26);                       \
+  v31 = vld1q_f16(b + 8 * 27);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va7, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va7, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va7, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va7, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va7, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va7, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va7, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va7, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va7, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va7, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va7, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va7, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va7, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va7, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va7, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va7, 7); \
+  va7 = vld1q_f16(a + 8 * 14);                       \
+  v30 = vld1q_f16(b + 8 * 28);                       \
+  v31 = vld1q_f16(b + 8 * 29);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va7, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va7, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va7, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va7, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va7, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va7, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va7, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va7, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va7, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va7, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va7, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va7, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va7, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va7, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va7, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va7, 7); \
+  va7 = vld1q_f16(a + 8 * 15);                       \
+  v30 = vld1q_f16(b + 8 * 30);                       \
+  v31 = vld1q_f16(b + 8 * 31);                       \
+  v0_7 = vfmaq_laneq_f16(v0_7, v30, va7, 0);         \
+  v8_15 = vfmaq_laneq_f16(v8_15, v30, va7, 1);       \
+  v16_23 = vfmaq_laneq_f16(v16_23, v30, va7, 2);     \
+  v24_31 = vfmaq_laneq_f16(v24_31, v30, va7, 3);     \
+  v32_39 = vfmaq_laneq_f16(v32_39, v30, va7, 4);     \
+  v40_47 = vfmaq_laneq_f16(v40_47, v30, va7, 5);     \
+  v48_55 = vfmaq_laneq_f16(v48_55, v30, va7, 6);     \
+  v56_63 = vfmaq_laneq_f16(v56_63, v30, va7, 7);     \
+  v64_71 = vfmaq_laneq_f16(v64_71, v31, va7, 0);     \
+  v72_79 = vfmaq_laneq_f16(v72_79, v31, va7, 1);     \
+  v80_87 = vfmaq_laneq_f16(v80_87, v31, va7, 2);     \
+  v88_95 = vfmaq_laneq_f16(v88_95, v31, va7, 3);     \
+  v96_103 = vfmaq_laneq_f16(v96_103, v31, va7, 4);   \
+  v104_111 = vfmaq_laneq_f16(v104_111, v31, va7, 5); \
+  v112_119 = vfmaq_laneq_f16(v112_119, v31, va7, 6); \
+  v120_127 = vfmaq_laneq_f16(v120_127, v31, va7, 7); \
+  l += 16;                                           \
+  __builtin_prefetch(b + 256, 0, 3);                 \
+  __builtin_prefetch(a + 128, 0, 3);                 \
+  b += 16 * 16;                                      \
+  a += 8 * 16;
+
+// 1. Partial sum 1024 digits : Medium-high accuracy, medium latency
 #define KERNEL_8x16_ACC8()                           \
   v0_7 = vdupq_n_f16(0.F);                           \
   v8_15 = vdupq_n_f16(0.F);                          \
   b += 16 * 1;                                       \
   a += 8 * 1;
 
+#define SAVE_KERNEL_8X16_F16_F32()                                             \
+  vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0_7))));     \
+  vst1q_f32(c + 4,                                                             \
+            vaddq_f32(vld1q_f32(c + 4), vcvt_f32_f16(vget_high_f16(v0_7))));   \
+                                                                               \
+  vst1q_f32(c + 8,                                                             \
+            vaddq_f32(vld1q_f32(c + 8), vcvt_f32_f16(vget_low_f16(v64_71))));  \
+  vst1q_f32(c + 8 + 4, vaddq_f32(vld1q_f32(c + 8 + 4),                         \
+                                 vcvt_f32_f16(vget_high_f16(v64_71))));        \
+                                                                               \
+  vst1q_f32(c + ldc,                                                           \
+            vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(vget_low_f16(v8_15)))); \
+  vst1q_f32(c + ldc + 4, vaddq_f32(vld1q_f32(c + ldc + 4),                     \
+                                   vcvt_f32_f16(vget_high_f16(v8_15))));       \
+                                                                               \
+  vst1q_f32(c + ldc + 8, vaddq_f32(vld1q_f32(c + ldc + 8),                     \
+                                   vcvt_f32_f16(vget_low_f16(v72_79))));       \
+  vst1q_f32(c + ldc + 8 + 4, vaddq_f32(vld1q_f32(c + ldc + 8 + 4),             \
+                                       vcvt_f32_f16(vget_high_f16(v72_79))));  \
+                                                                               \
+  vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc),                     \
+                                   vcvt_f32_f16(vget_low_f16(v16_23))));       \
+  vst1q_f32(c + 2 * ldc + 4, vaddq_f32(vld1q_f32(c + 2 * ldc + 4),             \
+                                       vcvt_f32_f16(vget_high_f16(v16_23))));  \
+                                                                               \
+  vst1q_f32(c + 2 * ldc + 8, vaddq_f32(vld1q_f32(c + 2 * ldc + 8),             \
+                                       vcvt_f32_f16(vget_low_f16(v80_87))));   \
+  vst1q_f32(c + 2 * ldc + 8 + 4,                                               \
+            vaddq_f32(vld1q_f32(c + 2 * ldc + 8 + 4),                          \
+                      vcvt_f32_f16(vget_high_f16(v80_87))));                   \
+                                                                               \
+  vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc),                     \
+                                   vcvt_f32_f16(vget_low_f16(v24_31))));       \
+  vst1q_f32(c + 3 * ldc + 4, vaddq_f32(vld1q_f32(c + 3 * ldc + 4),             \
+                                       vcvt_f32_f16(vget_high_f16(v24_31))));  \
+                                                                               \
+  vst1q_f32(c + 3 * ldc + 8, vaddq_f32(vld1q_f32(c + 3 * ldc + 8),             \
+                                       vcvt_f32_f16(vget_low_f16(v88_95))));   \
+  vst1q_f32(c + 3 * ldc + 8 + 4,                                               \
+            vaddq_f32(vld1q_f32(c + 3 * ldc + 8 + 4),                          \
+                      vcvt_f32_f16(vget_high_f16(v88_95))));                   \
+                                                                               \
+  vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc),                     \
+                                   vcvt_f32_f16(vget_low_f16(v32_39))));       \
+  vst1q_f32(c + 4 * ldc + 4, vaddq_f32(vld1q_f32(c + 4 * ldc + 4),             \
+                                       vcvt_f32_f16(vget_high_f16(v32_39))));  \
+                                                                               \
+  vst1q_f32(c + 4 * ldc + 8, vaddq_f32(vld1q_f32(c + 4 * ldc + 8),             \
+                                       vcvt_f32_f16(vget_low_f16(v96_103))));  \
+  vst1q_f32(c + 4 * ldc + 8 + 4,                                               \
+            vaddq_f32(vld1q_f32(c + 4 * ldc + 8 + 4),                          \
+                      vcvt_f32_f16(vget_high_f16(v96_103))));                  \
+                                                                               \
+  vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc),                     \
+                                   vcvt_f32_f16(vget_low_f16(v40_47))));       \
+  vst1q_f32(c + 5 * ldc + 4, vaddq_f32(vld1q_f32(c + 5 * ldc + 4),             \
+                                       vcvt_f32_f16(vget_high_f16(v40_47))));  \
+  vst1q_f32(c + 5 * ldc + 8, vaddq_f32(vld1q_f32(c + 5 * ldc + 8),             \
+                                       vcvt_f32_f16(vget_low_f16(v104_111)))); \
+  vst1q_f32(c + 5 * ldc + 8 + 4,                                               \
+            vaddq_f32(vld1q_f32(c + 5 * ldc + 8 + 4),                          \
+                      vcvt_f32_f16(vget_high_f16(v104_111))));                 \
+                                                                               \
+  vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc),                     \
+                                   vcvt_f32_f16(vget_low_f16(v48_55))));       \
+  vst1q_f32(c + 6 * ldc + 4, vaddq_f32(vld1q_f32(c + 6 * ldc + 4),             \
+                                       vcvt_f32_f16(vget_high_f16(v48_55))));  \
+                                                                               \
+  vst1q_f32(c + 6 * ldc + 8, vaddq_f32(vld1q_f32(c + 6 * ldc + 8),             \
+                                       vcvt_f32_f16(vget_low_f16(v112_119)))); \
+  vst1q_f32(c + 6 * ldc + 8 + 4,                                               \
+            vaddq_f32(vld1q_f32(c + 6 * ldc + 8 + 4),                          \
+                      vcvt_f32_f16(vget_high_f16(v112_119))));                 \
+                                                                               \
+  vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc),                     \
+                                   vcvt_f32_f16(vget_low_f16(v56_63))));       \
+  vst1q_f32(c + 7 * ldc + 4, vaddq_f32(vld1q_f32(c + 7 * ldc + 4),             \
+                                       vcvt_f32_f16(vget_high_f16(v56_63))));  \
+                                                                               \
+  vst1q_f32(c + 7 * ldc + 8, vaddq_f32(vld1q_f32(c + 7 * ldc + 8),             \
+                                       vcvt_f32_f16(vget_low_f16(v120_127)))); \
+  vst1q_f32(c + 7 * ldc + 8 + 4,                                               \
+            vaddq_f32(vld1q_f32(c + 7 * ldc + 8 + 4),                          \
+                      vcvt_f32_f16(vget_high_f16(v120_127))));
+
 /**
  * @brief hgemm 8x16 kernel sc = sa * sb
  *
@@ -425,6 +838,9 @@ void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
   __fp16 *a = sa, *b = sb;
   float *c = sc;
   unsigned int i, j, l;
+  unsigned int K4 = (K >> 2) << 2;
+  unsigned int K8 = (K >> 3) << 3;
+  unsigned int K16 = (K >> 4) << 4;
   for (i = 0; i < M; i += 8) {
     for (j = 0; j < N; j += 16) {
       __builtin_prefetch(b, 0, 3);
@@ -440,106 +856,21 @@ void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
       float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
       float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
       l = 0;
-      for (; l < K;) {
+      for (; l < K16;) {
+        KERNEL_8x16_ACC16();
+        SAVE_KERNEL_8X16_F16_F32();
+      }
+      for (; l < K8;) {
         KERNEL_8x16_ACC8();
-
-        vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(vget_low_f16(v0_7))));
-        vst1q_f32(c + 4, vaddq_f32(vld1q_f32(c + 4),
-                                   vcvt_f32_f16(vget_high_f16(v0_7))));
-
-        vst1q_f32(c + 8, vaddq_f32(vld1q_f32(c + 8),
-                                   vcvt_f32_f16(vget_low_f16(v64_71))));
-        vst1q_f32(c + 8 + 4, vaddq_f32(vld1q_f32(c + 8 + 4),
-                                       vcvt_f32_f16(vget_high_f16(v64_71))));
-
-        vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc),
-                                     vcvt_f32_f16(vget_low_f16(v8_15))));
-        vst1q_f32(c + ldc + 4, vaddq_f32(vld1q_f32(c + ldc + 4),
-                                         vcvt_f32_f16(vget_high_f16(v8_15))));
-
-        vst1q_f32(c + ldc + 8, vaddq_f32(vld1q_f32(c + ldc + 8),
-                                         vcvt_f32_f16(vget_low_f16(v72_79))));
-        vst1q_f32(c + ldc + 8 + 4,
-                  vaddq_f32(vld1q_f32(c + ldc + 8 + 4),
-                            vcvt_f32_f16(vget_high_f16(v72_79))));
-
-        vst1q_f32(c + 2 * ldc, vaddq_f32(vld1q_f32(c + 2 * ldc),
-                                         vcvt_f32_f16(vget_low_f16(v16_23))));
-        vst1q_f32(c + 2 * ldc + 4,
-                  vaddq_f32(vld1q_f32(c + 2 * ldc + 4),
-                            vcvt_f32_f16(vget_high_f16(v16_23))));
-
-        vst1q_f32(c + 2 * ldc + 8,
-                  vaddq_f32(vld1q_f32(c + 2 * ldc + 8),
-                            vcvt_f32_f16(vget_low_f16(v80_87))));
-        vst1q_f32(c + 2 * ldc + 8 + 4,
-                  vaddq_f32(vld1q_f32(c + 2 * ldc + 8 + 4),
-                            vcvt_f32_f16(vget_high_f16(v80_87))));
-
-        vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc),
-                                         vcvt_f32_f16(vget_low_f16(v24_31))));
-        vst1q_f32(c + 3 * ldc + 4,
-                  vaddq_f32(vld1q_f32(c + 3 * ldc + 4),
-                            vcvt_f32_f16(vget_high_f16(v24_31))));
-
-        vst1q_f32(c + 3 * ldc + 8,
-                  vaddq_f32(vld1q_f32(c + 3 * ldc + 8),
-                            vcvt_f32_f16(vget_low_f16(v88_95))));
-        vst1q_f32(c + 3 * ldc + 8 + 4,
-                  vaddq_f32(vld1q_f32(c + 3 * ldc + 8 + 4),
-                            vcvt_f32_f16(vget_high_f16(v88_95))));
-
-        vst1q_f32(c + 4 * ldc, vaddq_f32(vld1q_f32(c + 4 * ldc),
-                                         vcvt_f32_f16(vget_low_f16(v32_39))));
-        vst1q_f32(c + 4 * ldc + 4,
-                  vaddq_f32(vld1q_f32(c + 4 * ldc + 4),
-                            vcvt_f32_f16(vget_high_f16(v32_39))));
-
-        vst1q_f32(c + 4 * ldc + 8,
-                  vaddq_f32(vld1q_f32(c + 4 * ldc + 8),
-                            vcvt_f32_f16(vget_low_f16(v96_103))));
-        vst1q_f32(c + 4 * ldc + 8 + 4,
-                  vaddq_f32(vld1q_f32(c + 4 * ldc + 8 + 4),
-                            vcvt_f32_f16(vget_high_f16(v96_103))));
-
-        vst1q_f32(c + 5 * ldc, vaddq_f32(vld1q_f32(c + 5 * ldc),
-                                         vcvt_f32_f16(vget_low_f16(v40_47))));
-        vst1q_f32(c + 5 * ldc + 4,
-                  vaddq_f32(vld1q_f32(c + 5 * ldc + 4),
-                            vcvt_f32_f16(vget_high_f16(v40_47))));
-
-        vst1q_f32(c + 5 * ldc + 8,
-                  vaddq_f32(vld1q_f32(c + 5 * ldc + 8),
-                            vcvt_f32_f16(vget_low_f16(v104_111))));
-        vst1q_f32(c + 5 * ldc + 8 + 4,
-                  vaddq_f32(vld1q_f32(c + 5 * ldc + 8 + 4),
-                            vcvt_f32_f16(vget_high_f16(v104_111))));
-
-        vst1q_f32(c + 6 * ldc, vaddq_f32(vld1q_f32(c + 6 * ldc),
-                                         vcvt_f32_f16(vget_low_f16(v48_55))));
-        vst1q_f32(c + 6 * ldc + 4,
-                  vaddq_f32(vld1q_f32(c + 6 * ldc + 4),
-                            vcvt_f32_f16(vget_high_f16(v48_55))));
-
-        vst1q_f32(c + 6 * ldc + 8,
-                  vaddq_f32(vld1q_f32(c + 6 * ldc + 8),
-                            vcvt_f32_f16(vget_low_f16(v112_119))));
-        vst1q_f32(c + 6 * ldc + 8 + 4,
-                  vaddq_f32(vld1q_f32(c + 6 * ldc + 8 + 4),
-                            vcvt_f32_f16(vget_high_f16(v112_119))));
-
-        vst1q_f32(c + 7 * ldc, vaddq_f32(vld1q_f32(c + 7 * ldc),
-                                         vcvt_f32_f16(vget_low_f16(v56_63))));
-        vst1q_f32(c + 7 * ldc + 4,
-                  vaddq_f32(vld1q_f32(c + 7 * ldc + 4),
-                            vcvt_f32_f16(vget_high_f16(v56_63))));
-
-        vst1q_f32(c + 7 * ldc + 8,
-                  vaddq_f32(vld1q_f32(c + 7 * ldc + 8),
-                            vcvt_f32_f16(vget_low_f16(v120_127))));
-        vst1q_f32(c + 7 * ldc + 8 + 4,
-                  vaddq_f32(vld1q_f32(c + 7 * ldc + 8 + 4),
-                            vcvt_f32_f16(vget_high_f16(v120_127))));
+        SAVE_KERNEL_8X16_F16_F32();
+      }
+      for (; l < K4;) {
+        KERNEL_8x16_ACC4();
+        SAVE_KERNEL_8X16_F16_F32();
+      }
+      for (; l < K;) {
+        KERNEL_8x16_ACC1();
+        SAVE_KERNEL_8X16_F16_F32();
       }
       c += 16;
       a -= 8 * K;