sbgemm: spr: kernel handle alpha != 1.0
authorWangyang Guo <wangyang.guo@intel.com>
Fri, 17 Sep 2021 07:48:52 +0000 (00:48 -0700)
committerWangyang Guo <wangyang.guo@intel.com>
Mon, 18 Oct 2021 02:08:03 +0000 (19:08 -0700)
kernel/x86_64/sbgemm_kernel_16x16_spr.c
kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c [new file with mode: 0644]

index b340358..955db31 100644 (file)
  * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  * *****************************************************************************/
 
-#include <immintrin.h>
-#include <string.h>
 #include "common.h"
 
-typedef struct {
-       char palette_id;
-       char start_row;
-       char dummy0[14];  // bytes 2-15 reserved, must be zero
-       short tile_colsb[8];
-       char dummy1[16];  // bytes 32-47 reserved, must be zero
-       char tile_rows[8];
-       char dummy2[16];  // bytes 56-63 reserved, must be zero
-} tilecfg;
-
-/* tile0/tile1 -- A (m x 2k)
- * tile2/tile3 -- B (2k x n)
- * tile4-7 -- C (m x n)
- */
-#define TCONF(cfg, m, n, k2) \
-       memset(&cfg, 0, sizeof(tilecfg)); \
-       cfg.palette_id = 1; \
-       cfg.tile_rows[0] = m; \
-       cfg.tile_rows[1] = m; \
-       cfg.tile_rows[2] = k2>>1; \
-       cfg.tile_rows[3] = k2>>1; \
-       cfg.tile_rows[4] = m; \
-       cfg.tile_rows[5] = m; \
-       cfg.tile_rows[6] = m; \
-       cfg.tile_rows[7] = m; \
-       cfg.tile_colsb[0] = k2<<1; \
-       cfg.tile_colsb[1] = k2<<1; \
-       cfg.tile_colsb[2] = n * 4; \
-       cfg.tile_colsb[3] = n * 4; \
-       cfg.tile_colsb[4] = n * 4; \
-       cfg.tile_colsb[5] = n * 4; \
-       cfg.tile_colsb[6] = n * 4; \
-       cfg.tile_colsb[7] = n * 4; \
-       _tile_loadconfig(&cfg);
-
-/* CONFIG for handling k2 and odd tail at the same time
- * tile0 -- A (m x 2k)
- * tile1 -- A (m x 1)
- * tile2 -- B (2k x n)
- * tile3 -- B (1 x n)
- * tile4 -- C (m x n)
- */
-#define TCONF_TAIL(cfg, m, n, k2) \
-       memset(&cfg, 0, sizeof(tilecfg)); \
-       cfg.palette_id = 1; \
-       cfg.tile_rows[0] = m; \
-       cfg.tile_rows[1] = m; \
-       cfg.tile_rows[2] = k2>>1; \
-       cfg.tile_rows[3] = 1; \
-       cfg.tile_rows[4] = m; \
-       cfg.tile_colsb[0] = k2<<1; \
-       cfg.tile_colsb[1] = 4; \
-       cfg.tile_colsb[2] = n * 4; \
-       cfg.tile_colsb[3] = n * 4; \
-       cfg.tile_colsb[4] = n * 4; \
-       _tile_loadconfig(&cfg);
-
-#define T_A0   0
-#define T_A1   1
-#define T_B0   2
-#define T_B1   3
-#define T_C00  4
-#define T_C01  5
-#define T_C10  6
-#define T_C11  7
-
-// FIXME: gcc11 seem have problem in tile load/store address calc,
-// need to multiply with element size (2 or 4) here.
-#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2)
-#define LOAD_A_TAIL(M, N) {\
-       __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \
-       __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
-       _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
-       _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
-}
-#define MASK_LOAD_A_TAIL(M, N) {\
-       __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
-       __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
-       _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
-       _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
-}
-#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
-#define LOAD_B_TAIL(M, N) {\
-       __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
-       __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
-       _mm512_storeu_epi16(tail_b + 16 * N, zmm); \
-       _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
-}
-#define MASK_LOAD_B_TAIL(M, N) {\
-       __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \
-       __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
-       _mm512_storeu_epi16(tail_b + 16 * N, zmm); \
-       _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
-}
-
-#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
-#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)
-#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N)
-#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4)
+#define ALPHA_ONE
+#include "sbgemm_kernel_16x16_spr_tmpl.c"
+#undef ALPHA_ONE
+#include "sbgemm_kernel_16x16_spr_tmpl.c"
 
 
 int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOAT * iB, FLOAT * C, BLASLONG ldc)
@@ -140,287 +43,8 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
        A = iB;
        B = iA;
 
-       IFLOAT *ptr_a = A, *ptr_b = B;
-       IFLOAT *ptr_b0, *ptr_b1;
-       IFLOAT *ptr_a0, *ptr_a1;
-       FLOAT *ptr_c = C;
-       FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11;
-
-       BLASLONG lda, ldb;
-       BLASLONG m_count = m;
-       BLASLONG n_count, k_count;
-
-
-       IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64)));
-       IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64)));
-       tilecfg cfg;
-
-       if (k < 32)
-               goto tail_k;
-
-       for (; m_count > 31; m_count -= 32) {
-               ptr_b = B;
-
-               ptr_c00 = ptr_c;
-               ptr_c01 = ptr_c00 + 16;
-               ptr_c10 = ptr_c + 16 * ldc;
-               ptr_c11 = ptr_c10 + 16;
-               ptr_c += 32 * ldc;
-               n_count = n;
-               TCONF(cfg, 16, 16, 32);
-               for (; n_count > 31; n_count -= 32) {
-                       ptr_a0 = ptr_a;
-                       ptr_a1 = ptr_a + 16 * k;
-
-                       ptr_b0 = ptr_b;
-                       ptr_b1 = ptr_b + 16 * k;
-                       ptr_b += 32 * k;
-
-                       lda = 32;
-                       ldb = 32;
-                       LOAD_C(0, 0); LOAD_C(0, 1);
-                       LOAD_C(1, 0); LOAD_C(1, 1);
-                       k_count = k;
-                       for (; k_count > 31; k_count -= 32) {
-                               LOAD_A(0, x); LOAD_A(1, x);
-                               ptr_a0 += 16 * 32;
-                               ptr_a1 += 16 * 32;
-                               LOAD_B(x, 0); LOAD_B(x, 1);
-                               ptr_b0 += 16 * 32;
-                               ptr_b1 += 16 * 32;
-
-                               MATMUL(0, 0); MATMUL(0, 1);
-                               MATMUL(1, 0); MATMUL(1, 1);
-                       }
-                       STORE_C(0, 0); STORE_C(0, 1);
-                       STORE_C(1, 0); STORE_C(1, 1);
-                       ptr_c00 += 32;
-                       ptr_c01 += 32;
-                       ptr_c10 += 32;
-                       ptr_c11 += 32;
-               }
-               for (; n_count > 0; n_count -= 16) {
-                       int tail_n = (n_count > 16) ? 16: n_count;
-                       __mmask16 bmask = (1UL << tail_n) - 1;
-                       ptr_a0 = ptr_a;
-                       ptr_a1 = ptr_a + 16 * k;
-
-                       ptr_b0 = ptr_b;
-                       ptr_b += tail_n * k;
-
-                       lda = 32;
-                       ldb = 2 * tail_n;
-                       TCONF(cfg, 16, tail_n, 32);
-                       LOAD_C(0, 0);
-                       LOAD_C(1, 0);
-                       k_count = k;
-                       for (; k_count > 31; k_count -= 32) {
-                               LOAD_A(0, x); LOAD_A(1, x);
-                               ptr_a0 += 16 * 32;
-                               ptr_a1 += 16 * 32;
-                               LOAD_B(x, 0);
-                               ptr_b0 += tail_n * 32;
-
-                               MATMUL(0, 0);
-                               MATMUL(1, 0);
-                       }
-                       STORE_C(0, 0);
-                       STORE_C(1, 0);
-                       ptr_c00 += tail_n;
-                       ptr_c10 += tail_n;
-               }
-               ptr_a += 32 * k;
-       }
-       for (; m_count > 0; m_count -= 16) {
-               // process at most 16 m at a time
-               int tail_m = (m_count > 16) ? 16: m_count;
-               __mmask16 amask = (1UL << tail_m) - 1;
-
-               ptr_b = B;
-
-               ptr_c00 = ptr_c;
-               ptr_c01 = ptr_c00 + 16;
-               ptr_c += tail_m * ldc;
-               n_count = n;
-               TCONF(cfg, tail_m, 16, 32);
-               for (; n_count > 31; n_count -= 32) {
-                       ptr_a0 = ptr_a;
-
-                       ptr_b0 = ptr_b;
-                       ptr_b1 = ptr_b + 16 * k;
-                       ptr_b += 32 * k;
-
-                       lda = 32;
-                       ldb = 32;
-                       LOAD_C(0, 0); LOAD_C(0, 1);
-                       k_count = k;
-                       for (; k_count > 31; k_count -= 32) {
-                               LOAD_A(0, x);
-                               ptr_a0 += tail_m * 32;
-                               LOAD_B(x, 0); LOAD_B(x, 1);
-                               ptr_b0 += 16 * 32;
-                               ptr_b1 += 16 * 32;
-
-                               MATMUL(0, 0); MATMUL(0, 1);
-                       }
-                       STORE_C(0, 0); STORE_C(0, 1);
-                       ptr_c00 += 32;
-                       ptr_c01 += 32;
-               }
-               for (; n_count > 0; n_count -= 16) {
-                       int tail_n = (n_count > 16) ? 16: n_count;
-                       __mmask16 bmask = (1UL << tail_n) - 1;
-                       ptr_a0 = ptr_a;
-
-                       ptr_b0 = ptr_b;
-                       ptr_b += tail_n * k;
-
-                       lda = 32;
-                       ldb = 2 * tail_n;
-                       TCONF(cfg, tail_m, tail_n, 32);
-                       LOAD_C(0, 0);
-                       k_count = k;
-                       for (; k_count > 31; k_count -= 32) {
-                               LOAD_A(0, x);
-                               ptr_a0 += tail_m * 32;
-                               LOAD_B(x, 0);
-                               ptr_b0 += tail_n * 32;
-
-                               MATMUL(0, 0);
-                       }
-                       STORE_C(0, 0);
-                       ptr_c00 += tail_n;
-               }
-               ptr_a += tail_m * k;
-       }
-
-tail_k:
-       // process for k < 32
-       BLASLONG k32 = k & ~31;
-       BLASLONG k2 = k & ~1;
-       if (k32 != k) {
-               int remain_k2 = k2 - k32;
-               m_count = m;
-               ptr_a = A;
-               ptr_c = C;
-               if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0)
-                       for (; m_count > 0; m_count -= 16) {
-                               int tail_m = (m_count > 16) ? 16: m_count;
-                               __mmask16 amask = (1UL << tail_m) - 1;
-
-                               ptr_a0 = ptr_a + tail_m * k32;
-                               ptr_a1 = ptr_a + tail_m * k2;
-                               ptr_a += tail_m * k;
-                               ptr_b = B;
-                               ptr_c00 = ptr_c;
-                               ptr_c += tail_m * ldc;
-                               n_count = n;
-                               lda = remain_k2;
-                               ldb = 32;
-                               if (n_count > 15) {
-                                       TCONF_TAIL(cfg, tail_m, 16, remain_k2);
-                                       LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
-                                       for (; n_count > 15; n_count -= 16) {
-                                               ptr_b0 = ptr_b + 16 * k32;
-                                               ptr_b1 = ptr_b + 16 * k2;
-                                               LOAD_C(0, 0);
-                                               LOAD_B(x, 0); LOAD_B_TAIL(x, 1);
-                                               MATMUL(0, 0); MATMUL_TAIL(1, 1);
-                                               STORE_C(0, 0);
-                                               ptr_b += 16 * k;
-                                               ptr_c00 += 16;
-                                       }
-                               }
-                               if (n_count > 0) {
-                                       int tail_n = (n_count > 16) ? 16: n_count;
-                                       __mmask16 bmask = (1UL << tail_n) - 1;
-                                       ptr_b0 =  ptr_b + tail_n * k32;
-                                       ptr_b1 =  ptr_b + tail_n * k2;
-                                       ldb = 2 * tail_n;
-                                       TCONF_TAIL(cfg, tail_m, tail_n, remain_k2);
-                                       LOAD_C(0, 0);
-                                       LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
-                                       LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1);
-                                       MATMUL(0, 0); MATMUL_TAIL(1, 1);
-                                       STORE_C(0, 0);
-                               }
-                       }
-
-               } else if (remain_k2 > 0) { // k%32 = 2x
-                       for (; m_count > 0; m_count -= 16) {
-                               int tail_m = (m_count > 16) ? 16: m_count;
-
-                               ptr_a0 = ptr_a + tail_m * k32;
-                               ptr_a += tail_m * k;
-                               ptr_b = B;
-                               ptr_c00 = ptr_c;
-                               ptr_c += tail_m * ldc;
-                               n_count = n;
-                               lda = remain_k2;
-                               ldb = 32;
-                               if (n_count > 15) {
-                                       TCONF(cfg, tail_m, 16, remain_k2);
-                                       LOAD_A(0, x);
-                                       for (; n_count > 15; n_count -= 16) {
-                                               ptr_b0 = ptr_b + 16 * k32;
-                                               LOAD_C(0, 0);
-                                               LOAD_B(x, 0);
-                                               MATMUL(0, 0);
-                                               STORE_C(0, 0);
-                                               ptr_b += 16 * k;
-                                               ptr_c00 += 16;
-                                       }
-                               }
-                               if (n_count > 0) {
-                                       int tail_n = (n_count > 16) ? 16: n_count;
-                                       ptr_b0 =  ptr_b + tail_n * k32;
-                                       ldb = 2 * tail_n;
-                                       TCONF(cfg, tail_m, tail_n, remain_k2);
-                                       LOAD_C(0, 0);
-                                       LOAD_A(0, x);
-                                       LOAD_B(x, 0);
-                                       MATMUL(0, 0);
-                                       STORE_C(0, 0);
-                               }
-                       }
-               } else { // k%32 = 1
-                       for (; m_count > 0; m_count -= 16) {
-                               int tail_m = (m_count > 16) ? 16: m_count;
-                               __mmask16 amask = (1UL << tail_m) - 1;
-
-                               ptr_a0 = ptr_a + tail_m * k2;
-                               ptr_a += tail_m * k;
-                               ptr_b = B;
-                               ptr_c00 = ptr_c;
-                               ptr_c += tail_m * ldc;
-                               n_count = n;
-                               if (n_count > 15) {
-                                       TCONF(cfg, tail_m, 16, 2);
-                                       MASK_LOAD_A_TAIL(0, x);
-                                       for (; n_count > 15; n_count -= 16) {
-                                               ptr_b0 = ptr_b + 16 * k2;
-                                               LOAD_C(0, 0);
-                                               LOAD_B_TAIL(x, 0);
-                                               MATMUL(0, 0);
-                                               STORE_C(0, 0);
-                                               ptr_b += 16 * k;
-                                               ptr_c00 += 16;
-                                       }
-                               }
-                               if (n_count > 0) {
-                                       int tail_n = (n_count > 16) ? 16: n_count;
-                                       __mmask16 bmask = (1UL << tail_n) - 1;
-                                       ptr_b0 =  ptr_b + tail_n * k2;
-                                       TCONF(cfg, tail_m, tail_n, 2);
-                                       LOAD_C(0, 0);
-                                       MASK_LOAD_A_TAIL(0, x);
-                                       MASK_LOAD_B_TAIL(x, 0);
-                                       MATMUL(0, 0);
-                                       STORE_C(0, 0);
-                               }
-                       }
-
-               }
-       }
-       return 0;
+       if (alpha == 1.0f)
+               return sbgemm_kernel_spr_alpha_one(m, n, k, alpha, A, B, C, ldc);
+       else
+               return sbgemm_kernel_spr_alpha(m, n, k, alpha, A, B, C, ldc);
 }
diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c b/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c
new file mode 100644 (file)
index 0000000..465b9eb
--- /dev/null
@@ -0,0 +1,521 @@
+/***************************************************************************
+ * Copyright (c) 2021, The OpenBLAS Project
+ * All rights reserved.
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ * 1. Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in
+ * the documentation and/or other materials provided with the
+ * distribution.
+ * 3. Neither the name of the OpenBLAS project nor the names of
+ * its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written permission.
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+ * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ * *****************************************************************************/
+
+#include <immintrin.h>
+#include <string.h>
+#include "common.h"
+
+#ifndef SBGEMM_KERNEL_SPR
+#define SBGEMM_KERNEL_SPR
+typedef struct {
+       char palette_id;
+       char start_row;
+       char dummy0[14];  // bytes 2-15 reserved, must be zero
+       short tile_colsb[8];
+       char dummy1[16];  // bytes 32-47 reserved, must be zero
+       char tile_rows[8];
+       char dummy2[16];  // bytes 56-63 reserved, must be zero
+} tilecfg;
+
+/* tile0/tile1 -- A (m x 2k)
+ * tile2/tile3 -- B (2k x n)
+ * tile4-7 -- C (m x n)
+ */
+#define TCONF(cfg, m, n, k2) \
+       memset(&cfg, 0, sizeof(tilecfg)); \
+       cfg.palette_id = 1; \
+       cfg.tile_rows[0] = m; \
+       cfg.tile_rows[1] = m; \
+       cfg.tile_rows[2] = k2>>1; \
+       cfg.tile_rows[3] = k2>>1; \
+       cfg.tile_rows[4] = m; \
+       cfg.tile_rows[5] = m; \
+       cfg.tile_rows[6] = m; \
+       cfg.tile_rows[7] = m; \
+       cfg.tile_colsb[0] = k2<<1; \
+       cfg.tile_colsb[1] = k2<<1; \
+       cfg.tile_colsb[2] = n * 4; \
+       cfg.tile_colsb[3] = n * 4; \
+       cfg.tile_colsb[4] = n * 4; \
+       cfg.tile_colsb[5] = n * 4; \
+       cfg.tile_colsb[6] = n * 4; \
+       cfg.tile_colsb[7] = n * 4; \
+       _tile_loadconfig(&cfg);
+
+/* CONFIG for handling k2 and odd tail at the same time
+ * tile0 -- A (m x 2k)
+ * tile1 -- A (m x 1)
+ * tile2 -- B (2k x n)
+ * tile3 -- B (1 x n)
+ * tile4 -- C (m x n)
+ */
+#define TCONF_TAIL(cfg, m, n, k2) \
+       memset(&cfg, 0, sizeof(tilecfg)); \
+       cfg.palette_id = 1; \
+       cfg.tile_rows[0] = m; \
+       cfg.tile_rows[1] = m; \
+       cfg.tile_rows[2] = k2>>1; \
+       cfg.tile_rows[3] = 1; \
+       cfg.tile_rows[4] = m; \
+       cfg.tile_colsb[0] = k2<<1; \
+       cfg.tile_colsb[1] = 4; \
+       cfg.tile_colsb[2] = n * 4; \
+       cfg.tile_colsb[3] = n * 4; \
+       cfg.tile_colsb[4] = n * 4; \
+       _tile_loadconfig(&cfg);
+
+#define T_A0   0
+#define T_A1   1
+#define T_B0   2
+#define T_B1   3
+#define T_C00  4
+#define T_C01  5
+#define T_C10  6
+#define T_C11  7
+
+// FIXME: gcc11 seem have problem in tile load/store address calc,
+// need to multiply with element size (2 or 4) here.
+#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2)
+#define LOAD_A_TAIL(M, N) {\
+       __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \
+       __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
+       _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
+       _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
+}
+#define MASK_LOAD_A_TAIL(M, N) {\
+       __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
+       __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
+       _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
+       _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
+}
+#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
+#define LOAD_B_TAIL(M, N) {\
+       __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
+       __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
+       _mm512_storeu_epi16(tail_b + 16 * N, zmm); \
+       _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
+}
+#define MASK_LOAD_B_TAIL(M, N) {\
+       __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \
+       __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
+       _mm512_storeu_epi16(tail_b + 16 * N, zmm); \
+       _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
+}
+
+#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)
+#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N)
+#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4)
+#define LOAD_C_F(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
+
+#endif  // end of SBGEMM_KERNEL_SPR
+
+#ifdef ALPHA_ONE
+#undef LOAD_C
+#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
+#else
+#undef LOAD_C
+#define LOAD_C(M, N) _tile_zero(T_C##M##N)
+#define ALPHA_STORE(N) \
+       __m512 zmm_d##N = _mm512_loadu_ps(dst##N + noffset); \
+       __m512 zmm_s##N = _mm512_loadu_ps(src##N + noffset); \
+       zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \
+       _mm512_storeu_ps(dst##N + noffset, zmm_d##N);
+#define MASK_APLPHA_STORE(N) \
+       __m512 zmm_d##N = _mm512_maskz_loadu_ps(mask, dst##N + noffset); \
+       __m512 zmm_s##N = _mm512_maskz_loadu_ps(mask, src##N + noffset); \
+       zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \
+       _mm512_mask_storeu_ps(dst##N + noffset, mask, zmm_d##N);
+#endif // end of ALPHA_ONE
+
+
+#ifdef ALPHA_ONE
+int sbgemm_kernel_spr_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
+#else
+int sbgemm_kernel_spr_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
+#endif
+{
+       /* Row Major matrix for AMX requirement */
+       IFLOAT *ptr_a = A, *ptr_b = B;
+       IFLOAT *ptr_b0, *ptr_b1;
+       IFLOAT *ptr_a0, *ptr_a1;
+       FLOAT *ptr_c = C;
+       FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11;
+
+       BLASLONG lda, ldb;
+       BLASLONG m_count = m;
+       BLASLONG n_count, k_count;
+
+#ifndef ALPHA_ONE
+       FLOAT *tmp_c = malloc(sizeof(FLOAT) * m * n);
+       memset(tmp_c, 0, sizeof(FLOAT) * m * n);
+       ptr_c = tmp_c;
+       BLASLONG ldc_o = ldc;
+       ldc = n;
+#endif
+       IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64)));
+       IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64)));
+       tilecfg cfg;
+
+       if (k > 31) {
+               for (; m_count > 31; m_count -= 32) {
+                       ptr_b = B;
+
+                       ptr_c00 = ptr_c;
+                       ptr_c01 = ptr_c00 + 16;
+                       ptr_c10 = ptr_c + 16 * ldc;
+                       ptr_c11 = ptr_c10 + 16;
+                       ptr_c += 32 * ldc;
+                       n_count = n;
+                       TCONF(cfg, 16, 16, 32);
+                       for (; n_count > 31; n_count -= 32) {
+                               ptr_a0 = ptr_a;
+                               ptr_a1 = ptr_a + 16 * k;
+
+                               ptr_b0 = ptr_b;
+                               ptr_b1 = ptr_b + 16 * k;
+                               ptr_b += 32 * k;
+
+                               lda = 32;
+                               ldb = 32;
+                               LOAD_C(0, 0); LOAD_C(0, 1);
+                               LOAD_C(1, 0); LOAD_C(1, 1);
+                               k_count = k;
+                               for (; k_count > 31; k_count -= 32) {
+                                       LOAD_A(0, x); LOAD_A(1, x);
+                                       ptr_a0 += 16 * 32;
+                                       ptr_a1 += 16 * 32;
+                                       LOAD_B(x, 0); LOAD_B(x, 1);
+                                       ptr_b0 += 16 * 32;
+                                       ptr_b1 += 16 * 32;
+
+                                       MATMUL(0, 0); MATMUL(0, 1);
+                                       MATMUL(1, 0); MATMUL(1, 1);
+                               }
+                               STORE_C(0, 0); STORE_C(0, 1);
+                               STORE_C(1, 0); STORE_C(1, 1);
+                               ptr_c00 += 32;
+                               ptr_c01 += 32;
+                               ptr_c10 += 32;
+                               ptr_c11 += 32;
+                       }
+                       for (; n_count > 0; n_count -= 16) {
+                               int tail_n = (n_count > 16) ? 16: n_count;
+                               ptr_a0 = ptr_a;
+                               ptr_a1 = ptr_a + 16 * k;
+
+                               ptr_b0 = ptr_b;
+                               ptr_b += tail_n * k;
+
+                               lda = 32;
+                               ldb = 2 * tail_n;
+                               TCONF(cfg, 16, tail_n, 32);
+                               LOAD_C(0, 0);
+                               LOAD_C(1, 0);
+                               k_count = k;
+                               for (; k_count > 31; k_count -= 32) {
+                                       LOAD_A(0, x); LOAD_A(1, x);
+                                       ptr_a0 += 16 * 32;
+                                       ptr_a1 += 16 * 32;
+                                       LOAD_B(x, 0);
+                                       ptr_b0 += tail_n * 32;
+
+                                       MATMUL(0, 0);
+                                       MATMUL(1, 0);
+                               }
+                               STORE_C(0, 0);
+                               STORE_C(1, 0);
+                               ptr_c00 += tail_n;
+                               ptr_c10 += tail_n;
+                       }
+                       ptr_a += 32 * k;
+               }
+               for (; m_count > 0; m_count -= 16) {
+                       // process at most 16 m at a time
+                       int tail_m = (m_count > 16) ? 16: m_count;
+
+                       ptr_b = B;
+
+                       ptr_c00 = ptr_c;
+                       ptr_c01 = ptr_c00 + 16;
+                       ptr_c += tail_m * ldc;
+                       n_count = n;
+                       TCONF(cfg, tail_m, 16, 32);
+                       for (; n_count > 31; n_count -= 32) {
+                               ptr_a0 = ptr_a;
+
+                               ptr_b0 = ptr_b;
+                               ptr_b1 = ptr_b + 16 * k;
+                               ptr_b += 32 * k;
+
+                               lda = 32;
+                               ldb = 32;
+                               LOAD_C(0, 0); LOAD_C(0, 1);
+                               k_count = k;
+                               for (; k_count > 31; k_count -= 32) {
+                                       LOAD_A(0, x);
+                                       ptr_a0 += tail_m * 32;
+                                       LOAD_B(x, 0); LOAD_B(x, 1);
+                                       ptr_b0 += 16 * 32;
+                                       ptr_b1 += 16 * 32;
+
+                                       MATMUL(0, 0); MATMUL(0, 1);
+                               }
+                               STORE_C(0, 0); STORE_C(0, 1);
+                               ptr_c00 += 32;
+                               ptr_c01 += 32;
+                       }
+                       for (; n_count > 0; n_count -= 16) {
+                               int tail_n = (n_count > 16) ? 16: n_count;
+                               ptr_a0 = ptr_a;
+
+                               ptr_b0 = ptr_b;
+                               ptr_b += tail_n * k;
+
+                               lda = 32;
+                               ldb = 2 * tail_n;
+                               TCONF(cfg, tail_m, tail_n, 32);
+                               LOAD_C(0, 0);
+                               k_count = k;
+                               for (; k_count > 31; k_count -= 32) {
+                                       LOAD_A(0, x);
+                                       ptr_a0 += tail_m * 32;
+                                       LOAD_B(x, 0);
+                                       ptr_b0 += tail_n * 32;
+
+                                       MATMUL(0, 0);
+                               }
+                               STORE_C(0, 0);
+                               ptr_c00 += tail_n;
+                       }
+                       ptr_a += tail_m * k;
+               }
+       }
+
+       // process for k < 32
+       BLASLONG k32 = k & ~31;
+       BLASLONG k2 = k & ~1;
+       if (k32 != k) {
+               int remain_k2 = k2 - k32;
+               m_count = m;
+               ptr_a = A;
+#ifndef ALPHA_ONE
+               ptr_c = tmp_c;
+#else
+               ptr_c = C;
+#endif
+               if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0)
+                       for (; m_count > 0; m_count -= 16) {
+                               int tail_m = (m_count > 16) ? 16: m_count;
+                               __mmask16 amask = (1UL << tail_m) - 1;
+
+                               ptr_a0 = ptr_a + tail_m * k32;
+                               ptr_a1 = ptr_a + tail_m * k2;
+                               ptr_a += tail_m * k;
+                               ptr_b = B;
+                               ptr_c00 = ptr_c;
+                               ptr_c += tail_m * ldc;
+                               n_count = n;
+                               lda = remain_k2;
+                               ldb = 32;
+                               if (n_count > 15) {
+                                       TCONF_TAIL(cfg, tail_m, 16, remain_k2);
+                                       LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
+                                       for (; n_count > 15; n_count -= 16) {
+                                               ptr_b0 = ptr_b + 16 * k32;
+                                               ptr_b1 = ptr_b + 16 * k2;
+                                               LOAD_C_F(0, 0);
+                                               LOAD_B(x, 0); LOAD_B_TAIL(x, 1);
+                                               MATMUL(0, 0); MATMUL_TAIL(1, 1);
+                                               STORE_C(0, 0);
+                                               ptr_b += 16 * k;
+                                               ptr_c00 += 16;
+                                       }
+                               }
+                               if (n_count > 0) {
+                                       int tail_n = (n_count > 16) ? 16: n_count;
+                                       __mmask16 bmask = (1UL << tail_n) - 1;
+                                       ptr_b0 =  ptr_b + tail_n * k32;
+                                       ptr_b1 =  ptr_b + tail_n * k2;
+                                       ldb = 2 * tail_n;
+                                       TCONF_TAIL(cfg, tail_m, tail_n, remain_k2);
+                                       LOAD_C_F(0, 0);
+                                       LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
+                                       LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1);
+                                       MATMUL(0, 0); MATMUL_TAIL(1, 1);
+                                       STORE_C(0, 0);
+                               }
+                       }
+
+               } else if (remain_k2 > 0) { // k%32 = 2x
+                       for (; m_count > 0; m_count -= 16) {
+                               int tail_m = (m_count > 16) ? 16: m_count;
+
+                               ptr_a0 = ptr_a + tail_m * k32;
+                               ptr_a += tail_m * k;
+                               ptr_b = B;
+                               ptr_c00 = ptr_c;
+                               ptr_c += tail_m * ldc;
+                               n_count = n;
+                               lda = remain_k2;
+                               ldb = 32;
+                               if (n_count > 15) {
+                                       TCONF(cfg, tail_m, 16, remain_k2);
+                                       LOAD_A(0, x);
+                                       for (; n_count > 15; n_count -= 16) {
+                                               ptr_b0 = ptr_b + 16 * k32;
+                                               LOAD_C_F(0, 0);
+                                               LOAD_B(x, 0);
+                                               MATMUL(0, 0);
+                                               STORE_C(0, 0);
+                                               ptr_b += 16 * k;
+                                               ptr_c00 += 16;
+                                       }
+                               }
+                               if (n_count > 0) {
+                                       int tail_n = (n_count > 16) ? 16: n_count;
+                                       ptr_b0 =  ptr_b + tail_n * k32;
+                                       ldb = 2 * tail_n;
+                                       TCONF(cfg, tail_m, tail_n, remain_k2);
+                                       LOAD_C_F(0, 0);
+                                       LOAD_A(0, x);
+                                       LOAD_B(x, 0);
+                                       MATMUL(0, 0);
+                                       STORE_C(0, 0);
+                               }
+                       }
+               } else { // k%32 = 1
+                       for (; m_count > 0; m_count -= 16) {
+                               int tail_m = (m_count > 16) ? 16: m_count;
+                               __mmask16 amask = (1UL << tail_m) - 1;
+
+                               ptr_a0 = ptr_a + tail_m * k2;
+                               ptr_a += tail_m * k;
+                               ptr_b = B;
+                               ptr_c00 = ptr_c;
+                               ptr_c += tail_m * ldc;
+                               n_count = n;
+                               if (n_count > 15) {
+                                       TCONF(cfg, tail_m, 16, 2);
+                                       MASK_LOAD_A_TAIL(0, x);
+                                       for (; n_count > 15; n_count -= 16) {
+                                               ptr_b0 = ptr_b + 16 * k2;
+                                               LOAD_C_F(0, 0);
+                                               LOAD_B_TAIL(x, 0);
+                                               MATMUL(0, 0);
+                                               STORE_C(0, 0);
+                                               ptr_b += 16 * k;
+                                               ptr_c00 += 16;
+                                       }
+                               }
+                               if (n_count > 0) {
+                                       int tail_n = (n_count > 16) ? 16: n_count;
+                                       __mmask16 bmask = (1UL << tail_n) - 1;
+                                       ptr_b0 =  ptr_b + tail_n * k2;
+                                       TCONF(cfg, tail_m, tail_n, 2);
+                                       LOAD_C_F(0, 0);
+                                       MASK_LOAD_A_TAIL(0, x);
+                                       MASK_LOAD_B_TAIL(x, 0);
+                                       MATMUL(0, 0);
+                                       STORE_C(0, 0);
+                               }
+                       }
+
+               }
+       }
+#ifndef ALPHA_ONE
+       __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha));
+       BLASLONG n16 = n & ~15;
+       BLASLONG noffset;
+       FLOAT *src0, *src1, *src2, *src3;
+       FLOAT *dst0, *dst1, *dst2, *dst3;
+       FLOAT *src = tmp_c;
+       FLOAT *dst = C;
+       m_count = m;
+       for (; m_count > 3; m_count -= 4) {
+               src0 = src;
+               src1 = src0 + ldc;
+               src2 = src1 + ldc;
+               src3 = src2 + ldc;
+               src += 4 * ldc;
+
+               dst0 = dst;
+               dst1 = dst0 + ldc_o;
+               dst2 = dst1 + ldc_o;
+               dst3 = dst2 + ldc_o;
+               dst += 4 * ldc_o;
+
+               noffset = 0;
+               for (; noffset < n16; noffset += 16) {
+                       ALPHA_STORE(0);
+                       ALPHA_STORE(1);
+                       ALPHA_STORE(2);
+                       ALPHA_STORE(3);
+               }
+               if (noffset < n) {
+                       __mmask16 mask = (1UL << (n - noffset)) - 1;
+                       MASK_APLPHA_STORE(0);
+                       MASK_APLPHA_STORE(1);
+                       MASK_APLPHA_STORE(2);
+                       MASK_APLPHA_STORE(3);
+               }
+       }
+       for (; m_count > 1; m_count -= 2) {
+               src0 = src;
+               src1 = src0 + ldc;
+               src += 2 * ldc;
+
+               dst0 = dst;
+               dst1 = dst0 + ldc_o;
+               dst += 2 * ldc_o;
+
+               noffset = 0;
+               for (; noffset < n16; noffset += 16) {
+                       ALPHA_STORE(0);
+                       ALPHA_STORE(1);
+               }
+               if (noffset < n) {
+                       __mmask16 mask = (1UL << (n - noffset)) - 1;
+                       MASK_APLPHA_STORE(0);
+                       MASK_APLPHA_STORE(1);
+               }
+       }
+       for (; m_count > 0; m_count -= 1) {
+               src0 = src;
+               dst0 = dst;
+               noffset = 0;
+               for (; noffset < n16; noffset += 16) {
+                       ALPHA_STORE(0);
+               }
+               if (noffset < n) {
+                       __mmask16 mask = (1UL << (n - noffset)) - 1;
+                       MASK_APLPHA_STORE(0);
+               }
+       }
+       free(tmp_c);
+#endif
+       return 0;
+}