Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / s8x8s32 / jit_avx512_core_u8_copy_sum_bt_kern.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "jit_generator.hpp"
18 #include "common.hpp"
19
20 namespace mkldnn {
21 namespace impl {
22 namespace cpu {
23
24 jit_avx512_core_u8_copy_sum_bt_kern::jit_avx512_core_u8_copy_sum_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
25 {
26
27 #ifndef _WIN32
28 #define M       rdi
29 #define N       rsi
30 #define A       rdx
31 #define LDA     rcx
32 #define ALPHA   r8
33 #define B       r9
34
35 #define I       rax
36 #define A1      r10
37 #define A2      r8
38 #define LDA3    r11
39
40 #define ARG_BIAS        24+stacksize+rsp
41
42 #else
43
44 #define M       rcx
45 #define N       rdx
46 #define A       r8
47 #define LDA     r9
48 #define ALPHA   rax
49 #define B       rdi
50
51 #define I       rax
52 #define A1      rsi
53 #define A2      r10
54 #define LDA3    r11
55
56 #define ARG_ALPHA       40+stacksize+rsp
57 #define ARG_B           48+stacksize+rsp
58 #define ARG_BIAS        72+stacksize+rsp
59
60 #endif
61
62 inLocalLabel();
63 {
64
65 Xbyak::Label l15c;
66 Xbyak::Label l1f4;
67 Xbyak::Label l20;
68 Xbyak::Label l248;
69 Xbyak::Label l280;
70 Xbyak::Label l2a4;
71 Xbyak::Label l2b0;
72 Xbyak::Label l2c8;
73 Xbyak::Label l384;
74 Xbyak::Label l3e8;
75 Xbyak::Label l40;
76 Xbyak::Label l424;
77 Xbyak::Label l448;
78 Xbyak::Label l468;
79 Xbyak::Label l474;
80 Xbyak::Label l48c;
81 Xbyak::Label l550;
82 Xbyak::Label l5bc;
83 Xbyak::Label l600;
84 Xbyak::Label l628;
85 Xbyak::Label l646;
86 Xbyak::Label l650;
87 Xbyak::Label l668;
88 Xbyak::Label l700;
89 Xbyak::Label l760;
90 Xbyak::Label l7a4;
91 Xbyak::Label l7c8;
92 Xbyak::Label l7e8;
93
94         preamble();
95         auto stacksize = get_size_of_abi_save_regs();
96 #ifdef _WIN32
97         mov(ALPHA, ptr[ARG_ALPHA]);
98         mov(B, ptr[ARG_B]);
99 #endif
100
101         mov(M, qword[M]);
102         mov(N, qword[N]);
103         mov(LDA, qword[LDA]);
104         lea(LDA3, ptr[LDA+LDA*2]);
105         sub(A, -128);
106         sub(B, -128);
107         cmp(N, 0x8);
108         jl(l2a4, T_NEAR);
109         align(4);
110
111 L(l20);
112         mov(A1, A);
113         add(A, 0x8);
114         pxor(xmm8, xmm8);
115         pxor(xmm9, xmm9);
116         mov(I, M);
117         sar(I, 0x3);
118         jle(l15c, T_NEAR);
119         align(4);
120
121 L(l40);
122         movq(xmm0, qword[A1-0x80]);
123         add(A1, LDA);
124         movq(xmm1, qword[A1-0x80]);
125         add(A1, LDA);
126         movq(xmm2, qword[A1-0x80]);
127         add(A1, LDA);
128         movq(xmm3, qword[A1-0x80]);
129         add(A1, LDA);
130         punpcklbw(xmm0, xmm1);
131         punpcklbw(xmm2, xmm3);
132         movdqa(xmm1, xmm0);
133         punpcklwd(xmm0, xmm2);
134         punpckhwd(xmm1, xmm2);
135         pmovsxbw(xmm5, xmm0);
136         movhlps(xmm6, xmm0);
137         pmovsxbw(xmm6, xmm6);
138         phaddw(xmm5, xmm6);
139         phaddw(xmm5, xmm5);
140         pmovsxwd(xmm5, xmm5);
141         paddd(xmm8, xmm5);
142         pmovsxbw(xmm5, xmm1);
143         movhlps(xmm6, xmm1);
144         pmovsxbw(xmm6, xmm6);
145         phaddw(xmm5, xmm6);
146         phaddw(xmm5, xmm5);
147         pmovsxwd(xmm5, xmm5);
148         paddd(xmm9, xmm5);
149         movdqu(xword[B-0x80], xmm0);
150         movdqu(xword[B-0x70], xmm1);
151         movq(xmm0, qword[A1-0x80]);
152         add(A1, LDA);
153         movq(xmm1, qword[A1-0x80]);
154         add(A1, LDA);
155         movq(xmm2, qword[A1-0x80]);
156         add(A1, LDA);
157         movq(xmm3, qword[A1-0x80]);
158         add(A1, LDA);
159         punpcklbw(xmm0, xmm1);
160         punpcklbw(xmm2, xmm3);
161         movdqa(xmm1, xmm0);
162         punpcklwd(xmm0, xmm2);
163         punpckhwd(xmm1, xmm2);
164         pmovsxbw(xmm5, xmm0);
165         movhlps(xmm6, xmm0);
166         pmovsxbw(xmm6, xmm6);
167         phaddw(xmm5, xmm6);
168         phaddw(xmm5, xmm5);
169         pmovsxwd(xmm5, xmm5);
170         paddd(xmm8, xmm5);
171         pmovsxbw(xmm5, xmm1);
172         movhlps(xmm6, xmm1);
173         pmovsxbw(xmm6, xmm6);
174         phaddw(xmm5, xmm6);
175         phaddw(xmm5, xmm5);
176         pmovsxwd(xmm5, xmm5);
177         paddd(xmm9, xmm5);
178         movdqu(xword[B-0x60], xmm0);
179         movdqu(xword[B-0x50], xmm1);
180         sub(B, -64);
181         dec(I);
182         jg(l40, T_NEAR);
183         align(4);
184
185 L(l15c);
186         test(M, 0x4);
187         jle(l1f4, T_NEAR);
188         movq(xmm0, qword[A1-0x80]);
189         add(A1, LDA);
190         movq(xmm1, qword[A1-0x80]);
191         add(A1, LDA);
192         movq(xmm2, qword[A1-0x80]);
193         add(A1, LDA);
194         movq(xmm3, qword[A1-0x80]);
195         add(A1, LDA);
196         punpcklbw(xmm0, xmm1);
197         punpcklbw(xmm2, xmm3);
198         movdqa(xmm1, xmm0);
199         punpcklwd(xmm0, xmm2);
200         punpckhwd(xmm1, xmm2);
201         pmovsxbw(xmm5, xmm0);
202         movhlps(xmm6, xmm0);
203         pmovsxbw(xmm6, xmm6);
204         phaddw(xmm5, xmm6);
205         phaddw(xmm5, xmm5);
206         pmovsxwd(xmm5, xmm5);
207         paddd(xmm8, xmm5);
208         pmovsxbw(xmm5, xmm1);
209         movhlps(xmm6, xmm1);
210         pmovsxbw(xmm6, xmm6);
211         phaddw(xmm5, xmm6);
212         phaddw(xmm5, xmm5);
213         pmovsxwd(xmm5, xmm5);
214         paddd(xmm9, xmm5);
215         movdqu(xword[B-0x80], xmm0);
216         movdqu(xword[B-0x70], xmm1);
217         sub(B, -32);
218         align(4);
219
220 L(l1f4);
221         test(M, 0x2);
222         jle(l248, T_NEAR);
223         movq(xmm0, qword[A1-0x80]);
224         add(A1, LDA);
225         movq(xmm1, qword[A1-0x80]);
226         add(A1, LDA);
227         punpcklbw(xmm0, xmm1);
228         pmovsxbw(xmm5, xmm0);
229         phaddw(xmm5, xmm5);
230         pmovsxwd(xmm5, xmm5);
231         paddd(xmm8, xmm5);
232         movhlps(xmm6, xmm0);
233         pmovsxbw(xmm6, xmm6);
234         phaddw(xmm6, xmm6);
235         pmovsxwd(xmm6, xmm6);
236         paddd(xmm9, xmm6);
237         movdqu(xword[B-0x80], xmm0);
238         sub(B, -16);
239         align(4);
240
241 L(l248);
242         test(M, 0x1);
243         jle(l280, T_NEAR);
244         movq(xmm0, qword[A1-0x80]);
245         add(A1, LDA);
246         pmovsxbd(xmm5, xmm0);
247         pshufd(xmm6, xmm0, 0x55);
248         pmovsxbd(xmm6, xmm6);
249         paddd(xmm8, xmm5);
250         paddd(xmm9, xmm6);
251         movq(qword[B-0x80], xmm0);
252         sub(B, -8);
253         align(4);
254
255 L(l280);
256         mov(A1, qword[ARG_BIAS]);
257         movdqu(xword[A1], xmm8);
258         movdqu(xword[A1+0x10], xmm9);
259         add(qword[ARG_BIAS], 0x20);
260         sub(N, 0x8);
261         cmp(N, 0x8);
262         jge(l20, T_NEAR);
263         align(4);
264
265 L(l2a4);
266         cmp(N, 0x4);
267         jl(l468, T_NEAR);
268         align(4);
269
270 L(l2b0);
271         mov(A1, A);
272         add(A, 0x4);
273         pxor(xmm7, xmm7);
274         mov(I, M);
275         sar(I, 0x3);
276         jle(l384, T_NEAR);
277         align(4);
278
279 L(l2c8);
280         movd(xmm0, dword[A1-0x80]);
281         add(A1, LDA);
282         movd(xmm1, dword[A1-0x80]);
283         add(A1, LDA);
284         movd(xmm2, dword[A1-0x80]);
285         add(A1, LDA);
286         movd(xmm3, dword[A1-0x80]);
287         add(A1, LDA);
288         punpcklbw(xmm0, xmm1);
289         punpcklbw(xmm2, xmm3);
290         punpcklwd(xmm0, xmm2);
291         pmovsxbw(xmm5, xmm0);
292         movhlps(xmm6, xmm0);
293         pmovsxbw(xmm6, xmm6);
294         phaddw(xmm5, xmm6);
295         phaddw(xmm5, xmm5);
296         pmovsxwd(xmm5, xmm5);
297         paddd(xmm7, xmm5);
298         movdqu(xword[B-0x80], xmm0);
299         movd(xmm0, dword[A1-0x80]);
300         add(A1, LDA);
301         movd(xmm1, dword[A1-0x80]);
302         add(A1, LDA);
303         movd(xmm2, dword[A1-0x80]);
304         add(A1, LDA);
305         movd(xmm3, dword[A1-0x80]);
306         add(A1, LDA);
307         punpcklbw(xmm0, xmm1);
308         punpcklbw(xmm2, xmm3);
309         punpcklwd(xmm0, xmm2);
310         pmovsxbw(xmm5, xmm0);
311         movhlps(xmm6, xmm0);
312         pmovsxbw(xmm6, xmm6);
313         phaddw(xmm5, xmm6);
314         phaddw(xmm5, xmm5);
315         pmovsxwd(xmm5, xmm5);
316         paddd(xmm7, xmm5);
317         movdqu(xword[B-0x70], xmm0);
318         sub(B, -32);
319         dec(I);
320         jg(l2c8, T_NEAR);
321         align(4);
322
323 L(l384);
324         test(M, 0x4);
325         jle(l3e8, T_NEAR);
326         movd(xmm0, dword[A1-0x80]);
327         add(A1, LDA);
328         movd(xmm1, dword[A1-0x80]);
329         add(A1, LDA);
330         movd(xmm2, dword[A1-0x80]);
331         add(A1, LDA);
332         movd(xmm3, dword[A1-0x80]);
333         add(A1, LDA);
334         punpcklbw(xmm0, xmm1);
335         punpcklbw(xmm2, xmm3);
336         punpcklwd(xmm0, xmm2);
337         pmovsxbw(xmm5, xmm0);
338         movhlps(xmm6, xmm0);
339         pmovsxbw(xmm6, xmm6);
340         phaddw(xmm5, xmm6);
341         phaddw(xmm5, xmm5);
342         pmovsxwd(xmm5, xmm5);
343         paddd(xmm7, xmm5);
344         movdqu(xword[B-0x80], xmm0);
345         sub(B, -16);
346         align(4);
347
348 L(l3e8);
349         test(M, 0x2);
350         jle(l424, T_NEAR);
351         movd(xmm0, dword[A1-0x80]);
352         add(A1, LDA);
353         movd(xmm1, dword[A1-0x80]);
354         add(A1, LDA);
355         punpcklbw(xmm0, xmm1);
356         pmovsxbw(xmm5, xmm0);
357         phaddw(xmm5, xmm5);
358         pmovsxwd(xmm5, xmm5);
359         paddd(xmm7, xmm5);
360         movq(qword[B-0x80], xmm0);
361         sub(B, -8);
362         align(4);
363
364 L(l424);
365         test(M, 0x1);
366         jle(l448, T_NEAR);
367         movd(xmm0, dword[A1-0x80]);
368         pmovsxbd(xmm5, xmm0);
369         paddd(xmm7, xmm5);
370         movd(dword[B-0x80], xmm0);
371         sub(B, -4);
372         align(4);
373
374 L(l448);
375         mov(A1, qword[ARG_BIAS]);
376         movdqu(xword[A1], xmm7);
377         add(qword[ARG_BIAS], 0x10);
378         sub(N, 0x4);
379         cmp(N, 0x4);
380         jge(l2b0, T_NEAR);
381         align(4);
382
383 L(l468);
384         cmp(N, 0x2);
385         jl(l646, T_NEAR);
386         align(4);
387
388 L(l474);
389         mov(A1, A);
390         add(A, 0x2);
391         pxor(xmm7, xmm7);
392         mov(LDA3, M);
393         sar(LDA3, 0x3);
394         jle(l550, T_NEAR);
395         align(4);
396
397 L(l48c);
398         mov(ax, word[A1-0x80]);
399         add(A1, LDA);
400         pinsrw(xmm0, eax, 0x0);
401         mov(ax, word[A1-0x80]);
402         add(A1, LDA);
403         pinsrw(xmm1, eax, 0x0);
404         mov(ax, word[A1-0x80]);
405         add(A1, LDA);
406         pinsrw(xmm2, eax, 0x0);
407         mov(ax, word[A1-0x80]);
408         add(A1, LDA);
409         pinsrw(xmm3, eax, 0x0);
410         punpcklbw(xmm0, xmm1);
411         punpcklbw(xmm2, xmm3);
412         punpcklwd(xmm0, xmm2);
413         mov(ax, word[A1-0x80]);
414         add(A1, LDA);
415         pinsrw(xmm1, eax, 0x0);
416         mov(ax, word[A1-0x80]);
417         add(A1, LDA);
418         pinsrw(xmm2, eax, 0x0);
419         mov(ax, word[A1-0x80]);
420         add(A1, LDA);
421         pinsrw(xmm3, eax, 0x0);
422         mov(ax, word[A1-0x80]);
423         add(A1, LDA);
424         pinsrw(xmm4, eax, 0x0);
425         punpcklbw(xmm1, xmm2);
426         punpcklbw(xmm3, xmm4);
427         punpcklwd(xmm1, xmm3);
428         punpcklqdq(xmm0, xmm1);
429         pshufd(xmm6, xmm0, 0xd8);
430         pmovsxbw(xmm5, xmm6);
431         movhlps(xmm6, xmm6);
432         pmovsxbw(xmm6, xmm6);
433         phaddw(xmm5, xmm6);
434         phaddw(xmm5, xmm5);
435         phaddw(xmm5, xmm5);
436         pmovsxwd(xmm5, xmm5);
437         paddd(xmm7, xmm5);
438         movdqu(xword[B-0x80], xmm0);
439         sub(B, -16);
440         dec(LDA3);
441         jg(l48c, T_NEAR);
442         align(4);
443
444 L(l550);
445         test(M, 0x4);
446         jle(l5bc, T_NEAR);
447         mov(ax, word[A1-0x80]);
448         add(A1, LDA);
449         pinsrw(xmm0, eax, 0x0);
450         mov(ax, word[A1-0x80]);
451         add(A1, LDA);
452         pinsrw(xmm1, eax, 0x0);
453         mov(ax, word[A1-0x80]);
454         add(A1, LDA);
455         pinsrw(xmm2, eax, 0x0);
456         mov(ax, word[A1-0x80]);
457         add(A1, LDA);
458         pinsrw(xmm3, eax, 0x0);
459         punpcklbw(xmm0, xmm1);
460         punpcklbw(xmm2, xmm3);
461         punpcklwd(xmm0, xmm2);
462         pmovsxbw(xmm5, xmm0);
463         phaddw(xmm5, xmm5);
464         phaddw(xmm5, xmm5);
465         pmovsxwd(xmm5, xmm5);
466         paddd(xmm7, xmm5);
467         movq(qword[B-0x80], xmm0);
468         sub(B, -8);
469         align(4);
470
471 L(l5bc);
472         test(M, 0x2);
473         jle(l600, T_NEAR);
474         mov(ax, word[A1-0x80]);
475         add(A1, LDA);
476         pinsrw(xmm0, eax, 0x0);
477         mov(ax, word[A1-0x80]);
478         add(A1, LDA);
479         pinsrw(xmm1, eax, 0x0);
480         punpcklbw(xmm0, xmm1);
481         pmovsxbw(xmm5, xmm0);
482         phaddw(xmm5, xmm5);
483         pmovsxwd(xmm5, xmm5);
484         paddd(xmm7, xmm5);
485         movd(dword[B-0x80], xmm0);
486         sub(B, -4);
487         align(4);
488
489 L(l600);
490         test(M, 0x1);
491         jle(l628, T_NEAR);
492         mov(ax, word[A1-0x80]);
493         pinsrw(xmm0, eax, 0x0);
494         pmovsxbd(xmm5, xmm0);
495         paddd(xmm7, xmm5);
496         mov(word[B-0x80], ax);
497         sub(B, -2);
498         align(4);
499
500 L(l628);
501         mov(A1, qword[ARG_BIAS]);
502         movq(qword[A1], xmm7);
503         add(qword[ARG_BIAS], 0x8);
504         sub(N, 0x2);
505         cmp(N, 0x2);
506         jge(l474, T_NEAR);
507         align(4);
508
509 L(l646);
510         cmp(N, 0x1);
511         jl(l7e8, T_NEAR);
512         align(4);
513
514 L(l650);
515         mov(A1, A);
516         add(A, 0x1);
517         pxor(xmm7, xmm7);
518         mov(LDA3, M);
519         sar(LDA3, 0x3);
520         jle(l700, T_NEAR);
521         align(4);
522
523 L(l668);
524         mov(al, byte[A1-0x80]);
525         add(A1, LDA);
526         pinsrb(xmm0, eax, 0x0);
527         mov(al, byte[A1-0x80]);
528         add(A1, LDA);
529         pinsrb(xmm0, eax, 0x1);
530         mov(al, byte[A1-0x80]);
531         add(A1, LDA);
532         pinsrb(xmm0, eax, 0x2);
533         mov(al, byte[A1-0x80]);
534         add(A1, LDA);
535         pinsrb(xmm0, eax, 0x3);
536         mov(al, byte[A1-0x80]);
537         add(A1, LDA);
538         pinsrb(xmm0, eax, 0x4);
539         mov(al, byte[A1-0x80]);
540         add(A1, LDA);
541         pinsrb(xmm0, eax, 0x5);
542         mov(al, byte[A1-0x80]);
543         add(A1, LDA);
544         pinsrb(xmm0, eax, 0x6);
545         mov(al, byte[A1-0x80]);
546         add(A1, LDA);
547         pinsrb(xmm0, eax, 0x7);
548         pmovsxbw(xmm5, xmm0);
549         phaddw(xmm5, xmm6);
550         phaddw(xmm5, xmm5);
551         phaddw(xmm5, xmm5);
552         pmovsxwd(xmm5, xmm5);
553         paddd(xmm7, xmm5);
554         movq(qword[B-0x80], xmm0);
555         sub(B, -8);
556         dec(LDA3);
557         jg(l668, T_NEAR);
558         align(4);
559
560 L(l700);
561         test(M, 0x4);
562         jle(l760, T_NEAR);
563         mov(al, byte[A1-0x80]);
564         add(A1, LDA);
565         pinsrb(xmm0, eax, 0x0);
566         mov(al, byte[A1-0x80]);
567         add(A1, LDA);
568         pinsrb(xmm0, eax, 0x1);
569         mov(al, byte[A1-0x80]);
570         add(A1, LDA);
571         pinsrb(xmm0, eax, 0x2);
572         mov(al, byte[A1-0x80]);
573         add(A1, LDA);
574         pinsrb(xmm0, eax, 0x3);
575         pmovsxbw(xmm5, xmm0);
576         phaddw(xmm5, xmm5);
577         phaddw(xmm5, xmm5);
578         pmovsxwd(xmm5, xmm5);
579         paddd(xmm7, xmm5);
580         movd(dword[B-0x80], xmm0);
581         sub(B, -4);
582         align(4);
583
584 L(l760);
585         test(M, 0x2);
586         jle(l7a4, T_NEAR);
587         mov(al, byte[A1-0x80]);
588         add(A1, LDA);
589         pinsrb(xmm0, eax, 0x0);
590         mov(byte[B-0x80], al);
591         mov(al, byte[A1-0x80]);
592         add(A1, LDA);
593         pinsrb(xmm0, eax, 0x1);
594         pmovsxbw(xmm5, xmm0);
595         phaddw(xmm5, xmm5);
596         pmovsxwd(xmm5, xmm5);
597         paddd(xmm7, xmm5);
598         mov(byte[B-0x7f], al);
599         sub(B, -2);
600         align(4);
601
602 L(l7a4);
603         test(M, 0x1);
604         jle(l7c8, T_NEAR);
605         mov(al, byte[A1-0x80]);
606         pinsrw(xmm0, eax, 0x0);
607         pmovsxbd(xmm5, xmm0);
608         paddd(xmm7, xmm5);
609         mov(byte[B-0x80], al);
610         sub(B, -1);
611         align(4);
612
613 L(l7c8);
614         mov(A1, qword[ARG_BIAS]);
615         movd(dword[A1], xmm7);
616         add(qword[ARG_BIAS], 0x4);
617         sub(N, 0x1);
618         cmp(N, 0x1);
619         jge(l650, T_NEAR);
620         align(4);
621
622 L(l7e8);
623
624         postamble();
625 }
626 outLocalLabel();
627
628 #undef M
629 #undef N
630 #undef A
631 #undef LDA
632 #undef ALPHA
633 #undef B
634 #undef I
635 #undef A1
636 #undef A2
637 #undef LDA3
638 #ifdef _WIN32
639 #undef ARG_ALPHA
640 #undef ARG_B
641 #endif
642 #undef ARG_BIAS
643 }
644
645 }
646 }
647 }