1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
21 static const bool is_windows = 1;
23 static const bool is_windows = 0;
31 using namespace Xbyak;
36 // Convert between vector register lengths.
37 static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); }
38 static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); }
40 // Load from or store to C.
41 void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst,
42 const Xbyak::Address &src, int nelems)
45 default: vmovups(dst, src); break;
46 case 8: vmovups(make_ymm(dst), src); break;
47 case 4: vmovups(make_xmm(dst), src); break;
48 case 2: vmovlps(make_xmm(dst), src); break;
49 case 1: vmovss(make_xmm(dst), src); break;
52 void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst,
53 const Xbyak::Xmm &src, int nelems)
56 default: vmovups(dst, src); break;
57 case 8: vmovups(dst, make_ymm(src)); break;
58 case 4: vmovups(dst, make_xmm(src)); break;
59 case 2: vmovsd(dst, make_xmm(src)); break;
60 case 1: vmovss(dst, make_xmm(src)); break;
64 // Perform length-4 dot product accumulations of unsigned and signed bytes
66 // Use vpdpbusd if VNNI available, otherwise emulate.
67 void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst,
68 const Xmm &src1, const Xmm &src2)
71 vpdpbusd(dst, src1, src2);
73 vpmaddubsw(dp_scratch, src1, src2);
74 vpmaddwd(dp_scratch, ones, dp_scratch);
75 vpaddd(dst, dst, dp_scratch);
80 void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n,
83 int um_vecs = (unroll_m + 15) >> 4;
84 Label label_kernel_loop;
86 L_aligned(label_kernel_loop); {
87 for (int h = 0; h < 4; h++) {
88 for (int j = 0; j < unroll_n; j++) {
89 const Zmm b = b_regs[j & 1];
91 vpbroadcastd(b, ptr[BO + isize *
92 (2 * j + 2 * h * unroll_n - offset_b)]);
93 dot_product(c_regs[0][j], b, a_regs[0]);
95 if (j == 1 && !(h & 1))
96 prefetch_b(ptr[BO + isize * (prefetch_size_b
97 + 2 * h * unroll_n - offset_b)]);
99 prefetch_a(ptr[AO + isize * (prefetch_size_a
100 + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]);
102 for (int i = 1; i < um_vecs; i++)
103 dot_product(c_regs[i][j], b, a_regs[i]);
105 if (cfetch && (j == std::min(1, unroll_n - 1))) {
107 lea(CO2, ptr[CO2 + LDC]);
108 else if (h < um_vecs)
109 prefetch_c(ptr[CO2 + (16 * h * size)]);
112 if (h == 3 && j == std::min(3, unroll_n - 1))
113 lea(AA, ptr[AA + (32 * isize)]);
116 for (int i = 0; i < um_vecs; i++)
117 vmovups(a_regs[i], ptr[AO + isize *
118 (32 * i + 2 * (h + 1) * unroll_m - offset_a)]);
121 prefetch_x(ptr[AA - (offset_a * isize)]);
124 add(AO, 8 * isize * unroll_m);
125 add(BO, 8 * isize * unroll_n);
127 jg(label_kernel_loop, T_NEAR);
131 // k remainder loop for kernel.
132 void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m,
133 int unroll_n, int unroll_k, int bwidth)
135 if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
136 || (unroll_m < 0) || (unroll_n < 0))
139 int um_vecs = (unroll_m + 15) >> 4;
141 for (int h = 0; h < unroll_k; h++) {
142 for (int j = 0; j < unroll_n; j++) {
143 Zmm b = b_regs[j & 1];
144 auto b_src = ptr[BO + (-isize * offset_b
145 + bwidth * (j + h * unroll_n))];
149 vpbroadcastd(b, b_src);
152 vpbroadcastw(b, b_src);
155 vpbroadcastb(b, b_src);
158 for (int i = 0; i < um_vecs; i++)
159 dot_product(c_regs[i][j], b, a_regs[i]);
163 for (int i = 0; i < um_vecs; i++)
164 vmovups(a_regs[i], ptr[AO + isize * (32 * i
165 + (h + 1) * 2 * unroll_m - offset_a)]);
169 add(AO, unroll_k * unroll_m * bwidth);
170 add(BO, unroll_k * unroll_n * bwidth);
174 void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n)
176 if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
177 || (unroll_m < 0) || (unroll_n < 0))
180 int um_vecs = (unroll_m + 15) >> 4;
181 int stage1 = unroll_n, stage2 = unroll_n;
183 Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2;
184 Label label_k_main_loop_3, label_kernel_loop_3;
185 Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2;
186 Label label_k_rem_1, label_update_begin;
189 for (int i = 0; i < um_vecs; i++)
190 vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]);
194 jle(label_k_remainder_loop_begin, T_NEAR);
196 // Main k loops, broken into three parts to time C prefetching.
197 sub(LoopCount, stage1 + stage2);
198 jle(label_k_main_loop_2, T_NEAR);
200 kernel_loop(unroll_m, unroll_n, false);
202 L_aligned(label_k_main_loop_2);
203 lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
204 add(LoopCount, stage1);
205 jle(label_k_main_loop_3, T_NEAR);
207 kernel_loop(unroll_m, unroll_n, true);
209 L_aligned(label_k_main_loop_3);
210 lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
211 add(LoopCount, stage2);
212 jle(label_k_remainder_loop_begin, T_NEAR);
214 kernel_loop(unroll_m, unroll_n, true);
216 // k remainder handling
217 L_aligned(label_k_remainder_loop_begin);
220 je(label_k_rem_4, T_NEAR);
222 remainder_kernel(unroll_m, unroll_n, 2, 4);
224 L_aligned(label_k_rem_4);
227 je(label_k_rem_2, T_NEAR);
229 remainder_kernel(unroll_m, unroll_n, 1, 4);
231 L_aligned(label_k_rem_2);
234 je(label_k_rem_1, T_NEAR);
239 vpxorq(zero, zero, zero);
240 for (int i = 0; i < um_vecs; i++) {
242 vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]);
243 vpunpcklwd(tmp, a, zero);
244 vpunpckhwd(a, a, zero);
245 vshufi32x4(a, tmp, a, 0x44);
246 vshufi32x4(a, a, a, 0xD8);
249 remainder_kernel(unroll_m, unroll_n, 1, 2);
251 L_aligned(label_k_rem_1);
254 je(label_update_begin, T_NEAR);
256 vpxorq(zero, zero, zero);
257 for (int i = 0; i < um_vecs; i++) {
259 vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]);
260 vpunpcklbw(tmp, a, zero);
261 vpunpckhbw(a, a, zero);
262 vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1);
263 vpunpcklwd(tmp, a, zero);
264 vpunpckhwd(a, a, zero);
265 vshufi32x4(a, tmp, a, 0x44);
266 vshufi32x4(a, a, a, 0xD8);
269 remainder_kernel(unroll_m, unroll_n, 1, 1);
271 // Add offsets and update C.
272 L_aligned(label_update_begin);
274 if (enable_offset_r) {
276 mov(rax, coffset_ry);
277 for (int j = 0; j < unroll_n; j++) {
278 Zmm row_offset = zmm0;
280 vbroadcastss(row_offset, ptr[rax + size * j]);
282 for (int i = 0; i < um_vecs; i++)
283 vpaddd(c_regs[i][j], c_regs[i][j], row_offset);
285 add(coffset_ry, size * unroll_n);
288 if (enable_offset_c) {
289 // Add column offsets.
290 mov(rax, coffset_cy);
291 for (int i = 0; i < um_vecs; i++) {
292 Zmm col_offset = zmm0;
294 c_load(col_offset, ptr[rax + size * 16 * i], unroll_m);
296 for (int j = 0; j < unroll_n; j++)
297 vpaddd(c_regs[i][j], c_regs[i][j], col_offset);
302 lea(LDC3, ptr[LDC + LDC * 2]);
306 for (int j = 0; j < unroll_n; j++) {
307 if (j > 0 && (j & 3) == 0) {
308 lea(CO1, ptr[CO1 + LDC * 4]);
312 int jj = j - c_off_j;
314 for (int i = 0; i < um_vecs; i++) {
315 Zmm c = c_regs[i][j];
317 decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj;
319 auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i];
322 c_store(c_mem, c, unroll_m);
324 c_load(c_old, c_mem, unroll_m);
325 vpaddd(c_old, c, c_old);
326 c_store(c_mem, c_old, unroll_m);
333 lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]);
337 void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y,
338 Label *&cur_outerloop_label)
340 Label label_m_loop, label_n_loop, label_n_remainder_loops[6];
342 L(*cur_outerloop_label);
343 cur_outerloop_label++;
344 if (unroll_x >= IGEMM_UNROLL_M) {
347 jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label.
350 jle(*cur_outerloop_label, T_NEAR);
353 L_aligned(label_m_loop); {
355 add(C, unroll_x * size);
360 imul(AA, AA, unroll_x * isize);
361 lea(AA, ptr[A + AA + isize * prefetch_size_a]);
363 if (enable_offset_c) {
364 mov(rax, coffset_cx);
365 mov(coffset_cy, rax);
366 add(rax, unroll_x * size);
367 mov(coffset_cx, rax);
370 if (enable_offset_r) {
371 mov(rax, coffset_rx);
372 mov(coffset_ry, rax);
377 jl(label_n_remainder_loops[0], T_NEAR);
379 L_aligned(label_n_loop); {
380 innerloop(unroll_x, unroll_y);
383 jge(label_n_loop, T_NEAR);
389 for (int uy = 16; uy > 0; uy >>= 1) {
390 L(label_n_remainder_loops[label_idx++]);
393 jle(label_n_remainder_loops[label_idx], T_NEAR);
395 innerloop(unroll_x, uy);
399 L(label_n_remainder_loops[label_idx]);
402 if (unroll_x >= IGEMM_UNROLL_M) {
412 void jit_avx512_core_gemm_s8u8s32_kern::generate()
416 sub(rsp, stack_alloc_size);
426 sub(A, -offset_a * isize);
427 sub(B, -offset_b * isize);
433 lea(LDC, ptr[LDC * size]);
435 if (enable_offset_c) {
436 mov(rax, arg_coffset_c);
437 mov(coffset_cx, rax);
439 if (enable_offset_r) {
440 mov(rax, arg_coffset_r);
441 mov(coffset_rx, rax);
444 for (int i = 0; i < (max_unroll_m >> 4); i++) {
445 for (int j = 0; j < max_unroll_n; j++) {
446 auto &c = c_regs[i][j];
453 movq(make_xmm(ones), rax);
454 vpbroadcastw(ones, make_xmm(ones));
457 Label outerloop_labels[8];
458 Label *cur_outerloop_label = &outerloop_labels[0];
461 outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label);
463 // m remainder loops.
464 for (int um = 32; um > 0; um >>= 1)
465 if (IGEMM_UNROLL_M > um)
466 outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label);
468 L(*cur_outerloop_label);
471 add(rsp, stack_alloc_size);
476 jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool
477 beta_zero_, bool enable_offset_c_, bool enable_offset_r_) :
478 jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0),
479 arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0),
480 coffset_rx(0), coffset_ry(0)
482 beta_zero = beta_zero_;
483 enable_offset_c = enable_offset_c_;
484 enable_offset_r = enable_offset_r_;
485 vnni = mayiuse(avx512_core_vnni);
487 // Assign integer registers
488 M = is_windows ? rcx : rdi;
489 N = is_windows ? rdx : rsi;
490 K = is_windows ? r8 : rdx;
491 A = is_windows ? rsi : r8;
502 AA = is_windows ? rdi : rcx;
504 // Assign vector registers
507 for (int i = 0; i < (max_unroll_m >> 4); i++)
513 for (int i = 0; i < (max_unroll_m >> 4); i++)
514 for (int j = 0; j < max_unroll_n; j++)
515 c_regs[i][j] = Zmm(8 + rn++);
517 // Assign stack variables.
518 stack_alloc_size = 32;
519 auto args_offset = stack_alloc_size + get_size_of_abi_save_regs()
520 + 8 + (is_windows ? 48 : 0);
522 arg_a = ptr[rsp + (args_offset - 16)];
523 arg_b = ptr[rsp + (args_offset - 8)];
524 arg_c = ptr[rsp + (args_offset + 0)];
525 arg_ldc = ptr[rsp + (args_offset + 8)];
526 arg_coffset_c = ptr[rsp + (args_offset + 16)];
527 arg_coffset_r = ptr[rsp + (args_offset + 24)];
529 coffset_cx = qword[rsp + 0];
530 coffset_cy = qword[rsp + 8];
531 coffset_rx = qword[rsp + 16];
532 coffset_ry = qword[rsp + 24];