1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
29 void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b,
30 Xbyak::Zmm a, Xbyak::Zmm tmp,
31 Xbyak::Zmm one, bool swap,
43 vpmaddubsw(tmp, a, b);
45 vpmaddubsw(tmp, b, a);
46 vpmaddwd(tmp, tmp, one);
47 vpaddd(acc, tmp, acc);
52 void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx,
53 int b_idx, int nreg_acc,
54 Xbyak::Reg64 A, Xbyak::Reg64 lda,
55 Xbyak::Reg64 X, Xbyak::Zmm tmp,
56 Xbyak::Zmm one, bool swap, int use_vnni,
57 int use_mask, Xbyak::Opmask mask_n) {
60 int nreg_A = nreg_acc / 2 + (nreg_acc % 2);
64 vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]);
66 vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]);
70 for (i = 0; i < nreg_A; i++) {
72 vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
74 vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
78 for (i = 0; i < nreg_A; i++) {
79 // vnni (acc, b, a, tmp, one, swap, use_vnni)
80 vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx),
81 Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
84 for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
86 vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
88 vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
92 for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
93 vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx),
94 Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
99 void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A,
100 Xbyak::Zmm B, Xbyak::Zmm C,
103 vshufi32x4(dest, A, C, 0x44);
104 vshufi32x4(A, A, C, 0xEE);
105 vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3
107 vshufi32x4(dest, B, D, 0x44);
108 vshufi32x4(B, B, D, 0xEE);
109 vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3
111 vshufi32x4(A, C, D, 0x88);
112 vshufi32x4(B, C, D, 0xDD);
113 vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi
117 void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y,
118 int start_a_idx, int start_acc_idx,
119 Xbyak::Xmm beta, int use_mask,
120 Xbyak::Opmask mask_m) {
122 int l, i, k, j, last_it;
123 Xbyak::Label store_label;
126 for (k = 0; k < nreg_acc; k += 8) {
127 for (i = 0, j = k; i < 8; i += 4, j += 2) {
129 // shuffle per block of 4 registers
130 shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest
131 Xbyak::Zmm(start_acc_idx + j), // A = acc0
132 Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1
133 Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4
134 Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5
136 // extract low and high from dest and hadd
137 vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0);
138 vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1);
139 vphaddd(Xbyak::Ymm(start_a_idx + l),
140 Xbyak::Ymm(start_a_idx + l + 1),
141 Xbyak::Ymm(start_a_idx + l + 2));
146 vphaddd(Xbyak::Ymm(start_a_idx + l),
147 Xbyak::Ymm(start_a_idx + l - 2),
148 Xbyak::Ymm(start_a_idx + l - 1));
153 // eventually add with C and store new value
154 vxorps(Xbyak::Ymm(start_a_idx),
155 Xbyak::Ymm(start_a_idx),
156 Xbyak::Ymm(start_a_idx));
157 vucomiss(beta, Xbyak::Ymm(start_a_idx));
158 je(store_label, T_NEAR);
161 for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
163 last_it = (k + 8) > nreg_acc;
164 if (use_mask && last_it)
165 vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]);
167 vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]);
169 vpaddd(Xbyak::Ymm(start_a_idx + l),
170 Xbyak::Ymm(start_a_idx + l),
171 Xbyak::Ymm(start_a_idx + k / 8));
175 aligned_label(store_label);
176 for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
177 last_it = (k + 8) > nreg_acc;
178 if (use_mask && last_it)
179 vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m);
181 vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l));
186 template <typename T>
187 T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) {
189 Xbyak::Opmask mask_n = k1, mask_m = k2;
190 Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label;
191 Xbyak::Label n_tail_label, update_c_label, end_label;
192 constexpr unsigned int n_labels = (1 << unroll_m) - 1;
193 Xbyak::Label m_tail_label_case[n_labels];
194 Xbyak::Label n_loop_label_case[n_labels];
195 Xbyak::Label n_tail_label_case[n_labels];
196 Xbyak::Label update_c_label_case[n_labels];
201 Xbyak::Reg64 n = abi_param2, m = abi_param1;
202 Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3;
203 Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4;
204 Xbyak::Reg64 X = is_windows ? rdi : r8;
205 Xbyak::Xmm beta = xmm1;
206 Xbyak::Reg64 Y = is_windows ? rsi : r9;
208 bool swap = !std::is_same<T, gemv_s8u8s32_kernel_t>::value;
210 // Windows: read on the stack lda, X, beta, Y
213 int nreg_acc = 1 << unroll_m;
214 int nreg_A = 1 << (unroll_m - 1);
215 int nreg_A_acc = nreg_acc + nreg_A;
218 // set a zmm register to one
220 one = Xbyak::Zmm(zmm_idx + 1);
221 zmm_idx += 2; // one + tmp
230 mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]);
231 mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]);
232 movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]);
233 mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]);
236 if (use_vnni && !is_windows) {
240 mov(rax, (1 << unroll_n) - 1);
243 and_(rax, n); // rax contains n & ((1 << unroll_n) - 1)
248 // mask_n set (AVX512 only), can use rax and rbx again
250 // set mask_m for update of the C matrix
251 // load/store on the C matrix use Ymm so tail according to Ymm size
252 mov(rax, 7); // 8 * 32 = 256 Ymm size
253 and_(rax, m); // rax contains m & 7
258 // mask_m set (AVX512 only), can use rax and rbx again
260 // setup register of ones when VNNI instructions not available
262 vmovdqu16(one, ptr[rip + one_label]);
266 // base pointer for A rax contains a + i * lda
267 // Loop stop when rax >= a + (m & mask_um) * lda = rbx
268 // loop increment r10 = um * lda
270 mov(rax, A); // i = 0
280 // base pointer for X r11 contains x + j
281 // Loop stop when r11 >= x + n & mask_un = r12
283 // r13 = rax + j = A + i * lda + j
289 aligned_label(m_loop_label);
291 jge(m_tail_label, T_NEAR);
294 for(i = 0; i < nreg_acc; i++) {
295 vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
296 Xbyak::Zmm(i + zmm_idx + nreg_A),
297 Xbyak::Zmm(i + zmm_idx + nreg_A));
301 mov(r11, X); // j = 0
303 aligned_label(n_loop_label);
305 jge(n_tail_label, T_NEAR);
309 n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
310 r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
312 // increment rax with un
313 add(r11, 1 << unroll_n);
314 add(r13, 1 << unroll_n);
315 jmp(n_loop_label, T_NEAR);
319 aligned_label(n_tail_label);
322 je(update_c_label, T_NEAR);
323 n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
324 r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
327 aligned_label(update_c_label);
329 update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m);
331 // increment rax with um * lda
333 add(rbp, 1 << (unroll_m + 2));
334 jmp(m_loop_label, T_NEAR);
338 aligned_label(m_tail_label);
340 // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1
342 and_(r10, (1 << unroll_m) - 1);
343 for (ii = 1; ii < 1 << unroll_m; ii++) {
344 aligned_label(m_tail_label_case[ii-1]);
346 if (ii == (1 << unroll_m) - 1)
347 jne(end_label, T_NEAR);
349 jne(m_tail_label_case[ii], T_NEAR);
351 // m_tail = i, use i accumulators
353 for(i = 0; i < ii; i++) {
354 vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
355 Xbyak::Zmm(i + zmm_idx + nreg_A),
356 Xbyak::Zmm(i + zmm_idx + nreg_A));
360 mov(r11, X); // j = 0
362 aligned_label(n_loop_label_case[ii - 1]);
364 jge(n_tail_label_case[ii - 1], T_NEAR);
366 n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
367 lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
369 // increment rax with un
370 add(r11, 1 << unroll_n);
371 add(r13, 1 << unroll_n);
372 jmp(n_loop_label_case[ii - 1], T_NEAR);
376 aligned_label(n_tail_label_case[ii - 1]);
378 je(update_c_label_case[ii - 1], T_NEAR);
379 n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
380 lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
383 aligned_label(update_c_label_case[ii - 1]);
384 update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m);
386 if (ii < ((1 << unroll_m) - 1))
387 jmp(end_label, T_NEAR);
390 aligned_label(end_label);
395 aligned_label(one_label);
396 for (i = 0; i < size_vec_reg/8; i++)
397 dq(0x0001000100010001);
400 return (T) getCode();
403 template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t
404 jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>(int);
406 template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t
407 jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>(int);