--- /dev/null
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+#ifndef __BF16_COMMON_MACROS
+#define __BF16_COMMON_MACROS
+
+#include <immintrin.h>
+
+#define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \
+ reg256##_0 = _mm512_castps512_ps256(reg512##_0); \
+ reg256##_1 = _mm512_castps512_ps256(reg512##_1);
+
+
+#define BF16_MATRIX_LOAD_8x32(regArray, a, lda, idx_m, idx_n) \
+ regArray##_0 = _mm512_loadu_si512(&a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm512_loadu_si512(&a[(idx_m+1)*lda + idx_n]); \
+ regArray##_2 = _mm512_loadu_si512(&a[(idx_m+2)*lda + idx_n]); \
+ regArray##_3 = _mm512_loadu_si512(&a[(idx_m+3)*lda + idx_n]); \
+ regArray##_4 = _mm512_loadu_si512(&a[(idx_m+4)*lda + idx_n]); \
+ regArray##_5 = _mm512_loadu_si512(&a[(idx_m+5)*lda + idx_n]); \
+ regArray##_6 = _mm512_loadu_si512(&a[(idx_m+6)*lda + idx_n]); \
+ regArray##_7 = _mm512_loadu_si512(&a[(idx_m+7)*lda + idx_n]);
+
+
+#define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \
+ regArray##_0 = _mm256_loadu_si256(&a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm256_loadu_si256(&a[(idx_m+1)*lda + idx_n]); \
+ regArray##_2 = _mm256_loadu_si256(&a[(idx_m+2)*lda + idx_n]); \
+ regArray##_3 = _mm256_loadu_si256(&a[(idx_m+3)*lda + idx_n]); \
+ regArray##_4 = _mm256_loadu_si256(&a[(idx_m+4)*lda + idx_n]); \
+ regArray##_5 = _mm256_loadu_si256(&a[(idx_m+5)*lda + idx_n]); \
+ regArray##_6 = _mm256_loadu_si256(&a[(idx_m+6)*lda + idx_n]); \
+ regArray##_7 = _mm256_loadu_si256(&a[(idx_m+7)*lda + idx_n]);
+
+
+#define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \
+ regArray##_0 = _mm_loadu_si128(&a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm_loadu_si128(&a[(idx_m+1)*lda + idx_n]); \
+ regArray##_2 = _mm_loadu_si128(&a[(idx_m+2)*lda + idx_n]); \
+ regArray##_3 = _mm_loadu_si128(&a[(idx_m+3)*lda + idx_n]); \
+ regArray##_4 = _mm_loadu_si128(&a[(idx_m+4)*lda + idx_n]); \
+ regArray##_5 = _mm_loadu_si128(&a[(idx_m+5)*lda + idx_n]); \
+ regArray##_6 = _mm_loadu_si128(&a[(idx_m+6)*lda + idx_n]); \
+ regArray##_7 = _mm_loadu_si128(&a[(idx_m+7)*lda + idx_n]);
+
+
+#define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \
+ regArray = _mm512_loadu_si512(&a[idx_m*lda + idx_n]);
+
+
+#define BF16_MATRIX_MASKZ_LOAD_8x32(regArray, a, lda, idx_m, idx_n, mask) \
+ regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
+ regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
+ regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
+ regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
+ regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
+ regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
+ regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
+
+
+#define BF16_MATRIX_MASKZ_LOAD_8x16(regArray, a, lda, idx_m, idx_n, mask) \
+ regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
+ regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
+ regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
+ regArray##_4 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
+ regArray##_5 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
+ regArray##_6 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
+ regArray##_7 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
+
+
+#define BF16_MATRIX_MASKZ_LOAD_8x8(regArray, a, lda, idx_m, idx_n, mask) \
+ regArray##_0 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
+ regArray##_2 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
+ regArray##_3 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
+ regArray##_4 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
+ regArray##_5 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
+ regArray##_6 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
+ regArray##_7 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
+
+
+#define BF16_MATRIX_MASKZ_LOAD_4x32(regArray, a, lda, idx_m, idx_n, mask) \
+ regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
+ regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
+ regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
+
+
+#define BF16_MATRIX_MASKZ_LOAD_4x16(regArray, a, lda, idx_m, idx_n, mask) \
+ regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
+ regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
+ regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
+
+
+#define BF16_MATRIX_MASKZ_LOAD_8x32_2(regArray, a, lda, idx_m, idx_n, mask) \
+ regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
+ regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
+ regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
+ regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+8)*lda + idx_n]); \
+ regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+10)*lda + idx_n]); \
+ regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+12)*lda + idx_n]); \
+ regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+14)*lda + idx_n]);
+
+
+#define BF16_MATRIX_MASKZ_LOAD_4x32_2(regArray, a, lda, idx_m, idx_n, mask) \
+ regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
+ regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
+ regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
+ regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]);
+
+#define BF16_MATRIX_MASKZ_LOAD_1x32(regArray, a, lda, idx_m, idx_n, mask) \
+ regArray = _mm512_maskz_loadu_epi16(mask, &a[idx_m*lda + idx_n]);
+
+#define BF16_VECTOR_LOAD_1x32(reg, x, idx_n) \
+ reg = _mm512_loadu_si512(x + idx_n);
+
+
+#define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \
+ reg = _mm256_loadu_si256(x + idx_n);
+
+
+#define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \
+ reg = _mm_loadu_si128(x + idx_n);
+
+
+#define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \
+ reg = _mm512_maskz_loadu_epi16(mask, x + idx_n);
+
+
+#define BF16_VECTOR_MASKZ_LOAD_1x16(reg, x, idx_n, mask) \
+ reg = _mm256_maskz_loadu_epi16(mask, x + idx_n);
+
+
+#define BF16_VECTOR_MASKZ_LOAD_1x8(reg, x, idx_n, mask) \
+ reg = _mm_maskz_loadu_epi16(mask, x + idx_n);
+
+
+/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
+ Input - register array of 8 rows of raw-major matrix
+ Output - the output of Step 2
+
+ Step 1: 2-element interleave for matrix
+ |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
+ |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
+ |e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11|e16|e17|f16|f17|e18|e19|f18|f19|e24|e25|f24|f25|e26|e27|f26|f27
+ |g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11|g16|g17|h16|h17|g18|g19|h18|h19|g24|g25|h24|h25|g26|g27|h26|h27
+ |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
+ |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
+ |e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15|e20|e21|f20|f21|e22|e23|f22|f23|e28|e29|f28|f29|e30|e31|f30|f31
+ |g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15|g20|g21|h20|h21|g22|g23|h22|h23|g28|g29|h28|h29|g30|g31|h30|h31
+
+ Step 2: 4-element interleave for matrix
+ |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
+ |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
+ |e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9 |e16|e17|f16|f17|g16|g17|h16|h17|e24|e25|f24|f25|g24|g25|h24|h25
+ |e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11|e18|e19|f18|f19|g18|g19|h18|h19|e26|e27|f26|f27|g26|g27|h26|h27
+ |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
+ |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
+ |e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13|e20|e21|f20|f21|g20|g21|h20|h21|e28|e29|f28|f29|g28|g29|h28|h29
+ |e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15|e22|e23|f22|f23|g22|g23|h22|h23|e30|e31|f30|f31|g30|g31|h30|h31
+*/
+#define BF16_INTERLEAVE_8x32(regArray) \
+ regArray##_8 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
+ regArray##_9 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
+ regArray##_10 = _mm512_unpacklo_epi32(regArray##_4, regArray##_5); \
+ regArray##_11 = _mm512_unpacklo_epi32(regArray##_6, regArray##_7); \
+ regArray##_12 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
+ regArray##_13 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
+ regArray##_14 = _mm512_unpackhi_epi32(regArray##_4, regArray##_5); \
+ regArray##_15 = _mm512_unpackhi_epi32(regArray##_6, regArray##_7); \
+ \
+ regArray##_0 = _mm512_unpacklo_epi64(regArray##_8, regArray##_9); \
+ regArray##_1 = _mm512_unpackhi_epi64(regArray##_8, regArray##_9); \
+ regArray##_2 = _mm512_unpacklo_epi64(regArray##_10, regArray##_11); \
+ regArray##_3 = _mm512_unpackhi_epi64(regArray##_10, regArray##_11); \
+ regArray##_4 = _mm512_unpacklo_epi64(regArray##_12, regArray##_13); \
+ regArray##_5 = _mm512_unpackhi_epi64(regArray##_12, regArray##_13); \
+ regArray##_6 = _mm512_unpacklo_epi64(regArray##_14, regArray##_15); \
+ regArray##_7 = _mm512_unpackhi_epi64(regArray##_14, regArray##_15);
+
+
+/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
+ Input - register array of 8 rows of raw-major matrix
+ Output - the output of Step 2
+
+ Step 1: 2-element interleave for matrix
+ |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
+ |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
+ |e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11
+ |g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11
+ |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
+ |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
+ |e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15
+ |g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15
+
+ Step 2: 4-element interleave for matrix
+ |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
+ |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
+ |e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9
+ |e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11
+ |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
+ |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
+ |e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13
+ |e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15
+*/
+#define BF16_INTERLEAVE_8x16(regArray) \
+ regArray##_8 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
+ regArray##_9 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
+ regArray##_10 = _mm256_unpacklo_epi32(regArray##_4, regArray##_5); \
+ regArray##_11 = _mm256_unpacklo_epi32(regArray##_6, regArray##_7); \
+ regArray##_12 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
+ regArray##_13 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
+ regArray##_14 = _mm256_unpackhi_epi32(regArray##_4, regArray##_5); \
+ regArray##_15 = _mm256_unpackhi_epi32(regArray##_6, regArray##_7); \
+ \
+ regArray##_0 = _mm256_unpacklo_epi64(regArray##_8, regArray##_9); \
+ regArray##_1 = _mm256_unpackhi_epi64(regArray##_8, regArray##_9); \
+ regArray##_2 = _mm256_unpacklo_epi64(regArray##_10, regArray##_11); \
+ regArray##_3 = _mm256_unpackhi_epi64(regArray##_10, regArray##_11); \
+ regArray##_4 = _mm256_unpacklo_epi64(regArray##_12, regArray##_13); \
+ regArray##_5 = _mm256_unpackhi_epi64(regArray##_12, regArray##_13); \
+ regArray##_6 = _mm256_unpacklo_epi64(regArray##_14, regArray##_15); \
+ regArray##_7 = _mm256_unpackhi_epi64(regArray##_14, regArray##_15);
+
+/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
+ Input - register array of 8 rows of raw-major matrix
+ Output - the output of Step 2
+
+ Step 1: 2-element interleave for matrix
+ |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
+ |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
+ |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
+ |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
+
+ Step 2: 4-element interleave for matrix
+ |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
+ |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
+ |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
+ |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
+*/
+#define BF16_INTERLEAVE_4x32(regArray) \
+ regArray##_4 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
+ regArray##_5 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
+ regArray##_6 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
+ regArray##_7 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
+ \
+ regArray##_0 = _mm512_unpacklo_epi64(regArray##_4, regArray##_5); \
+ regArray##_1 = _mm512_unpackhi_epi64(regArray##_4, regArray##_5); \
+ regArray##_2 = _mm512_unpacklo_epi64(regArray##_6, regArray##_7); \
+ regArray##_3 = _mm512_unpackhi_epi64(regArray##_6, regArray##_7);
+
+
+/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
+ Input - register array of 8 rows of raw-major matrix
+ Output - the output of Step 2
+
+ Step 1: 2-element interleave for matrix
+ |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
+ |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
+ |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
+ |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
+
+ Step 2: 4-element interleave for matrix
+ |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
+ |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
+ |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
+ |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
+*/
+#define BF16_INTERLEAVE_4x16(regArray) \
+ regArray##_4 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
+ regArray##_5 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
+ regArray##_6 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
+ regArray##_7 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
+ \
+ regArray##_0 = _mm256_unpacklo_epi64(regArray##_4, regArray##_5); \
+ regArray##_1 = _mm256_unpackhi_epi64(regArray##_4, regArray##_5); \
+ regArray##_2 = _mm256_unpacklo_epi64(regArray##_6, regArray##_7); \
+ regArray##_3 = _mm256_unpackhi_epi64(regArray##_6, regArray##_7);
+
+
+/* 2-step interleave for x with 32 BF16 elements
+ Input - original vector
+ Output - the output of Step 2
+
+ Step 1: 2-element interleave for x:
+ |x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11|x16|x17|x16|x17|x18|x19|x18|x19|x24|x25|x24|x25|x26|x27|x26|x27
+ |x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15|x20|x21|x20|x21|x22|x23|x22|x23|x28|x29|x28|x29|x30|x31|x30|x31
+
+ Step 2: 4-element interleave for x:
+ |x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9 |x16|x17|x16|x17|x16|x17|x16|x17|x24|x25|x24|x25|x24|x25|x24|x25
+ |x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11|x18|x19|x18|x19|x18|x19|x18|x19|x26|x27|x26|x27|x26|x27|x26|x27
+ |x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13|x20|x21|x20|x21|x20|x21|x20|x21|x28|x29|x28|x29|x28|x29|x28|x29
+ |x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15|x22|x23|x22|x23|x22|x23|x22|x23|x30|x31|x30|x31|x30|x31|x30|x31
+*/
+#define BF16_INTERLEAVE_1x32(regArray) \
+ regArray##_1 = _mm512_unpacklo_epi32(regArray##_0, regArray##_0); \
+ regArray##_3 = _mm512_unpackhi_epi32(regArray##_0, regArray##_0); \
+ \
+ regArray##_0 = _mm512_unpacklo_epi64(regArray##_1, regArray##_1); \
+ regArray##_1 = _mm512_unpackhi_epi64(regArray##_1, regArray##_1); \
+ regArray##_2 = _mm512_unpacklo_epi64(regArray##_3, regArray##_3); \
+ regArray##_3 = _mm512_unpackhi_epi64(regArray##_3, regArray##_3);
+
+
+/* 2-step interleave for x with 16 BF16 elements
+ Input - original vector
+ Output - the output of Step 2
+
+ Step 1: 2-element interleave for x:
+ |x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11
+ |x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15
+
+ Step 2: 4-element interleave for x:
+ |x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9
+ |x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11
+ |x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13
+ |x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15
+*/
+#define BF16_INTERLEAVE_1x16(regArray) \
+ regArray##_1 = _mm256_unpacklo_epi32(regArray##_0, regArray##_0); \
+ regArray##_3 = _mm256_unpackhi_epi32(regArray##_0, regArray##_0); \
+ \
+ regArray##_0 = _mm256_unpacklo_epi64(regArray##_1, regArray##_1); \
+ regArray##_1 = _mm256_unpackhi_epi64(regArray##_1, regArray##_1); \
+ regArray##_2 = _mm256_unpacklo_epi64(regArray##_3, regArray##_3); \
+ regArray##_3 = _mm256_unpackhi_epi64(regArray##_3, regArray##_3);
+
+/* 1-step interleave to exchange the high-256s bit and low-256 bits of 4 pair of registers
+ |a0|a1|...|a14|a15|i0|i1|...|i14|i15|
+ |b0|b1|...|b14|b15|j0|j1|...|j14|j15|
+ |c0|c1|...|c14|c15|k0|k1|...|k14|k15|
+ |d0|d1|...|d14|d15|l0|l1|...|l14|l15|
+ |e0|e1|...|e14|e15|m0|m1|...|m14|m15|
+ |f0|f1|...|f14|f15|n0|n1|...|n14|n15|
+ |g0|g1|...|g14|g15|o0|o1|...|o14|o15|
+ |h0|h1|...|h14|h15|p0|p1|...|p14|p15|
+*/
+#define BF16_INTERLEAVE256_8x32(regArray) \
+ regArray##_0 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0x44); \
+ regArray##_1 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0xee); \
+ regArray##_2 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0x44); \
+ regArray##_3 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0xee); \
+ regArray##_4 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0x44); \
+ regArray##_5 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0xee); \
+ regArray##_6 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0x44); \
+ regArray##_7 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0xee);
+
+
+/* 1-step interleave to exchange the high-256s bit and low-256 bits of 2 pair of registers
+ |a0|a1|...|a14|a15|e0|e1|...|e14|e15|
+ |b0|b1|...|b14|b15|f0|f1|...|f14|f15|
+ |c0|c1|...|c14|c15|g0|g1|...|g14|g15|
+ |d0|d1|...|d14|d15|h0|h1|...|h14|h15|
+*/
+#define BF16_INTERLEAVE256_4x32(regArray) \
+ regArray##_0 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0x44); \
+ regArray##_1 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0xee); \
+ regArray##_2 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0x44); \
+ regArray##_3 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0xee);
+
+
+#define BF16_PERMUTE_8x32(idx, regArray) \
+ regArray##_8 = _mm512_permutexvar_epi16(idx, regArray##_0); \
+ regArray##_9 = _mm512_permutexvar_epi16(idx, regArray##_1); \
+ regArray##_10 = _mm512_permutexvar_epi16(idx, regArray##_2); \
+ regArray##_11 = _mm512_permutexvar_epi16(idx, regArray##_3); \
+ regArray##_12 = _mm512_permutexvar_epi16(idx, regArray##_4); \
+ regArray##_13 = _mm512_permutexvar_epi16(idx, regArray##_5); \
+ regArray##_14 = _mm512_permutexvar_epi16(idx, regArray##_6); \
+ regArray##_15 = _mm512_permutexvar_epi16(idx, regArray##_7);
+
+
+#define BF16_PERMUTE_8x32_2(idx, regArray) \
+ regArray##_8 = _mm512_permutexvar_epi32(idx, regArray##_0); \
+ regArray##_9 = _mm512_permutexvar_epi32(idx, regArray##_1); \
+ regArray##_10 = _mm512_permutexvar_epi32(idx, regArray##_2); \
+ regArray##_11 = _mm512_permutexvar_epi32(idx, regArray##_3); \
+ regArray##_12 = _mm512_permutexvar_epi32(idx, regArray##_4); \
+ regArray##_13 = _mm512_permutexvar_epi32(idx, regArray##_5); \
+ regArray##_14 = _mm512_permutexvar_epi32(idx, regArray##_6); \
+ regArray##_15 = _mm512_permutexvar_epi32(idx, regArray##_7);
+
+
+#define BF16_PERMUTE_4x32(idx, regArray) \
+ regArray##_4 = _mm512_permutexvar_epi16(idx, regArray##_0); \
+ regArray##_5 = _mm512_permutexvar_epi16(idx, regArray##_1); \
+ regArray##_6 = _mm512_permutexvar_epi16(idx, regArray##_2); \
+ regArray##_7 = _mm512_permutexvar_epi16(idx, regArray##_3);
+
+
+#define BF16_PERMUTE_4x32_2(idx, regArray) \
+ regArray##_4 = _mm512_permutexvar_epi32(idx, regArray##_0); \
+ regArray##_5 = _mm512_permutexvar_epi32(idx, regArray##_1); \
+ regArray##_6 = _mm512_permutexvar_epi32(idx, regArray##_2); \
+ regArray##_7 = _mm512_permutexvar_epi32(idx, regArray##_3);
+
+
+/* Calculate the dot result for 2-step interleaved matrix and vector
+ (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
+*/
+#define BF16_2STEP_INTERLEAVED_DOT_8x32(accumArray, matArray, xArray) \
+ accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
+ accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_2, (__m512bh) xArray##_0); \
+ accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
+ accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_1); \
+ accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_4, (__m512bh) xArray##_2); \
+ accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_6, (__m512bh) xArray##_2); \
+ accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_5, (__m512bh) xArray##_3); \
+ accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_7, (__m512bh) xArray##_3);
+
+
+/* Calculate the dot result for 2-step interleaved matrix and vector
+ (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
+*/
+#define BF16_2STEP_INTERLEAVED_DOT_8x16(accumArray, matArray, xArray) \
+ accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
+ accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_2, (__m256bh) xArray##_0); \
+ accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
+ accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_1); \
+ accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_4, (__m256bh) xArray##_2); \
+ accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_6, (__m256bh) xArray##_2); \
+ accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_5, (__m256bh) xArray##_3); \
+ accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_7, (__m256bh) xArray##_3);
+
+/* Calculate the dot result for 2-step interleaved matrix and vector
+ (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
+*/
+#define BF16_2STEP_INTERLEAVED_DOT_4x32(accumArray, matArray, xArray) \
+ accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
+ accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
+ accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_2, (__m512bh) xArray##_2); \
+ accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_3);
+
+
+/* Calculate the dot result for 2-step interleaved matrix and vector
+ (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
+*/
+#define BF16_2STEP_INTERLEAVED_DOT_4x16(accumArray, matArray, xArray) \
+ accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
+ accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
+ accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_2, (__m256bh) xArray##_2); \
+ accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_3);
+
+
+/* Calculate the dot result for matrix and vector at 32 elements per row
+ (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
+*/
+#define BF16_DOT_8x32(accumArray, matArray, xArray) \
+ accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray); \
+ accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray); \
+ accumArray##_2 = _mm512_dpbf16_ps(accumArray##_2, (__m512bh) matArray##_2, (__m512bh) xArray); \
+ accumArray##_3 = _mm512_dpbf16_ps(accumArray##_3, (__m512bh) matArray##_3, (__m512bh) xArray); \
+ accumArray##_4 = _mm512_dpbf16_ps(accumArray##_4, (__m512bh) matArray##_4, (__m512bh) xArray); \
+ accumArray##_5 = _mm512_dpbf16_ps(accumArray##_5, (__m512bh) matArray##_5, (__m512bh) xArray); \
+ accumArray##_6 = _mm512_dpbf16_ps(accumArray##_6, (__m512bh) matArray##_6, (__m512bh) xArray); \
+ accumArray##_7 = _mm512_dpbf16_ps(accumArray##_7, (__m512bh) matArray##_7, (__m512bh) xArray);
+
+/* Calculate the dot result for matrix and vector at 32 elements per row
+ (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
+*/
+#define BF16_DOT_1x32(accumArray, matArray, xArray) \
+ accumArray = _mm512_dpbf16_ps(accumArray, (__m512bh) matArray, (__m512bh) xArray);
+
+/* Calculate the dot result for matrix and vector at 16 elements per row
+ (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
+*/
+#define BF16_DOT_8x16(accumArray, matArray, xArray) \
+ accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray); \
+ accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray); \
+ accumArray##_2 = _mm256_dpbf16_ps(accumArray##_2, (__m256bh) matArray##_2, (__m256bh) xArray); \
+ accumArray##_3 = _mm256_dpbf16_ps(accumArray##_3, (__m256bh) matArray##_3, (__m256bh) xArray); \
+ accumArray##_4 = _mm256_dpbf16_ps(accumArray##_4, (__m256bh) matArray##_4, (__m256bh) xArray); \
+ accumArray##_5 = _mm256_dpbf16_ps(accumArray##_5, (__m256bh) matArray##_5, (__m256bh) xArray); \
+ accumArray##_6 = _mm256_dpbf16_ps(accumArray##_6, (__m256bh) matArray##_6, (__m256bh) xArray); \
+ accumArray##_7 = _mm256_dpbf16_ps(accumArray##_7, (__m256bh) matArray##_7, (__m256bh) xArray);
+
+
+/* 2-step interleave for matrix against 8 rows with 16 fp32 elements per row
+ Input - register array of 8 rows of raw-major matrix
+ Output - the output of Step 2
+
+ Step 1: 2-element interleave for matrix
+ |a0|b0|a1|b1|a4|b4|a5|b5|a8 |b8 |a9 |b9 |a12|b12|a13|b13|
+ |c0|d0|c1|d1|c4|d4|c5|d5|c8 |d8 |c9 |d9 |c12|d12|c13|d13|
+ |e0|f0|e1|f1|e4|f4|e5|f5|e8 |f8 |e9 |f9 |e12|f12|e13|f13|
+ |g0|h0|g1|h1|g4|h4|g5|h5|g8 |h8 |g9 |h9 |g12|h12|g13|h13|
+ |a2|b2|a3|b3|a6|b6|a7|b7|a10|b10|a11|b11|a14|b14|a15|b15|
+ |c2|d2|c3|d3|c6|d6|c7|d7|c10|d10|c11|d11|c14|d14|c15|d15|
+ |e2|f2|e3|f3|e6|f6|e7|f7|e10|f10|e11|f11|e14|f14|e15|f15|
+ |g2|h2|g3|h3|g6|h6|g7|h7|g10|h10|g11|h11|g14|h14|g15|h15|
+
+ Step 2: 4-element interleave for matrix
+ |a0|b0|c0|d0|a4|b4|c4|d4|a8 |b8 |c8 |d8 |a12|b12|c12|d12|
+ |a1|b1|c1|d1|a5|b5|c5|d5|a9 |b9 |c9 |d9 |a13|b13|c13|d13|
+ |e0|f0|g0|h0|e4|f4|g4|h4|e8 |f8 |g8 |h8 |e12|f12|g12|h12|
+ |e1|f1|g1|h1|e5|f5|g5|h5|e9 |f9 |g9 |h9 |e13|f13|g13|h13|
+ |a2|b2|c2|d2|a6|b6|c6|d6|a10|b10|c10|d10|a14|b14|c14|d14|
+ |a3|b3|c3|d3|a7|b7|c7|d7|a11|b11|c11|d11|a15|b15|c15|d15|
+ |e2|f2|g2|h2|e6|f6|g6|h6|e10|f10|g10|h10|e14|f14|g14|h14|
+ |e3|f3|g3|h3|e7|f7|g7|h7|e11|f11|g11|h11|e15|f15|g15|h15|
+*/
+#define FP32_INTERLEAVE_8x16(regArray) \
+ regArray##_8 = _mm512_unpacklo_ps(regArray##_0, regArray##_1); \
+ regArray##_9 = _mm512_unpacklo_ps(regArray##_2, regArray##_3); \
+ regArray##_10 = _mm512_unpacklo_ps(regArray##_4, regArray##_5); \
+ regArray##_11 = _mm512_unpacklo_ps(regArray##_6, regArray##_7); \
+ regArray##_12 = _mm512_unpackhi_ps(regArray##_0, regArray##_1); \
+ regArray##_13 = _mm512_unpackhi_ps(regArray##_2, regArray##_3); \
+ regArray##_14 = _mm512_unpackhi_ps(regArray##_4, regArray##_5); \
+ regArray##_15 = _mm512_unpackhi_ps(regArray##_6, regArray##_7); \
+ \
+ regArray##_0 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
+ regArray##_1 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
+ regArray##_4 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
+ regArray##_5 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
+ regArray##_2 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
+ regArray##_3 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
+ regArray##_6 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_14, (__m512d) regArray##_15); \
+ regArray##_7 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_14, (__m512d) regArray##_15);
+
+#define FP32_INTERLEAVE_8x16_ARRAY(regArray) \
+ regArray[8] = _mm512_unpacklo_ps(regArray[0], regArray[1]); \
+ regArray[9] = _mm512_unpacklo_ps(regArray[2], regArray[3]); \
+ regArray[10] = _mm512_unpacklo_ps(regArray[4], regArray[5]); \
+ regArray[11] = _mm512_unpacklo_ps(regArray[6], regArray[7]); \
+ regArray[12] = _mm512_unpackhi_ps(regArray[0], regArray[1]); \
+ regArray[13] = _mm512_unpackhi_ps(regArray[2], regArray[3]); \
+ regArray[14] = _mm512_unpackhi_ps(regArray[4], regArray[5]); \
+ regArray[15] = _mm512_unpackhi_ps(regArray[6], regArray[7]); \
+ \
+ regArray[0] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
+ regArray[1] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
+ regArray[4] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
+ regArray[5] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
+ regArray[2] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
+ regArray[3] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
+ regArray[6] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[14], (__m512d) regArray[15]); \
+ regArray[7] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[14], (__m512d) regArray[15]);
+
+/* 2-step interleave for matrix against 8 rows with 8 fp32 elements per row
+ Input - register array of 8 rows of raw-major matrix
+ Output - the output of Step 2
+
+ Step 1: 2-element interleave for matrix
+ |a0|b0|a1|b1|a4|b4|a5|b5|
+ |c0|d0|c1|d1|c4|d4|c5|d5|
+ |e0|f0|e1|f1|e4|f4|e5|f5|
+ |g0|h0|g1|h1|g4|h4|g5|h5|
+ |a2|b2|a3|b3|a6|b6|a7|b7|
+ |c2|d2|c3|d3|c6|d6|c7|d7|
+ |e2|f2|e3|f3|e6|f6|e7|f7|
+ |g2|h2|g3|h3|g6|h6|g7|h7|
+
+ Step 2: 4-element interleave for matrix
+ |a0|b0|c0|d0|a4|b4|c4|d4|
+ |a1|b1|c1|d1|a5|b5|c5|d5|
+ |e0|f0|g0|h0|e4|f4|g4|h4|
+ |e1|f1|g1|h1|e5|f5|g5|h5|
+ |a2|b2|c2|d2|a6|b6|c6|d6|
+ |a3|b3|c3|d3|a7|b7|c7|d7|
+ |e2|f2|g2|h2|e6|f6|g6|h6|
+ |e3|f3|g3|h3|e7|f7|g7|h7|
+*/
+#define FP32_INTERLEAVE_8x8(regArray) \
+ regArray##_8 = _mm256_unpacklo_ps(regArray##_0, regArray##_1); \
+ regArray##_9 = _mm256_unpacklo_ps(regArray##_2, regArray##_3); \
+ regArray##_10 = _mm256_unpacklo_ps(regArray##_4, regArray##_5); \
+ regArray##_11 = _mm256_unpacklo_ps(regArray##_6, regArray##_7); \
+ regArray##_12 = _mm256_unpackhi_ps(regArray##_0, regArray##_1); \
+ regArray##_13 = _mm256_unpackhi_ps(regArray##_2, regArray##_3); \
+ regArray##_14 = _mm256_unpackhi_ps(regArray##_4, regArray##_5); \
+ regArray##_15 = _mm256_unpackhi_ps(regArray##_6, regArray##_7); \
+ \
+ regArray##_0 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
+ regArray##_1 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
+ regArray##_4 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
+ regArray##_5 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
+ regArray##_2 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
+ regArray##_3 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
+ regArray##_6 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_14, (__m256d) regArray##_15); \
+ regArray##_7 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_14, (__m256d) regArray##_15);
+
+
+/* Accumulate the result for 2 batch of 4-registers
+*/
+#define FP32_ACCUM2_8x16(regArray) \
+ regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_1); \
+ regArray##_2 = _mm512_add_ps(regArray##_2, regArray##_3); \
+ regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_5); \
+ regArray##_6 = _mm512_add_ps(regArray##_6, regArray##_7); \
+ regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_2); \
+ regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_6);
+
+#define FP32_ACCUM2_8x16_ARRAY(regArray) \
+ regArray[0] = _mm512_add_ps(regArray[0], regArray[1]); \
+ regArray[2] = _mm512_add_ps(regArray[2], regArray[3]); \
+ regArray[4] = _mm512_add_ps(regArray[4], regArray[5]); \
+ regArray[6] = _mm512_add_ps(regArray[6], regArray[7]); \
+ regArray[0] = _mm512_add_ps(regArray[0], regArray[2]); \
+ regArray[4] = _mm512_add_ps(regArray[4], regArray[6]);
+
+/* Accumulate the result for 2 batch of 4-registers
+*/
+#define FP32_ACCUM2_8x8(regArray) \
+ regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_1); \
+ regArray##_2 = _mm256_add_ps(regArray##_2, regArray##_3); \
+ regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_5); \
+ regArray##_6 = _mm256_add_ps(regArray##_6, regArray##_7); \
+ regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_2); \
+ regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_6);
+
+
+/* Store 16 (alpha * result + beta * y) to y
+*/
+#define STORE16_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
+ regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_loadu_ps(targetAddr))); \
+ _mm512_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 16 (alpha * result + beta * y) to y
+*/
+#define STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
+ regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_maskz_loadu_ps(mask, targetAddr))); \
+ _mm512_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 8 (alpha * result + beta * y) to y
+*/
+#define STORE8_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
+ regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_loadu_ps(targetAddr))); \
+ _mm256_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 8 (alpha * result + beta * y) to y
+*/
+#define STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
+ regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_maskz_loadu_ps(mask, targetAddr))); \
+ _mm256_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 4 (alpha * result + beta * y) to y
+*/
+#define STORE4_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
+ regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_loadu_ps(targetAddr))); \
+ _mm_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 4 (alpha * result + beta * y) to y
+*/
+#define STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
+ regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_maskz_loadu_ps(mask, targetAddr))); \
+ _mm_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 16 (alpha * result + y) to y
+*/
+#define STORE16_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
+ regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_loadu_ps(targetAddr)); \
+ _mm512_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 16 (alpha * result + y) to y
+*/
+#define STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
+ regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
+ _mm512_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 8 (alpha * result + y) to y
+*/
+#define STORE8_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
+ regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_loadu_ps(targetAddr)); \
+ _mm256_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 8 (alpha * result + y) to y
+*/
+#define STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
+ regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
+ _mm256_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 4 (alpha * result + y) to y
+*/
+#define STORE4_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
+ regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_loadu_ps(targetAddr)); \
+ _mm_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 4 (alpha * result + y) to y
+*/
+#define STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
+ regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
+ _mm_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 16 (alpha * result) to y
+*/
+#define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
+ _mm512_storeu_ps(targetAddr, _mm512_mul_ps(ALPHAVECTOR, regResult));
+
+
+/* Masked store 16 (alpha * result) to y
+*/
+#define STORE16_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
+ _mm512_mask_storeu_ps(targetAddr, mask, _mm512_mul_ps(ALPHAVECTOR, regResult));
+
+
+/* Store 8 (alpha * result) to y
+*/
+#define STORE8_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
+ _mm256_storeu_ps(targetAddr, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
+
+
+/* Masked store 8 (alpha * result) to y
+*/
+#define STORE8_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
+ _mm256_mask_storeu_ps(targetAddr, mask, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
+
+
+/* Store 4 (alpha * result) to y
+*/
+#define STORE4_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
+ _mm_storeu_ps(targetAddr, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
+
+
+/* Masked store 4 (alpha * result) to y
+*/
+#define STORE4_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
+ _mm_mask_storeu_ps(targetAddr, mask, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
+
+
+/* Store 16 result to y
+*/
+#define STORE16_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
+ _mm512_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 16 result to y
+*/
+#define STORE16_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
+ _mm512_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 8 result to y
+*/
+#define STORE8_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
+ _mm256_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 8 result to y
+*/
+#define STORE8_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
+ _mm256_mask_storeu_ps(targetAddr, mask, regResult);
+
+
+/* Store 4 result to y
+*/
+#define STORE4_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
+ _mm_storeu_ps(targetAddr, regResult);
+
+
+/* Masked store 4 result to y
+*/
+#define STORE4_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
+ _mm_mask_storeu_ps(targetAddr, mask, regResult);
+
+#endif
--- /dev/null
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+#include <immintrin.h>
+#include "common.h"
+// Include common macros for BF16 based operations with IA intrinsics
+#include "bf16_common_macros.h"
+
+#ifndef ZERO_BETA // Beta is non-zero
+
+#ifndef ONE_BETA // BETA is not ONE
+
+#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_BETA
+#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA
+#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_BETA
+#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA
+#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_BETA
+#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA
+
+#else // BETA is ONE
+
+#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
+#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
+#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_ONE
+#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE
+#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_ONE
+#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE
+
+#endif
+
+#else // BETA is zero
+
+#ifndef ONE_ALPHA // ALPHA is not ONE
+
+#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA
+#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA
+#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA
+#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA
+#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA
+#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA
+
+#else // ALPHA is ONE
+
+#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_DIRECT
+#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_DIRECT
+#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_DIRECT
+#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_DIRECT
+#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_DIRECT
+#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_DIRECT
+
+#endif
+
+#endif
+
+
+// 32 rows parallel processing BF16 GEMV kernel for n=1 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_32x1_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_32x1_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_32x1_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_32x1(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_32x = m & (~31);
+
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2;
+ __m512i xArray;
+ __m512 result_0, result_1;
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+#endif
+
+ __m512i load_idx_lo = _mm512_set_epi16(0, 15, 0, 14, 0, 13, 0, 12, 0, 11, 0, 10, 0, 9, 0, 8,\
+ 0, 7, 0, 6, 0, 5, 0, 4, 0, 3, 0, 2, 0, 1, 0, 0);
+ __m512i M512_EPI16_16 = _mm512_set1_epi16(16);
+ __m512i load_idx_hi = _mm512_add_epi16(load_idx_lo, M512_EPI16_16);
+
+ unsigned int interleve_mask_value = ((unsigned int) 0x55555555);
+ __mmask32 interleave_mask = *((__mmask32*) &interleve_mask_value);
+
+ xArray = _mm512_set1_epi16((short) x[0]);
+ xArray = _mm512_mask_blend_epi16(interleave_mask, _mm512_setzero_si512(), xArray);
+
+ if (tag_m_32x > 0) {
+ for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)]); // Load 32 rows with n=1
+ matrixArray_1 = _mm512_permutexvar_epi16(load_idx_lo, matrixArray_0); // Expand the low 16 elements
+ matrixArray_2 = _mm512_permutexvar_epi16(load_idx_hi, matrixArray_0); // Expand the high 16 elements
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_1, (__m512bh) xArray);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_2, (__m512bh) xArray);
+
+ STORE16_COMPLETE_RESULT(result_0, y+idx_m)
+ STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
+ }
+ }
+
+ BLASLONG tail_num = m - tag_m_32x;
+ if (tail_num > 16) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-tail_num));
+ __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
+ matrixArray_0 = _mm512_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 32 rows with n=1
+ matrixArray_1 = _mm512_permutexvar_epi16(load_idx_lo, matrixArray_0); // Expand the low 16 elements
+ matrixArray_2 = _mm512_permutexvar_epi16(load_idx_hi, matrixArray_0); // Expand the high 16 elements
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_1, (__m512bh) xArray);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_2, (__m512bh) xArray);
+
+ unsigned short store_mask_value = (((unsigned short)0xffff) >> (32-tail_num));
+ __mmask16 store_mask = *((__mmask16*) &store_mask_value);
+ STORE16_COMPLETE_RESULT(result_0, y+tag_m_32x)
+ STORE16_MASK_COMPLETE_RESULT(result_1, y+tag_m_32x+16, store_mask)
+ } else if (tail_num > 8) {
+ __m256 result256_0 = _mm256_setzero_ps();
+ __m256 result256_1 = _mm256_setzero_ps();
+
+ __m256i load_idx_lo256 = _mm512_castsi512_si256(load_idx_lo);
+ __m256i load_idx_hi256 = _mm512_extracti32x8_epi32(load_idx_lo, 0x1);
+ __m256i xArray256 = _mm512_castsi512_si256(xArray);
+
+ unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-tail_num));
+ __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ __m256i matrixArray256_0 = _mm256_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 16 rows with n=1
+ __m256i matrixArray256_1 = _mm256_permutexvar_epi16(load_idx_lo256, matrixArray256_0); // Expand the low 8 elements
+ __m256i matrixArray256_2 = _mm256_permutexvar_epi16(load_idx_hi256, matrixArray256_0); // Expand the high 8 elements
+
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_1, (__m256bh) xArray256);
+ result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_2, (__m256bh) xArray256);
+
+ unsigned char store_mask_value = (((unsigned char)0xff) >> (16-tail_num));
+ __mmask8 store_mask = *((__mmask8*) &store_mask_value);
+ STORE8_COMPLETE_RESULT(result256_0, y+tag_m_32x)
+ STORE8_MASK_COMPLETE_RESULT(result256_1, y+tag_m_32x+8, store_mask)
+ } else {
+ __m128 result128_0 = _mm_setzero_ps();
+ __m128 result128_1 = _mm_setzero_ps();
+
+ __m128i load_idx_lo128 = _mm_set_epi16(0, 3, 0, 2, 0, 1, 0, 0);
+ __m128i M128_EPI16_4 = _mm_set1_epi16(4);
+ __m128i load_idx_hi128 = _mm_add_epi16(load_idx_lo128, M128_EPI16_4);
+
+ __m128i xArray128 = _mm512_castsi512_si128(xArray);
+
+ unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
+ __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
+ __m128i matrixArray128_0 = _mm_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 8 rows with n=1
+ __m128i matrixArray128_1 = _mm_permutexvar_epi16(load_idx_lo128, matrixArray128_0); // Expand the low 4 elements
+ __m128i matrixArray128_2 = _mm_permutexvar_epi16(load_idx_hi128, matrixArray128_0); // Expand the high 4 elements
+
+ result128_0 = _mm_dpbf16_ps(result128_0, (__m128bh) matrixArray128_1, (__m128bh) xArray128);
+ result128_1 = _mm_dpbf16_ps(result128_1, (__m128bh) matrixArray128_2, (__m128bh) xArray128);
+
+ if (tail_num > 4) {
+ unsigned char store_mask_value = (((unsigned char)0xf) >> (8-tail_num));
+ __mmask8 store_mask = *((__mmask8*) &store_mask_value);
+ STORE4_COMPLETE_RESULT(result128_0, y+tag_m_32x)
+ STORE4_MASK_COMPLETE_RESULT(result128_1, y+tag_m_32x+4, store_mask)
+ } else {
+ unsigned char store_mask_value = (((unsigned char)0xf) >> (4-tail_num));
+ __mmask8 store_mask = *((__mmask8*) &store_mask_value);
+ STORE4_MASK_COMPLETE_RESULT(result128_0, y+tag_m_32x, store_mask)
+ }
+ }
+
+ return 0;
+}
+
+// 32 rows parallel processing BF16 GEMV kernel for n=2 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_32x2_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_32x2_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_32x2_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_32x = m & (~31);
+
+ __m512i matrixArray_0, matrixArray_1;
+ __m512i xArray;
+ __m512 result_0, result_1;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ unsigned char load_mask_value = (((unsigned char)0xff) >> 6);
+ __mmask8 load_mask = *((__mmask8*) &load_mask_value);
+ xArray = _mm512_broadcastd_epi32(_mm_maskz_loadu_epi16(load_mask, x));
+
+ if (tag_m_32x > 0) {
+ for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*2]); // Load 16 rows as n=2
+ matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+16)*2]); // Load 16 rows as n=2
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray);
+
+ STORE16_COMPLETE_RESULT(result_0, y+idx_m)
+ STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
+ }
+ }
+
+ if (m - tag_m_32x >= 16) {
+ result_0 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_32x)*2]); // Load 16 rows with n=2
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
+
+ STORE16_COMPLETE_RESULT(result_0, y+tag_m_32x)
+
+ tag_m_32x += 16;
+ }
+
+ BLASLONG tail_num = m - tag_m_32x;
+ if (tail_num > 8) {
+ result_0 = _mm512_setzero_ps();
+
+ unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-(m&15)));
+ __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ matrixArray_0 = _mm512_maskz_loadu_epi32(tail_mask, &a[(tag_m_32x)*2]); // Load 16 rows with n=2
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
+
+ STORE16_MASK_COMPLETE_RESULT(result_0, y+tag_m_32x, tail_mask)
+ } else if (tail_num == 8) {
+ __m256 result256 = _mm256_setzero_ps();
+
+ __m256i matrixArray256 = _mm256_loadu_si256(&a[(tag_m_32x)*2]); // Load 8 rows with n=2
+ __m256i xArray256 = _mm512_castsi512_si256(xArray);
+ result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256);
+
+ STORE8_COMPLETE_RESULT(result256, y+tag_m_32x)
+ } else {
+ __m256 result256 = _mm256_setzero_ps();
+
+ unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-(m&7)));
+ __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
+ __m256i matrixArray256 = _mm256_maskz_loadu_epi32(tail_mask, &a[(tag_m_32x)*2]); // Load 8 rows with n=2
+ __m256i xArray256 = _mm512_castsi512_si256(xArray);
+ result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256);
+
+ STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_32x, tail_mask)
+ }
+
+ return 0;
+}
+
+// 32 rows parallel processing BF16 GEMV kernel for n=3 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_32x3_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_32x3_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_32x3_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_32x = m & (~31);
+
+ __m512 result_0, result_1;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
+ __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
+ __m128i xTmp = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|0|0|0|0|0|
+ __m512i xArray_0 = _mm512_broadcastd_epi32(xTmp); // x0|x1|x0|x1|...|x0|x1|
+ __m512i xArray_1 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(xTmp, 0x1)); // x2| 0|x2| 0|...|x2| 0|
+
+ __m512i load_idx_base;
+ __m512i M512_EPI16_2, M512_EPI16_8, M512_EPI16_16;
+ M512_EPI16_2 = _mm512_set1_epi16(2);
+ M512_EPI16_8 = _mm512_add_epi16(M512_EPI16_2, M512_EPI16_2);
+ M512_EPI16_8 = _mm512_add_epi16(M512_EPI16_8, M512_EPI16_8);
+ M512_EPI16_16 = _mm512_add_epi16(M512_EPI16_8, M512_EPI16_8);
+ load_idx_base = _mm512_set_epi16(46, 45, 43, 42, 40, 39, 37, 36, 34, 33, 31, 30, 28, 27, 25, 24,
+ 22, 21, 19, 18, 16, 15, 13, 12, 10, 9, 7, 6, 4, 3, 1, 0);
+
+ if (tag_m_32x > 0) {
+ __m512i load_idx01_1st, load_idx01_2nd, load_idx2_1st, load_idx2_2nd;
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6;
+
+ unsigned int idx_blend_mask_value = ((unsigned int)0x80000000);
+ __mmask32 idx_blend_mask = *((__mmask32*) &idx_blend_mask_value);
+
+ load_idx01_1st = load_idx_base;
+ load_idx01_2nd = _mm512_add_epi16(load_idx01_1st, M512_EPI16_16);
+ load_idx2_1st = _mm512_add_epi16(load_idx01_1st, M512_EPI16_2);
+ load_idx2_2nd = _mm512_add_epi16(load_idx01_2nd, M512_EPI16_2);
+ load_idx2_2nd = _mm512_mask_blend_epi16(idx_blend_mask, load_idx2_2nd, _mm512_setzero_si512());
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*3]); // Load 10 rows with n=3 plus 2 element
+ matrixArray_1 = _mm512_loadu_si512(&a[((idx_m+10)*3 + 2)]); // Load 10 rows with n=3 plus 2 element
+ matrixArray_2 = _mm512_loadu_si512(&a[((idx_m+21)*3 + 1)]); // Load 10 rows with n=3 plus 2 element
+
+ matrixArray_3 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_1st, matrixArray_1); // Select the first 2 elements for each row
+ matrixArray_4 = _mm512_permutex2var_epi16(matrixArray_1, load_idx01_2nd, matrixArray_2); // Select the first 2 elements for each row
+ matrixArray_5 = _mm512_permutex2var_epi16(matrixArray_0, load_idx2_1st, matrixArray_1); // Select the third element for each row
+ matrixArray_6 = _mm512_permutex2var_epi16(matrixArray_1, load_idx2_2nd, matrixArray_2); // Select the third element for each row
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_3, (__m512bh) xArray_0);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_5, (__m512bh) xArray_1);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_4, (__m512bh) xArray_0);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_6, (__m512bh) xArray_1);
+
+ STORE16_COMPLETE_RESULT(result_0, y+idx_m)
+ STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
+ }
+ }
+
+ if (tag_m_32x != m) {
+ __m256i load256_idx01_1st, load256_idx01_2nd, load256_idx2_1st, load256_idx2_2nd;
+ __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6;
+ __m256 result256_0, result256_1;
+
+ unsigned short idx256_blend_mask_value = ((unsigned short)0x8000);
+ __mmask16 idx256_blend_mask = *((__mmask16*) &idx256_blend_mask_value);
+
+ load256_idx01_1st = _mm512_castsi512_si256(load_idx_base);
+ load256_idx01_2nd = _mm256_add_epi16(load256_idx01_1st, _mm512_castsi512_si256(M512_EPI16_8));
+ load256_idx2_1st = _mm256_add_epi16(load256_idx01_1st, _mm512_castsi512_si256(M512_EPI16_2));
+ load256_idx2_2nd = _mm256_add_epi16(load256_idx01_2nd, _mm512_castsi512_si256(M512_EPI16_2));
+ load256_idx2_2nd = _mm256_mask_blend_epi16(idx256_blend_mask, load256_idx2_2nd, _mm256_setzero_si256());
+
+ if (m - tag_m_32x > 15) {
+ result256_0 = _mm256_setzero_ps();
+ result256_1 = _mm256_setzero_ps();
+
+ matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
+ matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
+ matrixArray256_2 = _mm256_loadu_si256(&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element
+
+ matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
+ matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
+ matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
+ matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
+
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
+ result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
+ result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
+
+ STORE8_COMPLETE_RESULT(result256_0, y+tag_m_32x)
+ STORE8_COMPLETE_RESULT(result256_1, y+tag_m_32x+8)
+
+ tag_m_32x += 16;
+ }
+
+ if (tag_m_32x != m) {
+ result256_0 = _mm256_setzero_ps();
+ result256_1 = _mm256_setzero_ps();
+ BLASLONG tail_num = m-tag_m_32x;
+
+ if (tail_num > 10) {
+ unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-10-1)*3+1)));
+ __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
+ matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
+ matrixArray256_2 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+10)*3 + 2)]); // Load m-tag_m_32x-10 rows
+
+ matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
+ matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
+ matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
+ matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
+
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
+ result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
+ result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
+ } else if (tail_num > 5) {
+ unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-5-1)*3+2)));
+ __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
+ matrixArray256_1 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+5)*3+1)]); // Load m-tag_m_32x-5 rows
+ matrixArray256_2 = _mm256_setzero_si256();
+
+ matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
+ matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
+ matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
+ matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
+
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
+ result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
+ result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
+ } else {
+ unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-(tail_num*3)));
+ __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ matrixArray256_0 = _mm256_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)*3]); // Load m-tag_m_32x rows
+ matrixArray256_1 = _mm256_setzero_si256();
+
+ matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
+ matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
+
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
+ }
+
+ unsigned short store_tail_mask_value = (((unsigned short)0xffff) >> (16-(tail_num)));
+ __mmask16 store_tail_mask = *((__mmask16*) &store_tail_mask_value);
+ __m512 result512 = _mm512_insertf32x8(_mm512_castps256_ps512(result256_0), result256_1, 0x1);
+ STORE16_MASK_COMPLETE_RESULT(result512, y+tag_m_32x, store_tail_mask)
+ }
+ }
+
+ return 0;
+}
+
+// 16 rows parallel processing BF16 GEMV kernel for n=4 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_16x4_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_16x4_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_16x4_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_16x4(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_16x = m & (~15);
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
+ __m512i xArray_01, xArray_23, xArray_remix;
+ __m512 result;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i M512_EPI32_1 = _mm512_set1_epi32(1);
+ __m512i idx_base_0 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
+ __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_1);
+ __m512i idx_base_remix = _mm512_inserti32x8(idx_base_0, _mm512_castsi512_si256(idx_base_1), 0x1);
+
+ unsigned char x_load_mask_value = (((unsigned char)0xf) >> 2);
+ __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
+ __m128i xTmp = _mm_maskz_loadu_epi32(x_load_mask, x); // |x0|x1|x2|x3|0|0|0|0|
+ xArray_01 = _mm512_broadcastd_epi32(xTmp); // |x0|x1|x0|x1|...|x0|x1|
+ xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(xTmp, 0x1)); // |x2|x3|x2|x3|...|x2|x3|
+ unsigned short blend_mask_value = ((unsigned short)0xff00);
+ __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
+ xArray_remix = _mm512_mask_blend_epi32(blend_mask, xArray_01, xArray_23); // |x0|x1|x0|x1|x0|x1|x0|x1|...|x2|x3|x2|x3|x2|x3|x2|x3|
+
+ if (tag_m_16x > 0) {
+ for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
+ result = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*4]); // Load 8 rows with n=4
+ matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+8)*4]); // Load 8 rows with n=4
+
+ matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_0, idx_base_0, matrixArray_1); // |a0|a1|...|h0|h1|i0|i1|...|p0|p1|
+ matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_0, idx_base_1, matrixArray_1); // |a2|a3|...|h2|h3|i2|i3|...|p2|p3|
+
+ result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_01);
+ result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_3, (__m512bh) xArray_23);
+
+ STORE16_COMPLETE_RESULT(result, y+idx_m)
+ }
+ }
+
+ if (m - tag_m_16x > 7) {
+ result = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*4]); // Load 8 rows with n=4
+ matrixArray_2 = _mm512_permutexvar_epi32(idx_base_remix, matrixArray_0); // a0|a1|...|h0|h1|a2|a3|...|h2|h3|
+
+ result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_remix);
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result), _mm512_extractf32x8_ps(result, 1));
+
+ STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
+ tag_m_16x += 8;
+ }
+
+ BLASLONG tail_num = m-tag_m_16x;
+ if (tail_num != 0) {
+ result = _mm512_setzero_ps();
+
+ unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-tail_num*2));
+ __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
+ matrixArray_0 = _mm512_maskz_loadu_epi32(tail_mask, &a[(tag_m_16x)*4]); // Load 8 rows with n=4
+ matrixArray_2 = _mm512_permutexvar_epi32(idx_base_remix, matrixArray_0); // a0|a1|...|h0|h1|a2|a3|...|h2|h3|
+
+ result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_remix);
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result), _mm512_extractf32x8_ps(result, 1));
+
+ unsigned char store_tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
+ __mmask8 store_tail_mask = *((__mmask8*) &store_tail_mask_value);
+ STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, store_tail_mask)
+ }
+
+ return 0;
+}
+
+// 30 rows parallel processing BF16 GEMV kernel for n=5 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_30x5_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_30x5_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_30x5_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_30x5(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_30x = m - (m%30);
+
+ unsigned char x_load_mask_value = (((unsigned char)0xff) >> 3);
+ __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
+ __m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|x3|x4|0|0|0|
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512 result_0, result_1;
+ __m512i xArray_01 = _mm512_broadcastd_epi32(x128); // x0|x1|x0|x1|...|x0|x1|
+ __m512i xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)); // x2|x3|x2|x3|...|x2|x3|
+ __m512i xArray_4 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)); // x4| 0|x4| 0|...|x4| 0|
+
+ __m512i M512_EPI16_2 = _mm512_set1_epi16(2);
+ __m512i load_idx01_stage1_1st = _mm512_set_epi16( 0, 0, 0, 0, 0, 0, 0, 0, 58, 57, 53, 52, 48, 47, 43, 42,
+ 38, 37, 33, 32, 26, 25, 21, 20, 16, 15, 11, 10, 6, 5, 1, 0);
+ __m512i load_idx01_stage1_2nd = _mm512_shuffle_i32x4(load_idx01_stage1_1st, load_idx01_stage1_1st, 0x39);
+ __m512i load_idx01_stage1_3rd = _mm512_shuffle_i32x4(load_idx01_stage1_1st, load_idx01_stage1_1st, 0x4f);
+
+ __m512i load_idx23_stage1_1st = _mm512_add_epi16(load_idx01_stage1_1st, M512_EPI16_2);
+ __m512i load_idx23_stage1_2nd = _mm512_add_epi16(load_idx01_stage1_2nd, M512_EPI16_2);
+ __m512i load_idx23_stage1_3rd = _mm512_add_epi16(load_idx01_stage1_3rd, M512_EPI16_2);
+
+ __m512i load_idx4_stage1_1st = _mm512_add_epi16(load_idx23_stage1_1st, M512_EPI16_2);
+ __m512i load_idx4_stage1_2nd = _mm512_add_epi16(load_idx23_stage1_2nd, M512_EPI16_2);
+ __m512i load_idx4_stage1_3rd = _mm512_add_epi16(load_idx23_stage1_3rd, M512_EPI16_2);
+
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4;
+ __m512i matrixArray_stage1_0, matrixArray_stage1_1, matrixArray_stage1_2;
+ __m512i matrixArray_stage2_0, matrixArray_stage2_1;
+
+ unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2);
+ __mmask32 load_mask = *((__mmask32*) &load_mask_value);
+ unsigned short store_mask_value = (((unsigned short)0xffff) >> 2);
+ __mmask16 store_mask = *((__mmask16*) &store_mask_value);
+
+ if (tag_m_30x > 0) {
+ unsigned short blend_mask_value_0 = ((unsigned short)0xf000);
+ __mmask16 blend_mask_0 = *((__mmask16*) &blend_mask_value_0);
+ unsigned short blend_mask_value_1 = ((unsigned short)0x3f00);
+ __mmask16 blend_mask_1 = *((__mmask16*) &blend_mask_value_1);
+ for (BLASLONG idx_m = 0; idx_m < tag_m_30x; idx_m+=30) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*5]); // Load 6 rows with n=5
+ matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+6)*5)]); // Load 6 rows with n=5
+ matrixArray_2 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+12)*5)]); // Load 6 rows with n=5
+ matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+18)*5)]); // Load 6 rows with n=5
+ matrixArray_4 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+24)*5)]); // Load 6 rows with n=5
+
+ // Process the 0|1 elements
+ // Stage 1: Select the 0|1 elements for each row
+ matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
+ matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx01_stage1_2nd, matrixArray_3);
+ matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx01_stage1_3rd, matrixArray_4);
+ // Stage 2: Reorder and compress all the 0|1 elements
+ matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
+ matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
+ // Calculate the result of the 0|1 elements
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_01);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_01);
+
+ // Process the 2|3 elements
+ // Stage 1: Select the 2|3 elements for each row
+ matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
+ matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx23_stage1_2nd, matrixArray_3);
+ matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx23_stage1_3rd, matrixArray_4);
+ // Stage 2: Reorder and compress all the 2|3 elements
+ matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
+ matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
+ // Calculate the result of the 2|3 elements and accumulate the result of 0|1 elements
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_23);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_23);
+
+ // Process the for 4 elements
+ // Stage 1: Select the 4 elements for each row
+ matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
+ matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx4_stage1_2nd, matrixArray_3);
+ matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx4_stage1_3rd, matrixArray_4);
+ // Stage 2: Reorder and compress all the 4 elements
+ matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
+ matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
+ // Calculate the result of the 4 element and accumulate the result of 0|1 and 2|3 elements
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_4);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_4);
+
+ STORE16_COMPLETE_RESULT(result_0, y+idx_m)
+ STORE16_MASK_COMPLETE_RESULT(result_1, y+idx_m+16, store_mask)
+ }
+ }
+
+ if (m - tag_m_30x > 11) {
+ BLASLONG tag_m_12x = m - ((m-tag_m_30x)%12);
+ for (BLASLONG idx_m = tag_m_30x; idx_m < tag_m_12x; idx_m+=12) {
+ unsigned short store_less_mask_value = (((unsigned short)0xffff) >> 4);
+ __mmask16 store_less_mask = *((__mmask16*) &store_less_mask_value);
+ result_0 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*5]); // Load 6 rows with n=5
+ matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+6)*5)]); // Load 6 rows with n=5
+
+ // Interleave the elements
+ matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
+ matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
+ matrixArray_stage1_2 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
+ // Calculate and accumulate the result
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_0, (__m512bh) xArray_01);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_1, (__m512bh) xArray_23);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_2, (__m512bh) xArray_4);
+
+ STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_less_mask)
+ tag_m_30x += 12;
+ }
+ }
+
+ BLASLONG tail_num = m - tag_m_30x;
+ if (tail_num > 6) {
+ unsigned short store_less_mask_value = (((unsigned short)0xffff) >> (4+(12-tail_num)));
+ __mmask16 store_less_mask = *((__mmask16*) &store_less_mask_value);
+ unsigned int load_less_mask_value = (((unsigned int)0xffffffff) >> (2+(12-tail_num)*5));
+ __mmask32 load_less_mask = *((__mmask32*) &load_less_mask_value);
+ result_0 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_30x)*5]); // Load 6 rows with n=5
+ matrixArray_1 = _mm512_maskz_loadu_epi16(load_less_mask, &a[((tag_m_30x+6)*5)]); // Load x rows with n=5
+
+ // Interleave the elements
+ matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
+ matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
+ matrixArray_stage1_2 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
+ // Calculate and accumulate the result
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_0, (__m512bh) xArray_01);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_1, (__m512bh) xArray_23);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_2, (__m512bh) xArray_4);
+
+ STORE16_MASK_COMPLETE_RESULT(result_0, y+tag_m_30x, store_less_mask)
+ } else {
+ __m128i matrixArray128;
+ __m128 result128, tmp128;
+ for (BLASLONG i = tag_m_30x; i < m; i++) {
+ result128 = _mm_setzero_ps();
+ matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*5]); // Load 1 rows with n=5
+ result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 14);
+ result128 = _mm_add_ps(result128, tmp128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 1);
+ result128 = _mm_add_ps(result128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * result128[0] + beta * y[i];
+#else
+ y[i] = alpha * result128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = result128[0] * alpha;
+#else
+ y[i] = result128[0];
+#endif
+#endif
+
+ }
+ }
+
+ return 0;
+}
+
+// 16 rows parallel processing BF16 GEMV kernel for n=6 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_16x6_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_16x6_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_16x6_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_16x = m & (~15);
+
+ unsigned char x_load_mask_value = (((unsigned char)0xff) >> 2);
+ __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
+ __m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|x3|x4|x5|0|0|
+
+ if (tag_m_16x > 0) {
+ __m512 result_0;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i M512_EPI32_1 = _mm512_set1_epi32(1);
+ __m512i load_idx01_1st = _mm512_set_epi32( 0, 0, 0, 0, 0, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0);
+ __m512i load_idx01_2nd = _mm512_set_epi32(13, 10, 7, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
+
+ __m512i load_idx23_1st = _mm512_add_epi32(load_idx01_1st, M512_EPI32_1);
+ __m512i load_idx23_2nd = _mm512_add_epi32(load_idx01_2nd, M512_EPI32_1);
+
+ __m512i load_idx45_1st = _mm512_add_epi32(load_idx23_1st, M512_EPI32_1);
+ __m512i load_idx45_2nd = _mm512_add_epi32(load_idx23_2nd, M512_EPI32_1);
+
+ unsigned short blend_mask_value = ((unsigned short)0x0400);
+ __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
+ // Set the 11th element to be 0 as invalid index for a 512 bit epi32 register
+ load_idx45_1st = _mm512_mask_blend_epi32(blend_mask, load_idx45_1st, load_idx01_2nd);
+ // Set the 11th element to be 0 as 0 is the correct index
+ load_idx45_2nd = _mm512_mask_blend_epi32(blend_mask, load_idx45_2nd, load_idx01_2nd);
+
+ __m512i xArray_01 = _mm512_broadcastd_epi32(x128); // x0|x1|x0|x1|...|x0|x1|
+ __m512i xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)); // x2|x3|x2|x3|...|x2|x3|
+ __m512i xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)); // x4|x5|x4|x5|...|x4|x5|
+
+ unsigned short permute_mask01_uint = (((unsigned short)0xf800));
+ __mmask16 permute_mask01 = *((__mmask16*) &permute_mask01_uint);
+ unsigned short permute_mask45_uint = (((unsigned short)0xfc00));
+ __mmask16 permute_mask45 = *((__mmask16*) &permute_mask45_uint);
+
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2;
+ __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2;
+ for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
+ result_0 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*6]); // Load 5 rows with n=6 plus 2 element
+ matrixArray_1 = _mm512_loadu_si512(&a[((idx_m+5)*6 + 2)]); // Load 5 rows with n=6 plus 2 element
+ matrixArray_2 = _mm512_loadu_si512(&a[((idx_m+10)*6 + 4)]); // Load 5 rows with n=6 plus 2 element
+
+ // Stage 1: interleave for the a..k elements
+ matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx01_1st, matrixArray_1);
+ matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx23_1st, matrixArray_1);
+ matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_0, load_idx45_1st, matrixArray_1);
+
+ // Stage 2: interleave for the l..p elements and remix together
+ matrixArray_stage_0 = _mm512_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask01, load_idx01_2nd, matrixArray_2);
+ matrixArray_stage_1 = _mm512_mask_permutexvar_epi32(matrixArray_stage_1, permute_mask01, load_idx23_2nd, matrixArray_2);
+ matrixArray_stage_2 = _mm512_mask_permutexvar_epi32(matrixArray_stage_2, permute_mask45, load_idx45_2nd, matrixArray_2);
+
+ // Calculate the result of the 0|1 elements
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_01);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_23);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_45);
+
+ STORE16_COMPLETE_RESULT(result_0, y+idx_m)
+ }
+
+ if (m - tag_m_16x > 7) {
+ __m256i M256_EPI32_1 = _mm512_castsi512_si256(M512_EPI32_1);
+ __m256i load_idx01_1st = _mm256_set_epi32( 0, 0, 15, 12, 9, 6, 3, 0);
+ __m256i load_idx01_2nd = _mm256_set_epi32( 5, 2, 0, 0, 0, 0, 0, 0);
+
+ __m256i load_idx23_1st = _mm256_add_epi32(load_idx01_1st, M256_EPI32_1);
+ __m256i load_idx23_2nd = _mm256_add_epi32(load_idx01_2nd, M256_EPI32_1);
+ unsigned char blend_mask_value = ((unsigned char)0x20);
+ __mmask8 blend_mask = *((__mmask8*) &blend_mask_value);
+ // Set the 6th element to be 0 as invalid index for a 512 bit epi32 register
+ load_idx23_1st = _mm256_mask_blend_epi32(blend_mask, load_idx23_1st, load_idx01_2nd);
+ // Set the 6th element to be 0 as 0 is the correct index
+ load_idx23_2nd = _mm256_mask_blend_epi32(blend_mask, load_idx23_2nd, load_idx01_2nd);
+
+ __m256i load_idx45_1st = _mm256_add_epi32(load_idx23_1st, M256_EPI32_1);
+ __m256i load_idx45_2nd = _mm256_add_epi32(load_idx23_2nd, M256_EPI32_1);
+
+ unsigned char permute_mask01_uint = (((unsigned char)0xc0));
+ __mmask8 permute_mask01 = *((__mmask8*) &permute_mask01_uint);
+ unsigned char permute_mask45_uint = (((unsigned char)0xe0));
+ __mmask8 permute_mask45 = *((__mmask8*) &permute_mask45_uint);
+
+ __m256i matrixArray_0, matrixArray_1, matrixArray_2;
+ __m256i matrixArray_stage_0;
+ __m256 result256_0;
+
+ result256_0 = _mm256_setzero_ps();
+
+ matrixArray_0 = _mm256_loadu_si256(&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element
+ matrixArray_1 = _mm256_loadu_si256(&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element
+ matrixArray_2 = _mm256_loadu_si256(&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element
+
+ // Process the 0|1 elements
+ // Select the 0|1 elements for each row
+ matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx01_1st, matrixArray_1);
+ matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask01, load_idx01_2nd, matrixArray_2);
+ // Calculate the result of the 0|1 elements
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_01));
+
+ // Process the 2|3 elements
+ // Select the 2|3 elements for each row
+ matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx23_1st, matrixArray_1);
+ matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask45, load_idx23_2nd, matrixArray_2);
+ // Calculate the result of the 0|1 elements
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_23));
+
+ // Process the for 4 elements
+ // Select the 4|5 elements for each row
+ matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx45_1st, matrixArray_1);
+ matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask45, load_idx45_2nd, matrixArray_2);
+ // Calculate the result of the 0|1 elements
+ result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_45));
+
+ STORE8_COMPLETE_RESULT(result256_0, y+tag_m_16x)
+ tag_m_16x += 8;
+ }
+ }
+
+ if (tag_m_16x != m) {
+ __m128i matrixArray128;
+ __m128 result128, tmp128;
+ for (BLASLONG i = tag_m_16x; i < m; i++) {
+ result128 = _mm_setzero_ps();
+ matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*6]); // Load 1 rows with n=6
+ result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 14);
+ result128 = _mm_add_ps(result128, tmp128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 1);
+ result128 = _mm_add_ps(result128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * result128[0] + beta * y[i];
+#else
+ y[i] = alpha * result128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = result128[0] * alpha;
+#else
+ y[i] = result128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 16 rows parallel processing BF16 GEMV kernel for n=7 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_16x7_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_16x7_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_16x7_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_16x7(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_16x = m & (~15);
+
+ unsigned char x_load_mask_value = (((unsigned char)0xff) >> 1);
+ __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
+ __m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|0|
+
+ if (tag_m_16x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
+ __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
+ __m512i xArray_0123, xArray_4567;
+ __m512 result_0, result_1, result_2, result_3;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i M512_EPI32_2 = _mm512_set1_epi32(2);
+ __m512i load_idx_stage1_0 = _mm512_set_epi16(31, 27, 26, 25, 24, 23, 22, 21, 31, 20, 19, 18, 17, 16, 15, 14,
+ 31, 13, 12, 11, 10, 9, 8, 7, 31, 6, 5, 4, 3, 2, 1, 0);
+ __m512i load_idx_stage2_0 = _mm512_set_epi32(29, 25, 21, 17, 13, 9, 5, 1, 28, 24, 20, 16, 12, 8, 4, 0);
+ __m512i load_idx_stage2_1 = _mm512_add_epi32(load_idx_stage2_0, M512_EPI32_2);
+
+ unsigned short x_blend_mask_value = ((unsigned short)0xff00);
+ __mmask16 x_blend_mask = *((__mmask16*) &x_blend_mask_value);
+ xArray_0123 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(x128), \
+ _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)));
+ xArray_4567 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)), \
+ _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x3)));
+
+ unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
+ __mmask32 load_mask = *((__mmask32*) &load_mask_value);
+ for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*7]); // Load 4 rows with n=7
+ matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+4)*7]); // Load 4 rows with n=7
+ matrixArray_2 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+8)*7]); // Load 4 rows with n=7
+ matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+12)*7]); // Load 4 rows with n=7
+
+ // Stage 1: padding
+ matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
+ matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
+ matrixArray_2 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_2); // |i0|i1|i2|i3|...|j6|j7|k0|k1|k2|k3|...|l6|l7|
+ matrixArray_3 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_3); // |m0|m1|m2|m3|...|n6|n7|o0|o1|o2|o3|...|p6|p7|
+
+ // Stage 2: interleave per 32 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|...|h0|h1|a2|a3|...|h2|h3|
+ matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|...|h4|h5|a6|a7|...|h6|h7|
+ matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_0, matrixArray_3); // |i0|i1|...|p0|p1|i2|i3|...|p2|p3|
+ matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_1, matrixArray_3); // |i4|i5|...|p4|p5|i6|i7|...|p6|p7|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_0123);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_3, (__m512bh) xArray_4567);
+
+ // Stage 3: interleave per 256 bits
+ result_2 = _mm512_shuffle_f32x4(result_0, result_1, 0x44);
+ result_3 = _mm512_shuffle_f32x4(result_0, result_1, 0xee);
+
+ result_2 = _mm512_add_ps(result_2, result_3);
+
+ STORE16_COMPLETE_RESULT(result_2, y+idx_m)
+ }
+
+ if (m - tag_m_16x > 7) {
+ result_0 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x)*7]); // Load 4 rows with n=7
+ matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x+4)*7]); // Load 4 rows with n=7
+
+ // Stage 1: padding
+ matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
+ matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
+
+ // Stage 2: interleave per 32 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
+ matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
+
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
+
+ STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
+
+ tag_m_16x += 8;
+ }
+
+ BLASLONG tail_num = m - tag_m_16x;
+ if (tail_num > 3) {
+ result_0 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x)*7]); // Load 4 rows with n=7
+ unsigned int tail_load_mask_value = (((unsigned int)0xffffffff) >> (4+(8-tail_num)*7));
+ __mmask32 tail_load_mask = *((__mmask32*) &tail_load_mask_value);
+ matrixArray_1 = _mm512_maskz_loadu_epi16(tail_load_mask, &a[(tag_m_16x+4)*7]); // Load 4 rows with n=7
+
+ // Stage 1: padding
+ matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
+ matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
+
+ // Stage 2: interleave per 32 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
+ matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
+
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
+
+ unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
+ __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
+ STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, tail_mask)
+ tag_m_16x = m;
+ }
+ }
+
+ if (tag_m_16x != m) {
+ __m128i matrixArray128;
+ __m128 result128, tmp128;
+ for (BLASLONG i = tag_m_16x; i < m; i++) {
+ result128 = _mm_setzero_ps();
+ matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*7]); // Load 1 rows with n=7
+ result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 14);
+ result128 = _mm_add_ps(result128, tmp128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 1);
+ result128 = _mm_add_ps(result128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * result128[0] + beta * y[i];
+#else
+ y[i] = alpha * result128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = result128[0] * alpha;
+#else
+ y[i] = result128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 16 rows parallel processing BF16 GEMV kernel for n=8 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_16x8_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_16x8_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_16x8_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_16x = m & (~15);
+
+ __m128i x128 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
+
+ if (tag_m_16x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
+ __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
+ __m512i xArray_0123, xArray_4567;
+ __m512 result_0, result_1, result_2, result_3;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i M512_EPI32_2 = _mm512_set1_epi32(2);
+ __m512i load_idx_stage2_0 = _mm512_set_epi32(29, 25, 21, 17, 13, 9, 5, 1, 28, 24, 20, 16, 12, 8, 4, 0);
+ __m512i load_idx_stage2_1 = _mm512_add_epi32(load_idx_stage2_0, M512_EPI32_2);
+
+ unsigned short x_blend_mask_value = ((unsigned short)0xff00);
+ __mmask16 x_blend_mask = *((__mmask16*) &x_blend_mask_value);
+ xArray_0123 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(x128), \
+ _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)));
+ xArray_4567 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)), \
+ _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x3)));
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*8]); // Load 4 rows with n=8
+ matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+4)*8]); // Load 4 rows with n=8
+ matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+8)*8]); // Load 4 rows with n=8
+ matrixArray_3 = _mm512_loadu_si512(&a[(idx_m+12)*8]); // Load 4 rows with n=8
+
+ // Stage 1: interleave per 32 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|...|h0|h1|a2|a3|...|h2|h3|
+ matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|...|h4|h5|a6|a7|...|h6|h7|
+ matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_0, matrixArray_3); // |i0|i1|...|p0|p1|i2|i3|...|p2|p3|
+ matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_1, matrixArray_3); // |i4|i5|...|p4|p5|i6|i7|...|p6|p7|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_0123);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_3, (__m512bh) xArray_4567);
+
+ // Stage 2: interleave per 256 bits
+ result_2 = _mm512_shuffle_f32x4(result_0, result_1, 0x44);
+ result_3 = _mm512_shuffle_f32x4(result_0, result_1, 0xee);
+
+ result_2 = _mm512_add_ps(result_2, result_3);
+
+ STORE16_COMPLETE_RESULT(result_2, y+idx_m)
+ }
+
+ if (m - tag_m_16x > 7) {
+ result_0 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*8]); // Load 4 rows with n=8
+ matrixArray_1 = _mm512_loadu_si512(&a[(tag_m_16x+4)*8]); // Load 4 rows with n=8
+
+ // Stage 1: interleave per 32 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
+ matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
+
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
+
+ STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
+ tag_m_16x += 8;
+ }
+
+ BLASLONG tail_num = m - tag_m_16x;
+ if (tail_num > 3) {
+ result_0 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*8]); // Load 4 rows with n=8
+ unsigned short tail_load_mask_value = (((unsigned int)0xffff) >> ((8-tail_num)*4));
+ __mmask16 tail_load_mask = *((__mmask16*) &tail_load_mask_value);
+ matrixArray_1 = _mm512_maskz_loadu_epi32(tail_load_mask, &a[(tag_m_16x+4)*8]); // Load 4 rows with n=8
+
+ // Stage 1: interleave per 32 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
+ matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
+
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
+
+ unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
+ __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
+ STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, tail_mask)
+ tag_m_16x = m;
+ }
+ }
+
+ if (tag_m_16x != m) {
+ __m128i matrixArray128;
+ __m128 result128, tmp128;
+ for (BLASLONG i = tag_m_16x; i < m; i++) {
+ result128 = _mm_setzero_ps();
+ matrixArray128 = _mm_loadu_si128(&a[(i)*8]); // Load 1 rows with n=8
+ result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 14);
+ result128 = _mm_add_ps(result128, tmp128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 1);
+ result128 = _mm_add_ps(result128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * result128[0] + beta * y[i];
+#else
+ y[i] = alpha * result128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = result128[0] * alpha;
+#else
+ y[i] = result128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 14 rows parallel processing BF16 GEMV kernel for n=9 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_14x9_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_14x9_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_14x9_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_14x = m - (m%14);
+
+ unsigned char x_load_mask_value = (((unsigned char)0xff) >> 7);
+ __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
+ __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
+ __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|0 |0 | 0| 0| 0| 0| 0|
+
+ if (tag_m_14x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
+ __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
+ __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89;
+ __m512 result_0, result_1;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m256i M256_EPI16_2 = _mm256_set1_epi16(2);
+ __m256i idx_base_0 = _mm256_set_epi16( 0, 0, 55, 54, 46, 45, 37, 36, 28, 27, 19, 18, 10, 9, 1, 0);
+ __m256i idx_base_1 = _mm256_add_epi16(idx_base_0, M256_EPI16_2);
+ __m256i idx_base_2 = _mm256_add_epi16(idx_base_1, M256_EPI16_2);
+ __m256i idx_base_3 = _mm256_add_epi16(idx_base_2, M256_EPI16_2);
+ __m256i idx_base_4 = _mm256_add_epi16(idx_base_3, M256_EPI16_2);
+ __m512i idx_idx = _mm512_set_epi32( 0, 0, 22, 21, 20, 19, 18, 17, 16, 6, 5, 4, 3, 2, 1, 0);
+
+ __m512i load_idx_stage1_0 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_0), idx_idx, _mm512_castsi256_si512(idx_base_1));
+ __m512i load_idx_stage1_1 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_2), idx_idx, _mm512_castsi256_si512(idx_base_3));
+ __m512i load_idx_stage1_2 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_1), idx_idx, _mm512_castsi256_si512(idx_base_0));
+ __m512i load_idx_stage1_3 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_3), idx_idx, _mm512_castsi256_si512(idx_base_2));
+ __m512i load_idx_stage1_4 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_4), idx_idx, _mm512_castsi256_si512(idx_base_4));
+ __m512i load_idx_stage2_0 = _mm512_set_epi32( 0, 0, 22, 21, 20, 19, 18, 17, 16, 13, 12, 11, 10, 9, 8, 7);
+
+ xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0|x1|x0|x1| ... |x0|x1|
+ xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2|x3|x2|x3| ... |x2|x3|
+ xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4|x5|x4|x5| ... |x4|x5|
+ xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6|x7|x6|x7| ... |x6|x7|
+ xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8|0 |x8| 0| ... |x8| 0|
+
+ unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 1);
+ __mmask32 load_mask = *((__mmask32*) &load_mask_value);
+ unsigned short blend_mask_value = ((unsigned short)0x3f80);
+ __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
+ unsigned short store_mask_value = (((unsigned short)0xffff) >> 2);
+ __mmask16 store_mask = *((__mmask16*) &store_mask_value);
+ for (BLASLONG idx_m = 0; idx_m < tag_m_14x; idx_m+=14) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*9]); // Load 3 rows with n=9 plus 5 elements
+ matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+3)*9 + 5]); // Load 3 rows with n=9 plus 4 elements
+ matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+7)*9]); // Load 3 rows with n=9 plus 5 elements
+ matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*9 + 5]); // Load 3 rows with n=9 plus 4 elements
+
+ // Stage 1: interleave per 16 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_0, matrixArray_1); // |a0|a1|...|g0|g1|a2|a3|...|g2|g3|x|x|x|x|
+ matrixArray_stage_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_1, matrixArray_1); // |a4|a5|...|g4|g5|a6|a7|...|g6|g7|x|x|x|x|
+ matrixArray_stage_2 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_2, matrixArray_3); // |h2|h3|...|n2|n3|h0|h1|...|n0|n1|x|x|x|x|
+ matrixArray_stage_3 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_3, matrixArray_3); // |h6|h7|...|n6|n7|h4|h5|...|n4|n5|x|x|x|x|
+ matrixArray_4 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_4, matrixArray_1); // |a8| x|...|g8| x| x| x|...| x| x|x|x|x|x|
+ matrixArray_5 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_4, matrixArray_3); // | x| x|...| x| x|h8| x|...|n8| x|x|x|x|x|
+
+ // Stage 2: interleave per 32 bits
+ matrixArray_0 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|b0|b1|...|h0|h1|i0|i1|j0|j1|...|n0|n1|x|x|x|x|
+ matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, load_idx_stage2_0, matrixArray_stage_2); // |a2|a3|b2|b3|...|h2|h3|i2|i3|j2|j3|...|n2|n3|x|x|x|x|
+ matrixArray_2 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_1, matrixArray_stage_3); // |a4|a5|b4|b5|...|h4|h5|i4|i5|j4|j5|...|n4|n5|x|x|x|x|
+ matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_stage_1, load_idx_stage2_0, matrixArray_stage_3); // |a6|a7|b6|b7|...|h6|h7|i6|i7|j6|j7|...|n6|n7|x|x|x|x|
+ matrixArray_4 = _mm512_mask_blend_epi32(blend_mask, matrixArray_4, matrixArray_5); // |a8| x|b8| x|...|h8| x|i8| x|j8| x|...|n8| x|x|x|x|x|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
+ result_0 = _mm512_add_ps(result_0, result_1);
+
+ STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
+ }
+ }
+
+ if (tag_m_14x != m) {
+ __m256i matrixArray256;
+ __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
+ __m256 result256;
+ __m128 result128, tmp128;
+ unsigned short load256_mask_value = (((unsigned short)0xffff) >> 7);
+ __mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
+ for (BLASLONG i = tag_m_14x; i < m; i++) {
+ result256 = _mm256_setzero_ps();
+ matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*9]);
+ result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
+ result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
+ tmp128 = _mm_shuffle_ps(result128, result128, 14);
+ result128 = _mm_add_ps(result128, tmp128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 1);
+ result128 = _mm_add_ps(result128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * result128[0] + beta * y[i];
+#else
+ y[i] = alpha * result128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = result128[0] * alpha;
+#else
+ y[i] = result128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 12 rows parallel processing BF16 GEMV kernel for n=10 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_12x10_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_12x10_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_12x10_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_12x = m - (m%12);
+
+ unsigned char x_load_mask_value = (((unsigned char)0xf) >> 3);
+ __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
+ __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
+ __m128i x128_1 = _mm_maskz_loadu_epi32(x_load_mask, (x+8)); // |x8|x9|0 | 0| 0| 0| 0| 0|
+
+ if (tag_m_12x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4;
+ __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
+ __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89;
+ __m512 result_0, result_1;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m256i M256_EPI32_1 = _mm256_set1_epi32(1);
+ __m256i idx_base_0 = _mm256_set_epi32( 0, 0, 26, 21, 16, 10, 5, 0);
+ __m256i idx_base_1 = _mm256_add_epi32(idx_base_0, M256_EPI32_1);
+ __m256i idx_base_2 = _mm256_add_epi32(idx_base_1, M256_EPI32_1);
+ __m256i idx_base_3 = _mm256_add_epi32(idx_base_2, M256_EPI32_1);
+ __m256i idx_base_4 = _mm256_add_epi32(idx_base_3, M256_EPI32_1);
+ __m512i idx_idx = _mm512_set_epi32( 0, 0, 0, 0, 21, 20, 19, 18, 17, 16, 5, 4, 3, 2, 1, 0);
+
+ __m512i load_idx_stage1_0 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_0), idx_idx, _mm512_castsi256_si512(idx_base_1));
+ __m512i load_idx_stage1_1 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_2), idx_idx, _mm512_castsi256_si512(idx_base_3));
+ __m512i load_idx_stage1_2 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_1), idx_idx, _mm512_castsi256_si512(idx_base_0));
+ __m512i load_idx_stage1_3 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_3), idx_idx, _mm512_castsi256_si512(idx_base_2));
+ __m512i load_idx_stage1_4 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_4), idx_idx, _mm512_castsi256_si512(idx_base_4));
+ __m512i load_idx_stage2_0 = _mm512_set_epi32( 0, 0, 0, 0, 21, 20, 19, 18, 17, 16, 11, 10, 9, 8, 7, 6);
+
+ xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0|x1|x0|x1| ... |x0|x1|
+ xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2|x3|x2|x3| ... |x2|x3|
+ xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4|x5|x4|x5| ... |x4|x5|
+ xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6|x7|x6|x7| ... |x6|x7|
+ xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8|x9|x8|x9| ... |x8|x9|
+
+ unsigned short blend_mask_value = ((unsigned short)0x0fc0);
+ __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
+ unsigned short load_mask_value = (((unsigned short)0xffff) >> 1);
+ __mmask16 load_mask = *((__mmask16*) &load_mask_value);
+ unsigned short store_mask_value = (((unsigned short)0xffff) >> 4);
+ __mmask16 store_mask = *((__mmask16*) &store_mask_value);
+ for (BLASLONG idx_m = 0; idx_m < tag_m_12x; idx_m+=12) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m)*10]); // Load 3 rows with n=10
+ matrixArray_1 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+3)*10]); // Load 3 rows with n=10
+ matrixArray_2 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+6)*10]); // Load 3 rows with n=10
+ matrixArray_3 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+9)*10]); // Load 3 rows with n=10
+
+ // Stage 1: interleave per 32 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_0, matrixArray_1); // |a0|a1|...|f0|f1|a2|a3|...|f2|f3|x|x|x|x|x|x|x|x|
+ matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_1, matrixArray_1); // |a4|a5|...|f4|f5|a6|a7|...|f6|f7|x|x|x|x|x|x|x|x|
+ matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_2, matrixArray_3); // |g2|g3|...|l2|l3|g0|g1|...|l0|l1|x|x|x|x|x|x|x|x|
+ matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_3, matrixArray_3); // |g6|g7|...|l6|l7|g4|g5|...|l4|l5|x|x|x|x|x|x|x|x|
+ matrixArray_stage_4 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_4, matrixArray_1); // |a8|a9|...|f8|f9| x| x|...| x| x|x|x|x|x|x|x|x|x|
+ matrixArray_stage_5 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_4, matrixArray_3); // | x| x|...| x| x|g8|g9|...|l8|l9|x|x|x|x|x|x|x|x|
+
+ // Stage 3: interleave per 256 bits
+ matrixArray_0 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|...|l0|l1|x|x|x|x|x|x|x|x|
+ matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, load_idx_stage2_0, matrixArray_stage_2); // |a2|a3|...|l2|l3|x|x|x|x|x|x|x|x|
+ matrixArray_2 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_1, matrixArray_stage_3); // |a4|a5|...|l4|l5|x|x|x|x|x|x|x|x|
+ matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_stage_1, load_idx_stage2_0, matrixArray_stage_3); // |a6|a7|...|l6|l7|x|x|x|x|x|x|x|x|
+ matrixArray_4 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_4, matrixArray_stage_5); // |a8|a9|...|l8|l9|x|x|x|x|x|x|x|x|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
+ result_0 = _mm512_add_ps(result_0, result_1);
+
+ STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
+ }
+ }
+
+ if (tag_m_12x != m) {
+ __m256i matrixArray256;
+ __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
+ __m256 result256;
+ __m128 result128, tmp128;
+ unsigned char load256_mask_value = (((unsigned char)0xff) >> 3);
+ __mmask8 load256_mask = *((__mmask8*) &load256_mask_value);
+ for (BLASLONG i = tag_m_12x; i < m; i++) {
+ result256 = _mm256_setzero_ps();
+ matrixArray256 = _mm256_maskz_loadu_epi32(load256_mask, &a[(i)*10]);
+ result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
+ result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
+ tmp128 = _mm_shuffle_ps(result128, result128, 14);
+ result128 = _mm_add_ps(result128, tmp128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 1);
+ result128 = _mm_add_ps(result128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * result128[0] + beta * y[i];
+#else
+ y[i] = alpha * result128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = result128[0] * alpha;
+#else
+ y[i] = result128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 15 rows parallel processing BF16 GEMV kernel for n=11 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_15x11_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_15x11_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_15x11_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_15x = m - (m%15);
+
+ unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
+ __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
+ __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2|x3|x4|x5|x6|x7|
+ __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10| 0| 0| 0| 0| 0|
+
+ if (tag_m_15x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
+ __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
+ __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89, xArray_10;
+ __m512 result_0, result_1;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
+ __m512i idx_stage2_base_0, idx_stage2_base_1, idx_stage2_base_2, idx_stage2_base_3;
+
+ __m512i M512_EPI16_2, M512_EPI16_4, M512_EPI16_6, M512_EPI32_5;
+ M512_EPI16_2 = _mm512_set1_epi16(2);
+ M512_EPI16_4 = _mm512_add_epi16(M512_EPI16_2, M512_EPI16_2);
+ M512_EPI16_6 = _mm512_add_epi16(M512_EPI16_4, M512_EPI16_2);
+ M512_EPI32_5 = _mm512_set1_epi32(5);
+
+ unsigned int BASE_MASK_10_value = ((unsigned int)0x000003ff);
+ __mmask32 BASE_MASK_10 = *((__mmask32*) &BASE_MASK_10_value);
+ unsigned int BASE_MASK_20_value = ((unsigned int)0x000ffc00);
+ __mmask32 BASE_MASK_20 = *((__mmask32*) &BASE_MASK_20_value);
+ unsigned int BASE_MASK_30_value = ((unsigned int)0x3ff00000);
+ __mmask32 BASE_MASK_30 = *((__mmask32*) &BASE_MASK_30_value);
+
+ idx_stage1_base_0 = _mm512_set_epi16( 0, 0, 49, 48, 38, 37, 27, 26, 16, 15, 5, 4, 47, 46, 36, 35,
+ 25, 24, 14, 13, 3, 2, 45, 44, 34, 33, 23, 22, 12, 11, 1, 0);
+ idx_stage1_base_1 = _mm512_add_epi16(idx_stage1_base_0, M512_EPI16_6);
+
+ idx_stage1_base_2 = _mm512_mask_add_epi16(idx_stage1_base_0, BASE_MASK_10, idx_stage1_base_0, M512_EPI16_2);
+ idx_stage1_base_2 = _mm512_mask_sub_epi16(idx_stage1_base_2, BASE_MASK_20, idx_stage1_base_0, M512_EPI16_2);
+ idx_stage1_base_3 = _mm512_add_epi16(idx_stage1_base_2, M512_EPI16_6);
+
+ idx_stage1_base_4 = _mm512_mask_add_epi16(idx_stage1_base_2, BASE_MASK_10, idx_stage1_base_2, M512_EPI16_2);
+ idx_stage1_base_4 = _mm512_mask_add_epi16(idx_stage1_base_4, BASE_MASK_20, idx_stage1_base_2, M512_EPI16_2);
+ idx_stage1_base_4 = _mm512_mask_sub_epi16(idx_stage1_base_4, BASE_MASK_30, idx_stage1_base_2, M512_EPI16_4);
+ idx_stage1_base_5 = _mm512_add_epi16(idx_stage1_base_4, M512_EPI16_6);
+
+ unsigned short idx_stage2_mask_1_value = ((unsigned short)0x03e0);
+ __mmask16 idx_stage2_mask_1 = *((__mmask16*) &idx_stage2_mask_1_value);
+ unsigned short idx_stage2_mask_2_value = ((unsigned short)0x7c00);
+ __mmask16 idx_stage2_mask_2 = *((__mmask16*) &idx_stage2_mask_2_value);
+ idx_stage2_base_0 = _mm512_set_epi32( 0, 0, 0, 0, 0, 0, 20, 19, 18, 17, 16, 9, 8, 7, 6, 5);
+ idx_stage2_base_1 = _mm512_set_epi32( 0, 25, 24, 23, 22, 21, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
+ idx_stage2_base_2 = _mm512_add_epi32(idx_stage2_base_0, M512_EPI32_5);
+ idx_stage2_base_2 = _mm512_mask_add_epi32(idx_stage2_base_2, idx_stage2_mask_1, idx_stage2_base_2, M512_EPI32_5);
+ idx_stage2_base_3 = _mm512_mask_sub_epi32(idx_stage2_base_1, idx_stage2_mask_2, idx_stage2_base_1, M512_EPI32_5);
+
+ xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0 |x1 |x0 |x1 | ... |x0 |x1 |
+ xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2 |x3 |x2 |x3 | ... |x2 |x3 |
+ xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4 |x5 |x4 |x5 | ... |x4 |x5 |
+ xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6 |x7 |x6 |x7 | ... |x6 |x7 |
+ xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8 |x9 |x8 |x9 | ... |x8 |x9 |
+ xArray_10 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_1, 0x1)); // |x10|0 |x10|0 | ... |x10|0 |
+
+ unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 9);
+ __mmask32 load_mask = *((__mmask32*) &load_mask_value);
+
+ unsigned short store_mask_value = (((unsigned short)0xffff) >> 1);
+ __mmask16 store_mask = *((__mmask16*) &store_mask_value);
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_15x; idx_m+=15) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[idx_m*11]); // Load 2 rows with n=11 plus 10 elements
+ matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[idx_m*11 + 32]); // Load 2 rows with n=11 plus 1 element
+ matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+5)*11]); // Load 2 rows with n=11 plus 10 elements
+ matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+5)*11 + 32]); // Load 2 rows with n=11 plus 1 element
+ matrixArray_4 = _mm512_loadu_si512(&a[(idx_m+10)*11]); // Load 2 rows with n=11 plus 10 elements
+ matrixArray_5 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*11 + 32]); // Load 2 rows with n=11 plus 1 element
+
+ // Stage 1: interleave per 16 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi16(matrixArray_0, idx_stage1_base_0, matrixArray_1); // |a0|a1|...|e0|e1|a2|a3|...|e2|e3|a4 |a5|...|e4 |e5|
+ matrixArray_stage_1 = _mm512_permutex2var_epi16(matrixArray_0, idx_stage1_base_1, matrixArray_1); // |a6|a7|...|e6|e7|a8|a9|...|e8|e9|a10|x |...|e10|x |
+ matrixArray_stage_2 = _mm512_permutex2var_epi16(matrixArray_2, idx_stage1_base_2, matrixArray_3); // |f2|f3|...|j2|j3|f0|f1|...|j0|j1|f4 |f5|...|j4 |j5|
+ matrixArray_stage_3 = _mm512_permutex2var_epi16(matrixArray_2, idx_stage1_base_3, matrixArray_3); // |f8|f9|...|j8|j9|f6|f7|...|j6|j7|f10|x |...|j10|x |
+ matrixArray_stage_4 = _mm512_permutex2var_epi16(matrixArray_4, idx_stage1_base_4, matrixArray_5); // |k4|k5|...|o4|o5|k2|k3|...|o2|o3|k0 |k1|...|o0 |o1|
+ matrixArray_stage_5 = _mm512_permutex2var_epi16(matrixArray_4, idx_stage1_base_5, matrixArray_5); // |k10|x|...|o10|x|k8|k9|...|o8|o9|k6 |k7|...|o6 |o7|
+
+ // Stage 2: interleave per 32 bits
+ matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|...|j0|j1|x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_1, matrixArray_stage_3); // |a6|a7|...|j6|j7|x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_0, matrixArray_stage_2); // |a2|a3|...|j2|j3|x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_2, matrixArray_stage_2); // |a4|a5|...|j4|j5|x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_0, matrixArray_stage_3); // |a8|a9|...|j8|j9|x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_2, matrixArray_stage_3); // |a10|x|...|j10|x|x|x|x|x|x|x|x|x|x|x|x|x|
+
+ matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_0, matrixArray_stage_4); // |a0|a1|.......................|o0|o1|x|x|
+ matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_3, matrixArray_stage_5); // |a6|a7|.......................|o6|o7|x|x|
+ matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_1 , idx_stage2_base_1, matrixArray_stage_4); // |a2|a3|.......................|o2|o3|x|x|
+ matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_2 , idx_stage2_base_3, matrixArray_stage_4); // |a4|a5|.......................|o4|o5|x|x|
+ matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_4 , idx_stage2_base_1, matrixArray_stage_5); // |a8|a9|.......................|o8|o9|x|x|
+ matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_5 , idx_stage2_base_3, matrixArray_stage_5); // |a10|x|.......................|o10|x|x|x|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_5, (__m512bh) xArray_10);
+ result_0 = _mm512_add_ps(result_0, result_1);
+
+ STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
+ }
+ }
+
+ if (tag_m_15x != m) {
+ __m256i matrixArray256;
+ __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
+ __m256 result256;
+ __m128 result128, tmp128;
+ unsigned short load256_mask_value = (((unsigned short)0xffff) >> 5);
+ __mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
+ for (BLASLONG i = tag_m_15x; i < m; i++) {
+ result256 = _mm256_setzero_ps();
+ matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*11]);
+ result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
+ result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
+ tmp128 = _mm_shuffle_ps(result128, result128, 14);
+ result128 = _mm_add_ps(result128, tmp128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 1);
+ result128 = _mm_add_ps(result128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * result128[0] + beta * y[i];
+#else
+ y[i] = alpha * result128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = result128[0] * alpha;
+#else
+ y[i] = result128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 15 rows parallel processing BF16 GEMV kernel for n=12 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_15x12_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_15x12_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_15x12_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_15x = m - (m%15);
+
+ unsigned char x_load_mask_value = (((unsigned char)0xff) >> 4);
+ __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
+ __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2| x3|x4|x5|x6|x7|
+ __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10|x11| 0| 0| 0| 0|
+
+ if (tag_m_15x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
+ __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
+ __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89, xArray_10;
+ __m512 result_0, result_1;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
+ __m512i idx_stage2_base_0, idx_stage2_base_1, idx_stage2_base_2, idx_stage2_base_3;
+
+ __m512i M512_EPI32_1, M512_EPI32_2, M512_EPI32_3, M512_EPI32_5;
+ M512_EPI32_1 = _mm512_set1_epi32(1);
+ M512_EPI32_2 = _mm512_add_epi32(M512_EPI32_1, M512_EPI32_1);
+ M512_EPI32_3 = _mm512_add_epi32(M512_EPI32_2, M512_EPI32_1);
+ M512_EPI32_5 = _mm512_add_epi32(M512_EPI32_3, M512_EPI32_2);
+
+ unsigned short BASE_MASK_10_value = ((unsigned short)0x001f);
+ __mmask16 BASE_MASK_10 = *((__mmask16*) &BASE_MASK_10_value);
+ unsigned short BASE_MASK_20_value = ((unsigned short)0x03e0);
+ __mmask16 BASE_MASK_20 = *((__mmask16*) &BASE_MASK_20_value);
+ unsigned short BASE_MASK_30_value = ((unsigned short)0xfc00);
+ __mmask16 BASE_MASK_30 = *((__mmask16*) &BASE_MASK_30_value);
+
+ idx_stage1_base_0 = _mm512_set_epi32( 0, 26, 20, 14, 8, 2, 25, 19, 13, 7, 1, 24, 18, 12, 6, 0);
+ idx_stage1_base_1 = _mm512_add_epi32(idx_stage1_base_0, M512_EPI32_3);
+
+ idx_stage1_base_2 = _mm512_mask_add_epi32(idx_stage1_base_0, BASE_MASK_10, idx_stage1_base_0, M512_EPI32_1);
+ idx_stage1_base_2 = _mm512_mask_sub_epi32(idx_stage1_base_2, BASE_MASK_20, idx_stage1_base_0, M512_EPI32_1);
+ idx_stage1_base_3 = _mm512_add_epi32(idx_stage1_base_2, M512_EPI32_3);
+
+ idx_stage1_base_4 = _mm512_mask_add_epi32(idx_stage1_base_2, BASE_MASK_10, idx_stage1_base_2, M512_EPI32_1);
+ idx_stage1_base_4 = _mm512_mask_add_epi32(idx_stage1_base_4, BASE_MASK_20, idx_stage1_base_2, M512_EPI32_1);
+ idx_stage1_base_4 = _mm512_mask_sub_epi32(idx_stage1_base_4, BASE_MASK_30, idx_stage1_base_2, M512_EPI32_2);
+ idx_stage1_base_5 = _mm512_add_epi32(idx_stage1_base_4, M512_EPI32_3);
+
+ unsigned short idx_stage2_mask_1_value = ((unsigned short)0x03e0);
+ __mmask16 idx_stage2_mask_1 = *((__mmask16*) &idx_stage2_mask_1_value);
+ unsigned short idx_stage2_mask_2_value = ((unsigned short)0x7c00);
+ __mmask16 idx_stage2_mask_2 = *((__mmask16*) &idx_stage2_mask_2_value);
+ idx_stage2_base_0 = _mm512_set_epi32( 0, 0, 0, 0, 0, 0, 20, 19, 18, 17, 16, 9, 8, 7, 6, 5);
+ idx_stage2_base_1 = _mm512_set_epi32( 0, 25, 24, 23, 22, 21, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
+ idx_stage2_base_2 = _mm512_add_epi32(idx_stage2_base_0, M512_EPI32_5);
+ idx_stage2_base_2 = _mm512_mask_add_epi32(idx_stage2_base_2, idx_stage2_mask_1, idx_stage2_base_2, M512_EPI32_5);
+ idx_stage2_base_3 = _mm512_mask_sub_epi32(idx_stage2_base_1, idx_stage2_mask_2, idx_stage2_base_1, M512_EPI32_5);
+
+ xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0 |x1 |x0 |x1 | ... |x0 |x1 |
+ xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2 |x3 |x2 |x3 | ... |x2 |x3 |
+ xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4 |x5 |x4 |x5 | ... |x4 |x5 |
+ xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6 |x7 |x6 |x7 | ... |x6 |x7 |
+ xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8 |x9 |x8 |x9 | ... |x8 |x9 |
+ xArray_10 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_1, 0x1)); // |x10|x11|x10|x11| ... |x10|x11|
+
+ unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
+ __mmask32 load_mask = *((__mmask32*) &load_mask_value);
+
+ unsigned short store_mask_value = (((unsigned short)0xffff) >> 1);
+ __mmask16 store_mask = *((__mmask16*) &store_mask_value);
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_15x; idx_m+=15) {
+ result_0 = _mm512_setzero_ps();
+ result_1 = _mm512_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[idx_m*12]); // Load 2 rows with n=12 plus 8 elements
+ matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[idx_m*12 + 32]); // Load 2 rows with n=12 plus 4 element
+ matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+5)*12]); // Load 2 rows with n=12 plus 8 elements
+ matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+5)*12 + 32]); // Load 2 rows with n=12 plus 4 element
+ matrixArray_4 = _mm512_loadu_si512(&a[(idx_m+10)*12]); // Load 2 rows with n=12 plus 8 elements
+ matrixArray_5 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*12 + 32]); // Load 2 rows with n=12 plus 4 element
+
+ // Stage 1: interleave per 16 bits
+ matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, idx_stage1_base_0, matrixArray_1); // |a0 |a1 |...|e0 |e1 |a2|a3|...|e2|e3|a4 |a5 |...|e4 |e5 |
+ matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, idx_stage1_base_1, matrixArray_1); // |a6 |a7 |...|e6 |e7 |a8|a9|...|e8|e9|a10|a11|...|e10|e11|
+ matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, idx_stage1_base_2, matrixArray_3); // |f2 |f3 |...|j2 |j3 |f0|f1|...|j0|j1|f4 |f5 |...|j4 |j5 |
+ matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, idx_stage1_base_3, matrixArray_3); // |f8 |f9 |...|j8 |j9 |f6|f7|...|j6|j7|f10|f11|...|j10|j11|
+ matrixArray_stage_4 = _mm512_permutex2var_epi32(matrixArray_4, idx_stage1_base_4, matrixArray_5); // |k4 |k5 |...|o4 |o5 |k2|k3|...|o2|o3|k0 |k1 |...|o0 |o1 |
+ matrixArray_stage_5 = _mm512_permutex2var_epi32(matrixArray_4, idx_stage1_base_5, matrixArray_5); // |k10|k11|...|o10|o11|k8|k9|...|o8|o9|k6 |k7 |...|o6 |o7 |
+
+ // Stage 2: interleave per 32 bits
+ matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_0, matrixArray_stage_2); // |a0 |a1 |...|j0 |j1 |x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_1, matrixArray_stage_3); // |a6 |a7 |...|j6 |j7 |x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_0, matrixArray_stage_2); // |a2 |a3 |...|j2 |j3 |x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_2, matrixArray_stage_2); // |a4 |a5 |...|j4 |j5 |x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_0, matrixArray_stage_3); // |a8 |a9 |...|j8 |j9 |x|x|x|x|x|x|x|x|x|x|x|x|
+ matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_2, matrixArray_stage_3); // |a10|a11|...|j10|j11|x|x|x|x|x|x|x|x|x|x|x|x|
+
+ matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_0, matrixArray_stage_4); // |a0|a1|.......................|o0|o1|x|x|
+ matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_3, matrixArray_stage_5); // |a6|a7|.......................|o6|o7|x|x|
+ matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_1 , idx_stage2_base_1, matrixArray_stage_4); // |a2|a3|.......................|o2|o3|x|x|
+ matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_2 , idx_stage2_base_3, matrixArray_stage_4); // |a4|a5|.......................|o4|o5|x|x|
+ matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_4 , idx_stage2_base_1, matrixArray_stage_5); // |a8|a9|.......................|o8|o9|x|x|
+ matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_5 , idx_stage2_base_3, matrixArray_stage_5); // |a10|x|.......................|o10|x|x|x|
+
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
+ result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
+ result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_5, (__m512bh) xArray_10);
+ result_0 = _mm512_add_ps(result_0, result_1);
+
+ STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
+ }
+ }
+
+ if (tag_m_15x != m) {
+ __m256i matrixArray256;
+ __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
+ __m256 result256;
+ __m128 result128, tmp128;
+ unsigned short load256_mask_value = (((unsigned short)0xffff) >> 4);
+ __mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
+ for (BLASLONG i = tag_m_15x; i < m; i++) {
+ result256 = _mm256_setzero_ps();
+ matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*12]);
+ result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
+ result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
+ tmp128 = _mm_shuffle_ps(result128, result128, 14);
+ result128 = _mm_add_ps(result128, tmp128);
+ tmp128 = _mm_shuffle_ps(result128, result128, 1);
+ result128 = _mm_add_ps(result128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * result128[0] + beta * y[i];
+#else
+ y[i] = alpha * result128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = result128[0] * alpha;
+#else
+ y[i] = result128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+
+// 16 rows parallel processing BF16 GEMV kernel for n=13 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_16x13_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_16x13_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_16x13_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_16x13(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_16x = m & (~15);
+
+ unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 3);
+ __mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
+ __m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|0|0|0|
+
+ if (tag_m_16x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
+ matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
+ __m512i xArray_0, xArray_1, xArray_2, xArray_3;
+ __m512 accum512_0, accum512_1;
+ __m512 result_0, result_1;
+
+ __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
+ __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
+ __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
+
+ unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 6);
+ __mmask32 load_mask = *((__mmask32*) &load_mask_value);
+
+ // Prepare X with 2-step interleave way
+ xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
+ BF16_INTERLEAVE_1x32(xArray)
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ // Load matrix
+ BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, idx_m, 0, x_load_mask)
+
+ matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
+ matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
+ matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
+ matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
+
+ BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, idx_m+8, 0, x_load_mask)
+
+ matrixArray_12 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
+ matrixArray_13 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
+ matrixArray_14 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
+ matrixArray_15 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
+
+ // interleave per 256 bits
+ BF16_INTERLEAVE256_8x32(matrixArray)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_8x32(matrixArray)
+
+ // Calculate the temp result for a..p[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
+
+ // Reorder and add up the final result
+ result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
+ result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
+ result_0 = _mm512_add_ps(result_0, result_1);
+ STORE16_COMPLETE_RESULT(result_0, y+idx_m)
+ }
+
+ if (m - tag_m_16x > 7) {
+ __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ // Load matrix
+ BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, tag_m_16x, 0, x_load_mask)
+
+ matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
+ matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
+ matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
+ matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
+
+ // interleave per 256 bits
+ matrixArray_0 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0x44);
+ matrixArray_1 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0xee);
+ matrixArray_2 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0x44);
+ matrixArray_3 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0xee);
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_4x32(matrixArray)
+
+ // Calculate the temp result for a..h[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
+
+ accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
+ accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
+ STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
+ tag_m_16x += 8;
+ }
+
+ if (m - tag_m_16x > 3) {
+ __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
+ __m256 accum256_0, accum256_1;
+
+ xArray256_0 = _mm512_castsi512_si256(xArray_0);
+ xArray256_1 = _mm512_castsi512_si256(xArray_1);
+ xArray256_2 = _mm512_castsi512_si256(xArray_2);
+ xArray256_3 = _mm512_castsi512_si256(xArray_3);
+
+ accum256_0 = _mm256_setzero_ps();
+ accum256_1 = _mm256_setzero_ps();
+
+ BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 13, tag_m_16x, 0, x_load_mask)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_4x16(matrixArray256)
+
+ // Calculate the temp result for a..d[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
+
+ accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
+ __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
+ STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
+ tag_m_16x += 4;
+ }
+ }
+
+ if (tag_m_16x != m) {
+ __m256i matrixArray256;
+ __m256 accum256;
+ __m128 accum128, tmp128;
+ for (BLASLONG i = tag_m_16x; i < m; i++) {
+ accum256 = _mm256_setzero_ps();
+ matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*13]); // Load 1 rows with n=13
+ accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
+ accum128 = _mm_add_ps(accum128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * accum128[0] + beta * y[i];
+#else
+ y[i] = alpha * accum128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = accum128[0] * alpha;
+#else
+ y[i] = accum128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 16 rows parallel processing BF16 GEMV kernel for n=14 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_16x14_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_16x14_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_16x14_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_16x14(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_16x = m & (~15);
+
+ unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 2);
+ __mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
+ __m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|0|0|
+
+ if (tag_m_16x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
+ matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
+ __m512i xArray_0, xArray_1, xArray_2, xArray_3;
+ __m512 accum512_0, accum512_1;
+ __m512 result_0, result_1;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
+ __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
+ __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
+ __m512i shift_idx = _mm512_set_epi32(0, 13, 12, 11, 10, 9, 8, 7, 0, 6, 5, 4, 3, 2, 1, 0);
+
+ unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
+ __mmask32 load_mask = *((__mmask32*) &load_mask_value);
+
+ // Prepare X with 2-step interleave way
+ xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
+ BF16_INTERLEAVE_1x32(xArray)
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ // Load matrix
+ BF16_MATRIX_MASKZ_LOAD_8x32_2(matrixArray, a, 14, idx_m, 0, load_mask)
+
+ // Pre-stage: shift the 2nd vector 1 position right for each register
+ BF16_PERMUTE_8x32_2(shift_idx, matrixArray)
+
+ // interleave per 256 bits
+ BF16_INTERLEAVE256_8x32(matrixArray)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_8x32(matrixArray)
+
+ // Calculate the temp result for a..p[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
+
+ // Reorder and add up the final result
+ result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
+ result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
+ result_0 = _mm512_add_ps(result_0, result_1);
+ STORE16_COMPLETE_RESULT(result_0, y+idx_m)
+ }
+
+ if (m - tag_m_16x > 7) {
+ __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ // Load matrix
+ BF16_MATRIX_MASKZ_LOAD_4x32_2(matrixArray, a, 14, tag_m_16x, 0, load_mask)
+
+ // Pre-stage: shift the 2nd vector 1 position right for each register
+ BF16_PERMUTE_4x32_2(shift_idx, matrixArray)
+
+ // interleave per 256 bits
+ BF16_INTERLEAVE256_4x32(matrixArray)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_4x32(matrixArray)
+
+ // Calculate the temp result for a..h[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
+
+ accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
+ accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
+ STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
+ tag_m_16x += 8;
+ }
+
+ if (m - tag_m_16x > 3) {
+ __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
+ __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
+ __m256 accum256_0, accum256_1;
+
+ xArray256_0 = _mm512_castsi512_si256(xArray_0);
+ xArray256_1 = _mm512_castsi512_si256(xArray_1);
+ xArray256_2 = _mm512_castsi512_si256(xArray_2);
+ xArray256_3 = _mm512_castsi512_si256(xArray_3);
+
+ accum256_0 = _mm256_setzero_ps();
+ accum256_1 = _mm256_setzero_ps();
+
+ BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 14, tag_m_16x, 0, x_load_mask)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_4x16(matrixArray256)
+
+ // Calculate the temp result for a..d[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
+
+ accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
+ __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
+ STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
+ tag_m_16x += 4;
+ }
+ }
+
+ if (tag_m_16x != m) {
+ __m256i matrixArray256;
+ __m256 accum256;
+ __m128 accum128, tmp128;
+ for (BLASLONG i = tag_m_16x; i < m; i++) {
+ accum256 = _mm256_setzero_ps();
+ matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*14]); // Load 1 rows with n=14
+ accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
+ accum128 = _mm_add_ps(accum128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * accum128[0] + beta * y[i];
+#else
+ y[i] = alpha * accum128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = accum128[0] * alpha;
+#else
+ y[i] = accum128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 16 rows parallel processing BF16 GEMV kernel for n=15 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_16x15_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_16x15_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_16x15_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_16x15(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_16x = m & (~15);
+
+ unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 1);
+ __mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
+ __m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|0|
+
+ if (tag_m_16x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
+ matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
+ __m512i xArray_0, xArray_1, xArray_2, xArray_3;
+ __m512 accum512_0, accum512_1;
+ __m512 result_0, result_1;
+
+ __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
+ __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
+ __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
+
+ unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2);
+ __mmask32 load_mask = *((__mmask32*) &load_mask_value);
+
+ // Prepare X with 2-step interleave way
+ xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
+ BF16_INTERLEAVE_1x32(xArray)
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ // Load matrix
+ BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, idx_m, 0, x_load_mask)
+
+ matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
+ matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
+ matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
+ matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
+
+ BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, idx_m+8, 0, x_load_mask)
+
+ matrixArray_12 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
+ matrixArray_13 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
+ matrixArray_14 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
+ matrixArray_15 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
+
+ // interleave per 256 bits
+ BF16_INTERLEAVE256_8x32(matrixArray)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_8x32(matrixArray)
+
+ // Calculate the temp result for a..p[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
+
+ // Reorder and add up the final result
+ result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
+ result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
+ result_0 = _mm512_add_ps(result_0, result_1);
+ STORE16_COMPLETE_RESULT(result_0, y+idx_m)
+ }
+
+ if (m - tag_m_16x > 7) {
+ __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ // Load matrix
+ BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, tag_m_16x, 0, x_load_mask)
+
+ matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
+ matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
+ matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
+ matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
+
+ // interleave per 256 bits
+ matrixArray_0 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0x44);
+ matrixArray_1 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0xee);
+ matrixArray_2 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0x44);
+ matrixArray_3 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0xee);
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_4x32(matrixArray)
+
+ // Calculate the temp result for a..h[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
+
+ accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
+ accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
+ STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
+ tag_m_16x += 8;
+ }
+
+ if (m - tag_m_16x > 3) {
+ __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
+ __m256 accum256_0, accum256_1;
+
+ xArray256_0 = _mm512_castsi512_si256(xArray_0);
+ xArray256_1 = _mm512_castsi512_si256(xArray_1);
+ xArray256_2 = _mm512_castsi512_si256(xArray_2);
+ xArray256_3 = _mm512_castsi512_si256(xArray_3);
+
+ accum256_0 = _mm256_setzero_ps();
+ accum256_1 = _mm256_setzero_ps();
+
+ BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 15, tag_m_16x, 0, x_load_mask)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_4x16(matrixArray256)
+
+ // Calculate the temp result for a..d[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
+
+ accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
+ __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
+ STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
+ tag_m_16x += 4;
+ }
+ }
+
+ if (tag_m_16x != m) {
+ __m256i matrixArray256;
+ __m256 accum256;
+ __m128 accum128, tmp128;
+ for (BLASLONG i = tag_m_16x; i < m; i++) {
+ accum256 = _mm256_setzero_ps();
+ matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*15]); // Load 1 rows with n=15
+ accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
+ accum128 = _mm_add_ps(accum128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * accum128[0] + beta * y[i];
+#else
+ y[i] = alpha * accum128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = accum128[0] * alpha;
+#else
+ y[i] = accum128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 16 rows parallel processing BF16 GEMV kernel for n=16 && lda ineffective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_16x16_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_16x16_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_16x16_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_16x = m & (~15);
+
+ __m256i x256 = _mm256_loadu_si256(x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|
+
+ if (tag_m_16x > 0) {
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
+ matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
+ __m512i xArray_0, xArray_1, xArray_2, xArray_3;
+ __m512 accum512_0, accum512_1;
+ __m512 result_0, result_1;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
+ __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
+ __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
+
+ // Prepare X with 2-step interleave way
+ xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
+ BF16_INTERLEAVE_1x32(xArray)
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ matrixArray_8 = _mm512_loadu_si512(&a[(idx_m )*16]); // Load 2 rows with n=16
+ matrixArray_9 = _mm512_loadu_si512(&a[(idx_m+2 )*16]); // Load 2 rows with n=16
+ matrixArray_10 = _mm512_loadu_si512(&a[(idx_m+4 )*16]); // Load 2 rows with n=16
+ matrixArray_11 = _mm512_loadu_si512(&a[(idx_m+6 )*16]); // Load 2 rows with n=16
+ matrixArray_12 = _mm512_loadu_si512(&a[(idx_m+8 )*16]); // Load 2 rows with n=16
+ matrixArray_13 = _mm512_loadu_si512(&a[(idx_m+10)*16]); // Load 2 rows with n=16
+ matrixArray_14 = _mm512_loadu_si512(&a[(idx_m+12)*16]); // Load 2 rows with n=16
+ matrixArray_15 = _mm512_loadu_si512(&a[(idx_m+14)*16]); // Load 2 rows with n=16
+
+ // interleave per 256 bits
+ BF16_INTERLEAVE256_8x32(matrixArray)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_8x32(matrixArray)
+
+ // Calculate the temp result for a..p[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
+
+ // Reorder and add up the final result
+ result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
+ result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
+ result_0 = _mm512_add_ps(result_0, result_1);
+ STORE16_COMPLETE_RESULT(result_0, y+idx_m)
+ }
+
+ if (m - tag_m_16x > 7) {
+ __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ matrixArray_4 = _mm512_loadu_si512(&a[(tag_m_16x )*16]); // Load 2 rows with n=16
+ matrixArray_5 = _mm512_loadu_si512(&a[(tag_m_16x+2 )*16]); // Load 2 rows with n=16
+ matrixArray_6 = _mm512_loadu_si512(&a[(tag_m_16x+4 )*16]); // Load 2 rows with n=16
+ matrixArray_7 = _mm512_loadu_si512(&a[(tag_m_16x+6 )*16]); // Load 2 rows with n=16
+
+ // interleave per 256 bits
+ BF16_INTERLEAVE256_4x32(matrixArray)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_4x32(matrixArray)
+
+ // Calculate the temp result for a..h[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
+
+ accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
+ accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
+ __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
+ STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
+ tag_m_16x += 8;
+ }
+
+ if (m - tag_m_16x > 3) {
+ __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, \
+ matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
+ __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
+ __m256 accum256_0, accum256_1;
+
+ xArray256_0 = _mm512_castsi512_si256(xArray_0);
+ xArray256_1 = _mm512_castsi512_si256(xArray_1);
+ xArray256_2 = _mm512_castsi512_si256(xArray_2);
+ xArray256_3 = _mm512_castsi512_si256(xArray_3);
+
+ accum256_0 = _mm256_setzero_ps();
+ accum256_1 = _mm256_setzero_ps();
+
+ matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x )*16]); // Load 2 rows with n=16
+ matrixArray_1 = _mm512_loadu_si512(&a[(tag_m_16x+2 )*16]); // Load 2 rows with n=16
+
+ matrixArray256_0 = _mm512_castsi512_si256(matrixArray_0);
+ matrixArray256_1 = _mm512_extracti32x8_epi32(matrixArray_0, 0x1);
+ matrixArray256_2 = _mm512_castsi512_si256(matrixArray_1);
+ matrixArray256_3 = _mm512_extracti32x8_epi32(matrixArray_1, 0x1);
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_4x16(matrixArray256)
+
+ // Calculate the temp result for a..d[0:15]
+ BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
+
+ accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
+ __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
+ STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
+ tag_m_16x += 4;
+ }
+ }
+
+ if (tag_m_16x != m) {
+ __m256i matrixArray256;
+ __m256 accum256;
+ __m128 accum128, tmp128;
+ for (BLASLONG i = tag_m_16x; i < m; i++) {
+ accum256 = _mm256_setzero_ps();
+ matrixArray256 = _mm256_loadu_si256(&a[(i)*16]); // Load 1 rows with n=16
+ accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
+ accum128 = _mm_add_ps(accum128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * accum128[0] + beta * y[i];
+#else
+ y[i] = alpha * accum128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = accum128[0] * alpha;
+#else
+ y[i] = accum128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 8 rows parallel processing BF16 GEMV kernel for n>16 && lda effective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_8x16p_lda_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_8x16p_lda_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_8x16p_lda_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_8x16p_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_8x = m & (~7);
+
+ unsigned int load_mask_value = (((unsigned int)0xffffffff) >> (32-n));
+ __mmask32 load_mask = *((__mmask32*) &load_mask_value);
+ __m512i x512 = _mm512_maskz_loadu_epi16(load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|...
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
+ matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
+ __m512 accum512_0, accum512_1, accum512_2, accum512_3;
+ __m256 accum256;
+ __m128 accum128;
+
+ if (tag_m_8x > 0) {
+ __m512i xArray_0, xArray_1, xArray_2, xArray_3;
+
+ __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
+ __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
+ __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
+
+ // Prepare X with 2-step interleave way
+ xArray_0 = x512;
+ BF16_INTERLEAVE_1x32(xArray)
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ // Load 8 rows from matrix
+ BF16_MATRIX_MASKZ_LOAD_8x32(matrixArray, a, lda, idx_m, 0, load_mask)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_8x32(matrixArray)
+
+ // Calculate the temp result for a..h[0:31]
+ BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
+
+ // Reorder and add up the final result
+ accum512_2 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
+ accum512_3 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
+ accum512_2 = _mm512_add_ps(accum512_2, accum512_3);
+ accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_2), _mm512_extractf32x8_ps(accum512_2, 1));
+ STORE8_COMPLETE_RESULT(accum256, y+idx_m)
+ }
+
+ if (m - tag_m_8x > 3) {
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+
+ // Load 4 rows from matrix
+ BF16_MATRIX_MASKZ_LOAD_4x32(matrixArray, a, lda, tag_m_8x, 0, load_mask)
+
+ // 2-step interleave for matrix
+ BF16_INTERLEAVE_4x32(matrixArray)
+
+ // Calculate the temp result for a..d[0:31]
+ BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
+
+ accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
+ accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
+ STORE4_COMPLETE_RESULT(accum128, y+tag_m_8x)
+ tag_m_8x += 4;
+ }
+ }
+
+ if (tag_m_8x != m) {
+ __m128 tmp128;
+ for (BLASLONG i = tag_m_8x; i < m; i++) {
+ accum512_0 = _mm512_setzero_ps();
+ matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(i)*lda]); // Load 1 rows with n=16
+ accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) matrixArray_0, (__m512bh) x512);
+ accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
+ accum128 = _mm_add_ps(accum128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * accum128[0] + beta * y[i];
+#else
+ y[i] = alpha * accum128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = accum128[0] * alpha;
+#else
+ y[i] = accum128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 8 rows parallel processing BF16 GEMV kernel for big N && lda effective scenario (process before interleave)
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_1x128_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_1x128_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_1x128_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_8x = m & (~7);
+ BLASLONG tag_n_32x = n & (~31);
+ BLASLONG tag_n_128x = n & (~127);
+
+ __m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
+ accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
+ __m512 accum512_bridge[8];
+ __m512 accum512_t_0, accum512_t_1, accum512_t_2, accum512_t_3;
+ __m256 accum256_0;
+ __m128 accum128;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
+ __m512i xArray_0, xArray_1, xArray_2, xArray_3;
+
+ unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(n&31)));
+ __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
+
+ __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
+ __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
+ __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
+
+ if (tag_m_8x > 0) {
+ for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
+ for (int j = idx_m; j < idx_m + 8; j++) {
+ accum512_t_0 = _mm512_setzero_ps();
+ accum512_t_1 = _mm512_setzero_ps();
+ accum512_t_2 = _mm512_setzero_ps();
+ accum512_t_3 = _mm512_setzero_ps();
+ /* Processing the main chunk with 128-elements per round */
+ for (long idx_n = 0; idx_n < tag_n_128x; idx_n += 128) {
+ BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n + 0)
+ BF16_MATRIX_LOAD_1x32(matrixArray_1, a, lda, j, idx_n + 32)
+ BF16_MATRIX_LOAD_1x32(matrixArray_2, a, lda, j, idx_n + 64)
+ BF16_MATRIX_LOAD_1x32(matrixArray_3, a, lda, j, idx_n + 96)
+
+ BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n + 0)
+ BF16_VECTOR_LOAD_1x32(xArray_1, x, idx_n + 32)
+ BF16_VECTOR_LOAD_1x32(xArray_2, x, idx_n + 64)
+ BF16_VECTOR_LOAD_1x32(xArray_3, x, idx_n + 96)
+
+ BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
+ BF16_DOT_1x32(accum512_t_1, matrixArray_1, xArray_1)
+ BF16_DOT_1x32(accum512_t_2, matrixArray_2, xArray_2)
+ BF16_DOT_1x32(accum512_t_3, matrixArray_3, xArray_3)
+ }
+
+ /* Processing the remaining <128 chunk with 32-elements per round */
+ for (long idx_n = tag_n_128x; idx_n < tag_n_32x; idx_n += 32) {
+ BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n)
+ BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
+ BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
+ }
+
+ /* Processing the remaining <32 chunk with masked 32-elements processing */
+ if ((n&31) != 0) {
+ BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, j, tag_n_32x, tail_mask)
+ BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
+ BF16_DOT_1x32(accum512_t_2, matrixArray_0, xArray_0)
+ }
+
+ /* Accumulate the 4 registers into 1 register */
+ accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_1);
+ accum512_t_2 = _mm512_add_ps(accum512_t_2, accum512_t_3);
+ accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_2);
+
+ // Temply save the result into a ZMM
+ accum512_bridge[j-idx_m] = accum512_t_0;
+ }
+
+ FP32_INTERLEAVE_8x16_ARRAY(accum512_bridge)
+ FP32_ACCUM2_8x16_ARRAY(accum512_bridge)
+ accum512_bridge[1] = _mm512_permutex2var_ps(accum512_bridge[0], idx_base_0, accum512_bridge[4]);
+ accum512_bridge[2] = _mm512_permutex2var_ps(accum512_bridge[0], idx_base_1, accum512_bridge[4]);
+ accum512_bridge[1] = _mm512_add_ps(accum512_bridge[1], accum512_bridge[2]);
+ accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_bridge[1]), _mm512_extractf32x8_ps(accum512_bridge[1], 1));
+ STORE8_COMPLETE_RESULT(accum256_0, y+idx_m)
+ }
+ }
+
+ if (tag_m_8x != m) {
+ __m128 tmp128;
+ for (BLASLONG j = tag_m_8x; j < m; j++) {
+ accum512_t_0 = _mm512_setzero_ps();
+ accum512_t_1 = _mm512_setzero_ps();
+ accum512_t_2 = _mm512_setzero_ps();
+ accum512_t_3 = _mm512_setzero_ps();
+ /* Processing the main chunk with 128-elements per round */
+ for (long idx_n = 0; idx_n < tag_n_128x; idx_n += 128) {
+ BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n + 0)
+ BF16_MATRIX_LOAD_1x32(matrixArray_1, a, lda, j, idx_n + 32)
+ BF16_MATRIX_LOAD_1x32(matrixArray_2, a, lda, j, idx_n + 64)
+ BF16_MATRIX_LOAD_1x32(matrixArray_3, a, lda, j, idx_n + 96)
+
+ BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n + 0)
+ BF16_VECTOR_LOAD_1x32(xArray_1, x, idx_n + 32)
+ BF16_VECTOR_LOAD_1x32(xArray_2, x, idx_n + 64)
+ BF16_VECTOR_LOAD_1x32(xArray_3, x, idx_n + 96)
+
+ BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
+ BF16_DOT_1x32(accum512_t_1, matrixArray_1, xArray_1)
+ BF16_DOT_1x32(accum512_t_2, matrixArray_2, xArray_2)
+ BF16_DOT_1x32(accum512_t_3, matrixArray_3, xArray_3)
+ }
+
+ /* Processing the remaining <128 chunk with 32-elements per round */
+ for (long idx_n = tag_n_128x; idx_n < tag_n_32x; idx_n += 32) {
+ BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n)
+ BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
+ BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
+ }
+
+ /* Processing the remaining <32 chunk with masked 32-elements processing */
+ if ((n&31) != 0) {
+ BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, j, tag_n_32x, tail_mask)
+ BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
+ BF16_DOT_1x32(accum512_t_2, matrixArray_0, xArray_0)
+ }
+
+ /* Accumulate the 4 registers into 1 register */
+ accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_1);
+ accum512_t_2 = _mm512_add_ps(accum512_t_2, accum512_t_3);
+ accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_2);
+
+ accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_t_0), _mm512_extractf32x8_ps(accum512_t_0, 1));
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
+ accum128 = _mm_add_ps(accum128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[j] = alpha * accum128[0] + beta * y[j];
+#else
+ y[j] = alpha * accum128[0] + y[j];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[j] = accum128[0] * alpha;
+#else
+ y[j] = accum128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 8 rows parallel processing BF16 GEMV kernel for n=32 && lda effective scenario (process before interleave)
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_8x32_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_8x32_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_8x32_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_8x32_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_8x = m & (~7);
+ BLASLONG tag_n_32x = n & (~31);
+
+ __m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
+ accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
+ __m256 accum256_0;
+ __m128 accum128;
+
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_set1_ps(beta);
+#endif
+
+ __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
+ __m512i xArray_0;
+
+ unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(n&31)));
+ __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
+
+ if (tag_m_8x > 0) {
+ __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
+ __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
+ __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
+
+ for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
+ accum512_0 = _mm512_setzero_ps();
+ accum512_1 = _mm512_setzero_ps();
+ accum512_2 = _mm512_setzero_ps();
+ accum512_3 = _mm512_setzero_ps();
+ accum512_4 = _mm512_setzero_ps();
+ accum512_5 = _mm512_setzero_ps();
+ accum512_6 = _mm512_setzero_ps();
+ accum512_7 = _mm512_setzero_ps();
+
+ for (BLASLONG idx_n = 0; idx_n < tag_n_32x; idx_n+=32) {
+ // Load 8 rows from matrix
+ BF16_MATRIX_LOAD_8x32(matrixArray, a, lda, idx_m, idx_n)
+
+ // Load x
+ BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
+
+ // Calculate the temp result for a..h[0:31]
+ BF16_DOT_8x32(accum512, matrixArray, xArray_0)
+ }
+
+ if (tag_n_32x != n) { // Go with masked 512
+ // Load 8 rows from matrix
+ BF16_MATRIX_MASKZ_LOAD_8x32(matrixArray, a, lda, idx_m, tag_n_32x, tail_mask)
+
+ // Load x
+ BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
+
+ // Calculate the temp result for a..h[0:31]
+ BF16_DOT_8x32(accum512, matrixArray, xArray_0)
+ }
+
+ // 2-step interleave for FP32 regsiter array
+ FP32_INTERLEAVE_8x16(accum512)
+
+ // Accumulate the 2 batch of registers into 2 register (0 and 4)
+ FP32_ACCUM2_8x16(accum512)
+
+ accum512_1 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_4);
+ accum512_2 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_4);
+ accum512_1 = _mm512_add_ps(accum512_1, accum512_2);
+ accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_1), _mm512_extractf32x8_ps(accum512_1, 1));
+ STORE8_COMPLETE_RESULT(accum256_0, y+idx_m)
+ }
+ }
+
+ if (tag_m_8x != m) {
+ __m128 tmp128;
+ for (BLASLONG i = tag_m_8x; i < m; i++) {
+ accum512_0 = _mm512_setzero_ps();
+ for (BLASLONG idx_n = 0; idx_n < tag_n_32x; idx_n+=32) {
+ // Load 32 elements from matrix
+ BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, i, idx_n)
+
+ // Load 32 elements from x
+ BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
+
+ // Calculate and accumulate the temp result
+ BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
+ }
+
+ if (tag_n_32x != n) {
+ // Load tail elements from matrix
+ BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, i, tag_n_32x, tail_mask)
+
+ // Load 32 elements from x
+ BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
+
+ // Calculate and accumulate the temp result
+ BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
+ }
+
+ accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
+ accum128 = _mm_add_ps(accum128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * accum128[0] + beta * y[i];
+#else
+ y[i] = alpha * accum128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = accum128[0] * alpha;
+#else
+ y[i] = accum128[0];
+#endif
+#endif
+ }
+ }
+
+ return 0;
+}
+
+// 8 rows parallel processing BF16 GEMV kernel for n<16 && lda effective scenario
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+static int sbgemv_kernel_8x16m_lda_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
+#else
+static int sbgemv_kernel_8x16m_lda_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
+#endif
+#else
+#ifndef ONE_ALPHA
+static int sbgemv_kernel_8x16m_lda_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
+#else
+static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
+#endif
+#endif
+{
+ BLASLONG tag_m_8x = m & (~7);
+
+ __m256i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
+ __m256i xArray256;
+
+ // Keep align with other kernels and macro definition, the high 256bit is never used
+#ifndef ONE_ALPHA
+ __m512 ALPHAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(alpha));
+#endif
+#ifndef ZERO_BETA
+ __m512 BETAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(beta));
+#endif
+
+ __m256 accum256_0, accum256_1, accum256_2, accum256_3, accum256_4, accum256_5, accum256_6, accum256_7, \
+ accum256_8, accum256_9, accum256_10, accum256_11, accum256_12, accum256_13, accum256_14, accum256_15;
+
+ __m256i M256_EPI32_4 = _mm256_set1_epi32(4);
+ __m256i idx_base_0 = _mm256_set_epi32(11, 10, 9, 8, 3, 2, 1, 0);
+ __m256i idx_base_1 = _mm256_add_epi32(idx_base_0, M256_EPI32_4);
+
+ unsigned short load_mask_value = (((unsigned short)0xffff) >> (16-n));
+ __mmask16 load_mask = *((__mmask16*) &load_mask_value);
+
+ if (n == 16) {
+ BF16_VECTOR_LOAD_1x16(xArray256, x, 0)
+ } else {
+ BF16_VECTOR_MASKZ_LOAD_1x16(xArray256, x, 0, load_mask)
+ }
+
+ if (n == 16) {
+ for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
+ accum256_0 = _mm256_setzero_ps();
+ accum256_1 = _mm256_setzero_ps();
+ accum256_2 = _mm256_setzero_ps();
+ accum256_3 = _mm256_setzero_ps();
+ accum256_4 = _mm256_setzero_ps();
+ accum256_5 = _mm256_setzero_ps();
+ accum256_6 = _mm256_setzero_ps();
+ accum256_7 = _mm256_setzero_ps();
+
+ BF16_MATRIX_LOAD_8x16(matrixArray, a, lda, idx_m, 0)
+
+ BF16_DOT_8x16(accum256, matrixArray, xArray256)
+
+ // 2-step interleave for FP32 regsiter array
+ FP32_INTERLEAVE_8x8(accum256)
+
+ // Accumulate the 2 batch of registers into 2 register (0 and 4)
+ FP32_ACCUM2_8x8(accum256)
+
+ accum256_1 = _mm256_permutex2var_ps(accum256_0, idx_base_0, accum256_4);
+ accum256_2 = _mm256_permutex2var_ps(accum256_0, idx_base_1, accum256_4);
+ accum256_1 = _mm256_add_ps(accum256_1, accum256_2);
+
+ STORE8_COMPLETE_RESULT(accum256_1, y+idx_m)
+ }
+
+ if (tag_m_8x != m) {
+ __m128 accum128, tmp128;
+ for (BLASLONG i = tag_m_8x; i < m; i++) {
+ accum256_0 = _mm256_setzero_ps();
+ matrixArray_0 = _mm256_loadu_si256(&a[(i)*lda]); // Load 1 rows with n=16
+ accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256);
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ y[i] += accum128[0] * alpha;
+ }
+ }
+ } else {
+ for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
+ accum256_0 = _mm256_setzero_ps();
+ accum256_1 = _mm256_setzero_ps();
+ accum256_2 = _mm256_setzero_ps();
+ accum256_3 = _mm256_setzero_ps();
+ accum256_4 = _mm256_setzero_ps();
+ accum256_5 = _mm256_setzero_ps();
+ accum256_6 = _mm256_setzero_ps();
+ accum256_7 = _mm256_setzero_ps();
+
+ BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray, a, lda, idx_m, 0, load_mask)
+
+ BF16_DOT_8x16(accum256, matrixArray, xArray256)
+
+ // 2-step interleave for FP32 regsiter array
+ FP32_INTERLEAVE_8x8(accum256)
+
+ // Accumulate the 2 batch of registers into 2 register (0 and 4)
+ FP32_ACCUM2_8x8(accum256)
+
+ accum256_1 = _mm256_permutex2var_ps(accum256_0, idx_base_0, accum256_4);
+ accum256_2 = _mm256_permutex2var_ps(accum256_0, idx_base_1, accum256_4);
+ accum256_1 = _mm256_add_ps(accum256_1, accum256_2);
+
+ STORE8_COMPLETE_RESULT(accum256_1, y+idx_m)
+ }
+
+ if (tag_m_8x != m) {
+ __m128 accum128, tmp128;
+ for (BLASLONG i = tag_m_8x; i < m; i++) {
+ accum256_0 = _mm256_setzero_ps();
+ matrixArray_0 = _mm256_maskz_loadu_epi16(load_mask, &a[(i)*lda]); // Load 1 rows with n=16
+ accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256);
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
+ accum128 = _mm_add_ps(accum128, tmp128);
+ tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
+ accum128 = _mm_add_ps(accum128, tmp128);
+#ifndef ZERO_BETA
+#ifndef ONE_BETA
+ y[i] = alpha * accum128[0] + beta * y[i];
+#else
+ y[i] = alpha * accum128[0] + y[i];
+#endif
+#else
+#ifndef ONE_ALPHA
+ y[i] = accum128[0] * alpha;
+#else
+ y[i] = accum128[0];
+#endif
+#endif
+ }
+ }
+ }
+
+ return 0;
+}