From 5d86becdaec262e8a2869ce909d94bec881fbfb6 Mon Sep 17 00:00:00 2001 From: "Chen, Guobing" Date: Thu, 5 Aug 2021 11:11:14 +0800 Subject: [PATCH] Add all SBGEMM kernels for IA AVX512-BF16 based platforms Added all SBGEMM kernels including NN/NT/TN/TT for both ColMajor and RowMajor, based on AVX512-BF16 ISA set on IA. Signed-off-by: Chen, Guobing --- kernel/x86_64/bf16_common_macros.h | 52 + kernel/x86_64/sbgemm_block_microk_cooperlake.c | 2024 ++++++++++++++++++--- kernel/x86_64/sbgemm_microk_cooperlake_template.c | 1737 +++++++++++++++--- 3 files changed, 3268 insertions(+), 545 deletions(-) diff --git a/kernel/x86_64/bf16_common_macros.h b/kernel/x86_64/bf16_common_macros.h index 1014ecc..78db7ab 100644 --- a/kernel/x86_64/bf16_common_macros.h +++ b/kernel/x86_64/bf16_common_macros.h @@ -29,6 +29,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#define _MM512_BROADCASTD_EPI32(addr, zmm) \ + __asm__ ("vpbroadcastd (%1), %0;" \ + : "=v" (zmm) \ + : "r" (addr) ) + +#define PREFETCH_T0(addr) \ + __asm__ ("prefetcht0 (%0);" \ + : \ + : "r" (addr) ) + #define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \ reg256##_0 = _mm512_castps512_ps256(reg512##_0); \ reg256##_1 = _mm512_castps512_ps256(reg512##_1); @@ -721,6 +731,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. _mm_mask_storeu_ps(targetAddr, mask, regResult); +/* Store 16 (result + y) to y +*/ +#define STORE16_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \ + regResult = _mm512_add_ps(regResult, _mm512_loadu_ps(targetAddr)); \ + _mm512_storeu_ps(targetAddr, regResult); + + +/* Masked store 16 (result + y) to y +*/ +#define STORE16_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \ + regResult = _mm512_add_ps(regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \ + _mm512_mask_storeu_ps(targetAddr, mask, regResult); + + +/* Store 8 (result + y) to y +*/ +#define STORE8_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \ + regResult = _mm256_add_ps(regResult, _mm256_loadu_ps(targetAddr)); \ + _mm256_storeu_ps(targetAddr, regResult); + + +/* Masked store 8 (result + y) to y +*/ +#define STORE8_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \ + regResult = _mm256_add_ps(regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \ + _mm256_mask_storeu_ps(targetAddr, mask, regResult); + + +/* Store 4 (result + y) to y +*/ +#define STORE4_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \ + regResult = _mm_add_ps(regResult, _mm_loadu_ps(targetAddr)); \ + _mm_storeu_ps(targetAddr, regResult); + + +/* Masked store 4 (result + y) to y +*/ +#define STORE4_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \ + regResult = _mm_add_ps(regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \ + _mm_mask_storeu_ps(targetAddr, mask, regResult); + + /* Store 16 (alpha * result) to y */ #define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \ diff --git a/kernel/x86_64/sbgemm_block_microk_cooperlake.c b/kernel/x86_64/sbgemm_block_microk_cooperlake.c index 2376fed..147c5eb 100644 --- a/kernel/x86_64/sbgemm_block_microk_cooperlake.c +++ b/kernel/x86_64/sbgemm_block_microk_cooperlake.c @@ -1,4 +1,4 @@ -#include "sbgemm.h" +//#include "sbgemm.h" #include // Walk around those intrinsics that missed by compiler @@ -7,420 +7,1878 @@ #define MM256_STOREU_EPI16(addr, reg) \ _mm256_mask_storeu_epi16((addr), ~0, (reg)) -#include -void print_block(BLASLONG m, BLASLONG n, bfloat16 * mat) -{ - printf("---- BLOCK %ld x %ld ----\n", m, n); - for (BLASLONG i=0; i> (32-m)); __m512i array512_0, array512_1, array512_2, array512_3; - BLASLONG idx_src_base0, idx_src_base1; - BLASLONG idx_target_base0, idx_target_base1; + bfloat16 * src_addr0, * src_addr1; + bfloat16 * dst_addr0, * dst_addr1; BLASLONG LDA_2x = 2*lda; BLASLONG BF16_BLOCK_T_M_2x = 2*32; - idx_src_base0 = 0; - idx_src_base1 = lda; - idx_target_base0 = 0; - idx_target_base1 = 32; - for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { - array512_0 = _mm512_loadu_si512(&A[idx_src_base0]); - array512_1 = _mm512_loadu_si512(&A[idx_src_base1]); - array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1); - array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1); - _mm512_storeu_si512(&block_A[idx_target_base0], array512_2); - _mm512_storeu_si512(&block_A[idx_target_base1], array512_3); - - idx_src_base0 += LDA_2x; - idx_src_base1 += LDA_2x; - idx_target_base0 += BF16_BLOCK_T_M_2x; - idx_target_base1 += BF16_BLOCK_T_M_2x; - } - - if (tag_k_2x != k) { - __m512i ZERO512 = _mm512_setzero_si512(); - array512_0 = _mm512_loadu_si512(&A[idx_src_base0]); - array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512); - array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512); - _mm512_storeu_si512(&block_A[idx_target_base0], array512_2); - _mm512_storeu_si512(&block_A[idx_target_base1], array512_3); - } - -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A); -#endif -} - -void COL_MAJOR_INCOPY_KERNEL_Kx32m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) -{ - BLASLONG tag_k_2x = k & (~1); - unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-m)); - __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); - - __m512i array512_0, array512_1, array512_2, array512_3; - BLASLONG idx_src_base0, idx_src_base1; - BLASLONG idx_target_base0, idx_target_base1; + src_addr0 = A; + src_addr1 = A + lda; + dst_addr0 = block_A; + dst_addr1 = block_A + 32; - BLASLONG LDA_2x = 2*lda; - BLASLONG BF16_BLOCK_T_M_2x = 2*32; - idx_src_base0 = 0; - idx_src_base1 = lda; - idx_target_base0 = 0; - idx_target_base1 = 32; for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { - array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]); - array512_1 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]); + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0); + array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1); array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1); array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1); - _mm512_storeu_si512(&block_A[idx_target_base0], array512_2); - _mm512_storeu_si512(&block_A[idx_target_base1], array512_3); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); - idx_src_base0 += LDA_2x; - idx_src_base1 += LDA_2x; - idx_target_base0 += BF16_BLOCK_T_M_2x; - idx_target_base1 += BF16_BLOCK_T_M_2x; + src_addr0 += LDA_2x; + src_addr1 += LDA_2x; + dst_addr0 += BF16_BLOCK_T_M_2x; + dst_addr1 += BF16_BLOCK_T_M_2x; } if (tag_k_2x != k) { __m512i ZERO512 = _mm512_setzero_si512(); - array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]); + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0); array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512); array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512); - _mm512_storeu_si512(&block_A[idx_target_base0], array512_2); - _mm512_storeu_si512(&block_A[idx_target_base1], array512_3); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); } - -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A); -#endif } +// INCOPY Kernel, 0> (16-m)); __m256i array256_0, array256_1, array256_2, array256_3; - BLASLONG idx_src_base0, idx_src_base1; - BLASLONG idx_target_base0; + bfloat16 * src_addr0, * src_addr1; + bfloat16 * dst_addr0; BLASLONG LDA_2x = 2*lda; - idx_src_base0 = 0; - idx_src_base1 = lda; - idx_target_base0 = 0; + + src_addr0 = A; + src_addr1 = A + lda; + dst_addr0 = block_A; + for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { - array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]); - array256_1 = MM256_LOADU_EPI16(&A[idx_src_base1]); + array256_0 = _mm256_maskz_loadu_epi16(tail_mask, src_addr0); + array256_1 = _mm256_maskz_loadu_epi16(tail_mask, src_addr1); array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1); array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1); // Store in one row of block_B - MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2); - MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3); + MM256_STOREU_EPI16(dst_addr0, array256_2); + MM256_STOREU_EPI16(dst_addr0+16, array256_3); - idx_src_base0 += LDA_2x; - idx_src_base1 += LDA_2x; - idx_target_base0 += 32; + src_addr0 += LDA_2x; + src_addr1 += LDA_2x; + dst_addr0 += 32; } if (tag_k_2x != k) { __m256i ZERO256 = _mm256_setzero_si256(); - array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]); + array256_0 = _mm256_maskz_loadu_epi16(tail_mask, src_addr0); array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256); array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256); // Store in one row of block_B - MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2); - MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3); + MM256_STOREU_EPI16(dst_addr0, array256_2); + MM256_STOREU_EPI16(dst_addr0+16, array256_3); } +} -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A); -#endif +// K=32, M=16 +void COL_MAJOR_ITCOPY_KERNEL_32x16(bfloat16 * A, BLASLONG lda, bfloat16 * block_A) +{ + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * dst_addr0, * dst_addr1; + + BLASLONG LDA_4x = lda*4; + + src_addr0 = A; + src_addr1 = A + lda; + src_addr2 = A + lda*2; + src_addr3 = A + lda*3; + dst_addr0 = block_A; + dst_addr1 = block_A + 32*8; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3; + __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3; + __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3; + + __m512i M512_EPI64_2 = _mm512_set1_epi64(2); + __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2); + + // Load and preprocess 1st 4 rows + array512_way0_0 = _mm512_loadu_si512(src_addr0); + array512_way0_1 = _mm512_loadu_si512(src_addr1); + array512_way0_2 = _mm512_loadu_si512(src_addr2); + array512_way0_3 = _mm512_loadu_si512(src_addr3); + array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1); + array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1); + array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3); + array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3); + array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 2nd 4 rows + array512_way1_0 = _mm512_loadu_si512(src_addr0); + array512_way1_1 = _mm512_loadu_si512(src_addr1); + array512_way1_2 = _mm512_loadu_si512(src_addr2); + array512_way1_3 = _mm512_loadu_si512(src_addr3); + array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1); + array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1); + array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3); + array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3); + array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 3rd 4 rows + array512_way2_0 = _mm512_loadu_si512(src_addr0); + array512_way2_1 = _mm512_loadu_si512(src_addr1); + array512_way2_2 = _mm512_loadu_si512(src_addr2); + array512_way2_3 = _mm512_loadu_si512(src_addr3); + array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1); + array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1); + array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3); + array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3); + array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 4th 4 rows + array512_way3_0 = _mm512_loadu_si512(src_addr0); + array512_way3_1 = _mm512_loadu_si512(src_addr1); + array512_way3_2 = _mm512_loadu_si512(src_addr2); + array512_way3_3 = _mm512_loadu_si512(src_addr3); + array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1); + array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1); + array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3); + array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3); + array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Compose and store the 0/1 and 16/17 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 2/3 and 18/19 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 4/5 and 20/21 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 6/7 and 22/23 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 8/9 and 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 10/11 and 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 12/13 and 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 14/15 and 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); } -void COL_MAJOR_INCOPY_KERNEL_Kx16m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) +// K=Any number but will be processed based on 32, M=32 +void COL_MAJOR_ITCOPY_KERNEL_Kx32(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) { - BLASLONG tag_k_2x = k & (~1); - unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * dst_addr0, * dst_addr1; - __m256i array256_0, array256_1, array256_2, array256_3; + BLASLONG tag_k_32x = k & (~31); - BLASLONG idx_src_base0, idx_src_base1; - BLASLONG idx_target_base0; + BLASLONG LDA_4x = lda*4; + BLASLONG LDA_8x = lda*8; + BLASLONG LDA_12x = lda*12; + BLASLONG LDA_16x = lda*16; - BLASLONG LDA_2x = 2*lda; - idx_src_base0 = 0; - idx_src_base1 = lda; - idx_target_base0 = 0; - for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { - array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]); - array256_1 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]); - array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1); - array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1); - // Store in one row of block_B - MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2); - MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3); + src_addr0 = A; + src_addr1 = A + lda; + src_addr2 = A + lda*2; + src_addr3 = A + lda*3; + dst_addr0 = block_A; + dst_addr1 = block_A + 32*16; - idx_src_base0 += LDA_2x; - idx_src_base1 += LDA_2x; - idx_target_base0 += 32; + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3; + __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3; + __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3; + + __m512i M512_EPI64_2 = _mm512_set1_epi64(2); + __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2); + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + for (int i = 0; i < 2; i++) { + // Load and preprocess 1st 4 rows + array512_way0_0 = _mm512_loadu_si512(src_addr0+idx_k); + array512_way0_1 = _mm512_loadu_si512(src_addr1+idx_k); + array512_way0_2 = _mm512_loadu_si512(src_addr2+idx_k); + array512_way0_3 = _mm512_loadu_si512(src_addr3+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1); + array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1); + array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3); + array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3); + array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 2nd 4 rows + array512_way1_0 = _mm512_loadu_si512(src_addr0+LDA_4x+idx_k); + array512_way1_1 = _mm512_loadu_si512(src_addr1+LDA_4x+idx_k); + array512_way1_2 = _mm512_loadu_si512(src_addr2+LDA_4x+idx_k); + array512_way1_3 = _mm512_loadu_si512(src_addr3+LDA_4x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1); + array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1); + array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3); + array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3); + array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 3rd 4 rows + array512_way2_0 = _mm512_loadu_si512(src_addr0+LDA_8x+idx_k); + array512_way2_1 = _mm512_loadu_si512(src_addr1+LDA_8x+idx_k); + array512_way2_2 = _mm512_loadu_si512(src_addr2+LDA_8x+idx_k); + array512_way2_3 = _mm512_loadu_si512(src_addr3+LDA_8x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1); + array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1); + array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3); + array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3); + array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 4th 4 rows + array512_way3_0 = _mm512_loadu_si512(src_addr0+LDA_12x+idx_k); + array512_way3_1 = _mm512_loadu_si512(src_addr1+LDA_12x+idx_k); + array512_way3_2 = _mm512_loadu_si512(src_addr2+LDA_12x+idx_k); + array512_way3_3 = _mm512_loadu_si512(src_addr3+LDA_12x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1); + array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1); + array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3); + array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3); + array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Compose and store the 0/1 and 16/17 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 2/3 and 18/19 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 4/5 and 20/21 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 6/7 and 22/23 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 8/9 and 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 10/11 and 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 12/13 and 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 14/15 and 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + + src_addr0 += LDA_16x; + src_addr1 += LDA_16x; + src_addr2 += LDA_16x; + src_addr3 += LDA_16x; + dst_addr0 -= (64*7 - 32); + dst_addr1 -= (64*7 - 32); + } + src_addr0 -= (LDA_16x*2); + src_addr1 -= (LDA_16x*2); + src_addr2 -= (LDA_16x*2); + src_addr3 -= (LDA_16x*2); + dst_addr0 += (32*30); + dst_addr1 += (32*30); } - if (tag_k_2x != k) { - __m256i ZERO256 = _mm256_setzero_si256(); - array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]); - array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256); - array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256); - // Store in one row of block_B - MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2); - MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3); + if (tag_k_32x != k) { + int k_rem = k - tag_k_32x; + unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem)); + __m512i array512[16]; + + bfloat16 * dst_addr_tmp = dst_addr0; + + for (int i = 0; i < 2; i++) { + // Load and preprocess 1st 4 rows + array512[0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[0], array512[1]); + array512_1 = _mm512_unpackhi_epi32(array512[0], array512[1]); + array512_2 = _mm512_unpacklo_epi32(array512[2], array512[3]); + array512_3 = _mm512_unpackhi_epi32(array512[2], array512[3]); + array512[0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[3] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 2nd 4 rows + array512[4] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[5] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[6] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[7] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[4], array512[5]); + array512_1 = _mm512_unpackhi_epi32(array512[4], array512[5]); + array512_2 = _mm512_unpacklo_epi32(array512[6], array512[7]); + array512_3 = _mm512_unpackhi_epi32(array512[6], array512[7]); + array512[4] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[5] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[6] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[7] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 3rd 4 rows + array512[8] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[9] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[10] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[11] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[8], array512[9]); + array512_1 = _mm512_unpackhi_epi32(array512[8], array512[9]); + array512_2 = _mm512_unpacklo_epi32(array512[10], array512[11]); + array512_3 = _mm512_unpackhi_epi32(array512[10], array512[11]); + array512[8] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[9] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[10] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[11] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 4th 4 rows + array512[12] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[13] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[14] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[15] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[12], array512[13]); + array512_1 = _mm512_unpackhi_epi32(array512[12], array512[13]); + array512_2 = _mm512_unpacklo_epi32(array512[14], array512[15]); + array512_3 = _mm512_unpackhi_epi32(array512[14], array512[15]); + array512[12] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[13] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[14] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[15] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // array512_01_1617_0, array512_01_1617_1, array512_89_2425_0, array512_89_2425_1; + // Half-compose of 0/1, 16/17, 8/9, 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512[0], permute_lo_idx, array512[4]); + array512_1 = _mm512_permutex2var_epi64(array512[8], permute_lo_idx, array512[12]); + array512_2 = _mm512_permutex2var_epi64(array512[0], permute_hi_idx, array512[4]); + array512_3 = _mm512_permutex2var_epi64(array512[8], permute_hi_idx, array512[12]); + array512[0] = array512_0; // 1st 8 pairs of col 0/1, and 1st 8 pairs of col 16/17 + array512[4] = array512_1; // 2nd 8 pairs of col 0/1, and 2nd 8 pairs of col 16/17 + array512[8] = array512_2; // 1st 8 pairs of col 8/9, and 1st 8 pairs of col 24/25 + array512[12] = array512_3; // 2nd 8 pairs of col 8/9, and 2nd 8 pairs of col 24/25 + + // Half-compose of 2/3, 18/19, 10/11, 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512[1], permute_lo_idx, array512[5]); + array512_1 = _mm512_permutex2var_epi64(array512[9], permute_lo_idx, array512[13]); + array512_2 = _mm512_permutex2var_epi64(array512[1], permute_hi_idx, array512[5]); + array512_3 = _mm512_permutex2var_epi64(array512[9], permute_hi_idx, array512[13]); + array512[1] = array512_0; // 1st 8 pairs of col 2/3, and 1st 8 pairs of col 18/19 + array512[5] = array512_1; // 2nd 8 pairs of col 2/3, and 2nd 8 pairs of col 18/19 + array512[9] = array512_2; // 1st 8 pairs of col 10/11, and 1st 8 pairs of col 26/27 + array512[13] = array512_3; // 2nd 8 pairs of col 10/11, and 2nd 8 pairs of col 26/27 + + // Half-compose of 4/5, 20/21, 12/13, 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512[2], permute_lo_idx, array512[6]); + array512_1 = _mm512_permutex2var_epi64(array512[10], permute_lo_idx, array512[14]); + array512_2 = _mm512_permutex2var_epi64(array512[2], permute_hi_idx, array512[6]); + array512_3 = _mm512_permutex2var_epi64(array512[10], permute_hi_idx, array512[14]); + array512[2] = array512_0; // 1st 8 pairs of col 4/5, and 1st 8 pairs of col 20/21 + array512[6] = array512_1; // 2nd 8 pairs of col 4/5, and 2nd 8 pairs of col 20/21 + array512[10] = array512_2; // 1st 8 pairs of col 12/13, and 1st 8 pairs of col 28/29 + array512[14] = array512_3; // 2nd 8 pairs of col 12/13, and 2nd 8 pairs of col 28/29 + + // Half-compose of 6/7, 22/23, 14/15, 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512[3], permute_lo_idx, array512[7]); + array512_1 = _mm512_permutex2var_epi64(array512[11], permute_lo_idx, array512[15]); + array512_2 = _mm512_permutex2var_epi64(array512[3], permute_hi_idx, array512[7]); + array512_3 = _mm512_permutex2var_epi64(array512[11], permute_hi_idx, array512[15]); + array512[3] = array512_0; // 1st 8 pairs of col 6/7, and 1st 8 pairs of col 22/23 + array512[7] = array512_1; // 2nd 8 pairs of col 6/7, and 2nd 8 pairs of col 22/23 + array512[11] = array512_2; // 1st 8 pairs of col 14/15, and 1st 8 pairs of col 30/31 + array512[15] = array512_3; // 2nd 8 pairs of col 14/15, and 2nd 8 pairs of col 30/31 + + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[0], _mm512_castsi512_si256(array512[4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 2/3 cols + array512_0 = _mm512_inserti64x4(array512[1], _mm512_castsi512_si256(array512[5]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 4/5 cols + array512_0 = _mm512_inserti64x4(array512[2], _mm512_castsi512_si256(array512[6]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 6/7 cols + array512_0 = _mm512_inserti64x4(array512[3], _mm512_castsi512_si256(array512[7]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 8/9 cols + array512_0 = _mm512_inserti64x4(array512[8], _mm512_castsi512_si256(array512[12]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 10/11 cols + array512_0 = _mm512_inserti64x4(array512[9], _mm512_castsi512_si256(array512[13]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 12/13 cols + array512_0 = _mm512_inserti64x4(array512[10], _mm512_castsi512_si256(array512[14]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 14/15 cols + array512_0 = _mm512_inserti64x4(array512[11], _mm512_castsi512_si256(array512[15]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store 16 ~ k_rem cols + int idx_length = (k_rem + 1 - 16) >> 1; + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } + + dst_addr0 = dst_addr_tmp + 32; + } + } +} + +// K=Any number but will be processed based on 32, 16> 1; + unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem)); + bfloat16 * dst_addr_tmp = dst_addr0; + + for (int j = 0; j < 4; j++) { + int array_idx = j*4; + // Load and preprocess 4 rows + array512[array_idx+0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[array_idx+1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[array_idx+2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[array_idx+3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]); + array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]); + array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + } + + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]); + array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]); + array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]); + array512[j+0] = array512_0; // 1st 8 pairs of col 0/1|2/3|4/5|6/7, and 1st 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+4] = array512_1; // 2nd 8 pairs of col 0/1|2/3|4/5|6/7, and 2nd 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+8] = array512_2; // 1st 8 pairs of col 8/9|10/11|12/13|14/15, and 1st 8 pairs of col 24/25|26/27|28/29|30/31 + array512[j+12] = array512_3; // 2nd 8 pairs of col 8/9|10/11|12/13|14/15, and 2nd 8 pairs of col 24/25|26/27|28/29|30/31 + } + + for (int j = 0; j < 4; j++) { + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int j = 8; j < 12; j++) { + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + // Compose and store 16 ~ k_rem cols + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } + + dst_addr0 = dst_addr_tmp + 32; + + for (int j = 0; j < m_rem; j++) { + array512[j] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+j*lda+tag_k_32x); + } + for (int j = m_rem; j < 16; j++) { + array512[j] = _mm512_setzero_si512(); + } + + for (int j = 0; j < 4; j++) { + int array_idx = j*4; + array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]); + array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]); + array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3); + } + + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]); + array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]); + array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]); + array512[j+0] = array512_0; // 1st 8 pairs of col 0/1|2/3|4/5|6/7, and 1st 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+4] = array512_1; // 2nd 8 pairs of col 0/1|2/3|4/5|6/7, and 2nd 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+8] = array512_2; // 1st 8 pairs of col 8/9|10/11|12/13|14/15, and 1st 8 pairs of col 24/25|26/27|28/29|30/31 + array512[j+12] = array512_3; // 2nd 8 pairs of col 8/9|10/11|12/13|14/15, and 2nd 8 pairs of col 24/25|26/27|28/29|30/31 + } + + for (int j = 0; j < 4; j++) { + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int j = 8; j < 12; j++) { + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + // Compose and store 16 ~ k_rem cols + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } + } +} + +// K=Any number but will be processed based on 32, M=16 +void COL_MAJOR_ITCOPY_KERNEL_Kx16(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) +{ + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * dst_addr0, * dst_addr1; + + BLASLONG tag_k_32x = k & (~31); + + BLASLONG LDA_4x = lda*4; + BLASLONG LDA_8x = lda*8; + BLASLONG LDA_12x = lda*12; + + src_addr0 = A; + src_addr1 = A + lda; + src_addr2 = A + lda*2; + src_addr3 = A + lda*3; + dst_addr0 = block_A; + dst_addr1 = block_A + 32*8; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3; + __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3; + __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3; + + __m512i M512_EPI64_2 = _mm512_set1_epi64(2); + __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2); + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + // Load and preprocess 1st 4 rows + array512_way0_0 = _mm512_loadu_si512(src_addr0+idx_k); + array512_way0_1 = _mm512_loadu_si512(src_addr1+idx_k); + array512_way0_2 = _mm512_loadu_si512(src_addr2+idx_k); + array512_way0_3 = _mm512_loadu_si512(src_addr3+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1); + array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1); + array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3); + array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3); + array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 2nd 4 rows + array512_way1_0 = _mm512_loadu_si512(src_addr0+LDA_4x+idx_k); + array512_way1_1 = _mm512_loadu_si512(src_addr1+LDA_4x+idx_k); + array512_way1_2 = _mm512_loadu_si512(src_addr2+LDA_4x+idx_k); + array512_way1_3 = _mm512_loadu_si512(src_addr3+LDA_4x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1); + array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1); + array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3); + array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3); + array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 3rd 4 rows + array512_way2_0 = _mm512_loadu_si512(src_addr0+LDA_8x+idx_k); + array512_way2_1 = _mm512_loadu_si512(src_addr1+LDA_8x+idx_k); + array512_way2_2 = _mm512_loadu_si512(src_addr2+LDA_8x+idx_k); + array512_way2_3 = _mm512_loadu_si512(src_addr3+LDA_8x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1); + array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1); + array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3); + array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3); + array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 4th 4 rows + array512_way3_0 = _mm512_loadu_si512(src_addr0+LDA_12x+idx_k); + array512_way3_1 = _mm512_loadu_si512(src_addr1+LDA_12x+idx_k); + array512_way3_2 = _mm512_loadu_si512(src_addr2+LDA_12x+idx_k); + array512_way3_3 = _mm512_loadu_si512(src_addr3+LDA_12x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1); + array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1); + array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3); + array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3); + array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Compose and store the 0/1 and 16/17 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 2/3 and 18/19 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 4/5 and 20/21 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 6/7 and 22/23 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 8/9 and 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 10/11 and 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 12/13 and 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 14/15 and 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32*9; + dst_addr1 += 32*9; + } + + if (tag_k_32x != k) { + int k_rem = k - tag_k_32x; + unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem)); + __m512i array512[16]; + + // Load and preprocess 1st 4 rows + array512[0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[0], array512[1]); + array512_1 = _mm512_unpackhi_epi32(array512[0], array512[1]); + array512_2 = _mm512_unpacklo_epi32(array512[2], array512[3]); + array512_3 = _mm512_unpackhi_epi32(array512[2], array512[3]); + array512[0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[3] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 2nd 4 rows + array512[4] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[5] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[6] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[7] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[4], array512[5]); + array512_1 = _mm512_unpackhi_epi32(array512[4], array512[5]); + array512_2 = _mm512_unpacklo_epi32(array512[6], array512[7]); + array512_3 = _mm512_unpackhi_epi32(array512[6], array512[7]); + array512[4] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[5] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[6] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[7] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 3rd 4 rows + array512[8] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[9] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[10] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[11] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[8], array512[9]); + array512_1 = _mm512_unpackhi_epi32(array512[8], array512[9]); + array512_2 = _mm512_unpacklo_epi32(array512[10], array512[11]); + array512_3 = _mm512_unpackhi_epi32(array512[10], array512[11]); + array512[8] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[9] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[10] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[11] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 4th 4 rows + array512[12] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[13] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[14] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[15] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[12], array512[13]); + array512_1 = _mm512_unpackhi_epi32(array512[12], array512[13]); + array512_2 = _mm512_unpacklo_epi32(array512[14], array512[15]); + array512_3 = _mm512_unpackhi_epi32(array512[14], array512[15]); + array512[12] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[13] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[14] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[15] = _mm512_unpackhi_epi64(array512_1, array512_3); + + // array512_01_1617_0, array512_01_1617_1, array512_89_2425_0, array512_89_2425_1; + // Half-compose of 0/1, 16/17, 8/9, 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512[0], permute_lo_idx, array512[4]); + array512_1 = _mm512_permutex2var_epi64(array512[8], permute_lo_idx, array512[12]); + array512_2 = _mm512_permutex2var_epi64(array512[0], permute_hi_idx, array512[4]); + array512_3 = _mm512_permutex2var_epi64(array512[8], permute_hi_idx, array512[12]); + array512[0] = array512_0; // 1st 8 pairs of col 0/1, and 1st 8 pairs of col 16/17 + array512[4] = array512_1; // 2nd 8 pairs of col 0/1, and 2nd 8 pairs of col 16/17 + array512[8] = array512_2; // 1st 8 pairs of col 8/9, and 1st 8 pairs of col 24/25 + array512[12] = array512_3; // 2nd 8 pairs of col 8/9, and 2nd 8 pairs of col 24/25 + + // Half-compose of 2/3, 18/19, 10/11, 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512[1], permute_lo_idx, array512[5]); + array512_1 = _mm512_permutex2var_epi64(array512[9], permute_lo_idx, array512[13]); + array512_2 = _mm512_permutex2var_epi64(array512[1], permute_hi_idx, array512[5]); + array512_3 = _mm512_permutex2var_epi64(array512[9], permute_hi_idx, array512[13]); + array512[1] = array512_0; // 1st 8 pairs of col 2/3, and 1st 8 pairs of col 18/19 + array512[5] = array512_1; // 2nd 8 pairs of col 2/3, and 2nd 8 pairs of col 18/19 + array512[9] = array512_2; // 1st 8 pairs of col 10/11, and 1st 8 pairs of col 26/27 + array512[13] = array512_3; // 2nd 8 pairs of col 10/11, and 2nd 8 pairs of col 26/27 + + // Half-compose of 4/5, 20/21, 12/13, 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512[2], permute_lo_idx, array512[6]); + array512_1 = _mm512_permutex2var_epi64(array512[10], permute_lo_idx, array512[14]); + array512_2 = _mm512_permutex2var_epi64(array512[2], permute_hi_idx, array512[6]); + array512_3 = _mm512_permutex2var_epi64(array512[10], permute_hi_idx, array512[14]); + array512[2] = array512_0; // 1st 8 pairs of col 4/5, and 1st 8 pairs of col 20/21 + array512[6] = array512_1; // 2nd 8 pairs of col 4/5, and 2nd 8 pairs of col 20/21 + array512[10] = array512_2; // 1st 8 pairs of col 12/13, and 1st 8 pairs of col 28/29 + array512[14] = array512_3; // 2nd 8 pairs of col 12/13, and 2nd 8 pairs of col 28/29 + + // Half-compose of 6/7, 22/23, 14/15, 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512[3], permute_lo_idx, array512[7]); + array512_1 = _mm512_permutex2var_epi64(array512[11], permute_lo_idx, array512[15]); + array512_2 = _mm512_permutex2var_epi64(array512[3], permute_hi_idx, array512[7]); + array512_3 = _mm512_permutex2var_epi64(array512[11], permute_hi_idx, array512[15]); + array512[3] = array512_0; // 1st 8 pairs of col 6/7, and 1st 8 pairs of col 22/23 + array512[7] = array512_1; // 2nd 8 pairs of col 6/7, and 2nd 8 pairs of col 22/23 + array512[11] = array512_2; // 1st 8 pairs of col 14/15, and 1st 8 pairs of col 30/31 + array512[15] = array512_3; // 2nd 8 pairs of col 14/15, and 2nd 8 pairs of col 30/31 + + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[0], _mm512_castsi512_si256(array512[4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 2/3 cols + array512_0 = _mm512_inserti64x4(array512[1], _mm512_castsi512_si256(array512[5]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 4/5 cols + array512_0 = _mm512_inserti64x4(array512[2], _mm512_castsi512_si256(array512[6]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 6/7 cols + array512_0 = _mm512_inserti64x4(array512[3], _mm512_castsi512_si256(array512[7]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 8/9 cols + array512_0 = _mm512_inserti64x4(array512[8], _mm512_castsi512_si256(array512[12]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 10/11 cols + array512_0 = _mm512_inserti64x4(array512[9], _mm512_castsi512_si256(array512[13]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 12/13 cols + array512_0 = _mm512_inserti64x4(array512[10], _mm512_castsi512_si256(array512[14]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 14/15 cols + array512_0 = _mm512_inserti64x4(array512[11], _mm512_castsi512_si256(array512[15]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store 16 ~ k_rem cols + int idx_length = (k_rem + 1 - 16) >> 1; + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + } + } +} + +// K=Any number but will be processed based on 32, M<=16 +void COL_MAJOR_ITCOPY_KERNEL_Kx16m(BLASLONG m, BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) +{ + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * dst_addr0, * dst_addr1; + + BLASLONG tag_k_32x = k & (~31); + + src_addr0 = A; + dst_addr0 = block_A; + dst_addr1 = block_A + 32*8; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512[16]; + + __m512i M512_EPI64_2 = _mm512_set1_epi64(2); + __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2); + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + for (int j = 0; j < m; j++) { + array512[j] = _mm512_loadu_si512(src_addr0+j*lda+idx_k); + } + for (int j = m; j < 16; j++) { + array512[j] = _mm512_setzero_si512(); + } + + for (int j = 0; j < 4; j++) { + int array_idx = j*4; + array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]); + array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]); + array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3); + } + + // Compose and store the 0/1, 2/3, 4/5, 6/7 and 16/17, 18/19, 20/21, 22/23 cols + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + } + + // Compose and store the 8/9, 10/11, 12/13, 14/15 and 24/25, 26/27, 28/29, 30/31 cols + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + } + + dst_addr0 += 32*8; + dst_addr1 += 32*8; } -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A); -#endif + if (tag_k_32x != k) { + int k_rem = k - tag_k_32x; + unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem)); + + for (int j = 0; j < m; j++) { + array512[j] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+j*lda+tag_k_32x); + } + for (int j = m; j < 16; j++) { + array512[j] = _mm512_setzero_si512(); + } + + for (int j = 0; j < 4; j++) { + int array_idx = j*4; + array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]); + array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]); + array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3); + } + + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]); + array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]); + array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]); + array512[j+0] = array512_0; // 1st 8 pairs of col 0/1|2/3|4/5|6/7, and 1st 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+4] = array512_1; // 2nd 8 pairs of col 0/1|2/3|4/5|6/7, and 2nd 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+8] = array512_2; // 1st 8 pairs of col 8/9|10/11|12/13|14/15, and 1st 8 pairs of col 24/25|26/27|28/29|30/31 + array512[j+12] = array512_3; // 2nd 8 pairs of col 8/9|10/11|12/13|14/15, and 2nd 8 pairs of col 24/25|26/27|28/29|30/31 + } + + for (int j = 0; j < 4; j++) { + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + + for (int j = 8; j < 12; j++) { + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + + // Compose and store 16 ~ k_rem cols + int idx_length = (k_rem + 1 - 16) >> 1; + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + } + } } +// COL_MAJOR_ONCOPY_KERNEL_16x32 behaves exactly the same as COL_MAJOR_ITCOPY_KERNEL_Kx16 +#define COL_MAJOR_ONCOPY_KERNEL_16x32 COL_MAJOR_ITCOPY_KERNEL_Kx16 + void COL_MAJOR_ONCOPY_KERNEL_8x32(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) { BLASLONG tag_k_32x = k & (~31); - BLASLONG idx_src_base0, idx_src_base1, idx_src_base2, idx_src_base3, idx_src_base4, idx_src_base5, idx_src_base6, idx_src_base7; - BLASLONG idx_target_base0; - idx_src_base0 = 0; - idx_src_base1 = 1*ldb; - idx_src_base2 = 2*ldb; - idx_src_base3 = 3*ldb; - idx_src_base4 = 4*ldb; - idx_src_base5 = 5*ldb; - idx_src_base6 = 6*ldb; - idx_src_base7 = 7*ldb; - idx_target_base0 = 0; + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3, * src_addr4, * src_addr5, * src_addr6, * src_addr7; + bfloat16 * dst_addr0; + + unsigned char blend_mask = (((unsigned char)0xcc)); + __m512i permute_idx = _mm512_set_epi64(13, 12, 7, 6, 9, 8, 3, 2); + + src_addr0 = B; + src_addr1 = src_addr0 + 1*ldb; + src_addr2 = src_addr0 + 2*ldb; + src_addr3 = src_addr0 + 3*ldb; + src_addr4 = src_addr0 + 4*ldb; + src_addr5 = src_addr0 + 5*ldb; + src_addr6 = src_addr0 + 6*ldb; + src_addr7 = src_addr0 + 7*ldb; + dst_addr0 = block_B; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3; for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base1+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_loadu_si512(&B[idx_src_base2+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_loadu_si512(&B[idx_src_base3+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_loadu_si512(&B[idx_src_base4+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_loadu_si512(&B[idx_src_base5+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_loadu_si512(&B[idx_src_base6+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_loadu_si512(&B[idx_src_base7+idx_k])); - idx_target_base0 += 32*8; + array512_0 = _mm512_loadu_si512(src_addr0+idx_k); + array512_1 = _mm512_loadu_si512(src_addr1+idx_k); + array512_2 = _mm512_loadu_si512(src_addr2+idx_k); + array512_3 = _mm512_loadu_si512(src_addr3+idx_k); + + array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2); + array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2); + array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3); + array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3); + + array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88); + array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd); + array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88); + array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd); + + array512_0 = _mm512_loadu_si512(src_addr4+idx_k); + array512_1 = _mm512_loadu_si512(src_addr5+idx_k); + array512_2 = _mm512_loadu_si512(src_addr6+idx_k); + array512_3 = _mm512_loadu_si512(src_addr7+idx_k); + + array512_way1_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way1_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way1_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way1_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way1_0, array512_way1_2); + array512_1 = _mm512_unpackhi_epi64(array512_way1_0, array512_way1_2); + array512_2 = _mm512_unpacklo_epi64(array512_way1_1, array512_way1_3); + array512_3 = _mm512_unpackhi_epi64(array512_way1_1, array512_way1_3); + + array512_way1_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x22); + array512_way1_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x77); + array512_way1_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x22); + array512_way1_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x77); + + array512_0 = _mm512_mask_blend_epi64(blend_mask, array512_way0_0, array512_way1_0); + array512_1 = _mm512_mask_blend_epi64(blend_mask, array512_way0_1, array512_way1_1); + array512_2 = _mm512_mask_blend_epi64(blend_mask, array512_way0_2, array512_way1_2); + array512_3 = _mm512_mask_blend_epi64(blend_mask, array512_way0_3, array512_way1_3); + _mm512_storeu_si512(dst_addr0, array512_0); + _mm512_storeu_si512(dst_addr0+32, array512_1); + _mm512_storeu_si512(dst_addr0+64, array512_2); + _mm512_storeu_si512(dst_addr0+96, array512_3); + + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way0_1, permute_idx, array512_way1_1); + array512_2 = _mm512_permutex2var_epi64(array512_way0_2, permute_idx, array512_way1_2); + array512_3 = _mm512_permutex2var_epi64(array512_way0_3, permute_idx, array512_way1_3); + _mm512_storeu_si512(dst_addr0+128, array512_0); + _mm512_storeu_si512(dst_addr0+160, array512_1); + _mm512_storeu_si512(dst_addr0+192, array512_2); + _mm512_storeu_si512(dst_addr0+224, array512_3); + + dst_addr0 += 256; } if (tag_k_32x != k) { unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x))); __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base1+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base2+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base3+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base4+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base5+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base6+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base7+tag_k_32x])); + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + + array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2); + array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2); + array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3); + array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3); + + array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88); + array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd); + array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88); + array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd); + + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr4+tag_k_32x); + array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr5+tag_k_32x); + array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr6+tag_k_32x); + array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr7+tag_k_32x); + + array512_way1_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way1_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way1_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way1_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way1_0, array512_way1_2); + array512_1 = _mm512_unpackhi_epi64(array512_way1_0, array512_way1_2); + array512_2 = _mm512_unpacklo_epi64(array512_way1_1, array512_way1_3); + array512_3 = _mm512_unpackhi_epi64(array512_way1_1, array512_way1_3); + + array512_way1_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x22); + array512_way1_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x77); + array512_way1_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x22); + array512_way1_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x77); + + + array512_0 = _mm512_mask_blend_epi64(blend_mask, array512_way0_0, array512_way1_0); + array512_1 = _mm512_mask_blend_epi64(blend_mask, array512_way0_1, array512_way1_1); + array512_2 = _mm512_mask_blend_epi64(blend_mask, array512_way0_2, array512_way1_2); + array512_3 = _mm512_mask_blend_epi64(blend_mask, array512_way0_3, array512_way1_3); + _mm512_storeu_si512(dst_addr0, array512_0); + _mm512_storeu_si512(dst_addr0+32, array512_1); + _mm512_storeu_si512(dst_addr0+64, array512_2); + _mm512_storeu_si512(dst_addr0+96, array512_3); + + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way0_1, permute_idx, array512_way1_1); + array512_2 = _mm512_permutex2var_epi64(array512_way0_2, permute_idx, array512_way1_2); + array512_3 = _mm512_permutex2var_epi64(array512_way0_3, permute_idx, array512_way1_3); + _mm512_storeu_si512(dst_addr0+128, array512_0); + _mm512_storeu_si512(dst_addr0+160, array512_1); + _mm512_storeu_si512(dst_addr0+192, array512_2); + _mm512_storeu_si512(dst_addr0+224, array512_3); + } +} + +void COL_MAJOR_ONCOPY_KERNEL_4x32(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) +{ + BLASLONG tag_k_32x = k & (~31); + + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * dst_addr0; + + src_addr0 = B; + src_addr1 = src_addr0 + 1*ldb; + src_addr2 = src_addr0 + 2*ldb; + src_addr3 = src_addr0 + 3*ldb; + dst_addr0 = block_B; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + array512_0 = _mm512_loadu_si512(src_addr0+idx_k); + array512_1 = _mm512_loadu_si512(src_addr1+idx_k); + array512_2 = _mm512_loadu_si512(src_addr2+idx_k); + array512_3 = _mm512_loadu_si512(src_addr3+idx_k); + + array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2); + array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2); + array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3); + array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3); + + array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88); + array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd); + array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88); + array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd); + + array512_0 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0x88); + array512_1 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0x88); + array512_2 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0xdd); + array512_3 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0xdd); + + _mm512_storeu_si512(dst_addr0, array512_0); + _mm512_storeu_si512(dst_addr0+32, array512_1); + _mm512_storeu_si512(dst_addr0+64, array512_2); + _mm512_storeu_si512(dst_addr0+96, array512_3); + + dst_addr0 += 128; } -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B); -#endif + if (tag_k_32x != k) { + unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x))); + __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + + array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2); + array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2); + array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3); + array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3); + + array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88); + array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd); + array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88); + array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd); + + array512_0 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0x88); + array512_1 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0x88); + array512_2 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0xdd); + array512_3 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0xdd); + + _mm512_storeu_si512(dst_addr0, array512_0); + _mm512_storeu_si512(dst_addr0+32, array512_1); + _mm512_storeu_si512(dst_addr0+64, array512_2); + _mm512_storeu_si512(dst_addr0+96, array512_3); + } } void COL_MAJOR_ONCOPY_KERNEL_Nx32(BLASLONG n, BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) { BLASLONG tag_k_32x = k & (~31); BLASLONG tag_n_2x = n & (~1); - BLASLONG idx_src_base0; - BLASLONG idx_target_base0; + + bfloat16 * src_addr0; + bfloat16 * dst_addr0; BLASLONG LDB_2x = 2*ldb; - idx_target_base0 = 0; + src_addr0 = B; + dst_addr0 = block_B; for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { - idx_src_base0 = 0; + src_addr0 = B; for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) { - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0 + idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base0 + ldb + idx_k])); - idx_src_base0 += LDB_2x; - idx_target_base0 += 64; + _mm512_storeu_si512(dst_addr0, _mm512_loadu_si512(src_addr0 + idx_k)); + _mm512_storeu_si512(dst_addr0 + 32, _mm512_loadu_si512(src_addr0 + ldb + idx_k)); + src_addr0 += LDB_2x; + dst_addr0 += 64; } if (tag_n_2x != n) { - _mm512_storeu_si512(&block_B[idx_target_base0], _mm512_loadu_si512(&B[idx_src_base0 + idx_k])); - idx_target_base0 += 32; + _mm512_storeu_si512(dst_addr0, _mm512_loadu_si512(src_addr0 + idx_k)); + dst_addr0 += 32; } } if (tag_k_32x != k) { unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x))); __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); - idx_src_base0 = 0; + src_addr0 = B; for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) { - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + ldb + tag_k_32x])); - idx_src_base0 += LDB_2x; - idx_target_base0 += 64; + _mm512_storeu_si512(dst_addr0, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + tag_k_32x)); + _mm512_storeu_si512(dst_addr0 + 32, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + ldb + tag_k_32x)); + src_addr0 += LDB_2x; + dst_addr0 += 64; } if (tag_n_2x != n) { - _mm512_storeu_si512(&block_B[idx_target_base0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + tag_k_32x])); + _mm512_storeu_si512(dst_addr0, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + tag_k_32x)); } } +} + +void COL_MAJOR_OTCOPY_KERNEL_Kx8(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) +{ + BLASLONG tag_k_2x = k & (~1); + unsigned char tail_mask_value = (unsigned char) 0xff; + __mmask8 tail_mask = *((__mmask8*) &tail_mask_value); -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B); -#endif + __m128i array128_0, array128_1, array128_2, array128_3; + + BLASLONG idx_src_base0, idx_src_base1; + BLASLONG idx_target_base0, idx_target_base1; + + BLASLONG LDA_2x = 2*ldb; + BLASLONG BF16_BLOCK_T_M_2x = 2*8; + idx_src_base0 = 0; + idx_src_base1 = ldb; + idx_target_base0 = 0; + idx_target_base1 = 8; + for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { + array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]); + array128_1 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base1]); + array128_2 = _mm_unpacklo_epi16(array128_0, array128_1); + array128_3 = _mm_unpackhi_epi16(array128_0, array128_1); + _mm_storeu_epi32(&block_B[idx_target_base0], array128_2); + _mm_storeu_epi32(&block_B[idx_target_base1], array128_3); + + idx_src_base0 += LDA_2x; + idx_src_base1 += LDA_2x; + idx_target_base0 += BF16_BLOCK_T_M_2x; + idx_target_base1 += BF16_BLOCK_T_M_2x; + } + + if (tag_k_2x != k) { + __m128i ZERO128 = _mm_setzero_si128(); + array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]); + array128_2 = _mm_unpacklo_epi16(array128_0, ZERO128); + array128_3 = _mm_unpackhi_epi16(array128_0, ZERO128); + _mm_storeu_epi32(&block_B[idx_target_base0], array128_2); + _mm_storeu_epi32(&block_B[idx_target_base1], array128_3); + } +} + +void COL_MAJOR_OTCOPY_KERNEL_Kx8m(BLASLONG k, BLASLONG n, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) +{ + BLASLONG tag_k_2x = k & (~1); + unsigned char tail_mask = (((unsigned char)0xff) >> (8-n)); + + __m128i array128_0, array128_1, array128_2, array128_3; + + BLASLONG idx_src_base0, idx_src_base1; + BLASLONG idx_target_base0, idx_target_base1; + + BLASLONG LDA_2x = 2*ldb; + BLASLONG BF16_BLOCK_T_M_2x = 2*8; + idx_src_base0 = 0; + idx_src_base1 = ldb; + idx_target_base0 = 0; + idx_target_base1 = 8; + for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { + array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]); + array128_1 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base1]); + array128_2 = _mm_unpacklo_epi16(array128_0, array128_1); + array128_3 = _mm_unpackhi_epi16(array128_0, array128_1); + _mm_storeu_epi32(&block_B[idx_target_base0], array128_2); + _mm_storeu_epi32(&block_B[idx_target_base1], array128_3); + + idx_src_base0 += LDA_2x; + idx_src_base1 += LDA_2x; + idx_target_base0 += BF16_BLOCK_T_M_2x; + idx_target_base1 += BF16_BLOCK_T_M_2x; + } + + if (tag_k_2x != k) { + __m128i ZERO128 = _mm_setzero_si128(); + array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]); + array128_2 = _mm_unpacklo_epi16(array128_0, ZERO128); + array128_3 = _mm_unpackhi_epi16(array128_0, ZERO128); + _mm_storeu_epi32(&block_B[idx_target_base0], array128_2); + _mm_storeu_epi32(&block_B[idx_target_base1], array128_3); + } } -// Scale matrix C while beta is not ZERO or ONE +// Scale matrix C when beta is not ZERO or ONE void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc) { - BLASLONG tag_n_Nx = N & (~3); - BLASLONG tag_n_Mx = M & (~15); + float * C_addr0 = C; + float * C_addr1 = C + ldc; + float * C_addr2 = C + ldc*2; + float * C_addr3 = C + ldc*3; BLASLONG LDC4x = ldc*4; - BLASLONG idx_base_0 = 0; - BLASLONG idx_base_1 = ldc; - BLASLONG idx_base_2 = ldc*2; - BLASLONG idx_base_3 = ldc*3; - - unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); __m512 array_512_0, array_512_1, array_512_2, array_512_3; + __m512 BETAVECTOR = _mm512_set1_ps(beta); - __m512 BETAVECTOR = _mm512_set1_ps(beta); + if (Order == CblasRowMajor) { + blasint tmp = M; + M = N; + N = tmp; + } - if (Order == CblasColMajor) { - for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) { - for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { - array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]); - array_512_1 = _mm512_loadu_ps(&C[idx_base_1+idx_m]); - array_512_2 = _mm512_loadu_ps(&C[idx_base_2+idx_m]); - array_512_3 = _mm512_loadu_ps(&C[idx_base_3+idx_m]); + BLASLONG tag_n_Nx = N & (~3); + BLASLONG tag_n_Mx = M & (~15); + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); + for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) { + for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { + array_512_0 = _mm512_loadu_ps(C_addr0 + idx_m); + array_512_1 = _mm512_loadu_ps(C_addr1 + idx_m); + array_512_2 = _mm512_loadu_ps(C_addr2 + idx_m); + array_512_3 = _mm512_loadu_ps(C_addr3 + idx_m); + + array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); + array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1); + array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2); + array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3); + + _mm512_storeu_ps(C_addr0 + idx_m, array_512_0); + _mm512_storeu_ps(C_addr1 + idx_m, array_512_1); + _mm512_storeu_ps(C_addr2 + idx_m, array_512_2); + _mm512_storeu_ps(C_addr3 + idx_m, array_512_3); + } - array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); - array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1); - array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2); - array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3); - - _mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0); - _mm512_storeu_ps(&C[idx_base_1+idx_m], array_512_1); - _mm512_storeu_ps(&C[idx_base_2+idx_m], array_512_2); - _mm512_storeu_ps(&C[idx_base_3+idx_m], array_512_3); - } + if (tag_n_Mx != M) { + array_512_0 = _mm512_maskz_loadu_ps(tail_mask, C_addr0 + tag_n_Mx); + array_512_1 = _mm512_maskz_loadu_ps(tail_mask, C_addr1 + tag_n_Mx); + array_512_2 = _mm512_maskz_loadu_ps(tail_mask, C_addr2 + tag_n_Mx); + array_512_3 = _mm512_maskz_loadu_ps(tail_mask, C_addr3 + tag_n_Mx); + + array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); + array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1); + array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2); + array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3); + + _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, array_512_0); + _mm512_mask_storeu_ps(C_addr1 + tag_n_Mx, tail_mask, array_512_1); + _mm512_mask_storeu_ps(C_addr2 + tag_n_Mx, tail_mask, array_512_2); + _mm512_mask_storeu_ps(C_addr3 + tag_n_Mx, tail_mask, array_512_3); + } - if (tag_n_Mx != M) { - array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]); - array_512_1 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_1+tag_n_Mx]); - array_512_2 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_2+tag_n_Mx]); - array_512_3 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_3+tag_n_Mx]); + C_addr0 += LDC4x; + C_addr1 += LDC4x; + C_addr2 += LDC4x; + C_addr3 += LDC4x; + } + if (tag_n_Nx != N) { + for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) { + for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { + array_512_0 = _mm512_loadu_ps(C_addr0 + idx_m); array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); - array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1); - array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2); - array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3); - - _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0); - _mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, array_512_1); - _mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, array_512_2); - _mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, array_512_3); + _mm512_storeu_ps(C_addr0 + idx_m, array_512_0); } - idx_base_0 += LDC4x; - idx_base_1 += LDC4x; - idx_base_2 += LDC4x; - idx_base_3 += LDC4x; - } - - if (tag_n_Nx != N) { - for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) { - for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { - array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]); - array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); - _mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0); - } - - if (tag_n_Mx != M) { - array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]); - array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); - _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0); - } - idx_base_0 += ldc; + if (tag_n_Mx != M) { + array_512_0 = _mm512_maskz_loadu_ps(tail_mask, C_addr0 + tag_n_Mx); + array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); + _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, array_512_0); } + C_addr0 += ldc; } - } else { - } } -// Scale matrix C while beta is not ZERO or ONE +// Zero C matrix when Beta is 0 void sbgemm_zero_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, float *C, OPENBLAS_CONST blasint ldc) { - BLASLONG tag_n_Nx = N & (~3); - BLASLONG tag_n_Mx = M & (~15); + float * C_addr0 = C; + float * C_addr1 = C + ldc; + float * C_addr2 = C + ldc*2; + float * C_addr3 = C + ldc*3; BLASLONG LDC4x = ldc*4; - BLASLONG idx_base_0 = 0; - BLASLONG idx_base_1 = ldc; - BLASLONG idx_base_2 = ldc*2; - BLASLONG idx_base_3 = ldc*3; - - unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); __m512 ZEROVECTOR = _mm512_setzero_ps(); - if (Order == CblasColMajor) { - for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) { - for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { - _mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR); - _mm512_storeu_ps(&C[idx_base_1+idx_m], ZEROVECTOR); - _mm512_storeu_ps(&C[idx_base_2+idx_m], ZEROVECTOR); - _mm512_storeu_ps(&C[idx_base_3+idx_m], ZEROVECTOR); - } + if (Order == CblasRowMajor) { + blasint tmp = M; + M = N; + N = tmp; + } - if (tag_n_Mx != M) { - _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR); - _mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, ZEROVECTOR); - _mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, ZEROVECTOR); - _mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, ZEROVECTOR); - } + BLASLONG tag_n_Nx = N & (~3); + BLASLONG tag_n_Mx = M & (~15); + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); + for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) { + for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { + _mm512_storeu_ps(C_addr0 + idx_m, ZEROVECTOR); + _mm512_storeu_ps(C_addr1 + idx_m, ZEROVECTOR); + _mm512_storeu_ps(C_addr2 + idx_m, ZEROVECTOR); + _mm512_storeu_ps(C_addr3 + idx_m, ZEROVECTOR); + } - idx_base_0 += LDC4x; - idx_base_1 += LDC4x; - idx_base_2 += LDC4x; - idx_base_3 += LDC4x; + if (tag_n_Mx != M) { + _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, ZEROVECTOR); + _mm512_mask_storeu_ps(C_addr1 + tag_n_Mx, tail_mask, ZEROVECTOR); + _mm512_mask_storeu_ps(C_addr2 + tag_n_Mx, tail_mask, ZEROVECTOR); + _mm512_mask_storeu_ps(C_addr3 + tag_n_Mx, tail_mask, ZEROVECTOR); } - if (tag_n_Nx != N) { - for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) { - for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { - _mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR); - } + C_addr0 += LDC4x; + C_addr1 += LDC4x; + C_addr2 += LDC4x; + C_addr3 += LDC4x; + } - if (tag_n_Mx != M) { - _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR); - } - idx_base_0 += ldc; + if (tag_n_Nx != N) { + for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) { + for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { + _mm512_storeu_ps(C_addr0 + idx_m, ZEROVECTOR); } - } - } else { + if (tag_n_Mx != M) { + _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, ZEROVECTOR); + } + C_addr0 += ldc; + } } -} \ No newline at end of file +} diff --git a/kernel/x86_64/sbgemm_microk_cooperlake_template.c b/kernel/x86_64/sbgemm_microk_cooperlake_template.c index dd4cb44..c715958 100644 --- a/kernel/x86_64/sbgemm_microk_cooperlake_template.c +++ b/kernel/x86_64/sbgemm_microk_cooperlake_template.c @@ -2,45 +2,115 @@ #include "bf16_common_macros.h" #include +/* These macros are needed and should be placed at the right place +#define BF16_BLOCK_STEP_N 8 +#define BF16_BLOCK_THRES_K 1024 +#define BF16_BLOCK_THRES_M 32 +#define BF16_BLOCK_THRES_N 1024 + +#define A(i,j) A[(i)*lda+(j)] +#define B(i,j) B[(i)*ldb+(j)] +#define C(i,j) C[(i)*ldc+(j)] + +#define ONE 1.e0f +#define ZERO 0.e0f +*/ + #undef STORE16_COMPLETE_RESULT #undef STORE16_MASK_COMPLETE_RESULT -#undef SBGEMM_BLOCK_KERNEL_32x8x32 -#undef SBGEMM_BLOCK_KERNEL_16x8x32 -#undef SBGEMM_BLOCK_KERNEL_32xNx32 -#undef SBGEMM_BLOCK_KERNEL_16xNx32 -#undef SBGEMM_BLOCKING_KERNEL_2 +#undef SBGEMM_BLOCK_KERNEL_NN_32x8xK +#undef SBGEMM_BLOCK_KERNEL_NN_16x8xK +#undef SBGEMM_BLOCK_KERNEL_NN_32xNx32 +#undef SBGEMM_BLOCK_KERNEL_NN_16xNx32 +#undef SBGEMM_BLOCK_KERNEL_NT_32x8xK +#undef SBGEMM_BLOCK_KERNEL_NT_16x8xK +#undef SBGEMM_BLOCK_KERNEL_NT_32xNxK +#undef SBGEMM_BLOCK_KERNEL_NT_16xNxK +#undef SBGEMM_BLOCK_KERNEL_TN_32x8xK +#undef SBGEMM_BLOCK_KERNEL_TN_16x8xK +#undef SBGEMM_BLOCK_KERNEL_TN_32xNx32 +#undef SBGEMM_BLOCK_KERNEL_TN_16xNx32 +#undef SBGEMM_BLOCK_KERNEL_TT_32x8xK +#undef SBGEMM_BLOCK_KERNEL_TT_16x8xK +#undef SBGEMM_BLOCK_KERNEL_TT_32xNxK +#undef SBGEMM_BLOCK_KERNEL_TT_16xNxK +#undef SBGEMM_BLOCKING_KERNEL_NN +#undef SBGEMM_BLOCKING_KERNEL_NT +#undef SBGEMM_BLOCKING_KERNEL_TN +#undef SBGEMM_BLOCKING_KERNEL_TT #ifndef ONE_ALPHA // ALPHA is not ONE - #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE - #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE - #define SBGEMM_BLOCK_KERNEL_32x8x32 sbgemm_block_kernel_32x8x32_alpha - #define SBGEMM_BLOCK_KERNEL_16x8x32 sbgemm_block_kernel_16x8x32_alpha - #define SBGEMM_BLOCK_KERNEL_32xNx32 sbgemm_block_kernel_32xNx32_alpha - #define SBGEMM_BLOCK_KERNEL_16xNx32 sbgemm_block_kernel_16xNx32_alpha - #define SBGEMM_BLOCKING_KERNEL_2 sbgemm_blocking_kernel_2_alpha + #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE + #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE + + #define SBGEMM_BLOCK_KERNEL_NN_32x8xK sbgemm_block_kernel_nn_32x8xK_alpha + #define SBGEMM_BLOCK_KERNEL_NN_16x8xK sbgemm_block_kernel_nn_16x8xK_alpha + #define SBGEMM_BLOCK_KERNEL_NN_32xNx32 sbgemm_block_kernel_nn_32xNx32_alpha + #define SBGEMM_BLOCK_KERNEL_NN_16xNx32 sbgemm_block_kernel_nn_16xNx32_alpha + + #define SBGEMM_BLOCK_KERNEL_NT_32x8xK SBGEMM_BLOCK_KERNEL_NN_32x8xK + #define SBGEMM_BLOCK_KERNEL_NT_16x8xK SBGEMM_BLOCK_KERNEL_NN_16x8xK + #define SBGEMM_BLOCK_KERNEL_NT_32xNxK sbgemm_block_kernel_nt_32xNxK_alpha + #define SBGEMM_BLOCK_KERNEL_NT_16xNxK sbgemm_block_kernel_nt_16xNxK_alpha + + #define SBGEMM_BLOCK_KERNEL_TN_32x8xK sbgemm_block_kernel_tn_32x8xK_alpha + #define SBGEMM_BLOCK_KERNEL_TN_16x8xK sbgemm_block_kernel_tn_16x8xK_alpha + #define SBGEMM_BLOCK_KERNEL_TN_32xNx32 sbgemm_block_kernel_tn_32xNx32_alpha + #define SBGEMM_BLOCK_KERNEL_TN_16xNx32 sbgemm_block_kernel_tn_16xNx32_alpha + + #define SBGEMM_BLOCK_KERNEL_TT_32x8xK SBGEMM_BLOCK_KERNEL_TN_32x8xK + #define SBGEMM_BLOCK_KERNEL_TT_16x8xK SBGEMM_BLOCK_KERNEL_TN_16x8xK + #define SBGEMM_BLOCK_KERNEL_TT_32xNxK sbgemm_block_kernel_tt_32xNxK_alpha + #define SBGEMM_BLOCK_KERNEL_TT_16xNxK sbgemm_block_kernel_tt_16xNxK_alpha + + #define SBGEMM_BLOCKING_KERNEL_NN sbgemm_blocking_kernel_nn_alpha + #define SBGEMM_BLOCKING_KERNEL_NT sbgemm_blocking_kernel_nt_alpha + #define SBGEMM_BLOCKING_KERNEL_TN sbgemm_blocking_kernel_tn_alpha + #define SBGEMM_BLOCKING_KERNEL_TT sbgemm_blocking_kernel_tt_alpha #else // ALPHA is ONE - #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ONE_ONE - #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ONE_ONE - #define SBGEMM_BLOCK_KERNEL_32x8x32 sbgemm_block_kernel_32x8x32_one - #define SBGEMM_BLOCK_KERNEL_16x8x32 sbgemm_block_kernel_16x8x32_one - #define SBGEMM_BLOCK_KERNEL_32xNx32 sbgemm_block_kernel_32xNx32_one - #define SBGEMM_BLOCK_KERNEL_16xNx32 sbgemm_block_kernel_16xNx32_one - #define SBGEMM_BLOCKING_KERNEL_2 sbgemm_blocking_kernel_2_one + #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ONE_ONE + #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ONE_ONE + + #define SBGEMM_BLOCK_KERNEL_NN_32x8xK sbgemm_block_kernel_nn_32x8xK_one + #define SBGEMM_BLOCK_KERNEL_NN_16x8xK sbgemm_block_kernel_nn_16x8xK_one + #define SBGEMM_BLOCK_KERNEL_NN_32xNx32 sbgemm_block_kernel_nn_32xNx32_one + #define SBGEMM_BLOCK_KERNEL_NN_16xNx32 sbgemm_block_kernel_nn_16xNx32_one + + #define SBGEMM_BLOCK_KERNEL_NT_32x8xK SBGEMM_BLOCK_KERNEL_NN_32x8xK + #define SBGEMM_BLOCK_KERNEL_NT_16x8xK SBGEMM_BLOCK_KERNEL_NN_16x8xK + #define SBGEMM_BLOCK_KERNEL_NT_32xNxK sbgemm_block_kernel_nt_32xNxK_one + #define SBGEMM_BLOCK_KERNEL_NT_16xNxK sbgemm_block_kernel_nt_16xNxK_one + + #define SBGEMM_BLOCK_KERNEL_TN_32x8xK sbgemm_block_kernel_tn_32x8xK_one + #define SBGEMM_BLOCK_KERNEL_TN_16x8xK sbgemm_block_kernel_tn_16x8xK_one + #define SBGEMM_BLOCK_KERNEL_TN_32xNx32 sbgemm_block_kernel_tn_32xNx32_one + #define SBGEMM_BLOCK_KERNEL_TN_16xNx32 sbgemm_block_kernel_tn_16xNx32_one + + #define SBGEMM_BLOCK_KERNEL_TT_32x8xK SBGEMM_BLOCK_KERNEL_TN_32x8xK + #define SBGEMM_BLOCK_KERNEL_TT_16x8xK SBGEMM_BLOCK_KERNEL_TN_16x8xK + #define SBGEMM_BLOCK_KERNEL_TT_32xNxK sbgemm_block_kernel_tt_32xNxK_one + #define SBGEMM_BLOCK_KERNEL_TT_16xNxK sbgemm_block_kernel_tt_16xNxK_one + + #define SBGEMM_BLOCKING_KERNEL_NN sbgemm_blocking_kernel_nn_one + #define SBGEMM_BLOCKING_KERNEL_NT sbgemm_blocking_kernel_nt_one + #define SBGEMM_BLOCKING_KERNEL_TN sbgemm_blocking_kernel_tn_one + #define SBGEMM_BLOCKING_KERNEL_TT sbgemm_blocking_kernel_tt_one #endif +extern bfloat16 * block_A; +extern bfloat16 * block_B; +/* --------------------------------------------- NN kernels ------------------------------------------ */ // SBGEMM Kernel for 16> (16-m)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8); result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8); result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8); result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8); - STORE16_MASK_COMPLETE_RESULT(result_512_0, (&C[ldc*0]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_1, (&C[ldc*1]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_2, (&C[ldc*2]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_3, (&C[ldc*3]), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_0, (C_addr), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3), tail_mask) result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8); result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8); result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8); result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8); - STORE16_MASK_COMPLETE_RESULT(result_512_4, (&C[ldc*4]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_5, (&C[ldc*5]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_6, (&C[ldc*6]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_7, (&C[ldc*7]), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7), tail_mask) } else { result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8); result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8); result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8); result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8); - STORE16_COMPLETE_RESULT(result_512_0, (&C[ldc*0])) - STORE16_COMPLETE_RESULT(result_512_1, (&C[ldc*1])) - STORE16_COMPLETE_RESULT(result_512_2, (&C[ldc*2])) - STORE16_COMPLETE_RESULT(result_512_3, (&C[ldc*3])) + STORE16_COMPLETE_RESULT(result_512_0, (C_addr)) + STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1)) + STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2)) + STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3)) result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8); result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8); result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8); result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8); - STORE16_COMPLETE_RESULT(result_512_4, (&C[ldc*4])) - STORE16_COMPLETE_RESULT(result_512_5, (&C[ldc*5])) - STORE16_COMPLETE_RESULT(result_512_6, (&C[ldc*6])) - STORE16_COMPLETE_RESULT(result_512_7, (&C[ldc*7])) + STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4)) + STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5)) + STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6)) + STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7)) } } // SBGEMM Kernel for 16> (32-m)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m)); for (int i = 0; i < n; i++) { result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]); result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]); - STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i])) - STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16]), tail_mask) + STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i)) + STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask) } } else { for (int i = 0; i < n; i++) { result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]); result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]); - STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i])) - STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16])) + STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i)) + STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16)) } } } // SBGEMM Kernel for 16<=M, N<8, K can be any number, but the processing will take 32 as a base #ifndef ONE_ALPHA // ALPHA is not ONE -void sbgemm_block_kernel_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +void sbgemm_block_kernel_nn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) #else // ALPHA is ONE -void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +void sbgemm_block_kernel_nn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) #endif { + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + int SHUFFLE_MAGIC_NO = 0x39; BLASLONG tag_k_32x = k & (~31); - BLASLONG idxB_base = 0; - BLASLONG width = 32; #ifndef ONE_ALPHA __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); @@ -432,21 +484,49 @@ void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float a result_512[i+1] = _mm512_setzero_ps(); } - for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) { + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { // Load B with unroll n - for (int i = 0; i < n; i ++) { - arrayB_512[i] = _mm512_loadu_si512(&B[idxB_base]); - idxB_base += 32; + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_loadu_si512(B_addr); + B_addr += 32; + } + + for (BLASLONG idx = 0; idx < 32;) { + // Each two rows are a group for 32-pair bf16 elements + // Load two rows into a 512 register + arrayA_512 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i ++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i]))); + arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO); + } + + idx += 2; + // Every 4 loops we need to switch to next 128 bits of arrayB registers + if ((idx & (~7)) == idx) { + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO); + } + } } + } - if (idx_k == tag_k_32x) {width = k - tag_k_32x;} + if (tag_k_32x != k) { + // Load B with unroll n + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_loadu_si512(B_addr); + B_addr += 32; + } + BLASLONG width = k - tag_k_32x; for (BLASLONG idx = 0; idx < width;) { // Each two rows are a group for 32-pair bf16 elements // Load two rows into a 512 register - arrayA_512 = _mm512_loadu_si512(&A[idx<<4]); + arrayA_512 = _mm512_loadu_si512(A_addr); + A_addr += 32; - for (int i = 0; i < n; i ++) { + for (int i = 0; i < n; i++) { result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i]))); arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO); } @@ -462,23 +542,24 @@ void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float a } if (m != 16) { - unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); for (int i = 0; i < n; i++) { result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8); - STORE16_MASK_COMPLETE_RESULT(result_512[i], (&C[ldc*i]), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask) } } else { for (int i = 0; i < n; i++) { result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8); - STORE16_COMPLETE_RESULT(result_512[i], (&C[ldc*i])) + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) } } } + + #ifndef ONE_ALPHA // ALPHA is not ONE -void sbgemm_blocking_kernel_2_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +void sbgemm_blocking_kernel_nn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) #else // ALPHA is ONE -void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +void sbgemm_blocking_kernel_nn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) #endif { BLASLONG m_step, n_step, k_step, k_step_round32; @@ -499,63 +580,52 @@ void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, while (n_from < N) { for (BLASLONG idx_k = 0; idx_k < K;) { // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... - COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, 0), lda, block_A); - // TODO: MT + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); - SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); - SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); } for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) { - COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, idx_m), lda, block_A); + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { - SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); + SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; - SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); + SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); } } if (tag_m_Nx != M) { m_step = M - tag_m_Nx; if (m_step > 16) { - COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); - for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { - SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); - } - - if (tag_n_Nx != n_to) { - n_step = n_to - tag_n_Nx; - SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); - } - } else if (m_step == 16) { - COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { - SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; - SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); } } else { - COL_MAJOR_INCOPY_KERNEL_Kx16m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); + COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { - SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; - SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); } } } @@ -573,22 +643,274 @@ void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); } } else { - m_step = M - tag_m_Nx; + m_step = M; + if (m_step > 16) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } + } +} +/* ----------------------------------------- End of NN kernels --------------------------------------- */ + +/* --------------------------------------------- NT kernels ------------------------------------------ */ +// SBGEMM Kernel for 16> (32-m)); + for (int i = 0; i < n; i ++) { + result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]); + result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]); + STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i)) + STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask) + } + } else { + for (int i = 0; i < n; i ++) { + result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]); + result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]); + STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i)) + STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16)) + } + } +} + +// SBGEMM Kernel for M<=16, N<8, K can be any number +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_block_kernel_nt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#else // ALPHA is ONE +void sbgemm_block_kernel_nt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#endif +{ + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + +#ifndef ONE_ALPHA + __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); +#endif + + __m512i arrayA_512_0; + __m512i arrayB_512[8]; + __m512 result_512[8]; + + result_512[0] = _mm512_setzero_ps(); + result_512[1] = _mm512_setzero_ps(); + result_512[2] = _mm512_setzero_ps(); + result_512[3] = _mm512_setzero_ps(); + result_512[4] = _mm512_setzero_ps(); + result_512[5] = _mm512_setzero_ps(); + result_512[6] = _mm512_setzero_ps(); + result_512[7] = _mm512_setzero_ps(); + + for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) { + // Each two rows are a group for 16-pair bf16 elements + // Load two rows into a 512 register + arrayA_512_0 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i ++) { + _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]); + } + B_addr += 16; + + for (int i = 0; i < n; i ++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]); + } + } + + if (m != 16) { + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); + for (int i = 0; i < n; i++) { + result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8); + STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask) + } + } else { + for (int i = 0; i < n; i++) { + result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8); + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + } + } +} + +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_blocking_kernel_nt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#else // ALPHA is ONE +void sbgemm_blocking_kernel_nt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#endif +{ + BLASLONG m_step, n_step, k_step, k_step_round32; + BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1)); + + BLASLONG n_from, n_to; + BLASLONG tag_n_Nx; + + n_from = 0; + n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + + k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + + if (M >= BF16_BLOCK_THRES_M) { while (n_from < N) { for (BLASLONG idx_k = 0; idx_k < K;) { // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... - COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, 0), lda, block_A); - // TODO: MT + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... - COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); - SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; - COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); - SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) { + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); + } + } + + if (tag_m_Nx != M) { + m_step = M - tag_m_Nx; + if (m_step > 16) { + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } else { + COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } } idx_k += k_step; @@ -597,13 +919,884 @@ void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, k_step_round32 = k_step & (~31); k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; } + n_from = n_to; n_to += BF16_BLOCK_THRES_N; n_to = (n_to > N) ? N : n_to; tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); } + } else { + m_step = M; + if (m_step > 16) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } } } +/* ----------------------------------------- End of NT kernels --------------------------------------- */ + +/* --------------------------------------------- TN kernels ------------------------------------------ */ +// SBGEMM Kernel for 16> (32-m)); + __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + STORE16_COMPLETE_RESULT(result_512_0, (C_addr)) + STORE16_MASK_COMPLETE_RESULT(result_512_8, (C_addr + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc)) + STORE16_MASK_COMPLETE_RESULT(result_512_9, (C_addr + ldc + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2)) + STORE16_MASK_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3)) + STORE16_MASK_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4)) + STORE16_MASK_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5)) + STORE16_MASK_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6)) + STORE16_MASK_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7)) + STORE16_MASK_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16), tail_mask) + } else { + STORE16_COMPLETE_RESULT(result_512_0, (C_addr)) + STORE16_COMPLETE_RESULT(result_512_8, (C_addr + 16)) + STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc)) + STORE16_COMPLETE_RESULT(result_512_9, (C_addr + ldc + 16)) + STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2)) + STORE16_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16)) + STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3)) + STORE16_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16)) + STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4)) + STORE16_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16)) + STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5)) + STORE16_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16)) + STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6)) + STORE16_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16)) + STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7)) + STORE16_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16)) + } +} + +// SBGEMM Kernel for M=16, N=8, K=Any number +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_block_kernel_tn_16x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#else // ALPHA is ONE +void sbgemm_block_kernel_tn_16x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#endif +{ + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + +#ifndef ONE_ALPHA + __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); +#endif + + __m512i arrayA_512_0; + __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7; + __m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7; + + result_512_0 = _mm512_setzero_ps(); + result_512_1 = _mm512_setzero_ps(); + result_512_2 = _mm512_setzero_ps(); + result_512_3 = _mm512_setzero_ps(); + result_512_4 = _mm512_setzero_ps(); + result_512_5 = _mm512_setzero_ps(); + result_512_6 = _mm512_setzero_ps(); + result_512_7 = _mm512_setzero_ps(); + + for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) { + // Load 16 pair of BF16 elements from A (16 rows) + arrayA_512_0 = _mm512_loadu_si512(A_addr + 0); + + // Load 8 rows of B + _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0); + _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1); + _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2); + _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3); + _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4); + _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5); + _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6); + _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7); + + result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0); + result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1); + result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2); + result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3); + result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4); + result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5); + result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6); + result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7); + + // Load B with unroll 8 + B_addr += 16; + // Load A with unroll 32 + A_addr += 32; + } + + if (m != 16) { + unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m)); + __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + STORE16_MASK_COMPLETE_RESULT(result_512_0, (C_addr), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_1, (C_addr + ldc), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7), tail_mask) + } else { + STORE16_COMPLETE_RESULT(result_512_0, (C_addr)) + STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc)) + STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2)) + STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3)) + STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4)) + STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5)) + STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6)) + STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7)) + } +} + +// SBGEMM Kernel for 16> (32-m)); + for (int i = 0; i < n; i++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask) + } + } else { + for (int i = 0; i < n; i++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16)) + } + } +} + +// SBGEMM Kernel for M<=16, N<8, K=Any number but will be processed based on 32 +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_block_kernel_tn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#else // ALPHA is ONE +void sbgemm_block_kernel_tn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#endif +{ + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + + int SHUFFLE_MAGIC_NO = 0x39; + BLASLONG tag_k_32x = k & (~31); + +#ifndef ONE_ALPHA + __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); +#endif + + __m512i arrayA_512; + __m512i arrayB_512[8]; + __m512 result_512[8]; + + for (int i = 0; i < 8; i++) { + result_512[i] = _mm512_setzero_ps(); + } + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + // Load B with unroll n + for (int i = 0; i < n; i ++) { + arrayB_512[i] = _mm512_loadu_si512(B_addr); + B_addr += 32; + } + + for (BLASLONG idx = 0; idx < 32;) { + // Each two rows are a group for 32-pair bf16 elements + arrayA_512 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i]))); + arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO); + } + + idx += 2; + // Every 4 loops we need to switch to next 128 bits of arrayB registers + if ((idx & (~7)) == idx) { + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO); + } + } + } + } + + if (tag_k_32x != k) { + // Load B with unroll n + for (int i = 0; i < n; i ++) { + arrayB_512[i] = _mm512_loadu_si512(B_addr); + B_addr += 32; + } + + BLASLONG width = k - tag_k_32x; + for (BLASLONG idx = 0; idx < width;) { + // Each two rows are a group for 32-pair bf16 elements + arrayA_512 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i]))); + arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO); + } + + idx += 2; + // Every 4 loops we need to switch to next 128 bits of arrayB registers + if ((idx & (~7)) == idx) { + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO); + } + } + } + } + + if (m != 16) { + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); + for (int i = 0; i < n; i++) { + STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask) + } + } else { + for (int i = 0; i < n; i++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + } + } +} + +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_blocking_kernel_tn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#else // ALPHA is ONE +void sbgemm_blocking_kernel_tn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#endif +{ + BLASLONG m_step, n_step, k_step, k_step_round32; + BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1)); + + BLASLONG n_from, n_to; + BLASLONG tag_n_Nx; + + n_from = 0; + n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + + k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + + if (M >= BF16_BLOCK_THRES_M) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); // TODO how to process m + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) { + COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); + } + } + + if (tag_m_Nx != M) { + m_step = M - tag_m_Nx; + if (m_step > 16) { + COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } else { + COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + m_step = M; + if (m_step > 16) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } + } +} +/* ----------------------------------------- End of TN kernels --------------------------------------- */ + +/* --------------------------------------------- TT kernels ------------------------------------------ */ +// SBGEMM Kernel for 16> (32-m)); + for (int i = 0; i < n; i ++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask) + } + } else { + for (int i = 0; i < n; i ++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16)) + } + } +} + +// SBGEMM Kernel for M<=16, N<8, K can be any number +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_block_kernel_tt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#else // ALPHA is ONE +void sbgemm_block_kernel_tt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#endif +{ + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + +#ifndef ONE_ALPHA + __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); +#endif + + __m512i arrayA_512_0; + __m512i arrayB_512[8]; + __m512 result_512[8]; + + result_512[0] = _mm512_setzero_ps(); + result_512[1] = _mm512_setzero_ps(); + result_512[2] = _mm512_setzero_ps(); + result_512[3] = _mm512_setzero_ps(); + result_512[4] = _mm512_setzero_ps(); + result_512[5] = _mm512_setzero_ps(); + result_512[6] = _mm512_setzero_ps(); + result_512[7] = _mm512_setzero_ps(); + + for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) { + // Each two rows are a group for 16-pair bf16 elements + // Load two rows into a 512 register + arrayA_512_0 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i ++) { + _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]); + } + B_addr += 16; + + for (int i = 0; i < n; i ++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]); + } + } + + if (m != 16) { + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); + for (int i = 0; i < n; i++) { + STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask) + } + } else { + for (int i = 0; i < n; i++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + } + } +} + +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_blocking_kernel_tt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#else // ALPHA is ONE +void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#endif +{ + BLASLONG m_step, n_step, k_step, k_step_round32; + BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1)); + + BLASLONG n_from, n_to; + BLASLONG tag_n_Nx; + + n_from = 0; + n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + + k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + + if (M >= BF16_BLOCK_THRES_M) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) { + COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); + } + } + + if (tag_m_Nx != M) { + m_step = M - tag_m_Nx; + if (m_step > 16) { + COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } else { + COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + m_step = M; + if (m_step > 16) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } + } +} +/* ----------------------------------------- End of TT kernels --------------------------------------- */ #ifndef ONE_ALPHA // ALPHA is not ONE void sbgemm_internal_kernel_alpha(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, @@ -613,13 +1806,33 @@ void sbgemm_internal_kernel_one(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_ OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc) #endif { - bfloat16 block_A[BF16_BLOCK_THRES_K * BF16_BLOCK_THRES_M]; - bfloat16 block_B[BF16_BLOCK_THRES_N * BF16_BLOCK_THRES_K]; - - // TODO: assume no trans for both A and B, to complement these scenarios later if (Order == CblasColMajor) { - SBGEMM_BLOCKING_KERNEL_2(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + if (TransA == CblasNoTrans) { + if (TransB == CblasNoTrans) { + SBGEMM_BLOCKING_KERNEL_NN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } else if (TransB == CblasTrans) { + SBGEMM_BLOCKING_KERNEL_NT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } + } else { + if (TransB == CblasNoTrans) { + SBGEMM_BLOCKING_KERNEL_TN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } else if (TransB == CblasTrans) { + SBGEMM_BLOCKING_KERNEL_TT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } + } } else { - + if (TransA == CblasNoTrans) { + if (TransB == CblasNoTrans) { + SBGEMM_BLOCKING_KERNEL_NN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B); + } else if (TransB == CblasTrans) { + SBGEMM_BLOCKING_KERNEL_TN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B); + } + } else { + if (TransB == CblasNoTrans) { + SBGEMM_BLOCKING_KERNEL_NT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B); + } else if (TransB == CblasTrans) { + SBGEMM_BLOCKING_KERNEL_TT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B); + } + } } -} \ No newline at end of file +} -- 2.7.4