Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / s8x8s32 / jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp
1 /*******************************************************************************
2  * Copyright 2019 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
18
19 #ifdef _WIN32
20 #define is_windows 1
21 #else
22 #define is_windows 0
23 #endif
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b,
30                                              Xbyak::Zmm a, Xbyak::Zmm tmp,
31                                              Xbyak::Zmm one, bool swap,
32                                              int use_vnni) {
33
34     if (use_vnni) {
35         if (swap)
36             vpdpbusd(acc, a, b);
37         else
38             vpdpbusd(acc, b, a);
39     }
40
41     else {
42         if (swap)
43             vpmaddubsw(tmp, a, b);
44         else
45             vpmaddubsw(tmp, b, a);
46         vpmaddwd(tmp, tmp, one);
47         vpaddd(acc, tmp, acc);
48     }
49
50 }
51
52 void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx,
53                                                     int b_idx, int nreg_acc,
54                                                     Xbyak::Reg64 A, Xbyak::Reg64 lda,
55                                                     Xbyak::Reg64 X, Xbyak::Zmm tmp,
56                                                     Xbyak::Zmm one, bool swap, int use_vnni,
57                                                     int use_mask, Xbyak::Opmask mask_n) {
58
59     int i;
60     int nreg_A = nreg_acc / 2 + (nreg_acc % 2);
61
62     // load X + j
63     if (use_mask)
64         vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]);
65     else
66         vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]);
67
68     xor_(r14, r14);
69     // load values of A
70     for (i = 0; i < nreg_A; i++) {
71         if (use_mask)
72             vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
73         else
74             vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
75         add(r14, lda);
76     }
77
78     for (i = 0; i < nreg_A; i++) {
79         // vnni (acc, b, a, tmp, one, swap, use_vnni)
80         vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx),
81              Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
82     }
83
84     for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
85         if (use_mask)
86             vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
87         else
88             vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
89         add(r14, lda);
90     }
91
92     for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
93         vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx),
94              Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
95     }
96
97 }
98
99 void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A,
100                                                         Xbyak::Zmm B, Xbyak::Zmm C,
101                                                         Xbyak::Zmm D) {
102
103     vshufi32x4(dest, A, C, 0x44);
104     vshufi32x4(A, A, C, 0xEE);
105     vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3
106
107     vshufi32x4(dest, B, D, 0x44);
108     vshufi32x4(B, B, D, 0xEE);
109     vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3
110
111     vshufi32x4(A, C, D, 0x88);
112     vshufi32x4(B, C, D, 0xDD);
113     vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi
114
115 }
116
117 void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y,
118                                                  int start_a_idx, int start_acc_idx,
119                                                  Xbyak::Xmm beta, int use_mask,
120                                                  Xbyak::Opmask mask_m) {
121
122     int l, i, k, j, last_it;
123     Xbyak::Label store_label;
124
125     l = 0;
126     for (k = 0; k < nreg_acc; k += 8) {
127         for (i = 0, j = k; i < 8; i += 4, j += 2) {
128             if (j < nreg_acc) {
129                 // shuffle per block of 4 registers
130                 shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest
131                                 Xbyak::Zmm(start_acc_idx + j), // A = acc0
132                                 Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1
133                                 Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4
134                                 Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5
135
136                 // extract low and high from dest and hadd
137                 vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0);
138                 vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1);
139                 vphaddd(Xbyak::Ymm(start_a_idx + l),
140                         Xbyak::Ymm(start_a_idx + l + 1),
141                         Xbyak::Ymm(start_a_idx + l + 2));
142             }
143             l++;
144         }
145
146         vphaddd(Xbyak::Ymm(start_a_idx + l),
147                 Xbyak::Ymm(start_a_idx + l - 2),
148                 Xbyak::Ymm(start_a_idx + l - 1));
149
150         l++;
151     }
152
153     // eventually add with C and store new value
154     vxorps(Xbyak::Ymm(start_a_idx),
155            Xbyak::Ymm(start_a_idx),
156            Xbyak::Ymm(start_a_idx));
157     vucomiss(beta, Xbyak::Ymm(start_a_idx));
158     je(store_label, T_NEAR);
159
160     // beta = 1
161     for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
162         // load Y and add
163         last_it = (k + 8) > nreg_acc;
164         if (use_mask && last_it)
165             vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]);
166         else
167             vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]);
168
169         vpaddd(Xbyak::Ymm(start_a_idx + l),
170                Xbyak::Ymm(start_a_idx + l),
171                Xbyak::Ymm(start_a_idx + k / 8));
172     }
173
174     // store
175     aligned_label(store_label);
176     for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
177         last_it = (k + 8) > nreg_acc;
178         if (use_mask && last_it)
179             vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m);
180         else
181             vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l));
182     }
183
184 }
185
186 template <typename T>
187 T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) {
188
189     Xbyak::Opmask mask_n = k1, mask_m = k2;
190     Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label;
191     Xbyak::Label n_tail_label, update_c_label, end_label;
192     constexpr unsigned int n_labels = (1 << unroll_m) - 1;
193     Xbyak::Label m_tail_label_case[n_labels];
194     Xbyak::Label n_loop_label_case[n_labels];
195     Xbyak::Label n_tail_label_case[n_labels];
196     Xbyak::Label update_c_label_case[n_labels];
197
198     int i, ii;
199
200     Xbyak::Zmm one, tmp;
201     Xbyak::Reg64 n = abi_param2, m = abi_param1;
202     Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3;
203     Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4;
204     Xbyak::Reg64 X = is_windows ? rdi : r8;
205     Xbyak::Xmm beta = xmm1;
206     Xbyak::Reg64 Y = is_windows ? rsi : r9;
207
208     bool swap = !std::is_same<T, gemv_s8u8s32_kernel_t>::value;
209
210     // Windows: read on the stack lda, X, beta, Y
211
212     int zmm_idx = 1;
213     int nreg_acc = 1 << unroll_m;
214     int nreg_A = 1 << (unroll_m - 1);
215     int nreg_A_acc = nreg_acc + nreg_A;
216
217     if (!use_vnni) {
218         // set a zmm register to one
219         tmp = Xbyak::Zmm(0);
220         one = Xbyak::Zmm(zmm_idx + 1);
221         zmm_idx += 2; // one + tmp
222     }
223     else {
224         beta = xmm0;
225     }
226
227     preamble();
228
229     if (is_windows) {
230         mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]);
231         mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]);
232         movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]);
233         mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]);
234     }
235
236     if (use_vnni && !is_windows) {
237         movaps(beta, xmm1);
238     }
239
240     mov(rax, (1 << unroll_n) - 1);
241     kmovq(k3, rax);
242
243     and_(rax, n); // rax contains n & ((1 << unroll_n) - 1)
244     mov(rbx, 1);
245     shlx(rbx, rbx, rax);
246     sub(rbx, 1);
247     kmovq(mask_n, rbx);
248     // mask_n set (AVX512 only), can use rax and rbx again
249
250     // set mask_m for update of the C matrix
251     // load/store on the C matrix use Ymm so tail according to Ymm size
252     mov(rax, 7); // 8 * 32 = 256 Ymm size
253     and_(rax, m); // rax contains m & 7
254     mov(rbx, 1);
255     shlx(rbx, rbx, rax);
256     sub(rbx, 1);
257     kmovq(mask_m, rbx);
258     // mask_m set (AVX512 only), can use rax and rbx again
259
260     // setup register of ones when VNNI instructions not available
261     if (!use_vnni) {
262         vmovdqu16(one, ptr[rip + one_label]);
263     }
264
265     // M loop
266     // base pointer for A rax contains a + i * lda
267     // Loop stop when rax >= a + (m & mask_um) * lda = rbx
268     // loop increment r10 = um * lda
269     // rbp = Y + i
270     mov(rax, A); // i = 0
271     mov(rbx, m);
272     and_(rbx, mask_um);
273     imul(rbx, lda);
274     add(rbx, A);
275     mov(r10, lda);
276     sal(r10, unroll_m);
277     mov(rbp, Y);
278
279     // N loop
280     // base pointer for X r11 contains x + j
281     // Loop stop when r11 >= x + n & mask_un = r12
282     // loop increment un
283     // r13 = rax + j = A + i * lda + j
284     mov(r12, n);
285     and_(r12, mask_un);
286     add(r12, X);
287
288     // M loop
289     aligned_label(m_loop_label);
290     cmp(rax, rbx);
291     jge(m_tail_label, T_NEAR);
292
293     // enter M loop
294     for(i = 0; i < nreg_acc; i++) {
295         vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
296                Xbyak::Zmm(i + zmm_idx + nreg_A),
297                Xbyak::Zmm(i + zmm_idx + nreg_A));
298     }
299
300     // N loop
301     mov(r11, X); // j = 0
302     mov(r13, rax);
303     aligned_label(n_loop_label);
304     cmp(r11, r12);
305     jge(n_tail_label, T_NEAR);
306
307     // enter N loop
308
309     n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
310                 r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
311
312     // increment rax with un
313     add(r11, 1 << unroll_n);
314     add(r13, 1 << unroll_n);
315     jmp(n_loop_label, T_NEAR);
316     // end N loop
317
318     // N tail
319     aligned_label(n_tail_label);
320
321     ktestq(mask_n, k3);
322     je(update_c_label, T_NEAR);
323     n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
324                 r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
325
326     // update C matrix
327     aligned_label(update_c_label);
328
329     update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m);
330
331     // increment rax with um * lda
332     add(rax, r10);
333     add(rbp, 1 << (unroll_m + 2));
334     jmp(m_loop_label, T_NEAR);
335     // end M loop
336
337     // M tail
338     aligned_label(m_tail_label);
339
340     // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1
341     mov(r10, m);
342     and_(r10, (1 << unroll_m) - 1);
343     for (ii = 1; ii < 1 << unroll_m; ii++) {
344         aligned_label(m_tail_label_case[ii-1]);
345         cmp(r10, ii);
346         if (ii == (1 << unroll_m) - 1)
347             jne(end_label, T_NEAR);
348         else
349             jne(m_tail_label_case[ii], T_NEAR);
350
351         // m_tail = i, use i accumulators
352
353         for(i = 0; i < ii; i++) {
354             vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
355                    Xbyak::Zmm(i + zmm_idx + nreg_A),
356                    Xbyak::Zmm(i + zmm_idx + nreg_A));
357         }
358
359         // N loop
360         mov(r11, X); // j = 0
361         mov(r13, rax);
362         aligned_label(n_loop_label_case[ii - 1]);
363         cmp(r11, r12);
364         jge(n_tail_label_case[ii - 1], T_NEAR);
365
366         n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
367                     lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
368
369         // increment rax with un
370         add(r11, 1 << unroll_n);
371         add(r13, 1 << unroll_n);
372         jmp(n_loop_label_case[ii - 1], T_NEAR);
373         // end N loop
374
375         // N tail
376         aligned_label(n_tail_label_case[ii - 1]);
377         ktestq(mask_n, k3);
378         je(update_c_label_case[ii - 1], T_NEAR);
379         n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
380                     lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
381
382         // update C matrix
383         aligned_label(update_c_label_case[ii - 1]);
384         update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m);
385
386         if (ii < ((1 << unroll_m) - 1))
387             jmp(end_label, T_NEAR);
388     }
389
390     aligned_label(end_label);
391
392     postamble();
393
394     if (!use_vnni) {
395         aligned_label(one_label);
396         for (i = 0; i < size_vec_reg/8; i++)
397             dq(0x0001000100010001);
398     }
399
400     return (T) getCode();
401 }
402
403 template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t
404 jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>(int);
405
406 template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t
407 jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>(int);
408
409 }
410 }
411 }