Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / s8x8s32 / jit_avx512_core_u8_copy_bn_kern.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "jit_generator.hpp"
18 #include "common.hpp"
19
20 namespace mkldnn {
21 namespace impl {
22 namespace cpu {
23
24 jit_avx512_core_u8_copy_bn_kern::jit_avx512_core_u8_copy_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
25 {
26
27 #ifndef _WIN32
28 #define M       rdi
29 #define N       rsi
30 #define A       rdx
31 #define LDA     rcx
32 #define ALPHA   r8
33 #define B       r9
34
35 #define I       rax
36 #define A1      r10
37 #define A2      r8
38 #define LDA3    r11
39
40 #else
41
42 #define M       rcx
43 #define N       rdx
44 #define A       r8
45 #define LDA     r9
46 #define ALPHA   rax
47 #define B       rdi
48
49 #define I       rax
50 #define A1      rsi
51 #define A2      r10
52 #define LDA3    r11
53
54 #define ARG_ALPHA       40+stacksize+rsp
55 #define ARG_B           48+stacksize+rsp
56
57 #endif
58
59 inLocalLabel();
60 {
61
62 Xbyak::Label l118;
63 Xbyak::Label l1a8;
64 Xbyak::Label l20;
65 Xbyak::Label l218;
66 Xbyak::Label l28c;
67 Xbyak::Label l2f8;
68 Xbyak::Label l308;
69 Xbyak::Label l314;
70 Xbyak::Label l32c;
71 Xbyak::Label l3a0;
72 Xbyak::Label l3c;
73 Xbyak::Label l3f0;
74 Xbyak::Label l434;
75 Xbyak::Label l47c;
76 Xbyak::Label l4bc;
77 Xbyak::Label l4cc;
78 Xbyak::Label l4d8;
79 Xbyak::Label l4f0;
80 Xbyak::Label l528;
81 Xbyak::Label l554;
82 Xbyak::Label l580;
83 Xbyak::Label l5b0;
84 Xbyak::Label l5d0;
85 Xbyak::Label l5de;
86 Xbyak::Label l5e8;
87 Xbyak::Label l5f8;
88 Xbyak::Label l614;
89 Xbyak::Label l634;
90 Xbyak::Label l654;
91 Xbyak::Label l670;
92 Xbyak::Label l688;
93 Xbyak::Label l698;
94
95         preamble();
96 #ifdef _WIN32
97         auto stacksize = get_size_of_abi_save_regs();
98         mov(ALPHA, ptr[ARG_ALPHA]);
99         mov(B, ptr[ARG_B]);
100 #endif
101
102         mov(N, qword[N]);
103         mov(M, qword[M]);
104         mov(LDA, qword[LDA]);
105         sub(A, -128);
106         sub(B, -128);
107         lea(LDA3, ptr[LDA+LDA*2]);
108         cmp(N, 0x8);
109         jl(l308, T_NEAR);
110         align(4);
111
112 L(l20);
113         mov(A1, A);
114         lea(A2, ptr[A1+LDA*4]);
115         lea(I, ptr[A1+LDA*8]);
116         mov(A, I);
117         mov(I, M);
118         sar(I, 0x4);
119         jle(l118, T_NEAR);
120         align(4);
121
122 L(l3c);
123         movdqu(xmm0, xword[A1-0x80]);
124         movdqu(xmm1, xword[A1+LDA*1-0x80]);
125         movdqu(xmm2, xword[A1+LDA*2-0x80]);
126         movdqu(xmm3, xword[A1+LDA3*1-0x80]);
127         sub(A1, -16);
128         movdqa(xmm4, xmm0);
129         punpckldq(xmm0, xmm1);
130         punpckhdq(xmm4, xmm1);
131         movdqa(xmm5, xmm2);
132         punpckldq(xmm2, xmm3);
133         punpckhdq(xmm5, xmm3);
134         movdqa(xmm1, xmm0);
135         punpcklqdq(xmm0, xmm2);
136         punpckhqdq(xmm1, xmm2);
137         movdqa(xmm3, xmm4);
138         punpcklqdq(xmm4, xmm5);
139         punpckhqdq(xmm3, xmm5);
140         movdqu(xword[B-0x80], xmm0);
141         movdqu(xword[B-0x60], xmm1);
142         movdqu(xword[B-0x40], xmm4);
143         movdqu(xword[B-0x20], xmm3);
144         movdqu(xmm0, xword[A2-0x80]);
145         movdqu(xmm1, xword[A2+LDA*1-0x80]);
146         movdqu(xmm2, xword[A2+LDA*2-0x80]);
147         movdqu(xmm3, xword[A2+LDA3*1-0x80]);
148         sub(A2, -16);
149         movdqa(xmm4, xmm0);
150         punpckldq(xmm0, xmm1);
151         punpckhdq(xmm4, xmm1);
152         movdqa(xmm5, xmm2);
153         punpckldq(xmm2, xmm3);
154         punpckhdq(xmm5, xmm3);
155         movdqa(xmm1, xmm0);
156         punpcklqdq(xmm0, xmm2);
157         punpckhqdq(xmm1, xmm2);
158         movdqa(xmm3, xmm4);
159         punpcklqdq(xmm4, xmm5);
160         punpckhqdq(xmm3, xmm5);
161         movdqu(xword[B-0x70], xmm0);
162         movdqu(xword[B-0x50], xmm1);
163         movdqu(xword[B-0x30], xmm4);
164         movdqu(xword[B-0x10], xmm3);
165         sub(B, -128);
166         dec(I);
167         jg(l3c, T_NEAR);
168         align(4);
169
170 L(l118);
171         test(M, 0x8);
172         jle(l1a8, T_NEAR);
173         movq(xmm0, qword[A1-0x80]);
174         movq(xmm1, qword[A1+LDA*1-0x80]);
175         movq(xmm2, qword[A1+LDA*2-0x80]);
176         movq(xmm3, qword[A1+LDA3*1-0x80]);
177         sub(A1, -8);
178         punpckldq(xmm0, xmm1);
179         punpckldq(xmm2, xmm3);
180         movdqa(xmm1, xmm0);
181         punpcklqdq(xmm0, xmm2);
182         punpckhqdq(xmm1, xmm2);
183         movdqu(xword[B-0x80], xmm0);
184         movdqu(xword[B-0x60], xmm1);
185         movq(xmm0, qword[A2-0x80]);
186         movq(xmm1, qword[A2+LDA*1-0x80]);
187         movq(xmm2, qword[A2+LDA*2-0x80]);
188         movq(xmm3, qword[A2+LDA3*1-0x80]);
189         sub(A2, -8);
190         punpckldq(xmm0, xmm1);
191         punpckldq(xmm2, xmm3);
192         movdqa(xmm1, xmm0);
193         punpcklqdq(xmm0, xmm2);
194         punpckhqdq(xmm1, xmm2);
195         movdqu(xword[B-0x70], xmm0);
196         movdqu(xword[B-0x50], xmm1);
197         sub(B, -64);
198         align(4);
199
200 L(l1a8);
201         test(M, 0x4);
202         jle(l218, T_NEAR);
203         movd(xmm0, dword[A1-0x80]);
204         movd(xmm1, dword[A1+LDA*1-0x80]);
205         movd(xmm2, dword[A1+LDA*2-0x80]);
206         movd(xmm3, dword[A1+LDA3*1-0x80]);
207         sub(A1, -4);
208         punpckldq(xmm0, xmm1);
209         punpckldq(xmm2, xmm3);
210         punpcklqdq(xmm0, xmm2);
211         movdqu(xword[B-0x80], xmm0);
212         movd(xmm0, dword[A2-0x80]);
213         movd(xmm1, dword[A2+LDA*1-0x80]);
214         movd(xmm2, dword[A2+LDA*2-0x80]);
215         movd(xmm3, dword[A2+LDA3*1-0x80]);
216         sub(A2, -4);
217         punpckldq(xmm0, xmm1);
218         punpckldq(xmm2, xmm3);
219         punpcklqdq(xmm0, xmm2);
220         movdqu(xword[B-0x70], xmm0);
221         sub(B, -32);
222         align(4);
223
224 L(l218);
225         test(M, 0x2);
226         jle(l28c, T_NEAR);
227         mov(ax, word[A1-0x80]);
228         pinsrw(xmm0, eax, 0x0);
229         mov(ax, word[A1+LDA*1-0x80]);
230         pinsrw(xmm0, eax, 0x1);
231         mov(ax, word[A1+LDA*2-0x80]);
232         pinsrw(xmm0, eax, 0x2);
233         mov(ax, word[A1+LDA3*1-0x80]);
234         sub(A1, -2);
235         pinsrw(xmm0, eax, 0x3);
236         mov(ax, word[A2-0x80]);
237         pinsrw(xmm0, eax, 0x4);
238         mov(ax, word[A2+LDA*1-0x80]);
239         pinsrw(xmm0, eax, 0x5);
240         mov(ax, word[A2+LDA*2-0x80]);
241         pinsrw(xmm0, eax, 0x6);
242         mov(ax, word[A2+LDA3*1-0x80]);
243         sub(A2, -2);
244         pinsrw(xmm0, eax, 0x7);
245         movdqu(xword[B-0x80], xmm0);
246         sub(B, -16);
247         align(4);
248
249 L(l28c);
250         test(M, 0x1);
251         jle(l2f8, T_NEAR);
252         mov(al, byte[A1-0x80]);
253         pinsrb(xmm0, eax, 0x0);
254         mov(al, byte[A1+LDA*1-0x80]);
255         pinsrb(xmm0, eax, 0x1);
256         mov(al, byte[A1+LDA*2-0x80]);
257         pinsrb(xmm0, eax, 0x2);
258         mov(al, byte[A1+LDA3*1-0x80]);
259         pinsrb(xmm0, eax, 0x3);
260         mov(al, byte[A2-0x80]);
261         pinsrb(xmm0, eax, 0x4);
262         mov(al, byte[A2+LDA*1-0x80]);
263         pinsrb(xmm0, eax, 0x5);
264         mov(al, byte[A2+LDA*2-0x80]);
265         pinsrb(xmm0, eax, 0x6);
266         mov(al, byte[A2+LDA3*1-0x80]);
267         pinsrb(xmm0, eax, 0x7);
268         movq(qword[B-0x80], xmm0);
269         sub(B, -8);
270         align(4);
271
272 L(l2f8);
273         sub(N, 0x8);
274         cmp(N, 0x8);
275         jge(l20, T_NEAR);
276         align(4);
277
278 L(l308);
279         cmp(N, 0x4);
280         jl(l4cc, T_NEAR);
281         align(4);
282
283 L(l314);
284         mov(A1, A);
285         lea(A2, ptr[A1+LDA*2]);
286         lea(I, ptr[A1+LDA*4]);
287         mov(A, I);
288         mov(I, M);
289         sar(I, 0x4);
290         jle(l3a0, T_NEAR);
291         align(4);
292
293 L(l32c);
294         movdqu(xmm0, xword[A1-0x80]);
295         movdqu(xmm1, xword[A1+LDA*1-0x80]);
296         sub(A1, -16);
297         movdqu(xmm2, xword[A2-0x80]);
298         movdqu(xmm3, xword[A2+LDA*1-0x80]);
299         sub(A2, -16);
300         movdqa(xmm4, xmm0);
301         punpckldq(xmm0, xmm1);
302         punpckhdq(xmm4, xmm1);
303         movdqa(xmm5, xmm2);
304         punpckldq(xmm2, xmm3);
305         punpckhdq(xmm5, xmm3);
306         movdqa(xmm1, xmm0);
307         punpcklqdq(xmm0, xmm2);
308         punpckhqdq(xmm1, xmm2);
309         movdqa(xmm3, xmm4);
310         punpcklqdq(xmm4, xmm5);
311         punpckhqdq(xmm3, xmm5);
312         movdqu(xword[B-0x80], xmm0);
313         movdqu(xword[B-0x70], xmm1);
314         movdqu(xword[B-0x60], xmm4);
315         movdqu(xword[B-0x50], xmm3);
316         sub(B, -64);
317         dec(I);
318         jg(l32c, T_NEAR);
319         align(4);
320
321 L(l3a0);
322         test(M, 0x8);
323         jle(l3f0, T_NEAR);
324         movq(xmm0, qword[A1-0x80]);
325         movq(xmm1, qword[A1+LDA*1-0x80]);
326         sub(A1, -8);
327         movq(xmm2, qword[A2-0x80]);
328         movq(xmm3, qword[A2+LDA*1-0x80]);
329         sub(A2, -8);
330         punpckldq(xmm0, xmm1);
331         punpckldq(xmm2, xmm3);
332         movdqa(xmm1, xmm0);
333         punpcklqdq(xmm0, xmm2);
334         punpckhqdq(xmm1, xmm2);
335         movdqu(xword[B-0x80], xmm0);
336         movdqu(xword[B-0x70], xmm1);
337         sub(B, -32);
338         align(4);
339
340 L(l3f0);
341         test(M, 0x4);
342         jle(l434, T_NEAR);
343         movd(xmm0, dword[A1-0x80]);
344         movd(xmm1, dword[A1+LDA*1-0x80]);
345         sub(A1, -4);
346         movd(xmm2, dword[A2-0x80]);
347         movd(xmm3, dword[A2+LDA*1-0x80]);
348         sub(A2, -4);
349         punpckldq(xmm0, xmm1);
350         punpckldq(xmm2, xmm3);
351         punpcklqdq(xmm0, xmm2);
352         movdqu(xword[B-0x80], xmm0);
353         sub(B, -16);
354         align(4);
355
356 L(l434);
357         test(M, 0x2);
358         jle(l47c, T_NEAR);
359         mov(ax, word[A1-0x80]);
360         pinsrw(xmm0, eax, 0x0);
361         mov(ax, word[A1+LDA*1-0x80]);
362         sub(A1, -2);
363         pinsrw(xmm0, eax, 0x1);
364         mov(ax, word[A2-0x80]);
365         pinsrw(xmm0, eax, 0x2);
366         mov(ax, word[A2+LDA*1-0x80]);
367         sub(A2, -2);
368         pinsrw(xmm0, eax, 0x3);
369         movq(qword[B-0x80], xmm0);
370         sub(B, -8);
371         align(4);
372
373 L(l47c);
374         test(M, 0x1);
375         jle(l4bc, T_NEAR);
376         mov(al, byte[A1-0x80]);
377         pinsrb(xmm0, eax, 0x0);
378         mov(al, byte[A1+LDA*1-0x80]);
379         pinsrb(xmm0, eax, 0x1);
380         mov(al, byte[A2-0x80]);
381         pinsrb(xmm0, eax, 0x2);
382         mov(al, byte[A2+LDA*1-0x80]);
383         pinsrb(xmm0, eax, 0x3);
384         movd(dword[B-0x80], xmm0);
385         sub(B, -4);
386         align(4);
387
388 L(l4bc);
389         sub(N, 0x4);
390         cmp(N, 0x4);
391         jge(l314, T_NEAR);
392         align(4);
393
394 L(l4cc);
395         cmp(N, 0x2);
396         jl(l5de, T_NEAR);
397         align(4);
398
399 L(l4d8);
400         mov(A1, A);
401         lea(A2, ptr[A1+LDA*1]);
402         lea(I, ptr[A1+LDA*2]);
403         mov(A, I);
404         mov(I, M);
405         sar(I, 0x4);
406         jle(l528, T_NEAR);
407         align(4);
408
409 L(l4f0);
410         movdqu(xmm0, xword[A1-0x80]);
411         sub(A1, -16);
412         movdqu(xmm1, xword[A2-0x80]);
413         sub(A2, -16);
414         movdqa(xmm2, xmm0);
415         punpckldq(xmm0, xmm1);
416         punpckhdq(xmm2, xmm1);
417         movdqu(xword[B-0x80], xmm0);
418         movdqu(xword[B-0x70], xmm2);
419         sub(B, -32);
420         dec(I);
421         jg(l4f0, T_NEAR);
422         align(4);
423
424 L(l528);
425         test(M, 0x8);
426         jle(l554, T_NEAR);
427         movq(xmm0, qword[A1-0x80]);
428         sub(A1, -8);
429         movq(xmm1, qword[A2-0x80]);
430         sub(A2, -8);
431         punpckldq(xmm0, xmm1);
432         movdqu(xword[B-0x80], xmm0);
433         sub(B, -16);
434         align(4);
435
436 L(l554);
437         test(M, 0x4);
438         jle(l580, T_NEAR);
439         movd(xmm0, dword[A1-0x80]);
440         sub(A1, -4);
441         movd(xmm1, dword[A2-0x80]);
442         sub(A2, -4);
443         punpckldq(xmm0, xmm1);
444         movq(qword[B-0x80], xmm0);
445         sub(B, -8);
446         align(4);
447
448 L(l580);
449         test(M, 0x2);
450         jle(l5b0, T_NEAR);
451         mov(ax, word[A1-0x80]);
452         sub(A1, -2);
453         pinsrw(xmm0, eax, 0x0);
454         mov(ax, word[A2-0x80]);
455         sub(A2, -2);
456         pinsrw(xmm0, eax, 0x1);
457         movd(dword[B-0x80], xmm0);
458         sub(B, -4);
459         align(4);
460
461 L(l5b0);
462         test(M, 0x1);
463         jle(l5d0, T_NEAR);
464         mov(al, byte[A1-0x80]);
465         mov(byte[B-0x80], al);
466         mov(al, byte[A2-0x80]);
467         mov(byte[B-0x7f], al);
468         sub(B, -2);
469         align(4);
470
471 L(l5d0);
472         sub(N, 0x2);
473         cmp(N, 0x2);
474         jge(l4d8, T_NEAR);
475         align(4);
476
477 L(l5de);
478         cmp(N, 0x1);
479         jl(l698, T_NEAR);
480         align(4);
481
482 L(l5e8);
483         mov(A1, A);
484         add(A, LDA);
485         mov(I, M);
486         sar(I, 0x4);
487         jle(l614, T_NEAR);
488         align(4);
489
490 L(l5f8);
491         movdqu(xmm0, xword[A1-0x80]);
492         sub(A1, -16);
493         movdqu(xword[B-0x80], xmm0);
494         sub(B, -16);
495         dec(I);
496         jg(l5f8, T_NEAR);
497         align(4);
498
499 L(l614);
500         test(M, 0x8);
501         jle(l634, T_NEAR);
502         movq(xmm0, qword[A1-0x80]);
503         sub(A1, -8);
504         movq(qword[B-0x80], xmm0);
505         sub(B, -8);
506         align(4);
507
508 L(l634);
509         test(M, 0x4);
510         jle(l654, T_NEAR);
511         movd(xmm0, dword[A1-0x80]);
512         sub(A1, -4);
513         movd(dword[B-0x80], xmm0);
514         sub(B, -4);
515         align(4);
516
517 L(l654);
518         test(M, 0x2);
519         jle(l670, T_NEAR);
520         mov(ax, word[A1-0x80]);
521         mov(word[B-0x80], ax);
522         sub(A1, -2);
523         sub(B, -2);
524         align(4);
525
526 L(l670);
527         test(M, 0x1);
528         jle(l688, T_NEAR);
529         mov(al, byte[A1-0x80]);
530         mov(byte[B-0x80], al);
531         sub(B, -1);
532         align(4);
533
534 L(l688);
535         sub(N, 0x1);
536         cmp(N, 0x1);
537         jge(l5e8, T_NEAR);
538         align(4);
539
540 L(l698);
541
542         postamble();
543 }
544 outLocalLabel();
545
546 #undef M
547 #undef N
548 #undef A
549 #undef LDA
550 #undef ALPHA
551 #undef B
552 #undef I
553 #undef A1
554 #undef A2
555 #undef LDA3
556 #ifdef _WIN32
557 #undef ARG_ALPHA
558 #undef ARG_B
559 #endif
560 }
561
562 }
563 }
564 }