Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / jit_avx512_common_gemm_f32.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <math.h>
18
19 #include "mkldnn_thread.hpp"
20 #include "utils.hpp"
21
22 #include "gemm_utils.hpp"
23 #include "jit_avx512_common_gemm_f32.hpp"
24
25 #define CACHE_LINE_SIZE 64
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::utils;
33
34 using namespace Xbyak;
35 #define STACKSIZE get_size_of_abi_save_regs()
36 #ifdef _WIN32
37 #define STACK_K_CAPACITY 32
38 #else
39 #define STACK_K_CAPACITY 2048
40 #endif
41 #define SIZE 4
42 #define OFFSET 128
43 #define BASE_SHIFT 2
44 #define SECOND_FETCH unroll_n
45 #define UNROLL_M 48
46 #define UNROLL_N 8
47
48 struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
49     xbyak_gemm(char transa, char transb, float beta, bool hasBias = false,
50             void *code_ptr = nullptr,
51             size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
52         : jit_generator(code_ptr, code_size)
53     {
54         enum { ver_avx512_core, ver_avx512_mic } ver =
55             mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic;
56
57         bool isTransA = (transa == 'T' || transa == 't');
58         bool isTransB = (transb == 'T' || transb == 't');
59         bool isBeta0 = (beta == 0.0);
60         bool isBetaN = (!isBeta0 && beta != 1.0);
61
62         // various definitions for convenience
63         auto ARG_M = abi_param1;
64         auto ARG_N = abi_param2;
65         auto K = abi_param3;
66         auto ARG_ALPHA = abi_param4;
67 #ifdef _WIN32
68         auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
69         auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
70             sizeof(float *) + STACKSIZE];
71         const auto stackOffset = OFFSET_SHADOWSPACE +
72             sizeof(float *) + STACKSIZE;
73         auto A = rsi;
74         auto LDA = rdi;
75 #else
76         auto ARG_A = r8;
77         auto ARG_LDA = r9;
78         const auto stackOffset = STACKSIZE;
79         auto A = ARG_A;
80         auto LDA = ARG_LDA;
81 #endif
82         auto ARG_B = ptr[rsp + 8 + stackOffset];
83         auto ARG_LDB = ptr[rsp + 16 + stackOffset];
84         auto ARG_BETA = ptr[rsp + 24 + stackOffset];
85         auto ARG_C = ptr[rsp + 32 + stackOffset];
86         auto ARG_LDC = ptr[rsp + 40 + stackOffset];
87         auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
88         auto ARG_WS = ptr[rsp + 56 + stackOffset];
89
90         auto B = r11;
91         auto LDB = rbx;
92         auto LDC = r13;
93         auto LL = rax;
94         auto AO1 = abi_param2;
95         auto BO1 = abi_param4;
96         auto BO2 = rbp;
97         auto CO1 = r14;
98         auto CO2 = r15;
99         auto LDB3 = r10;
100         auto LDA4 = abi_param1;
101         auto AA = r12;
102         auto BIAS1 = abi_param1;
103
104         auto M = qword[rsp + 0];
105         auto N = qword[rsp + 8];
106         auto FLAG = qword[rsp + 16];
107         auto I = qword[rsp + 24];
108         auto C = qword[rsp + 32];
109         auto BIAS = qword[rsp + 40];
110         auto ALPHA = qword[rsp + 48];
111         auto BETA = qword[rsp + 64];
112         auto ORIG_A = qword[rsp + 80];
113         auto ORIG_SP = qword[rsp + 120];
114
115         auto ZSTRIDE = zmm4;
116         auto VALPHA = zmm6;
117         auto VBETA = zmm7;
118         auto VBIAS1 = zmm1;
119         auto VBIAS2 = zmm2;
120         auto VBIAS3 = zmm3;
121
122         auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80;
123         auto PREFETCHSIZEB = 16;
124
125         Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15,
126             zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24,
127             zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 };
128
129         // Function for packing if needed
130         auto do_pack = [&](int unroll_m) {
131             Label pack2, pack3, pack4, pack10;
132
133             mov(BO1, A);
134             lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
135             mov(LL, K);
136             sar(LL, 2);
137             jle(pack3, T_NEAR);
138             align(16);
139
140             L(pack2);
141             if (!isTransA) {
142                 for (int i = 0; i < 4; i++) {
143                     vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
144                     if (unroll_m > 16)
145                         vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
146                     if (unroll_m > 32)
147                         vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
148                     add(BO1, LDA);
149
150                     vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE]
151                                     | k1,
152                             zmm0);
153                     if (unroll_m > 16)
154                         vmovups(ptr[AO1
155                                         + (unroll_m * i + 1 * 16 - OFFSET)
156                                                 * SIZE]
157                                         | k2,
158                                 zmm1);
159                     if (unroll_m > 32)
160                         vmovups(ptr[AO1
161                                         + (unroll_m * i + 2 * 16 - OFFSET)
162                                                 * SIZE]
163                                         | k3,
164                                 zmm2);
165                 }
166             } else {
167                 for (int i = 0; i < 4; i++) {
168                     kmovw(k4, k1);
169                     vgatherqps(ymm5 | k4,
170                             ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]);
171                     lea(BO2, ptr[BO1 + LDA * 8]);
172                     kshiftrw(k4, k1, 8);
173                     vgatherqps(ymm6 | k4,
174                             ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
175                     vshuff64x2(zmm0, zmm5, zmm6, 0x44);
176
177                     if (unroll_m > 16) {
178                         lea(BO2, ptr[BO2 + LDA * 8]);
179                         kmovw(k4, k2);
180                         vgatherqps(ymm5 | k4,
181                                 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
182                         lea(BO2, ptr[BO2 + LDA * 8]);
183                         kshiftrw(k4, k2, 8);
184                         vgatherqps(ymm6 | k4,
185                                 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
186                         vshuff64x2(zmm1, zmm5, zmm6, 0x44);
187                     }
188
189                     if (unroll_m > 32) {
190                         lea(BO2, ptr[BO2 + LDA * 8]);
191                         kmovw(k4, k3);
192                         vgatherqps(ymm5 | k4,
193                                 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
194                         lea(BO2, ptr[BO2 + LDA * 8]);
195                         kshiftrw(k4, k3, 8);
196                         vgatherqps(ymm6 | k4,
197                                 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
198                         lea(BO2, ptr[BO2 + LDA * 8]);
199                         vshuff64x2(zmm2, zmm5, zmm6, 0x44);
200                     }
201
202                     vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE],
203                             zmm0 | k1);
204                     if (unroll_m > 16)
205                         vmovups(ptr[AO1
206                                         + (unroll_m * i + 1 * 16 - OFFSET)
207                                                 * SIZE],
208                                 zmm1 | k2);
209                     if (unroll_m > 32)
210                         vmovups(ptr[AO1
211                                         + (unroll_m * i + 2 * 16 - OFFSET)
212                                                 * SIZE],
213                                 zmm2 | k3);
214                 }
215                 add(BO1, 4 * SIZE);
216             }
217             add(AO1, unroll_m * 4 * SIZE);
218
219             sub(LL, 1);
220             jg(pack2, T_NEAR);
221             align(16);
222
223             L(pack3);
224             mov(LL, K);
225             and_(LL, 3);
226             jle(pack10, T_NEAR);
227             align(16);
228
229             L(pack4);
230             if (!isTransA) {
231                 vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
232                 if (unroll_m > 16)
233                     vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
234                 if (unroll_m > 32)
235                     vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
236                 add(BO1, LDA);
237             } else {
238                 kmovw(k4, k1);
239                 vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]);
240                 lea(BO2, ptr[BO1 + LDA * 8]);
241                 kshiftrw(k4, k1, 8);
242                 vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
243                 vshuff64x2(zmm0, zmm5, zmm6, 0x44);
244
245                 if (unroll_m > 16) {
246                     lea(BO2, ptr[BO2 + LDA * 8]);
247                     kmovw(k4, k2);
248                     vgatherqps(ymm5 | k4,
249                             ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
250                     lea(BO2, ptr[BO2 + LDA * 8]);
251                     kshiftrw(k4, k2, 8);
252                     vgatherqps(ymm6 | k4,
253                             ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
254                     vshuff64x2(zmm1, zmm5, zmm6, 0x44);
255                 }
256
257                 if (unroll_m > 32) {
258                     lea(BO2, ptr[BO2 + LDA * 8]);
259                     kmovw(k4, k3);
260                     vgatherqps(ymm5 | k4,
261                             ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
262                     lea(BO2, ptr[BO2 + LDA * 8]);
263                     kshiftrw(k4, k3, 8);
264                     vgatherqps(ymm6 | k4,
265                             ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
266                     lea(BO2, ptr[BO2 + LDA * 8]);
267                     vshuff64x2(zmm2, zmm5, zmm6, 0x44);
268                 }
269                 add(BO1, SIZE);
270             }
271
272             vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
273                     zmm0 | k1);
274             if (unroll_m > 16)
275                 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE],
276                         zmm1 | k2);
277             if (unroll_m > 32)
278                 vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE],
279                         zmm2 | k3);
280
281             add(AO1, unroll_m * SIZE);
282             sub(LL, 1);
283             jg(pack4, T_NEAR);
284             align(16);
285
286             L(pack10);
287         };
288
289         // Function to update C, covering masking and other considerations
290         auto update = [&](Zmm reg, bool useCO1, int offset, int mask,
291                 bool useScale = false) {
292             vmulps(reg, reg, VALPHA);
293             if (!isBeta0) {
294                 if (!useScale) {
295                     switch (mask) {
296                     case 0:
297                         if (useCO1)
298                             vmovups(zmm0, ptr[CO1 + offset * SIZE]);
299                         else
300                             vmovups(zmm0, ptr[CO2 + offset * SIZE]);
301                         break;
302                     case 1:
303                         if (useCO1)
304                             vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]);
305                         else
306                             vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]);
307                         break;
308                     case 2:
309                         if (useCO1)
310                             vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]);
311                         else
312                             vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]);
313                         break;
314                     case 3:
315                         if (useCO1)
316                             vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]);
317                         else
318                             vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]);
319                         break;
320                     }
321                 } else {
322                     switch (mask) {
323                     case 0:
324                         if (useCO1)
325                             vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]);
326                         else
327                             vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]);
328                         break;
329                     case 1:
330                         if (useCO1)
331                             vmovups(zmm0 | k1 | T_z,
332                                     ptr[CO1 + LDC + offset * SIZE]);
333                         else
334                             vmovups(zmm0 | k1 | T_z,
335                                     ptr[CO2 + LDC + offset * SIZE]);
336                         break;
337                     case 2:
338                         if (useCO1)
339                             vmovups(zmm0 | k2 | T_z,
340                                     ptr[CO1 + LDC + offset * SIZE]);
341                         else
342                             vmovups(zmm0 | k2 | T_z,
343                                     ptr[CO2 + LDC + offset * SIZE]);
344                         break;
345                     case 3:
346                         if (useCO1)
347                             vmovups(zmm0 | k3 | T_z,
348                                     ptr[CO1 + LDC + offset * SIZE]);
349                         else
350                             vmovups(zmm0 | k3 | T_z,
351                                     ptr[CO2 + LDC + offset * SIZE]);
352                         break;
353                     }
354                 }
355                 if (!isBetaN) {
356                     vaddps(zmm0, reg, zmm0);
357                 } else {
358                     vfmadd132ps(zmm0, reg, VBETA);
359                 }
360                 if (!useScale) {
361                     switch (mask) {
362                     case 0:
363                         if (useCO1)
364                             vmovups(ptr[CO1 + offset * SIZE], zmm0);
365                         else
366                             vmovups(ptr[CO2 + offset * SIZE], zmm0);
367                         break;
368                     case 1:
369                         if (useCO1)
370                             vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1);
371                         else
372                             vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1);
373                         break;
374                     case 2:
375                         if (useCO1)
376                             vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2);
377                         else
378                             vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2);
379                         break;
380                     case 3:
381                         if (useCO1)
382                             vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3);
383                         else
384                             vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3);
385                         break;
386                     }
387                 } else {
388                     switch (mask) {
389                     case 0:
390                         if (useCO1)
391                             vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0);
392                         else
393                             vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0);
394                         break;
395                     case 1:
396                         if (useCO1)
397                             vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1);
398                         else
399                             vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1);
400                         break;
401                     case 2:
402                         if (useCO1)
403                             vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2);
404                         else
405                             vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2);
406                         break;
407                     case 3:
408                         if (useCO1)
409                             vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3);
410                         else
411                             vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3);
412                         break;
413                     }
414                 }
415             } else {
416                 if (!useScale) {
417                     switch (mask) {
418                     case 0:
419                         if (useCO1)
420                             vmovups(ptr[CO1 + offset * SIZE], reg);
421                         else
422                             vmovups(ptr[CO2 + offset * SIZE], reg);
423                         break;
424                     case 1:
425                         if (useCO1)
426                             vmovups(ptr[CO1 + offset * SIZE], reg | k1);
427                         else
428                             vmovups(ptr[CO2 + offset * SIZE], reg | k1);
429                         break;
430                     case 2:
431                         if (useCO1)
432                             vmovups(ptr[CO1 + offset * SIZE], reg | k2);
433                         else
434                             vmovups(ptr[CO2 + offset * SIZE], reg | k2);
435                         break;
436                     case 3:
437                         if (useCO1)
438                             vmovups(ptr[CO1 + offset * SIZE], reg | k3);
439                         else
440                             vmovups(ptr[CO2 + offset * SIZE], reg | k3);
441                         break;
442                     }
443                 } else {
444                     switch (mask) {
445                     case 0:
446                         if (useCO1)
447                             vmovups(ptr[CO1 + LDC + offset * SIZE], reg);
448                         else
449                             vmovups(ptr[CO2 + LDC + offset * SIZE], reg);
450                         break;
451                     case 1:
452                         if (useCO1)
453                             vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1);
454                         else
455                             vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1);
456                         break;
457                     case 2:
458                         if (useCO1)
459                             vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2);
460                         else
461                             vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2);
462                         break;
463                     case 3:
464                         if (useCO1)
465                             vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3);
466                         else
467                             vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3);
468                         break;
469                     }
470                 }
471             }
472             vpxorq(reg, reg, reg);
473         };
474
475         // Loop with unroll_n - 2 FMAs; called by innerkernel
476         auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) {
477             for (int i = 2; i < unroll_n; i++) {
478                 if (ver == ver_avx512_core) {
479                     if (!isTransB) {
480                         switch (i) {
481                         case 2:
482                             vbroadcastss(
483                                     zmm3,
484                                     ptr[BO1 + LDB * 2
485                                             + (iteration - OFFSET) * SIZE]);
486                             break;
487                         case 3:
488                             vbroadcastss(
489                                     zmm3,
490                                     ptr[BO1 + LDB3
491                                             + (iteration - OFFSET) * SIZE]);
492                             break;
493                         case 4:
494                             vbroadcastss(zmm3,
495                                     ptr[BO2 + (iteration - OFFSET) * SIZE]);
496                             break;
497                         case 5:
498                             vbroadcastss(
499                                     zmm3,
500                                     ptr[BO2 + LDB * 1
501                                             + (iteration - OFFSET) * SIZE]);
502                             break;
503                         case 6:
504                             vbroadcastss(
505                                     zmm3,
506                                     ptr[BO2 + LDB * 2
507                                             + (iteration - OFFSET) * SIZE]);
508                             break;
509                         case 7:
510                             vbroadcastss(
511                                     zmm3,
512                                     ptr[BO2 + LDB3
513                                             + (iteration - OFFSET) * SIZE]);
514                             break;
515                         }
516                     } else {
517                         vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
518                     }
519                     vfmadd231ps(regs[i], zmm3, zmm0);
520                     if (unroll_m >= 32)
521                         vfmadd231ps(regs[i + 8], zmm3, zmm1);
522                     if (unroll_m >= 48)
523                         vfmadd231ps(regs[i + 16], zmm3, zmm2);
524                 } else {
525                     if (!isTransB) {
526                         switch (i) {
527                         case 2:
528                             vfmadd231ps(regs[i], zmm0,
529                                     zword_b[BO1 + LDB * 2
530                                     + (iteration - OFFSET) * SIZE]);
531                             if (unroll_m >= 32)
532                                 vfmadd231ps(regs[i + 8], zmm1,
533                                         zword_b[BO1 + LDB * 2
534                                         + (iteration - OFFSET) * SIZE]);
535                             if (unroll_m >= 48)
536                                 vfmadd231ps(regs[i + 16], zmm2,
537                                         zword_b[BO1 + LDB * 2
538                                         + (iteration - OFFSET) * SIZE]);
539                             break;
540                         case 3:
541                             vfmadd231ps(regs[i], zmm0,
542                                     zword_b[BO1 + LDB3
543                                     + (iteration - OFFSET) * SIZE]);
544                             if (unroll_m >= 32)
545                                 vfmadd231ps(regs[i + 8], zmm1,
546                                         zword_b[BO1 + LDB3
547                                         + (iteration - OFFSET) * SIZE]);
548                             if (unroll_m >= 48)
549                                 vfmadd231ps(regs[i + 16], zmm2,
550                                         zword_b[BO1 + LDB3
551                                         + (iteration - OFFSET) * SIZE]);
552                             break;
553                         case 4:
554                             vfmadd231ps(regs[i], zmm0,
555                                     zword_b[BO2 + (iteration - OFFSET) * SIZE]);
556                             if (unroll_m >= 32)
557                                 vfmadd231ps(regs[i + 8], zmm1,
558                                         zword_b[BO2 + (iteration - OFFSET) * SIZE]);
559                             if (unroll_m >= 48)
560                                 vfmadd231ps(regs[i + 16], zmm2,
561                                         zword_b[BO2 + (iteration - OFFSET) * SIZE]);
562                             break;
563                         case 5:
564                             vfmadd231ps(regs[i], zmm0,
565                                     zword_b[BO2 + LDB * 1
566                                     + (iteration - OFFSET) * SIZE]);
567                             if (unroll_m >= 32)
568                                 vfmadd231ps(regs[i + 8], zmm1,
569                                         zword_b[BO2 + LDB * 1
570                                         + (iteration - OFFSET) * SIZE]);
571                             if (unroll_m >= 48)
572                                 vfmadd231ps(regs[i + 16], zmm2,
573                                         zword_b[BO2 + LDB * 1
574                                         + (iteration - OFFSET) * SIZE]);
575                             break;
576                         case 6:
577                             vfmadd231ps(regs[i], zmm0,
578                                     zword_b[BO2 + LDB * 2
579                                     + (iteration - OFFSET) * SIZE]);
580                             if (unroll_m >= 32)
581                                 vfmadd231ps(regs[i + 8], zmm1,
582                                         zword_b[BO2 + LDB * 2
583                                         + (iteration - OFFSET) * SIZE]);
584                             if (unroll_m >= 48)
585                                 vfmadd231ps(regs[i + 16], zmm2,
586                                         zword_b[BO2 + LDB * 2
587                                         + (iteration - OFFSET) * SIZE]);
588                             break;
589                         case 7:
590                             vfmadd231ps(regs[i], zmm0,
591                                     zword_b[BO2 + LDB3
592                                     + (iteration - OFFSET) * SIZE]);
593                             if (unroll_m >= 32)
594                                 vfmadd231ps(regs[i + 8], zmm1,
595                                         zword_b[BO2 + LDB3
596                                         + (iteration - OFFSET) * SIZE]);
597                             if (unroll_m >= 48)
598                                 vfmadd231ps(regs[i + 16], zmm2,
599                                         zword_b[BO2 + LDB3
600                                         + (iteration - OFFSET) * SIZE]);
601                             break;
602                         }
603                     } else {
604                         vfmadd231ps(
605                                 regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]);
606                         if (unroll_m >= 32)
607                             vfmadd231ps(regs[i + 8], zmm1,
608                                     zword_b[BO1 + (i - OFFSET) * SIZE]);
609                         if (unroll_m >= 48)
610                             vfmadd231ps(regs[i + 16], zmm2,
611                                     zword_b[BO1 + (i - OFFSET) * SIZE]);
612                     }
613                 }
614             }
615         };
616
617         // Innerkernel; called by kernel
618         auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect,
619                 bool isCopy, bool doCPrefetch, bool isUnmasked = true) {
620             for (int i = 0; i < 8; i++) {
621                 if (!isDirect) {
622                     prefetcht0(ptr[AO1
623                             + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET)
624                                     * SIZE]);
625                     if (unroll_m >= 32)
626                         prefetcht0(ptr[AO1
627                             + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET)
628                                     * SIZE]);
629                     if (unroll_m >= 48)
630                         prefetcht0(ptr[AO1
631                             + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET)
632                                     * SIZE]);
633                 } else {
634                     prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]);
635                     if (unroll_m >= 32)
636                         prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]);
637                     if (unroll_m >= 48)
638                         prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]);
639                 }
640
641                 if (!isDirect) {
642                     if (i != 0) {
643                         if (isUnmasked || unroll_m > 16) {
644                             vmovups(zmm0,
645                                     ptr[AO1
646                                             + (unroll_m * i + 0 * 16 - OFFSET)
647                                                     * SIZE]);
648                         } else {
649                             vmovups(zmm0 | k1 | T_z,
650                                     ptr[AO1
651                                             + (unroll_m * i + 0 * 16 - OFFSET)
652                                                     * SIZE]);
653                         }
654                         if (unroll_m >= 32) {
655                             if (isUnmasked || unroll_m > 32) {
656                                 vmovups(zmm1, ptr[AO1
657                                                       + (unroll_m * i + 1 * 16
658                                                                 - OFFSET)
659                                                               * SIZE]);
660                             } else {
661                                 vmovups(zmm1 | k2 | T_z,
662                                         ptr[AO1
663                                                 + (unroll_m * i + 1 * 16
664                                                           - OFFSET)
665                                                         * SIZE]);
666                             }
667                         }
668                         if (unroll_m >= 48) {
669                             if (isUnmasked) {
670                                 vmovups(zmm2, ptr[AO1
671                                                       + (unroll_m * i + 2 * 16
672                                                                 - OFFSET)
673                                                               * SIZE]);
674                             } else {
675                                 vmovups(zmm2 | k3 | T_z,
676                                         ptr[AO1
677                                                 + (unroll_m * i + 2 * 16
678                                                           - OFFSET)
679                                                         * SIZE]);
680                             }
681                         }
682                     }
683                 } else {
684                     if (isUnmasked || unroll_m > 16) {
685                         vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
686                     } else {
687                         vmovups(zmm0 | k1 | T_z,
688                                 ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
689                     }
690                     if (unroll_m >= 32) {
691                         if (isUnmasked || unroll_m > 32) {
692                             vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
693                         } else {
694                             vmovups(zmm1 | k2 | T_z,
695                                     ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
696                         }
697                     }
698                     if (unroll_m >= 48) {
699                         if (isUnmasked) {
700                             vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
701                         } else {
702                             vmovups(zmm2 | k3 | T_z,
703                                     ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
704                         }
705                     }
706                     add(AO1, LDA);
707                 }
708
709                 if (ver == ver_avx512_core) {
710                     if (!isTransB) {
711                         vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
712                     } else {
713                         vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
714                     }
715                     vfmadd231ps(regs[0], zmm3, zmm0);
716                     if (unroll_m >= 32)
717                         vfmadd231ps(regs[0 + 8], zmm3, zmm1);
718                     if (unroll_m >= 48)
719                         vfmadd231ps(regs[0 + 16], zmm3, zmm2);
720                 } else {
721                     if (!isTransB) {
722                         vfmadd231ps(regs[0], zmm0,
723                                 zword_b[BO1 + (i - OFFSET) * SIZE]);
724                         if (unroll_m >= 32)
725                             vfmadd231ps(regs[0 + 8], zmm1,
726                                     zword_b[BO1 + (i - OFFSET) * SIZE]);
727                         if (unroll_m >= 48)
728                             vfmadd231ps(regs[0 + 16], zmm2,
729                                     zword_b[BO1 + (i - OFFSET) * SIZE]);
730                     } else {
731                         vfmadd231ps(regs[0], zmm0,
732                                 zword_b[BO1 + (0 - OFFSET) * SIZE]);
733                         if (unroll_m >= 32)
734                             vfmadd231ps(regs[0 + 8], zmm1,
735                                     zword_b[BO1 + (0 - OFFSET) * SIZE]);
736                         if (unroll_m >= 48)
737                             vfmadd231ps(regs[0 + 16], zmm2,
738                                     zword_b[BO1 + (0 - OFFSET) * SIZE]);
739                     }
740                 }
741
742                 if (unroll_n >= i + 1) {
743                     if (!isTransB) {
744                         switch (i) {
745                         case 0:
746                             prefetcht0(
747                                     ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]);
748                             break;
749                         case 1:
750                             prefetcht0(ptr[BO1 + LDB
751                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
752                             break;
753                         case 2:
754                             prefetcht0(ptr[BO1 + LDB * 2
755                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
756                             break;
757                         case 3:
758                             prefetcht0(ptr[BO1 + LDB3
759                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
760                             break;
761                         case 4:
762                             prefetcht0(
763                                     ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]);
764                             break;
765                         case 5:
766                             prefetcht0(ptr[BO2 + LDB
767                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
768                             break;
769                         case 6:
770                             prefetcht0(ptr[BO2 + LDB * 2
771                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
772                             break;
773                         case 7:
774                             prefetcht0(ptr[BO2 + LDB3
775                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
776                             break;
777                         }
778                     }
779                 }
780
781                 if (unroll_n >= 2) {
782                     if (ver == ver_avx512_core) {
783                         if (!isTransB) {
784                             vbroadcastss(zmm3,
785                                     ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
786                         } else {
787                             vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]);
788                         }
789                         vfmadd231ps(regs[1], zmm3, zmm0);
790                         if (unroll_m >= 32)
791                             vfmadd231ps(regs[1 + 8], zmm3, zmm1);
792                         if (unroll_m >= 48)
793                             vfmadd231ps(regs[1 + 16], zmm3, zmm2);
794                     } else {
795                         if (!isTransB) {
796                             vfmadd231ps(regs[1], zmm0,
797                                     zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
798                             if (unroll_m >= 32)
799                                 vfmadd231ps(regs[1 + 8], zmm1,
800                                         zword_b[BO1 + LDB * 1
801                                         + (i - OFFSET) * SIZE]);
802                             if (unroll_m >= 48)
803                                 vfmadd231ps(regs[1 + 16], zmm2,
804                                         zword_b[BO1 + LDB * 1
805                                         + (i - OFFSET) * SIZE]);
806                         } else {
807                             vfmadd231ps(regs[1], zmm0,
808                                     zword_b[BO1 + (1 - OFFSET) * SIZE]);
809                             if (unroll_m >= 32)
810                                 vfmadd231ps(regs[1 + 8], zmm1,
811                                         zword_b[BO1 + (1 - OFFSET) * SIZE]);
812                             if (unroll_m >= 48)
813                                 vfmadd231ps(regs[1 + 16], zmm2,
814                                         zword_b[BO1 + (1 - OFFSET) * SIZE]);
815                         }
816                     }
817                 }
818
819                 if (isCopy) {
820                     if (isUnmasked || unroll_m > 16) {
821                         vmovups(ptr[LDA4
822                                         + (unroll_m * i + 0 * 16 - OFFSET)
823                                                 * SIZE],
824                                 zmm0);
825                     } else {
826                         vmovups(ptr[LDA4
827                                         + (unroll_m * i + 0 * 16 - OFFSET)
828                                                 * SIZE],
829                                 zmm0 | k1);
830                     }
831                     if (unroll_m >= 32) {
832                         if (isUnmasked || unroll_m > 32) {
833                             vmovups(ptr[LDA4
834                                             + (unroll_m * i + 1 * 16 - OFFSET)
835                                                     * SIZE],
836                                     zmm1);
837                         } else {
838                             vmovups(ptr[LDA4
839                                             + (unroll_m * i + 1 * 16 - OFFSET)
840                                                     * SIZE],
841                                     zmm1 | k2);
842                         }
843                     }
844                     if (unroll_m >= 48) {
845                         if (isUnmasked) {
846                             vmovups(ptr[LDA4
847                                             + (unroll_m * i + 2 * 16 - OFFSET)
848                                                     * SIZE],
849                                     zmm2);
850                         } else {
851                             vmovups(ptr[LDA4
852                                             + (unroll_m * i + 2 * 16 - OFFSET)
853                                                     * SIZE],
854                                     zmm2 | k3);
855                         }
856                     }
857                     if (i == 7)
858                         sub(LDA4, -unroll_m * 8 * SIZE);
859                 }
860                 fmaloop(unroll_m, unroll_n, i);
861
862                 if (i == 1) {
863                     if (doCPrefetch) {
864                         if (ver == ver_avx512_core)
865                             prefetchw(ptr[CO2 + 0 * 16 * SIZE]);
866                         else
867                             prefetcht0(ptr[CO2 + 0 * 16 * SIZE]);
868                     }
869                 }
870                 if (i == 3) {
871                     if (doCPrefetch && unroll_m >= 32) {
872                         if (ver == ver_avx512_core)
873                             prefetchw(ptr[CO2 + 1 * 16 * SIZE]);
874                         else
875                             prefetcht0(ptr[CO2 + 1 * 16 * SIZE]);
876                     }
877                     if (!isTransA) {
878                         if (ver == ver_avx512_core)
879                             prefetcht0(ptr[AA + 16 * 0 * SIZE]);
880                         else
881                             prefetcht2(ptr[AA + 16 * 0 * SIZE]);
882                     }
883                 }
884                 if (i == 5) {
885                     if (doCPrefetch) {
886                         if (unroll_m >= 48) {
887                             if (ver == ver_avx512_core)
888                                 prefetchw(ptr[CO2 + 2 * 16 * SIZE]);
889                             else
890                                 prefetcht0(ptr[CO2 + 2 * 16 * SIZE]);
891                         }
892                         add(CO2, LDC);
893                     }
894                     if (!isTransA) {
895                         if (unroll_m >= 32) {
896                             if (ver == ver_avx512_core)
897                                 prefetcht0(ptr[AA + 16 * 1 * SIZE]);
898                             else
899                                 prefetcht2(ptr[AA + 16 * 1 * SIZE]);
900                         }
901                     }
902                 }
903
904                 if (isTransB) {
905                     prefetcht0(ptr[BO1 + BO2]);
906                     add(BO1, LDB);
907                 }
908             } // end of for loop
909
910             if (!isTransB) {
911                 sub(BO1, -8 * SIZE);
912                 if (unroll_n >= 4)
913                     sub(BO2, -8 * SIZE);
914             }
915             if (!isTransA) {
916                 if (unroll_m >= 48) {
917                     if (ver == ver_avx512_core)
918                         prefetcht0(ptr[AA + 16 * 2 * SIZE]);
919                     else
920                         prefetcht2(ptr[AA + 16 * 2 * SIZE]);
921                 }
922                 lea(AA, ptr[AA + LDA]);
923             }
924
925             if (!isDirect) {
926                 if (isUnmasked || unroll_m > 16) {
927                     vmovups(zmm0,
928                             ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
929                 } else {
930                     vmovups(zmm0 | k1 | T_z,
931                             ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
932                 }
933                 if (unroll_m >= 32) {
934                     if (isUnmasked || unroll_m > 32) {
935                         vmovups(zmm1, ptr[AO1
936                                               + (unroll_m * 8 + 1 * 16 - OFFSET)
937                                                       * SIZE]);
938                     } else {
939                         vmovups(zmm1 | k2 | T_z,
940                                 ptr[AO1
941                                         + (unroll_m * 8 + 1 * 16 - OFFSET)
942                                                 * SIZE]);
943                     }
944                 }
945                 if (unroll_m >= 48) {
946                     if (isUnmasked) {
947                         vmovups(zmm2, ptr[AO1
948                                               + (unroll_m * 8 + 2 * 16 - OFFSET)
949                                                       * SIZE]);
950                     } else {
951                         vmovups(zmm2 | k3 | T_z,
952                                 ptr[AO1
953                                         + (unroll_m * 8 + 2 * 16 - OFFSET)
954                                                 * SIZE]);
955                     }
956                 }
957                 sub(AO1, -unroll_m * 8 * SIZE);
958             }
959
960             sub(LL, 1);
961         };
962
963         // Main kernel; does prefetching and calls innerkernel
964         // After calculating results in registers, writes back to C matrix by
965         // calling update
966         auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
967                 bool isCopy, bool isUnmasked = true) {
968             if (!isDirect) {
969                 lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
970             } else {
971                 mov(AO1, A);
972             }
973
974             if (isCopy) {
975                 lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]);
976             } else {
977                 auto step = ver == ver_avx512_core ? 2 : 4;
978                 lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]);
979             }
980
981             if (isTransB) {
982                 lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]);
983             }
984
985             if (!isDirect) {
986                 if (isUnmasked || unroll_m > 16) {
987                     vmovups(zmm0,
988                             ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
989                 } else {
990                     vmovups(zmm0 | k1 | T_z,
991                             ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
992                 }
993                 if (unroll_m >= 32) {
994                     if (isUnmasked || unroll_m > 32) {
995                         vmovups(zmm1, ptr[AO1
996                                               + (unroll_m * 0 + 1 * 16 - OFFSET)
997                                                       * SIZE]);
998                     } else {
999                         vmovups(zmm1 | k2 | T_z,
1000                                 ptr[AO1
1001                                         + (unroll_m * 0 + 1 * 16 - OFFSET)
1002                                                 * SIZE]);
1003                     }
1004                 }
1005                 if (unroll_m >= 48) {
1006                     if (isUnmasked) {
1007                         vmovups(zmm2, ptr[AO1
1008                                               + (unroll_m * 0 + 2 * 16 - OFFSET)
1009                                                       * SIZE]);
1010                     } else {
1011                         vmovups(zmm2 | k3 | T_z,
1012                                 ptr[AO1
1013                                         + (unroll_m * 0 + 2 * 16 - OFFSET)
1014                                                 * SIZE]);
1015                     }
1016                 }
1017             }
1018
1019             Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18;
1020
1021             mov(LL, K);
1022             sar(LL, 3);
1023             sub(LL, SECOND_FETCH);
1024             jle(kernel13, T_NEAR);
1025             align(16);
1026
1027             L(kernel12);
1028             innerkernel(
1029                     unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked);
1030             jg(kernel12, T_NEAR);
1031             align(16);
1032
1033             L(kernel13);
1034             lea(CO2, ptr[CO1 + (16 - 1) * SIZE]);
1035             add(LL, unroll_n);
1036             jle(kernel15, T_NEAR);
1037             align(16);
1038
1039             L(kernel14);
1040             innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked);
1041             jg(kernel14, T_NEAR);
1042             align(16);
1043
1044             L(kernel15);
1045             mov(LL, K);
1046             and_(LL, 7);
1047             jle(kernel18, T_NEAR);
1048             align(16);
1049
1050             L(kernel16);
1051             if (isDirect) {
1052                 if (isUnmasked || unroll_m > 16) {
1053                     vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
1054                 } else {
1055                     vmovups(zmm0 | k1 | T_z,
1056                             ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
1057                 }
1058                 if (unroll_m >= 32) {
1059                     if (isUnmasked || unroll_m > 32) {
1060                         vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
1061                     } else {
1062                         vmovups(zmm1 | k2 | T_z,
1063                                 ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
1064                     }
1065                 }
1066                 if (unroll_m >= 48) {
1067                     if (isUnmasked) {
1068                         vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
1069                     } else {
1070                         vmovups(zmm2 | k3 | T_z,
1071                                 ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
1072                     }
1073                 }
1074                 add(AO1, LDA);
1075             }
1076
1077             for (int i = 0; i < unroll_n; i++) {
1078                 if (!isTransB) {
1079                     switch (i) {
1080                     case 0:
1081                         vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
1082                         break;
1083                     case 1:
1084                         vbroadcastss(
1085                                 zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1086                         break;
1087                     case 2:
1088                         vbroadcastss(
1089                                 zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1090                         break;
1091                     case 3:
1092                         vbroadcastss(
1093                                 zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]);
1094                         break;
1095                     case 4:
1096                         vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]);
1097                         break;
1098                     case 5:
1099                         vbroadcastss(
1100                                 zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1101                         break;
1102                     case 6:
1103                         vbroadcastss(
1104                                 zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1105                         break;
1106                     case 7:
1107                         vbroadcastss(
1108                                 zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]);
1109                         break;
1110                     }
1111                 } else {
1112                     vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
1113                 }
1114                 vfmadd231ps(regs[i], zmm3, zmm0);
1115                 if (unroll_m >= 32) {
1116                     vfmadd231ps(regs[i + 8], zmm3, zmm1);
1117                 }
1118                 if (unroll_m >= 48) {
1119                     vfmadd231ps(regs[i + 16], zmm3, zmm2);
1120                 }
1121             }
1122
1123             if (isCopy) {
1124                 if (isUnmasked || unroll_m > 16) {
1125                     vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
1126                             zmm0);
1127                 } else {
1128                     vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
1129                             zmm0 | k1);
1130                 }
1131                 if (unroll_m >= 32) {
1132                     if (isUnmasked || unroll_m > 32) {
1133                         vmovups(ptr[LDA4
1134                                         + (unroll_m * 0 + 1 * 16 - OFFSET)
1135                                                 * SIZE],
1136                                 zmm1);
1137                     } else {
1138                         vmovups(ptr[LDA4
1139                                         + (unroll_m * 0 + 1 * 16 - OFFSET)
1140                                                 * SIZE],
1141                                 zmm1 | k2);
1142                     }
1143                 }
1144                 if (unroll_m >= 48) {
1145                     if (isUnmasked) {
1146                         vmovups(ptr[LDA4
1147                                         + (unroll_m * 0 + 2 * 16 - OFFSET)
1148                                                 * SIZE],
1149                                 zmm2);
1150                     } else {
1151                         vmovups(ptr[LDA4
1152                                         + (unroll_m * 0 + 2 * 16 - OFFSET)
1153                                                 * SIZE],
1154                                 zmm2 | k3);
1155                     }
1156                 }
1157                 sub(LDA4, -unroll_m * SIZE);
1158             }
1159
1160             if (!isDirect) {
1161                 if (isUnmasked || unroll_m > 16) {
1162                     vmovups(zmm0,
1163                             ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
1164                 } else {
1165                     vmovups(zmm0 | k1 | T_z,
1166                             ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
1167                 }
1168                 if (unroll_m >= 32) {
1169                     if (isUnmasked || unroll_m > 32) {
1170                         vmovups(zmm1, ptr[AO1
1171                                               + (unroll_m * 1 + 1 * 16 - OFFSET)
1172                                                       * SIZE]);
1173                     } else {
1174                         vmovups(zmm1 | k2 | T_z,
1175                                 ptr[AO1
1176                                         + (unroll_m * 1 + 1 * 16 - OFFSET)
1177                                                 * SIZE]);
1178                     }
1179                 }
1180                 if (unroll_m >= 48) {
1181                     if (isUnmasked) {
1182                         vmovups(zmm2, ptr[AO1
1183                                               + (unroll_m * 1 + 2 * 16 - OFFSET)
1184                                                       * SIZE]);
1185                     } else {
1186                         vmovups(zmm2 | k3 | T_z,
1187                                 ptr[AO1
1188                                         + (unroll_m * 1 + 2 * 16 - OFFSET)
1189                                                 * SIZE]);
1190                     }
1191                 }
1192                 sub(AO1, -unroll_m * SIZE);
1193             }
1194
1195             if (!isTransB) {
1196                 sub(BO1, -SIZE);
1197                 if (unroll_n >= 4) {
1198                     sub(BO2, -SIZE);
1199                 }
1200             } else {
1201                 add(BO1, LDB);
1202             }
1203
1204             sub(LL, 1);
1205             jg(kernel16, T_NEAR);
1206             align(16);
1207
1208             L(kernel18);
1209             vbroadcastss(VALPHA, ALPHA);
1210
1211             if (isBetaN) {
1212                 vbroadcastss(VBETA, BETA);
1213             }
1214
1215             // Write back the results; all beta cases need to be handled
1216             if (hasBias) {
1217                 mov(BIAS1, BIAS);
1218                 if (isUnmasked || unroll_m > 16)
1219                     vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1220                 else
1221                     vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]);
1222                 if (unroll_m >= 32) {
1223                     if (isUnmasked || unroll_m > 32)
1224                         vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]);
1225                     else
1226                         vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]);
1227                 }
1228                 if (unroll_m >= 48) {
1229                     if (isUnmasked)
1230                         vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]);
1231                     else
1232                         vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]);
1233                 }
1234             }
1235
1236             for (int i = 0; i < unroll_n; i++) {
1237                 bool useScale = i % 2 != 0;
1238                 bool useCO1 = i < 2;
1239                 if (i == 2)
1240                     lea(CO2, ptr[CO1 + LDC * 2]);
1241                 if (i == 4 || i == 6)
1242                     lea(CO2, ptr[CO2 + LDC * 2]);
1243                 if (hasBias)
1244                     vaddps(regs[i], VBIAS1, regs[i]);
1245                 if (isUnmasked || unroll_m > 16) {
1246                     update(regs[i], useCO1, 0, 0, useScale);
1247                 } else {
1248                     update(regs[i], useCO1, 0, 1, useScale);
1249                 }
1250                 if (unroll_m >= 32) {
1251                     if (hasBias)
1252                         vaddps(regs[i + 8], VBIAS2, regs[i + 8]);
1253                     if (isUnmasked || unroll_m > 32) {
1254                         update(regs[i + 8], useCO1, 16, 0, useScale);
1255                     } else {
1256                         update(regs[i + 8], useCO1, 16, 2, useScale);
1257                     }
1258                 }
1259                 if (unroll_m >= 48) {
1260                     if (hasBias)
1261                         vaddps(regs[i + 16], VBIAS3, regs[i + 16]);
1262                     if (isUnmasked) {
1263                         update(regs[i + 16], useCO1, 32, 0, useScale);
1264                     } else {
1265                         update(regs[i + 16], useCO1, 32, 3, useScale);
1266                     }
1267                 }
1268             }
1269
1270             switch (unroll_n) {
1271             case 1: add(CO1, LDC); break;
1272             case 2: lea(CO1, ptr[CO1 + LDC * 2]); break;
1273             case 3: lea(CO1, ptr[CO2 + LDC * 1]); break;
1274             case 4: lea(CO1, ptr[CO2 + LDC * 2]); break;
1275             case 5: lea(CO1, ptr[CO2 + LDC * 1]); break;
1276             case 6: lea(CO1, ptr[CO2 + LDC * 2]); break;
1277             case 7: lea(CO1, ptr[CO2 + LDC * 1]); break;
1278             case 8: lea(CO1, ptr[CO2 + LDC * 2]); break;
1279             }
1280
1281             // Compute next address of B
1282             if (!isTransB) {
1283                 lea(rax, ptr[K * SIZE]);
1284                 switch (unroll_n) {
1285                 case 1:
1286                     add(BO1, LDB);
1287                     add(BO2, LDB);
1288                     break;
1289                 case 2:
1290                     lea(BO1, ptr[BO1 + LDB * 2]);
1291                     lea(BO2, ptr[BO2 + LDB * 2]);
1292                     break;
1293                 case 3:
1294                     lea(BO1, ptr[BO1 + LDB3]);
1295                     lea(BO2, ptr[BO2 + LDB3]);
1296                     break;
1297                 case 4:
1298                     lea(BO1, ptr[BO1 + LDB * 4]);
1299                     lea(BO2, ptr[BO2 + LDB * 4]);
1300                     break;
1301                 case 5:
1302                     lea(BO1, ptr[BO1 + LDB * 4]);
1303                     add(BO1, LDB);
1304                     lea(BO2, ptr[BO2 + LDB * 4]);
1305                     add(BO2, LDB);
1306                     break;
1307                 case 6:
1308                     lea(BO1, ptr[BO1 + LDB3 * 2]);
1309                     lea(BO2, ptr[BO2 + LDB3 * 2]);
1310                     break;
1311                 case 7:
1312                     lea(BO1, ptr[BO1 + LDB * 8]);
1313                     sub(BO1, LDB);
1314                     lea(BO2, ptr[BO2 + LDB * 8]);
1315                     sub(BO2, LDB);
1316                     break;
1317                 case 8:
1318                     lea(BO1, ptr[BO1 + LDB * 8]);
1319                     lea(BO2, ptr[BO2 + LDB * 8]);
1320                     break;
1321                 }
1322                 sub(BO1, rax);
1323                 sub(BO2, rax);
1324             } else {
1325                 mov(rax, LDB);
1326                 imul(rax, K);
1327                 sub(BO1, rax);
1328                 add(BO1, unroll_n * SIZE);
1329             }
1330         };
1331
1332         // High-level subroutine; does packing if needed, then splits C matrix.
1333         // Operates on chunks of 48 rows, 8 columns at a time (handling tail
1334         // cases appropriately by doing 32 or 16 rows, and/or with masking,
1335         // and/or fewer columns).
1336         auto subloop = [&](int unroll_m) {
1337             Label l_subloop_20x[8], l_subloop_mask_20x[8];
1338             Label l_subloop_30x[8], l_subloop_mask_30x[8];
1339
1340             Label subloop11, subloop11mask;
1341             Label subloop30, subloop30mask;
1342             Label subloop31, subloop31mask;
1343             Label subloop96;
1344             Label subloop98, subloop98mask;
1345             Label subloop99;
1346
1347             // Create mask
1348             mov(BO1, rcx);
1349             mov(rcx, M);
1350             sub(rcx, unroll_m - 16);
1351             mov(CO1, 16);
1352             cmp(rcx, 16);
1353
1354             cmovg(rcx, CO1);
1355             mov(rax, 1);
1356             sal(rax, cl);
1357             sub(rax, 1);
1358             mov(rcx, 0xffff);
1359
1360             if (unroll_m == 16) {
1361                 kmovw(k1, eax);
1362             } else if (unroll_m == 32) {
1363                 kmovw(k1, ecx);
1364                 kmovw(k2, eax);
1365             } else {
1366                 kmovw(k1, ecx);
1367                 kmovw(k2, ecx);
1368                 kmovw(k3, eax);
1369             }
1370             mov(rcx, BO1);
1371
1372             and_(rax, 0xffff);
1373             cmp(rax, 0xffff);
1374             jne(subloop96, T_NEAR);
1375
1376             if (isTransA) {
1377                 do_pack(unroll_m);
1378             }
1379
1380             mov(CO1, C);
1381             add(C, unroll_m * SIZE);
1382
1383             mov(BO1, B);
1384             if (!isTransB) {
1385                 lea(BO2, ptr[B + LDB * 4]);
1386             }
1387
1388             if (!isTransA) {
1389                 lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
1390                 cmp(M, UNROLL_M);
1391                 jg(subloop98, T_NEAR);
1392
1393                 mov(AA, ORIG_A);
1394                 lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
1395                 L(subloop98);
1396             }
1397
1398             mov(LL, N);
1399             mov(I, LL);
1400             if (!isTransA) {
1401                 // If N is too small, skip copy operation
1402                 cmp(LL, UNROLL_N * 3);
1403                 jle(subloop30, T_NEAR);
1404
1405                 // If A is not aligned to cache line
1406                 cmp(FLAG, 0);
1407                 je(subloop30, T_NEAR);
1408             } else {
1409                 cmp(LL, UNROLL_N);
1410                 jl(l_subloop_20x[1], T_NEAR);
1411             }
1412             align(16);
1413
1414             if (!isTransA) {
1415                 kernel(unroll_m, UNROLL_N, true, true);
1416             } else {
1417                 kernel(unroll_m, UNROLL_N, false, false);
1418             }
1419
1420             sub(I, UNROLL_N);
1421             cmp(I, UNROLL_N);
1422             jl(l_subloop_20x[1], T_NEAR);
1423             align(16);
1424
1425             L(subloop11);
1426             kernel(unroll_m, UNROLL_N, false, false);
1427             sub(I, UNROLL_N);
1428             cmp(I, UNROLL_N);
1429             jge(subloop11, T_NEAR);
1430             align(16);
1431
1432             for (int i = 1; i <= 7; i++) {
1433                 L(l_subloop_20x[i]);
1434                 cmp(I, i);
1435                 if (i < 7) {
1436                     jne(l_subloop_20x[i + 1], T_NEAR);
1437                 } else {
1438                     jne(subloop99, T_NEAR);
1439                 }
1440                 kernel(unroll_m, i, false, false);
1441                 jmp(subloop99, T_NEAR);
1442                 align(16);
1443             }
1444
1445             if (!isTransA) {
1446                 L(subloop30);
1447                 cmp(I, UNROLL_N);
1448                 jl(l_subloop_30x[1], T_NEAR);
1449                 align(16);
1450
1451                 L(subloop31);
1452                 kernel(unroll_m, UNROLL_N, true, false);
1453                 sub(I, UNROLL_N);
1454                 cmp(I, UNROLL_N);
1455                 jge(subloop31, T_NEAR);
1456                 align(16);
1457
1458                 for (int i = 1; i <= 7; i++) {
1459                     L(l_subloop_30x[i]);
1460                     cmp(I, i);
1461                     if (i < 7) {
1462                         jne(l_subloop_30x[i + 1], T_NEAR);
1463                     } else {
1464                         jne(subloop99, T_NEAR);
1465                     }
1466                     kernel(unroll_m, i, true, false);
1467                     if (i < 7)
1468                         jmp(subloop99, T_NEAR);
1469                     align(16);
1470                 }
1471             }
1472             jmp(subloop99, T_NEAR);
1473             align(16);
1474
1475             L(subloop96);
1476             if (isTransA) {
1477                 do_pack(unroll_m);
1478             }
1479
1480             mov(CO1, C);
1481             add(C, unroll_m * SIZE);
1482             mov(BO1, B);
1483             if (!isTransB) {
1484                 lea(BO2, ptr[B + LDB * 4]);
1485             }
1486
1487             if (!isTransA) {
1488                 lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
1489                 cmp(M, UNROLL_M);
1490                 jg(subloop98mask, T_NEAR);
1491                 mov(AA, ORIG_A);
1492                 lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
1493                 L(subloop98mask);
1494             }
1495
1496             mov(LL, N);
1497             mov(I, LL);
1498             if (!isTransA) {
1499                 // If N is too small, skip copy operation
1500                 cmp(LL, UNROLL_N * 3);
1501                 jle(subloop30mask, T_NEAR);
1502
1503                 // If A is not aligned to cache line
1504                 cmp(FLAG, 0);
1505                 je(subloop30mask, T_NEAR);
1506             } else {
1507                 cmp(LL, UNROLL_N);
1508                 jl(l_subloop_mask_20x[1], T_NEAR);
1509             }
1510             align(16);
1511
1512             if (!isTransA) {
1513                 kernel(unroll_m, UNROLL_N, true, true, false);
1514             } else {
1515                 kernel(unroll_m, UNROLL_N, false, false, false);
1516             }
1517
1518             sub(I, UNROLL_N);
1519             cmp(I, UNROLL_N);
1520             jl(l_subloop_mask_20x[1], T_NEAR);
1521             align(16);
1522
1523             L(subloop11mask);
1524             kernel(unroll_m, UNROLL_N, false, false, false);
1525             sub(I, UNROLL_N);
1526             cmp(I, UNROLL_N);
1527             jge(subloop11mask, T_NEAR);
1528             align(16);
1529
1530             for (int i = 1; i <= 7; i++) {
1531                 L(l_subloop_mask_20x[i]);
1532                 cmp(I, i);
1533                 if (i < 7) {
1534                     jne(l_subloop_mask_20x[i + 1], T_NEAR);
1535                 } else {
1536                     jne(subloop99, T_NEAR);
1537                 }
1538                 kernel(unroll_m, i, false, false, false);
1539                 jmp(subloop99, T_NEAR);
1540                 align(16);
1541             }
1542
1543             if (!isTransA) {
1544                 L(subloop30mask);
1545                 cmp(I, UNROLL_N);
1546                 jl(l_subloop_mask_30x[1], T_NEAR);
1547                 align(16);
1548
1549                 L(subloop31mask);
1550                 kernel(unroll_m, UNROLL_N, true, false, false);
1551                 sub(I, UNROLL_N);
1552                 cmp(I, UNROLL_N);
1553                 jge(subloop31mask, T_NEAR);
1554                 align(16);
1555
1556                 for (int i = 1; i <= 7; i++) {
1557                     L(l_subloop_mask_30x[i]);
1558                     cmp(I, i);
1559                     if (i < 7) {
1560                         jne(l_subloop_mask_30x[i + 1], T_NEAR);
1561                     } else {
1562                         jne(subloop99, T_NEAR);
1563                     }
1564                     kernel(unroll_m, i, true, false, false);
1565                     if (i < 7)
1566                         jmp(subloop99, T_NEAR);
1567                     align(16);
1568                 }
1569             }
1570
1571             L(subloop99);
1572             // Compute address for A
1573             if (!isTransA) {
1574                 add(A, unroll_m * SIZE);
1575             } else {
1576                 mov(rax, LDA);
1577                 imul(rax, rax, unroll_m);
1578                 add(A, rax);
1579             }
1580
1581             // Compute next address of BIAS
1582             if (hasBias) {
1583                 add(BIAS, unroll_m * SIZE);
1584             }
1585         };
1586
1587         preamble();
1588
1589         Label buffer_in_ws, buffer_allocated;
1590
1591         // Get the registers
1592         mov(B, ARG_B);
1593         mov(LDB, ARG_LDB);
1594         mov(r15, ARG_BETA);
1595         mov(r12, ARG_C);
1596         if (hasBias)
1597             mov(r10, ARG_BIAS);
1598         mov(LDC, ARG_LDC);
1599         mov(rbp, rsp);
1600
1601         vmovss(xmm0, ptr[ARG_ALPHA]);
1602         vmovss(xmm1, ptr[r15]);
1603
1604 #if _WIN32
1605         mov(A, ARG_A);
1606         mov(LDA, ARG_LDA);
1607 #endif
1608
1609         cmp(K, STACK_K_CAPACITY);
1610         jg(buffer_in_ws, T_NEAR);
1611
1612         // Create buffer and align to 4kB page
1613         lea(rax, ptr[K * SIZE]);
1614         imul(rax, rax, 0x30);
1615         add(rax, 256);
1616         sub(rsp, rax);
1617         and_(rsp, -PAGE_4K);
1618         jmp(buffer_allocated, T_NEAR);
1619
1620         L(buffer_in_ws);
1621         mov(rsp, ARG_WS);
1622
1623         L(buffer_allocated);
1624
1625         mov(ORIG_SP, rbp);
1626         mov(M, ARG_M);
1627         mov(N, ARG_N);
1628         mov(C, r12);
1629         if (hasBias)
1630             mov(BIAS, r10);
1631         vmovss(ALPHA, xmm0);
1632         vmovss(BETA, xmm1);
1633         sub(A, -OFFSET * SIZE);
1634         sub(B, -OFFSET * SIZE);
1635         mov(ORIG_A, A);
1636         sal(LDA, BASE_SHIFT);
1637         sal(LDB, BASE_SHIFT);
1638         sal(LDC, BASE_SHIFT);
1639         lea(LDB3, ptr[LDB + LDB * 2]);
1640
1641         if (isTransA) {
1642             vpbroadcastq(zmm2, LDA);
1643             vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE);
1644             mov(rax, -2);
1645             kmovw(k4, eax);
1646
1647             for (int i = 0; i < 6; i++) {
1648                 vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
1649                 kshiftlw(k4, k4, 1);
1650             }
1651             vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
1652         }
1653
1654         // Check A alignment and leading dimension; take copy-based path as
1655         // needed
1656         mov(rax, LDA);
1657         or_(rax, A);
1658         and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f);
1659         mov(FLAG, rax);
1660
1661         for (int i = 8; i < 16; i++) {
1662             for (int j = 0; j < 3; j++) {
1663                 vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j));
1664             }
1665         }
1666
1667         Label main0, main1, main2, main999;
1668
1669         cmp(M, 32);
1670         jle(main0, T_NEAR);
1671         align(16);
1672
1673         L(main1);
1674         subloop(48);
1675         sub(M, UNROLL_M);
1676         cmp(M, 32);
1677         jg(main1, T_NEAR);
1678         align(16);
1679
1680         L(main0);
1681         cmp(M, 16);
1682         jle(main2, T_NEAR);
1683
1684         subloop(32);
1685         jmp(main999, T_NEAR);
1686         align(16);
1687
1688         L(main2);
1689         cmp(M, 0);
1690         jle(main999, T_NEAR);
1691         subloop(16);
1692         align(16);
1693
1694         L(main999);
1695         // Restore original stack
1696         mov(rsp, ORIG_SP);
1697
1698         vzeroupper();
1699         postamble();
1700
1701         ker_ = reinterpret_cast<decltype(ker_)>(
1702                 const_cast<uint8_t *>(this->getCode()));
1703     }
1704
1705     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm)
1706
1707     void operator()(long long int m, long long int n, long long int k,
1708             const float *alpha, const float *a, long long int lda,
1709             const float *b, long long int ldb, const float *beta, float *c,
1710             long long int ldc, const float *bias, float *ws)
1711     {
1712         (*ker_)(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
1713     }
1714
1715 private:
1716     void (*ker_)(long long int m, long long int n, long long int k,
1717             const float *alpha, const float *a, long long int lda,
1718             const float *b, long long int ldb, const float *beta, float *c,
1719             long long int ldc, const float *bias, float *ws);
1720 };
1721
1722 typedef void (*ker)(long long int, long long int, long long int, float *,
1723         float *, long long int, float *, long long int, float *, float *,
1724         long long int, float *, float *);
1725 void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa,
1726         const char *transb, int m, int n, int k, const float *alpha,
1727         const float *a, int lda, const float *b, int ldb, const float *beta,
1728         float *c, int ldc, const float *bias, float *ws)
1729 {
1730     bool isTransA = (*transa == 'T' || *transa == 't');
1731     bool isTransB = (*transb == 'T' || *transb == 't');
1732
1733     int Bm, sizeM, Bn, sizeN, Bk, sizeK;
1734
1735     int i, j;
1736
1737     if ((m <= 0) || (n <= 0))
1738         return;
1739
1740     if ((k <= 0) || (alpha[0] == 0.)) {
1741
1742         if (beta[0] == 0.) {
1743             for (j = 0; j < n; j++)
1744                 for (i = 0; i < m; i++)
1745                     c[i + j * ldc] = 0.0;
1746         } else if (beta[0] != 1.) {
1747             for (j = 0; j < n; j++)
1748                 for (i = 0; i < m; i++)
1749                     c[i + j * ldc] *= beta[0];
1750         }
1751
1752         return;
1753     }
1754
1755     int BM = 4032, BN, BK;
1756     if (mayiuse(avx512_core)) {
1757         BN = isTransA ? 384 : 64;
1758         BK = 384;
1759     } else {
1760         BN = isTransA ? 96 : 64;
1761         BK = isTransB ? 96 : 192;
1762         if (!isTransA && !isTransB)
1763             BK = 128;
1764     }
1765     const float *curA, *curB, *curBias = nullptr;
1766     float *curC;
1767
1768     for (Bk = 0; Bk < k; Bk += sizeK) {
1769         sizeK = k - Bk;
1770         if (sizeK >= BK * 2)
1771             sizeK = BK;
1772         else {
1773             if (sizeK > BK)
1774                 sizeK = (sizeK + 1) / 2;
1775         }
1776
1777         for (Bm = 0; Bm < m; Bm += sizeM) {
1778             sizeM = m - Bm;
1779             if (sizeM >= BM * 2)
1780                 sizeM = BM;
1781             else {
1782                 if (sizeM > BM + BM / 2)
1783                     sizeM = (sizeM + 1) / 2;
1784             }
1785
1786             for (Bn = 0; Bn < n; Bn += sizeN) {
1787                 sizeN = n - Bn;
1788                 if (sizeN >= BN * 2)
1789                     sizeN = BN;
1790                 else {
1791                     if (sizeN > BN + BN / 2)
1792                         sizeN = (sizeN + 1) / 2;
1793                 }
1794
1795                 if (!isTransA) {
1796                     curA = a + Bm + (size_t)Bk * lda;
1797                 } else {
1798                     curA = a + Bk + (size_t)Bm * lda;
1799                 }
1800                 if (!isTransB) {
1801                     curB = b + Bk + (size_t)Bn * ldb;
1802                 } else {
1803                     curB = b + Bn + (size_t)Bk * ldb;
1804                 }
1805                 curC = c + Bm + (size_t)Bn * ldc;
1806                 if (bias != nullptr) {
1807                     if (Bk == 0) {
1808                         curBias = bias + Bm;
1809                     } else {
1810                         curBias = nullptr;
1811                     }
1812                 }
1813                 if (Bk == 0) {
1814                     if (*beta == 0.0 && bias == nullptr)
1815                         (*ker_b0_)((long long int)sizeM, (long long int)sizeN,
1816                                 (long long int)sizeK, alpha, curA,
1817                                 (long long int)lda, curB, (long long int)ldb,
1818                                 beta, curC, (long long int)ldc, curBias, ws);
1819                     else
1820                         (*ker_bn_)((long long int)sizeM, (long long int)sizeN,
1821                                 (long long int)sizeK, alpha, curA,
1822                                 (long long int)lda, curB, (long long int)ldb,
1823                                 beta, curC, (long long int)ldc, curBias, ws);
1824                 } else {
1825                     (*ker_b1_)((long long int)sizeM, (long long int)sizeN,
1826                             (long long int)sizeK, alpha, curA,
1827                             (long long int)lda, curB, (long long int)ldb, beta,
1828                             curC, (long long int)ldc, curBias, ws);
1829                 }
1830             }
1831         }
1832     }
1833     return;
1834 }
1835
1836 void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
1837         const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
1838         const float *A, const int *p_lda, const float *B, const int *p_ldb,
1839         const float *p_beta, float *C, const int *p_ldc, const float *bias)
1840 {
1841     if (beta_ == 0. || beta_ == 1.)
1842         assert(*p_beta == beta_);
1843     assert((one_of(*transa, 'T', 't') == one_of(transa_, 'T', 't')));
1844
1845     int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
1846     int m = *p_m;
1847     int n = *p_n;
1848     int k = *p_k;
1849     int lda = *p_lda;
1850     int ldb = *p_ldb;
1851     int ldc = *p_ldc;
1852     float beta = *p_beta;
1853     int MB, NB, KB;
1854
1855     int nthr_m, nthr_n, nthr_k, nthr_mn;
1856
1857     assert(nthr <= nthrs_);
1858
1859     // Determine threading partitioning
1860     gemm_utils::calc_nthr_nocopy_avx512_common(
1861             m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
1862     assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
1863
1864     // May not happen, but just in case
1865     if (nthr < nthr_m * nthr_n * nthr_k)
1866         nthr = nthr_m * nthr_n * nthr_k;
1867
1868     nthr_mn = nthr_m * nthr_n;
1869
1870     unsigned char * ompstatus_ = nullptr;
1871     unsigned char volatile *ompstatus = nullptr;
1872
1873     float *c_buffers = nullptr;
1874     float *ws_buffers = nullptr;
1875
1876     if (nthr_k > 1) {
1877         ompstatus_ = (unsigned char *) malloc(
1878                 nthr * CACHE_LINE_SIZE,
1879                 CACHE_LINE_SIZE);
1880         ompstatus = (unsigned char volatile *) ompstatus_;
1881         assert(ompstatus);
1882         for (int i = 0; i < nthr; i++)
1883             ompstatus[i * CACHE_LINE_SIZE] = 0;
1884
1885         c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
1886                 * sizeof(float), PAGE_4K);
1887     }
1888
1889     const size_t ws_elems_per_thr = k * 48 + 64;
1890     const size_t ws_size_per_thr
1891             = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
1892     if (k > STACK_K_CAPACITY) {
1893         ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
1894     }
1895
1896     parallel(nthr, [&](const int ithr, const int nthr) {
1897         int ithr_m, ithr_n, ithr_k, ithr_mn;
1898         int m_from, m_to, myM;
1899         int n_from, n_to, myN;
1900         int k_from, k_to, myK;
1901         int cbase, ibase;
1902         const float *myA, *myB, *myBias = nullptr;
1903         float *myC = C, myBeta;
1904         float *ws = ws_buffers ?
1905                 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
1906         int ld = ldc;
1907
1908         if (ithr < nthr_m * nthr_n * nthr_k) {
1909
1910             ithr_mn = ithr % nthr_mn;
1911             ithr_m = ithr_mn % nthr_m;
1912             ithr_n = ithr_mn / nthr_m;
1913             ithr_k = ithr / nthr_mn;
1914
1915             /* swap ithr_k for performance improvement */
1916             if (ithr_k == 0)
1917                 ithr_k = nthr_k - 1;
1918             else if (ithr_k == nthr_k - 1)
1919                 ithr_k = 0;
1920
1921             m_from = MB * (ithr_m);
1922             m_to = MB * (ithr_m + 1);
1923             if (m_to > m)
1924                 m_to = m;
1925             myM = m_to - m_from;
1926
1927             n_from = NB * (ithr_n);
1928             n_to = NB * (ithr_n + 1);
1929             if (n_to > n)
1930                 n_to = n;
1931             myN = n_to - n_from;
1932
1933             k_from = KB * (ithr_k);
1934             k_to = KB * (ithr_k + 1);
1935             if (k_to > k)
1936                 k_to = k;
1937             myK = k_to - k_from;
1938
1939             cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
1940             ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
1941
1942             if ((myM > 0) && (myN > 0)) {
1943
1944                 if (*transa == 'N' || *transa == 'n') {
1945                     myA = &(A[m_from + k_from * lda]);
1946                 } else {
1947                     myA = &(A[k_from + m_from * lda]);
1948                 }
1949                 if (*transb == 'N' || *transb == 'n') {
1950                     myB = &(B[k_from + n_from * ldb]);
1951                 } else {
1952                     myB = &(B[n_from + k_from * ldb]);
1953                 }
1954                 if (ithr_k == 0) {
1955                     myC = &(C[m_from + n_from * ldc]);
1956                     myBeta = beta;
1957                     ld = ldc;
1958                     if (hasBias_)
1959                         myBias = &(bias[m_from]);
1960                 } else {
1961                     myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
1962                     myBeta = 0.0;
1963                     ld = MB;
1964                     myBias = nullptr;
1965                 }
1966
1967                 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
1968                         lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
1969
1970                 if (nthr_k > 1)
1971                     ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
1972             }
1973
1974             if (nthr_k > 1) {
1975
1976                 // sum matrices partitioned along K dimension
1977                 int n1, n2;
1978
1979                 gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
1980
1981                 if (ithr_k > 0) {
1982
1983                     myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
1984                     myC = myC + n1 * MB;
1985                     /* need to wait until main thread finishes */
1986                     while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
1987                     };
1988
1989                     /* my cache is hot */
1990                     gemm_utils::sum_two_matrices(myM, n2, myC, MB,
1991                             &C[m_from + (n_from + n1) * ldc], ldc);
1992                 }
1993
1994                 for (int ik = 1; ik < nthr_k; ++ik) {
1995                     if (ik != ithr_k) {
1996
1997                         myC = c_buffers + MB * NB * (cbase + ik - 1);
1998                         myC = myC + n1 * MB;
1999
2000                         while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2001                         };
2002
2003                         gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2004                                 &C[m_from + (n_from + n1) * ldc], ldc);
2005                     }
2006                 }
2007             }
2008         }
2009     });
2010
2011     free(c_buffers);
2012     free(ompstatus_);
2013     free(ws_buffers);
2014 }
2015
2016 jit_avx512_common_gemm_f32::jit_avx512_common_gemm_f32(
2017         char transa, char transb, float beta, bool hasBias)
2018 {
2019     transa_ = transa;
2020     transb_ = transb;
2021     beta_ = beta;
2022     hasBias_ = hasBias;
2023     if (hasBias) {
2024         assert(beta == 0.0);
2025     }
2026     ker_bn_ = new xbyak_gemm(transa, transb, beta, hasBias);
2027     if (beta != 1.0) {
2028         ker_b1_ = new xbyak_gemm(transa, transb, 1.0);
2029     } else {
2030         ker_b1_ = ker_bn_;
2031     }
2032     if (beta != 0.0 || (beta == 0.0 && hasBias)) {
2033         ker_b0_ = new xbyak_gemm(transa, transb, 0.0);
2034     } else {
2035         ker_b0_ = ker_bn_;
2036     }
2037
2038     nthrs_ = mkldnn_get_max_threads();
2039 }
2040
2041 jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32()
2042 {
2043     delete ker_bn_;
2044     if (beta_ != 1.0)
2045         delete ker_b1_;
2046     if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
2047         delete ker_b0_;
2048 }
2049 }
2050 }
2051 }
2052
2053 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s