hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta);
} else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) {
hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
- } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
+ } else if ((N & 0x7) == 0 && (K & 0x7) == 0) {
hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
- } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
+ } else if ((N & 0x3) == 0 && (K & 0x7) == 0) {
hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
}
}
v112_119 = vdupq_n_f16(0.F); \
v120_127 = vdupq_n_f16(0.F);
-// 0. Partial sum 2048 digits : Best latency, worst accuracy.
+// 1. Partial sum 2048 digits
#define KERNEL_8x16_ACC16() \
va0 = vld1q_f16(a); \
v24 = vld1q_f16(b); \
b += 16 * 16; \
a += 8 * 16;
-// 1. Partial sum 1024 digits : Medium-high accuracy, medium latency
+// 2. Partial sum 1024 digits
#define KERNEL_8x16_ACC8() \
va0 = vld1q_f16(a); \
v24 = vld1q_f16(b); \
b += 16 * 8; \
a += 8 * 8;
-// 2. Partial sum 512 digits : Medium accuracy, medium latency
+// 3. Partial sum 512 digits
#define KERNEL_8x16_ACC4() \
va0 = vld1q_f16(a); \
v24 = vld1q_f16(b); \
b += 16 * 4; \
a += 8 * 4;
-// 3. Partial sum 128 digits : Best accuracy, worst latency
+// 3. Partial sum 128 digits
#define KERNEL_8x16_ACC1() \
va0 = vld1q_f16(a); \
v24 = vld1q_f16(b); \
for (; l < K;) {
KERNEL_8x16_ACC1();
}
- vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
- vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
- vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
- vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
- vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23));
- vst1q_f16(c + 2 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87));
- vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31));
- vst1q_f16(c + 3 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95));
- vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39));
- vst1q_f16(c + 4 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103));
- vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47));
- vst1q_f16(c + 5 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111));
- vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55));
- vst1q_f16(c + 6 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119));
- vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
- vst1q_f16(c + 7 * ldc + 8,
- vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
+ vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7));
+ vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71));
+ vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15));
+ vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79));
+ vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23));
+ vst1q_f16(c + 2 * ldc + 8, vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87));
+ vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31));
+ vst1q_f16(c + 3 * ldc + 8, vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95));
+ vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39));
+ vst1q_f16(c + 4 * ldc + 8,
+ vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103));
+ vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47));
+ vst1q_f16(c + 5 * ldc + 8,
+ vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111));
+ vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55));
+ vst1q_f16(c + 6 * ldc + 8,
+ vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119));
+ vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63));
+ vst1q_f16(c + 7 * ldc + 8,
+ vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127));
c += 16;
a -= 8 * K;
}