_mm512_storeu_epi16(tail_a + 16 * M, zmm); \
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
}
+#define MASK_LOAD_A_TAIL(M, N) {\
+ __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
+ __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
+ _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
+ _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
+}
#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
#define LOAD_B_TAIL(M, N) {\
__m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
A = iB;
B = iA;
- printf("kernel: m %d, n %d, k %d, ldc: %d\n", m, n, k, ldc);
IFLOAT *ptr_a = A, *ptr_b = B;
IFLOAT *ptr_b0, *ptr_b1;
IFLOAT *ptr_a0, *ptr_a1;
}
ptr_a += 32 * k;
}
+ for (; m_count > 0; m_count -= 16) {
+ // process at most 16 m at a time
+ int tail_m = (m_count > 16) ? 16: m_count;
+ __mmask16 amask = (1UL << tail_m) - 1;
+
+ ptr_b = B;
+
+ ptr_c00 = ptr_c;
+ ptr_c01 = ptr_c00 + 16;
+ ptr_c += tail_m * ldc;
+ n_count = n;
+ for (; n_count > 31; n_count -= 32) {
+ ptr_a0 = ptr_a;
+
+ ptr_b0 = ptr_b;
+ ptr_b1 = ptr_b + 16 * k;
+ ptr_b += 32 * k;
+
+ lda = 32;
+ ldb = 32;
+ TCONF(cfg, tail_m, 16, 32);
+ LOAD_C(0, 0); LOAD_C(0, 1);
+ k_count = k;
+ for (; k_count > 31; k_count -= 32) {
+ LOAD_A(0, x);
+ ptr_a0 += tail_m * 32;
+ LOAD_B(x, 0); LOAD_B(x, 1);
+ ptr_b0 += 16 * 32;
+ ptr_b1 += 16 * 32;
+
+ MATMUL(0, 0); MATMUL(0, 1);
+ }
+ STORE_C(0, 0); STORE_C(0, 1);
+ if (k_count > 1) {
+ /* still have more than 2*k */
+ int remain_k2 = k_count & ~1;
+ k_count -= remain_k2;
+ lda = remain_k2;
+ TCONF(cfg, tail_m, 16, remain_k2);
+ /* reconfig will clear all tiles,
+ * need to store/load again
+ */
+ LOAD_C(0, 0); LOAD_C(0, 1);
+
+ LOAD_A(0, x);
+ ptr_a0 += tail_m * remain_k2;
+ LOAD_B(x, 0); LOAD_B(x, 1);
+ ptr_b0 += 16 * remain_k2;
+ ptr_b1 += 16 * remain_k2;
+
+ MATMUL(0, 0); MATMUL(0, 1);
+
+ STORE_C(0, 0); STORE_C(0, 1);
+ }
+ if (k_count > 0) {
+ /* still have odd tail k, need to transform into 2*k */
+ TCONF(cfg, tail_m, 16, 2);
+
+ LOAD_C(0, 0); LOAD_C(0, 1);
+
+ MASK_LOAD_A_TAIL(0, x);
+ LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1);
+
+ MATMUL(0, 0); MATMUL(0, 1);
+
+ STORE_C(0, 0); STORE_C(0, 1);
+ }
+ ptr_c00 += 32;
+ ptr_c01 += 32;
+ }
+ for (; n_count > 0; n_count -= 16) {
+ int tail_n = (n_count > 16) ? 16: n_count;
+ __mmask16 bmask = (1UL << tail_n) - 1;
+ ptr_a0 = ptr_a;
+
+ ptr_b0 = ptr_b;
+ ptr_b += tail_n * k;
+
+ lda = 32;
+ ldb = 2 * tail_n;
+ TCONF(cfg, tail_m, tail_n, 32);
+ LOAD_C(0, 0);
+ k_count = k;
+ for (; k_count > 31; k_count -= 32) {
+ LOAD_A(0, x);
+ ptr_a0 += tail_m * 32;
+ LOAD_B(x, 0);
+ ptr_b0 += tail_n * 32;
+
+ MATMUL(0, 0);
+ }
+ STORE_C(0, 0);
+ if (k_count > 1) {
+ /* still have more than 2*k */
+ int remain_k2 = k_count & ~1;
+ k_count -= remain_k2;
+ lda = remain_k2;
+ TCONF(cfg, tail_m, tail_n, remain_k2);
+ /* reconfig will clear all tiles,
+ * need to store/load again
+ */
+ LOAD_C(0, 0);
+
+ LOAD_A(0, x);
+ ptr_a0 += tail_m * remain_k2;
+ LOAD_B(x, 0);
+ ptr_b0 += tail_n * remain_k2;
+
+ MATMUL(0, 0);
+
+ STORE_C(0, 0);
+ }
+ if (k_count > 0) {
+ /* still have odd tail k, need to transform into 2*k */
+ TCONF(cfg, tail_m, tail_n, 2);
+
+ LOAD_C(0, 0);
+
+ MASK_LOAD_A_TAIL(0, x);
+ MASK_LOAD_B_TAIL(x, 0);
+ MATMUL(0, 0);
+
+ STORE_C(0, 0);
+ }
+ ptr_c00 += tail_n;
+ }
+ ptr_a += tail_m * k;
+ }
return 0;
}