Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / s8x8s32 / jit_avx512_core_gemm_s8u8s32_kern.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
18
19
20 #ifdef _WIN32
21 static const bool is_windows = 1;
22 #else
23 static const bool is_windows = 0;
24 #endif
25
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace Xbyak;
32
33
34
35
36 // Convert between vector register lengths.
37 static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); }
38 static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); }
39
40 // Load from or store to C.
41 void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst,
42     const Xbyak::Address &src, int nelems)
43 {
44     switch (nelems) {
45     default: vmovups(dst, src); break;
46     case 8:  vmovups(make_ymm(dst), src); break;
47     case 4:  vmovups(make_xmm(dst), src); break;
48     case 2:  vmovlps(make_xmm(dst), src); break;
49     case 1:  vmovss(make_xmm(dst), src); break;
50     }
51 }
52 void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst,
53     const Xbyak::Xmm &src, int nelems)
54 {
55     switch (nelems) {
56     default: vmovups(dst, src); break;
57     case 8:  vmovups(dst, make_ymm(src)); break;
58     case 4:  vmovups(dst, make_xmm(src)); break;
59     case 2:  vmovsd(dst, make_xmm(src)); break;
60     case 1:  vmovss(dst, make_xmm(src)); break;
61     }
62 }
63
64 // Perform length-4 dot product accumulations of unsigned and signed bytes
65 //  in parallel.
66 // Use vpdpbusd if VNNI available, otherwise emulate.
67 void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst,
68     const Xmm &src1, const Xmm &src2)
69 {
70     if (vnni)
71         vpdpbusd(dst, src1, src2);
72     else {
73         vpmaddubsw(dp_scratch, src1, src2);
74         vpmaddwd(dp_scratch, ones, dp_scratch);
75         vpaddd(dst, dst, dp_scratch);
76     }
77 }
78
79 // Inner kernel.
80 void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n,
81         bool cfetch)
82 {
83     int um_vecs = (unroll_m + 15) >> 4;
84     Label label_kernel_loop;
85
86     L_aligned(label_kernel_loop); {
87         for (int h = 0; h < 4; h++) {
88             for (int j = 0; j < unroll_n; j++) {
89                 const Zmm b = b_regs[j & 1];
90
91                 vpbroadcastd(b, ptr[BO + isize *
92                     (2 * j + 2 * h * unroll_n - offset_b)]);
93                 dot_product(c_regs[0][j], b, a_regs[0]);
94
95                 if (j == 1 && !(h & 1))
96                     prefetch_b(ptr[BO + isize * (prefetch_size_b
97                         + 2 * h * unroll_n - offset_b)]);
98                 else if (j % 3 == 0)
99                     prefetch_a(ptr[AO + isize * (prefetch_size_a
100                         + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]);
101
102                 for (int i = 1; i < um_vecs; i++)
103                     dot_product(c_regs[i][j], b, a_regs[i]);
104
105                 if (cfetch && (j == std::min(1, unroll_n - 1))) {
106                     if (h == 3)
107                         lea(CO2, ptr[CO2 + LDC]);
108                     else if (h < um_vecs)
109                         prefetch_c(ptr[CO2 + (16 * h * size)]);
110                 }
111
112                 if (h == 3 && j == std::min(3, unroll_n - 1))
113                     lea(AA, ptr[AA + (32 * isize)]);
114             }
115
116             for (int i = 0; i < um_vecs; i++)
117                 vmovups(a_regs[i], ptr[AO + isize *
118                 (32 * i + 2 * (h + 1) * unroll_m - offset_a)]);
119
120             if (h == 2)
121                 prefetch_x(ptr[AA - (offset_a * isize)]);
122         }
123
124         add(AO, 8 * isize * unroll_m);
125         add(BO, 8 * isize * unroll_n);
126         sub(LoopCount, 1);
127         jg(label_kernel_loop, T_NEAR);
128     }
129 }
130
131 // k remainder loop for kernel.
132 void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m,
133         int unroll_n, int unroll_k, int bwidth)
134 {
135     if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
136             || (unroll_m < 0)  || (unroll_n < 0))
137         return;
138
139     int um_vecs = (unroll_m + 15) >> 4;
140
141     for (int h = 0; h < unroll_k; h++) {
142         for (int j = 0; j < unroll_n; j++) {
143             Zmm b = b_regs[j & 1];
144             auto b_src = ptr[BO + (-isize * offset_b
145                 + bwidth * (j + h * unroll_n))];
146
147             switch (bwidth) {
148             case 4:
149                 vpbroadcastd(b, b_src);
150                 break;
151             case 2:
152                 vpbroadcastw(b, b_src);
153                 break;
154             case 1:
155                 vpbroadcastb(b, b_src);
156                 break;
157             }
158             for (int i = 0; i < um_vecs; i++)
159                 dot_product(c_regs[i][j], b, a_regs[i]);
160         }
161
162         if (unroll_k > 1) {
163             for (int i = 0; i < um_vecs; i++)
164                 vmovups(a_regs[i], ptr[AO + isize * (32 * i
165                     + (h + 1) * 2 * unroll_m - offset_a)]);
166         }
167     }
168
169     add(AO, unroll_k * unroll_m * bwidth);
170     add(BO, unroll_k * unroll_n * bwidth);
171 }
172
173 // Inner loop.
174 void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n)
175 {
176     if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
177             || (unroll_m < 0)  || (unroll_n < 0))
178         return;
179
180     int um_vecs = (unroll_m + 15) >> 4;
181     int stage1 = unroll_n, stage2 = unroll_n;
182
183     Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2;
184     Label label_k_main_loop_3, label_kernel_loop_3;
185     Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2;
186     Label label_k_rem_1, label_update_begin;
187
188     mov(AO, A);
189     for (int i = 0; i < um_vecs; i++)
190         vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]);
191
192     mov(LoopCount, K);
193     sar(LoopCount, 4);
194     jle(label_k_remainder_loop_begin, T_NEAR);
195
196     // Main k loops, broken into three parts to time C prefetching.
197     sub(LoopCount, stage1 + stage2);
198     jle(label_k_main_loop_2, T_NEAR);
199
200     kernel_loop(unroll_m, unroll_n, false);
201
202     L_aligned(label_k_main_loop_2);
203     lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
204     add(LoopCount, stage1);
205     jle(label_k_main_loop_3, T_NEAR);
206
207     kernel_loop(unroll_m, unroll_n, true);
208
209     L_aligned(label_k_main_loop_3);
210     lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
211     add(LoopCount, stage2);
212     jle(label_k_remainder_loop_begin, T_NEAR);
213
214     kernel_loop(unroll_m, unroll_n, true);
215
216     // k remainder handling
217     L_aligned(label_k_remainder_loop_begin);
218     mov(LoopCount, K);
219     test(LoopCount, 8);
220     je(label_k_rem_4, T_NEAR);
221
222     remainder_kernel(unroll_m, unroll_n, 2, 4);
223
224     L_aligned(label_k_rem_4);
225     mov(LoopCount, K);
226     test(LoopCount, 4);
227     je(label_k_rem_2, T_NEAR);
228
229     remainder_kernel(unroll_m, unroll_n, 1, 4);
230
231     L_aligned(label_k_rem_2);
232     mov(LoopCount, K);
233     test(LoopCount, 2);
234     je(label_k_rem_1, T_NEAR);
235
236     Zmm zero = zmm6;
237     Zmm tmp = zmm5;
238
239     vpxorq(zero, zero, zero);
240     for (int i = 0; i < um_vecs; i++) {
241         Zmm a = a_regs[i];
242         vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]);
243         vpunpcklwd(tmp, a, zero);
244         vpunpckhwd(a, a, zero);
245         vshufi32x4(a, tmp, a, 0x44);
246         vshufi32x4(a, a, a, 0xD8);
247     }
248
249     remainder_kernel(unroll_m, unroll_n, 1, 2);
250
251     L_aligned(label_k_rem_1);
252     mov(LoopCount, K);
253     test(LoopCount, 1);
254     je(label_update_begin, T_NEAR);
255
256     vpxorq(zero, zero, zero);
257     for (int i = 0; i < um_vecs; i++) {
258         Zmm a = a_regs[i];
259         vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]);
260         vpunpcklbw(tmp, a, zero);
261         vpunpckhbw(a, a, zero);
262         vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1);
263         vpunpcklwd(tmp, a, zero);
264         vpunpckhwd(a, a, zero);
265         vshufi32x4(a, tmp, a, 0x44);
266         vshufi32x4(a, a, a, 0xD8);
267     }
268
269     remainder_kernel(unroll_m, unroll_n, 1, 1);
270
271     // Add offsets and update C.
272     L_aligned(label_update_begin);
273
274     if (enable_offset_r) {
275         // Add row offsets.
276         mov(rax, coffset_ry);
277         for (int j = 0; j < unroll_n; j++) {
278             Zmm row_offset = zmm0;
279
280             vbroadcastss(row_offset, ptr[rax + size * j]);
281
282             for (int i = 0; i < um_vecs; i++)
283                 vpaddd(c_regs[i][j], c_regs[i][j], row_offset);
284         }
285         add(coffset_ry, size * unroll_n);
286     }
287
288     if (enable_offset_c) {
289         // Add column offsets.
290         mov(rax, coffset_cy);
291         for (int i = 0; i < um_vecs; i++) {
292             Zmm col_offset = zmm0;
293
294             c_load(col_offset, ptr[rax + size * 16 * i], unroll_m);
295
296             for (int j = 0; j < unroll_n; j++)
297                 vpaddd(c_regs[i][j], c_regs[i][j], col_offset);
298         }
299     }
300
301     Reg64 LDC3 = rax;
302     lea(LDC3, ptr[LDC + LDC * 2]);
303
304     // C updates.
305     int c_off_j = 0;
306     for (int j = 0; j < unroll_n; j++) {
307         if (j > 0 && (j & 3) == 0) {
308             lea(CO1, ptr[CO1 + LDC * 4]);
309             c_off_j += 4;
310         }
311
312         int jj = j - c_off_j;
313
314         for (int i = 0; i < um_vecs; i++) {
315             Zmm c = c_regs[i][j];
316             Zmm c_old = zmm0;
317             decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj;
318
319             auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i];
320
321             if (beta_zero)
322                 c_store(c_mem, c, unroll_m);
323             else {
324                 c_load(c_old, c_mem, unroll_m);
325                 vpaddd(c_old, c, c_old);
326                 c_store(c_mem, c_old, unroll_m);
327             }
328
329             vpxorq(c, c, c);
330         }
331     }
332
333     lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]);
334 }
335
336 // Outer loop.
337 void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y,
338     Label *&cur_outerloop_label)
339 {
340     Label label_m_loop, label_n_loop, label_n_remainder_loops[6];
341
342     L(*cur_outerloop_label);
343     cur_outerloop_label++;
344     if (unroll_x >= IGEMM_UNROLL_M) {
345         mov(J, M);
346         cmp(J, unroll_x);
347         jl(*cur_outerloop_label, T_NEAR);    // Jump to next outerloop label.
348     } else {
349         test(J, unroll_x);
350         jle(*cur_outerloop_label, T_NEAR);
351     }
352
353     L_aligned(label_m_loop); {
354         mov(CO1, C);
355         add(C, unroll_x * size);
356
357         mov(BO, B);
358
359         mov(AA, K);
360         imul(AA, AA, unroll_x * isize);
361         lea(AA, ptr[A + AA + isize * prefetch_size_a]);
362
363         if (enable_offset_c) {
364             mov(rax, coffset_cx);
365             mov(coffset_cy, rax);
366             add(rax, unroll_x * size);
367             mov(coffset_cx, rax);
368         }
369
370         if (enable_offset_r) {
371             mov(rax, coffset_rx);
372             mov(coffset_ry, rax);
373         }
374
375         mov(I, N);
376         cmp(I, unroll_y);
377         jl(label_n_remainder_loops[0], T_NEAR);
378
379         L_aligned(label_n_loop); {
380             innerloop(unroll_x, unroll_y);
381             sub(I, unroll_y);
382             cmp(I, unroll_y);
383             jge(label_n_loop, T_NEAR);
384         }
385
386         align(16);
387
388         int label_idx = 0;
389         for (int uy = 16; uy > 0; uy >>= 1) {
390             L(label_n_remainder_loops[label_idx++]);
391             if (unroll_y > uy) {
392                 test(I, uy);
393                 jle(label_n_remainder_loops[label_idx], T_NEAR);
394
395                 innerloop(unroll_x, uy);
396                 align(16);
397             }
398         }
399         L(label_n_remainder_loops[label_idx]);
400
401         mov(A, AO);
402         if (unroll_x >= IGEMM_UNROLL_M) {
403             sub(J, unroll_x);
404             cmp(J, unroll_x);
405             jge(label_m_loop);
406         }
407     }
408
409     align(16);
410 }
411
412 void jit_avx512_core_gemm_s8u8s32_kern::generate()
413 {
414     // Prologue
415     preamble();
416     sub(rsp, stack_alloc_size);
417
418     if (is_windows) {
419         mov(A, arg_a);
420         mov(B, arg_b);
421     }
422
423     mov(C, arg_c);
424     mov(LDC, arg_ldc);
425
426     sub(A, -offset_a * isize);
427     sub(B, -offset_b * isize);
428
429     mov(M, qword[M]);
430     mov(N, qword[N]);
431     mov(K, qword[K]);
432
433     lea(LDC, ptr[LDC * size]);
434
435     if (enable_offset_c) {
436         mov(rax, arg_coffset_c);
437         mov(coffset_cx, rax);
438     }
439     if (enable_offset_r) {
440         mov(rax, arg_coffset_r);
441         mov(coffset_rx, rax);
442     }
443
444     for (int i = 0; i < (max_unroll_m >> 4); i++) {
445         for (int j = 0; j < max_unroll_n; j++) {
446             auto &c = c_regs[i][j];
447             vpxorq(c, c, c);
448         }
449     }
450
451     if (!vnni) {
452         mov(rax, 1);
453         movq(make_xmm(ones), rax);
454         vpbroadcastw(ones, make_xmm(ones));
455     }
456
457     Label outerloop_labels[8];
458     Label *cur_outerloop_label = &outerloop_labels[0];
459
460     // Main m loop.
461     outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label);
462
463     // m remainder loops.
464     for (int um = 32; um > 0; um >>= 1)
465         if (IGEMM_UNROLL_M > um)
466             outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label);
467
468     L(*cur_outerloop_label);
469
470     // Epilogue.
471     add(rsp, stack_alloc_size);
472     postamble();
473 }
474
475
476 jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool
477         beta_zero_, bool enable_offset_c_, bool enable_offset_r_) :
478     jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0),
479     arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0),
480     coffset_rx(0), coffset_ry(0)
481 {
482     beta_zero = beta_zero_;
483     enable_offset_c = enable_offset_c_;
484     enable_offset_r = enable_offset_r_;
485     vnni = mayiuse(avx512_core_vnni);
486
487     // Assign integer registers
488     M = is_windows ? rcx : rdi;
489     N = is_windows ? rdx : rsi;
490     K = is_windows ? r8 : rdx;
491     A = is_windows ? rsi : r8;
492     B = r9;
493     C = r10;
494     LDC = r11;
495     I = r12;
496     J = r13;
497     LoopCount = rax;
498     AO = r14;
499     BO = r15;
500     CO1 = rbx;
501     CO2 = rbp;
502     AA = is_windows ? rdi : rcx;
503
504     // Assign vector registers
505     dp_scratch = zmm6;
506     ones = zmm7;
507     for (int i = 0; i < (max_unroll_m >> 4); i++)
508         a_regs[i] = Zmm(i);
509     b_regs[0] = zmm4;
510     b_regs[1] = zmm5;
511
512     int rn = 0;
513     for (int i = 0; i < (max_unroll_m >> 4); i++)
514         for (int j = 0; j < max_unroll_n; j++)
515             c_regs[i][j] = Zmm(8 + rn++);
516
517     // Assign stack variables.
518     stack_alloc_size = 32;
519     auto args_offset = stack_alloc_size + get_size_of_abi_save_regs()
520         + 8 + (is_windows ? 48 : 0);
521
522     arg_a         = ptr[rsp + (args_offset - 16)];
523     arg_b         = ptr[rsp + (args_offset - 8)];
524     arg_c         = ptr[rsp + (args_offset + 0)];
525     arg_ldc       = ptr[rsp + (args_offset + 8)];
526     arg_coffset_c = ptr[rsp + (args_offset + 16)];
527     arg_coffset_r = ptr[rsp + (args_offset + 24)];
528
529     coffset_cx = qword[rsp + 0];
530     coffset_cy = qword[rsp + 8];
531     coffset_rx = qword[rsp + 16];
532     coffset_ry = qword[rsp + 24];
533
534     generate();
535 }
536
537 }
538 }
539 }