- I found there was a repeated usage of matrix initialization before mul-add fused operations.
- With separate initialization code, we can enjoy:
1. Cleaner code that is reusable for both f16 & f16-f32 kernel
2. Redundant init process is minimized for f16 kernel. Better latency with the SAME accuracy.
**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped
Signed-off-by: skykongkong8 <ss.kong@samsung.com>
#include <hgemm_common.h>
#include <stdlib.h>
+#define INIT_KERNEL_4x4() \
+ v24 = vdup_n_f16(0.F); \
+ v25 = vdup_n_f16(0.F); \
+ v26 = vdup_n_f16(0.F); \
+ v27 = vdup_n_f16(0.F);
+
// 1. Partial sum 256 digits : Medium accuracy, medium latency
#define KERNEL_4x4_ACC16() \
- v24 = vdup_n_f16(0.F); \
- v25 = vdup_n_f16(0.F); \
- v26 = vdup_n_f16(0.F); \
- v27 = vdup_n_f16(0.F); \
dv0 = vld1_f16(a); \
vb0 = vld1_f16(b); \
v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
- l += 16; \
+ l += 16; \
__builtin_prefetch(b + 64, 0, 3); \
__builtin_prefetch(a + 64, 0, 3); \
- b += 4 * 16; \
+ b += 4 * 16; \
a += 4 * 16;
// 2. Partial sum 128 digits : Medium accuracy, medium latency
#define KERNEL_4x4_ACC8() \
- v24 = vdup_n_f16(0.F); \
- v25 = vdup_n_f16(0.F); \
- v26 = vdup_n_f16(0.F); \
- v27 = vdup_n_f16(0.F); \
dv0 = vld1_f16(a); \
vb0 = vld1_f16(b); \
v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
// 2. Partial sum 16 digits : Best accuracy, worst latency
#define KERNEL_4x4_ACC1() \
- v24 = vdup_n_f16(0.F); \
- v25 = vdup_n_f16(0.F); \
- v26 = vdup_n_f16(0.F); \
- v27 = vdup_n_f16(0.F); \
dv0 = vld1_f16(a); \
vb0 = vld1_f16(b); \
v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
__builtin_prefetch(b, 0, 3);
__builtin_prefetch(a, 0, 3);
- float16x4_t v24 = {0};
- float16x4_t v25 = {0};
- float16x4_t v26 = {0};
- float16x4_t v27 = {0};
+ float16x4_t v24;
+ float16x4_t v25;
+ float16x4_t v26;
+ float16x4_t v27;
+ INIT_KERNEL_4x4();
for (l = 0; l < K; l += VL_FP16_HALF) {
float16x4_t v0 = vld1_f16(b);
float16x4_t vb0, vb1, vb2, vb3, vb4, vb5, vb6, vb7;
l = 0;
for (; l < K16;) {
+ INIT_KERNEL_4x4();
KERNEL_4x4_ACC16();
SAVE_KERNEL_4X4_F16_F32();
}
for (; l < K8;) {
+ INIT_KERNEL_4x4();
KERNEL_4x4_ACC8();
SAVE_KERNEL_4X4_F16_F32();
}
for (; l < K;) {
+ INIT_KERNEL_4x4();
KERNEL_4x4_ACC1();
SAVE_KERNEL_4X4_F16_F32();
}
#include <hgemm_common.h>
#include <stdlib.h>
-/// @note Following KERNELs are the combinations of accuracy-latency
-/// tradeoff. User can select which kernel to use by replacing them.
+#define INIT_KERNEL_4X8() \
+ v0 = vdupq_n_f16(0.F); \
+ v3 = vdupq_n_f16(0.F); \
+ v6 = vdupq_n_f16(0.F); \
+ v9 = vdupq_n_f16(0.F);
// 1. Partial sum 256 digits : worst accuracy, best latency
#define KERNEL_4x8_ACC16() \
- v0 = vdupq_n_f16(0.F); \
- v3 = vdupq_n_f16(0.F); \
- v6 = vdupq_n_f16(0.F); \
- v9 = vdupq_n_f16(0.F); \
dv0 = vld1_f16(a); \
v24 = vld1q_f16(b); \
v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
// 1. Partial sum 256 digits : worst accuracy, best latency
#define KERNEL_4x8_ACC8() \
- v0 = vdupq_n_f16(0.F); \
- v3 = vdupq_n_f16(0.F); \
- v6 = vdupq_n_f16(0.F); \
- v9 = vdupq_n_f16(0.F); \
dv0 = vld1_f16(a); \
v24 = vld1q_f16(b); \
v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
// 2. Partial sum 128 digits : medium accuracy, medium latency
#define KERNEL_4x8_ACC4() \
- v0 = vdupq_n_f16(0.F); \
- v3 = vdupq_n_f16(0.F); \
- v6 = vdupq_n_f16(0.F); \
- v9 = vdupq_n_f16(0.F); \
dv0 = vld1_f16(a); \
v24 = vld1q_f16(b); \
v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
// 3. Partial sum 32 digits : Best accuracy, worst latency
#define KERNEL_4x8_ACC1() \
- v0 = vdupq_n_f16(0.F); \
- v3 = vdupq_n_f16(0.F); \
- v6 = vdupq_n_f16(0.F); \
- v9 = vdupq_n_f16(0.F); \
dv0 = vld1_f16(a); \
v24 = vld1q_f16(b); \
v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
assert(M % 4 == 0 && N % 8 == 0);
__fp16 *a = sa, *b = sb, *c = sc;
- unsigned int k8 = (K >> 3) << 3;
+ unsigned int K8 = (K >> 3) << 3;
unsigned int i, j, l;
for (i = 0; i < M; i += 4) {
for (j = 0; j < N; j += 8) {
float16x8_t v0, v3, v6, v9;
float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
+ INIT_KERNEL_4X8();
l = 0;
- for (; l < k8;) {
+ for (; l < K8;) {
KERNEL_4x8_ACC8();
-
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
- vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
- vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
- vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
}
for (; l < K;) {
KERNEL_4x8_ACC1();
-
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
- vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
- vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
- vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
}
+ vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0));
+ vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v3));
+ vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v6));
+ vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v9));
c += 8;
a -= 4 * K;
}
float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
l = 0;
for (; l < K16;) {
+ INIT_KERNEL_4X8();
KERNEL_4x8_ACC16();
SAVE_KERNEL_4X8_F16_F32();
}
for (; l < K8;) {
+ INIT_KERNEL_4X8();
KERNEL_4x8_ACC8();
SAVE_KERNEL_4X8_F16_F32();
}
for (; l < K4;) {
+ INIT_KERNEL_4X8();
KERNEL_4x8_ACC4();
SAVE_KERNEL_4X8_F16_F32();
}
for (; l < K;) {
+ INIT_KERNEL_4X8();
KERNEL_4x8_ACC1();
SAVE_KERNEL_4X8_F16_F32();
}
#include <hgemm_common.h>
#include <stdlib.h>
-/// @note Following KERNELs are the combinations of accuracy-latency
-/// tradeoff. User can select which kernel to use by replacing them.
+#define INIT_KERNEL_8X16() \
+ v0_7 = vdupq_n_f16(0.F); \
+ v8_15 = vdupq_n_f16(0.F); \
+ v16_23 = vdupq_n_f16(0.F); \
+ v24_31 = vdupq_n_f16(0.F); \
+ v32_39 = vdupq_n_f16(0.F); \
+ v40_47 = vdupq_n_f16(0.F); \
+ v48_55 = vdupq_n_f16(0.F); \
+ v56_63 = vdupq_n_f16(0.F); \
+ v64_71 = vdupq_n_f16(0.F); \
+ v72_79 = vdupq_n_f16(0.F); \
+ v80_87 = vdupq_n_f16(0.F); \
+ v88_95 = vdupq_n_f16(0.F); \
+ v96_103 = vdupq_n_f16(0.F); \
+ v104_111 = vdupq_n_f16(0.F); \
+ v112_119 = vdupq_n_f16(0.F); \
+ v120_127 = vdupq_n_f16(0.F);
// 0. Partial sum 2048 digits : Best latency, worst accuracy.
#define KERNEL_8x16_ACC16() \
- v0_7 = vdupq_n_f16(0.F); \
- v8_15 = vdupq_n_f16(0.F); \
- v16_23 = vdupq_n_f16(0.F); \
- v24_31 = vdupq_n_f16(0.F); \
- v32_39 = vdupq_n_f16(0.F); \
- v40_47 = vdupq_n_f16(0.F); \
- v48_55 = vdupq_n_f16(0.F); \
- v56_63 = vdupq_n_f16(0.F); \
- v64_71 = vdupq_n_f16(0.F); \
- v72_79 = vdupq_n_f16(0.F); \
- v80_87 = vdupq_n_f16(0.F); \
- v88_95 = vdupq_n_f16(0.F); \
- v96_103 = vdupq_n_f16(0.F); \
- v104_111 = vdupq_n_f16(0.F); \
- v112_119 = vdupq_n_f16(0.F); \
- v120_127 = vdupq_n_f16(0.F); \
va0 = vld1q_f16(a); \
v24 = vld1q_f16(b); \
v25 = vld1q_f16(b + 8); \
// 1. Partial sum 1024 digits : Medium-high accuracy, medium latency
#define KERNEL_8x16_ACC8() \
- v0_7 = vdupq_n_f16(0.F); \
- v8_15 = vdupq_n_f16(0.F); \
- v16_23 = vdupq_n_f16(0.F); \
- v24_31 = vdupq_n_f16(0.F); \
- v32_39 = vdupq_n_f16(0.F); \
- v40_47 = vdupq_n_f16(0.F); \
- v48_55 = vdupq_n_f16(0.F); \
- v56_63 = vdupq_n_f16(0.F); \
- v64_71 = vdupq_n_f16(0.F); \
- v72_79 = vdupq_n_f16(0.F); \
- v80_87 = vdupq_n_f16(0.F); \
- v88_95 = vdupq_n_f16(0.F); \
- v96_103 = vdupq_n_f16(0.F); \
- v104_111 = vdupq_n_f16(0.F); \
- v112_119 = vdupq_n_f16(0.F); \
- v120_127 = vdupq_n_f16(0.F); \
va0 = vld1q_f16(a); \
v24 = vld1q_f16(b); \
v25 = vld1q_f16(b + 8); \
// 2. Partial sum 512 digits : Medium accuracy, medium latency
#define KERNEL_8x16_ACC4() \
- v0_7 = vdupq_n_f16(0.F); \
- v8_15 = vdupq_n_f16(0.F); \
- v16_23 = vdupq_n_f16(0.F); \
- v24_31 = vdupq_n_f16(0.F); \
- v32_39 = vdupq_n_f16(0.F); \
- v40_47 = vdupq_n_f16(0.F); \
- v48_55 = vdupq_n_f16(0.F); \
- v56_63 = vdupq_n_f16(0.F); \
- v64_71 = vdupq_n_f16(0.F); \
- v72_79 = vdupq_n_f16(0.F); \
- v80_87 = vdupq_n_f16(0.F); \
- v88_95 = vdupq_n_f16(0.F); \
- v96_103 = vdupq_n_f16(0.F); \
- v104_111 = vdupq_n_f16(0.F); \
- v112_119 = vdupq_n_f16(0.F); \
- v120_127 = vdupq_n_f16(0.F); \
va0 = vld1q_f16(a); \
v24 = vld1q_f16(b); \
v25 = vld1q_f16(b + 8); \
// 3. Partial sum 128 digits : Best accuracy, worst latency
#define KERNEL_8x16_ACC1() \
- v0_7 = vdupq_n_f16(0.F); \
- v8_15 = vdupq_n_f16(0.F); \
- v16_23 = vdupq_n_f16(0.F); \
- v24_31 = vdupq_n_f16(0.F); \
- v32_39 = vdupq_n_f16(0.F); \
- v40_47 = vdupq_n_f16(0.F); \
- v48_55 = vdupq_n_f16(0.F); \
- v56_63 = vdupq_n_f16(0.F); \
- v64_71 = vdupq_n_f16(0.F); \
- v72_79 = vdupq_n_f16(0.F); \
- v80_87 = vdupq_n_f16(0.F); \
- v88_95 = vdupq_n_f16(0.F); \
- v96_103 = vdupq_n_f16(0.F); \
- v104_111 = vdupq_n_f16(0.F); \
- v112_119 = vdupq_n_f16(0.F); \
- v120_127 = vdupq_n_f16(0.F); \
va0 = vld1q_f16(a); \
v24 = vld1q_f16(b); \
v25 = vld1q_f16(b + 8); \
float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
float16x8_t va0, va1, va2, va3;
+
+ INIT_KERNEL_8X16();
l = 0;
for (; l < K;) {
- KERNEL_8x16_ACC4();
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
+ KERNEL_8x16_ACC1();
+ }
+ vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
vst1q_f16(c + 7 * ldc + 8,
vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
- }
c += 16;
a -= 8 * K;
}
float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
l = 0;
for (; l < K16;) {
+ INIT_KERNEL_8X16();
KERNEL_8x16_ACC16();
SAVE_KERNEL_8X16_F16_F32();
}
for (; l < K8;) {
+ INIT_KERNEL_8X16();
KERNEL_8x16_ACC8();
SAVE_KERNEL_8X16_F16_F32();
}
for (; l < K4;) {
+ INIT_KERNEL_8X16();
KERNEL_8x16_ACC4();
SAVE_KERNEL_8X16_F16_F32();
}
for (; l < K;) {
+ INIT_KERNEL_8X16();
KERNEL_8x16_ACC1();
SAVE_KERNEL_8X16_F16_F32();
}
#include <hgemm_common.h>
#include <stdlib.h>
-/// @note Following KERNELs are the combinations of accuracy-latency
-/// tradeoff. User can select which kernel to use by replacing them.
+#define INIT_KERNEL_8x8() \
+ v24 = vdupq_n_f16(0.F); \
+ v25 = vdupq_n_f16(0.F); \
+ v26 = vdupq_n_f16(0.F); \
+ v27 = vdupq_n_f16(0.F); \
+ v28 = vdupq_n_f16(0.F); \
+ v29 = vdupq_n_f16(0.F); \
+ v30 = vdupq_n_f16(0.F); \
+ v31 = vdupq_n_f16(0.F);
// 1. Partial sum 1024 digits : Worst accuracy, best latency
#define KERNEL_8x8_ACC16() \
- v24 = vdupq_n_f16(0.F); \
- v25 = vdupq_n_f16(0.F); \
- v26 = vdupq_n_f16(0.F); \
- v27 = vdupq_n_f16(0.F); \
- v28 = vdupq_n_f16(0.F); \
- v29 = vdupq_n_f16(0.F); \
- v30 = vdupq_n_f16(0.F); \
- v31 = vdupq_n_f16(0.F); \
va0 = vld1q_f16(a); \
v16 = vld1q_f16(b); \
v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
// 2. Partial sum 512 digits : Medium accuracy, medium latency
#define KERNEL_8x8_ACC8() \
- v24 = vdupq_n_f16(0.F); \
- v25 = vdupq_n_f16(0.F); \
- v26 = vdupq_n_f16(0.F); \
- v27 = vdupq_n_f16(0.F); \
- v28 = vdupq_n_f16(0.F); \
- v29 = vdupq_n_f16(0.F); \
- v30 = vdupq_n_f16(0.F); \
- v31 = vdupq_n_f16(0.F); \
va0 = vld1q_f16(a); \
v16 = vld1q_f16(b); \
v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
// 3. Partial sum 256 digits : Medium accuracy, medium latency
#define KERNEL_8x8_ACC4() \
- v24 = vdupq_n_f16(0.F); \
- v25 = vdupq_n_f16(0.F); \
- v26 = vdupq_n_f16(0.F); \
- v27 = vdupq_n_f16(0.F); \
- v28 = vdupq_n_f16(0.F); \
- v29 = vdupq_n_f16(0.F); \
- v30 = vdupq_n_f16(0.F); \
- v31 = vdupq_n_f16(0.F); \
va0 = vld1q_f16(a); \
v16 = vld1q_f16(b); \
v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
// 4. Partial sum 64 digits : Best accuracy, worst latency
#define KERNEL_8x8_ACC1() \
- v24 = vdupq_n_f16(0.F); \
- v25 = vdupq_n_f16(0.F); \
- v26 = vdupq_n_f16(0.F); \
- v27 = vdupq_n_f16(0.F); \
- v28 = vdupq_n_f16(0.F); \
- v29 = vdupq_n_f16(0.F); \
- v30 = vdupq_n_f16(0.F); \
- v31 = vdupq_n_f16(0.F); \
va0 = vld1q_f16(a); \
v16 = vld1q_f16(b); \
v24 = vfmaq_laneq_f16(v24, v16, va0, 0); \
float16x8_t v16, v17, v18, v19, v20, v21, v22, v23;
float16x8_t v24, v25, v26, v27, v28, v29, v30, v31;
float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
+ INIT_KERNEL_8x8();
l = 0;
for (; l < K;) {
- KERNEL_8x8_ACC8();
-
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24));
- vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25));
- vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26));
- vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27));
- vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28));
- vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29));
- vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30));
- vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31));
+ KERNEL_8x8_ACC1();
}
+ vst1q_f16(c, vaddq_f16(vld1q_f16(c), v24));
+ vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v25));
+ vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v26));
+ vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v27));
+ vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v28));
+ vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v29));
+ vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v30));
+ vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v31));
c += 8;
a -= 8 * K;
}
float16x8_t va0, va1, va2, va3, va4, va5, va6, va7;
l = 0;
for (; l < K16;) {
+ INIT_KERNEL_8x8();
KERNEL_8x8_ACC16();
SAVE_KERNEL_8X8_F16_f32();
}
for (; l < K8;) {
+ INIT_KERNEL_8x8();
KERNEL_8x8_ACC8();
SAVE_KERNEL_8X8_F16_f32();
}
for (; l < K4;) {
+ INIT_KERNEL_8x8();
KERNEL_8x8_ACC4();
SAVE_KERNEL_8X8_F16_f32();
}
for (; l < K;) {
+ INIT_KERNEL_8x8();
KERNEL_8x8_ACC1();
SAVE_KERNEL_8X8_F16_f32();
}