sbgemm: spr: kernel works for NN case when alpha is 1.0
authorWangyang Guo <wangyang.guo@intel.com>
Mon, 13 Sep 2021 02:22:58 +0000 (19:22 -0700)
committerWangyang Guo <wangyang.guo@intel.com>
Mon, 18 Oct 2021 02:08:03 +0000 (19:08 -0700)
kernel/x86_64/sbgemm_kernel_16x16_spr.c
kernel/x86_64/sbgemm_oncopy_16_spr.c

index 41d2634d54d82ac88bf9a1bcc21b13ea2c73f44c..b7b4e36a397d447e753dcd7d1e497e4327123197 100644 (file)
@@ -82,6 +82,12 @@ typedef struct {
        _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
        _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
 }
+#define MASK_LOAD_A_TAIL(M, N) {\
+       __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
+       __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
+       _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
+       _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
+}
 #define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
 #define LOAD_B_TAIL(M, N) {\
        __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
@@ -111,7 +117,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
        A = iB;
        B = iA;
 
-       printf("kernel: m %d, n %d, k %d, ldc: %d\n", m, n, k, ldc);
        IFLOAT *ptr_a = A, *ptr_b = B;
        IFLOAT *ptr_b0, *ptr_b1;
        IFLOAT *ptr_a0, *ptr_a1;
@@ -279,5 +284,133 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
                }
                ptr_a += 32 * k;
        }
+       for (; m_count > 0; m_count -= 16) {
+               // process at most 16 m at a time
+               int tail_m = (m_count > 16) ? 16: m_count;
+               __mmask16 amask = (1UL << tail_m) - 1;
+
+               ptr_b = B;
+
+               ptr_c00 = ptr_c;
+               ptr_c01 = ptr_c00 + 16;
+               ptr_c += tail_m * ldc;
+               n_count = n;
+               for (; n_count > 31; n_count -= 32) {
+                       ptr_a0 = ptr_a;
+
+                       ptr_b0 = ptr_b;
+                       ptr_b1 = ptr_b + 16 * k;
+                       ptr_b += 32 * k;
+
+                       lda = 32;
+                       ldb = 32;
+                       TCONF(cfg, tail_m, 16, 32);
+                       LOAD_C(0, 0); LOAD_C(0, 1);
+                       k_count = k;
+                       for (; k_count > 31; k_count -= 32) {
+                               LOAD_A(0, x);
+                               ptr_a0 += tail_m * 32;
+                               LOAD_B(x, 0); LOAD_B(x, 1);
+                               ptr_b0 += 16 * 32;
+                               ptr_b1 += 16 * 32;
+
+                               MATMUL(0, 0); MATMUL(0, 1);
+                       }
+                       STORE_C(0, 0); STORE_C(0, 1);
+                       if (k_count > 1) {
+                               /* still have more than 2*k */
+                               int remain_k2 = k_count & ~1;
+                               k_count -= remain_k2;
+                               lda = remain_k2;
+                               TCONF(cfg, tail_m, 16, remain_k2);
+                               /* reconfig will clear all tiles,
+                                * need to store/load again
+                                */
+                               LOAD_C(0, 0); LOAD_C(0, 1);
+
+                               LOAD_A(0, x);
+                               ptr_a0 += tail_m * remain_k2;
+                               LOAD_B(x, 0); LOAD_B(x, 1);
+                               ptr_b0 += 16 * remain_k2;
+                               ptr_b1 += 16 * remain_k2;
+
+                               MATMUL(0, 0); MATMUL(0, 1);
+
+                               STORE_C(0, 0); STORE_C(0, 1);
+                       }
+                       if (k_count > 0) {
+                               /* still have odd tail k, need to transform into 2*k */
+                               TCONF(cfg, tail_m, 16, 2);
+
+                               LOAD_C(0, 0); LOAD_C(0, 1);
+
+                               MASK_LOAD_A_TAIL(0, x);
+                               LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1);
+
+                               MATMUL(0, 0); MATMUL(0, 1);
+
+                               STORE_C(0, 0); STORE_C(0, 1);
+                       }
+                       ptr_c00 += 32;
+                       ptr_c01 += 32;
+               }
+               for (; n_count > 0; n_count -= 16) {
+                       int tail_n = (n_count > 16) ? 16: n_count;
+                       __mmask16 bmask = (1UL << tail_n) - 1;
+                       ptr_a0 = ptr_a;
+
+                       ptr_b0 = ptr_b;
+                       ptr_b += tail_n * k;
+
+                       lda = 32;
+                       ldb = 2 * tail_n;
+                       TCONF(cfg, tail_m, tail_n, 32);
+                       LOAD_C(0, 0);
+                       k_count = k;
+                       for (; k_count > 31; k_count -= 32) {
+                               LOAD_A(0, x);
+                               ptr_a0 += tail_m * 32;
+                               LOAD_B(x, 0);
+                               ptr_b0 += tail_n * 32;
+
+                               MATMUL(0, 0);
+                       }
+                       STORE_C(0, 0);
+                       if (k_count > 1) {
+                               /* still have more than 2*k */
+                               int remain_k2 = k_count & ~1;
+                               k_count -= remain_k2;
+                               lda = remain_k2;
+                               TCONF(cfg, tail_m, tail_n, remain_k2);
+                               /* reconfig will clear all tiles,
+                                * need to store/load again
+                                */
+                               LOAD_C(0, 0);
+
+                               LOAD_A(0, x);
+                               ptr_a0 += tail_m * remain_k2;
+                               LOAD_B(x, 0);
+                               ptr_b0 += tail_n * remain_k2;
+
+                               MATMUL(0, 0);
+
+                               STORE_C(0, 0);
+                       }
+                       if (k_count > 0) {
+                               /* still have odd tail k, need to transform into 2*k */
+                               TCONF(cfg, tail_m, tail_n, 2);
+
+                               LOAD_C(0, 0);
+
+                               MASK_LOAD_A_TAIL(0, x);
+                               MASK_LOAD_B_TAIL(x, 0);
+                               MATMUL(0, 0);
+
+                               STORE_C(0, 0);
+                       }
+                       ptr_c00 += tail_n;
+               }
+               ptr_a += tail_m * k;
+       }
        return 0;
 }
index da353d2c7a195083a0067bc595602e96b2725f11..f5668e26e6eb3fd6dee3e2f1f14f8020202ac3bc 100644 (file)
@@ -32,8 +32,8 @@
 #define MASK_COPY_32(N) _mm512_mask_storeu_epi16(boffset + tail_m * N, mmask, _mm512_maskz_loadu_epi16(mmask, aoffset##N + i))
 #define COPY_ODD_TAIL(N) *(boffset + N) = *(aoffset##N + i);
 
+
 int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) {
-       printf("ONCOPY: m %d, n %d, lda %d\n", m, n, lda);
        BLASLONG i, j;
        IFLOAT *aoffset, *boffset;
        IFLOAT *aoffset0, *aoffset1, *aoffset2, *aoffset3;