Add all SBGEMM kernels for IA AVX512-BF16 based platforms
authorChen, Guobing <guobing.chen@intel.com>
Thu, 5 Aug 2021 03:11:14 +0000 (11:11 +0800)
committerChen, Guobing <guobing.chen@intel.com>
Thu, 5 Aug 2021 03:11:29 +0000 (11:11 +0800)
Added all SBGEMM kernels including NN/NT/TN/TT for both ColMajor and
RowMajor, based on AVX512-BF16 ISA set on IA.

Signed-off-by: Chen, Guobing <guobing.chen@intel.com>
kernel/x86_64/bf16_common_macros.h
kernel/x86_64/sbgemm_block_microk_cooperlake.c
kernel/x86_64/sbgemm_microk_cooperlake_template.c

index 1014ecc..78db7ab 100644 (file)
@@ -29,6 +29,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 #include <immintrin.h>
 
+#define _MM512_BROADCASTD_EPI32(addr, zmm)             \
+                    __asm__ ("vpbroadcastd (%1), %0;"  \
+                            : "=v" (zmm)               \
+                            : "r"  (addr) )
+
+#define PREFETCH_T0(addr)             \
+                    __asm__ ("prefetcht0 (%0);"  \
+                            :                    \
+                            : "r"  (addr) )
+
 #define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512)   \
     reg256##_0 = _mm512_castps512_ps256(reg512##_0);  \
     reg256##_1 = _mm512_castps512_ps256(reg512##_1);
@@ -721,6 +731,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     _mm_mask_storeu_ps(targetAddr, mask, regResult);
 
 
+/* Store 16 (result + y) to y
+*/
+#define STORE16_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr)          \
+    regResult = _mm512_add_ps(regResult, _mm512_loadu_ps(targetAddr));  \
+    _mm512_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 16 (result + y) to y
+*/
+#define STORE16_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask)           \
+    regResult = _mm512_add_ps(regResult, _mm512_maskz_loadu_ps(mask, targetAddr));  \
+    _mm512_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 8 (result + y) to y
+*/
+#define STORE8_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr)           \
+    regResult = _mm256_add_ps(regResult, _mm256_loadu_ps(targetAddr));  \
+    _mm256_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 8 (result + y) to y
+*/
+#define STORE8_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask)            \
+    regResult = _mm256_add_ps(regResult, _mm256_maskz_loadu_ps(mask, targetAddr));  \
+    _mm256_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 4 (result + y) to y
+*/
+#define STORE4_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr)     \
+    regResult = _mm_add_ps(regResult, _mm_loadu_ps(targetAddr));  \
+    _mm_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 4 (result + y) to y
+*/
+#define STORE4_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask)      \
+    regResult = _mm_add_ps(regResult, _mm_maskz_loadu_ps(mask, targetAddr));  \
+    _mm_mask_storeu_ps(targetAddr, mask, regResult);
+
+
 /* Store 16 (alpha * result) to y
 */
 #define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr)  \
index 2376fed..147c5eb 100644 (file)
@@ -1,4 +1,4 @@
-#include "sbgemm.h"
+//#include "sbgemm.h"
 
 #include <immintrin.h>
 // Walk around those intrinsics that missed by compiler
 #define MM256_STOREU_EPI16(addr, reg)  \
             _mm256_mask_storeu_epi16((addr), ~0, (reg))
 
-#include <stdio.h>
-void print_block(BLASLONG m, BLASLONG n, bfloat16 * mat)
-{
-    printf("---- BLOCK %ld x %ld ----\n", m, n);
-    for (BLASLONG i=0; i<m; i++) {
-        for (BLASLONG j=0; j<n; j++) {
-            printf("%-4X  ", *(mat + i*n +j));
-        }
-        printf("\n");
-    }
-    printf("---- End of BLOCK ----\n");
-}
-
-void COL_MAJOR_INCOPY_KERNEL_Kx32(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+// INCOPY Kernel, 16<M<=32, k can be any number
+void COL_MAJOR_INCOPY_KERNEL_Kx32(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
 {
     BLASLONG tag_k_2x = k & (~1);
+    unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-m));
 
     __m512i array512_0, array512_1, array512_2, array512_3;
 
-    BLASLONG idx_src_base0, idx_src_base1;
-    BLASLONG idx_target_base0, idx_target_base1;
+    bfloat16 * src_addr0, * src_addr1;
+    bfloat16 * dst_addr0, * dst_addr1;
 
     BLASLONG LDA_2x = 2*lda;
     BLASLONG BF16_BLOCK_T_M_2x = 2*32;
-    idx_src_base0 = 0;
-    idx_src_base1 = lda;
-    idx_target_base0 = 0;
-    idx_target_base1 = 32;
-    for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
-        array512_0 = _mm512_loadu_si512(&A[idx_src_base0]);
-        array512_1 = _mm512_loadu_si512(&A[idx_src_base1]);
-        array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1);
-        array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1);
-        _mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
-        _mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
-
-        idx_src_base0 += LDA_2x;
-        idx_src_base1 += LDA_2x;
-        idx_target_base0 += BF16_BLOCK_T_M_2x;
-        idx_target_base1 += BF16_BLOCK_T_M_2x;
-    }
-
-    if (tag_k_2x != k) {
-        __m512i ZERO512 = _mm512_setzero_si512();
-        array512_0 = _mm512_loadu_si512(&A[idx_src_base0]);
-        array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512);
-        array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512);
-        _mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
-        _mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
-   }
-
-#ifdef DEBUG_PROFILE
-   print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
-#endif
-}
-
-void COL_MAJOR_INCOPY_KERNEL_Kx32m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
-{
-    BLASLONG tag_k_2x = k & (~1);
-    unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-m));
-    __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
-
-    __m512i array512_0, array512_1, array512_2, array512_3;
 
-    BLASLONG idx_src_base0, idx_src_base1;
-    BLASLONG idx_target_base0, idx_target_base1;
+    src_addr0 = A;
+    src_addr1 = A + lda;
+    dst_addr0 = block_A;
+    dst_addr1 = block_A + 32;
 
-    BLASLONG LDA_2x = 2*lda;
-    BLASLONG BF16_BLOCK_T_M_2x = 2*32;
-    idx_src_base0 = 0;
-    idx_src_base1 = lda;
-    idx_target_base0 = 0;
-    idx_target_base1 = 32;
     for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
-        array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
-        array512_1 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]);
+        array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0);
+        array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1);
         array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1);
         array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1);
-        _mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
-        _mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
 
-        idx_src_base0 += LDA_2x;
-        idx_src_base1 += LDA_2x;
-        idx_target_base0 += BF16_BLOCK_T_M_2x;
-        idx_target_base1 += BF16_BLOCK_T_M_2x;
+        src_addr0 += LDA_2x;
+        src_addr1 += LDA_2x;
+        dst_addr0 += BF16_BLOCK_T_M_2x;
+        dst_addr1 += BF16_BLOCK_T_M_2x;
     }
 
     if (tag_k_2x != k) {
         __m512i ZERO512 = _mm512_setzero_si512();
-        array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
+        array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0);
         array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512);
         array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512);
-        _mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
-        _mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
     }
-
-#ifdef DEBUG_PROFILE
-   print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
-#endif
 }
 
+// INCOPY Kernel, 0<M<=16, k can be any number
 void COL_MAJOR_INCOPY_KERNEL_Kx16(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
 {
     BLASLONG tag_k_2x = k & (~1);
+    unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
 
     __m256i array256_0, array256_1, array256_2, array256_3;
 
-    BLASLONG idx_src_base0, idx_src_base1;
-    BLASLONG idx_target_base0;
+    bfloat16 * src_addr0, * src_addr1;
+    bfloat16 * dst_addr0;
 
     BLASLONG LDA_2x = 2*lda;
-    idx_src_base0 = 0;
-    idx_src_base1 = lda;
-    idx_target_base0 = 0;
+
+    src_addr0 = A;
+    src_addr1 = A + lda;
+    dst_addr0 = block_A;
+
     for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
-        array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]);
-        array256_1 = MM256_LOADU_EPI16(&A[idx_src_base1]);
+        array256_0 = _mm256_maskz_loadu_epi16(tail_mask, src_addr0);
+        array256_1 = _mm256_maskz_loadu_epi16(tail_mask, src_addr1);
         array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1);
         array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1);
         // Store in one row of block_B
-        MM256_STOREU_EPI16(&block_A[idx_target_base0],      array256_2);
-        MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
+        MM256_STOREU_EPI16(dst_addr0,    array256_2);
+        MM256_STOREU_EPI16(dst_addr0+16, array256_3);
 
-        idx_src_base0 += LDA_2x;
-        idx_src_base1 += LDA_2x;
-        idx_target_base0 += 32;
+        src_addr0 += LDA_2x;
+        src_addr1 += LDA_2x;
+        dst_addr0 += 32;
     }
 
     if (tag_k_2x != k) {
         __m256i ZERO256 = _mm256_setzero_si256();
-        array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]);
+        array256_0 = _mm256_maskz_loadu_epi16(tail_mask, src_addr0);
         array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256);
         array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256);
         // Store in one row of block_B
-        MM256_STOREU_EPI16(&block_A[idx_target_base0],      array256_2);
-        MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
+        MM256_STOREU_EPI16(dst_addr0,    array256_2);
+        MM256_STOREU_EPI16(dst_addr0+16, array256_3);
     }
+}
 
-#ifdef DEBUG_PROFILE
-   print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
-#endif
+// K=32, M=16
+void COL_MAJOR_ITCOPY_KERNEL_32x16(bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+{
+    bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+    bfloat16 * dst_addr0, * dst_addr1;
+
+    BLASLONG LDA_4x = lda*4;
+
+    src_addr0 = A;
+    src_addr1 = A + lda;
+    src_addr2 = A + lda*2;
+    src_addr3 = A + lda*3;
+    dst_addr0 = block_A;
+    dst_addr1 = block_A + 32*8;
+
+    __m512i array512_0, array512_1, array512_2, array512_3;
+    __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+    __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3;
+    __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3;
+    __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3;
+
+    __m512i M512_EPI64_2   = _mm512_set1_epi64(2);
+    __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+    __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+    // Load and preprocess 1st 4 rows
+    array512_way0_0 = _mm512_loadu_si512(src_addr0);
+    array512_way0_1 = _mm512_loadu_si512(src_addr1);
+    array512_way0_2 = _mm512_loadu_si512(src_addr2);
+    array512_way0_3 = _mm512_loadu_si512(src_addr3);
+    array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1);
+    array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1);
+    array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3);
+    array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3);
+    array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+    array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+    array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+    array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+    src_addr0 += LDA_4x;
+    src_addr1 += LDA_4x;
+    src_addr2 += LDA_4x;
+    src_addr3 += LDA_4x;
+
+    // Load and preprocess 2nd 4 rows
+    array512_way1_0 = _mm512_loadu_si512(src_addr0);
+    array512_way1_1 = _mm512_loadu_si512(src_addr1);
+    array512_way1_2 = _mm512_loadu_si512(src_addr2);
+    array512_way1_3 = _mm512_loadu_si512(src_addr3);
+    array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1);
+    array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1);
+    array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3);
+    array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3);
+    array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+    array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+    array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+    array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+    src_addr0 += LDA_4x;
+    src_addr1 += LDA_4x;
+    src_addr2 += LDA_4x;
+    src_addr3 += LDA_4x;
+
+    // Load and preprocess 3rd 4 rows
+    array512_way2_0 = _mm512_loadu_si512(src_addr0);
+    array512_way2_1 = _mm512_loadu_si512(src_addr1);
+    array512_way2_2 = _mm512_loadu_si512(src_addr2);
+    array512_way2_3 = _mm512_loadu_si512(src_addr3);
+    array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1);
+    array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1);
+    array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3);
+    array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3);
+    array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+    array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+    array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+    array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+    src_addr0 += LDA_4x;
+    src_addr1 += LDA_4x;
+    src_addr2 += LDA_4x;
+    src_addr3 += LDA_4x;
+
+    // Load and preprocess 4th 4 rows
+    array512_way3_0 = _mm512_loadu_si512(src_addr0);
+    array512_way3_1 = _mm512_loadu_si512(src_addr1);
+    array512_way3_2 = _mm512_loadu_si512(src_addr2);
+    array512_way3_3 = _mm512_loadu_si512(src_addr3);
+    array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1);
+    array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1);
+    array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3);
+    array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3);
+    array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+    array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+    array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+    array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+    // Compose and store the 0/1 and 16/17 cols
+    array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0);
+    array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0);
+    array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+    array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+    _mm512_storeu_si512(dst_addr0, array512_2);
+    _mm512_storeu_si512(dst_addr1, array512_3);
+    dst_addr0 += 32;
+    dst_addr1 += 32;
+
+    // Compose and store the 2/3 and 18/19 cols
+    array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1);
+    array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1);
+    array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+    array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+    _mm512_storeu_si512(dst_addr0, array512_2);
+    _mm512_storeu_si512(dst_addr1, array512_3);
+    dst_addr0 += 32;
+    dst_addr1 += 32;
+
+    // Compose and store the 4/5 and 20/21 cols
+    array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2);
+    array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2);
+    array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+    array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+    _mm512_storeu_si512(dst_addr0, array512_2);
+    _mm512_storeu_si512(dst_addr1, array512_3);
+    dst_addr0 += 32;
+    dst_addr1 += 32;
+
+    // Compose and store the 6/7 and 22/23 cols
+    array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3);
+    array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3);
+    array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+    array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+    _mm512_storeu_si512(dst_addr0, array512_2);
+    _mm512_storeu_si512(dst_addr1, array512_3);
+    dst_addr0 += 32;
+    dst_addr1 += 32;
+
+    // Compose and store the 8/9 and 24/25 cols
+    array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0);
+    array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0);
+    array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+    array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+    _mm512_storeu_si512(dst_addr0, array512_2);
+    _mm512_storeu_si512(dst_addr1, array512_3);
+    dst_addr0 += 32;
+    dst_addr1 += 32;
+
+    // Compose and store the 10/11 and 26/27 cols
+    array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1);
+    array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1);
+    array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+    array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+    _mm512_storeu_si512(dst_addr0, array512_2);
+    _mm512_storeu_si512(dst_addr1, array512_3);
+    dst_addr0 += 32;
+    dst_addr1 += 32;
+
+    // Compose and store the 12/13 and 28/29 cols
+    array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2);
+    array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2);
+    array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+    array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+    _mm512_storeu_si512(dst_addr0, array512_2);
+    _mm512_storeu_si512(dst_addr1, array512_3);
+    dst_addr0 += 32;
+    dst_addr1 += 32;
+
+    // Compose and store the 14/15 and 30/31 cols
+    array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3);
+    array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3);
+    array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+    array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+    _mm512_storeu_si512(dst_addr0, array512_2);
+    _mm512_storeu_si512(dst_addr1, array512_3);
 }
 
-void COL_MAJOR_INCOPY_KERNEL_Kx16m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+// K=Any number but will be processed based on 32, M=32
+void COL_MAJOR_ITCOPY_KERNEL_Kx32(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
 {
-    BLASLONG tag_k_2x = k & (~1);
-    unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
-    __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+    bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+    bfloat16 * dst_addr0, * dst_addr1;
 
-    __m256i array256_0, array256_1, array256_2, array256_3;
+    BLASLONG tag_k_32x = k & (~31);
 
-    BLASLONG idx_src_base0, idx_src_base1;
-    BLASLONG idx_target_base0;
+    BLASLONG LDA_4x  = lda*4;
+    BLASLONG LDA_8x  = lda*8;
+    BLASLONG LDA_12x = lda*12;
+    BLASLONG LDA_16x = lda*16;
 
-    BLASLONG LDA_2x = 2*lda;
-    idx_src_base0 = 0;
-    idx_src_base1 = lda;
-    idx_target_base0 = 0;
-    for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
-        array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
-        array256_1 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]);
-        array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1);
-        array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1);
-        // Store in one row of block_B
-        MM256_STOREU_EPI16(&block_A[idx_target_base0],      array256_2);
-        MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
+    src_addr0 = A;
+    src_addr1 = A + lda;
+    src_addr2 = A + lda*2;
+    src_addr3 = A + lda*3;
+    dst_addr0 = block_A;
+    dst_addr1 = block_A + 32*16;
 
-        idx_src_base0 += LDA_2x;
-        idx_src_base1 += LDA_2x;
-        idx_target_base0 += 32;
+    __m512i array512_0, array512_1, array512_2, array512_3;
+    __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+    __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3;
+    __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3;
+    __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3;
+
+    __m512i M512_EPI64_2   = _mm512_set1_epi64(2);
+    __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+    __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+    for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+        for (int i = 0; i < 2; i++) {
+            // Load and preprocess 1st 4 rows
+            array512_way0_0 = _mm512_loadu_si512(src_addr0+idx_k);
+            array512_way0_1 = _mm512_loadu_si512(src_addr1+idx_k);
+            array512_way0_2 = _mm512_loadu_si512(src_addr2+idx_k);
+            array512_way0_3 = _mm512_loadu_si512(src_addr3+idx_k);
+            array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1);
+            array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1);
+            array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3);
+            array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3);
+            array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+            // Load and preprocess 2nd 4 rows
+            array512_way1_0 = _mm512_loadu_si512(src_addr0+LDA_4x+idx_k);
+            array512_way1_1 = _mm512_loadu_si512(src_addr1+LDA_4x+idx_k);
+            array512_way1_2 = _mm512_loadu_si512(src_addr2+LDA_4x+idx_k);
+            array512_way1_3 = _mm512_loadu_si512(src_addr3+LDA_4x+idx_k);
+            array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1);
+            array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1);
+            array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3);
+            array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3);
+            array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+            // Load and preprocess 3rd 4 rows
+            array512_way2_0 = _mm512_loadu_si512(src_addr0+LDA_8x+idx_k);
+            array512_way2_1 = _mm512_loadu_si512(src_addr1+LDA_8x+idx_k);
+            array512_way2_2 = _mm512_loadu_si512(src_addr2+LDA_8x+idx_k);
+            array512_way2_3 = _mm512_loadu_si512(src_addr3+LDA_8x+idx_k);
+            array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1);
+            array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1);
+            array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3);
+            array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3);
+            array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+            // Load and preprocess 4th 4 rows
+            array512_way3_0 = _mm512_loadu_si512(src_addr0+LDA_12x+idx_k);
+            array512_way3_1 = _mm512_loadu_si512(src_addr1+LDA_12x+idx_k);
+            array512_way3_2 = _mm512_loadu_si512(src_addr2+LDA_12x+idx_k);
+            array512_way3_3 = _mm512_loadu_si512(src_addr3+LDA_12x+idx_k);
+            array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1);
+            array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1);
+            array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3);
+            array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3);
+            array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+            // Compose and store the 0/1 and 16/17 cols
+            array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0);
+            array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+
+            // Compose and store the 2/3 and 18/19 cols
+            array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1);
+            array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+
+            // Compose and store the 4/5 and 20/21 cols
+            array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2);
+            array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+
+            // Compose and store the 6/7 and 22/23 cols
+            array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3);
+            array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+
+            // Compose and store the 8/9 and 24/25 cols
+            array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0);
+            array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+
+            // Compose and store the 10/11 and 26/27 cols
+            array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1);
+            array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+
+            // Compose and store the 12/13 and 28/29 cols
+            array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2);
+            array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+
+            // Compose and store the 14/15 and 30/31 cols
+            array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3);
+            array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+
+            src_addr0 += LDA_16x;
+            src_addr1 += LDA_16x;
+            src_addr2 += LDA_16x;
+            src_addr3 += LDA_16x;
+            dst_addr0 -= (64*7 - 32);
+            dst_addr1 -= (64*7 - 32);
+        }
+        src_addr0 -= (LDA_16x*2);
+        src_addr1 -= (LDA_16x*2);
+        src_addr2 -= (LDA_16x*2);
+        src_addr3 -= (LDA_16x*2);
+        dst_addr0 += (32*30);
+        dst_addr1 += (32*30);
     }
 
-    if (tag_k_2x != k) {
-        __m256i ZERO256 = _mm256_setzero_si256();
-        array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
-        array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256);
-        array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256);
-        // Store in one row of block_B
-        MM256_STOREU_EPI16(&block_A[idx_target_base0],      array256_2);
-        MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
+    if (tag_k_32x != k) {
+        int k_rem = k - tag_k_32x;
+        unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem));
+        __m512i array512[16];
+
+        bfloat16 * dst_addr_tmp = dst_addr0;
+
+        for (int i = 0; i < 2; i++) {
+            // Load and preprocess 1st 4 rows
+            array512[0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+            array512[1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+            array512[2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+            array512[3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+            array512_0 = _mm512_unpacklo_epi32(array512[0], array512[1]);
+            array512_1 = _mm512_unpackhi_epi32(array512[0], array512[1]);
+            array512_2 = _mm512_unpacklo_epi32(array512[2], array512[3]);
+            array512_3 = _mm512_unpackhi_epi32(array512[2], array512[3]);
+            array512[0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+            src_addr0 += LDA_4x;
+            src_addr1 += LDA_4x;
+            src_addr2 += LDA_4x;
+            src_addr3 += LDA_4x;
+
+            // Load and preprocess 2nd 4 rows
+            array512[4] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+            array512[5] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+            array512[6] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+            array512[7] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+            array512_0 = _mm512_unpacklo_epi32(array512[4], array512[5]);
+            array512_1 = _mm512_unpackhi_epi32(array512[4], array512[5]);
+            array512_2 = _mm512_unpacklo_epi32(array512[6], array512[7]);
+            array512_3 = _mm512_unpackhi_epi32(array512[6], array512[7]);
+            array512[4] = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[5] = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[6] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[7] = _mm512_unpackhi_epi64(array512_1, array512_3);
+            src_addr0 += LDA_4x;
+            src_addr1 += LDA_4x;
+            src_addr2 += LDA_4x;
+            src_addr3 += LDA_4x;
+
+            // Load and preprocess 3rd 4 rows
+            array512[8]  = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+            array512[9]  = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+            array512[10] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+            array512[11] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+            array512_0 = _mm512_unpacklo_epi32(array512[8],  array512[9]);
+            array512_1 = _mm512_unpackhi_epi32(array512[8],  array512[9]);
+            array512_2 = _mm512_unpacklo_epi32(array512[10], array512[11]);
+            array512_3 = _mm512_unpackhi_epi32(array512[10], array512[11]);
+            array512[8]  = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[9]  = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[10] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[11] = _mm512_unpackhi_epi64(array512_1, array512_3);
+            src_addr0 += LDA_4x;
+            src_addr1 += LDA_4x;
+            src_addr2 += LDA_4x;
+            src_addr3 += LDA_4x;
+
+            // Load and preprocess 4th 4 rows
+            array512[12] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+            array512[13] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+            array512[14] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+            array512[15] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+            array512_0 = _mm512_unpacklo_epi32(array512[12], array512[13]);
+            array512_1 = _mm512_unpackhi_epi32(array512[12], array512[13]);
+            array512_2 = _mm512_unpacklo_epi32(array512[14], array512[15]);
+            array512_3 = _mm512_unpackhi_epi32(array512[14], array512[15]);
+            array512[12] = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[13] = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[14] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[15] = _mm512_unpackhi_epi64(array512_1, array512_3);
+            src_addr0 += LDA_4x;
+            src_addr1 += LDA_4x;
+            src_addr2 += LDA_4x;
+            src_addr3 += LDA_4x;
+
+            // array512_01_1617_0, array512_01_1617_1, array512_89_2425_0, array512_89_2425_1;
+            // Half-compose of 0/1, 16/17, 8/9, 24/25 cols
+            array512_0 = _mm512_permutex2var_epi64(array512[0], permute_lo_idx, array512[4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[8], permute_lo_idx, array512[12]);
+            array512_2 = _mm512_permutex2var_epi64(array512[0], permute_hi_idx, array512[4]);
+            array512_3 = _mm512_permutex2var_epi64(array512[8], permute_hi_idx, array512[12]);
+            array512[0]  = array512_0;  // 1st 8 pairs of col 0/1,   and 1st 8 pairs of col 16/17
+            array512[4]  = array512_1;  // 2nd 8 pairs of col 0/1,   and 2nd 8 pairs of col 16/17
+            array512[8]  = array512_2;  // 1st 8 pairs of col 8/9,   and 1st 8 pairs of col 24/25
+            array512[12] = array512_3;  // 2nd 8 pairs of col 8/9,   and 2nd 8 pairs of col 24/25
+
+            // Half-compose of 2/3, 18/19, 10/11, 26/27 cols
+            array512_0 = _mm512_permutex2var_epi64(array512[1], permute_lo_idx, array512[5]);
+            array512_1 = _mm512_permutex2var_epi64(array512[9], permute_lo_idx, array512[13]);
+            array512_2 = _mm512_permutex2var_epi64(array512[1], permute_hi_idx, array512[5]);
+            array512_3 = _mm512_permutex2var_epi64(array512[9], permute_hi_idx, array512[13]);
+            array512[1]  = array512_0;  // 1st 8 pairs of col 2/3,   and 1st 8 pairs of col 18/19
+            array512[5]  = array512_1;  // 2nd 8 pairs of col 2/3,   and 2nd 8 pairs of col 18/19
+            array512[9]  = array512_2;  // 1st 8 pairs of col 10/11, and 1st 8 pairs of col 26/27
+            array512[13] = array512_3;  // 2nd 8 pairs of col 10/11, and 2nd 8 pairs of col 26/27
+
+            // Half-compose of 4/5, 20/21, 12/13, 28/29 cols
+            array512_0 = _mm512_permutex2var_epi64(array512[2],  permute_lo_idx, array512[6]);
+            array512_1 = _mm512_permutex2var_epi64(array512[10], permute_lo_idx, array512[14]);
+            array512_2 = _mm512_permutex2var_epi64(array512[2],  permute_hi_idx, array512[6]);
+            array512_3 = _mm512_permutex2var_epi64(array512[10], permute_hi_idx, array512[14]);
+            array512[2]  = array512_0;  // 1st 8 pairs of col 4/5,   and 1st 8 pairs of col 20/21
+            array512[6]  = array512_1;  // 2nd 8 pairs of col 4/5,   and 2nd 8 pairs of col 20/21
+            array512[10] = array512_2;  // 1st 8 pairs of col 12/13, and 1st 8 pairs of col 28/29
+            array512[14] = array512_3;  // 2nd 8 pairs of col 12/13, and 2nd 8 pairs of col 28/29
+
+            // Half-compose of 6/7, 22/23, 14/15, 30/31 cols
+            array512_0 = _mm512_permutex2var_epi64(array512[3],  permute_lo_idx, array512[7]);
+            array512_1 = _mm512_permutex2var_epi64(array512[11], permute_lo_idx, array512[15]);
+            array512_2 = _mm512_permutex2var_epi64(array512[3],  permute_hi_idx, array512[7]);
+            array512_3 = _mm512_permutex2var_epi64(array512[11], permute_hi_idx, array512[15]);
+            array512[3]  = array512_0;  // 1st 8 pairs of col 6/7,   and 1st 8 pairs of col 22/23
+            array512[7]  = array512_1;  // 2nd 8 pairs of col 6/7,   and 2nd 8 pairs of col 22/23
+            array512[11] = array512_2;  // 1st 8 pairs of col 14/15, and 1st 8 pairs of col 30/31
+            array512[15] = array512_3;  // 2nd 8 pairs of col 14/15, and 2nd 8 pairs of col 30/31
+
+            // Compose and store the 0/1 cols
+            array512_0 = _mm512_inserti64x4(array512[0], _mm512_castsi512_si256(array512[4]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+
+            // Compose and store the 2/3 cols
+            array512_0 = _mm512_inserti64x4(array512[1], _mm512_castsi512_si256(array512[5]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+
+            // Compose and store the 4/5 cols
+            array512_0 = _mm512_inserti64x4(array512[2], _mm512_castsi512_si256(array512[6]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+
+            // Compose and store the 6/7 cols
+            array512_0 = _mm512_inserti64x4(array512[3], _mm512_castsi512_si256(array512[7]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+
+            // Compose and store the 8/9 cols
+            array512_0 = _mm512_inserti64x4(array512[8], _mm512_castsi512_si256(array512[12]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+
+            // Compose and store the 10/11 cols
+            array512_0 = _mm512_inserti64x4(array512[9], _mm512_castsi512_si256(array512[13]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+
+            // Compose and store the 12/13 cols
+            array512_0 = _mm512_inserti64x4(array512[10], _mm512_castsi512_si256(array512[14]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+
+            // Compose and store the 14/15 cols
+            array512_0 = _mm512_inserti64x4(array512[11], _mm512_castsi512_si256(array512[15]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+
+            // Compose and store 16 ~ k_rem cols
+            int idx_length = (k_rem + 1 - 16) >> 1;
+            if (idx_length > 4) {
+                for (int idx_k = 0; idx_k < 4; idx_k++) {
+                    array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                    _mm512_storeu_si512(dst_addr0, array512_0);
+                    dst_addr0 += 64;
+                }
+
+                for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+                    array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+                    _mm512_storeu_si512(dst_addr0, array512_0);
+                    dst_addr0 += 64;
+                }
+            } else {
+                for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+                    array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                    _mm512_storeu_si512(dst_addr0, array512_0);
+                    dst_addr0 += 64;
+                }
+            }
+
+            dst_addr0 = dst_addr_tmp + 32;
+        }
+    }
+}
+
+// K=Any number but will be processed based on 32, 16<M<32
+void COL_MAJOR_ITCOPY_KERNEL_Kx32m(BLASLONG m, BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+{
+    bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+    bfloat16 * dst_addr0, * dst_addr1;
+
+    BLASLONG tag_k_32x = k & (~31);
+
+    BLASLONG LDA_4x  = lda*4;
+
+    BLASLONG m_rem = m-16;
+
+    src_addr0 = A;
+    src_addr1 = A + lda;
+    src_addr2 = A + lda*2;
+    src_addr3 = A + lda*3;
+    dst_addr0 = block_A;
+    dst_addr1 = block_A + 32*16;
+
+    __m512i array512_0, array512_1, array512_2, array512_3;
+    __m512i array512[16];
+
+    __m512i M512_EPI64_2   = _mm512_set1_epi64(2);
+    __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+    __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+    for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+        for (int j = 0; j < 4; j++) {
+            int array_idx = j*4;
+            // Load and preprocess 4 rows
+            array512[array_idx+0] = _mm512_loadu_si512(src_addr0+idx_k);
+            array512[array_idx+1] = _mm512_loadu_si512(src_addr1+idx_k);
+            array512[array_idx+2] = _mm512_loadu_si512(src_addr2+idx_k);
+            array512[array_idx+3] = _mm512_loadu_si512(src_addr3+idx_k);
+            array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+            src_addr0 += LDA_4x;
+            src_addr1 += LDA_4x;
+            src_addr2 += LDA_4x;
+            src_addr3 += LDA_4x;
+        }
+
+        // Compose and store the 0/1, 2/3, 4/5, 6/7 and 16/17, 18/19, 20/21, 22/23 cols
+        for (int j = 0; j < 4; j++) {
+            array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+        }
+
+        // Compose and store the 8/9, 10/11, 12/13, 14/15 and 24/25, 26/27, 28/29, 30/31 cols
+        for (int j = 0; j < 4; j++) {
+            array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+        }
+
+        dst_addr0 -= (64*8 - 32);
+        dst_addr1 -= (64*8 - 32);
+
+        for (int j = 0; j < m_rem; j++) {
+            array512[j] = _mm512_loadu_si512(src_addr0+j*lda+idx_k);
+        }
+        for (int j = m_rem; j < 16; j++) {
+            array512[j] = _mm512_setzero_si512();
+        }
+
+        for (int j = 0; j < 4; j++) {
+            int array_idx = j*4;
+            array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+        }
+
+        // Compose and store the 0/1, 2/3, 4/5, 6/7 and 16/17, 18/19, 20/21, 22/23 cols
+        for (int j = 0; j < 4; j++) {
+            array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+        }
+
+        // Compose and store the 8/9, 10/11, 12/13, 14/15 and 24/25, 26/27, 28/29, 30/31 cols
+        for (int j = 0; j < 4; j++) {
+            array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 64;
+            dst_addr1 += 64;
+        }
+
+        src_addr0 -= (LDA_4x*4);
+        src_addr1 -= (LDA_4x*4);
+        src_addr2 -= (LDA_4x*4);
+        src_addr3 -= (LDA_4x*4);
+        dst_addr0 += (32*15);
+        dst_addr1 += (32*15);
+    }
+
+    if (tag_k_32x != k) {
+        int k_rem = k - tag_k_32x;
+        int idx_length = (k_rem + 1 - 16) >> 1;
+        unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem));
+        bfloat16 * dst_addr_tmp = dst_addr0;
+
+        for (int j = 0; j < 4; j++) {
+            int array_idx = j*4;
+            // Load and preprocess 4 rows
+            array512[array_idx+0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+            array512[array_idx+1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+            array512[array_idx+2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+            array512[array_idx+3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+            array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+            src_addr0 += LDA_4x;
+            src_addr1 += LDA_4x;
+            src_addr2 += LDA_4x;
+            src_addr3 += LDA_4x;
+        }
+
+        for (int j = 0; j < 4; j++) {
+            array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+            array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+            array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+            array512[j+0]  = array512_0;  // 1st 8 pairs of col 0/1|2/3|4/5|6/7,   and 1st 8 pairs of col 16/17|18/19|20/21|22/23
+            array512[j+4]  = array512_1;  // 2nd 8 pairs of col 0/1|2/3|4/5|6/7,   and 2nd 8 pairs of col 16/17|18/19|20/21|22/23
+            array512[j+8]  = array512_2;  // 1st 8 pairs of col 8/9|10/11|12/13|14/15,   and 1st 8 pairs of col 24/25|26/27|28/29|30/31
+            array512[j+12] = array512_3;  // 2nd 8 pairs of col 8/9|10/11|12/13|14/15,   and 2nd 8 pairs of col 24/25|26/27|28/29|30/31
+        }
+
+        for (int j = 0; j < 4; j++) {
+            // Compose and store the 0/1 cols
+            array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+        }
+
+        for (int j = 8; j < 12; j++) {
+            array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+        }
+
+        // Compose and store 16 ~ k_rem cols
+        if (idx_length > 4) {
+            for (int idx_k = 0; idx_k < 4; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 64;
+            }
+
+            for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 64;
+            }
+        } else {
+            for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 64;
+            }
+        }
+
+        dst_addr0 = dst_addr_tmp + 32;
+
+        for (int j = 0; j < m_rem; j++) {
+            array512[j] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+j*lda+tag_k_32x);
+        }
+        for (int j = m_rem; j < 16; j++) {
+            array512[j] = _mm512_setzero_si512();
+        }
+
+        for (int j = 0; j < 4; j++) {
+            int array_idx = j*4;
+            array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+        }
+
+        for (int j = 0; j < 4; j++) {
+            array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+            array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+            array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+            array512[j+0]  = array512_0;  // 1st 8 pairs of col 0/1|2/3|4/5|6/7,   and 1st 8 pairs of col 16/17|18/19|20/21|22/23
+            array512[j+4]  = array512_1;  // 2nd 8 pairs of col 0/1|2/3|4/5|6/7,   and 2nd 8 pairs of col 16/17|18/19|20/21|22/23
+            array512[j+8]  = array512_2;  // 1st 8 pairs of col 8/9|10/11|12/13|14/15,   and 1st 8 pairs of col 24/25|26/27|28/29|30/31
+            array512[j+12] = array512_3;  // 2nd 8 pairs of col 8/9|10/11|12/13|14/15,   and 2nd 8 pairs of col 24/25|26/27|28/29|30/31
+        }
+
+        for (int j = 0; j < 4; j++) {
+            // Compose and store the 0/1 cols
+            array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+        }
+
+        for (int j = 8; j < 12; j++) {
+            array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 64;
+        }
+
+        // Compose and store 16 ~ k_rem cols
+        if (idx_length > 4) {
+            for (int idx_k = 0; idx_k < 4; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 64;
+            }
+
+            for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 64;
+            }
+        } else {
+            for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 64;
+            }
+        }
+    }
+}
+
+// K=Any number but will be processed based on 32, M=16
+void COL_MAJOR_ITCOPY_KERNEL_Kx16(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+{
+    bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+    bfloat16 * dst_addr0, * dst_addr1;
+
+    BLASLONG tag_k_32x = k & (~31);
+
+    BLASLONG LDA_4x  = lda*4;
+    BLASLONG LDA_8x  = lda*8;
+    BLASLONG LDA_12x = lda*12;
+
+    src_addr0 = A;
+    src_addr1 = A + lda;
+    src_addr2 = A + lda*2;
+    src_addr3 = A + lda*3;
+    dst_addr0 = block_A;
+    dst_addr1 = block_A + 32*8;
+
+    __m512i array512_0, array512_1, array512_2, array512_3;
+    __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+    __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3;
+    __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3;
+    __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3;
+
+    __m512i M512_EPI64_2   = _mm512_set1_epi64(2);
+    __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+    __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+    for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+        // Load and preprocess 1st 4 rows
+        array512_way0_0 = _mm512_loadu_si512(src_addr0+idx_k);
+        array512_way0_1 = _mm512_loadu_si512(src_addr1+idx_k);
+        array512_way0_2 = _mm512_loadu_si512(src_addr2+idx_k);
+        array512_way0_3 = _mm512_loadu_si512(src_addr3+idx_k);
+        array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1);
+        array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1);
+        array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3);
+        array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3);
+        array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+        array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+        array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+        array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+        // Load and preprocess 2nd 4 rows
+        array512_way1_0 = _mm512_loadu_si512(src_addr0+LDA_4x+idx_k);
+        array512_way1_1 = _mm512_loadu_si512(src_addr1+LDA_4x+idx_k);
+        array512_way1_2 = _mm512_loadu_si512(src_addr2+LDA_4x+idx_k);
+        array512_way1_3 = _mm512_loadu_si512(src_addr3+LDA_4x+idx_k);
+        array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1);
+        array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1);
+        array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3);
+        array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3);
+        array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+        array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+        array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+        array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+        // Load and preprocess 3rd 4 rows
+        array512_way2_0 = _mm512_loadu_si512(src_addr0+LDA_8x+idx_k);
+        array512_way2_1 = _mm512_loadu_si512(src_addr1+LDA_8x+idx_k);
+        array512_way2_2 = _mm512_loadu_si512(src_addr2+LDA_8x+idx_k);
+        array512_way2_3 = _mm512_loadu_si512(src_addr3+LDA_8x+idx_k);
+        array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1);
+        array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1);
+        array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3);
+        array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3);
+        array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+        array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+        array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+        array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+        // Load and preprocess 4th 4 rows
+        array512_way3_0 = _mm512_loadu_si512(src_addr0+LDA_12x+idx_k);
+        array512_way3_1 = _mm512_loadu_si512(src_addr1+LDA_12x+idx_k);
+        array512_way3_2 = _mm512_loadu_si512(src_addr2+LDA_12x+idx_k);
+        array512_way3_3 = _mm512_loadu_si512(src_addr3+LDA_12x+idx_k);
+        array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1);
+        array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1);
+        array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3);
+        array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3);
+        array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+        array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+        array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+        array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+        // Compose and store the 0/1 and 16/17 cols
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0);
+        array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0);
+        array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+        array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
+        dst_addr0 += 32;
+        dst_addr1 += 32;
+
+        // Compose and store the 2/3 and 18/19 cols
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1);
+        array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1);
+        array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+        array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
+        dst_addr0 += 32;
+        dst_addr1 += 32;
+
+        // Compose and store the 4/5 and 20/21 cols
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2);
+        array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2);
+        array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+        array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
+        dst_addr0 += 32;
+        dst_addr1 += 32;
+
+        // Compose and store the 6/7 and 22/23 cols
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3);
+        array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3);
+        array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+        array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
+        dst_addr0 += 32;
+        dst_addr1 += 32;
+
+        // Compose and store the 8/9 and 24/25 cols
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0);
+        array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0);
+        array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+        array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
+        dst_addr0 += 32;
+        dst_addr1 += 32;
+
+        // Compose and store the 10/11 and 26/27 cols
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1);
+        array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1);
+        array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+        array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
+        dst_addr0 += 32;
+        dst_addr1 += 32;
+
+        // Compose and store the 12/13 and 28/29 cols
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2);
+        array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2);
+        array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+        array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
+        dst_addr0 += 32;
+        dst_addr1 += 32;
+
+        // Compose and store the 14/15 and 30/31 cols
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3);
+        array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3);
+        array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+        array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+        _mm512_storeu_si512(dst_addr0, array512_2);
+        _mm512_storeu_si512(dst_addr1, array512_3);
+        dst_addr0 += 32*9;
+        dst_addr1 += 32*9;
+    }
+
+    if (tag_k_32x != k) {
+        int k_rem = k - tag_k_32x;
+        unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem));
+        __m512i array512[16];
+
+        // Load and preprocess 1st 4 rows
+        array512[0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+        array512[1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+        array512[2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+        array512[3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+        array512_0 = _mm512_unpacklo_epi32(array512[0], array512[1]);
+        array512_1 = _mm512_unpackhi_epi32(array512[0], array512[1]);
+        array512_2 = _mm512_unpacklo_epi32(array512[2], array512[3]);
+        array512_3 = _mm512_unpackhi_epi32(array512[2], array512[3]);
+        array512[0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+        array512[1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+        array512[2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+        array512[3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+        src_addr0 += LDA_4x;
+        src_addr1 += LDA_4x;
+        src_addr2 += LDA_4x;
+        src_addr3 += LDA_4x;
+
+        // Load and preprocess 2nd 4 rows
+        array512[4] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+        array512[5] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+        array512[6] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+        array512[7] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+        array512_0 = _mm512_unpacklo_epi32(array512[4], array512[5]);
+        array512_1 = _mm512_unpackhi_epi32(array512[4], array512[5]);
+        array512_2 = _mm512_unpacklo_epi32(array512[6], array512[7]);
+        array512_3 = _mm512_unpackhi_epi32(array512[6], array512[7]);
+        array512[4] = _mm512_unpacklo_epi64(array512_0, array512_2);
+        array512[5] = _mm512_unpackhi_epi64(array512_0, array512_2);
+        array512[6] = _mm512_unpacklo_epi64(array512_1, array512_3);
+        array512[7] = _mm512_unpackhi_epi64(array512_1, array512_3);
+        src_addr0 += LDA_4x;
+        src_addr1 += LDA_4x;
+        src_addr2 += LDA_4x;
+        src_addr3 += LDA_4x;
+
+        // Load and preprocess 3rd 4 rows
+        array512[8]  = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+        array512[9]  = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+        array512[10] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+        array512[11] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+        array512_0 = _mm512_unpacklo_epi32(array512[8],  array512[9]);
+        array512_1 = _mm512_unpackhi_epi32(array512[8],  array512[9]);
+        array512_2 = _mm512_unpacklo_epi32(array512[10], array512[11]);
+        array512_3 = _mm512_unpackhi_epi32(array512[10], array512[11]);
+        array512[8]  = _mm512_unpacklo_epi64(array512_0, array512_2);
+        array512[9]  = _mm512_unpackhi_epi64(array512_0, array512_2);
+        array512[10] = _mm512_unpacklo_epi64(array512_1, array512_3);
+        array512[11] = _mm512_unpackhi_epi64(array512_1, array512_3);
+        src_addr0 += LDA_4x;
+        src_addr1 += LDA_4x;
+        src_addr2 += LDA_4x;
+        src_addr3 += LDA_4x;
+
+        // Load and preprocess 4th 4 rows
+        array512[12] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+        array512[13] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+        array512[14] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+        array512[15] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+        array512_0 = _mm512_unpacklo_epi32(array512[12], array512[13]);
+        array512_1 = _mm512_unpackhi_epi32(array512[12], array512[13]);
+        array512_2 = _mm512_unpacklo_epi32(array512[14], array512[15]);
+        array512_3 = _mm512_unpackhi_epi32(array512[14], array512[15]);
+        array512[12] = _mm512_unpacklo_epi64(array512_0, array512_2);
+        array512[13] = _mm512_unpackhi_epi64(array512_0, array512_2);
+        array512[14] = _mm512_unpacklo_epi64(array512_1, array512_3);
+        array512[15] = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+        // array512_01_1617_0, array512_01_1617_1, array512_89_2425_0, array512_89_2425_1;
+        // Half-compose of 0/1, 16/17, 8/9, 24/25 cols
+        array512_0 = _mm512_permutex2var_epi64(array512[0], permute_lo_idx, array512[4]);
+        array512_1 = _mm512_permutex2var_epi64(array512[8], permute_lo_idx, array512[12]);
+        array512_2 = _mm512_permutex2var_epi64(array512[0], permute_hi_idx, array512[4]);
+        array512_3 = _mm512_permutex2var_epi64(array512[8], permute_hi_idx, array512[12]);
+        array512[0]  = array512_0;  // 1st 8 pairs of col 0/1,   and 1st 8 pairs of col 16/17
+        array512[4]  = array512_1;  // 2nd 8 pairs of col 0/1,   and 2nd 8 pairs of col 16/17
+        array512[8]  = array512_2;  // 1st 8 pairs of col 8/9,   and 1st 8 pairs of col 24/25
+        array512[12] = array512_3;  // 2nd 8 pairs of col 8/9,   and 2nd 8 pairs of col 24/25
+
+        // Half-compose of 2/3, 18/19, 10/11, 26/27 cols
+        array512_0 = _mm512_permutex2var_epi64(array512[1], permute_lo_idx, array512[5]);
+        array512_1 = _mm512_permutex2var_epi64(array512[9], permute_lo_idx, array512[13]);
+        array512_2 = _mm512_permutex2var_epi64(array512[1], permute_hi_idx, array512[5]);
+        array512_3 = _mm512_permutex2var_epi64(array512[9], permute_hi_idx, array512[13]);
+        array512[1]  = array512_0;  // 1st 8 pairs of col 2/3,   and 1st 8 pairs of col 18/19
+        array512[5]  = array512_1;  // 2nd 8 pairs of col 2/3,   and 2nd 8 pairs of col 18/19
+        array512[9]  = array512_2;  // 1st 8 pairs of col 10/11, and 1st 8 pairs of col 26/27
+        array512[13] = array512_3;  // 2nd 8 pairs of col 10/11, and 2nd 8 pairs of col 26/27
+
+        // Half-compose of 4/5, 20/21, 12/13, 28/29 cols
+        array512_0 = _mm512_permutex2var_epi64(array512[2],  permute_lo_idx, array512[6]);
+        array512_1 = _mm512_permutex2var_epi64(array512[10], permute_lo_idx, array512[14]);
+        array512_2 = _mm512_permutex2var_epi64(array512[2],  permute_hi_idx, array512[6]);
+        array512_3 = _mm512_permutex2var_epi64(array512[10], permute_hi_idx, array512[14]);
+        array512[2]  = array512_0;  // 1st 8 pairs of col 4/5,   and 1st 8 pairs of col 20/21
+        array512[6]  = array512_1;  // 2nd 8 pairs of col 4/5,   and 2nd 8 pairs of col 20/21
+        array512[10] = array512_2;  // 1st 8 pairs of col 12/13, and 1st 8 pairs of col 28/29
+        array512[14] = array512_3;  // 2nd 8 pairs of col 12/13, and 2nd 8 pairs of col 28/29
+
+        // Half-compose of 6/7, 22/23, 14/15, 30/31 cols
+        array512_0 = _mm512_permutex2var_epi64(array512[3],  permute_lo_idx, array512[7]);
+        array512_1 = _mm512_permutex2var_epi64(array512[11], permute_lo_idx, array512[15]);
+        array512_2 = _mm512_permutex2var_epi64(array512[3],  permute_hi_idx, array512[7]);
+        array512_3 = _mm512_permutex2var_epi64(array512[11], permute_hi_idx, array512[15]);
+        array512[3]  = array512_0;  // 1st 8 pairs of col 6/7,   and 1st 8 pairs of col 22/23
+        array512[7]  = array512_1;  // 2nd 8 pairs of col 6/7,   and 2nd 8 pairs of col 22/23
+        array512[11] = array512_2;  // 1st 8 pairs of col 14/15, and 1st 8 pairs of col 30/31
+        array512[15] = array512_3;  // 2nd 8 pairs of col 14/15, and 2nd 8 pairs of col 30/31
+
+        // Compose and store the 0/1 cols
+        array512_0 = _mm512_inserti64x4(array512[0], _mm512_castsi512_si256(array512[4]), 0x1);
+        _mm512_storeu_si512(dst_addr0, array512_0);
+        dst_addr0 += 32;
+
+        // Compose and store the 2/3 cols
+        array512_0 = _mm512_inserti64x4(array512[1], _mm512_castsi512_si256(array512[5]), 0x1);
+        _mm512_storeu_si512(dst_addr0, array512_0);
+        dst_addr0 += 32;
+
+        // Compose and store the 4/5 cols
+        array512_0 = _mm512_inserti64x4(array512[2], _mm512_castsi512_si256(array512[6]), 0x1);
+        _mm512_storeu_si512(dst_addr0, array512_0);
+        dst_addr0 += 32;
+
+        // Compose and store the 6/7 cols
+        array512_0 = _mm512_inserti64x4(array512[3], _mm512_castsi512_si256(array512[7]), 0x1);
+        _mm512_storeu_si512(dst_addr0, array512_0);
+        dst_addr0 += 32;
+
+        // Compose and store the 8/9 cols
+        array512_0 = _mm512_inserti64x4(array512[8], _mm512_castsi512_si256(array512[12]), 0x1);
+        _mm512_storeu_si512(dst_addr0, array512_0);
+        dst_addr0 += 32;
+
+        // Compose and store the 10/11 cols
+        array512_0 = _mm512_inserti64x4(array512[9], _mm512_castsi512_si256(array512[13]), 0x1);
+        _mm512_storeu_si512(dst_addr0, array512_0);
+        dst_addr0 += 32;
+
+        // Compose and store the 12/13 cols
+        array512_0 = _mm512_inserti64x4(array512[10], _mm512_castsi512_si256(array512[14]), 0x1);
+        _mm512_storeu_si512(dst_addr0, array512_0);
+        dst_addr0 += 32;
+
+        // Compose and store the 14/15 cols
+        array512_0 = _mm512_inserti64x4(array512[11], _mm512_castsi512_si256(array512[15]), 0x1);
+        _mm512_storeu_si512(dst_addr0, array512_0);
+        dst_addr0 += 32;
+
+        // Compose and store 16 ~ k_rem cols
+        int idx_length = (k_rem + 1 - 16) >> 1;
+        if (idx_length > 4) {
+            for (int idx_k = 0; idx_k < 4; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 32;
+            }
+
+            for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 32;
+            }
+        } else {
+            for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 32;
+            }
+        }
+    }
+}
+
+// K=Any number but will be processed based on 32, M<=16
+void COL_MAJOR_ITCOPY_KERNEL_Kx16m(BLASLONG m, BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+{
+    bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+    bfloat16 * dst_addr0, * dst_addr1;
+
+    BLASLONG tag_k_32x = k & (~31);
+
+    src_addr0 = A;
+    dst_addr0 = block_A;
+    dst_addr1 = block_A + 32*8;
+
+    __m512i array512_0, array512_1, array512_2, array512_3;
+    __m512i array512[16];
+
+    __m512i M512_EPI64_2   = _mm512_set1_epi64(2);
+    __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+    __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+    for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+        for (int j = 0; j < m; j++) {
+            array512[j] = _mm512_loadu_si512(src_addr0+j*lda+idx_k);
+        }
+        for (int j = m; j < 16; j++) {
+            array512[j] = _mm512_setzero_si512();
+        }
+
+        for (int j = 0; j < 4; j++) {
+            int array_idx = j*4;
+            array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+        }
+
+        // Compose and store the 0/1, 2/3, 4/5, 6/7 and 16/17, 18/19, 20/21, 22/23 cols
+        for (int j = 0; j < 4; j++) {
+            array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 32;
+            dst_addr1 += 32;
+        }
+
+        // Compose and store the 8/9, 10/11, 12/13, 14/15 and 24/25, 26/27, 28/29, 30/31 cols
+        for (int j = 0; j < 4; j++) {
+            array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+            array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+            array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+            _mm512_storeu_si512(dst_addr0, array512_2);
+            _mm512_storeu_si512(dst_addr1, array512_3);
+            dst_addr0 += 32;
+            dst_addr1 += 32;
+        }
+
+        dst_addr0 += 32*8;
+        dst_addr1 += 32*8;
     }
 
-#ifdef DEBUG_PROFILE
-   print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
-#endif
+    if (tag_k_32x != k) {
+        int k_rem = k - tag_k_32x;
+        unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem));
+
+        for (int j = 0; j < m; j++) {
+            array512[j] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+j*lda+tag_k_32x);
+        }
+        for (int j = m; j < 16; j++) {
+            array512[j] = _mm512_setzero_si512();
+        }
+
+        for (int j = 0; j < 4; j++) {
+            int array_idx = j*4;
+            array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+            array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+            array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+            array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+            array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+            array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+        }
+
+        for (int j = 0; j < 4; j++) {
+            array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+            array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+            array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+            array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+            array512[j+0]  = array512_0;  // 1st 8 pairs of col 0/1|2/3|4/5|6/7,   and 1st 8 pairs of col 16/17|18/19|20/21|22/23
+            array512[j+4]  = array512_1;  // 2nd 8 pairs of col 0/1|2/3|4/5|6/7,   and 2nd 8 pairs of col 16/17|18/19|20/21|22/23
+            array512[j+8]  = array512_2;  // 1st 8 pairs of col 8/9|10/11|12/13|14/15,   and 1st 8 pairs of col 24/25|26/27|28/29|30/31
+            array512[j+12] = array512_3;  // 2nd 8 pairs of col 8/9|10/11|12/13|14/15,   and 2nd 8 pairs of col 24/25|26/27|28/29|30/31
+        }
+
+        for (int j = 0; j < 4; j++) {
+            // Compose and store the 0/1 cols
+            array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 32;
+        }
+
+        for (int j = 8; j < 12; j++) {
+            array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+            _mm512_storeu_si512(dst_addr0, array512_0);
+            dst_addr0 += 32;
+        }
+
+        // Compose and store 16 ~ k_rem cols
+        int idx_length = (k_rem + 1 - 16) >> 1;
+        if (idx_length > 4) {
+            for (int idx_k = 0; idx_k < 4; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 32;
+            }
+
+            for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 32;
+            }
+        } else {
+            for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+                array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+                _mm512_storeu_si512(dst_addr0, array512_0);
+                dst_addr0 += 32;
+            }
+        }
+    }
 }
 
+// COL_MAJOR_ONCOPY_KERNEL_16x32 behaves exactly the same as COL_MAJOR_ITCOPY_KERNEL_Kx16
+#define COL_MAJOR_ONCOPY_KERNEL_16x32 COL_MAJOR_ITCOPY_KERNEL_Kx16
+
 void COL_MAJOR_ONCOPY_KERNEL_8x32(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
 {
     BLASLONG tag_k_32x = k & (~31);
-    BLASLONG idx_src_base0, idx_src_base1, idx_src_base2, idx_src_base3, idx_src_base4, idx_src_base5, idx_src_base6, idx_src_base7;
-    BLASLONG idx_target_base0;
 
-    idx_src_base0 = 0;
-    idx_src_base1 = 1*ldb;
-    idx_src_base2 = 2*ldb;
-    idx_src_base3 = 3*ldb;
-    idx_src_base4 = 4*ldb;
-    idx_src_base5 = 5*ldb;
-    idx_src_base6 = 6*ldb;
-    idx_src_base7 = 7*ldb;
-    idx_target_base0 = 0;
+    bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3, * src_addr4, * src_addr5, * src_addr6, * src_addr7;
+    bfloat16 * dst_addr0;
+
+    unsigned char blend_mask = (((unsigned char)0xcc));
+    __m512i permute_idx = _mm512_set_epi64(13, 12, 7, 6, 9, 8, 3, 2);
+
+    src_addr0 = B;
+    src_addr1 = src_addr0 + 1*ldb;
+    src_addr2 = src_addr0 + 2*ldb;
+    src_addr3 = src_addr0 + 3*ldb;
+    src_addr4 = src_addr0 + 4*ldb;
+    src_addr5 = src_addr0 + 5*ldb;
+    src_addr6 = src_addr0 + 6*ldb;
+    src_addr7 = src_addr0 + 7*ldb;
+    dst_addr0 = block_B;
+
+    __m512i array512_0, array512_1, array512_2, array512_3;
+    __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+    __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3;
 
     for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0+idx_k]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base1+idx_k]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_loadu_si512(&B[idx_src_base2+idx_k]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_loadu_si512(&B[idx_src_base3+idx_k]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_loadu_si512(&B[idx_src_base4+idx_k]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_loadu_si512(&B[idx_src_base5+idx_k]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_loadu_si512(&B[idx_src_base6+idx_k]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_loadu_si512(&B[idx_src_base7+idx_k]));
-        idx_target_base0 += 32*8;
+        array512_0 = _mm512_loadu_si512(src_addr0+idx_k);
+        array512_1 = _mm512_loadu_si512(src_addr1+idx_k);
+        array512_2 = _mm512_loadu_si512(src_addr2+idx_k);
+        array512_3 = _mm512_loadu_si512(src_addr3+idx_k);
+
+        array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+        array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+        array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+        array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+        array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2);
+        array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2);
+        array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3);
+        array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3);
+
+        array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88);
+        array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd);
+        array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88);
+        array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd);
+
+        array512_0 = _mm512_loadu_si512(src_addr4+idx_k);
+        array512_1 = _mm512_loadu_si512(src_addr5+idx_k);
+        array512_2 = _mm512_loadu_si512(src_addr6+idx_k);
+        array512_3 = _mm512_loadu_si512(src_addr7+idx_k);
+
+        array512_way1_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+        array512_way1_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+        array512_way1_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+        array512_way1_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+        array512_0 = _mm512_unpacklo_epi64(array512_way1_0, array512_way1_2);
+        array512_1 = _mm512_unpackhi_epi64(array512_way1_0, array512_way1_2);
+        array512_2 = _mm512_unpacklo_epi64(array512_way1_1, array512_way1_3);
+        array512_3 = _mm512_unpackhi_epi64(array512_way1_1, array512_way1_3);
+
+        array512_way1_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x22);
+        array512_way1_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x77);
+        array512_way1_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x22);
+        array512_way1_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x77);
+
+        array512_0 = _mm512_mask_blend_epi64(blend_mask, array512_way0_0, array512_way1_0);
+        array512_1 = _mm512_mask_blend_epi64(blend_mask, array512_way0_1, array512_way1_1);
+        array512_2 = _mm512_mask_blend_epi64(blend_mask, array512_way0_2, array512_way1_2);
+        array512_3 = _mm512_mask_blend_epi64(blend_mask, array512_way0_3, array512_way1_3);
+        _mm512_storeu_si512(dst_addr0,    array512_0);
+        _mm512_storeu_si512(dst_addr0+32, array512_1);
+        _mm512_storeu_si512(dst_addr0+64, array512_2);
+        _mm512_storeu_si512(dst_addr0+96, array512_3);
+
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_idx, array512_way1_0);
+        array512_1 = _mm512_permutex2var_epi64(array512_way0_1, permute_idx, array512_way1_1);
+        array512_2 = _mm512_permutex2var_epi64(array512_way0_2, permute_idx, array512_way1_2);
+        array512_3 = _mm512_permutex2var_epi64(array512_way0_3, permute_idx, array512_way1_3);
+        _mm512_storeu_si512(dst_addr0+128, array512_0);
+        _mm512_storeu_si512(dst_addr0+160, array512_1);
+        _mm512_storeu_si512(dst_addr0+192, array512_2);
+        _mm512_storeu_si512(dst_addr0+224, array512_3);
+
+        dst_addr0 += 256;
     }
 
     if (tag_k_32x != k) {
         unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x)));
         __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0+tag_k_32x]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base1+tag_k_32x]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base2+tag_k_32x]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base3+tag_k_32x]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base4+tag_k_32x]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base5+tag_k_32x]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base6+tag_k_32x]));
-        _mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base7+tag_k_32x]));
+        array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+        array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+        array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+        array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+
+        array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+        array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+        array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+        array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+        array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2);
+        array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2);
+        array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3);
+        array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3);
+
+        array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88);
+        array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd);
+        array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88);
+        array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd);
+
+        array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr4+tag_k_32x);
+        array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr5+tag_k_32x);
+        array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr6+tag_k_32x);
+        array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr7+tag_k_32x);
+
+        array512_way1_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+        array512_way1_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+        array512_way1_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+        array512_way1_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+        array512_0 = _mm512_unpacklo_epi64(array512_way1_0, array512_way1_2);
+        array512_1 = _mm512_unpackhi_epi64(array512_way1_0, array512_way1_2);
+        array512_2 = _mm512_unpacklo_epi64(array512_way1_1, array512_way1_3);
+        array512_3 = _mm512_unpackhi_epi64(array512_way1_1, array512_way1_3);
+
+        array512_way1_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x22);
+        array512_way1_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x77);
+        array512_way1_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x22);
+        array512_way1_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x77);
+
+
+        array512_0 = _mm512_mask_blend_epi64(blend_mask, array512_way0_0, array512_way1_0);
+        array512_1 = _mm512_mask_blend_epi64(blend_mask, array512_way0_1, array512_way1_1);
+        array512_2 = _mm512_mask_blend_epi64(blend_mask, array512_way0_2, array512_way1_2);
+        array512_3 = _mm512_mask_blend_epi64(blend_mask, array512_way0_3, array512_way1_3);
+        _mm512_storeu_si512(dst_addr0,    array512_0);
+        _mm512_storeu_si512(dst_addr0+32, array512_1);
+        _mm512_storeu_si512(dst_addr0+64, array512_2);
+        _mm512_storeu_si512(dst_addr0+96, array512_3);
+
+        array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_idx, array512_way1_0);
+        array512_1 = _mm512_permutex2var_epi64(array512_way0_1, permute_idx, array512_way1_1);
+        array512_2 = _mm512_permutex2var_epi64(array512_way0_2, permute_idx, array512_way1_2);
+        array512_3 = _mm512_permutex2var_epi64(array512_way0_3, permute_idx, array512_way1_3);
+        _mm512_storeu_si512(dst_addr0+128, array512_0);
+        _mm512_storeu_si512(dst_addr0+160, array512_1);
+        _mm512_storeu_si512(dst_addr0+192, array512_2);
+        _mm512_storeu_si512(dst_addr0+224, array512_3);
+    }
+}
+
+void COL_MAJOR_ONCOPY_KERNEL_4x32(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
+{
+    BLASLONG tag_k_32x = k & (~31);
+
+    bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+    bfloat16 * dst_addr0;
+
+    src_addr0 = B;
+    src_addr1 = src_addr0 + 1*ldb;
+    src_addr2 = src_addr0 + 2*ldb;
+    src_addr3 = src_addr0 + 3*ldb;
+    dst_addr0 = block_B;
+
+    __m512i array512_0, array512_1, array512_2, array512_3;
+    __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+
+    for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+        array512_0 = _mm512_loadu_si512(src_addr0+idx_k);
+        array512_1 = _mm512_loadu_si512(src_addr1+idx_k);
+        array512_2 = _mm512_loadu_si512(src_addr2+idx_k);
+        array512_3 = _mm512_loadu_si512(src_addr3+idx_k);
+
+        array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+        array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+        array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+        array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+        array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2);
+        array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2);
+        array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3);
+        array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3);
+
+        array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88);
+        array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd);
+        array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88);
+        array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd);
+
+        array512_0 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0x88);
+        array512_1 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0x88);
+        array512_2 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0xdd);
+        array512_3 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0xdd);
+
+        _mm512_storeu_si512(dst_addr0,    array512_0);
+        _mm512_storeu_si512(dst_addr0+32, array512_1);
+        _mm512_storeu_si512(dst_addr0+64, array512_2);
+        _mm512_storeu_si512(dst_addr0+96, array512_3);
+
+        dst_addr0 += 128;
     }
 
-#ifdef DEBUG_PROFILE
-   print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B);
-#endif
+    if (tag_k_32x != k) {
+        unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x)));
+        __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
+        array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+        array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+        array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+        array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+
+        array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+        array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+        array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+        array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+        array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2);
+        array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2);
+        array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3);
+        array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3);
+
+        array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88);
+        array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd);
+        array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88);
+        array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd);
+
+        array512_0 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0x88);
+        array512_1 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0x88);
+        array512_2 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0xdd);
+        array512_3 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0xdd);
+
+        _mm512_storeu_si512(dst_addr0,    array512_0);
+        _mm512_storeu_si512(dst_addr0+32, array512_1);
+        _mm512_storeu_si512(dst_addr0+64, array512_2);
+        _mm512_storeu_si512(dst_addr0+96, array512_3);
+    }
 }
 
 void COL_MAJOR_ONCOPY_KERNEL_Nx32(BLASLONG n, BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
 {
     BLASLONG tag_k_32x = k & (~31);
     BLASLONG tag_n_2x  = n & (~1);
-    BLASLONG idx_src_base0;
-    BLASLONG idx_target_base0;
+
+    bfloat16 * src_addr0;
+    bfloat16 * dst_addr0;
 
     BLASLONG LDB_2x = 2*ldb;
 
-    idx_target_base0 = 0;
+    src_addr0 = B;
+    dst_addr0 = block_B;
 
     for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
-        idx_src_base0 = 0;
+        src_addr0 = B;
         for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) {
-            _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0       + idx_k]));
-            _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base0 + ldb + idx_k]));
-            idx_src_base0 += LDB_2x;
-            idx_target_base0 += 64;
+            _mm512_storeu_si512(dst_addr0,      _mm512_loadu_si512(src_addr0 + idx_k));
+            _mm512_storeu_si512(dst_addr0 + 32, _mm512_loadu_si512(src_addr0 + ldb + idx_k));
+            src_addr0 += LDB_2x;
+            dst_addr0 += 64;
         }
         
         if (tag_n_2x != n) {
-            _mm512_storeu_si512(&block_B[idx_target_base0],  _mm512_loadu_si512(&B[idx_src_base0       + idx_k]));
-            idx_target_base0 += 32;
+            _mm512_storeu_si512(dst_addr0,  _mm512_loadu_si512(src_addr0 + idx_k));
+            dst_addr0 += 32;
         }
     }
 
     if (tag_k_32x != k) {
         unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x)));
         __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
-        idx_src_base0 = 0;
+        src_addr0 = B;
         for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) {
-            _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0       + tag_k_32x]));
-            _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + ldb + tag_k_32x]));
-            idx_src_base0 += LDB_2x;
-            idx_target_base0 += 64;
+            _mm512_storeu_si512(dst_addr0,      _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + tag_k_32x));
+            _mm512_storeu_si512(dst_addr0 + 32, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + ldb + tag_k_32x));
+            src_addr0 += LDB_2x;
+            dst_addr0 += 64;
         }
         
         if (tag_n_2x != n) {
-            _mm512_storeu_si512(&block_B[idx_target_base0],  _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0       + tag_k_32x]));
+            _mm512_storeu_si512(dst_addr0,  _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + tag_k_32x));
         }
     }
+}
+
+void COL_MAJOR_OTCOPY_KERNEL_Kx8(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
+{
+    BLASLONG tag_k_2x = k & (~1);
+    unsigned char tail_mask_value = (unsigned char) 0xff;
+    __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
 
-#ifdef DEBUG_PROFILE
-   print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B);
-#endif
+    __m128i array128_0, array128_1, array128_2, array128_3;
+
+    BLASLONG idx_src_base0, idx_src_base1;
+    BLASLONG idx_target_base0, idx_target_base1;
+
+    BLASLONG LDA_2x = 2*ldb;
+    BLASLONG BF16_BLOCK_T_M_2x = 2*8;
+    idx_src_base0 = 0;
+    idx_src_base1 = ldb;
+    idx_target_base0 = 0;
+    idx_target_base1 = 8;
+    for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
+        array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]);
+        array128_1 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base1]);
+        array128_2 = _mm_unpacklo_epi16(array128_0, array128_1);
+        array128_3 = _mm_unpackhi_epi16(array128_0, array128_1);
+        _mm_storeu_epi32(&block_B[idx_target_base0], array128_2);
+        _mm_storeu_epi32(&block_B[idx_target_base1], array128_3);
+
+        idx_src_base0 += LDA_2x;
+        idx_src_base1 += LDA_2x;
+        idx_target_base0 += BF16_BLOCK_T_M_2x;
+        idx_target_base1 += BF16_BLOCK_T_M_2x;
+    }
+
+    if (tag_k_2x != k) {
+        __m128i ZERO128 = _mm_setzero_si128();
+        array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]);
+        array128_2 = _mm_unpacklo_epi16(array128_0, ZERO128);
+        array128_3 = _mm_unpackhi_epi16(array128_0, ZERO128);
+        _mm_storeu_epi32(&block_B[idx_target_base0], array128_2);
+        _mm_storeu_epi32(&block_B[idx_target_base1], array128_3);
+   }
+}
+
+void COL_MAJOR_OTCOPY_KERNEL_Kx8m(BLASLONG k, BLASLONG n, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
+{
+    BLASLONG tag_k_2x = k & (~1);
+    unsigned char tail_mask = (((unsigned char)0xff) >> (8-n));
+
+    __m128i array128_0, array128_1, array128_2, array128_3;
+
+    BLASLONG idx_src_base0, idx_src_base1;
+    BLASLONG idx_target_base0, idx_target_base1;
+
+    BLASLONG LDA_2x = 2*ldb;
+    BLASLONG BF16_BLOCK_T_M_2x = 2*8;
+    idx_src_base0 = 0;
+    idx_src_base1 = ldb;
+    idx_target_base0 = 0;
+    idx_target_base1 = 8;
+    for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
+        array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]);
+        array128_1 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base1]);
+        array128_2 = _mm_unpacklo_epi16(array128_0, array128_1);
+        array128_3 = _mm_unpackhi_epi16(array128_0, array128_1);
+        _mm_storeu_epi32(&block_B[idx_target_base0], array128_2);
+        _mm_storeu_epi32(&block_B[idx_target_base1], array128_3);
+
+        idx_src_base0 += LDA_2x;
+        idx_src_base1 += LDA_2x;
+        idx_target_base0 += BF16_BLOCK_T_M_2x;
+        idx_target_base1 += BF16_BLOCK_T_M_2x;
+    }
+
+    if (tag_k_2x != k) {
+        __m128i ZERO128 = _mm_setzero_si128();
+        array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]);
+        array128_2 = _mm_unpacklo_epi16(array128_0, ZERO128);
+        array128_3 = _mm_unpackhi_epi16(array128_0, ZERO128);
+        _mm_storeu_epi32(&block_B[idx_target_base0], array128_2);
+        _mm_storeu_epi32(&block_B[idx_target_base1], array128_3);
+   }
 }
 
-// Scale matrix C while beta is not ZERO or ONE
+// Scale matrix C when beta is not ZERO or ONE
 void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc)
 {
-    BLASLONG tag_n_Nx = N & (~3);
-    BLASLONG tag_n_Mx = M & (~15);
+    float * C_addr0 = C;
+    float * C_addr1 = C + ldc;
+    float * C_addr2 = C + ldc*2;
+    float * C_addr3 = C + ldc*3;
 
     BLASLONG LDC4x = ldc*4;
-    BLASLONG idx_base_0 = 0;
-    BLASLONG idx_base_1 = ldc;
-    BLASLONG idx_base_2 = ldc*2;
-    BLASLONG idx_base_3 = ldc*3;
-
-    unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
-    __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
 
     __m512 array_512_0, array_512_1, array_512_2, array_512_3;
+    __m512 BETAVECTOR  = _mm512_set1_ps(beta);
 
-    __m512  BETAVECTOR  = _mm512_set1_ps(beta);
+    if (Order == CblasRowMajor) {
+        blasint tmp = M;
+        M = N;
+        N = tmp;
+    }
 
-    if (Order == CblasColMajor) {
-        for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
-            for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
-                array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]);
-                array_512_1 = _mm512_loadu_ps(&C[idx_base_1+idx_m]);
-                array_512_2 = _mm512_loadu_ps(&C[idx_base_2+idx_m]);
-                array_512_3 = _mm512_loadu_ps(&C[idx_base_3+idx_m]);
+    BLASLONG tag_n_Nx = N & (~3);
+    BLASLONG tag_n_Mx = M & (~15);
+    unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
+    for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
+        for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
+            array_512_0 = _mm512_loadu_ps(C_addr0 + idx_m);
+            array_512_1 = _mm512_loadu_ps(C_addr1 + idx_m);
+            array_512_2 = _mm512_loadu_ps(C_addr2 + idx_m);
+            array_512_3 = _mm512_loadu_ps(C_addr3 + idx_m);
+
+            array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
+            array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
+            array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
+            array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
+
+            _mm512_storeu_ps(C_addr0 + idx_m, array_512_0);
+            _mm512_storeu_ps(C_addr1 + idx_m, array_512_1);
+            _mm512_storeu_ps(C_addr2 + idx_m, array_512_2);
+            _mm512_storeu_ps(C_addr3 + idx_m, array_512_3);
+        }
 
-                array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
-                array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
-                array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
-                array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
-
-                _mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0);
-                _mm512_storeu_ps(&C[idx_base_1+idx_m], array_512_1);
-                _mm512_storeu_ps(&C[idx_base_2+idx_m], array_512_2);
-                _mm512_storeu_ps(&C[idx_base_3+idx_m], array_512_3);
-            }
+        if (tag_n_Mx != M) {
+            array_512_0 = _mm512_maskz_loadu_ps(tail_mask, C_addr0 + tag_n_Mx);
+            array_512_1 = _mm512_maskz_loadu_ps(tail_mask, C_addr1 + tag_n_Mx);
+            array_512_2 = _mm512_maskz_loadu_ps(tail_mask, C_addr2 + tag_n_Mx);
+            array_512_3 = _mm512_maskz_loadu_ps(tail_mask, C_addr3 + tag_n_Mx);
+
+            array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
+            array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
+            array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
+            array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
+
+            _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, array_512_0);
+            _mm512_mask_storeu_ps(C_addr1 + tag_n_Mx, tail_mask, array_512_1);
+            _mm512_mask_storeu_ps(C_addr2 + tag_n_Mx, tail_mask, array_512_2);
+            _mm512_mask_storeu_ps(C_addr3 + tag_n_Mx, tail_mask, array_512_3);
+        }
 
-            if (tag_n_Mx != M) {
-                array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]);
-                array_512_1 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_1+tag_n_Mx]);
-                array_512_2 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_2+tag_n_Mx]);
-                array_512_3 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_3+tag_n_Mx]);
+        C_addr0 += LDC4x;
+        C_addr1 += LDC4x;
+        C_addr2 += LDC4x;
+        C_addr3 += LDC4x;
+    }
 
+    if (tag_n_Nx != N) {
+        for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
+            for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
+                array_512_0 = _mm512_loadu_ps(C_addr0 + idx_m);
                 array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
-                array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
-                array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
-                array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
-
-                _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0);
-                _mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, array_512_1);
-                _mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, array_512_2);
-                _mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, array_512_3);
+                _mm512_storeu_ps(C_addr0 + idx_m, array_512_0);
             }
 
-            idx_base_0 += LDC4x;
-            idx_base_1 += LDC4x;
-            idx_base_2 += LDC4x;
-            idx_base_3 += LDC4x;
-        }
-
-        if (tag_n_Nx != N) {
-            for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
-                for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
-                    array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]);
-                    array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
-                    _mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0);
-                }
-
-                if (tag_n_Mx != M) {
-                    array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]);
-                    array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
-                    _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0);
-                }
-                idx_base_0 += ldc;
+            if (tag_n_Mx != M) {
+                array_512_0 = _mm512_maskz_loadu_ps(tail_mask, C_addr0 + tag_n_Mx);
+                array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
+                _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, array_512_0);
             }
+            C_addr0 += ldc;
         }
-    } else {
-
     }
 }
 
-// Scale matrix C while beta is not ZERO or ONE
+// Zero C matrix when Beta is 0
 void sbgemm_zero_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, float *C, OPENBLAS_CONST blasint ldc)
 {
-    BLASLONG tag_n_Nx = N & (~3);
-    BLASLONG tag_n_Mx = M & (~15);
+    float * C_addr0 = C;
+    float * C_addr1 = C + ldc;
+    float * C_addr2 = C + ldc*2;
+    float * C_addr3 = C + ldc*3;
 
     BLASLONG LDC4x = ldc*4;
-    BLASLONG idx_base_0 = 0;
-    BLASLONG idx_base_1 = ldc;
-    BLASLONG idx_base_2 = ldc*2;
-    BLASLONG idx_base_3 = ldc*3;
-
-    unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
-    __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
 
     __m512  ZEROVECTOR  = _mm512_setzero_ps();
 
-    if (Order == CblasColMajor) {
-        for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
-            for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
-                _mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR);
-                _mm512_storeu_ps(&C[idx_base_1+idx_m], ZEROVECTOR);
-                _mm512_storeu_ps(&C[idx_base_2+idx_m], ZEROVECTOR);
-                _mm512_storeu_ps(&C[idx_base_3+idx_m], ZEROVECTOR);
-            }
+    if (Order == CblasRowMajor) {
+        blasint tmp = M;
+        M = N;
+        N = tmp;
+    }
 
-            if (tag_n_Mx != M) {
-                _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR);
-                _mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, ZEROVECTOR);
-                _mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, ZEROVECTOR);
-                _mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, ZEROVECTOR);
-            }
+    BLASLONG tag_n_Nx = N & (~3);
+    BLASLONG tag_n_Mx = M & (~15);
+    unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
+    for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
+        for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
+            _mm512_storeu_ps(C_addr0 + idx_m, ZEROVECTOR);
+            _mm512_storeu_ps(C_addr1 + idx_m, ZEROVECTOR);
+            _mm512_storeu_ps(C_addr2 + idx_m, ZEROVECTOR);
+            _mm512_storeu_ps(C_addr3 + idx_m, ZEROVECTOR);
+        }
 
-            idx_base_0 += LDC4x;
-            idx_base_1 += LDC4x;
-            idx_base_2 += LDC4x;
-            idx_base_3 += LDC4x;
+        if (tag_n_Mx != M) {
+            _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, ZEROVECTOR);
+            _mm512_mask_storeu_ps(C_addr1 + tag_n_Mx, tail_mask, ZEROVECTOR);
+            _mm512_mask_storeu_ps(C_addr2 + tag_n_Mx, tail_mask, ZEROVECTOR);
+            _mm512_mask_storeu_ps(C_addr3 + tag_n_Mx, tail_mask, ZEROVECTOR);
         }
 
-        if (tag_n_Nx != N) {
-            for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
-                for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
-                    _mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR);
-                }
+        C_addr0 += LDC4x;
+        C_addr1 += LDC4x;
+        C_addr2 += LDC4x;
+        C_addr3 += LDC4x;
+    }
 
-                if (tag_n_Mx != M) {
-                    _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR);
-                }
-                idx_base_0 += ldc;
+    if (tag_n_Nx != N) {
+        for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
+            for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
+                _mm512_storeu_ps(C_addr0 + idx_m, ZEROVECTOR);
             }
-        }
-    } else {
 
+            if (tag_n_Mx != M) {
+                _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, ZEROVECTOR);
+            }
+            C_addr0 += ldc;
+        }
     }
-}
\ No newline at end of file
+}
index dd4cb44..c715958 100644 (file)
 #include "bf16_common_macros.h"
 #include <immintrin.h>
 
+/*  These macros are needed and should be placed at the right place
+#define BF16_BLOCK_STEP_N 8
+#define BF16_BLOCK_THRES_K 1024
+#define BF16_BLOCK_THRES_M 32
+#define BF16_BLOCK_THRES_N 1024
+
+#define A(i,j) A[(i)*lda+(j)]
+#define B(i,j) B[(i)*ldb+(j)]
+#define C(i,j) C[(i)*ldc+(j)]
+
+#define ONE  1.e0f
+#define ZERO  0.e0f
+*/
+
 #undef STORE16_COMPLETE_RESULT
 #undef STORE16_MASK_COMPLETE_RESULT
-#undef SBGEMM_BLOCK_KERNEL_32x8x32
-#undef SBGEMM_BLOCK_KERNEL_16x8x32
-#undef SBGEMM_BLOCK_KERNEL_32xNx32
-#undef SBGEMM_BLOCK_KERNEL_16xNx32
-#undef SBGEMM_BLOCKING_KERNEL_2
+#undef SBGEMM_BLOCK_KERNEL_NN_32x8xK
+#undef SBGEMM_BLOCK_KERNEL_NN_16x8xK
+#undef SBGEMM_BLOCK_KERNEL_NN_32xNx32
+#undef SBGEMM_BLOCK_KERNEL_NN_16xNx32
+#undef SBGEMM_BLOCK_KERNEL_NT_32x8xK
+#undef SBGEMM_BLOCK_KERNEL_NT_16x8xK
+#undef SBGEMM_BLOCK_KERNEL_NT_32xNxK
+#undef SBGEMM_BLOCK_KERNEL_NT_16xNxK
+#undef SBGEMM_BLOCK_KERNEL_TN_32x8xK
+#undef SBGEMM_BLOCK_KERNEL_TN_16x8xK
+#undef SBGEMM_BLOCK_KERNEL_TN_32xNx32
+#undef SBGEMM_BLOCK_KERNEL_TN_16xNx32
+#undef SBGEMM_BLOCK_KERNEL_TT_32x8xK
+#undef SBGEMM_BLOCK_KERNEL_TT_16x8xK
+#undef SBGEMM_BLOCK_KERNEL_TT_32xNxK
+#undef SBGEMM_BLOCK_KERNEL_TT_16xNxK
+#undef SBGEMM_BLOCKING_KERNEL_NN
+#undef SBGEMM_BLOCKING_KERNEL_NT
+#undef SBGEMM_BLOCKING_KERNEL_TN
+#undef SBGEMM_BLOCKING_KERNEL_TT
 
 #ifndef ONE_ALPHA      // ALPHA is not ONE
-    #define STORE16_COMPLETE_RESULT       STORE16_COMPLETE_RESULT_ALPHA_ONE
-    #define STORE16_MASK_COMPLETE_RESULT  STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
-    #define SBGEMM_BLOCK_KERNEL_32x8x32   sbgemm_block_kernel_32x8x32_alpha
-    #define SBGEMM_BLOCK_KERNEL_16x8x32   sbgemm_block_kernel_16x8x32_alpha
-    #define SBGEMM_BLOCK_KERNEL_32xNx32   sbgemm_block_kernel_32xNx32_alpha
-    #define SBGEMM_BLOCK_KERNEL_16xNx32   sbgemm_block_kernel_16xNx32_alpha
-    #define SBGEMM_BLOCKING_KERNEL_2      sbgemm_blocking_kernel_2_alpha
+    #define STORE16_COMPLETE_RESULT          STORE16_COMPLETE_RESULT_ALPHA_ONE
+    #define STORE16_MASK_COMPLETE_RESULT     STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
+
+    #define SBGEMM_BLOCK_KERNEL_NN_32x8xK    sbgemm_block_kernel_nn_32x8xK_alpha
+    #define SBGEMM_BLOCK_KERNEL_NN_16x8xK    sbgemm_block_kernel_nn_16x8xK_alpha
+    #define SBGEMM_BLOCK_KERNEL_NN_32xNx32   sbgemm_block_kernel_nn_32xNx32_alpha
+    #define SBGEMM_BLOCK_KERNEL_NN_16xNx32   sbgemm_block_kernel_nn_16xNx32_alpha
+
+    #define SBGEMM_BLOCK_KERNEL_NT_32x8xK    SBGEMM_BLOCK_KERNEL_NN_32x8xK
+    #define SBGEMM_BLOCK_KERNEL_NT_16x8xK    SBGEMM_BLOCK_KERNEL_NN_16x8xK
+    #define SBGEMM_BLOCK_KERNEL_NT_32xNxK    sbgemm_block_kernel_nt_32xNxK_alpha
+    #define SBGEMM_BLOCK_KERNEL_NT_16xNxK    sbgemm_block_kernel_nt_16xNxK_alpha
+
+    #define SBGEMM_BLOCK_KERNEL_TN_32x8xK    sbgemm_block_kernel_tn_32x8xK_alpha
+    #define SBGEMM_BLOCK_KERNEL_TN_16x8xK    sbgemm_block_kernel_tn_16x8xK_alpha
+    #define SBGEMM_BLOCK_KERNEL_TN_32xNx32   sbgemm_block_kernel_tn_32xNx32_alpha
+    #define SBGEMM_BLOCK_KERNEL_TN_16xNx32   sbgemm_block_kernel_tn_16xNx32_alpha
+
+    #define SBGEMM_BLOCK_KERNEL_TT_32x8xK    SBGEMM_BLOCK_KERNEL_TN_32x8xK
+    #define SBGEMM_BLOCK_KERNEL_TT_16x8xK    SBGEMM_BLOCK_KERNEL_TN_16x8xK
+    #define SBGEMM_BLOCK_KERNEL_TT_32xNxK    sbgemm_block_kernel_tt_32xNxK_alpha
+    #define SBGEMM_BLOCK_KERNEL_TT_16xNxK    sbgemm_block_kernel_tt_16xNxK_alpha
+
+    #define SBGEMM_BLOCKING_KERNEL_NN        sbgemm_blocking_kernel_nn_alpha
+    #define SBGEMM_BLOCKING_KERNEL_NT        sbgemm_blocking_kernel_nt_alpha
+    #define SBGEMM_BLOCKING_KERNEL_TN        sbgemm_blocking_kernel_tn_alpha
+    #define SBGEMM_BLOCKING_KERNEL_TT        sbgemm_blocking_kernel_tt_alpha
 #else                  // ALPHA is ONE
-    #define STORE16_COMPLETE_RESULT       STORE16_COMPLETE_RESULT_ONE_ONE
-    #define STORE16_MASK_COMPLETE_RESULT  STORE16_MASK_COMPLETE_RESULT_ONE_ONE
-    #define SBGEMM_BLOCK_KERNEL_32x8x32   sbgemm_block_kernel_32x8x32_one
-    #define SBGEMM_BLOCK_KERNEL_16x8x32   sbgemm_block_kernel_16x8x32_one
-    #define SBGEMM_BLOCK_KERNEL_32xNx32   sbgemm_block_kernel_32xNx32_one
-    #define SBGEMM_BLOCK_KERNEL_16xNx32   sbgemm_block_kernel_16xNx32_one
-    #define SBGEMM_BLOCKING_KERNEL_2      sbgemm_blocking_kernel_2_one
+    #define STORE16_COMPLETE_RESULT          STORE16_COMPLETE_RESULT_ONE_ONE
+    #define STORE16_MASK_COMPLETE_RESULT     STORE16_MASK_COMPLETE_RESULT_ONE_ONE
+
+    #define SBGEMM_BLOCK_KERNEL_NN_32x8xK    sbgemm_block_kernel_nn_32x8xK_one
+    #define SBGEMM_BLOCK_KERNEL_NN_16x8xK    sbgemm_block_kernel_nn_16x8xK_one
+    #define SBGEMM_BLOCK_KERNEL_NN_32xNx32   sbgemm_block_kernel_nn_32xNx32_one
+    #define SBGEMM_BLOCK_KERNEL_NN_16xNx32   sbgemm_block_kernel_nn_16xNx32_one
+
+    #define SBGEMM_BLOCK_KERNEL_NT_32x8xK    SBGEMM_BLOCK_KERNEL_NN_32x8xK
+    #define SBGEMM_BLOCK_KERNEL_NT_16x8xK    SBGEMM_BLOCK_KERNEL_NN_16x8xK
+    #define SBGEMM_BLOCK_KERNEL_NT_32xNxK    sbgemm_block_kernel_nt_32xNxK_one
+    #define SBGEMM_BLOCK_KERNEL_NT_16xNxK    sbgemm_block_kernel_nt_16xNxK_one
+
+    #define SBGEMM_BLOCK_KERNEL_TN_32x8xK    sbgemm_block_kernel_tn_32x8xK_one
+    #define SBGEMM_BLOCK_KERNEL_TN_16x8xK    sbgemm_block_kernel_tn_16x8xK_one
+    #define SBGEMM_BLOCK_KERNEL_TN_32xNx32   sbgemm_block_kernel_tn_32xNx32_one
+    #define SBGEMM_BLOCK_KERNEL_TN_16xNx32   sbgemm_block_kernel_tn_16xNx32_one
+
+    #define SBGEMM_BLOCK_KERNEL_TT_32x8xK    SBGEMM_BLOCK_KERNEL_TN_32x8xK
+    #define SBGEMM_BLOCK_KERNEL_TT_16x8xK    SBGEMM_BLOCK_KERNEL_TN_16x8xK
+    #define SBGEMM_BLOCK_KERNEL_TT_32xNxK    sbgemm_block_kernel_tt_32xNxK_one
+    #define SBGEMM_BLOCK_KERNEL_TT_16xNxK    sbgemm_block_kernel_tt_16xNxK_one
+
+    #define SBGEMM_BLOCKING_KERNEL_NN        sbgemm_blocking_kernel_nn_one
+    #define SBGEMM_BLOCKING_KERNEL_NT        sbgemm_blocking_kernel_nt_one
+    #define SBGEMM_BLOCKING_KERNEL_TN        sbgemm_blocking_kernel_tn_one
+    #define SBGEMM_BLOCKING_KERNEL_TT        sbgemm_blocking_kernel_tt_one
 #endif
 
+extern bfloat16 * block_A;
+extern bfloat16 * block_B;
 
+/* --------------------------------------------- NN kernels ------------------------------------------ */
 // SBGEMM Kernel for 16<M<=32, N=8, K can be any number, but the processing will take 32 as a base
 #ifndef ONE_ALPHA      // ALPHA is not ONE
-void sbgemm_block_kernel_32x8x32_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_32x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
 #else                  // ALPHA is ONE
-void sbgemm_block_kernel_32x8x32_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_32x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
 #endif
 {
-    int  SHUFFLE_MAGIC_NO = 0x39;
-    BLASLONG tag_k_32x = k & (~31);
-    BLASLONG idxA_base = 0;
-    BLASLONG idxB_base = 0;
-    BLASLONG width = 32;
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
 
 #ifndef ONE_ALPHA
     __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
@@ -73,65 +143,42 @@ void sbgemm_block_kernel_32x8x32_one(BLASLONG m, BLASLONG k, float alpha, bfloat
     result_512_14 = _mm512_setzero_ps();
     result_512_15 = _mm512_setzero_ps();
 
-    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
-        // Load B with unroll 8
-        idxB_base = idx_k << 3;
-        arrayB_512_0 = _mm512_loadu_si512(&B[idxB_base + 32*0]);
-        arrayB_512_1 = _mm512_loadu_si512(&B[idxB_base + 32*1]);
-        arrayB_512_2 = _mm512_loadu_si512(&B[idxB_base + 32*2]);
-        arrayB_512_3 = _mm512_loadu_si512(&B[idxB_base + 32*3]);
-        arrayB_512_4 = _mm512_loadu_si512(&B[idxB_base + 32*4]);
-        arrayB_512_5 = _mm512_loadu_si512(&B[idxB_base + 32*5]);
-        arrayB_512_6 = _mm512_loadu_si512(&B[idxB_base + 32*6]);
-        arrayB_512_7 = _mm512_loadu_si512(&B[idxB_base + 32*7]);
-
-        if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
-
-        for (BLASLONG idx = 0; idx < width;) {
-            // Each two rows are a group for 32-pair bf16 elements
-            idxA_base = idx << 5;
-            arrayA_512_0 = _mm512_loadu_si512(&A[idxA_base]);
-            arrayA_512_1 = _mm512_loadu_si512(&A[idxA_base + 32]);
-
-            result_512_0  = _mm512_dpbf16_ps(result_512_0,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_0)));
-            result_512_1  = _mm512_dpbf16_ps(result_512_1,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_1)));
-            result_512_2  = _mm512_dpbf16_ps(result_512_2,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_2)));
-            result_512_3  = _mm512_dpbf16_ps(result_512_3,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_3)));
-            result_512_4  = _mm512_dpbf16_ps(result_512_4,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_4)));
-            result_512_5  = _mm512_dpbf16_ps(result_512_5,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_5)));
-            result_512_6  = _mm512_dpbf16_ps(result_512_6,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_6)));
-            result_512_7  = _mm512_dpbf16_ps(result_512_7,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_7)));
-            result_512_8  = _mm512_dpbf16_ps(result_512_8,  (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_0)));
-            result_512_9  = _mm512_dpbf16_ps(result_512_9,  (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_1)));
-            result_512_10 = _mm512_dpbf16_ps(result_512_10, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_2)));
-            result_512_11 = _mm512_dpbf16_ps(result_512_11, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_3)));
-            result_512_12 = _mm512_dpbf16_ps(result_512_12, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_4)));
-            result_512_13 = _mm512_dpbf16_ps(result_512_13, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_5)));
-            result_512_14 = _mm512_dpbf16_ps(result_512_14, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_6)));
-            result_512_15 = _mm512_dpbf16_ps(result_512_15, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_7)));
-
-            arrayB_512_0 = _mm512_shuffle_epi32(arrayB_512_0, SHUFFLE_MAGIC_NO);
-            arrayB_512_1 = _mm512_shuffle_epi32(arrayB_512_1, SHUFFLE_MAGIC_NO);
-            arrayB_512_2 = _mm512_shuffle_epi32(arrayB_512_2, SHUFFLE_MAGIC_NO);
-            arrayB_512_3 = _mm512_shuffle_epi32(arrayB_512_3, SHUFFLE_MAGIC_NO);
-            arrayB_512_4 = _mm512_shuffle_epi32(arrayB_512_4, SHUFFLE_MAGIC_NO);
-            arrayB_512_5 = _mm512_shuffle_epi32(arrayB_512_5, SHUFFLE_MAGIC_NO);
-            arrayB_512_6 = _mm512_shuffle_epi32(arrayB_512_6, SHUFFLE_MAGIC_NO);
-            arrayB_512_7 = _mm512_shuffle_epi32(arrayB_512_7, SHUFFLE_MAGIC_NO);
+    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+        // Each two rows are a group for 32-pair bf16 elements
+        arrayA_512_0 = _mm512_loadu_si512(A_addr);
+        arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
+
+        _MM512_BROADCASTD_EPI32(B_addr + 0,  arrayB_512_0);
+        _MM512_BROADCASTD_EPI32(B_addr + 2,  arrayB_512_1);
+        _MM512_BROADCASTD_EPI32(B_addr + 4,  arrayB_512_2);
+        _MM512_BROADCASTD_EPI32(B_addr + 6,  arrayB_512_3);
+        _MM512_BROADCASTD_EPI32(B_addr + 8,  arrayB_512_4);
+        _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
+        _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
+        _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
+
+        result_512_0  = _mm512_dpbf16_ps(result_512_0,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
+        result_512_1  = _mm512_dpbf16_ps(result_512_1,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
+        result_512_2  = _mm512_dpbf16_ps(result_512_2,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
+        result_512_3  = _mm512_dpbf16_ps(result_512_3,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
+        result_512_4  = _mm512_dpbf16_ps(result_512_4,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
+        result_512_5  = _mm512_dpbf16_ps(result_512_5,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
+        result_512_6  = _mm512_dpbf16_ps(result_512_6,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
+        result_512_7  = _mm512_dpbf16_ps(result_512_7,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
+
+        result_512_8  = _mm512_dpbf16_ps(result_512_8,  (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_0);
+        result_512_9  = _mm512_dpbf16_ps(result_512_9,  (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_1);
+        result_512_10 = _mm512_dpbf16_ps(result_512_10, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_2);
+        result_512_11 = _mm512_dpbf16_ps(result_512_11, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_3);
+        result_512_12 = _mm512_dpbf16_ps(result_512_12, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_4);
+        result_512_13 = _mm512_dpbf16_ps(result_512_13, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_5);
+        result_512_14 = _mm512_dpbf16_ps(result_512_14, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_6);
+        result_512_15 = _mm512_dpbf16_ps(result_512_15, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_7);
 
-            idx += 2;
-            // Every 4 loops we need to switch to next 128 bits of arrayB registers
-            if ((idx & (~7)) == idx) {
-                arrayB_512_0 = _mm512_shuffle_i32x4(arrayB_512_0, arrayB_512_0, SHUFFLE_MAGIC_NO);
-                arrayB_512_1 = _mm512_shuffle_i32x4(arrayB_512_1, arrayB_512_1, SHUFFLE_MAGIC_NO);
-                arrayB_512_2 = _mm512_shuffle_i32x4(arrayB_512_2, arrayB_512_2, SHUFFLE_MAGIC_NO);
-                arrayB_512_3 = _mm512_shuffle_i32x4(arrayB_512_3, arrayB_512_3, SHUFFLE_MAGIC_NO);
-                arrayB_512_4 = _mm512_shuffle_i32x4(arrayB_512_4, arrayB_512_4, SHUFFLE_MAGIC_NO);
-                arrayB_512_5 = _mm512_shuffle_i32x4(arrayB_512_5, arrayB_512_5, SHUFFLE_MAGIC_NO);
-                arrayB_512_6 = _mm512_shuffle_i32x4(arrayB_512_6, arrayB_512_6, SHUFFLE_MAGIC_NO);
-                arrayB_512_7 = _mm512_shuffle_i32x4(arrayB_512_7, arrayB_512_7, SHUFFLE_MAGIC_NO);
-            }
-        }
+        // Load B with unroll 8
+        B_addr += 16;
+        // Load A with unroll 64
+        A_addr += 64;
     }
 
     if (m != 32) {
@@ -141,81 +188,80 @@ void sbgemm_block_kernel_32x8x32_one(BLASLONG m, BLASLONG k, float alpha, bfloat
         result_512_tmp_1 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base1, result_512_8);
         result_512_tmp_2 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base0, result_512_9);
         result_512_tmp_3 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base1, result_512_9);
-        STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*0]))
-        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*0+16]), tail_mask)
-        STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*1]))
-        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*1+16]), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr))
+        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*1))
+        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*1 + 16), tail_mask)
         result_512_tmp_0 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base0, result_512_10);
         result_512_tmp_1 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base1, result_512_10);
         result_512_tmp_2 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base0, result_512_11);
         result_512_tmp_3 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base1, result_512_11);
-        STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*2]))
-        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*2+16]), tail_mask)
-        STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*3]))
-        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*3+16]), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*2))
+        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*2 + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*3))
+        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*3 + 16), tail_mask)
         result_512_tmp_0 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base0, result_512_12);
         result_512_tmp_1 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base1, result_512_12);
         result_512_tmp_2 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base0, result_512_13);
         result_512_tmp_3 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base1, result_512_13);
-        STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*4]))
-        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*4+16]), tail_mask)
-        STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*5]))
-        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*5+16]), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*4))
+        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*4 + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*5))
+        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*5 + 16), tail_mask)
         result_512_tmp_0 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base0, result_512_14);
         result_512_tmp_1 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base1, result_512_14);
         result_512_tmp_2 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base0, result_512_15);
         result_512_tmp_3 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base1, result_512_15);
-        STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*6]))
-        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*6+16]), tail_mask)
-        STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*7]))
-        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*7+16]), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*6))
+        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*6 + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*7))
+        STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*7 + 16), tail_mask)
     } else {
         result_512_tmp_0 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base0, result_512_8);
         result_512_tmp_1 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base1, result_512_8);
         result_512_tmp_2 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base0, result_512_9);
         result_512_tmp_3 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base1, result_512_9);
-        STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*0]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*0+16]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*1]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*1+16]))
+        STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr))
+        STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + 16))
+        STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*1))
+        STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*1 + 16))
         result_512_tmp_0 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base0, result_512_10);
         result_512_tmp_1 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base1, result_512_10);
         result_512_tmp_2 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base0, result_512_11);
         result_512_tmp_3 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base1, result_512_11);
-        STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*2]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*2+16]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*3]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*3+16]))
+        STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*2))
+        STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*2 + 16))
+        STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*3))
+        STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*3 + 16))
         result_512_tmp_0 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base0, result_512_12);
         result_512_tmp_1 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base1, result_512_12);
         result_512_tmp_2 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base0, result_512_13);
         result_512_tmp_3 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base1, result_512_13);
-        STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*4]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*4+16]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*5]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*5+16]))
+        STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*4))
+        STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*4 + 16))
+        STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*5))
+        STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*5 + 16))
         result_512_tmp_0 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base0, result_512_14);
         result_512_tmp_1 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base1, result_512_14);
         result_512_tmp_2 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base0, result_512_15);
         result_512_tmp_3 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base1, result_512_15);
-        STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*6]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*6+16]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*7]))
-        STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*7+16]))
+        STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*6))
+        STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*6 + 16))
+        STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*7))
+        STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*7 + 16))
     }
 }
 
-// SBGEMM Kernel for M<=16, N=8, K can be any number, but the processing will take 32 as a base
+// SBGEMM Kernel for M<=16, N=8, K can be any number
 #ifndef ONE_ALPHA      // ALPHA is not ONE
-void sbgemm_block_kernel_16x8x32_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_16x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
 #else                  // ALPHA is ONE
-void sbgemm_block_kernel_16x8x32_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_16x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
 #endif
 {
-    int  SHUFFLE_MAGIC_NO = 0x39;
-    BLASLONG tag_k_32x = k & (~31);
-    BLASLONG idxB_base = 0;
-    BLASLONG width = 32;
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
 
 #ifndef ONE_ALPHA
     __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
@@ -234,110 +280,87 @@ void sbgemm_block_kernel_16x8x32_one(BLASLONG m, BLASLONG k, float alpha, bfloat
     result_512_6  = _mm512_setzero_ps();
     result_512_7  = _mm512_setzero_ps();
 
-    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
-        // Load B with unroll 8
-        idxB_base = idx_k << 3;
-        arrayB_512_0 = _mm512_loadu_si512(&B[idxB_base + 32*0]);
-        arrayB_512_1 = _mm512_loadu_si512(&B[idxB_base + 32*1]);
-        arrayB_512_2 = _mm512_loadu_si512(&B[idxB_base + 32*2]);
-        arrayB_512_3 = _mm512_loadu_si512(&B[idxB_base + 32*3]);
-        arrayB_512_4 = _mm512_loadu_si512(&B[idxB_base + 32*4]);
-        arrayB_512_5 = _mm512_loadu_si512(&B[idxB_base + 32*5]);
-        arrayB_512_6 = _mm512_loadu_si512(&B[idxB_base + 32*6]);
-        arrayB_512_7 = _mm512_loadu_si512(&B[idxB_base + 32*7]);
-
-        if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
-
-        for (BLASLONG idx = 0; idx < width;) {
-            // Each two rows are a group for 32-pair bf16 elements
-            // Load two rows into a 512 register
-            arrayA_512_0 = _mm512_loadu_si512(&A[idx<<4]);
-
-            result_512_0  = _mm512_dpbf16_ps(result_512_0,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_0)));
-            result_512_1  = _mm512_dpbf16_ps(result_512_1,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_1)));
-            result_512_2  = _mm512_dpbf16_ps(result_512_2,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_2)));
-            result_512_3  = _mm512_dpbf16_ps(result_512_3,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_3)));
-            result_512_4  = _mm512_dpbf16_ps(result_512_4,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_4)));
-            result_512_5  = _mm512_dpbf16_ps(result_512_5,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_5)));
-            result_512_6  = _mm512_dpbf16_ps(result_512_6,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_6)));
-            result_512_7  = _mm512_dpbf16_ps(result_512_7,  (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_7)));
-
-            arrayB_512_0 = _mm512_shuffle_epi32(arrayB_512_0, SHUFFLE_MAGIC_NO);
-            arrayB_512_1 = _mm512_shuffle_epi32(arrayB_512_1, SHUFFLE_MAGIC_NO);
-            arrayB_512_2 = _mm512_shuffle_epi32(arrayB_512_2, SHUFFLE_MAGIC_NO);
-            arrayB_512_3 = _mm512_shuffle_epi32(arrayB_512_3, SHUFFLE_MAGIC_NO);
-            arrayB_512_4 = _mm512_shuffle_epi32(arrayB_512_4, SHUFFLE_MAGIC_NO);
-            arrayB_512_5 = _mm512_shuffle_epi32(arrayB_512_5, SHUFFLE_MAGIC_NO);
-            arrayB_512_6 = _mm512_shuffle_epi32(arrayB_512_6, SHUFFLE_MAGIC_NO);
-            arrayB_512_7 = _mm512_shuffle_epi32(arrayB_512_7, SHUFFLE_MAGIC_NO);
+    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+        // Each two rows are a group for 32-pair bf16 elements
+        // Load two rows into a 512 register
+        arrayA_512_0 = _mm512_loadu_si512(A_addr);
+
+        _MM512_BROADCASTD_EPI32(B_addr + 0,  arrayB_512_0);
+        _MM512_BROADCASTD_EPI32(B_addr + 2,  arrayB_512_1);
+        _MM512_BROADCASTD_EPI32(B_addr + 4,  arrayB_512_2);
+        _MM512_BROADCASTD_EPI32(B_addr + 6,  arrayB_512_3);
+        _MM512_BROADCASTD_EPI32(B_addr + 8,  arrayB_512_4);
+        _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
+        _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
+        _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
+
+        result_512_0  = _mm512_dpbf16_ps(result_512_0,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
+        result_512_1  = _mm512_dpbf16_ps(result_512_1,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
+        result_512_2  = _mm512_dpbf16_ps(result_512_2,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
+        result_512_3  = _mm512_dpbf16_ps(result_512_3,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
+        result_512_4  = _mm512_dpbf16_ps(result_512_4,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
+        result_512_5  = _mm512_dpbf16_ps(result_512_5,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
+        result_512_6  = _mm512_dpbf16_ps(result_512_6,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
+        result_512_7  = _mm512_dpbf16_ps(result_512_7,  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
 
-            idx += 2;
-            // Every 4 loops we need to switch to next 128 bits of arrayB registers
-            if ((idx & (~7)) == idx) {
-                arrayB_512_0 = _mm512_shuffle_i32x4(arrayB_512_0, arrayB_512_0, SHUFFLE_MAGIC_NO);
-                arrayB_512_1 = _mm512_shuffle_i32x4(arrayB_512_1, arrayB_512_1, SHUFFLE_MAGIC_NO);
-                arrayB_512_2 = _mm512_shuffle_i32x4(arrayB_512_2, arrayB_512_2, SHUFFLE_MAGIC_NO);
-                arrayB_512_3 = _mm512_shuffle_i32x4(arrayB_512_3, arrayB_512_3, SHUFFLE_MAGIC_NO);
-                arrayB_512_4 = _mm512_shuffle_i32x4(arrayB_512_4, arrayB_512_4, SHUFFLE_MAGIC_NO);
-                arrayB_512_5 = _mm512_shuffle_i32x4(arrayB_512_5, arrayB_512_5, SHUFFLE_MAGIC_NO);
-                arrayB_512_6 = _mm512_shuffle_i32x4(arrayB_512_6, arrayB_512_6, SHUFFLE_MAGIC_NO);
-                arrayB_512_7 = _mm512_shuffle_i32x4(arrayB_512_7, arrayB_512_7, SHUFFLE_MAGIC_NO);
-            }
-        }
+        // Load B with unroll 8
+        B_addr += 16;
+        // Load A with unroll 16
+        A_addr += 32;
     }
 
     if (m != 16) {
-        unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
-        __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+        unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
 
         result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8);
         result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8);
         result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8);
         result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8);
-        STORE16_MASK_COMPLETE_RESULT(result_512_0, (&C[ldc*0]), tail_mask)
-        STORE16_MASK_COMPLETE_RESULT(result_512_1, (&C[ldc*1]), tail_mask)
-        STORE16_MASK_COMPLETE_RESULT(result_512_2, (&C[ldc*2]), tail_mask)
-        STORE16_MASK_COMPLETE_RESULT(result_512_3, (&C[ldc*3]), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_0, (C_addr), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3), tail_mask)
         result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8);
         result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8);
         result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8);
         result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8);
-        STORE16_MASK_COMPLETE_RESULT(result_512_4, (&C[ldc*4]), tail_mask)
-        STORE16_MASK_COMPLETE_RESULT(result_512_5, (&C[ldc*5]), tail_mask)
-        STORE16_MASK_COMPLETE_RESULT(result_512_6, (&C[ldc*6]), tail_mask)
-        STORE16_MASK_COMPLETE_RESULT(result_512_7, (&C[ldc*7]), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7), tail_mask)
     } else {
         result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8);
         result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8);
         result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8);
         result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8);
-        STORE16_COMPLETE_RESULT(result_512_0, (&C[ldc*0]))
-        STORE16_COMPLETE_RESULT(result_512_1, (&C[ldc*1]))
-        STORE16_COMPLETE_RESULT(result_512_2, (&C[ldc*2]))
-        STORE16_COMPLETE_RESULT(result_512_3, (&C[ldc*3]))
+        STORE16_COMPLETE_RESULT(result_512_0, (C_addr))
+        STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1))
+        STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2))
+        STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3))
         result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8);
         result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8);
         result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8);
         result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8);
-        STORE16_COMPLETE_RESULT(result_512_4, (&C[ldc*4]))
-        STORE16_COMPLETE_RESULT(result_512_5, (&C[ldc*5]))
-        STORE16_COMPLETE_RESULT(result_512_6, (&C[ldc*6]))
-        STORE16_COMPLETE_RESULT(result_512_7, (&C[ldc*7]))
+        STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4))
+        STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5))
+        STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6))
+        STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7))
     }
 }
 
 // SBGEMM Kernel for 16<M<=32, N<8, K can be any number, but the processing will take 32 as a base
 #ifndef ONE_ALPHA      // ALPHA is not ONE
-void sbgemm_block_kernel_32xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_32xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
 #else                  // ALPHA is ONE
-void sbgemm_block_kernel_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
 #endif
 {
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
     int  SHUFFLE_MAGIC_NO = 0x39;
     BLASLONG tag_k_32x = k & (~31);
-    BLASLONG idxA_base = 0;
-    BLASLONG idxB_base = 0;
-    BLASLONG width = 32;
 
 #ifndef ONE_ALPHA
     __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
@@ -357,20 +380,48 @@ void sbgemm_block_kernel_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float a
         result_512[i+1]  = _mm512_setzero_ps();
     }
 
-    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
+    for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
         // Load B with unroll n
         for (int i = 0; i < n; i ++) {
-            arrayB_512[i] = _mm512_loadu_si512(&B[idxB_base]);
-            idxB_base += 32;
+            arrayB_512[i] = _mm512_loadu_si512(B_addr);
+            B_addr += 32;
+        }
+
+        for (BLASLONG idx = 0; idx < 32;) {
+            // Each two rows are a group for 32-pair bf16 elements
+            arrayA_512[0] = _mm512_loadu_si512(A_addr);
+            arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
+            A_addr += 64;
+
+            for (int i = 0; i < n; i++) {
+                result_512[i]   = _mm512_dpbf16_ps(result_512[i]  , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+                result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+                arrayB_512[i]   = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+            }
+
+            idx += 2;
+            // Every 4 loops we need to switch to next 128 bits of arrayB registers
+            if ((idx & (~7)) == idx) {
+                for (int i = 0; i < n; i++) {
+                    arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+                }
+            }
         }
+    }
 
-        if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
+    if (tag_k_32x != k) {
+        // Load B with unroll n
+        for (int i = 0; i < n; i ++) {
+            arrayB_512[i] = _mm512_loadu_si512(B_addr);
+            B_addr += 32;
+        }
 
+        BLASLONG width = k - tag_k_32x;
         for (BLASLONG idx = 0; idx < width;) {
             // Each two rows are a group for 32-pair bf16 elements
-            idxA_base = idx << 5;
-            arrayA_512[0] = _mm512_loadu_si512(&A[idxA_base]);
-            arrayA_512[1] = _mm512_loadu_si512(&A[idxA_base + 32]);
+            arrayA_512[0] = _mm512_loadu_si512(A_addr);
+            arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
+            A_addr += 64;
 
             for (int i = 0; i < n; i++) {
                 result_512[i]   = _mm512_dpbf16_ps(result_512[i]  , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
@@ -389,35 +440,36 @@ void sbgemm_block_kernel_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float a
     }
 
     if (m != 32) {
-        unsigned short tail_mask_value = (((unsigned short)0xffff) >> (32-m));
-        __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+        unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
         for (int i = 0; i < n; i++) {
             result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
             result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
-            STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i]))
-            STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16]), tail_mask)
+            STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
+            STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask)
         }
     } else {
         for (int i = 0; i < n; i++) {
             result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
             result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
-            STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i]))
-            STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16]))
+            STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
+            STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16))
         }
     }
 }
 
 // SBGEMM Kernel for 16<=M, N<8, K can be any number, but the processing will take 32 as a base
 #ifndef ONE_ALPHA      // ALPHA is not ONE
-void sbgemm_block_kernel_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
 #else                  // ALPHA is ONE
-void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
 #endif
 {
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
     int  SHUFFLE_MAGIC_NO = 0x39;
     BLASLONG tag_k_32x = k & (~31);
-    BLASLONG idxB_base = 0;
-    BLASLONG width = 32;
 
 #ifndef ONE_ALPHA
     __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
@@ -432,21 +484,49 @@ void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float a
         result_512[i+1]  = _mm512_setzero_ps();
     }
 
-    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
+    for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
         // Load B with unroll n
-        for (int i = 0; i < n; i ++) {
-            arrayB_512[i] = _mm512_loadu_si512(&B[idxB_base]);
-            idxB_base += 32;
+        for (int i = 0; i < n; i++) {
+            arrayB_512[i] = _mm512_loadu_si512(B_addr);
+            B_addr += 32;
+        }
+
+        for (BLASLONG idx = 0; idx < 32;) {
+            // Each two rows are a group for 32-pair bf16 elements
+            // Load two rows into a 512 register
+            arrayA_512 = _mm512_loadu_si512(A_addr);
+            A_addr += 32;
+
+            for (int i = 0; i < n; i ++) {
+                result_512[i]  = _mm512_dpbf16_ps(result_512[i],  (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+                arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+            }
+
+            idx += 2;
+            // Every 4 loops we need to switch to next 128 bits of arrayB registers
+            if ((idx & (~7)) == idx) {
+                for (int i = 0; i < n; i++) {
+                    arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+                }
+            }
         }
+    }
 
-        if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
+    if (tag_k_32x != k) {
+        // Load B with unroll n
+        for (int i = 0; i < n; i++) {
+            arrayB_512[i] = _mm512_loadu_si512(B_addr);
+            B_addr += 32;
+        }
 
+        BLASLONG width = k - tag_k_32x;
         for (BLASLONG idx = 0; idx < width;) {
             // Each two rows are a group for 32-pair bf16 elements
             // Load two rows into a 512 register
-            arrayA_512 = _mm512_loadu_si512(&A[idx<<4]);
+            arrayA_512 = _mm512_loadu_si512(A_addr);
+            A_addr += 32;
 
-            for (int i = 0; i < n; i ++) {
+            for (int i = 0; i < n; i++) {
                 result_512[i]  = _mm512_dpbf16_ps(result_512[i],  (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
                 arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
             }
@@ -462,23 +542,24 @@ void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float a
     }
 
     if (m != 16) {
-        unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
-        __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+        unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
         for (int i = 0; i < n; i++) {
             result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
-            STORE16_MASK_COMPLETE_RESULT(result_512[i], (&C[ldc*i]), tail_mask)
+            STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
         }
     } else {
         for (int i = 0; i < n; i++) {
             result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
-            STORE16_COMPLETE_RESULT(result_512[i], (&C[ldc*i]))
+            STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
         }
     }
 }
+
+
 #ifndef ONE_ALPHA      // ALPHA is not ONE
-void sbgemm_blocking_kernel_2_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+void sbgemm_blocking_kernel_nn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
 #else                  // ALPHA is ONE
-void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+void sbgemm_blocking_kernel_nn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
 #endif
 {
     BLASLONG m_step, n_step, k_step, k_step_round32;
@@ -499,63 +580,52 @@ void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha,
         while (n_from < N) {
             for (BLASLONG idx_k = 0; idx_k < K;) {
                 // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
-                COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, 0), lda, block_A);
-                // TODO: MT
+                COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A);
                 for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
                     // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
                     COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
-                    SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
                 }
 
                 if (tag_n_Nx != n_to) {
                     n_step = n_to - tag_n_Nx;
                     COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
-                    SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
                 }
 
                 for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
-                    COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, idx_m), lda, block_A);
+                    COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A);
                     for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
-                        SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
+                        SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
                     }
 
                     if (tag_n_Nx != n_to) {
                         n_step = n_to - tag_n_Nx;
-                        SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
+                        SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
                     }
                 }
 
                 if (tag_m_Nx != M) {
                     m_step = M - tag_m_Nx;
                     if (m_step > 16) {
-                        COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
-                        for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
-                            SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
-                        }
-
-                        if (tag_n_Nx != n_to) {
-                            n_step = n_to - tag_n_Nx;
-                            SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
-                        }
-                    } else if (m_step == 16) {
-                        COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
+                        COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
                         for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
-                            SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+                            SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
                         }
 
                         if (tag_n_Nx != n_to) {
                             n_step = n_to - tag_n_Nx;
-                            SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+                            SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
                         }
                     } else {
-                        COL_MAJOR_INCOPY_KERNEL_Kx16m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
+                        COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
                         for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
-                            SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+                            SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
                         }
 
                         if (tag_n_Nx != n_to) {
                             n_step = n_to - tag_n_Nx;
-                            SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+                            SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
                         }
                     }
                 }
@@ -573,22 +643,274 @@ void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha,
             tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));           
         }
     } else {
-        m_step = M - tag_m_Nx;
+        m_step = M;
+        if (m_step > 16) {
+            while (n_from < N) {
+                for (BLASLONG idx_k = 0; idx_k < K;) {
+                    // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+                    COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                        COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    }
+
+                    idx_k += k_step;
+                    k_step = K - idx_k;
+                    k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                    k_step_round32 = k_step & (~31);
+                    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+                }
+                n_from = n_to;
+                n_to += BF16_BLOCK_THRES_N;
+                n_to = (n_to > N) ? N : n_to;
+                tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+            }
+        } else {
+            while (n_from < N) {
+                for (BLASLONG idx_k = 0; idx_k < K;) {
+                    COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                        COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    }
+
+                    idx_k += k_step;
+                    k_step = K - idx_k;
+                    k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                    k_step_round32 = k_step & (~31);
+                    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+                }
+                n_from = n_to;
+                n_to += BF16_BLOCK_THRES_N;
+                n_to = (n_to > N) ? N : n_to;
+                tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+            }
+        }
+    }
+}
+/* ----------------------------------------- End of NN kernels --------------------------------------- */
+
+/* --------------------------------------------- NT kernels ------------------------------------------ */
+// SBGEMM Kernel for 16<M<=32, N<8, K can be any number
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_block_kernel_nt_32xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else                  // ALPHA is ONE
+void sbgemm_block_kernel_nt_32xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
+#ifndef ONE_ALPHA
+    __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+    __m512i arrayA_512_0, arrayA_512_1;
+    __m512i arrayB_512[8];
+    __m512  result_512[16];
+    __m512  result_512_tmp_0, result_512_tmp_1;
+
+    __m512i M512_EPI32_8      = _mm512_set1_epi32(8);
+    __m512i shuffle_idx_base0 = _mm512_set_epi32(23, 22, 21, 20, 7, 6, 5, 4, 19, 18, 17, 16, 3, 2, 1, 0);
+    __m512i shuffle_idx_base1 = _mm512_add_epi32(shuffle_idx_base0, M512_EPI32_8);
+
+    result_512[0]  = _mm512_setzero_ps();
+    result_512[1]  = _mm512_setzero_ps();
+    result_512[2]  = _mm512_setzero_ps();
+    result_512[3]  = _mm512_setzero_ps();
+    result_512[4]  = _mm512_setzero_ps();
+    result_512[5]  = _mm512_setzero_ps();
+    result_512[6]  = _mm512_setzero_ps();
+    result_512[7]  = _mm512_setzero_ps();
+    result_512[8]  = _mm512_setzero_ps();
+    result_512[9]  = _mm512_setzero_ps();
+    result_512[10] = _mm512_setzero_ps();
+    result_512[11] = _mm512_setzero_ps();
+    result_512[12] = _mm512_setzero_ps();
+    result_512[13] = _mm512_setzero_ps();
+    result_512[14] = _mm512_setzero_ps();
+    result_512[15] = _mm512_setzero_ps();
+
+    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+        // Each two rows are a group for 32-pair bf16 elements
+        arrayA_512_0 = _mm512_loadu_si512(A_addr);
+        arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
+        A_addr += 64;
+
+        for (int i = 0; i < n; i ++) {
+            _MM512_BROADCASTD_EPI32(B_addr + i*2,  arrayB_512[i]);
+        }
+        B_addr += 16;
+
+        for (int i = 0; i < n; i ++) {
+            result_512[i] = _mm512_dpbf16_ps(result_512[i],  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
+            result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8],  (__m512bh) arrayA_512_1, (__m512bh) arrayB_512[i]);
+        }
+    }
+
+    if (m != 32) {
+        unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
+        for (int i = 0; i < n; i ++) {
+            result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
+            result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
+            STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
+            STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask)
+        }
+    } else {
+        for (int i = 0; i < n; i ++) {
+            result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
+            result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
+            STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
+            STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16))
+        }
+    }
+}
+
+// SBGEMM Kernel for M<=16, N<8, K can be any number
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_block_kernel_nt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else                  // ALPHA is ONE
+void sbgemm_block_kernel_nt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
+#ifndef ONE_ALPHA
+    __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+    __m512i arrayA_512_0;
+    __m512i arrayB_512[8];
+    __m512  result_512[8];
+
+    result_512[0]  = _mm512_setzero_ps();
+    result_512[1]  = _mm512_setzero_ps();
+    result_512[2]  = _mm512_setzero_ps();
+    result_512[3]  = _mm512_setzero_ps();
+    result_512[4]  = _mm512_setzero_ps();
+    result_512[5]  = _mm512_setzero_ps();
+    result_512[6]  = _mm512_setzero_ps();
+    result_512[7]  = _mm512_setzero_ps();
+
+    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+        // Each two rows are a group for 16-pair bf16 elements
+        // Load two rows into a 512 register
+        arrayA_512_0 = _mm512_loadu_si512(A_addr);
+        A_addr += 32;
+
+        for (int i = 0; i < n; i ++) {
+            _MM512_BROADCASTD_EPI32(B_addr + i*2,  arrayB_512[i]);
+        }
+        B_addr += 16;
+
+        for (int i = 0; i < n; i ++) {
+            result_512[i] = _mm512_dpbf16_ps(result_512[i],  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
+        }
+    }
+
+    if (m != 16) {
+        unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
+        for (int i = 0; i < n; i++) {
+            result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
+            STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
+        }
+    } else {
+        for (int i = 0; i < n; i++) {
+            result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
+            STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+        }
+    }
+}
+
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_blocking_kernel_nt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#else                  // ALPHA is ONE
+void sbgemm_blocking_kernel_nt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#endif
+{
+    BLASLONG m_step, n_step, k_step, k_step_round32;
+    BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
+
+    BLASLONG n_from, n_to;
+    BLASLONG tag_n_Nx;
+
+    n_from = 0;
+    n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
+    tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+
+    k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
+    k_step_round32 = k_step & (~31);
+    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+
+    if (M >= BF16_BLOCK_THRES_M) {
         while (n_from < N) {
             for (BLASLONG idx_k = 0; idx_k < K;) {
                 // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
-                COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, 0), lda, block_A);
-                // TODO: MT
+                COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A);
                 for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
                     // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
-                    COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
-                    SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                    SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
                 }
 
                 if (tag_n_Nx != n_to) {
                     n_step = n_to - tag_n_Nx;
-                    COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
-                    SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                    SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                }
+
+                for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
+                    COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
+                    }
+                }
+
+                if (tag_m_Nx != M) {
+                    m_step = M - tag_m_Nx;
+                    if (m_step > 16) {
+                        COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
+                        for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                            SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+                        }
+
+                        if (tag_n_Nx != n_to) {
+                            n_step = n_to - tag_n_Nx;
+                            SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+                        }
+                    } else {
+                        COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
+                        for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                            SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+                        }
+
+                        if (tag_n_Nx != n_to) {
+                            n_step = n_to - tag_n_Nx;
+                            SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+                        }
+                    }
                 }
 
                 idx_k += k_step;
@@ -597,13 +919,884 @@ void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha,
                 k_step_round32 = k_step & (~31);
                 k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
             }
+
             n_from = n_to;
             n_to += BF16_BLOCK_THRES_N;
             n_to = (n_to > N) ? N : n_to;
             tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
         }
+    } else {
+        m_step = M;
+        if (m_step > 16) {
+            while (n_from < N) {
+                for (BLASLONG idx_k = 0; idx_k < K;) {
+                    // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+                    COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                        COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    }
+
+                    idx_k += k_step;
+                    k_step = K - idx_k;
+                    k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                    k_step_round32 = k_step & (~31);
+                    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+                }
+                n_from = n_to;
+                n_to += BF16_BLOCK_THRES_N;
+                n_to = (n_to > N) ? N : n_to;
+                tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+            }
+        } else {
+            while (n_from < N) {
+                for (BLASLONG idx_k = 0; idx_k < K;) {
+                    // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+                    COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                        COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    }
+
+                    idx_k += k_step;
+                    k_step = K - idx_k;
+                    k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                    k_step_round32 = k_step & (~31);
+                    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+                }
+                n_from = n_to;
+                n_to += BF16_BLOCK_THRES_N;
+                n_to = (n_to > N) ? N : n_to;
+                tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+            }
+        }
     }
 }
+/* ----------------------------------------- End of NT kernels --------------------------------------- */
+
+/* --------------------------------------------- TN kernels ------------------------------------------ */
+// SBGEMM Kernel for 16<M<=32, N=8, K=Any number
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_block_kernel_tn_32x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else                  // ALPHA is ONE
+void sbgemm_block_kernel_tn_32x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
+#ifndef ONE_ALPHA
+    __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+    __m512i arrayA_512_0, arrayA_512_1;
+    __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
+    __m512  result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7,
+            result_512_8, result_512_9, result_512_10, result_512_11, result_512_12, result_512_13, result_512_14, result_512_15;
+
+    result_512_0  = _mm512_setzero_ps();
+    result_512_1  = _mm512_setzero_ps();
+    result_512_2  = _mm512_setzero_ps();
+    result_512_3  = _mm512_setzero_ps();
+    result_512_4  = _mm512_setzero_ps();
+    result_512_5  = _mm512_setzero_ps();
+    result_512_6  = _mm512_setzero_ps();
+    result_512_7  = _mm512_setzero_ps();
+    result_512_8  = _mm512_setzero_ps();
+    result_512_9  = _mm512_setzero_ps();
+    result_512_10 = _mm512_setzero_ps();
+    result_512_11 = _mm512_setzero_ps();
+    result_512_12 = _mm512_setzero_ps();
+    result_512_13 = _mm512_setzero_ps();
+    result_512_14 = _mm512_setzero_ps();
+    result_512_15 = _mm512_setzero_ps();
+
+    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+        // Load 32 pair of BF16 elements from A (32 rows)
+        arrayA_512_0 = _mm512_loadu_si512(A_addr);
+        arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
+
+        // Load 8 rows of B
+        _MM512_BROADCASTD_EPI32(B_addr + 0,  arrayB_512_0);
+        _MM512_BROADCASTD_EPI32(B_addr + 2,  arrayB_512_1);
+        _MM512_BROADCASTD_EPI32(B_addr + 4,  arrayB_512_2);
+        _MM512_BROADCASTD_EPI32(B_addr + 6,  arrayB_512_3);
+        _MM512_BROADCASTD_EPI32(B_addr + 8,  arrayB_512_4);
+        _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
+        _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
+        _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
+
+        result_512_0  = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
+        result_512_1  = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
+        result_512_2  = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
+        result_512_3  = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
+        result_512_4  = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
+        result_512_5  = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
+        result_512_6  = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
+        result_512_7  = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
+
+        result_512_8  = _mm512_dpbf16_ps(result_512_8,  (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_0);
+        result_512_9  = _mm512_dpbf16_ps(result_512_9,  (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_1);
+        result_512_10 = _mm512_dpbf16_ps(result_512_10, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_2);
+        result_512_11 = _mm512_dpbf16_ps(result_512_11, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_3);
+        result_512_12 = _mm512_dpbf16_ps(result_512_12, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_4);
+        result_512_13 = _mm512_dpbf16_ps(result_512_13, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_5);
+        result_512_14 = _mm512_dpbf16_ps(result_512_14, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_6);
+        result_512_15 = _mm512_dpbf16_ps(result_512_15, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_7);
+
+        // Load B with unroll 8
+        B_addr += 16;
+        // Load A with unroll 64
+        A_addr += 64;
+    }
+
+    if (m != 32) {
+        unsigned short tail_mask_value = (((unsigned short)0xffff) >> (32-m));
+        __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+        STORE16_COMPLETE_RESULT(result_512_0,  (C_addr))
+        STORE16_MASK_COMPLETE_RESULT(result_512_8,  (C_addr + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_1,  (C_addr + ldc))
+        STORE16_MASK_COMPLETE_RESULT(result_512_9,  (C_addr + ldc + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_2,  (C_addr + ldc*2))
+        STORE16_MASK_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_3,  (C_addr + ldc*3))
+        STORE16_MASK_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_4,  (C_addr + ldc*4))
+        STORE16_MASK_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_5,  (C_addr + ldc*5))
+        STORE16_MASK_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_6,  (C_addr + ldc*6))
+        STORE16_MASK_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16), tail_mask)
+        STORE16_COMPLETE_RESULT(result_512_7,  (C_addr + ldc*7))
+        STORE16_MASK_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16), tail_mask)
+    } else {
+        STORE16_COMPLETE_RESULT(result_512_0,  (C_addr))
+        STORE16_COMPLETE_RESULT(result_512_8,  (C_addr + 16))
+        STORE16_COMPLETE_RESULT(result_512_1,  (C_addr + ldc))
+        STORE16_COMPLETE_RESULT(result_512_9,  (C_addr + ldc + 16))
+        STORE16_COMPLETE_RESULT(result_512_2,  (C_addr + ldc*2))
+        STORE16_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16))
+        STORE16_COMPLETE_RESULT(result_512_3,  (C_addr + ldc*3))
+        STORE16_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16))
+        STORE16_COMPLETE_RESULT(result_512_4,  (C_addr + ldc*4))
+        STORE16_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16))
+        STORE16_COMPLETE_RESULT(result_512_5,  (C_addr + ldc*5))
+        STORE16_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16))
+        STORE16_COMPLETE_RESULT(result_512_6,  (C_addr + ldc*6))
+        STORE16_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16))
+        STORE16_COMPLETE_RESULT(result_512_7,  (C_addr + ldc*7))
+        STORE16_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16))
+    }
+}
+
+// SBGEMM Kernel for M=16, N=8, K=Any number
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_block_kernel_tn_16x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else                  // ALPHA is ONE
+void sbgemm_block_kernel_tn_16x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
+#ifndef ONE_ALPHA
+    __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+    __m512i arrayA_512_0;
+    __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
+    __m512  result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7;
+
+    result_512_0  = _mm512_setzero_ps();
+    result_512_1  = _mm512_setzero_ps();
+    result_512_2  = _mm512_setzero_ps();
+    result_512_3  = _mm512_setzero_ps();
+    result_512_4  = _mm512_setzero_ps();
+    result_512_5  = _mm512_setzero_ps();
+    result_512_6  = _mm512_setzero_ps();
+    result_512_7  = _mm512_setzero_ps();
+
+    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+        // Load 16 pair of BF16 elements from A (16 rows)
+        arrayA_512_0 = _mm512_loadu_si512(A_addr + 0);
+
+        // Load 8 rows of B
+        _MM512_BROADCASTD_EPI32(B_addr + 0,  arrayB_512_0);
+        _MM512_BROADCASTD_EPI32(B_addr + 2,  arrayB_512_1);
+        _MM512_BROADCASTD_EPI32(B_addr + 4,  arrayB_512_2);
+        _MM512_BROADCASTD_EPI32(B_addr + 6,  arrayB_512_3);
+        _MM512_BROADCASTD_EPI32(B_addr + 8,  arrayB_512_4);
+        _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
+        _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
+        _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
+
+        result_512_0  = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
+        result_512_1  = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
+        result_512_2  = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
+        result_512_3  = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
+        result_512_4  = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
+        result_512_5  = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
+        result_512_6  = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
+        result_512_7  = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
+
+        // Load B with unroll 8
+        B_addr += 16;
+        // Load A with unroll 32
+        A_addr += 32;
+    }
+
+    if (m != 16) {
+        unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
+        __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+        STORE16_MASK_COMPLETE_RESULT(result_512_0,  (C_addr), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_1,  (C_addr + ldc), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_2,  (C_addr + ldc*2), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_3,  (C_addr + ldc*3), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_4,  (C_addr + ldc*4), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_5,  (C_addr + ldc*5), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_6,  (C_addr + ldc*6), tail_mask)
+        STORE16_MASK_COMPLETE_RESULT(result_512_7,  (C_addr + ldc*7), tail_mask)
+    } else {
+        STORE16_COMPLETE_RESULT(result_512_0,  (C_addr))
+        STORE16_COMPLETE_RESULT(result_512_1,  (C_addr + ldc))
+        STORE16_COMPLETE_RESULT(result_512_2,  (C_addr + ldc*2))
+        STORE16_COMPLETE_RESULT(result_512_3,  (C_addr + ldc*3))
+        STORE16_COMPLETE_RESULT(result_512_4,  (C_addr + ldc*4))
+        STORE16_COMPLETE_RESULT(result_512_5,  (C_addr + ldc*5))
+        STORE16_COMPLETE_RESULT(result_512_6,  (C_addr + ldc*6))
+        STORE16_COMPLETE_RESULT(result_512_7,  (C_addr + ldc*7))
+    }
+}
+
+// SBGEMM Kernel for 16<M<=32, N<8, K=Any number but will be processed based on 32
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_block_kernel_tn_32xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else                  // ALPHA is ONE
+void sbgemm_block_kernel_tn_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
+    int  SHUFFLE_MAGIC_NO = 0x39;
+    BLASLONG tag_k_32x = k & (~31);
+
+#ifndef ONE_ALPHA
+    __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+    __m512i arrayA_512[2];
+    __m512i arrayB_512[8];
+    __m512  result_512[16];
+
+    for (int i = 0; i < 15; i++) {
+        result_512[i] = _mm512_setzero_ps();
+    }
+
+    for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+        // Load B with unroll n
+        for (int i = 0; i < n; i ++) {
+            arrayB_512[i] = _mm512_loadu_si512(B_addr);
+            B_addr += 32;
+        }
+
+        for (BLASLONG idx = 0; idx < 32;) {
+            // Each two rows are a group for 32-pair bf16 elements
+            arrayA_512[0] = _mm512_loadu_si512(A_addr);
+            arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
+            A_addr += 64;
+
+            for (int i = 0; i < n; i++) {
+                result_512[i]   = _mm512_dpbf16_ps(result_512[i]  , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+                result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+                arrayB_512[i]   = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+            }
+
+            idx += 2;
+            // Every 4 loops we need to switch to next 128 bits of arrayB registers
+            if ((idx & (~7)) == idx) {
+                for (int i = 0; i < n; i++) {
+                    arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+                }
+            }
+        }
+    }
+
+    if (tag_k_32x != k) {
+            // Load B with unroll n
+        for (int i = 0; i < n; i ++) {
+            arrayB_512[i] = _mm512_loadu_si512(B_addr);
+            B_addr += 32;
+        }
+
+        BLASLONG width = k - tag_k_32x;
+        for (BLASLONG idx = 0; idx < width;) {
+            // Each two rows are a group for 32-pair bf16 elements
+            arrayA_512[0] = _mm512_loadu_si512(A_addr);
+            arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
+            A_addr += 64;
+
+            for (int i = 0; i < n; i++) {
+                result_512[i]   = _mm512_dpbf16_ps(result_512[i]  , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+                result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+                arrayB_512[i]   = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+            }
+
+            idx += 2;
+            // Every 4 loops we need to switch to next 128 bits of arrayB registers
+            if ((idx & (~7)) == idx) {
+                for (int i = 0; i < n; i++) {
+                    arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+                }
+            }
+        }
+    }
+
+    if (m != 32) {
+        unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
+        for (int i = 0; i < n; i++) {
+            STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+            STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask)
+        }
+    } else {
+        for (int i = 0; i < n; i++) {
+            STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+            STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16))
+        }
+    }
+}
+
+// SBGEMM Kernel for M<=16, N<8, K=Any number but will be processed based on 32
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_block_kernel_tn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else                  // ALPHA is ONE
+void sbgemm_block_kernel_tn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
+    int  SHUFFLE_MAGIC_NO = 0x39;
+    BLASLONG tag_k_32x = k & (~31);
+
+#ifndef ONE_ALPHA
+    __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+    __m512i arrayA_512;
+    __m512i arrayB_512[8];
+    __m512  result_512[8];
+
+    for (int i = 0; i < 8; i++) {
+        result_512[i]    = _mm512_setzero_ps();
+    }
+
+    for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+        // Load B with unroll n
+        for (int i = 0; i < n; i ++) {
+            arrayB_512[i] = _mm512_loadu_si512(B_addr);
+            B_addr += 32;
+        }
+
+        for (BLASLONG idx = 0; idx < 32;) {
+            // Each two rows are a group for 32-pair bf16 elements
+            arrayA_512 = _mm512_loadu_si512(A_addr);
+            A_addr += 32;
+
+            for (int i = 0; i < n; i++) {
+                result_512[i]   = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+                arrayB_512[i]   = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+            }
+
+            idx += 2;
+            // Every 4 loops we need to switch to next 128 bits of arrayB registers
+            if ((idx & (~7)) == idx) {
+                for (int i = 0; i < n; i++) {
+                    arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+                }
+            }
+        }
+    }
+
+    if (tag_k_32x != k) {
+        // Load B with unroll n
+        for (int i = 0; i < n; i ++) {
+            arrayB_512[i] = _mm512_loadu_si512(B_addr);
+            B_addr += 32;
+        }
+
+        BLASLONG width = k - tag_k_32x;
+        for (BLASLONG idx = 0; idx < width;) {
+            // Each two rows are a group for 32-pair bf16 elements
+            arrayA_512 = _mm512_loadu_si512(A_addr);
+            A_addr += 32;
+
+            for (int i = 0; i < n; i++) {
+                result_512[i]   = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+                arrayB_512[i]   = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+            }
+
+            idx += 2;
+            // Every 4 loops we need to switch to next 128 bits of arrayB registers
+            if ((idx & (~7)) == idx) {
+                for (int i = 0; i < n; i++) {
+                    arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+                }
+            }
+        }
+    }
+
+    if (m != 16) {
+        unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
+        for (int i = 0; i < n; i++) {
+            STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
+        }
+    } else {
+        for (int i = 0; i < n; i++) {
+            STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+        }
+    }
+}
+
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_blocking_kernel_tn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#else                  // ALPHA is ONE
+void sbgemm_blocking_kernel_tn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#endif
+{
+    BLASLONG m_step, n_step, k_step, k_step_round32;
+    BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
+
+    BLASLONG n_from, n_to;
+    BLASLONG tag_n_Nx;
+
+    n_from = 0;
+    n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
+    tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+
+    k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
+    k_step_round32 = k_step & (~31);
+    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+
+    if (M >= BF16_BLOCK_THRES_M) {
+        while (n_from < N) {
+            for (BLASLONG idx_k = 0; idx_k < K;) {
+                // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+                COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A);
+                for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                    // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                    COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                    SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); // TODO how to process m
+                }
+
+                if (tag_n_Nx != n_to) {
+                    n_step = n_to - tag_n_Nx;
+                    COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                    SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                }
+
+                for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
+                    COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
+                    }
+                }
+
+                if (tag_m_Nx != M) {
+                    m_step = M - tag_m_Nx;
+                    if (m_step > 16) {
+                        COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
+                        for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                            SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+                        }
+
+                        if (tag_n_Nx != n_to) {
+                            n_step = n_to - tag_n_Nx;
+                            SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+                        }
+                    } else {
+                        COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
+                        for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                            SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+                        }
+
+                        if (tag_n_Nx != n_to) {
+                            n_step = n_to - tag_n_Nx;
+                            SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+                        }
+                    }
+                }
+
+                idx_k += k_step;
+                k_step = K - idx_k;
+                k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                k_step_round32 = k_step & (~31);
+                k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+            }
+
+            n_from = n_to;
+            n_to += BF16_BLOCK_THRES_N;
+            n_to = (n_to > N) ? N : n_to;
+            tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+        }
+    } else {
+        m_step = M;
+        if (m_step > 16) {
+            while (n_from < N) {
+                for (BLASLONG idx_k = 0; idx_k < K;) {
+                    // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+                    COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                        COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    }
+
+                    idx_k += k_step;
+                    k_step = K - idx_k;
+                    k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                    k_step_round32 = k_step & (~31);
+                    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+                }
+                n_from = n_to;
+                n_to += BF16_BLOCK_THRES_N;
+                n_to = (n_to > N) ? N : n_to;
+                tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+            }
+        } else {
+            while (n_from < N) {
+                for (BLASLONG idx_k = 0; idx_k < K;) {
+                    // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+                    COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                        COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    }
+
+                    idx_k += k_step;
+                    k_step = K - idx_k;
+                    k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                    k_step_round32 = k_step & (~31);
+                    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+                }
+                n_from = n_to;
+                n_to += BF16_BLOCK_THRES_N;
+                n_to = (n_to > N) ? N : n_to;
+                tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+            }
+        }
+    }
+}
+/* ----------------------------------------- End of TN kernels --------------------------------------- */
+
+/* --------------------------------------------- TT kernels ------------------------------------------ */
+// SBGEMM Kernel for 16<M<=32, N<8, K can be any number
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_block_kernel_tt_32xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else                  // ALPHA is ONE
+void sbgemm_block_kernel_tt_32xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
+#ifndef ONE_ALPHA
+    __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+    __m512i arrayA_512_0, arrayA_512_1;
+    __m512i arrayB_512[8];
+    __m512  result_512[16];
+
+    result_512[0]  = _mm512_setzero_ps();
+    result_512[1]  = _mm512_setzero_ps();
+    result_512[2]  = _mm512_setzero_ps();
+    result_512[3]  = _mm512_setzero_ps();
+    result_512[4]  = _mm512_setzero_ps();
+    result_512[5]  = _mm512_setzero_ps();
+    result_512[6]  = _mm512_setzero_ps();
+    result_512[7]  = _mm512_setzero_ps();
+    result_512[8]  = _mm512_setzero_ps();
+    result_512[9]  = _mm512_setzero_ps();
+    result_512[10] = _mm512_setzero_ps();
+    result_512[11] = _mm512_setzero_ps();
+    result_512[12] = _mm512_setzero_ps();
+    result_512[13] = _mm512_setzero_ps();
+    result_512[14] = _mm512_setzero_ps();
+    result_512[15] = _mm512_setzero_ps();
+
+    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+        // Each two rows are a group for 32-pair bf16 elements
+        arrayA_512_0 = _mm512_loadu_si512(A_addr);
+        arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
+        A_addr += 64;
+
+        for (int i = 0; i < n; i ++) {
+            _MM512_BROADCASTD_EPI32(B_addr + i*2,  arrayB_512[i]);
+        }
+        B_addr += 16;
+
+        for (int i = 0; i < n; i ++) {
+            result_512[i] = _mm512_dpbf16_ps(result_512[i],  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
+            result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8],  (__m512bh) arrayA_512_1, (__m512bh) arrayB_512[i]);
+        }
+    }
+
+    if (m != 32) {
+        unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
+        for (int i = 0; i < n; i ++) {
+            STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+            STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask)
+        }
+    } else {
+        for (int i = 0; i < n; i ++) {
+            STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+            STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16))
+        }
+    }
+}
+
+// SBGEMM Kernel for M<=16, N<8, K can be any number
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_block_kernel_tt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else                  // ALPHA is ONE
+void sbgemm_block_kernel_tt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+    bfloat16 * A_addr = A;
+    bfloat16 * B_addr = B;
+    float    * C_addr = C;
+
+#ifndef ONE_ALPHA
+    __m512  ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+    __m512i arrayA_512_0;
+    __m512i arrayB_512[8];
+    __m512  result_512[8];
+
+    result_512[0]  = _mm512_setzero_ps();
+    result_512[1]  = _mm512_setzero_ps();
+    result_512[2]  = _mm512_setzero_ps();
+    result_512[3]  = _mm512_setzero_ps();
+    result_512[4]  = _mm512_setzero_ps();
+    result_512[5]  = _mm512_setzero_ps();
+    result_512[6]  = _mm512_setzero_ps();
+    result_512[7]  = _mm512_setzero_ps();
+
+    for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+        // Each two rows are a group for 16-pair bf16 elements
+        // Load two rows into a 512 register
+        arrayA_512_0 = _mm512_loadu_si512(A_addr);
+        A_addr += 32;
+
+        for (int i = 0; i < n; i ++) {
+            _MM512_BROADCASTD_EPI32(B_addr + i*2,  arrayB_512[i]);
+        }
+        B_addr += 16;
+
+        for (int i = 0; i < n; i ++) {
+            result_512[i] = _mm512_dpbf16_ps(result_512[i],  (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
+        }
+    }
+
+    if (m != 16) {
+        unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
+        for (int i = 0; i < n; i++) {
+            STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
+        }
+    } else {
+        for (int i = 0; i < n; i++) {
+            STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+        }
+    }
+}
+
+#ifndef ONE_ALPHA      // ALPHA is not ONE
+void sbgemm_blocking_kernel_tt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#else                  // ALPHA is ONE
+void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#endif
+{
+    BLASLONG m_step, n_step, k_step, k_step_round32;
+    BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
+
+    BLASLONG n_from, n_to;
+    BLASLONG tag_n_Nx;
+
+    n_from = 0;
+    n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
+    tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+
+    k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
+    k_step_round32 = k_step & (~31);
+    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+
+    if (M >= BF16_BLOCK_THRES_M) {
+        while (n_from < N) {
+            for (BLASLONG idx_k = 0; idx_k < K;) {
+                // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+                COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A);
+                for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                    // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                    COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                    SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                }
+
+                if (tag_n_Nx != n_to) {
+                    n_step = n_to - tag_n_Nx;
+                    COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                    SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                }
+
+                for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
+                    COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
+                    }
+                }
+
+                if (tag_m_Nx != M) {
+                    m_step = M - tag_m_Nx;
+                    if (m_step > 16) {
+                        COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
+                        for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                            SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+                        }
+
+                        if (tag_n_Nx != n_to) {
+                            n_step = n_to - tag_n_Nx;
+                            SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+                        }
+                    } else {
+                        COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
+                        for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                            SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+                        }
+
+                        if (tag_n_Nx != n_to) {
+                            n_step = n_to - tag_n_Nx;
+                            SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+                        }
+                    }
+                }
+
+                idx_k += k_step;
+                k_step = K - idx_k;
+                k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                k_step_round32 = k_step & (~31);
+                k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+            }
+
+            n_from = n_to;
+            n_to += BF16_BLOCK_THRES_N;
+            n_to = (n_to > N) ? N : n_to;
+            tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+        }
+    } else {
+        m_step = M;
+        if (m_step > 16) {
+            while (n_from < N) {
+                for (BLASLONG idx_k = 0; idx_k < K;) {
+                    // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+                    COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                        COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    }
+
+                    idx_k += k_step;
+                    k_step = K - idx_k;
+                    k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                    k_step_round32 = k_step & (~31);
+                    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+                }
+                n_from = n_to;
+                n_to += BF16_BLOCK_THRES_N;
+                n_to = (n_to > N) ? N : n_to;
+                tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+            }
+        } else {
+            while (n_from < N) {
+                for (BLASLONG idx_k = 0; idx_k < K;) {
+                    // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+                    COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A);
+                    for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+                        // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+                        COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+                    }
+
+                    if (tag_n_Nx != n_to) {
+                        n_step = n_to - tag_n_Nx;
+                        COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+                        SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+                    }
+
+                    idx_k += k_step;
+                    k_step = K - idx_k;
+                    k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+                    k_step_round32 = k_step & (~31);
+                    k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+                }
+                n_from = n_to;
+                n_to += BF16_BLOCK_THRES_N;
+                n_to = (n_to > N) ? N : n_to;
+                tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+            }
+        }
+    }
+}
+/* ----------------------------------------- End of TT kernels --------------------------------------- */
 
 #ifndef ONE_ALPHA      // ALPHA is not ONE
 void sbgemm_internal_kernel_alpha(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
@@ -613,13 +1806,33 @@ void sbgemm_internal_kernel_one(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_
                 OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc)
 #endif
 {
-    bfloat16 block_A[BF16_BLOCK_THRES_K * BF16_BLOCK_THRES_M];
-    bfloat16 block_B[BF16_BLOCK_THRES_N * BF16_BLOCK_THRES_K];
-
-    // TODO: assume no trans for both A and B, to complement these scenarios later
     if (Order == CblasColMajor) {
-        SBGEMM_BLOCKING_KERNEL_2(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+        if (TransA == CblasNoTrans) {
+            if (TransB == CblasNoTrans) {
+                SBGEMM_BLOCKING_KERNEL_NN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+            } else if (TransB == CblasTrans) {
+                SBGEMM_BLOCKING_KERNEL_NT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+            }
+        } else {
+            if (TransB == CblasNoTrans) {
+                SBGEMM_BLOCKING_KERNEL_TN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+            } else if (TransB == CblasTrans) {
+                SBGEMM_BLOCKING_KERNEL_TT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+            }
+        }
     } else {
-        
+        if (TransA == CblasNoTrans) {
+            if (TransB == CblasNoTrans) {
+                SBGEMM_BLOCKING_KERNEL_NN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
+            } else if (TransB == CblasTrans) {
+                SBGEMM_BLOCKING_KERNEL_TN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
+            }
+        } else {
+            if (TransB == CblasNoTrans) {
+                SBGEMM_BLOCKING_KERNEL_NT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
+            } else if (TransB == CblasTrans) {
+                SBGEMM_BLOCKING_KERNEL_TT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
+            }
+        }
     }
-}
\ No newline at end of file
+}