#include <immintrin.h>
+#define _MM512_BROADCASTD_EPI32(addr, zmm) \
+ __asm__ ("vpbroadcastd (%1), %0;" \
+ : "=v" (zmm) \
+ : "r" (addr) )
+
+#define PREFETCH_T0(addr) \
+ __asm__ ("prefetcht0 (%0);" \
+ : \
+ : "r" (addr) )
+
#define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \
reg256##_0 = _mm512_castps512_ps256(reg512##_0); \
reg256##_1 = _mm512_castps512_ps256(reg512##_1);
_mm_mask_storeu_ps(targetAddr, mask, regResult);
+/* Store 16 (result + y) to y
+*/
+#define STORE16_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
+ regResult = _mm512_add_ps(regResult, _mm512_loadu_ps(targetAddr)); \
+ _mm512_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 16 (result + y) to y
+*/
+#define STORE16_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
+ regResult = _mm512_add_ps(regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
+ _mm512_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 8 (result + y) to y
+*/
+#define STORE8_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
+ regResult = _mm256_add_ps(regResult, _mm256_loadu_ps(targetAddr)); \
+ _mm256_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 8 (result + y) to y
+*/
+#define STORE8_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
+ regResult = _mm256_add_ps(regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
+ _mm256_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 4 (result + y) to y
+*/
+#define STORE4_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
+ regResult = _mm_add_ps(regResult, _mm_loadu_ps(targetAddr)); \
+ _mm_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 4 (result + y) to y
+*/
+#define STORE4_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
+ regResult = _mm_add_ps(regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
+ _mm_mask_storeu_ps(targetAddr, mask, regResult);
+
+
/* Store 16 (alpha * result) to y
*/
#define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
-#include "sbgemm.h"
+//#include "sbgemm.h"
#include <immintrin.h>
// Walk around those intrinsics that missed by compiler
#define MM256_STOREU_EPI16(addr, reg) \
_mm256_mask_storeu_epi16((addr), ~0, (reg))
-#include <stdio.h>
-void print_block(BLASLONG m, BLASLONG n, bfloat16 * mat)
-{
- printf("---- BLOCK %ld x %ld ----\n", m, n);
- for (BLASLONG i=0; i<m; i++) {
- for (BLASLONG j=0; j<n; j++) {
- printf("%-4X ", *(mat + i*n +j));
- }
- printf("\n");
- }
- printf("---- End of BLOCK ----\n");
-}
-
-void COL_MAJOR_INCOPY_KERNEL_Kx32(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+// INCOPY Kernel, 16<M<=32, k can be any number
+void COL_MAJOR_INCOPY_KERNEL_Kx32(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
{
BLASLONG tag_k_2x = k & (~1);
+ unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-m));
__m512i array512_0, array512_1, array512_2, array512_3;
- BLASLONG idx_src_base0, idx_src_base1;
- BLASLONG idx_target_base0, idx_target_base1;
+ bfloat16 * src_addr0, * src_addr1;
+ bfloat16 * dst_addr0, * dst_addr1;
BLASLONG LDA_2x = 2*lda;
BLASLONG BF16_BLOCK_T_M_2x = 2*32;
- idx_src_base0 = 0;
- idx_src_base1 = lda;
- idx_target_base0 = 0;
- idx_target_base1 = 32;
- for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
- array512_0 = _mm512_loadu_si512(&A[idx_src_base0]);
- array512_1 = _mm512_loadu_si512(&A[idx_src_base1]);
- array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1);
- array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1);
- _mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
- _mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
-
- idx_src_base0 += LDA_2x;
- idx_src_base1 += LDA_2x;
- idx_target_base0 += BF16_BLOCK_T_M_2x;
- idx_target_base1 += BF16_BLOCK_T_M_2x;
- }
-
- if (tag_k_2x != k) {
- __m512i ZERO512 = _mm512_setzero_si512();
- array512_0 = _mm512_loadu_si512(&A[idx_src_base0]);
- array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512);
- array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512);
- _mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
- _mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
- }
-
-#ifdef DEBUG_PROFILE
- print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
-#endif
-}
-
-void COL_MAJOR_INCOPY_KERNEL_Kx32m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
-{
- BLASLONG tag_k_2x = k & (~1);
- unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-m));
- __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
-
- __m512i array512_0, array512_1, array512_2, array512_3;
- BLASLONG idx_src_base0, idx_src_base1;
- BLASLONG idx_target_base0, idx_target_base1;
+ src_addr0 = A;
+ src_addr1 = A + lda;
+ dst_addr0 = block_A;
+ dst_addr1 = block_A + 32;
- BLASLONG LDA_2x = 2*lda;
- BLASLONG BF16_BLOCK_T_M_2x = 2*32;
- idx_src_base0 = 0;
- idx_src_base1 = lda;
- idx_target_base0 = 0;
- idx_target_base1 = 32;
for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
- array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
- array512_1 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]);
+ array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0);
+ array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1);
array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1);
array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1);
- _mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
- _mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
- idx_src_base0 += LDA_2x;
- idx_src_base1 += LDA_2x;
- idx_target_base0 += BF16_BLOCK_T_M_2x;
- idx_target_base1 += BF16_BLOCK_T_M_2x;
+ src_addr0 += LDA_2x;
+ src_addr1 += LDA_2x;
+ dst_addr0 += BF16_BLOCK_T_M_2x;
+ dst_addr1 += BF16_BLOCK_T_M_2x;
}
if (tag_k_2x != k) {
__m512i ZERO512 = _mm512_setzero_si512();
- array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
+ array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0);
array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512);
array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512);
- _mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
- _mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
}
-
-#ifdef DEBUG_PROFILE
- print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
-#endif
}
+// INCOPY Kernel, 0<M<=16, k can be any number
void COL_MAJOR_INCOPY_KERNEL_Kx16(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
{
BLASLONG tag_k_2x = k & (~1);
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
__m256i array256_0, array256_1, array256_2, array256_3;
- BLASLONG idx_src_base0, idx_src_base1;
- BLASLONG idx_target_base0;
+ bfloat16 * src_addr0, * src_addr1;
+ bfloat16 * dst_addr0;
BLASLONG LDA_2x = 2*lda;
- idx_src_base0 = 0;
- idx_src_base1 = lda;
- idx_target_base0 = 0;
+
+ src_addr0 = A;
+ src_addr1 = A + lda;
+ dst_addr0 = block_A;
+
for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
- array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]);
- array256_1 = MM256_LOADU_EPI16(&A[idx_src_base1]);
+ array256_0 = _mm256_maskz_loadu_epi16(tail_mask, src_addr0);
+ array256_1 = _mm256_maskz_loadu_epi16(tail_mask, src_addr1);
array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1);
array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1);
// Store in one row of block_B
- MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2);
- MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
+ MM256_STOREU_EPI16(dst_addr0, array256_2);
+ MM256_STOREU_EPI16(dst_addr0+16, array256_3);
- idx_src_base0 += LDA_2x;
- idx_src_base1 += LDA_2x;
- idx_target_base0 += 32;
+ src_addr0 += LDA_2x;
+ src_addr1 += LDA_2x;
+ dst_addr0 += 32;
}
if (tag_k_2x != k) {
__m256i ZERO256 = _mm256_setzero_si256();
- array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]);
+ array256_0 = _mm256_maskz_loadu_epi16(tail_mask, src_addr0);
array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256);
array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256);
// Store in one row of block_B
- MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2);
- MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
+ MM256_STOREU_EPI16(dst_addr0, array256_2);
+ MM256_STOREU_EPI16(dst_addr0+16, array256_3);
}
+}
-#ifdef DEBUG_PROFILE
- print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
-#endif
+// K=32, M=16
+void COL_MAJOR_ITCOPY_KERNEL_32x16(bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+{
+ bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+ bfloat16 * dst_addr0, * dst_addr1;
+
+ BLASLONG LDA_4x = lda*4;
+
+ src_addr0 = A;
+ src_addr1 = A + lda;
+ src_addr2 = A + lda*2;
+ src_addr3 = A + lda*3;
+ dst_addr0 = block_A;
+ dst_addr1 = block_A + 32*8;
+
+ __m512i array512_0, array512_1, array512_2, array512_3;
+ __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+ __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3;
+ __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3;
+ __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3;
+
+ __m512i M512_EPI64_2 = _mm512_set1_epi64(2);
+ __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+ __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+ // Load and preprocess 1st 4 rows
+ array512_way0_0 = _mm512_loadu_si512(src_addr0);
+ array512_way0_1 = _mm512_loadu_si512(src_addr1);
+ array512_way0_2 = _mm512_loadu_si512(src_addr2);
+ array512_way0_3 = _mm512_loadu_si512(src_addr3);
+ array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3);
+ array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // Load and preprocess 2nd 4 rows
+ array512_way1_0 = _mm512_loadu_si512(src_addr0);
+ array512_way1_1 = _mm512_loadu_si512(src_addr1);
+ array512_way1_2 = _mm512_loadu_si512(src_addr2);
+ array512_way1_3 = _mm512_loadu_si512(src_addr3);
+ array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3);
+ array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // Load and preprocess 3rd 4 rows
+ array512_way2_0 = _mm512_loadu_si512(src_addr0);
+ array512_way2_1 = _mm512_loadu_si512(src_addr1);
+ array512_way2_2 = _mm512_loadu_si512(src_addr2);
+ array512_way2_3 = _mm512_loadu_si512(src_addr3);
+ array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3);
+ array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // Load and preprocess 4th 4 rows
+ array512_way3_0 = _mm512_loadu_si512(src_addr0);
+ array512_way3_1 = _mm512_loadu_si512(src_addr1);
+ array512_way3_2 = _mm512_loadu_si512(src_addr2);
+ array512_way3_3 = _mm512_loadu_si512(src_addr3);
+ array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3);
+ array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // Compose and store the 0/1 and 16/17 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 2/3 and 18/19 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 4/5 and 20/21 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 6/7 and 22/23 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 8/9 and 24/25 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 10/11 and 26/27 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 12/13 and 28/29 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 14/15 and 30/31 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
}
-void COL_MAJOR_INCOPY_KERNEL_Kx16m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+// K=Any number but will be processed based on 32, M=32
+void COL_MAJOR_ITCOPY_KERNEL_Kx32(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
{
- BLASLONG tag_k_2x = k & (~1);
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+ bfloat16 * dst_addr0, * dst_addr1;
- __m256i array256_0, array256_1, array256_2, array256_3;
+ BLASLONG tag_k_32x = k & (~31);
- BLASLONG idx_src_base0, idx_src_base1;
- BLASLONG idx_target_base0;
+ BLASLONG LDA_4x = lda*4;
+ BLASLONG LDA_8x = lda*8;
+ BLASLONG LDA_12x = lda*12;
+ BLASLONG LDA_16x = lda*16;
- BLASLONG LDA_2x = 2*lda;
- idx_src_base0 = 0;
- idx_src_base1 = lda;
- idx_target_base0 = 0;
- for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
- array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
- array256_1 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]);
- array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1);
- array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1);
- // Store in one row of block_B
- MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2);
- MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
+ src_addr0 = A;
+ src_addr1 = A + lda;
+ src_addr2 = A + lda*2;
+ src_addr3 = A + lda*3;
+ dst_addr0 = block_A;
+ dst_addr1 = block_A + 32*16;
- idx_src_base0 += LDA_2x;
- idx_src_base1 += LDA_2x;
- idx_target_base0 += 32;
+ __m512i array512_0, array512_1, array512_2, array512_3;
+ __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+ __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3;
+ __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3;
+ __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3;
+
+ __m512i M512_EPI64_2 = _mm512_set1_epi64(2);
+ __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+ __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+ for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+ for (int i = 0; i < 2; i++) {
+ // Load and preprocess 1st 4 rows
+ array512_way0_0 = _mm512_loadu_si512(src_addr0+idx_k);
+ array512_way0_1 = _mm512_loadu_si512(src_addr1+idx_k);
+ array512_way0_2 = _mm512_loadu_si512(src_addr2+idx_k);
+ array512_way0_3 = _mm512_loadu_si512(src_addr3+idx_k);
+ array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3);
+ array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // Load and preprocess 2nd 4 rows
+ array512_way1_0 = _mm512_loadu_si512(src_addr0+LDA_4x+idx_k);
+ array512_way1_1 = _mm512_loadu_si512(src_addr1+LDA_4x+idx_k);
+ array512_way1_2 = _mm512_loadu_si512(src_addr2+LDA_4x+idx_k);
+ array512_way1_3 = _mm512_loadu_si512(src_addr3+LDA_4x+idx_k);
+ array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3);
+ array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // Load and preprocess 3rd 4 rows
+ array512_way2_0 = _mm512_loadu_si512(src_addr0+LDA_8x+idx_k);
+ array512_way2_1 = _mm512_loadu_si512(src_addr1+LDA_8x+idx_k);
+ array512_way2_2 = _mm512_loadu_si512(src_addr2+LDA_8x+idx_k);
+ array512_way2_3 = _mm512_loadu_si512(src_addr3+LDA_8x+idx_k);
+ array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3);
+ array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // Load and preprocess 4th 4 rows
+ array512_way3_0 = _mm512_loadu_si512(src_addr0+LDA_12x+idx_k);
+ array512_way3_1 = _mm512_loadu_si512(src_addr1+LDA_12x+idx_k);
+ array512_way3_2 = _mm512_loadu_si512(src_addr2+LDA_12x+idx_k);
+ array512_way3_3 = _mm512_loadu_si512(src_addr3+LDA_12x+idx_k);
+ array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3);
+ array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // Compose and store the 0/1 and 16/17 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+
+ // Compose and store the 2/3 and 18/19 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+
+ // Compose and store the 4/5 and 20/21 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+
+ // Compose and store the 6/7 and 22/23 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+
+ // Compose and store the 8/9 and 24/25 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+
+ // Compose and store the 10/11 and 26/27 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+
+ // Compose and store the 12/13 and 28/29 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+
+ // Compose and store the 14/15 and 30/31 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+
+ src_addr0 += LDA_16x;
+ src_addr1 += LDA_16x;
+ src_addr2 += LDA_16x;
+ src_addr3 += LDA_16x;
+ dst_addr0 -= (64*7 - 32);
+ dst_addr1 -= (64*7 - 32);
+ }
+ src_addr0 -= (LDA_16x*2);
+ src_addr1 -= (LDA_16x*2);
+ src_addr2 -= (LDA_16x*2);
+ src_addr3 -= (LDA_16x*2);
+ dst_addr0 += (32*30);
+ dst_addr1 += (32*30);
}
- if (tag_k_2x != k) {
- __m256i ZERO256 = _mm256_setzero_si256();
- array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
- array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256);
- array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256);
- // Store in one row of block_B
- MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2);
- MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
+ if (tag_k_32x != k) {
+ int k_rem = k - tag_k_32x;
+ unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem));
+ __m512i array512[16];
+
+ bfloat16 * dst_addr_tmp = dst_addr0;
+
+ for (int i = 0; i < 2; i++) {
+ // Load and preprocess 1st 4 rows
+ array512[0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512[1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512[2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512[3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+ array512_0 = _mm512_unpacklo_epi32(array512[0], array512[1]);
+ array512_1 = _mm512_unpackhi_epi32(array512[0], array512[1]);
+ array512_2 = _mm512_unpacklo_epi32(array512[2], array512[3]);
+ array512_3 = _mm512_unpackhi_epi32(array512[2], array512[3]);
+ array512[0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // Load and preprocess 2nd 4 rows
+ array512[4] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512[5] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512[6] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512[7] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+ array512_0 = _mm512_unpacklo_epi32(array512[4], array512[5]);
+ array512_1 = _mm512_unpackhi_epi32(array512[4], array512[5]);
+ array512_2 = _mm512_unpacklo_epi32(array512[6], array512[7]);
+ array512_3 = _mm512_unpackhi_epi32(array512[6], array512[7]);
+ array512[4] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[5] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[6] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[7] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // Load and preprocess 3rd 4 rows
+ array512[8] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512[9] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512[10] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512[11] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+ array512_0 = _mm512_unpacklo_epi32(array512[8], array512[9]);
+ array512_1 = _mm512_unpackhi_epi32(array512[8], array512[9]);
+ array512_2 = _mm512_unpacklo_epi32(array512[10], array512[11]);
+ array512_3 = _mm512_unpackhi_epi32(array512[10], array512[11]);
+ array512[8] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[9] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[10] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[11] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // Load and preprocess 4th 4 rows
+ array512[12] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512[13] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512[14] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512[15] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+ array512_0 = _mm512_unpacklo_epi32(array512[12], array512[13]);
+ array512_1 = _mm512_unpackhi_epi32(array512[12], array512[13]);
+ array512_2 = _mm512_unpacklo_epi32(array512[14], array512[15]);
+ array512_3 = _mm512_unpackhi_epi32(array512[14], array512[15]);
+ array512[12] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[13] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[14] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[15] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // array512_01_1617_0, array512_01_1617_1, array512_89_2425_0, array512_89_2425_1;
+ // Half-compose of 0/1, 16/17, 8/9, 24/25 cols
+ array512_0 = _mm512_permutex2var_epi64(array512[0], permute_lo_idx, array512[4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[8], permute_lo_idx, array512[12]);
+ array512_2 = _mm512_permutex2var_epi64(array512[0], permute_hi_idx, array512[4]);
+ array512_3 = _mm512_permutex2var_epi64(array512[8], permute_hi_idx, array512[12]);
+ array512[0] = array512_0; // 1st 8 pairs of col 0/1, and 1st 8 pairs of col 16/17
+ array512[4] = array512_1; // 2nd 8 pairs of col 0/1, and 2nd 8 pairs of col 16/17
+ array512[8] = array512_2; // 1st 8 pairs of col 8/9, and 1st 8 pairs of col 24/25
+ array512[12] = array512_3; // 2nd 8 pairs of col 8/9, and 2nd 8 pairs of col 24/25
+
+ // Half-compose of 2/3, 18/19, 10/11, 26/27 cols
+ array512_0 = _mm512_permutex2var_epi64(array512[1], permute_lo_idx, array512[5]);
+ array512_1 = _mm512_permutex2var_epi64(array512[9], permute_lo_idx, array512[13]);
+ array512_2 = _mm512_permutex2var_epi64(array512[1], permute_hi_idx, array512[5]);
+ array512_3 = _mm512_permutex2var_epi64(array512[9], permute_hi_idx, array512[13]);
+ array512[1] = array512_0; // 1st 8 pairs of col 2/3, and 1st 8 pairs of col 18/19
+ array512[5] = array512_1; // 2nd 8 pairs of col 2/3, and 2nd 8 pairs of col 18/19
+ array512[9] = array512_2; // 1st 8 pairs of col 10/11, and 1st 8 pairs of col 26/27
+ array512[13] = array512_3; // 2nd 8 pairs of col 10/11, and 2nd 8 pairs of col 26/27
+
+ // Half-compose of 4/5, 20/21, 12/13, 28/29 cols
+ array512_0 = _mm512_permutex2var_epi64(array512[2], permute_lo_idx, array512[6]);
+ array512_1 = _mm512_permutex2var_epi64(array512[10], permute_lo_idx, array512[14]);
+ array512_2 = _mm512_permutex2var_epi64(array512[2], permute_hi_idx, array512[6]);
+ array512_3 = _mm512_permutex2var_epi64(array512[10], permute_hi_idx, array512[14]);
+ array512[2] = array512_0; // 1st 8 pairs of col 4/5, and 1st 8 pairs of col 20/21
+ array512[6] = array512_1; // 2nd 8 pairs of col 4/5, and 2nd 8 pairs of col 20/21
+ array512[10] = array512_2; // 1st 8 pairs of col 12/13, and 1st 8 pairs of col 28/29
+ array512[14] = array512_3; // 2nd 8 pairs of col 12/13, and 2nd 8 pairs of col 28/29
+
+ // Half-compose of 6/7, 22/23, 14/15, 30/31 cols
+ array512_0 = _mm512_permutex2var_epi64(array512[3], permute_lo_idx, array512[7]);
+ array512_1 = _mm512_permutex2var_epi64(array512[11], permute_lo_idx, array512[15]);
+ array512_2 = _mm512_permutex2var_epi64(array512[3], permute_hi_idx, array512[7]);
+ array512_3 = _mm512_permutex2var_epi64(array512[11], permute_hi_idx, array512[15]);
+ array512[3] = array512_0; // 1st 8 pairs of col 6/7, and 1st 8 pairs of col 22/23
+ array512[7] = array512_1; // 2nd 8 pairs of col 6/7, and 2nd 8 pairs of col 22/23
+ array512[11] = array512_2; // 1st 8 pairs of col 14/15, and 1st 8 pairs of col 30/31
+ array512[15] = array512_3; // 2nd 8 pairs of col 14/15, and 2nd 8 pairs of col 30/31
+
+ // Compose and store the 0/1 cols
+ array512_0 = _mm512_inserti64x4(array512[0], _mm512_castsi512_si256(array512[4]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+
+ // Compose and store the 2/3 cols
+ array512_0 = _mm512_inserti64x4(array512[1], _mm512_castsi512_si256(array512[5]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+
+ // Compose and store the 4/5 cols
+ array512_0 = _mm512_inserti64x4(array512[2], _mm512_castsi512_si256(array512[6]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+
+ // Compose and store the 6/7 cols
+ array512_0 = _mm512_inserti64x4(array512[3], _mm512_castsi512_si256(array512[7]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+
+ // Compose and store the 8/9 cols
+ array512_0 = _mm512_inserti64x4(array512[8], _mm512_castsi512_si256(array512[12]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+
+ // Compose and store the 10/11 cols
+ array512_0 = _mm512_inserti64x4(array512[9], _mm512_castsi512_si256(array512[13]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+
+ // Compose and store the 12/13 cols
+ array512_0 = _mm512_inserti64x4(array512[10], _mm512_castsi512_si256(array512[14]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+
+ // Compose and store the 14/15 cols
+ array512_0 = _mm512_inserti64x4(array512[11], _mm512_castsi512_si256(array512[15]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+
+ // Compose and store 16 ~ k_rem cols
+ int idx_length = (k_rem + 1 - 16) >> 1;
+ if (idx_length > 4) {
+ for (int idx_k = 0; idx_k < 4; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+
+ for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+ } else {
+ for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+ }
+
+ dst_addr0 = dst_addr_tmp + 32;
+ }
+ }
+}
+
+// K=Any number but will be processed based on 32, 16<M<32
+void COL_MAJOR_ITCOPY_KERNEL_Kx32m(BLASLONG m, BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+{
+ bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+ bfloat16 * dst_addr0, * dst_addr1;
+
+ BLASLONG tag_k_32x = k & (~31);
+
+ BLASLONG LDA_4x = lda*4;
+
+ BLASLONG m_rem = m-16;
+
+ src_addr0 = A;
+ src_addr1 = A + lda;
+ src_addr2 = A + lda*2;
+ src_addr3 = A + lda*3;
+ dst_addr0 = block_A;
+ dst_addr1 = block_A + 32*16;
+
+ __m512i array512_0, array512_1, array512_2, array512_3;
+ __m512i array512[16];
+
+ __m512i M512_EPI64_2 = _mm512_set1_epi64(2);
+ __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+ __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+ for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+ for (int j = 0; j < 4; j++) {
+ int array_idx = j*4;
+ // Load and preprocess 4 rows
+ array512[array_idx+0] = _mm512_loadu_si512(src_addr0+idx_k);
+ array512[array_idx+1] = _mm512_loadu_si512(src_addr1+idx_k);
+ array512[array_idx+2] = _mm512_loadu_si512(src_addr2+idx_k);
+ array512[array_idx+3] = _mm512_loadu_si512(src_addr3+idx_k);
+ array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+ }
+
+ // Compose and store the 0/1, 2/3, 4/5, 6/7 and 16/17, 18/19, 20/21, 22/23 cols
+ for (int j = 0; j < 4; j++) {
+ array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+ }
+
+ // Compose and store the 8/9, 10/11, 12/13, 14/15 and 24/25, 26/27, 28/29, 30/31 cols
+ for (int j = 0; j < 4; j++) {
+ array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+ }
+
+ dst_addr0 -= (64*8 - 32);
+ dst_addr1 -= (64*8 - 32);
+
+ for (int j = 0; j < m_rem; j++) {
+ array512[j] = _mm512_loadu_si512(src_addr0+j*lda+idx_k);
+ }
+ for (int j = m_rem; j < 16; j++) {
+ array512[j] = _mm512_setzero_si512();
+ }
+
+ for (int j = 0; j < 4; j++) {
+ int array_idx = j*4;
+ array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ }
+
+ // Compose and store the 0/1, 2/3, 4/5, 6/7 and 16/17, 18/19, 20/21, 22/23 cols
+ for (int j = 0; j < 4; j++) {
+ array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+ }
+
+ // Compose and store the 8/9, 10/11, 12/13, 14/15 and 24/25, 26/27, 28/29, 30/31 cols
+ for (int j = 0; j < 4; j++) {
+ array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 64;
+ dst_addr1 += 64;
+ }
+
+ src_addr0 -= (LDA_4x*4);
+ src_addr1 -= (LDA_4x*4);
+ src_addr2 -= (LDA_4x*4);
+ src_addr3 -= (LDA_4x*4);
+ dst_addr0 += (32*15);
+ dst_addr1 += (32*15);
+ }
+
+ if (tag_k_32x != k) {
+ int k_rem = k - tag_k_32x;
+ int idx_length = (k_rem + 1 - 16) >> 1;
+ unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem));
+ bfloat16 * dst_addr_tmp = dst_addr0;
+
+ for (int j = 0; j < 4; j++) {
+ int array_idx = j*4;
+ // Load and preprocess 4 rows
+ array512[array_idx+0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512[array_idx+1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512[array_idx+2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512[array_idx+3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+ array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+ }
+
+ for (int j = 0; j < 4; j++) {
+ array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+ array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+ array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+ array512[j+0] = array512_0; // 1st 8 pairs of col 0/1|2/3|4/5|6/7, and 1st 8 pairs of col 16/17|18/19|20/21|22/23
+ array512[j+4] = array512_1; // 2nd 8 pairs of col 0/1|2/3|4/5|6/7, and 2nd 8 pairs of col 16/17|18/19|20/21|22/23
+ array512[j+8] = array512_2; // 1st 8 pairs of col 8/9|10/11|12/13|14/15, and 1st 8 pairs of col 24/25|26/27|28/29|30/31
+ array512[j+12] = array512_3; // 2nd 8 pairs of col 8/9|10/11|12/13|14/15, and 2nd 8 pairs of col 24/25|26/27|28/29|30/31
+ }
+
+ for (int j = 0; j < 4; j++) {
+ // Compose and store the 0/1 cols
+ array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+
+ for (int j = 8; j < 12; j++) {
+ array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+
+ // Compose and store 16 ~ k_rem cols
+ if (idx_length > 4) {
+ for (int idx_k = 0; idx_k < 4; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+
+ for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+ } else {
+ for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+ }
+
+ dst_addr0 = dst_addr_tmp + 32;
+
+ for (int j = 0; j < m_rem; j++) {
+ array512[j] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+j*lda+tag_k_32x);
+ }
+ for (int j = m_rem; j < 16; j++) {
+ array512[j] = _mm512_setzero_si512();
+ }
+
+ for (int j = 0; j < 4; j++) {
+ int array_idx = j*4;
+ array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ }
+
+ for (int j = 0; j < 4; j++) {
+ array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+ array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+ array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+ array512[j+0] = array512_0; // 1st 8 pairs of col 0/1|2/3|4/5|6/7, and 1st 8 pairs of col 16/17|18/19|20/21|22/23
+ array512[j+4] = array512_1; // 2nd 8 pairs of col 0/1|2/3|4/5|6/7, and 2nd 8 pairs of col 16/17|18/19|20/21|22/23
+ array512[j+8] = array512_2; // 1st 8 pairs of col 8/9|10/11|12/13|14/15, and 1st 8 pairs of col 24/25|26/27|28/29|30/31
+ array512[j+12] = array512_3; // 2nd 8 pairs of col 8/9|10/11|12/13|14/15, and 2nd 8 pairs of col 24/25|26/27|28/29|30/31
+ }
+
+ for (int j = 0; j < 4; j++) {
+ // Compose and store the 0/1 cols
+ array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+
+ for (int j = 8; j < 12; j++) {
+ array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+
+ // Compose and store 16 ~ k_rem cols
+ if (idx_length > 4) {
+ for (int idx_k = 0; idx_k < 4; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+
+ for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+ } else {
+ for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 64;
+ }
+ }
+ }
+}
+
+// K=Any number but will be processed based on 32, M=16
+void COL_MAJOR_ITCOPY_KERNEL_Kx16(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+{
+ bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+ bfloat16 * dst_addr0, * dst_addr1;
+
+ BLASLONG tag_k_32x = k & (~31);
+
+ BLASLONG LDA_4x = lda*4;
+ BLASLONG LDA_8x = lda*8;
+ BLASLONG LDA_12x = lda*12;
+
+ src_addr0 = A;
+ src_addr1 = A + lda;
+ src_addr2 = A + lda*2;
+ src_addr3 = A + lda*3;
+ dst_addr0 = block_A;
+ dst_addr1 = block_A + 32*8;
+
+ __m512i array512_0, array512_1, array512_2, array512_3;
+ __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+ __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3;
+ __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3;
+ __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3;
+
+ __m512i M512_EPI64_2 = _mm512_set1_epi64(2);
+ __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+ __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+ for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+ // Load and preprocess 1st 4 rows
+ array512_way0_0 = _mm512_loadu_si512(src_addr0+idx_k);
+ array512_way0_1 = _mm512_loadu_si512(src_addr1+idx_k);
+ array512_way0_2 = _mm512_loadu_si512(src_addr2+idx_k);
+ array512_way0_3 = _mm512_loadu_si512(src_addr3+idx_k);
+ array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3);
+ array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // Load and preprocess 2nd 4 rows
+ array512_way1_0 = _mm512_loadu_si512(src_addr0+LDA_4x+idx_k);
+ array512_way1_1 = _mm512_loadu_si512(src_addr1+LDA_4x+idx_k);
+ array512_way1_2 = _mm512_loadu_si512(src_addr2+LDA_4x+idx_k);
+ array512_way1_3 = _mm512_loadu_si512(src_addr3+LDA_4x+idx_k);
+ array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3);
+ array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // Load and preprocess 3rd 4 rows
+ array512_way2_0 = _mm512_loadu_si512(src_addr0+LDA_8x+idx_k);
+ array512_way2_1 = _mm512_loadu_si512(src_addr1+LDA_8x+idx_k);
+ array512_way2_2 = _mm512_loadu_si512(src_addr2+LDA_8x+idx_k);
+ array512_way2_3 = _mm512_loadu_si512(src_addr3+LDA_8x+idx_k);
+ array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3);
+ array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // Load and preprocess 4th 4 rows
+ array512_way3_0 = _mm512_loadu_si512(src_addr0+LDA_12x+idx_k);
+ array512_way3_1 = _mm512_loadu_si512(src_addr1+LDA_12x+idx_k);
+ array512_way3_2 = _mm512_loadu_si512(src_addr2+LDA_12x+idx_k);
+ array512_way3_3 = _mm512_loadu_si512(src_addr3+LDA_12x+idx_k);
+ array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1);
+ array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1);
+ array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3);
+ array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3);
+ array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // Compose and store the 0/1 and 16/17 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 2/3 and 18/19 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 4/5 and 20/21 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 6/7 and 22/23 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 8/9 and 24/25 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 10/11 and 26/27 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 12/13 and 28/29 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+
+ // Compose and store the 14/15 and 30/31 cols
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3);
+ array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32*9;
+ dst_addr1 += 32*9;
+ }
+
+ if (tag_k_32x != k) {
+ int k_rem = k - tag_k_32x;
+ unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem));
+ __m512i array512[16];
+
+ // Load and preprocess 1st 4 rows
+ array512[0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512[1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512[2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512[3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+ array512_0 = _mm512_unpacklo_epi32(array512[0], array512[1]);
+ array512_1 = _mm512_unpackhi_epi32(array512[0], array512[1]);
+ array512_2 = _mm512_unpacklo_epi32(array512[2], array512[3]);
+ array512_3 = _mm512_unpackhi_epi32(array512[2], array512[3]);
+ array512[0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // Load and preprocess 2nd 4 rows
+ array512[4] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512[5] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512[6] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512[7] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+ array512_0 = _mm512_unpacklo_epi32(array512[4], array512[5]);
+ array512_1 = _mm512_unpackhi_epi32(array512[4], array512[5]);
+ array512_2 = _mm512_unpacklo_epi32(array512[6], array512[7]);
+ array512_3 = _mm512_unpackhi_epi32(array512[6], array512[7]);
+ array512[4] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[5] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[6] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[7] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // Load and preprocess 3rd 4 rows
+ array512[8] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512[9] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512[10] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512[11] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+ array512_0 = _mm512_unpacklo_epi32(array512[8], array512[9]);
+ array512_1 = _mm512_unpackhi_epi32(array512[8], array512[9]);
+ array512_2 = _mm512_unpacklo_epi32(array512[10], array512[11]);
+ array512_3 = _mm512_unpackhi_epi32(array512[10], array512[11]);
+ array512[8] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[9] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[10] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[11] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ src_addr0 += LDA_4x;
+ src_addr1 += LDA_4x;
+ src_addr2 += LDA_4x;
+ src_addr3 += LDA_4x;
+
+ // Load and preprocess 4th 4 rows
+ array512[12] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512[13] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512[14] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512[15] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+ array512_0 = _mm512_unpacklo_epi32(array512[12], array512[13]);
+ array512_1 = _mm512_unpackhi_epi32(array512[12], array512[13]);
+ array512_2 = _mm512_unpacklo_epi32(array512[14], array512[15]);
+ array512_3 = _mm512_unpackhi_epi32(array512[14], array512[15]);
+ array512[12] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[13] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[14] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[15] = _mm512_unpackhi_epi64(array512_1, array512_3);
+
+ // array512_01_1617_0, array512_01_1617_1, array512_89_2425_0, array512_89_2425_1;
+ // Half-compose of 0/1, 16/17, 8/9, 24/25 cols
+ array512_0 = _mm512_permutex2var_epi64(array512[0], permute_lo_idx, array512[4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[8], permute_lo_idx, array512[12]);
+ array512_2 = _mm512_permutex2var_epi64(array512[0], permute_hi_idx, array512[4]);
+ array512_3 = _mm512_permutex2var_epi64(array512[8], permute_hi_idx, array512[12]);
+ array512[0] = array512_0; // 1st 8 pairs of col 0/1, and 1st 8 pairs of col 16/17
+ array512[4] = array512_1; // 2nd 8 pairs of col 0/1, and 2nd 8 pairs of col 16/17
+ array512[8] = array512_2; // 1st 8 pairs of col 8/9, and 1st 8 pairs of col 24/25
+ array512[12] = array512_3; // 2nd 8 pairs of col 8/9, and 2nd 8 pairs of col 24/25
+
+ // Half-compose of 2/3, 18/19, 10/11, 26/27 cols
+ array512_0 = _mm512_permutex2var_epi64(array512[1], permute_lo_idx, array512[5]);
+ array512_1 = _mm512_permutex2var_epi64(array512[9], permute_lo_idx, array512[13]);
+ array512_2 = _mm512_permutex2var_epi64(array512[1], permute_hi_idx, array512[5]);
+ array512_3 = _mm512_permutex2var_epi64(array512[9], permute_hi_idx, array512[13]);
+ array512[1] = array512_0; // 1st 8 pairs of col 2/3, and 1st 8 pairs of col 18/19
+ array512[5] = array512_1; // 2nd 8 pairs of col 2/3, and 2nd 8 pairs of col 18/19
+ array512[9] = array512_2; // 1st 8 pairs of col 10/11, and 1st 8 pairs of col 26/27
+ array512[13] = array512_3; // 2nd 8 pairs of col 10/11, and 2nd 8 pairs of col 26/27
+
+ // Half-compose of 4/5, 20/21, 12/13, 28/29 cols
+ array512_0 = _mm512_permutex2var_epi64(array512[2], permute_lo_idx, array512[6]);
+ array512_1 = _mm512_permutex2var_epi64(array512[10], permute_lo_idx, array512[14]);
+ array512_2 = _mm512_permutex2var_epi64(array512[2], permute_hi_idx, array512[6]);
+ array512_3 = _mm512_permutex2var_epi64(array512[10], permute_hi_idx, array512[14]);
+ array512[2] = array512_0; // 1st 8 pairs of col 4/5, and 1st 8 pairs of col 20/21
+ array512[6] = array512_1; // 2nd 8 pairs of col 4/5, and 2nd 8 pairs of col 20/21
+ array512[10] = array512_2; // 1st 8 pairs of col 12/13, and 1st 8 pairs of col 28/29
+ array512[14] = array512_3; // 2nd 8 pairs of col 12/13, and 2nd 8 pairs of col 28/29
+
+ // Half-compose of 6/7, 22/23, 14/15, 30/31 cols
+ array512_0 = _mm512_permutex2var_epi64(array512[3], permute_lo_idx, array512[7]);
+ array512_1 = _mm512_permutex2var_epi64(array512[11], permute_lo_idx, array512[15]);
+ array512_2 = _mm512_permutex2var_epi64(array512[3], permute_hi_idx, array512[7]);
+ array512_3 = _mm512_permutex2var_epi64(array512[11], permute_hi_idx, array512[15]);
+ array512[3] = array512_0; // 1st 8 pairs of col 6/7, and 1st 8 pairs of col 22/23
+ array512[7] = array512_1; // 2nd 8 pairs of col 6/7, and 2nd 8 pairs of col 22/23
+ array512[11] = array512_2; // 1st 8 pairs of col 14/15, and 1st 8 pairs of col 30/31
+ array512[15] = array512_3; // 2nd 8 pairs of col 14/15, and 2nd 8 pairs of col 30/31
+
+ // Compose and store the 0/1 cols
+ array512_0 = _mm512_inserti64x4(array512[0], _mm512_castsi512_si256(array512[4]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+
+ // Compose and store the 2/3 cols
+ array512_0 = _mm512_inserti64x4(array512[1], _mm512_castsi512_si256(array512[5]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+
+ // Compose and store the 4/5 cols
+ array512_0 = _mm512_inserti64x4(array512[2], _mm512_castsi512_si256(array512[6]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+
+ // Compose and store the 6/7 cols
+ array512_0 = _mm512_inserti64x4(array512[3], _mm512_castsi512_si256(array512[7]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+
+ // Compose and store the 8/9 cols
+ array512_0 = _mm512_inserti64x4(array512[8], _mm512_castsi512_si256(array512[12]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+
+ // Compose and store the 10/11 cols
+ array512_0 = _mm512_inserti64x4(array512[9], _mm512_castsi512_si256(array512[13]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+
+ // Compose and store the 12/13 cols
+ array512_0 = _mm512_inserti64x4(array512[10], _mm512_castsi512_si256(array512[14]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+
+ // Compose and store the 14/15 cols
+ array512_0 = _mm512_inserti64x4(array512[11], _mm512_castsi512_si256(array512[15]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+
+ // Compose and store 16 ~ k_rem cols
+ int idx_length = (k_rem + 1 - 16) >> 1;
+ if (idx_length > 4) {
+ for (int idx_k = 0; idx_k < 4; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+ }
+
+ for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+ }
+ } else {
+ for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+ }
+ }
+ }
+}
+
+// K=Any number but will be processed based on 32, M<=16
+void COL_MAJOR_ITCOPY_KERNEL_Kx16m(BLASLONG m, BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
+{
+ bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+ bfloat16 * dst_addr0, * dst_addr1;
+
+ BLASLONG tag_k_32x = k & (~31);
+
+ src_addr0 = A;
+ dst_addr0 = block_A;
+ dst_addr1 = block_A + 32*8;
+
+ __m512i array512_0, array512_1, array512_2, array512_3;
+ __m512i array512[16];
+
+ __m512i M512_EPI64_2 = _mm512_set1_epi64(2);
+ __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
+ __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2);
+
+ for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+ for (int j = 0; j < m; j++) {
+ array512[j] = _mm512_loadu_si512(src_addr0+j*lda+idx_k);
+ }
+ for (int j = m; j < 16; j++) {
+ array512[j] = _mm512_setzero_si512();
+ }
+
+ for (int j = 0; j < 4; j++) {
+ int array_idx = j*4;
+ array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ }
+
+ // Compose and store the 0/1, 2/3, 4/5, 6/7 and 16/17, 18/19, 20/21, 22/23 cols
+ for (int j = 0; j < 4; j++) {
+ array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+ }
+
+ // Compose and store the 8/9, 10/11, 12/13, 14/15 and 24/25, 26/27, 28/29, 30/31 cols
+ for (int j = 0; j < 4; j++) {
+ array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+ array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1);
+ array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_2);
+ _mm512_storeu_si512(dst_addr1, array512_3);
+ dst_addr0 += 32;
+ dst_addr1 += 32;
+ }
+
+ dst_addr0 += 32*8;
+ dst_addr1 += 32*8;
}
-#ifdef DEBUG_PROFILE
- print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
-#endif
+ if (tag_k_32x != k) {
+ int k_rem = k - tag_k_32x;
+ unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem));
+
+ for (int j = 0; j < m; j++) {
+ array512[j] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+j*lda+tag_k_32x);
+ }
+ for (int j = m; j < 16; j++) {
+ array512[j] = _mm512_setzero_si512();
+ }
+
+ for (int j = 0; j < 4; j++) {
+ int array_idx = j*4;
+ array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]);
+ array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]);
+ array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2);
+ array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2);
+ array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3);
+ array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3);
+ }
+
+ for (int j = 0; j < 4; j++) {
+ array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]);
+ array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]);
+ array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]);
+ array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]);
+ array512[j+0] = array512_0; // 1st 8 pairs of col 0/1|2/3|4/5|6/7, and 1st 8 pairs of col 16/17|18/19|20/21|22/23
+ array512[j+4] = array512_1; // 2nd 8 pairs of col 0/1|2/3|4/5|6/7, and 2nd 8 pairs of col 16/17|18/19|20/21|22/23
+ array512[j+8] = array512_2; // 1st 8 pairs of col 8/9|10/11|12/13|14/15, and 1st 8 pairs of col 24/25|26/27|28/29|30/31
+ array512[j+12] = array512_3; // 2nd 8 pairs of col 8/9|10/11|12/13|14/15, and 2nd 8 pairs of col 24/25|26/27|28/29|30/31
+ }
+
+ for (int j = 0; j < 4; j++) {
+ // Compose and store the 0/1 cols
+ array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+ }
+
+ for (int j = 8; j < 12; j++) {
+ array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+ }
+
+ // Compose and store 16 ~ k_rem cols
+ int idx_length = (k_rem + 1 - 16) >> 1;
+ if (idx_length > 4) {
+ for (int idx_k = 0; idx_k < 4; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+ }
+
+ for (int idx_k = 4; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+ }
+ } else {
+ for (int idx_k = 0; idx_k < idx_length; idx_k++) {
+ array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ dst_addr0 += 32;
+ }
+ }
+ }
}
+// COL_MAJOR_ONCOPY_KERNEL_16x32 behaves exactly the same as COL_MAJOR_ITCOPY_KERNEL_Kx16
+#define COL_MAJOR_ONCOPY_KERNEL_16x32 COL_MAJOR_ITCOPY_KERNEL_Kx16
+
void COL_MAJOR_ONCOPY_KERNEL_8x32(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
{
BLASLONG tag_k_32x = k & (~31);
- BLASLONG idx_src_base0, idx_src_base1, idx_src_base2, idx_src_base3, idx_src_base4, idx_src_base5, idx_src_base6, idx_src_base7;
- BLASLONG idx_target_base0;
- idx_src_base0 = 0;
- idx_src_base1 = 1*ldb;
- idx_src_base2 = 2*ldb;
- idx_src_base3 = 3*ldb;
- idx_src_base4 = 4*ldb;
- idx_src_base5 = 5*ldb;
- idx_src_base6 = 6*ldb;
- idx_src_base7 = 7*ldb;
- idx_target_base0 = 0;
+ bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3, * src_addr4, * src_addr5, * src_addr6, * src_addr7;
+ bfloat16 * dst_addr0;
+
+ unsigned char blend_mask = (((unsigned char)0xcc));
+ __m512i permute_idx = _mm512_set_epi64(13, 12, 7, 6, 9, 8, 3, 2);
+
+ src_addr0 = B;
+ src_addr1 = src_addr0 + 1*ldb;
+ src_addr2 = src_addr0 + 2*ldb;
+ src_addr3 = src_addr0 + 3*ldb;
+ src_addr4 = src_addr0 + 4*ldb;
+ src_addr5 = src_addr0 + 5*ldb;
+ src_addr6 = src_addr0 + 6*ldb;
+ src_addr7 = src_addr0 + 7*ldb;
+ dst_addr0 = block_B;
+
+ __m512i array512_0, array512_1, array512_2, array512_3;
+ __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+ __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3;
for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0+idx_k]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base1+idx_k]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_loadu_si512(&B[idx_src_base2+idx_k]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_loadu_si512(&B[idx_src_base3+idx_k]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_loadu_si512(&B[idx_src_base4+idx_k]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_loadu_si512(&B[idx_src_base5+idx_k]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_loadu_si512(&B[idx_src_base6+idx_k]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_loadu_si512(&B[idx_src_base7+idx_k]));
- idx_target_base0 += 32*8;
+ array512_0 = _mm512_loadu_si512(src_addr0+idx_k);
+ array512_1 = _mm512_loadu_si512(src_addr1+idx_k);
+ array512_2 = _mm512_loadu_si512(src_addr2+idx_k);
+ array512_3 = _mm512_loadu_si512(src_addr3+idx_k);
+
+ array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+ array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+ array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+ array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+ array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2);
+ array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2);
+ array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3);
+ array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3);
+
+ array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88);
+ array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd);
+ array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88);
+ array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd);
+
+ array512_0 = _mm512_loadu_si512(src_addr4+idx_k);
+ array512_1 = _mm512_loadu_si512(src_addr5+idx_k);
+ array512_2 = _mm512_loadu_si512(src_addr6+idx_k);
+ array512_3 = _mm512_loadu_si512(src_addr7+idx_k);
+
+ array512_way1_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+ array512_way1_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+ array512_way1_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+ array512_way1_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+ array512_0 = _mm512_unpacklo_epi64(array512_way1_0, array512_way1_2);
+ array512_1 = _mm512_unpackhi_epi64(array512_way1_0, array512_way1_2);
+ array512_2 = _mm512_unpacklo_epi64(array512_way1_1, array512_way1_3);
+ array512_3 = _mm512_unpackhi_epi64(array512_way1_1, array512_way1_3);
+
+ array512_way1_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x22);
+ array512_way1_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x77);
+ array512_way1_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x22);
+ array512_way1_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x77);
+
+ array512_0 = _mm512_mask_blend_epi64(blend_mask, array512_way0_0, array512_way1_0);
+ array512_1 = _mm512_mask_blend_epi64(blend_mask, array512_way0_1, array512_way1_1);
+ array512_2 = _mm512_mask_blend_epi64(blend_mask, array512_way0_2, array512_way1_2);
+ array512_3 = _mm512_mask_blend_epi64(blend_mask, array512_way0_3, array512_way1_3);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ _mm512_storeu_si512(dst_addr0+32, array512_1);
+ _mm512_storeu_si512(dst_addr0+64, array512_2);
+ _mm512_storeu_si512(dst_addr0+96, array512_3);
+
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_idx, array512_way1_0);
+ array512_1 = _mm512_permutex2var_epi64(array512_way0_1, permute_idx, array512_way1_1);
+ array512_2 = _mm512_permutex2var_epi64(array512_way0_2, permute_idx, array512_way1_2);
+ array512_3 = _mm512_permutex2var_epi64(array512_way0_3, permute_idx, array512_way1_3);
+ _mm512_storeu_si512(dst_addr0+128, array512_0);
+ _mm512_storeu_si512(dst_addr0+160, array512_1);
+ _mm512_storeu_si512(dst_addr0+192, array512_2);
+ _mm512_storeu_si512(dst_addr0+224, array512_3);
+
+ dst_addr0 += 256;
}
if (tag_k_32x != k) {
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x)));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0+tag_k_32x]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base1+tag_k_32x]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base2+tag_k_32x]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base3+tag_k_32x]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base4+tag_k_32x]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base5+tag_k_32x]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base6+tag_k_32x]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base7+tag_k_32x]));
+ array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+
+ array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+ array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+ array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+ array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+ array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2);
+ array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2);
+ array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3);
+ array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3);
+
+ array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88);
+ array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd);
+ array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88);
+ array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd);
+
+ array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr4+tag_k_32x);
+ array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr5+tag_k_32x);
+ array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr6+tag_k_32x);
+ array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr7+tag_k_32x);
+
+ array512_way1_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+ array512_way1_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+ array512_way1_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+ array512_way1_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+ array512_0 = _mm512_unpacklo_epi64(array512_way1_0, array512_way1_2);
+ array512_1 = _mm512_unpackhi_epi64(array512_way1_0, array512_way1_2);
+ array512_2 = _mm512_unpacklo_epi64(array512_way1_1, array512_way1_3);
+ array512_3 = _mm512_unpackhi_epi64(array512_way1_1, array512_way1_3);
+
+ array512_way1_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x22);
+ array512_way1_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x77);
+ array512_way1_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x22);
+ array512_way1_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x77);
+
+
+ array512_0 = _mm512_mask_blend_epi64(blend_mask, array512_way0_0, array512_way1_0);
+ array512_1 = _mm512_mask_blend_epi64(blend_mask, array512_way0_1, array512_way1_1);
+ array512_2 = _mm512_mask_blend_epi64(blend_mask, array512_way0_2, array512_way1_2);
+ array512_3 = _mm512_mask_blend_epi64(blend_mask, array512_way0_3, array512_way1_3);
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ _mm512_storeu_si512(dst_addr0+32, array512_1);
+ _mm512_storeu_si512(dst_addr0+64, array512_2);
+ _mm512_storeu_si512(dst_addr0+96, array512_3);
+
+ array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_idx, array512_way1_0);
+ array512_1 = _mm512_permutex2var_epi64(array512_way0_1, permute_idx, array512_way1_1);
+ array512_2 = _mm512_permutex2var_epi64(array512_way0_2, permute_idx, array512_way1_2);
+ array512_3 = _mm512_permutex2var_epi64(array512_way0_3, permute_idx, array512_way1_3);
+ _mm512_storeu_si512(dst_addr0+128, array512_0);
+ _mm512_storeu_si512(dst_addr0+160, array512_1);
+ _mm512_storeu_si512(dst_addr0+192, array512_2);
+ _mm512_storeu_si512(dst_addr0+224, array512_3);
+ }
+}
+
+void COL_MAJOR_ONCOPY_KERNEL_4x32(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
+{
+ BLASLONG tag_k_32x = k & (~31);
+
+ bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
+ bfloat16 * dst_addr0;
+
+ src_addr0 = B;
+ src_addr1 = src_addr0 + 1*ldb;
+ src_addr2 = src_addr0 + 2*ldb;
+ src_addr3 = src_addr0 + 3*ldb;
+ dst_addr0 = block_B;
+
+ __m512i array512_0, array512_1, array512_2, array512_3;
+ __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3;
+
+ for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+ array512_0 = _mm512_loadu_si512(src_addr0+idx_k);
+ array512_1 = _mm512_loadu_si512(src_addr1+idx_k);
+ array512_2 = _mm512_loadu_si512(src_addr2+idx_k);
+ array512_3 = _mm512_loadu_si512(src_addr3+idx_k);
+
+ array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+ array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+ array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+ array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+ array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2);
+ array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2);
+ array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3);
+ array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3);
+
+ array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88);
+ array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd);
+ array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88);
+ array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd);
+
+ array512_0 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0x88);
+ array512_1 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0x88);
+ array512_2 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0xdd);
+ array512_3 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0xdd);
+
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ _mm512_storeu_si512(dst_addr0+32, array512_1);
+ _mm512_storeu_si512(dst_addr0+64, array512_2);
+ _mm512_storeu_si512(dst_addr0+96, array512_3);
+
+ dst_addr0 += 128;
}
-#ifdef DEBUG_PROFILE
- print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B);
-#endif
+ if (tag_k_32x != k) {
+ unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x)));
+ __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
+ array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x);
+ array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x);
+ array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x);
+ array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x);
+
+ array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1);
+ array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1);
+ array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3);
+ array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3);
+
+ array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2);
+ array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2);
+ array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3);
+ array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3);
+
+ array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88);
+ array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd);
+ array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88);
+ array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd);
+
+ array512_0 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0x88);
+ array512_1 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0x88);
+ array512_2 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0xdd);
+ array512_3 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0xdd);
+
+ _mm512_storeu_si512(dst_addr0, array512_0);
+ _mm512_storeu_si512(dst_addr0+32, array512_1);
+ _mm512_storeu_si512(dst_addr0+64, array512_2);
+ _mm512_storeu_si512(dst_addr0+96, array512_3);
+ }
}
void COL_MAJOR_ONCOPY_KERNEL_Nx32(BLASLONG n, BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
{
BLASLONG tag_k_32x = k & (~31);
BLASLONG tag_n_2x = n & (~1);
- BLASLONG idx_src_base0;
- BLASLONG idx_target_base0;
+
+ bfloat16 * src_addr0;
+ bfloat16 * dst_addr0;
BLASLONG LDB_2x = 2*ldb;
- idx_target_base0 = 0;
+ src_addr0 = B;
+ dst_addr0 = block_B;
for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
- idx_src_base0 = 0;
+ src_addr0 = B;
for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) {
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0 + idx_k]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base0 + ldb + idx_k]));
- idx_src_base0 += LDB_2x;
- idx_target_base0 += 64;
+ _mm512_storeu_si512(dst_addr0, _mm512_loadu_si512(src_addr0 + idx_k));
+ _mm512_storeu_si512(dst_addr0 + 32, _mm512_loadu_si512(src_addr0 + ldb + idx_k));
+ src_addr0 += LDB_2x;
+ dst_addr0 += 64;
}
if (tag_n_2x != n) {
- _mm512_storeu_si512(&block_B[idx_target_base0], _mm512_loadu_si512(&B[idx_src_base0 + idx_k]));
- idx_target_base0 += 32;
+ _mm512_storeu_si512(dst_addr0, _mm512_loadu_si512(src_addr0 + idx_k));
+ dst_addr0 += 32;
}
}
if (tag_k_32x != k) {
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x)));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
- idx_src_base0 = 0;
+ src_addr0 = B;
for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) {
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + tag_k_32x]));
- _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + ldb + tag_k_32x]));
- idx_src_base0 += LDB_2x;
- idx_target_base0 += 64;
+ _mm512_storeu_si512(dst_addr0, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + tag_k_32x));
+ _mm512_storeu_si512(dst_addr0 + 32, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + ldb + tag_k_32x));
+ src_addr0 += LDB_2x;
+ dst_addr0 += 64;
}
if (tag_n_2x != n) {
- _mm512_storeu_si512(&block_B[idx_target_base0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + tag_k_32x]));
+ _mm512_storeu_si512(dst_addr0, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + tag_k_32x));
}
}
+}
+
+void COL_MAJOR_OTCOPY_KERNEL_Kx8(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
+{
+ BLASLONG tag_k_2x = k & (~1);
+ unsigned char tail_mask_value = (unsigned char) 0xff;
+ __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
-#ifdef DEBUG_PROFILE
- print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B);
-#endif
+ __m128i array128_0, array128_1, array128_2, array128_3;
+
+ BLASLONG idx_src_base0, idx_src_base1;
+ BLASLONG idx_target_base0, idx_target_base1;
+
+ BLASLONG LDA_2x = 2*ldb;
+ BLASLONG BF16_BLOCK_T_M_2x = 2*8;
+ idx_src_base0 = 0;
+ idx_src_base1 = ldb;
+ idx_target_base0 = 0;
+ idx_target_base1 = 8;
+ for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
+ array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]);
+ array128_1 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base1]);
+ array128_2 = _mm_unpacklo_epi16(array128_0, array128_1);
+ array128_3 = _mm_unpackhi_epi16(array128_0, array128_1);
+ _mm_storeu_epi32(&block_B[idx_target_base0], array128_2);
+ _mm_storeu_epi32(&block_B[idx_target_base1], array128_3);
+
+ idx_src_base0 += LDA_2x;
+ idx_src_base1 += LDA_2x;
+ idx_target_base0 += BF16_BLOCK_T_M_2x;
+ idx_target_base1 += BF16_BLOCK_T_M_2x;
+ }
+
+ if (tag_k_2x != k) {
+ __m128i ZERO128 = _mm_setzero_si128();
+ array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]);
+ array128_2 = _mm_unpacklo_epi16(array128_0, ZERO128);
+ array128_3 = _mm_unpackhi_epi16(array128_0, ZERO128);
+ _mm_storeu_epi32(&block_B[idx_target_base0], array128_2);
+ _mm_storeu_epi32(&block_B[idx_target_base1], array128_3);
+ }
+}
+
+void COL_MAJOR_OTCOPY_KERNEL_Kx8m(BLASLONG k, BLASLONG n, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
+{
+ BLASLONG tag_k_2x = k & (~1);
+ unsigned char tail_mask = (((unsigned char)0xff) >> (8-n));
+
+ __m128i array128_0, array128_1, array128_2, array128_3;
+
+ BLASLONG idx_src_base0, idx_src_base1;
+ BLASLONG idx_target_base0, idx_target_base1;
+
+ BLASLONG LDA_2x = 2*ldb;
+ BLASLONG BF16_BLOCK_T_M_2x = 2*8;
+ idx_src_base0 = 0;
+ idx_src_base1 = ldb;
+ idx_target_base0 = 0;
+ idx_target_base1 = 8;
+ for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
+ array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]);
+ array128_1 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base1]);
+ array128_2 = _mm_unpacklo_epi16(array128_0, array128_1);
+ array128_3 = _mm_unpackhi_epi16(array128_0, array128_1);
+ _mm_storeu_epi32(&block_B[idx_target_base0], array128_2);
+ _mm_storeu_epi32(&block_B[idx_target_base1], array128_3);
+
+ idx_src_base0 += LDA_2x;
+ idx_src_base1 += LDA_2x;
+ idx_target_base0 += BF16_BLOCK_T_M_2x;
+ idx_target_base1 += BF16_BLOCK_T_M_2x;
+ }
+
+ if (tag_k_2x != k) {
+ __m128i ZERO128 = _mm_setzero_si128();
+ array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]);
+ array128_2 = _mm_unpacklo_epi16(array128_0, ZERO128);
+ array128_3 = _mm_unpackhi_epi16(array128_0, ZERO128);
+ _mm_storeu_epi32(&block_B[idx_target_base0], array128_2);
+ _mm_storeu_epi32(&block_B[idx_target_base1], array128_3);
+ }
}
-// Scale matrix C while beta is not ZERO or ONE
+// Scale matrix C when beta is not ZERO or ONE
void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc)
{
- BLASLONG tag_n_Nx = N & (~3);
- BLASLONG tag_n_Mx = M & (~15);
+ float * C_addr0 = C;
+ float * C_addr1 = C + ldc;
+ float * C_addr2 = C + ldc*2;
+ float * C_addr3 = C + ldc*3;
BLASLONG LDC4x = ldc*4;
- BLASLONG idx_base_0 = 0;
- BLASLONG idx_base_1 = ldc;
- BLASLONG idx_base_2 = ldc*2;
- BLASLONG idx_base_3 = ldc*3;
-
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
__m512 array_512_0, array_512_1, array_512_2, array_512_3;
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
+ if (Order == CblasRowMajor) {
+ blasint tmp = M;
+ M = N;
+ N = tmp;
+ }
- if (Order == CblasColMajor) {
- for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
- for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
- array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]);
- array_512_1 = _mm512_loadu_ps(&C[idx_base_1+idx_m]);
- array_512_2 = _mm512_loadu_ps(&C[idx_base_2+idx_m]);
- array_512_3 = _mm512_loadu_ps(&C[idx_base_3+idx_m]);
+ BLASLONG tag_n_Nx = N & (~3);
+ BLASLONG tag_n_Mx = M & (~15);
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
+ for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
+ for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
+ array_512_0 = _mm512_loadu_ps(C_addr0 + idx_m);
+ array_512_1 = _mm512_loadu_ps(C_addr1 + idx_m);
+ array_512_2 = _mm512_loadu_ps(C_addr2 + idx_m);
+ array_512_3 = _mm512_loadu_ps(C_addr3 + idx_m);
+
+ array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
+ array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
+ array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
+ array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
+
+ _mm512_storeu_ps(C_addr0 + idx_m, array_512_0);
+ _mm512_storeu_ps(C_addr1 + idx_m, array_512_1);
+ _mm512_storeu_ps(C_addr2 + idx_m, array_512_2);
+ _mm512_storeu_ps(C_addr3 + idx_m, array_512_3);
+ }
- array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
- array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
- array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
- array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
-
- _mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0);
- _mm512_storeu_ps(&C[idx_base_1+idx_m], array_512_1);
- _mm512_storeu_ps(&C[idx_base_2+idx_m], array_512_2);
- _mm512_storeu_ps(&C[idx_base_3+idx_m], array_512_3);
- }
+ if (tag_n_Mx != M) {
+ array_512_0 = _mm512_maskz_loadu_ps(tail_mask, C_addr0 + tag_n_Mx);
+ array_512_1 = _mm512_maskz_loadu_ps(tail_mask, C_addr1 + tag_n_Mx);
+ array_512_2 = _mm512_maskz_loadu_ps(tail_mask, C_addr2 + tag_n_Mx);
+ array_512_3 = _mm512_maskz_loadu_ps(tail_mask, C_addr3 + tag_n_Mx);
+
+ array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
+ array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
+ array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
+ array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
+
+ _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, array_512_0);
+ _mm512_mask_storeu_ps(C_addr1 + tag_n_Mx, tail_mask, array_512_1);
+ _mm512_mask_storeu_ps(C_addr2 + tag_n_Mx, tail_mask, array_512_2);
+ _mm512_mask_storeu_ps(C_addr3 + tag_n_Mx, tail_mask, array_512_3);
+ }
- if (tag_n_Mx != M) {
- array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]);
- array_512_1 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_1+tag_n_Mx]);
- array_512_2 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_2+tag_n_Mx]);
- array_512_3 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_3+tag_n_Mx]);
+ C_addr0 += LDC4x;
+ C_addr1 += LDC4x;
+ C_addr2 += LDC4x;
+ C_addr3 += LDC4x;
+ }
+ if (tag_n_Nx != N) {
+ for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
+ for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
+ array_512_0 = _mm512_loadu_ps(C_addr0 + idx_m);
array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
- array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
- array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
- array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
-
- _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0);
- _mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, array_512_1);
- _mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, array_512_2);
- _mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, array_512_3);
+ _mm512_storeu_ps(C_addr0 + idx_m, array_512_0);
}
- idx_base_0 += LDC4x;
- idx_base_1 += LDC4x;
- idx_base_2 += LDC4x;
- idx_base_3 += LDC4x;
- }
-
- if (tag_n_Nx != N) {
- for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
- for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
- array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]);
- array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
- _mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0);
- }
-
- if (tag_n_Mx != M) {
- array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]);
- array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
- _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0);
- }
- idx_base_0 += ldc;
+ if (tag_n_Mx != M) {
+ array_512_0 = _mm512_maskz_loadu_ps(tail_mask, C_addr0 + tag_n_Mx);
+ array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
+ _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, array_512_0);
}
+ C_addr0 += ldc;
}
- } else {
-
}
}
-// Scale matrix C while beta is not ZERO or ONE
+// Zero C matrix when Beta is 0
void sbgemm_zero_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, float *C, OPENBLAS_CONST blasint ldc)
{
- BLASLONG tag_n_Nx = N & (~3);
- BLASLONG tag_n_Mx = M & (~15);
+ float * C_addr0 = C;
+ float * C_addr1 = C + ldc;
+ float * C_addr2 = C + ldc*2;
+ float * C_addr3 = C + ldc*3;
BLASLONG LDC4x = ldc*4;
- BLASLONG idx_base_0 = 0;
- BLASLONG idx_base_1 = ldc;
- BLASLONG idx_base_2 = ldc*2;
- BLASLONG idx_base_3 = ldc*3;
-
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
__m512 ZEROVECTOR = _mm512_setzero_ps();
- if (Order == CblasColMajor) {
- for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
- for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
- _mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR);
- _mm512_storeu_ps(&C[idx_base_1+idx_m], ZEROVECTOR);
- _mm512_storeu_ps(&C[idx_base_2+idx_m], ZEROVECTOR);
- _mm512_storeu_ps(&C[idx_base_3+idx_m], ZEROVECTOR);
- }
+ if (Order == CblasRowMajor) {
+ blasint tmp = M;
+ M = N;
+ N = tmp;
+ }
- if (tag_n_Mx != M) {
- _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR);
- _mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, ZEROVECTOR);
- _mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, ZEROVECTOR);
- _mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, ZEROVECTOR);
- }
+ BLASLONG tag_n_Nx = N & (~3);
+ BLASLONG tag_n_Mx = M & (~15);
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
+ for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
+ for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
+ _mm512_storeu_ps(C_addr0 + idx_m, ZEROVECTOR);
+ _mm512_storeu_ps(C_addr1 + idx_m, ZEROVECTOR);
+ _mm512_storeu_ps(C_addr2 + idx_m, ZEROVECTOR);
+ _mm512_storeu_ps(C_addr3 + idx_m, ZEROVECTOR);
+ }
- idx_base_0 += LDC4x;
- idx_base_1 += LDC4x;
- idx_base_2 += LDC4x;
- idx_base_3 += LDC4x;
+ if (tag_n_Mx != M) {
+ _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, ZEROVECTOR);
+ _mm512_mask_storeu_ps(C_addr1 + tag_n_Mx, tail_mask, ZEROVECTOR);
+ _mm512_mask_storeu_ps(C_addr2 + tag_n_Mx, tail_mask, ZEROVECTOR);
+ _mm512_mask_storeu_ps(C_addr3 + tag_n_Mx, tail_mask, ZEROVECTOR);
}
- if (tag_n_Nx != N) {
- for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
- for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
- _mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR);
- }
+ C_addr0 += LDC4x;
+ C_addr1 += LDC4x;
+ C_addr2 += LDC4x;
+ C_addr3 += LDC4x;
+ }
- if (tag_n_Mx != M) {
- _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR);
- }
- idx_base_0 += ldc;
+ if (tag_n_Nx != N) {
+ for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
+ for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
+ _mm512_storeu_ps(C_addr0 + idx_m, ZEROVECTOR);
}
- }
- } else {
+ if (tag_n_Mx != M) {
+ _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, ZEROVECTOR);
+ }
+ C_addr0 += ldc;
+ }
}
-}
\ No newline at end of file
+}
#include "bf16_common_macros.h"
#include <immintrin.h>
+/* These macros are needed and should be placed at the right place
+#define BF16_BLOCK_STEP_N 8
+#define BF16_BLOCK_THRES_K 1024
+#define BF16_BLOCK_THRES_M 32
+#define BF16_BLOCK_THRES_N 1024
+
+#define A(i,j) A[(i)*lda+(j)]
+#define B(i,j) B[(i)*ldb+(j)]
+#define C(i,j) C[(i)*ldc+(j)]
+
+#define ONE 1.e0f
+#define ZERO 0.e0f
+*/
+
#undef STORE16_COMPLETE_RESULT
#undef STORE16_MASK_COMPLETE_RESULT
-#undef SBGEMM_BLOCK_KERNEL_32x8x32
-#undef SBGEMM_BLOCK_KERNEL_16x8x32
-#undef SBGEMM_BLOCK_KERNEL_32xNx32
-#undef SBGEMM_BLOCK_KERNEL_16xNx32
-#undef SBGEMM_BLOCKING_KERNEL_2
+#undef SBGEMM_BLOCK_KERNEL_NN_32x8xK
+#undef SBGEMM_BLOCK_KERNEL_NN_16x8xK
+#undef SBGEMM_BLOCK_KERNEL_NN_32xNx32
+#undef SBGEMM_BLOCK_KERNEL_NN_16xNx32
+#undef SBGEMM_BLOCK_KERNEL_NT_32x8xK
+#undef SBGEMM_BLOCK_KERNEL_NT_16x8xK
+#undef SBGEMM_BLOCK_KERNEL_NT_32xNxK
+#undef SBGEMM_BLOCK_KERNEL_NT_16xNxK
+#undef SBGEMM_BLOCK_KERNEL_TN_32x8xK
+#undef SBGEMM_BLOCK_KERNEL_TN_16x8xK
+#undef SBGEMM_BLOCK_KERNEL_TN_32xNx32
+#undef SBGEMM_BLOCK_KERNEL_TN_16xNx32
+#undef SBGEMM_BLOCK_KERNEL_TT_32x8xK
+#undef SBGEMM_BLOCK_KERNEL_TT_16x8xK
+#undef SBGEMM_BLOCK_KERNEL_TT_32xNxK
+#undef SBGEMM_BLOCK_KERNEL_TT_16xNxK
+#undef SBGEMM_BLOCKING_KERNEL_NN
+#undef SBGEMM_BLOCKING_KERNEL_NT
+#undef SBGEMM_BLOCKING_KERNEL_TN
+#undef SBGEMM_BLOCKING_KERNEL_TT
#ifndef ONE_ALPHA // ALPHA is not ONE
- #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
- #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
- #define SBGEMM_BLOCK_KERNEL_32x8x32 sbgemm_block_kernel_32x8x32_alpha
- #define SBGEMM_BLOCK_KERNEL_16x8x32 sbgemm_block_kernel_16x8x32_alpha
- #define SBGEMM_BLOCK_KERNEL_32xNx32 sbgemm_block_kernel_32xNx32_alpha
- #define SBGEMM_BLOCK_KERNEL_16xNx32 sbgemm_block_kernel_16xNx32_alpha
- #define SBGEMM_BLOCKING_KERNEL_2 sbgemm_blocking_kernel_2_alpha
+ #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
+ #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
+
+ #define SBGEMM_BLOCK_KERNEL_NN_32x8xK sbgemm_block_kernel_nn_32x8xK_alpha
+ #define SBGEMM_BLOCK_KERNEL_NN_16x8xK sbgemm_block_kernel_nn_16x8xK_alpha
+ #define SBGEMM_BLOCK_KERNEL_NN_32xNx32 sbgemm_block_kernel_nn_32xNx32_alpha
+ #define SBGEMM_BLOCK_KERNEL_NN_16xNx32 sbgemm_block_kernel_nn_16xNx32_alpha
+
+ #define SBGEMM_BLOCK_KERNEL_NT_32x8xK SBGEMM_BLOCK_KERNEL_NN_32x8xK
+ #define SBGEMM_BLOCK_KERNEL_NT_16x8xK SBGEMM_BLOCK_KERNEL_NN_16x8xK
+ #define SBGEMM_BLOCK_KERNEL_NT_32xNxK sbgemm_block_kernel_nt_32xNxK_alpha
+ #define SBGEMM_BLOCK_KERNEL_NT_16xNxK sbgemm_block_kernel_nt_16xNxK_alpha
+
+ #define SBGEMM_BLOCK_KERNEL_TN_32x8xK sbgemm_block_kernel_tn_32x8xK_alpha
+ #define SBGEMM_BLOCK_KERNEL_TN_16x8xK sbgemm_block_kernel_tn_16x8xK_alpha
+ #define SBGEMM_BLOCK_KERNEL_TN_32xNx32 sbgemm_block_kernel_tn_32xNx32_alpha
+ #define SBGEMM_BLOCK_KERNEL_TN_16xNx32 sbgemm_block_kernel_tn_16xNx32_alpha
+
+ #define SBGEMM_BLOCK_KERNEL_TT_32x8xK SBGEMM_BLOCK_KERNEL_TN_32x8xK
+ #define SBGEMM_BLOCK_KERNEL_TT_16x8xK SBGEMM_BLOCK_KERNEL_TN_16x8xK
+ #define SBGEMM_BLOCK_KERNEL_TT_32xNxK sbgemm_block_kernel_tt_32xNxK_alpha
+ #define SBGEMM_BLOCK_KERNEL_TT_16xNxK sbgemm_block_kernel_tt_16xNxK_alpha
+
+ #define SBGEMM_BLOCKING_KERNEL_NN sbgemm_blocking_kernel_nn_alpha
+ #define SBGEMM_BLOCKING_KERNEL_NT sbgemm_blocking_kernel_nt_alpha
+ #define SBGEMM_BLOCKING_KERNEL_TN sbgemm_blocking_kernel_tn_alpha
+ #define SBGEMM_BLOCKING_KERNEL_TT sbgemm_blocking_kernel_tt_alpha
#else // ALPHA is ONE
- #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ONE_ONE
- #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ONE_ONE
- #define SBGEMM_BLOCK_KERNEL_32x8x32 sbgemm_block_kernel_32x8x32_one
- #define SBGEMM_BLOCK_KERNEL_16x8x32 sbgemm_block_kernel_16x8x32_one
- #define SBGEMM_BLOCK_KERNEL_32xNx32 sbgemm_block_kernel_32xNx32_one
- #define SBGEMM_BLOCK_KERNEL_16xNx32 sbgemm_block_kernel_16xNx32_one
- #define SBGEMM_BLOCKING_KERNEL_2 sbgemm_blocking_kernel_2_one
+ #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ONE_ONE
+ #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ONE_ONE
+
+ #define SBGEMM_BLOCK_KERNEL_NN_32x8xK sbgemm_block_kernel_nn_32x8xK_one
+ #define SBGEMM_BLOCK_KERNEL_NN_16x8xK sbgemm_block_kernel_nn_16x8xK_one
+ #define SBGEMM_BLOCK_KERNEL_NN_32xNx32 sbgemm_block_kernel_nn_32xNx32_one
+ #define SBGEMM_BLOCK_KERNEL_NN_16xNx32 sbgemm_block_kernel_nn_16xNx32_one
+
+ #define SBGEMM_BLOCK_KERNEL_NT_32x8xK SBGEMM_BLOCK_KERNEL_NN_32x8xK
+ #define SBGEMM_BLOCK_KERNEL_NT_16x8xK SBGEMM_BLOCK_KERNEL_NN_16x8xK
+ #define SBGEMM_BLOCK_KERNEL_NT_32xNxK sbgemm_block_kernel_nt_32xNxK_one
+ #define SBGEMM_BLOCK_KERNEL_NT_16xNxK sbgemm_block_kernel_nt_16xNxK_one
+
+ #define SBGEMM_BLOCK_KERNEL_TN_32x8xK sbgemm_block_kernel_tn_32x8xK_one
+ #define SBGEMM_BLOCK_KERNEL_TN_16x8xK sbgemm_block_kernel_tn_16x8xK_one
+ #define SBGEMM_BLOCK_KERNEL_TN_32xNx32 sbgemm_block_kernel_tn_32xNx32_one
+ #define SBGEMM_BLOCK_KERNEL_TN_16xNx32 sbgemm_block_kernel_tn_16xNx32_one
+
+ #define SBGEMM_BLOCK_KERNEL_TT_32x8xK SBGEMM_BLOCK_KERNEL_TN_32x8xK
+ #define SBGEMM_BLOCK_KERNEL_TT_16x8xK SBGEMM_BLOCK_KERNEL_TN_16x8xK
+ #define SBGEMM_BLOCK_KERNEL_TT_32xNxK sbgemm_block_kernel_tt_32xNxK_one
+ #define SBGEMM_BLOCK_KERNEL_TT_16xNxK sbgemm_block_kernel_tt_16xNxK_one
+
+ #define SBGEMM_BLOCKING_KERNEL_NN sbgemm_blocking_kernel_nn_one
+ #define SBGEMM_BLOCKING_KERNEL_NT sbgemm_blocking_kernel_nt_one
+ #define SBGEMM_BLOCKING_KERNEL_TN sbgemm_blocking_kernel_tn_one
+ #define SBGEMM_BLOCKING_KERNEL_TT sbgemm_blocking_kernel_tt_one
#endif
+extern bfloat16 * block_A;
+extern bfloat16 * block_B;
+/* --------------------------------------------- NN kernels ------------------------------------------ */
// SBGEMM Kernel for 16<M<=32, N=8, K can be any number, but the processing will take 32 as a base
#ifndef ONE_ALPHA // ALPHA is not ONE
-void sbgemm_block_kernel_32x8x32_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_32x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#else // ALPHA is ONE
-void sbgemm_block_kernel_32x8x32_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_32x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#endif
{
- int SHUFFLE_MAGIC_NO = 0x39;
- BLASLONG tag_k_32x = k & (~31);
- BLASLONG idxA_base = 0;
- BLASLONG idxB_base = 0;
- BLASLONG width = 32;
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
result_512_14 = _mm512_setzero_ps();
result_512_15 = _mm512_setzero_ps();
- for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
- // Load B with unroll 8
- idxB_base = idx_k << 3;
- arrayB_512_0 = _mm512_loadu_si512(&B[idxB_base + 32*0]);
- arrayB_512_1 = _mm512_loadu_si512(&B[idxB_base + 32*1]);
- arrayB_512_2 = _mm512_loadu_si512(&B[idxB_base + 32*2]);
- arrayB_512_3 = _mm512_loadu_si512(&B[idxB_base + 32*3]);
- arrayB_512_4 = _mm512_loadu_si512(&B[idxB_base + 32*4]);
- arrayB_512_5 = _mm512_loadu_si512(&B[idxB_base + 32*5]);
- arrayB_512_6 = _mm512_loadu_si512(&B[idxB_base + 32*6]);
- arrayB_512_7 = _mm512_loadu_si512(&B[idxB_base + 32*7]);
-
- if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
-
- for (BLASLONG idx = 0; idx < width;) {
- // Each two rows are a group for 32-pair bf16 elements
- idxA_base = idx << 5;
- arrayA_512_0 = _mm512_loadu_si512(&A[idxA_base]);
- arrayA_512_1 = _mm512_loadu_si512(&A[idxA_base + 32]);
-
- result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_0)));
- result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_1)));
- result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_2)));
- result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_3)));
- result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_4)));
- result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_5)));
- result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_6)));
- result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_7)));
- result_512_8 = _mm512_dpbf16_ps(result_512_8, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_0)));
- result_512_9 = _mm512_dpbf16_ps(result_512_9, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_1)));
- result_512_10 = _mm512_dpbf16_ps(result_512_10, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_2)));
- result_512_11 = _mm512_dpbf16_ps(result_512_11, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_3)));
- result_512_12 = _mm512_dpbf16_ps(result_512_12, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_4)));
- result_512_13 = _mm512_dpbf16_ps(result_512_13, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_5)));
- result_512_14 = _mm512_dpbf16_ps(result_512_14, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_6)));
- result_512_15 = _mm512_dpbf16_ps(result_512_15, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_7)));
-
- arrayB_512_0 = _mm512_shuffle_epi32(arrayB_512_0, SHUFFLE_MAGIC_NO);
- arrayB_512_1 = _mm512_shuffle_epi32(arrayB_512_1, SHUFFLE_MAGIC_NO);
- arrayB_512_2 = _mm512_shuffle_epi32(arrayB_512_2, SHUFFLE_MAGIC_NO);
- arrayB_512_3 = _mm512_shuffle_epi32(arrayB_512_3, SHUFFLE_MAGIC_NO);
- arrayB_512_4 = _mm512_shuffle_epi32(arrayB_512_4, SHUFFLE_MAGIC_NO);
- arrayB_512_5 = _mm512_shuffle_epi32(arrayB_512_5, SHUFFLE_MAGIC_NO);
- arrayB_512_6 = _mm512_shuffle_epi32(arrayB_512_6, SHUFFLE_MAGIC_NO);
- arrayB_512_7 = _mm512_shuffle_epi32(arrayB_512_7, SHUFFLE_MAGIC_NO);
+ for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+ // Each two rows are a group for 32-pair bf16 elements
+ arrayA_512_0 = _mm512_loadu_si512(A_addr);
+ arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
+
+ _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0);
+ _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1);
+ _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2);
+ _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3);
+ _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4);
+ _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
+ _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
+ _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
+
+ result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
+ result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
+ result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
+ result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
+ result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
+ result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
+ result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
+ result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
+
+ result_512_8 = _mm512_dpbf16_ps(result_512_8, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_0);
+ result_512_9 = _mm512_dpbf16_ps(result_512_9, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_1);
+ result_512_10 = _mm512_dpbf16_ps(result_512_10, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_2);
+ result_512_11 = _mm512_dpbf16_ps(result_512_11, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_3);
+ result_512_12 = _mm512_dpbf16_ps(result_512_12, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_4);
+ result_512_13 = _mm512_dpbf16_ps(result_512_13, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_5);
+ result_512_14 = _mm512_dpbf16_ps(result_512_14, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_6);
+ result_512_15 = _mm512_dpbf16_ps(result_512_15, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_7);
- idx += 2;
- // Every 4 loops we need to switch to next 128 bits of arrayB registers
- if ((idx & (~7)) == idx) {
- arrayB_512_0 = _mm512_shuffle_i32x4(arrayB_512_0, arrayB_512_0, SHUFFLE_MAGIC_NO);
- arrayB_512_1 = _mm512_shuffle_i32x4(arrayB_512_1, arrayB_512_1, SHUFFLE_MAGIC_NO);
- arrayB_512_2 = _mm512_shuffle_i32x4(arrayB_512_2, arrayB_512_2, SHUFFLE_MAGIC_NO);
- arrayB_512_3 = _mm512_shuffle_i32x4(arrayB_512_3, arrayB_512_3, SHUFFLE_MAGIC_NO);
- arrayB_512_4 = _mm512_shuffle_i32x4(arrayB_512_4, arrayB_512_4, SHUFFLE_MAGIC_NO);
- arrayB_512_5 = _mm512_shuffle_i32x4(arrayB_512_5, arrayB_512_5, SHUFFLE_MAGIC_NO);
- arrayB_512_6 = _mm512_shuffle_i32x4(arrayB_512_6, arrayB_512_6, SHUFFLE_MAGIC_NO);
- arrayB_512_7 = _mm512_shuffle_i32x4(arrayB_512_7, arrayB_512_7, SHUFFLE_MAGIC_NO);
- }
- }
+ // Load B with unroll 8
+ B_addr += 16;
+ // Load A with unroll 64
+ A_addr += 64;
}
if (m != 32) {
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base1, result_512_8);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base0, result_512_9);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base1, result_512_9);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*0]))
- STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*0+16]), tail_mask)
- STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*1]))
- STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*1+16]), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*1))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*1 + 16), tail_mask)
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base0, result_512_10);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base1, result_512_10);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base0, result_512_11);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base1, result_512_11);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*2]))
- STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*2+16]), tail_mask)
- STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*3]))
- STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*3+16]), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*2))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*2 + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*3))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*3 + 16), tail_mask)
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base0, result_512_12);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base1, result_512_12);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base0, result_512_13);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base1, result_512_13);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*4]))
- STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*4+16]), tail_mask)
- STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*5]))
- STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*5+16]), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*4))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*4 + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*5))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*5 + 16), tail_mask)
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base0, result_512_14);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base1, result_512_14);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base0, result_512_15);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base1, result_512_15);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*6]))
- STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*6+16]), tail_mask)
- STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*7]))
- STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*7+16]), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*6))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*6 + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*7))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*7 + 16), tail_mask)
} else {
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base0, result_512_8);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base1, result_512_8);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base0, result_512_9);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base1, result_512_9);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*0]))
- STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*0+16]))
- STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*1]))
- STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*1+16]))
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr))
+ STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + 16))
+ STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*1))
+ STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*1 + 16))
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base0, result_512_10);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base1, result_512_10);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base0, result_512_11);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base1, result_512_11);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*2]))
- STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*2+16]))
- STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*3]))
- STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*3+16]))
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*2))
+ STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*2 + 16))
+ STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*3))
+ STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*3 + 16))
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base0, result_512_12);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base1, result_512_12);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base0, result_512_13);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base1, result_512_13);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*4]))
- STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*4+16]))
- STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*5]))
- STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*5+16]))
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*4))
+ STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*4 + 16))
+ STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*5))
+ STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*5 + 16))
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base0, result_512_14);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base1, result_512_14);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base0, result_512_15);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base1, result_512_15);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*6]))
- STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*6+16]))
- STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*7]))
- STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*7+16]))
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*6))
+ STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*6 + 16))
+ STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*7))
+ STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*7 + 16))
}
}
-// SBGEMM Kernel for M<=16, N=8, K can be any number, but the processing will take 32 as a base
+// SBGEMM Kernel for M<=16, N=8, K can be any number
#ifndef ONE_ALPHA // ALPHA is not ONE
-void sbgemm_block_kernel_16x8x32_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_16x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#else // ALPHA is ONE
-void sbgemm_block_kernel_16x8x32_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_16x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#endif
{
- int SHUFFLE_MAGIC_NO = 0x39;
- BLASLONG tag_k_32x = k & (~31);
- BLASLONG idxB_base = 0;
- BLASLONG width = 32;
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
result_512_6 = _mm512_setzero_ps();
result_512_7 = _mm512_setzero_ps();
- for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
- // Load B with unroll 8
- idxB_base = idx_k << 3;
- arrayB_512_0 = _mm512_loadu_si512(&B[idxB_base + 32*0]);
- arrayB_512_1 = _mm512_loadu_si512(&B[idxB_base + 32*1]);
- arrayB_512_2 = _mm512_loadu_si512(&B[idxB_base + 32*2]);
- arrayB_512_3 = _mm512_loadu_si512(&B[idxB_base + 32*3]);
- arrayB_512_4 = _mm512_loadu_si512(&B[idxB_base + 32*4]);
- arrayB_512_5 = _mm512_loadu_si512(&B[idxB_base + 32*5]);
- arrayB_512_6 = _mm512_loadu_si512(&B[idxB_base + 32*6]);
- arrayB_512_7 = _mm512_loadu_si512(&B[idxB_base + 32*7]);
-
- if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
-
- for (BLASLONG idx = 0; idx < width;) {
- // Each two rows are a group for 32-pair bf16 elements
- // Load two rows into a 512 register
- arrayA_512_0 = _mm512_loadu_si512(&A[idx<<4]);
-
- result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_0)));
- result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_1)));
- result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_2)));
- result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_3)));
- result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_4)));
- result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_5)));
- result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_6)));
- result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_7)));
-
- arrayB_512_0 = _mm512_shuffle_epi32(arrayB_512_0, SHUFFLE_MAGIC_NO);
- arrayB_512_1 = _mm512_shuffle_epi32(arrayB_512_1, SHUFFLE_MAGIC_NO);
- arrayB_512_2 = _mm512_shuffle_epi32(arrayB_512_2, SHUFFLE_MAGIC_NO);
- arrayB_512_3 = _mm512_shuffle_epi32(arrayB_512_3, SHUFFLE_MAGIC_NO);
- arrayB_512_4 = _mm512_shuffle_epi32(arrayB_512_4, SHUFFLE_MAGIC_NO);
- arrayB_512_5 = _mm512_shuffle_epi32(arrayB_512_5, SHUFFLE_MAGIC_NO);
- arrayB_512_6 = _mm512_shuffle_epi32(arrayB_512_6, SHUFFLE_MAGIC_NO);
- arrayB_512_7 = _mm512_shuffle_epi32(arrayB_512_7, SHUFFLE_MAGIC_NO);
+ for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+ // Each two rows are a group for 32-pair bf16 elements
+ // Load two rows into a 512 register
+ arrayA_512_0 = _mm512_loadu_si512(A_addr);
+
+ _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0);
+ _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1);
+ _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2);
+ _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3);
+ _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4);
+ _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
+ _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
+ _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
+
+ result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
+ result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
+ result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
+ result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
+ result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
+ result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
+ result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
+ result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
- idx += 2;
- // Every 4 loops we need to switch to next 128 bits of arrayB registers
- if ((idx & (~7)) == idx) {
- arrayB_512_0 = _mm512_shuffle_i32x4(arrayB_512_0, arrayB_512_0, SHUFFLE_MAGIC_NO);
- arrayB_512_1 = _mm512_shuffle_i32x4(arrayB_512_1, arrayB_512_1, SHUFFLE_MAGIC_NO);
- arrayB_512_2 = _mm512_shuffle_i32x4(arrayB_512_2, arrayB_512_2, SHUFFLE_MAGIC_NO);
- arrayB_512_3 = _mm512_shuffle_i32x4(arrayB_512_3, arrayB_512_3, SHUFFLE_MAGIC_NO);
- arrayB_512_4 = _mm512_shuffle_i32x4(arrayB_512_4, arrayB_512_4, SHUFFLE_MAGIC_NO);
- arrayB_512_5 = _mm512_shuffle_i32x4(arrayB_512_5, arrayB_512_5, SHUFFLE_MAGIC_NO);
- arrayB_512_6 = _mm512_shuffle_i32x4(arrayB_512_6, arrayB_512_6, SHUFFLE_MAGIC_NO);
- arrayB_512_7 = _mm512_shuffle_i32x4(arrayB_512_7, arrayB_512_7, SHUFFLE_MAGIC_NO);
- }
- }
+ // Load B with unroll 8
+ B_addr += 16;
+ // Load A with unroll 16
+ A_addr += 32;
}
if (m != 16) {
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8);
result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8);
result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8);
result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8);
- STORE16_MASK_COMPLETE_RESULT(result_512_0, (&C[ldc*0]), tail_mask)
- STORE16_MASK_COMPLETE_RESULT(result_512_1, (&C[ldc*1]), tail_mask)
- STORE16_MASK_COMPLETE_RESULT(result_512_2, (&C[ldc*2]), tail_mask)
- STORE16_MASK_COMPLETE_RESULT(result_512_3, (&C[ldc*3]), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_0, (C_addr), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3), tail_mask)
result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8);
result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8);
result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8);
result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8);
- STORE16_MASK_COMPLETE_RESULT(result_512_4, (&C[ldc*4]), tail_mask)
- STORE16_MASK_COMPLETE_RESULT(result_512_5, (&C[ldc*5]), tail_mask)
- STORE16_MASK_COMPLETE_RESULT(result_512_6, (&C[ldc*6]), tail_mask)
- STORE16_MASK_COMPLETE_RESULT(result_512_7, (&C[ldc*7]), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7), tail_mask)
} else {
result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8);
result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8);
result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8);
result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8);
- STORE16_COMPLETE_RESULT(result_512_0, (&C[ldc*0]))
- STORE16_COMPLETE_RESULT(result_512_1, (&C[ldc*1]))
- STORE16_COMPLETE_RESULT(result_512_2, (&C[ldc*2]))
- STORE16_COMPLETE_RESULT(result_512_3, (&C[ldc*3]))
+ STORE16_COMPLETE_RESULT(result_512_0, (C_addr))
+ STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1))
+ STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2))
+ STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3))
result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8);
result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8);
result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8);
result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8);
- STORE16_COMPLETE_RESULT(result_512_4, (&C[ldc*4]))
- STORE16_COMPLETE_RESULT(result_512_5, (&C[ldc*5]))
- STORE16_COMPLETE_RESULT(result_512_6, (&C[ldc*6]))
- STORE16_COMPLETE_RESULT(result_512_7, (&C[ldc*7]))
+ STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4))
+ STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5))
+ STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6))
+ STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7))
}
}
// SBGEMM Kernel for 16<M<=32, N<8, K can be any number, but the processing will take 32 as a base
#ifndef ONE_ALPHA // ALPHA is not ONE
-void sbgemm_block_kernel_32xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_32xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#else // ALPHA is ONE
-void sbgemm_block_kernel_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#endif
{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
int SHUFFLE_MAGIC_NO = 0x39;
BLASLONG tag_k_32x = k & (~31);
- BLASLONG idxA_base = 0;
- BLASLONG idxB_base = 0;
- BLASLONG width = 32;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
result_512[i+1] = _mm512_setzero_ps();
}
- for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
+ for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
// Load B with unroll n
for (int i = 0; i < n; i ++) {
- arrayB_512[i] = _mm512_loadu_si512(&B[idxB_base]);
- idxB_base += 32;
+ arrayB_512[i] = _mm512_loadu_si512(B_addr);
+ B_addr += 32;
+ }
+
+ for (BLASLONG idx = 0; idx < 32;) {
+ // Each two rows are a group for 32-pair bf16 elements
+ arrayA_512[0] = _mm512_loadu_si512(A_addr);
+ arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
+ A_addr += 64;
+
+ for (int i = 0; i < n; i++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i] , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+ result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+ arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+
+ idx += 2;
+ // Every 4 loops we need to switch to next 128 bits of arrayB registers
+ if ((idx & (~7)) == idx) {
+ for (int i = 0; i < n; i++) {
+ arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+ }
}
+ }
- if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
+ if (tag_k_32x != k) {
+ // Load B with unroll n
+ for (int i = 0; i < n; i ++) {
+ arrayB_512[i] = _mm512_loadu_si512(B_addr);
+ B_addr += 32;
+ }
+ BLASLONG width = k - tag_k_32x;
for (BLASLONG idx = 0; idx < width;) {
// Each two rows are a group for 32-pair bf16 elements
- idxA_base = idx << 5;
- arrayA_512[0] = _mm512_loadu_si512(&A[idxA_base]);
- arrayA_512[1] = _mm512_loadu_si512(&A[idxA_base + 32]);
+ arrayA_512[0] = _mm512_loadu_si512(A_addr);
+ arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
+ A_addr += 64;
for (int i = 0; i < n; i++) {
result_512[i] = _mm512_dpbf16_ps(result_512[i] , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
}
if (m != 32) {
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (32-m));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
for (int i = 0; i < n; i++) {
result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i]))
- STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16]), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask)
}
} else {
for (int i = 0; i < n; i++) {
result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
- STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i]))
- STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16]))
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
+ STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16))
}
}
}
// SBGEMM Kernel for 16<=M, N<8, K can be any number, but the processing will take 32 as a base
#ifndef ONE_ALPHA // ALPHA is not ONE
-void sbgemm_block_kernel_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#else // ALPHA is ONE
-void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+void sbgemm_block_kernel_nn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#endif
{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
int SHUFFLE_MAGIC_NO = 0x39;
BLASLONG tag_k_32x = k & (~31);
- BLASLONG idxB_base = 0;
- BLASLONG width = 32;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
result_512[i+1] = _mm512_setzero_ps();
}
- for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
+ for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
// Load B with unroll n
- for (int i = 0; i < n; i ++) {
- arrayB_512[i] = _mm512_loadu_si512(&B[idxB_base]);
- idxB_base += 32;
+ for (int i = 0; i < n; i++) {
+ arrayB_512[i] = _mm512_loadu_si512(B_addr);
+ B_addr += 32;
+ }
+
+ for (BLASLONG idx = 0; idx < 32;) {
+ // Each two rows are a group for 32-pair bf16 elements
+ // Load two rows into a 512 register
+ arrayA_512 = _mm512_loadu_si512(A_addr);
+ A_addr += 32;
+
+ for (int i = 0; i < n; i ++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+ arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+
+ idx += 2;
+ // Every 4 loops we need to switch to next 128 bits of arrayB registers
+ if ((idx & (~7)) == idx) {
+ for (int i = 0; i < n; i++) {
+ arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+ }
}
+ }
- if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
+ if (tag_k_32x != k) {
+ // Load B with unroll n
+ for (int i = 0; i < n; i++) {
+ arrayB_512[i] = _mm512_loadu_si512(B_addr);
+ B_addr += 32;
+ }
+ BLASLONG width = k - tag_k_32x;
for (BLASLONG idx = 0; idx < width;) {
// Each two rows are a group for 32-pair bf16 elements
// Load two rows into a 512 register
- arrayA_512 = _mm512_loadu_si512(&A[idx<<4]);
+ arrayA_512 = _mm512_loadu_si512(A_addr);
+ A_addr += 32;
- for (int i = 0; i < n; i ++) {
+ for (int i = 0; i < n; i++) {
result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
}
}
if (m != 16) {
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
for (int i = 0; i < n; i++) {
result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
- STORE16_MASK_COMPLETE_RESULT(result_512[i], (&C[ldc*i]), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
}
} else {
for (int i = 0; i < n; i++) {
result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
- STORE16_COMPLETE_RESULT(result_512[i], (&C[ldc*i]))
+ STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
}
}
}
+
+
#ifndef ONE_ALPHA // ALPHA is not ONE
-void sbgemm_blocking_kernel_2_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+void sbgemm_blocking_kernel_nn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
#else // ALPHA is ONE
-void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+void sbgemm_blocking_kernel_nn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
#endif
{
BLASLONG m_step, n_step, k_step, k_step_round32;
while (n_from < N) {
for (BLASLONG idx_k = 0; idx_k < K;) {
// Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
- COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, 0), lda, block_A);
- // TODO: MT
+ COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A);
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
// Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
- SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
- SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
}
for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
- COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, idx_m), lda, block_A);
+ COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A);
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
- SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
+ SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
- SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
+ SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
}
}
if (tag_m_Nx != M) {
m_step = M - tag_m_Nx;
if (m_step > 16) {
- COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
- for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
- SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
- }
-
- if (tag_n_Nx != n_to) {
- n_step = n_to - tag_n_Nx;
- SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
- }
- } else if (m_step == 16) {
- COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
+ COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
- SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+ SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
- SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+ SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
}
} else {
- COL_MAJOR_INCOPY_KERNEL_Kx16m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
+ COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
- SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+ SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
- SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+ SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
}
}
}
tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
}
} else {
- m_step = M - tag_m_Nx;
+ m_step = M;
+ if (m_step > 16) {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+ COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ } else {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ }
+ }
+}
+/* ----------------------------------------- End of NN kernels --------------------------------------- */
+
+/* --------------------------------------------- NT kernels ------------------------------------------ */
+// SBGEMM Kernel for 16<M<=32, N<8, K can be any number
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_block_kernel_nt_32xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else // ALPHA is ONE
+void sbgemm_block_kernel_nt_32xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+ __m512i arrayA_512_0, arrayA_512_1;
+ __m512i arrayB_512[8];
+ __m512 result_512[16];
+ __m512 result_512_tmp_0, result_512_tmp_1;
+
+ __m512i M512_EPI32_8 = _mm512_set1_epi32(8);
+ __m512i shuffle_idx_base0 = _mm512_set_epi32(23, 22, 21, 20, 7, 6, 5, 4, 19, 18, 17, 16, 3, 2, 1, 0);
+ __m512i shuffle_idx_base1 = _mm512_add_epi32(shuffle_idx_base0, M512_EPI32_8);
+
+ result_512[0] = _mm512_setzero_ps();
+ result_512[1] = _mm512_setzero_ps();
+ result_512[2] = _mm512_setzero_ps();
+ result_512[3] = _mm512_setzero_ps();
+ result_512[4] = _mm512_setzero_ps();
+ result_512[5] = _mm512_setzero_ps();
+ result_512[6] = _mm512_setzero_ps();
+ result_512[7] = _mm512_setzero_ps();
+ result_512[8] = _mm512_setzero_ps();
+ result_512[9] = _mm512_setzero_ps();
+ result_512[10] = _mm512_setzero_ps();
+ result_512[11] = _mm512_setzero_ps();
+ result_512[12] = _mm512_setzero_ps();
+ result_512[13] = _mm512_setzero_ps();
+ result_512[14] = _mm512_setzero_ps();
+ result_512[15] = _mm512_setzero_ps();
+
+ for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+ // Each two rows are a group for 32-pair bf16 elements
+ arrayA_512_0 = _mm512_loadu_si512(A_addr);
+ arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
+ A_addr += 64;
+
+ for (int i = 0; i < n; i ++) {
+ _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]);
+ }
+ B_addr += 16;
+
+ for (int i = 0; i < n; i ++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
+ result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512_1, (__m512bh) arrayB_512[i]);
+ }
+ }
+
+ if (m != 32) {
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
+ for (int i = 0; i < n; i ++) {
+ result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
+ result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
+ STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask)
+ }
+ } else {
+ for (int i = 0; i < n; i ++) {
+ result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
+ result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
+ STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
+ STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16))
+ }
+ }
+}
+
+// SBGEMM Kernel for M<=16, N<8, K can be any number
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_block_kernel_nt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else // ALPHA is ONE
+void sbgemm_block_kernel_nt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+ __m512i arrayA_512_0;
+ __m512i arrayB_512[8];
+ __m512 result_512[8];
+
+ result_512[0] = _mm512_setzero_ps();
+ result_512[1] = _mm512_setzero_ps();
+ result_512[2] = _mm512_setzero_ps();
+ result_512[3] = _mm512_setzero_ps();
+ result_512[4] = _mm512_setzero_ps();
+ result_512[5] = _mm512_setzero_ps();
+ result_512[6] = _mm512_setzero_ps();
+ result_512[7] = _mm512_setzero_ps();
+
+ for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+ // Each two rows are a group for 16-pair bf16 elements
+ // Load two rows into a 512 register
+ arrayA_512_0 = _mm512_loadu_si512(A_addr);
+ A_addr += 32;
+
+ for (int i = 0; i < n; i ++) {
+ _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]);
+ }
+ B_addr += 16;
+
+ for (int i = 0; i < n; i ++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
+ }
+ }
+
+ if (m != 16) {
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
+ for (int i = 0; i < n; i++) {
+ result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
+ STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
+ }
+ } else {
+ for (int i = 0; i < n; i++) {
+ result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
+ STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+ }
+ }
+}
+
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_blocking_kernel_nt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#else // ALPHA is ONE
+void sbgemm_blocking_kernel_nt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#endif
+{
+ BLASLONG m_step, n_step, k_step, k_step_round32;
+ BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
+
+ BLASLONG n_from, n_to;
+ BLASLONG tag_n_Nx;
+
+ n_from = 0;
+ n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+
+ k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+
+ if (M >= BF16_BLOCK_THRES_M) {
while (n_from < N) {
for (BLASLONG idx_k = 0; idx_k < K;) {
// Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
- COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, 0), lda, block_A);
- // TODO: MT
+ COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A);
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
// Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
- COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
- SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
- COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
- SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
+ COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
+ }
+ }
+
+ if (tag_m_Nx != M) {
+ m_step = M - tag_m_Nx;
+ if (m_step > 16) {
+ COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+ }
+ } else {
+ COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+ }
+ }
}
idx_k += k_step;
k_step_round32 = k_step & (~31);
k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
}
+
n_from = n_to;
n_to += BF16_BLOCK_THRES_N;
n_to = (n_to > N) ? N : n_to;
tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
}
+ } else {
+ m_step = M;
+ if (m_step > 16) {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+ COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ } else {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+ COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ }
}
}
+/* ----------------------------------------- End of NT kernels --------------------------------------- */
+
+/* --------------------------------------------- TN kernels ------------------------------------------ */
+// SBGEMM Kernel for 16<M<=32, N=8, K=Any number
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_block_kernel_tn_32x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else // ALPHA is ONE
+void sbgemm_block_kernel_tn_32x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+ __m512i arrayA_512_0, arrayA_512_1;
+ __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
+ __m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7,
+ result_512_8, result_512_9, result_512_10, result_512_11, result_512_12, result_512_13, result_512_14, result_512_15;
+
+ result_512_0 = _mm512_setzero_ps();
+ result_512_1 = _mm512_setzero_ps();
+ result_512_2 = _mm512_setzero_ps();
+ result_512_3 = _mm512_setzero_ps();
+ result_512_4 = _mm512_setzero_ps();
+ result_512_5 = _mm512_setzero_ps();
+ result_512_6 = _mm512_setzero_ps();
+ result_512_7 = _mm512_setzero_ps();
+ result_512_8 = _mm512_setzero_ps();
+ result_512_9 = _mm512_setzero_ps();
+ result_512_10 = _mm512_setzero_ps();
+ result_512_11 = _mm512_setzero_ps();
+ result_512_12 = _mm512_setzero_ps();
+ result_512_13 = _mm512_setzero_ps();
+ result_512_14 = _mm512_setzero_ps();
+ result_512_15 = _mm512_setzero_ps();
+
+ for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+ // Load 32 pair of BF16 elements from A (32 rows)
+ arrayA_512_0 = _mm512_loadu_si512(A_addr);
+ arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
+
+ // Load 8 rows of B
+ _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0);
+ _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1);
+ _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2);
+ _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3);
+ _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4);
+ _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
+ _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
+ _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
+
+ result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
+ result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
+ result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
+ result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
+ result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
+ result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
+ result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
+ result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
+
+ result_512_8 = _mm512_dpbf16_ps(result_512_8, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_0);
+ result_512_9 = _mm512_dpbf16_ps(result_512_9, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_1);
+ result_512_10 = _mm512_dpbf16_ps(result_512_10, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_2);
+ result_512_11 = _mm512_dpbf16_ps(result_512_11, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_3);
+ result_512_12 = _mm512_dpbf16_ps(result_512_12, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_4);
+ result_512_13 = _mm512_dpbf16_ps(result_512_13, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_5);
+ result_512_14 = _mm512_dpbf16_ps(result_512_14, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_6);
+ result_512_15 = _mm512_dpbf16_ps(result_512_15, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_7);
+
+ // Load B with unroll 8
+ B_addr += 16;
+ // Load A with unroll 64
+ A_addr += 64;
+ }
+
+ if (m != 32) {
+ unsigned short tail_mask_value = (((unsigned short)0xffff) >> (32-m));
+ __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ STORE16_COMPLETE_RESULT(result_512_0, (C_addr))
+ STORE16_MASK_COMPLETE_RESULT(result_512_8, (C_addr + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc))
+ STORE16_MASK_COMPLETE_RESULT(result_512_9, (C_addr + ldc + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2))
+ STORE16_MASK_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3))
+ STORE16_MASK_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4))
+ STORE16_MASK_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5))
+ STORE16_MASK_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6))
+ STORE16_MASK_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16), tail_mask)
+ STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7))
+ STORE16_MASK_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16), tail_mask)
+ } else {
+ STORE16_COMPLETE_RESULT(result_512_0, (C_addr))
+ STORE16_COMPLETE_RESULT(result_512_8, (C_addr + 16))
+ STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc))
+ STORE16_COMPLETE_RESULT(result_512_9, (C_addr + ldc + 16))
+ STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2))
+ STORE16_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16))
+ STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3))
+ STORE16_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16))
+ STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4))
+ STORE16_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16))
+ STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5))
+ STORE16_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16))
+ STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6))
+ STORE16_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16))
+ STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7))
+ STORE16_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16))
+ }
+}
+
+// SBGEMM Kernel for M=16, N=8, K=Any number
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_block_kernel_tn_16x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else // ALPHA is ONE
+void sbgemm_block_kernel_tn_16x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+ __m512i arrayA_512_0;
+ __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
+ __m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7;
+
+ result_512_0 = _mm512_setzero_ps();
+ result_512_1 = _mm512_setzero_ps();
+ result_512_2 = _mm512_setzero_ps();
+ result_512_3 = _mm512_setzero_ps();
+ result_512_4 = _mm512_setzero_ps();
+ result_512_5 = _mm512_setzero_ps();
+ result_512_6 = _mm512_setzero_ps();
+ result_512_7 = _mm512_setzero_ps();
+
+ for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+ // Load 16 pair of BF16 elements from A (16 rows)
+ arrayA_512_0 = _mm512_loadu_si512(A_addr + 0);
+
+ // Load 8 rows of B
+ _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0);
+ _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1);
+ _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2);
+ _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3);
+ _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4);
+ _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
+ _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
+ _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
+
+ result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
+ result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
+ result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
+ result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
+ result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
+ result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
+ result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
+ result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
+
+ // Load B with unroll 8
+ B_addr += 16;
+ // Load A with unroll 32
+ A_addr += 32;
+ }
+
+ if (m != 16) {
+ unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
+ __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ STORE16_MASK_COMPLETE_RESULT(result_512_0, (C_addr), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_1, (C_addr + ldc), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6), tail_mask)
+ STORE16_MASK_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7), tail_mask)
+ } else {
+ STORE16_COMPLETE_RESULT(result_512_0, (C_addr))
+ STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc))
+ STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2))
+ STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3))
+ STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4))
+ STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5))
+ STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6))
+ STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7))
+ }
+}
+
+// SBGEMM Kernel for 16<M<=32, N<8, K=Any number but will be processed based on 32
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_block_kernel_tn_32xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else // ALPHA is ONE
+void sbgemm_block_kernel_tn_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
+ int SHUFFLE_MAGIC_NO = 0x39;
+ BLASLONG tag_k_32x = k & (~31);
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+ __m512i arrayA_512[2];
+ __m512i arrayB_512[8];
+ __m512 result_512[16];
+
+ for (int i = 0; i < 15; i++) {
+ result_512[i] = _mm512_setzero_ps();
+ }
+
+ for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+ // Load B with unroll n
+ for (int i = 0; i < n; i ++) {
+ arrayB_512[i] = _mm512_loadu_si512(B_addr);
+ B_addr += 32;
+ }
+
+ for (BLASLONG idx = 0; idx < 32;) {
+ // Each two rows are a group for 32-pair bf16 elements
+ arrayA_512[0] = _mm512_loadu_si512(A_addr);
+ arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
+ A_addr += 64;
+
+ for (int i = 0; i < n; i++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i] , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+ result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+ arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+
+ idx += 2;
+ // Every 4 loops we need to switch to next 128 bits of arrayB registers
+ if ((idx & (~7)) == idx) {
+ for (int i = 0; i < n; i++) {
+ arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+ }
+ }
+ }
+
+ if (tag_k_32x != k) {
+ // Load B with unroll n
+ for (int i = 0; i < n; i ++) {
+ arrayB_512[i] = _mm512_loadu_si512(B_addr);
+ B_addr += 32;
+ }
+
+ BLASLONG width = k - tag_k_32x;
+ for (BLASLONG idx = 0; idx < width;) {
+ // Each two rows are a group for 32-pair bf16 elements
+ arrayA_512[0] = _mm512_loadu_si512(A_addr);
+ arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
+ A_addr += 64;
+
+ for (int i = 0; i < n; i++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i] , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+ result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+ arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+
+ idx += 2;
+ // Every 4 loops we need to switch to next 128 bits of arrayB registers
+ if ((idx & (~7)) == idx) {
+ for (int i = 0; i < n; i++) {
+ arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+ }
+ }
+ }
+
+ if (m != 32) {
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
+ for (int i = 0; i < n; i++) {
+ STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+ STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask)
+ }
+ } else {
+ for (int i = 0; i < n; i++) {
+ STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+ STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16))
+ }
+ }
+}
+
+// SBGEMM Kernel for M<=16, N<8, K=Any number but will be processed based on 32
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_block_kernel_tn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else // ALPHA is ONE
+void sbgemm_block_kernel_tn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
+ int SHUFFLE_MAGIC_NO = 0x39;
+ BLASLONG tag_k_32x = k & (~31);
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+ __m512i arrayA_512;
+ __m512i arrayB_512[8];
+ __m512 result_512[8];
+
+ for (int i = 0; i < 8; i++) {
+ result_512[i] = _mm512_setzero_ps();
+ }
+
+ for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
+ // Load B with unroll n
+ for (int i = 0; i < n; i ++) {
+ arrayB_512[i] = _mm512_loadu_si512(B_addr);
+ B_addr += 32;
+ }
+
+ for (BLASLONG idx = 0; idx < 32;) {
+ // Each two rows are a group for 32-pair bf16 elements
+ arrayA_512 = _mm512_loadu_si512(A_addr);
+ A_addr += 32;
+
+ for (int i = 0; i < n; i++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+ arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+
+ idx += 2;
+ // Every 4 loops we need to switch to next 128 bits of arrayB registers
+ if ((idx & (~7)) == idx) {
+ for (int i = 0; i < n; i++) {
+ arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+ }
+ }
+ }
+
+ if (tag_k_32x != k) {
+ // Load B with unroll n
+ for (int i = 0; i < n; i ++) {
+ arrayB_512[i] = _mm512_loadu_si512(B_addr);
+ B_addr += 32;
+ }
+
+ BLASLONG width = k - tag_k_32x;
+ for (BLASLONG idx = 0; idx < width;) {
+ // Each two rows are a group for 32-pair bf16 elements
+ arrayA_512 = _mm512_loadu_si512(A_addr);
+ A_addr += 32;
+
+ for (int i = 0; i < n; i++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
+ arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+
+ idx += 2;
+ // Every 4 loops we need to switch to next 128 bits of arrayB registers
+ if ((idx & (~7)) == idx) {
+ for (int i = 0; i < n; i++) {
+ arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
+ }
+ }
+ }
+ }
+
+ if (m != 16) {
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
+ for (int i = 0; i < n; i++) {
+ STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
+ }
+ } else {
+ for (int i = 0; i < n; i++) {
+ STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+ }
+ }
+}
+
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_blocking_kernel_tn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#else // ALPHA is ONE
+void sbgemm_blocking_kernel_tn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#endif
+{
+ BLASLONG m_step, n_step, k_step, k_step_round32;
+ BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
+
+ BLASLONG n_from, n_to;
+ BLASLONG tag_n_Nx;
+
+ n_from = 0;
+ n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+
+ k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+
+ if (M >= BF16_BLOCK_THRES_M) {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+ COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); // TODO how to process m
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
+ COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
+ }
+ }
+
+ if (tag_m_Nx != M) {
+ m_step = M - tag_m_Nx;
+ if (m_step > 16) {
+ COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+ }
+ } else {
+ COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+ }
+ }
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ } else {
+ m_step = M;
+ if (m_step > 16) {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+ COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ } else {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+ COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ }
+ }
+}
+/* ----------------------------------------- End of TN kernels --------------------------------------- */
+
+/* --------------------------------------------- TT kernels ------------------------------------------ */
+// SBGEMM Kernel for 16<M<=32, N<8, K can be any number
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_block_kernel_tt_32xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else // ALPHA is ONE
+void sbgemm_block_kernel_tt_32xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+ __m512i arrayA_512_0, arrayA_512_1;
+ __m512i arrayB_512[8];
+ __m512 result_512[16];
+
+ result_512[0] = _mm512_setzero_ps();
+ result_512[1] = _mm512_setzero_ps();
+ result_512[2] = _mm512_setzero_ps();
+ result_512[3] = _mm512_setzero_ps();
+ result_512[4] = _mm512_setzero_ps();
+ result_512[5] = _mm512_setzero_ps();
+ result_512[6] = _mm512_setzero_ps();
+ result_512[7] = _mm512_setzero_ps();
+ result_512[8] = _mm512_setzero_ps();
+ result_512[9] = _mm512_setzero_ps();
+ result_512[10] = _mm512_setzero_ps();
+ result_512[11] = _mm512_setzero_ps();
+ result_512[12] = _mm512_setzero_ps();
+ result_512[13] = _mm512_setzero_ps();
+ result_512[14] = _mm512_setzero_ps();
+ result_512[15] = _mm512_setzero_ps();
+
+ for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+ // Each two rows are a group for 32-pair bf16 elements
+ arrayA_512_0 = _mm512_loadu_si512(A_addr);
+ arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
+ A_addr += 64;
+
+ for (int i = 0; i < n; i ++) {
+ _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]);
+ }
+ B_addr += 16;
+
+ for (int i = 0; i < n; i ++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
+ result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512_1, (__m512bh) arrayB_512[i]);
+ }
+ }
+
+ if (m != 32) {
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
+ for (int i = 0; i < n; i ++) {
+ STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+ STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask)
+ }
+ } else {
+ for (int i = 0; i < n; i ++) {
+ STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+ STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16))
+ }
+ }
+}
+
+// SBGEMM Kernel for M<=16, N<8, K can be any number
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_block_kernel_tt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#else // ALPHA is ONE
+void sbgemm_block_kernel_tt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
+#endif
+{
+ bfloat16 * A_addr = A;
+ bfloat16 * B_addr = B;
+ float * C_addr = C;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+
+ __m512i arrayA_512_0;
+ __m512i arrayB_512[8];
+ __m512 result_512[8];
+
+ result_512[0] = _mm512_setzero_ps();
+ result_512[1] = _mm512_setzero_ps();
+ result_512[2] = _mm512_setzero_ps();
+ result_512[3] = _mm512_setzero_ps();
+ result_512[4] = _mm512_setzero_ps();
+ result_512[5] = _mm512_setzero_ps();
+ result_512[6] = _mm512_setzero_ps();
+ result_512[7] = _mm512_setzero_ps();
+
+ for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
+ // Each two rows are a group for 16-pair bf16 elements
+ // Load two rows into a 512 register
+ arrayA_512_0 = _mm512_loadu_si512(A_addr);
+ A_addr += 32;
+
+ for (int i = 0; i < n; i ++) {
+ _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]);
+ }
+ B_addr += 16;
+
+ for (int i = 0; i < n; i ++) {
+ result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
+ }
+ }
+
+ if (m != 16) {
+ unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
+ for (int i = 0; i < n; i++) {
+ STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
+ }
+ } else {
+ for (int i = 0; i < n; i++) {
+ STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
+ }
+ }
+}
+
+#ifndef ONE_ALPHA // ALPHA is not ONE
+void sbgemm_blocking_kernel_tt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#else // ALPHA is ONE
+void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
+#endif
+{
+ BLASLONG m_step, n_step, k_step, k_step_round32;
+ BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
+
+ BLASLONG n_from, n_to;
+ BLASLONG tag_n_Nx;
+
+ n_from = 0;
+ n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+
+ k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+
+ if (M >= BF16_BLOCK_THRES_M) {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+ COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
+ COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
+ }
+ }
+
+ if (tag_m_Nx != M) {
+ m_step = M - tag_m_Nx;
+ if (m_step > 16) {
+ COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+ }
+ } else {
+ COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
+ }
+ }
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ } else {
+ m_step = M;
+ if (m_step > 16) {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+ COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ } else {
+ while (n_from < N) {
+ for (BLASLONG idx_k = 0; idx_k < K;) {
+ // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
+ COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A);
+ for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
+ // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
+ COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
+ }
+
+ if (tag_n_Nx != n_to) {
+ n_step = n_to - tag_n_Nx;
+ COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
+ SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
+ }
+
+ idx_k += k_step;
+ k_step = K - idx_k;
+ k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
+ k_step_round32 = k_step & (~31);
+ k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
+ }
+ n_from = n_to;
+ n_to += BF16_BLOCK_THRES_N;
+ n_to = (n_to > N) ? N : n_to;
+ tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
+ }
+ }
+ }
+}
+/* ----------------------------------------- End of TT kernels --------------------------------------- */
#ifndef ONE_ALPHA // ALPHA is not ONE
void sbgemm_internal_kernel_alpha(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc)
#endif
{
- bfloat16 block_A[BF16_BLOCK_THRES_K * BF16_BLOCK_THRES_M];
- bfloat16 block_B[BF16_BLOCK_THRES_N * BF16_BLOCK_THRES_K];
-
- // TODO: assume no trans for both A and B, to complement these scenarios later
if (Order == CblasColMajor) {
- SBGEMM_BLOCKING_KERNEL_2(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+ if (TransA == CblasNoTrans) {
+ if (TransB == CblasNoTrans) {
+ SBGEMM_BLOCKING_KERNEL_NN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+ } else if (TransB == CblasTrans) {
+ SBGEMM_BLOCKING_KERNEL_NT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+ }
+ } else {
+ if (TransB == CblasNoTrans) {
+ SBGEMM_BLOCKING_KERNEL_TN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+ } else if (TransB == CblasTrans) {
+ SBGEMM_BLOCKING_KERNEL_TT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
+ }
+ }
} else {
-
+ if (TransA == CblasNoTrans) {
+ if (TransB == CblasNoTrans) {
+ SBGEMM_BLOCKING_KERNEL_NN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
+ } else if (TransB == CblasTrans) {
+ SBGEMM_BLOCKING_KERNEL_TN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
+ }
+ } else {
+ if (TransB == CblasNoTrans) {
+ SBGEMM_BLOCKING_KERNEL_NT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
+ } else if (TransB == CblasTrans) {
+ SBGEMM_BLOCKING_KERNEL_TT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
+ }
+ }
}
-}
\ No newline at end of file
+}