Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / f32 / jit_avx512_common_gemm_f32.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cmath>
18 #include <mutex>
19
20 #include "mkldnn_thread.hpp"
21 #include "utils.hpp"
22
23 #include "ref_gemm_f32.hpp"
24 #include "gemm_utils_f32.hpp"
25 #include "jit_avx512_common_gemm_f32.hpp"
26
27 #include "jit_generator.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 #define CACHE_LINE_SIZE 64
34
35 #define STACKSIZE get_size_of_abi_save_regs()
36 #ifdef _WIN32
37 #define STACK_K_CAPACITY 32
38 #else
39 #define STACK_K_CAPACITY 2048
40 #endif
41 #define SIZE 4
42 #define OFFSET 128
43 #define BASE_SHIFT 2
44 #define SECOND_FETCH unroll_n
45 #define UNROLL_M 48
46 #define UNROLL_N 8
47
48 namespace avx512_common_gemm_f32 {
49 using namespace gemm_utils;
50
51 struct xbyak_gemm : public jit_generator {
52     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm)
53
54     xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
55             void *code_ptr = nullptr,
56             size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
57         : jit_generator(code_ptr, code_size)
58     {
59         using namespace Xbyak;
60
61         enum { ver_avx512_core, ver_avx512_mic } ver =
62             mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic;
63
64         bool isBeta0 = (beta == 0.0);
65         bool isBetaN = (!isBeta0 && beta != 1.0);
66
67         // various definitions for convenience
68         auto ARG_M = abi_param1;
69         auto ARG_N = abi_param2;
70         auto K = abi_param3;
71         auto ARG_ALPHA = abi_param4;
72 #ifdef _WIN32
73         auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
74         auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
75             sizeof(float *) + STACKSIZE];
76         const auto stackOffset = OFFSET_SHADOWSPACE +
77             sizeof(float *) + STACKSIZE;
78         auto A = rsi;
79         auto LDA = rdi;
80 #else
81         auto ARG_A = r8;
82         auto ARG_LDA = r9;
83         const auto stackOffset = STACKSIZE;
84         auto A = ARG_A;
85         auto LDA = ARG_LDA;
86 #endif
87         auto ARG_B = ptr[rsp + 8 + stackOffset];
88         auto ARG_LDB = ptr[rsp + 16 + stackOffset];
89         auto ARG_BETA = ptr[rsp + 24 + stackOffset];
90         auto ARG_C = ptr[rsp + 32 + stackOffset];
91         auto ARG_LDC = ptr[rsp + 40 + stackOffset];
92         auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
93         auto ARG_WS = ptr[rsp + 56 + stackOffset];
94
95         auto B = r11;
96         auto LDB = rbx;
97         auto LDC = r13;
98         auto LL = rax;
99         auto AO1 = abi_param2;
100         auto BO1 = abi_param4;
101         auto BO2 = rbp;
102         auto CO1 = r14;
103         auto CO2 = r15;
104         auto LDB3 = r10;
105         auto LDA4 = abi_param1;
106         auto AA = r12;
107         auto BIAS1 = abi_param1;
108
109         auto M = qword[rsp + 0];
110         auto N = qword[rsp + 8];
111         auto FLAG = qword[rsp + 16];
112         auto I = qword[rsp + 24];
113         auto C = qword[rsp + 32];
114         auto BIAS = qword[rsp + 40];
115         auto ALPHA = qword[rsp + 48];
116         auto BETA = qword[rsp + 64];
117         auto ORIG_A = qword[rsp + 80];
118         auto ORIG_SP = qword[rsp + 120];
119
120         auto ZSTRIDE = zmm4;
121         auto VALPHA = zmm6;
122         auto VBETA = zmm7;
123         auto VBIAS1 = zmm1;
124         auto VBIAS2 = zmm2;
125         auto VBIAS3 = zmm3;
126
127         auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80;
128         auto PREFETCHSIZEB = 16;
129
130         Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15,
131             zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24,
132             zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 };
133
134         // Function for packing if needed
135         auto do_pack = [&](int unroll_m) {
136             Label pack2, pack3, pack4, pack10;
137
138             mov(BO1, A);
139             lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
140             mov(LL, K);
141             sar(LL, 2);
142             jle(pack3, T_NEAR);
143             align(16);
144
145             L(pack2);
146             if (!isTransA) {
147                 for (int i = 0; i < 4; i++) {
148                     vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
149                     if (unroll_m > 16)
150                         vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
151                     if (unroll_m > 32)
152                         vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
153                     add(BO1, LDA);
154
155                     vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE]
156                                     | k1,
157                             zmm0);
158                     if (unroll_m > 16)
159                         vmovups(ptr[AO1
160                                         + (unroll_m * i + 1 * 16 - OFFSET)
161                                                 * SIZE]
162                                         | k2,
163                                 zmm1);
164                     if (unroll_m > 32)
165                         vmovups(ptr[AO1
166                                         + (unroll_m * i + 2 * 16 - OFFSET)
167                                                 * SIZE]
168                                         | k3,
169                                 zmm2);
170                 }
171             } else {
172                 for (int i = 0; i < 4; i++) {
173                     kmovw(k4, k1);
174                     vgatherqps(ymm5 | k4,
175                             ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]);
176                     lea(BO2, ptr[BO1 + LDA * 8]);
177                     kshiftrw(k4, k1, 8);
178                     vgatherqps(ymm6 | k4,
179                             ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
180                     vshuff64x2(zmm0, zmm5, zmm6, 0x44);
181
182                     if (unroll_m > 16) {
183                         lea(BO2, ptr[BO2 + LDA * 8]);
184                         kmovw(k4, k2);
185                         vgatherqps(ymm5 | k4,
186                                 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
187                         lea(BO2, ptr[BO2 + LDA * 8]);
188                         kshiftrw(k4, k2, 8);
189                         vgatherqps(ymm6 | k4,
190                                 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
191                         vshuff64x2(zmm1, zmm5, zmm6, 0x44);
192                     }
193
194                     if (unroll_m > 32) {
195                         lea(BO2, ptr[BO2 + LDA * 8]);
196                         kmovw(k4, k3);
197                         vgatherqps(ymm5 | k4,
198                                 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
199                         lea(BO2, ptr[BO2 + LDA * 8]);
200                         kshiftrw(k4, k3, 8);
201                         vgatherqps(ymm6 | k4,
202                                 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
203                         lea(BO2, ptr[BO2 + LDA * 8]);
204                         vshuff64x2(zmm2, zmm5, zmm6, 0x44);
205                     }
206
207                     vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE],
208                             zmm0 | k1);
209                     if (unroll_m > 16)
210                         vmovups(ptr[AO1
211                                         + (unroll_m * i + 1 * 16 - OFFSET)
212                                                 * SIZE],
213                                 zmm1 | k2);
214                     if (unroll_m > 32)
215                         vmovups(ptr[AO1
216                                         + (unroll_m * i + 2 * 16 - OFFSET)
217                                                 * SIZE],
218                                 zmm2 | k3);
219                 }
220                 add(BO1, 4 * SIZE);
221             }
222             add(AO1, unroll_m * 4 * SIZE);
223
224             sub(LL, 1);
225             jg(pack2, T_NEAR);
226             align(16);
227
228             L(pack3);
229             mov(LL, K);
230             and_(LL, 3);
231             jle(pack10, T_NEAR);
232             align(16);
233
234             L(pack4);
235             if (!isTransA) {
236                 vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
237                 if (unroll_m > 16)
238                     vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
239                 if (unroll_m > 32)
240                     vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
241                 add(BO1, LDA);
242             } else {
243                 kmovw(k4, k1);
244                 vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]);
245                 lea(BO2, ptr[BO1 + LDA * 8]);
246                 kshiftrw(k4, k1, 8);
247                 vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
248                 vshuff64x2(zmm0, zmm5, zmm6, 0x44);
249
250                 if (unroll_m > 16) {
251                     lea(BO2, ptr[BO2 + LDA * 8]);
252                     kmovw(k4, k2);
253                     vgatherqps(ymm5 | k4,
254                             ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
255                     lea(BO2, ptr[BO2 + LDA * 8]);
256                     kshiftrw(k4, k2, 8);
257                     vgatherqps(ymm6 | k4,
258                             ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
259                     vshuff64x2(zmm1, zmm5, zmm6, 0x44);
260                 }
261
262                 if (unroll_m > 32) {
263                     lea(BO2, ptr[BO2 + LDA * 8]);
264                     kmovw(k4, k3);
265                     vgatherqps(ymm5 | k4,
266                             ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
267                     lea(BO2, ptr[BO2 + LDA * 8]);
268                     kshiftrw(k4, k3, 8);
269                     vgatherqps(ymm6 | k4,
270                             ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
271                     lea(BO2, ptr[BO2 + LDA * 8]);
272                     vshuff64x2(zmm2, zmm5, zmm6, 0x44);
273                 }
274                 add(BO1, SIZE);
275             }
276
277             vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
278                     zmm0 | k1);
279             if (unroll_m > 16)
280                 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE],
281                         zmm1 | k2);
282             if (unroll_m > 32)
283                 vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE],
284                         zmm2 | k3);
285
286             add(AO1, unroll_m * SIZE);
287             sub(LL, 1);
288             jg(pack4, T_NEAR);
289             align(16);
290
291             L(pack10);
292         };
293
294         // Function to update C, covering masking and other considerations
295         auto update = [&](Zmm reg, bool useCO1, int offset, int mask,
296                 bool useScale = false) {
297             vmulps(reg, reg, VALPHA);
298             if (!isBeta0) {
299                 if (!useScale) {
300                     switch (mask) {
301                     case 0:
302                         if (useCO1)
303                             vmovups(zmm0, ptr[CO1 + offset * SIZE]);
304                         else
305                             vmovups(zmm0, ptr[CO2 + offset * SIZE]);
306                         break;
307                     case 1:
308                         if (useCO1)
309                             vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]);
310                         else
311                             vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]);
312                         break;
313                     case 2:
314                         if (useCO1)
315                             vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]);
316                         else
317                             vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]);
318                         break;
319                     case 3:
320                         if (useCO1)
321                             vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]);
322                         else
323                             vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]);
324                         break;
325                     }
326                 } else {
327                     switch (mask) {
328                     case 0:
329                         if (useCO1)
330                             vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]);
331                         else
332                             vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]);
333                         break;
334                     case 1:
335                         if (useCO1)
336                             vmovups(zmm0 | k1 | T_z,
337                                     ptr[CO1 + LDC + offset * SIZE]);
338                         else
339                             vmovups(zmm0 | k1 | T_z,
340                                     ptr[CO2 + LDC + offset * SIZE]);
341                         break;
342                     case 2:
343                         if (useCO1)
344                             vmovups(zmm0 | k2 | T_z,
345                                     ptr[CO1 + LDC + offset * SIZE]);
346                         else
347                             vmovups(zmm0 | k2 | T_z,
348                                     ptr[CO2 + LDC + offset * SIZE]);
349                         break;
350                     case 3:
351                         if (useCO1)
352                             vmovups(zmm0 | k3 | T_z,
353                                     ptr[CO1 + LDC + offset * SIZE]);
354                         else
355                             vmovups(zmm0 | k3 | T_z,
356                                     ptr[CO2 + LDC + offset * SIZE]);
357                         break;
358                     }
359                 }
360                 if (!isBetaN) {
361                     vaddps(zmm0, reg, zmm0);
362                 } else {
363                     vfmadd132ps(zmm0, reg, VBETA);
364                 }
365                 if (!useScale) {
366                     switch (mask) {
367                     case 0:
368                         if (useCO1)
369                             vmovups(ptr[CO1 + offset * SIZE], zmm0);
370                         else
371                             vmovups(ptr[CO2 + offset * SIZE], zmm0);
372                         break;
373                     case 1:
374                         if (useCO1)
375                             vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1);
376                         else
377                             vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1);
378                         break;
379                     case 2:
380                         if (useCO1)
381                             vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2);
382                         else
383                             vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2);
384                         break;
385                     case 3:
386                         if (useCO1)
387                             vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3);
388                         else
389                             vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3);
390                         break;
391                     }
392                 } else {
393                     switch (mask) {
394                     case 0:
395                         if (useCO1)
396                             vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0);
397                         else
398                             vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0);
399                         break;
400                     case 1:
401                         if (useCO1)
402                             vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1);
403                         else
404                             vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1);
405                         break;
406                     case 2:
407                         if (useCO1)
408                             vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2);
409                         else
410                             vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2);
411                         break;
412                     case 3:
413                         if (useCO1)
414                             vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3);
415                         else
416                             vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3);
417                         break;
418                     }
419                 }
420             } else {
421                 if (!useScale) {
422                     switch (mask) {
423                     case 0:
424                         if (useCO1)
425                             vmovups(ptr[CO1 + offset * SIZE], reg);
426                         else
427                             vmovups(ptr[CO2 + offset * SIZE], reg);
428                         break;
429                     case 1:
430                         if (useCO1)
431                             vmovups(ptr[CO1 + offset * SIZE], reg | k1);
432                         else
433                             vmovups(ptr[CO2 + offset * SIZE], reg | k1);
434                         break;
435                     case 2:
436                         if (useCO1)
437                             vmovups(ptr[CO1 + offset * SIZE], reg | k2);
438                         else
439                             vmovups(ptr[CO2 + offset * SIZE], reg | k2);
440                         break;
441                     case 3:
442                         if (useCO1)
443                             vmovups(ptr[CO1 + offset * SIZE], reg | k3);
444                         else
445                             vmovups(ptr[CO2 + offset * SIZE], reg | k3);
446                         break;
447                     }
448                 } else {
449                     switch (mask) {
450                     case 0:
451                         if (useCO1)
452                             vmovups(ptr[CO1 + LDC + offset * SIZE], reg);
453                         else
454                             vmovups(ptr[CO2 + LDC + offset * SIZE], reg);
455                         break;
456                     case 1:
457                         if (useCO1)
458                             vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1);
459                         else
460                             vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1);
461                         break;
462                     case 2:
463                         if (useCO1)
464                             vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2);
465                         else
466                             vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2);
467                         break;
468                     case 3:
469                         if (useCO1)
470                             vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3);
471                         else
472                             vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3);
473                         break;
474                     }
475                 }
476             }
477             vpxorq(reg, reg, reg);
478         };
479
480         // Loop with unroll_n - 2 FMAs; called by innerkernel
481         auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) {
482             for (int i = 2; i < unroll_n; i++) {
483                 if (ver == ver_avx512_core) {
484                     if (!isTransB) {
485                         switch (i) {
486                         case 2:
487                             vbroadcastss(
488                                     zmm3,
489                                     ptr[BO1 + LDB * 2
490                                             + (iteration - OFFSET) * SIZE]);
491                             break;
492                         case 3:
493                             vbroadcastss(
494                                     zmm3,
495                                     ptr[BO1 + LDB3
496                                             + (iteration - OFFSET) * SIZE]);
497                             break;
498                         case 4:
499                             vbroadcastss(zmm3,
500                                     ptr[BO2 + (iteration - OFFSET) * SIZE]);
501                             break;
502                         case 5:
503                             vbroadcastss(
504                                     zmm3,
505                                     ptr[BO2 + LDB * 1
506                                             + (iteration - OFFSET) * SIZE]);
507                             break;
508                         case 6:
509                             vbroadcastss(
510                                     zmm3,
511                                     ptr[BO2 + LDB * 2
512                                             + (iteration - OFFSET) * SIZE]);
513                             break;
514                         case 7:
515                             vbroadcastss(
516                                     zmm3,
517                                     ptr[BO2 + LDB3
518                                             + (iteration - OFFSET) * SIZE]);
519                             break;
520                         }
521                     } else {
522                         vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
523                     }
524                     vfmadd231ps(regs[i], zmm3, zmm0);
525                     if (unroll_m >= 32)
526                         vfmadd231ps(regs[i + 8], zmm3, zmm1);
527                     if (unroll_m >= 48)
528                         vfmadd231ps(regs[i + 16], zmm3, zmm2);
529                 } else {
530                     if (!isTransB) {
531                         switch (i) {
532                         case 2:
533                             vfmadd231ps(regs[i], zmm0,
534                                     zword_b[BO1 + LDB * 2
535                                     + (iteration - OFFSET) * SIZE]);
536                             if (unroll_m >= 32)
537                                 vfmadd231ps(regs[i + 8], zmm1,
538                                         zword_b[BO1 + LDB * 2
539                                         + (iteration - OFFSET) * SIZE]);
540                             if (unroll_m >= 48)
541                                 vfmadd231ps(regs[i + 16], zmm2,
542                                         zword_b[BO1 + LDB * 2
543                                         + (iteration - OFFSET) * SIZE]);
544                             break;
545                         case 3:
546                             vfmadd231ps(regs[i], zmm0,
547                                     zword_b[BO1 + LDB3
548                                     + (iteration - OFFSET) * SIZE]);
549                             if (unroll_m >= 32)
550                                 vfmadd231ps(regs[i + 8], zmm1,
551                                         zword_b[BO1 + LDB3
552                                         + (iteration - OFFSET) * SIZE]);
553                             if (unroll_m >= 48)
554                                 vfmadd231ps(regs[i + 16], zmm2,
555                                         zword_b[BO1 + LDB3
556                                         + (iteration - OFFSET) * SIZE]);
557                             break;
558                         case 4:
559                             vfmadd231ps(regs[i], zmm0,
560                                     zword_b[BO2 + (iteration - OFFSET) * SIZE]);
561                             if (unroll_m >= 32)
562                                 vfmadd231ps(regs[i + 8], zmm1,
563                                         zword_b[BO2 + (iteration - OFFSET) * SIZE]);
564                             if (unroll_m >= 48)
565                                 vfmadd231ps(regs[i + 16], zmm2,
566                                         zword_b[BO2 + (iteration - OFFSET) * SIZE]);
567                             break;
568                         case 5:
569                             vfmadd231ps(regs[i], zmm0,
570                                     zword_b[BO2 + LDB * 1
571                                     + (iteration - OFFSET) * SIZE]);
572                             if (unroll_m >= 32)
573                                 vfmadd231ps(regs[i + 8], zmm1,
574                                         zword_b[BO2 + LDB * 1
575                                         + (iteration - OFFSET) * SIZE]);
576                             if (unroll_m >= 48)
577                                 vfmadd231ps(regs[i + 16], zmm2,
578                                         zword_b[BO2 + LDB * 1
579                                         + (iteration - OFFSET) * SIZE]);
580                             break;
581                         case 6:
582                             vfmadd231ps(regs[i], zmm0,
583                                     zword_b[BO2 + LDB * 2
584                                     + (iteration - OFFSET) * SIZE]);
585                             if (unroll_m >= 32)
586                                 vfmadd231ps(regs[i + 8], zmm1,
587                                         zword_b[BO2 + LDB * 2
588                                         + (iteration - OFFSET) * SIZE]);
589                             if (unroll_m >= 48)
590                                 vfmadd231ps(regs[i + 16], zmm2,
591                                         zword_b[BO2 + LDB * 2
592                                         + (iteration - OFFSET) * SIZE]);
593                             break;
594                         case 7:
595                             vfmadd231ps(regs[i], zmm0,
596                                     zword_b[BO2 + LDB3
597                                     + (iteration - OFFSET) * SIZE]);
598                             if (unroll_m >= 32)
599                                 vfmadd231ps(regs[i + 8], zmm1,
600                                         zword_b[BO2 + LDB3
601                                         + (iteration - OFFSET) * SIZE]);
602                             if (unroll_m >= 48)
603                                 vfmadd231ps(regs[i + 16], zmm2,
604                                         zword_b[BO2 + LDB3
605                                         + (iteration - OFFSET) * SIZE]);
606                             break;
607                         }
608                     } else {
609                         vfmadd231ps(
610                                 regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]);
611                         if (unroll_m >= 32)
612                             vfmadd231ps(regs[i + 8], zmm1,
613                                     zword_b[BO1 + (i - OFFSET) * SIZE]);
614                         if (unroll_m >= 48)
615                             vfmadd231ps(regs[i + 16], zmm2,
616                                     zword_b[BO1 + (i - OFFSET) * SIZE]);
617                     }
618                 }
619             }
620         };
621
622         // Innerkernel; called by kernel
623         auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect,
624                 bool isCopy, bool doCPrefetch, bool isUnmasked = true) {
625             for (int i = 0; i < 8; i++) {
626                 if (!isDirect) {
627                     prefetcht0(ptr[AO1
628                             + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET)
629                                     * SIZE]);
630                     if (unroll_m >= 32)
631                         prefetcht0(ptr[AO1
632                             + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET)
633                                     * SIZE]);
634                     if (unroll_m >= 48)
635                         prefetcht0(ptr[AO1
636                             + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET)
637                                     * SIZE]);
638                 } else {
639                     prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]);
640                     if (unroll_m >= 32)
641                         prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]);
642                     if (unroll_m >= 48)
643                         prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]);
644                 }
645
646                 if (!isDirect) {
647                     if (i != 0) {
648                         if (isUnmasked || unroll_m > 16) {
649                             vmovups(zmm0,
650                                     ptr[AO1
651                                             + (unroll_m * i + 0 * 16 - OFFSET)
652                                                     * SIZE]);
653                         } else {
654                             vmovups(zmm0 | k1 | T_z,
655                                     ptr[AO1
656                                             + (unroll_m * i + 0 * 16 - OFFSET)
657                                                     * SIZE]);
658                         }
659                         if (unroll_m >= 32) {
660                             if (isUnmasked || unroll_m > 32) {
661                                 vmovups(zmm1, ptr[AO1
662                                                       + (unroll_m * i + 1 * 16
663                                                                 - OFFSET)
664                                                               * SIZE]);
665                             } else {
666                                 vmovups(zmm1 | k2 | T_z,
667                                         ptr[AO1
668                                                 + (unroll_m * i + 1 * 16
669                                                           - OFFSET)
670                                                         * SIZE]);
671                             }
672                         }
673                         if (unroll_m >= 48) {
674                             if (isUnmasked) {
675                                 vmovups(zmm2, ptr[AO1
676                                                       + (unroll_m * i + 2 * 16
677                                                                 - OFFSET)
678                                                               * SIZE]);
679                             } else {
680                                 vmovups(zmm2 | k3 | T_z,
681                                         ptr[AO1
682                                                 + (unroll_m * i + 2 * 16
683                                                           - OFFSET)
684                                                         * SIZE]);
685                             }
686                         }
687                     }
688                 } else {
689                     if (isUnmasked || unroll_m > 16) {
690                         vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
691                     } else {
692                         vmovups(zmm0 | k1 | T_z,
693                                 ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
694                     }
695                     if (unroll_m >= 32) {
696                         if (isUnmasked || unroll_m > 32) {
697                             vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
698                         } else {
699                             vmovups(zmm1 | k2 | T_z,
700                                     ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
701                         }
702                     }
703                     if (unroll_m >= 48) {
704                         if (isUnmasked) {
705                             vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
706                         } else {
707                             vmovups(zmm2 | k3 | T_z,
708                                     ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
709                         }
710                     }
711                     add(AO1, LDA);
712                 }
713
714                 if (ver == ver_avx512_core) {
715                     if (!isTransB) {
716                         vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
717                     } else {
718                         vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
719                     }
720                     vfmadd231ps(regs[0], zmm3, zmm0);
721                     if (unroll_m >= 32)
722                         vfmadd231ps(regs[0 + 8], zmm3, zmm1);
723                     if (unroll_m >= 48)
724                         vfmadd231ps(regs[0 + 16], zmm3, zmm2);
725                 } else {
726                     if (!isTransB) {
727                         vfmadd231ps(regs[0], zmm0,
728                                 zword_b[BO1 + (i - OFFSET) * SIZE]);
729                         if (unroll_m >= 32)
730                             vfmadd231ps(regs[0 + 8], zmm1,
731                                     zword_b[BO1 + (i - OFFSET) * SIZE]);
732                         if (unroll_m >= 48)
733                             vfmadd231ps(regs[0 + 16], zmm2,
734                                     zword_b[BO1 + (i - OFFSET) * SIZE]);
735                     } else {
736                         vfmadd231ps(regs[0], zmm0,
737                                 zword_b[BO1 + (0 - OFFSET) * SIZE]);
738                         if (unroll_m >= 32)
739                             vfmadd231ps(regs[0 + 8], zmm1,
740                                     zword_b[BO1 + (0 - OFFSET) * SIZE]);
741                         if (unroll_m >= 48)
742                             vfmadd231ps(regs[0 + 16], zmm2,
743                                     zword_b[BO1 + (0 - OFFSET) * SIZE]);
744                     }
745                 }
746
747                 if (unroll_n >= i + 1) {
748                     if (!isTransB) {
749                         switch (i) {
750                         case 0:
751                             prefetcht0(
752                                     ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]);
753                             break;
754                         case 1:
755                             prefetcht0(ptr[BO1 + LDB
756                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
757                             break;
758                         case 2:
759                             prefetcht0(ptr[BO1 + LDB * 2
760                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
761                             break;
762                         case 3:
763                             prefetcht0(ptr[BO1 + LDB3
764                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
765                             break;
766                         case 4:
767                             prefetcht0(
768                                     ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]);
769                             break;
770                         case 5:
771                             prefetcht0(ptr[BO2 + LDB
772                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
773                             break;
774                         case 6:
775                             prefetcht0(ptr[BO2 + LDB * 2
776                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
777                             break;
778                         case 7:
779                             prefetcht0(ptr[BO2 + LDB3
780                                     + (PREFETCHSIZEB - OFFSET) * SIZE]);
781                             break;
782                         }
783                     }
784                 }
785
786                 if (unroll_n >= 2) {
787                     if (ver == ver_avx512_core) {
788                         if (!isTransB) {
789                             vbroadcastss(zmm3,
790                                     ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
791                         } else {
792                             vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]);
793                         }
794                         vfmadd231ps(regs[1], zmm3, zmm0);
795                         if (unroll_m >= 32)
796                             vfmadd231ps(regs[1 + 8], zmm3, zmm1);
797                         if (unroll_m >= 48)
798                             vfmadd231ps(regs[1 + 16], zmm3, zmm2);
799                     } else {
800                         if (!isTransB) {
801                             vfmadd231ps(regs[1], zmm0,
802                                     zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
803                             if (unroll_m >= 32)
804                                 vfmadd231ps(regs[1 + 8], zmm1,
805                                         zword_b[BO1 + LDB * 1
806                                         + (i - OFFSET) * SIZE]);
807                             if (unroll_m >= 48)
808                                 vfmadd231ps(regs[1 + 16], zmm2,
809                                         zword_b[BO1 + LDB * 1
810                                         + (i - OFFSET) * SIZE]);
811                         } else {
812                             vfmadd231ps(regs[1], zmm0,
813                                     zword_b[BO1 + (1 - OFFSET) * SIZE]);
814                             if (unroll_m >= 32)
815                                 vfmadd231ps(regs[1 + 8], zmm1,
816                                         zword_b[BO1 + (1 - OFFSET) * SIZE]);
817                             if (unroll_m >= 48)
818                                 vfmadd231ps(regs[1 + 16], zmm2,
819                                         zword_b[BO1 + (1 - OFFSET) * SIZE]);
820                         }
821                     }
822                 }
823
824                 if (isCopy) {
825                     if (isUnmasked || unroll_m > 16) {
826                         vmovups(ptr[LDA4
827                                         + (unroll_m * i + 0 * 16 - OFFSET)
828                                                 * SIZE],
829                                 zmm0);
830                     } else {
831                         vmovups(ptr[LDA4
832                                         + (unroll_m * i + 0 * 16 - OFFSET)
833                                                 * SIZE],
834                                 zmm0 | k1);
835                     }
836                     if (unroll_m >= 32) {
837                         if (isUnmasked || unroll_m > 32) {
838                             vmovups(ptr[LDA4
839                                             + (unroll_m * i + 1 * 16 - OFFSET)
840                                                     * SIZE],
841                                     zmm1);
842                         } else {
843                             vmovups(ptr[LDA4
844                                             + (unroll_m * i + 1 * 16 - OFFSET)
845                                                     * SIZE],
846                                     zmm1 | k2);
847                         }
848                     }
849                     if (unroll_m >= 48) {
850                         if (isUnmasked) {
851                             vmovups(ptr[LDA4
852                                             + (unroll_m * i + 2 * 16 - OFFSET)
853                                                     * SIZE],
854                                     zmm2);
855                         } else {
856                             vmovups(ptr[LDA4
857                                             + (unroll_m * i + 2 * 16 - OFFSET)
858                                                     * SIZE],
859                                     zmm2 | k3);
860                         }
861                     }
862                     if (i == 7)
863                         sub(LDA4, -unroll_m * 8 * SIZE);
864                 }
865                 fmaloop(unroll_m, unroll_n, i);
866
867                 if (i == 1) {
868                     if (doCPrefetch) {
869                         if (ver == ver_avx512_core)
870                             prefetchw(ptr[CO2 + 0 * 16 * SIZE]);
871                         else
872                             prefetcht0(ptr[CO2 + 0 * 16 * SIZE]);
873                     }
874                 }
875                 if (i == 3) {
876                     if (doCPrefetch && unroll_m >= 32) {
877                         if (ver == ver_avx512_core)
878                             prefetchw(ptr[CO2 + 1 * 16 * SIZE]);
879                         else
880                             prefetcht0(ptr[CO2 + 1 * 16 * SIZE]);
881                     }
882                     if (!isTransA) {
883                         if (ver == ver_avx512_core)
884                             prefetcht0(ptr[AA + 16 * 0 * SIZE]);
885                         else
886                             prefetcht2(ptr[AA + 16 * 0 * SIZE]);
887                     }
888                 }
889                 if (i == 5) {
890                     if (doCPrefetch) {
891                         if (unroll_m >= 48) {
892                             if (ver == ver_avx512_core)
893                                 prefetchw(ptr[CO2 + 2 * 16 * SIZE]);
894                             else
895                                 prefetcht0(ptr[CO2 + 2 * 16 * SIZE]);
896                         }
897                         add(CO2, LDC);
898                     }
899                     if (!isTransA) {
900                         if (unroll_m >= 32) {
901                             if (ver == ver_avx512_core)
902                                 prefetcht0(ptr[AA + 16 * 1 * SIZE]);
903                             else
904                                 prefetcht2(ptr[AA + 16 * 1 * SIZE]);
905                         }
906                     }
907                 }
908
909                 if (isTransB) {
910                     prefetcht0(ptr[BO1 + BO2]);
911                     add(BO1, LDB);
912                 }
913             } // end of for loop
914
915             if (!isTransB) {
916                 sub(BO1, -8 * SIZE);
917                 if (unroll_n >= 4)
918                     sub(BO2, -8 * SIZE);
919             }
920             if (!isTransA) {
921                 if (unroll_m >= 48) {
922                     if (ver == ver_avx512_core)
923                         prefetcht0(ptr[AA + 16 * 2 * SIZE]);
924                     else
925                         prefetcht2(ptr[AA + 16 * 2 * SIZE]);
926                 }
927                 lea(AA, ptr[AA + LDA]);
928             }
929
930             if (!isDirect) {
931                 if (isUnmasked || unroll_m > 16) {
932                     vmovups(zmm0,
933                             ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
934                 } else {
935                     vmovups(zmm0 | k1 | T_z,
936                             ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
937                 }
938                 if (unroll_m >= 32) {
939                     if (isUnmasked || unroll_m > 32) {
940                         vmovups(zmm1, ptr[AO1
941                                               + (unroll_m * 8 + 1 * 16 - OFFSET)
942                                                       * SIZE]);
943                     } else {
944                         vmovups(zmm1 | k2 | T_z,
945                                 ptr[AO1
946                                         + (unroll_m * 8 + 1 * 16 - OFFSET)
947                                                 * SIZE]);
948                     }
949                 }
950                 if (unroll_m >= 48) {
951                     if (isUnmasked) {
952                         vmovups(zmm2, ptr[AO1
953                                               + (unroll_m * 8 + 2 * 16 - OFFSET)
954                                                       * SIZE]);
955                     } else {
956                         vmovups(zmm2 | k3 | T_z,
957                                 ptr[AO1
958                                         + (unroll_m * 8 + 2 * 16 - OFFSET)
959                                                 * SIZE]);
960                     }
961                 }
962                 sub(AO1, -unroll_m * 8 * SIZE);
963             }
964
965             sub(LL, 1);
966         };
967
968         // Main kernel; does prefetching and calls innerkernel
969         // After calculating results in registers, writes back to C matrix by
970         // calling update
971         auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
972                 bool isCopy, bool isUnmasked = true) {
973             if (!isDirect) {
974                 lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
975             } else {
976                 mov(AO1, A);
977             }
978
979             if (isCopy) {
980                 lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]);
981             } else {
982                 auto step = ver == ver_avx512_core ? 2 : 4;
983                 lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]);
984             }
985
986             if (isTransB) {
987                 lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]);
988             }
989
990             if (!isDirect) {
991                 if (isUnmasked || unroll_m > 16) {
992                     vmovups(zmm0,
993                             ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
994                 } else {
995                     vmovups(zmm0 | k1 | T_z,
996                             ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
997                 }
998                 if (unroll_m >= 32) {
999                     if (isUnmasked || unroll_m > 32) {
1000                         vmovups(zmm1, ptr[AO1
1001                                               + (unroll_m * 0 + 1 * 16 - OFFSET)
1002                                                       * SIZE]);
1003                     } else {
1004                         vmovups(zmm1 | k2 | T_z,
1005                                 ptr[AO1
1006                                         + (unroll_m * 0 + 1 * 16 - OFFSET)
1007                                                 * SIZE]);
1008                     }
1009                 }
1010                 if (unroll_m >= 48) {
1011                     if (isUnmasked) {
1012                         vmovups(zmm2, ptr[AO1
1013                                               + (unroll_m * 0 + 2 * 16 - OFFSET)
1014                                                       * SIZE]);
1015                     } else {
1016                         vmovups(zmm2 | k3 | T_z,
1017                                 ptr[AO1
1018                                         + (unroll_m * 0 + 2 * 16 - OFFSET)
1019                                                 * SIZE]);
1020                     }
1021                 }
1022             }
1023
1024             Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18;
1025
1026             mov(LL, K);
1027             sar(LL, 3);
1028             sub(LL, SECOND_FETCH);
1029             jle(kernel13, T_NEAR);
1030             align(16);
1031
1032             L(kernel12);
1033             innerkernel(
1034                     unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked);
1035             jg(kernel12, T_NEAR);
1036             align(16);
1037
1038             L(kernel13);
1039             lea(CO2, ptr[CO1 + (16 - 1) * SIZE]);
1040             add(LL, unroll_n);
1041             jle(kernel15, T_NEAR);
1042             align(16);
1043
1044             L(kernel14);
1045             innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked);
1046             jg(kernel14, T_NEAR);
1047             align(16);
1048
1049             L(kernel15);
1050             mov(LL, K);
1051             and_(LL, 7);
1052             jle(kernel18, T_NEAR);
1053             align(16);
1054
1055             L(kernel16);
1056             if (isDirect) {
1057                 if (isUnmasked || unroll_m > 16) {
1058                     vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
1059                 } else {
1060                     vmovups(zmm0 | k1 | T_z,
1061                             ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
1062                 }
1063                 if (unroll_m >= 32) {
1064                     if (isUnmasked || unroll_m > 32) {
1065                         vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
1066                     } else {
1067                         vmovups(zmm1 | k2 | T_z,
1068                                 ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
1069                     }
1070                 }
1071                 if (unroll_m >= 48) {
1072                     if (isUnmasked) {
1073                         vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
1074                     } else {
1075                         vmovups(zmm2 | k3 | T_z,
1076                                 ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
1077                     }
1078                 }
1079                 add(AO1, LDA);
1080             }
1081
1082             for (int i = 0; i < unroll_n; i++) {
1083                 if (!isTransB) {
1084                     switch (i) {
1085                     case 0:
1086                         vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
1087                         break;
1088                     case 1:
1089                         vbroadcastss(
1090                                 zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1091                         break;
1092                     case 2:
1093                         vbroadcastss(
1094                                 zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1095                         break;
1096                     case 3:
1097                         vbroadcastss(
1098                                 zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]);
1099                         break;
1100                     case 4:
1101                         vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]);
1102                         break;
1103                     case 5:
1104                         vbroadcastss(
1105                                 zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1106                         break;
1107                     case 6:
1108                         vbroadcastss(
1109                                 zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1110                         break;
1111                     case 7:
1112                         vbroadcastss(
1113                                 zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]);
1114                         break;
1115                     }
1116                 } else {
1117                     vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
1118                 }
1119                 vfmadd231ps(regs[i], zmm3, zmm0);
1120                 if (unroll_m >= 32) {
1121                     vfmadd231ps(regs[i + 8], zmm3, zmm1);
1122                 }
1123                 if (unroll_m >= 48) {
1124                     vfmadd231ps(regs[i + 16], zmm3, zmm2);
1125                 }
1126             }
1127
1128             if (isCopy) {
1129                 if (isUnmasked || unroll_m > 16) {
1130                     vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
1131                             zmm0);
1132                 } else {
1133                     vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
1134                             zmm0 | k1);
1135                 }
1136                 if (unroll_m >= 32) {
1137                     if (isUnmasked || unroll_m > 32) {
1138                         vmovups(ptr[LDA4
1139                                         + (unroll_m * 0 + 1 * 16 - OFFSET)
1140                                                 * SIZE],
1141                                 zmm1);
1142                     } else {
1143                         vmovups(ptr[LDA4
1144                                         + (unroll_m * 0 + 1 * 16 - OFFSET)
1145                                                 * SIZE],
1146                                 zmm1 | k2);
1147                     }
1148                 }
1149                 if (unroll_m >= 48) {
1150                     if (isUnmasked) {
1151                         vmovups(ptr[LDA4
1152                                         + (unroll_m * 0 + 2 * 16 - OFFSET)
1153                                                 * SIZE],
1154                                 zmm2);
1155                     } else {
1156                         vmovups(ptr[LDA4
1157                                         + (unroll_m * 0 + 2 * 16 - OFFSET)
1158                                                 * SIZE],
1159                                 zmm2 | k3);
1160                     }
1161                 }
1162                 sub(LDA4, -unroll_m * SIZE);
1163             }
1164
1165             if (!isDirect) {
1166                 if (isUnmasked || unroll_m > 16) {
1167                     vmovups(zmm0,
1168                             ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
1169                 } else {
1170                     vmovups(zmm0 | k1 | T_z,
1171                             ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
1172                 }
1173                 if (unroll_m >= 32) {
1174                     if (isUnmasked || unroll_m > 32) {
1175                         vmovups(zmm1, ptr[AO1
1176                                               + (unroll_m * 1 + 1 * 16 - OFFSET)
1177                                                       * SIZE]);
1178                     } else {
1179                         vmovups(zmm1 | k2 | T_z,
1180                                 ptr[AO1
1181                                         + (unroll_m * 1 + 1 * 16 - OFFSET)
1182                                                 * SIZE]);
1183                     }
1184                 }
1185                 if (unroll_m >= 48) {
1186                     if (isUnmasked) {
1187                         vmovups(zmm2, ptr[AO1
1188                                               + (unroll_m * 1 + 2 * 16 - OFFSET)
1189                                                       * SIZE]);
1190                     } else {
1191                         vmovups(zmm2 | k3 | T_z,
1192                                 ptr[AO1
1193                                         + (unroll_m * 1 + 2 * 16 - OFFSET)
1194                                                 * SIZE]);
1195                     }
1196                 }
1197                 sub(AO1, -unroll_m * SIZE);
1198             }
1199
1200             if (!isTransB) {
1201                 sub(BO1, -SIZE);
1202                 if (unroll_n >= 4) {
1203                     sub(BO2, -SIZE);
1204                 }
1205             } else {
1206                 add(BO1, LDB);
1207             }
1208
1209             sub(LL, 1);
1210             jg(kernel16, T_NEAR);
1211             align(16);
1212
1213             L(kernel18);
1214             vbroadcastss(VALPHA, ALPHA);
1215
1216             if (isBetaN) {
1217                 vbroadcastss(VBETA, BETA);
1218             }
1219
1220             // Write back the results; all beta cases need to be handled
1221             if (hasBias) {
1222                 mov(BIAS1, BIAS);
1223                 if (isUnmasked || unroll_m > 16)
1224                     vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1225                 else
1226                     vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]);
1227                 if (unroll_m >= 32) {
1228                     if (isUnmasked || unroll_m > 32)
1229                         vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]);
1230                     else
1231                         vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]);
1232                 }
1233                 if (unroll_m >= 48) {
1234                     if (isUnmasked)
1235                         vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]);
1236                     else
1237                         vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]);
1238                 }
1239             }
1240
1241             for (int i = 0; i < unroll_n; i++) {
1242                 bool useScale = i % 2 != 0;
1243                 bool useCO1 = i < 2;
1244                 if (i == 2)
1245                     lea(CO2, ptr[CO1 + LDC * 2]);
1246                 if (i == 4 || i == 6)
1247                     lea(CO2, ptr[CO2 + LDC * 2]);
1248                 if (hasBias)
1249                     vaddps(regs[i], VBIAS1, regs[i]);
1250                 if (isUnmasked || unroll_m > 16) {
1251                     update(regs[i], useCO1, 0, 0, useScale);
1252                 } else {
1253                     update(regs[i], useCO1, 0, 1, useScale);
1254                 }
1255                 if (unroll_m >= 32) {
1256                     if (hasBias)
1257                         vaddps(regs[i + 8], VBIAS2, regs[i + 8]);
1258                     if (isUnmasked || unroll_m > 32) {
1259                         update(regs[i + 8], useCO1, 16, 0, useScale);
1260                     } else {
1261                         update(regs[i + 8], useCO1, 16, 2, useScale);
1262                     }
1263                 }
1264                 if (unroll_m >= 48) {
1265                     if (hasBias)
1266                         vaddps(regs[i + 16], VBIAS3, regs[i + 16]);
1267                     if (isUnmasked) {
1268                         update(regs[i + 16], useCO1, 32, 0, useScale);
1269                     } else {
1270                         update(regs[i + 16], useCO1, 32, 3, useScale);
1271                     }
1272                 }
1273             }
1274
1275             switch (unroll_n) {
1276             case 1: add(CO1, LDC); break;
1277             case 2: lea(CO1, ptr[CO1 + LDC * 2]); break;
1278             case 3: lea(CO1, ptr[CO2 + LDC * 1]); break;
1279             case 4: lea(CO1, ptr[CO2 + LDC * 2]); break;
1280             case 5: lea(CO1, ptr[CO2 + LDC * 1]); break;
1281             case 6: lea(CO1, ptr[CO2 + LDC * 2]); break;
1282             case 7: lea(CO1, ptr[CO2 + LDC * 1]); break;
1283             case 8: lea(CO1, ptr[CO2 + LDC * 2]); break;
1284             }
1285
1286             // Compute next address of B
1287             if (!isTransB) {
1288                 lea(rax, ptr[K * SIZE]);
1289                 switch (unroll_n) {
1290                 case 1:
1291                     add(BO1, LDB);
1292                     add(BO2, LDB);
1293                     break;
1294                 case 2:
1295                     lea(BO1, ptr[BO1 + LDB * 2]);
1296                     lea(BO2, ptr[BO2 + LDB * 2]);
1297                     break;
1298                 case 3:
1299                     lea(BO1, ptr[BO1 + LDB3]);
1300                     lea(BO2, ptr[BO2 + LDB3]);
1301                     break;
1302                 case 4:
1303                     lea(BO1, ptr[BO1 + LDB * 4]);
1304                     lea(BO2, ptr[BO2 + LDB * 4]);
1305                     break;
1306                 case 5:
1307                     lea(BO1, ptr[BO1 + LDB * 4]);
1308                     add(BO1, LDB);
1309                     lea(BO2, ptr[BO2 + LDB * 4]);
1310                     add(BO2, LDB);
1311                     break;
1312                 case 6:
1313                     lea(BO1, ptr[BO1 + LDB3 * 2]);
1314                     lea(BO2, ptr[BO2 + LDB3 * 2]);
1315                     break;
1316                 case 7:
1317                     lea(BO1, ptr[BO1 + LDB * 8]);
1318                     sub(BO1, LDB);
1319                     lea(BO2, ptr[BO2 + LDB * 8]);
1320                     sub(BO2, LDB);
1321                     break;
1322                 case 8:
1323                     lea(BO1, ptr[BO1 + LDB * 8]);
1324                     lea(BO2, ptr[BO2 + LDB * 8]);
1325                     break;
1326                 }
1327                 sub(BO1, rax);
1328                 sub(BO2, rax);
1329             } else {
1330                 mov(rax, LDB);
1331                 imul(rax, K);
1332                 sub(BO1, rax);
1333                 add(BO1, unroll_n * SIZE);
1334             }
1335         };
1336
1337         // High-level subroutine; does packing if needed, then splits C matrix.
1338         // Operates on chunks of 48 rows, 8 columns at a time (handling tail
1339         // cases appropriately by doing 32 or 16 rows, and/or with masking,
1340         // and/or fewer columns).
1341         auto subloop = [&](int unroll_m) {
1342             Label l_subloop_20x[8], l_subloop_mask_20x[8];
1343             Label l_subloop_30x[8], l_subloop_mask_30x[8];
1344
1345             Label subloop11, subloop11mask;
1346             Label subloop30, subloop30mask;
1347             Label subloop31, subloop31mask;
1348             Label subloop96;
1349             Label subloop98, subloop98mask;
1350             Label subloop99;
1351
1352             // Create mask
1353             mov(BO1, rcx);
1354             mov(rcx, M);
1355             sub(rcx, unroll_m - 16);
1356             mov(CO1, 16);
1357             cmp(rcx, 16);
1358
1359             cmovg(rcx, CO1);
1360             mov(rax, 1);
1361             sal(rax, cl);
1362             sub(rax, 1);
1363             mov(rcx, 0xffff);
1364
1365             if (unroll_m == 16) {
1366                 kmovw(k1, eax);
1367             } else if (unroll_m == 32) {
1368                 kmovw(k1, ecx);
1369                 kmovw(k2, eax);
1370             } else {
1371                 kmovw(k1, ecx);
1372                 kmovw(k2, ecx);
1373                 kmovw(k3, eax);
1374             }
1375             mov(rcx, BO1);
1376
1377             and_(rax, 0xffff);
1378             cmp(rax, 0xffff);
1379             jne(subloop96, T_NEAR);
1380
1381             if (isTransA) {
1382                 do_pack(unroll_m);
1383             }
1384
1385             mov(CO1, C);
1386             add(C, unroll_m * SIZE);
1387
1388             mov(BO1, B);
1389             if (!isTransB) {
1390                 lea(BO2, ptr[B + LDB * 4]);
1391             }
1392
1393             if (!isTransA) {
1394                 lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
1395                 cmp(M, UNROLL_M);
1396                 jg(subloop98, T_NEAR);
1397
1398                 mov(AA, ORIG_A);
1399                 lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
1400                 L(subloop98);
1401             }
1402
1403             mov(LL, N);
1404             mov(I, LL);
1405             if (!isTransA) {
1406                 // If N is too small, skip copy operation
1407                 cmp(LL, UNROLL_N * 3);
1408                 jle(subloop30, T_NEAR);
1409
1410                 // If A is not aligned to cache line
1411                 cmp(FLAG, 0);
1412                 je(subloop30, T_NEAR);
1413             } else {
1414                 cmp(LL, UNROLL_N);
1415                 jl(l_subloop_20x[1], T_NEAR);
1416             }
1417             align(16);
1418
1419             if (!isTransA) {
1420                 kernel(unroll_m, UNROLL_N, true, true);
1421             } else {
1422                 kernel(unroll_m, UNROLL_N, false, false);
1423             }
1424
1425             sub(I, UNROLL_N);
1426             cmp(I, UNROLL_N);
1427             jl(l_subloop_20x[1], T_NEAR);
1428             align(16);
1429
1430             L(subloop11);
1431             kernel(unroll_m, UNROLL_N, false, false);
1432             sub(I, UNROLL_N);
1433             cmp(I, UNROLL_N);
1434             jge(subloop11, T_NEAR);
1435             align(16);
1436
1437             for (int i = 1; i <= 7; i++) {
1438                 L(l_subloop_20x[i]);
1439                 cmp(I, i);
1440                 if (i < 7) {
1441                     jne(l_subloop_20x[i + 1], T_NEAR);
1442                 } else {
1443                     jne(subloop99, T_NEAR);
1444                 }
1445                 kernel(unroll_m, i, false, false);
1446                 jmp(subloop99, T_NEAR);
1447                 align(16);
1448             }
1449
1450             if (!isTransA) {
1451                 L(subloop30);
1452                 cmp(I, UNROLL_N);
1453                 jl(l_subloop_30x[1], T_NEAR);
1454                 align(16);
1455
1456                 L(subloop31);
1457                 kernel(unroll_m, UNROLL_N, true, false);
1458                 sub(I, UNROLL_N);
1459                 cmp(I, UNROLL_N);
1460                 jge(subloop31, T_NEAR);
1461                 align(16);
1462
1463                 for (int i = 1; i <= 7; i++) {
1464                     L(l_subloop_30x[i]);
1465                     cmp(I, i);
1466                     if (i < 7) {
1467                         jne(l_subloop_30x[i + 1], T_NEAR);
1468                     } else {
1469                         jne(subloop99, T_NEAR);
1470                     }
1471                     kernel(unroll_m, i, true, false);
1472                     if (i < 7)
1473                         jmp(subloop99, T_NEAR);
1474                     align(16);
1475                 }
1476             }
1477             jmp(subloop99, T_NEAR);
1478             align(16);
1479
1480             L(subloop96);
1481             if (isTransA) {
1482                 do_pack(unroll_m);
1483             }
1484
1485             mov(CO1, C);
1486             add(C, unroll_m * SIZE);
1487             mov(BO1, B);
1488             if (!isTransB) {
1489                 lea(BO2, ptr[B + LDB * 4]);
1490             }
1491
1492             if (!isTransA) {
1493                 lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
1494                 cmp(M, UNROLL_M);
1495                 jg(subloop98mask, T_NEAR);
1496                 mov(AA, ORIG_A);
1497                 lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
1498                 L(subloop98mask);
1499             }
1500
1501             mov(LL, N);
1502             mov(I, LL);
1503             if (!isTransA) {
1504                 // If N is too small, skip copy operation
1505                 cmp(LL, UNROLL_N * 3);
1506                 jle(subloop30mask, T_NEAR);
1507
1508                 // If A is not aligned to cache line
1509                 cmp(FLAG, 0);
1510                 je(subloop30mask, T_NEAR);
1511             } else {
1512                 cmp(LL, UNROLL_N);
1513                 jl(l_subloop_mask_20x[1], T_NEAR);
1514             }
1515             align(16);
1516
1517             if (!isTransA) {
1518                 kernel(unroll_m, UNROLL_N, true, true, false);
1519             } else {
1520                 kernel(unroll_m, UNROLL_N, false, false, false);
1521             }
1522
1523             sub(I, UNROLL_N);
1524             cmp(I, UNROLL_N);
1525             jl(l_subloop_mask_20x[1], T_NEAR);
1526             align(16);
1527
1528             L(subloop11mask);
1529             kernel(unroll_m, UNROLL_N, false, false, false);
1530             sub(I, UNROLL_N);
1531             cmp(I, UNROLL_N);
1532             jge(subloop11mask, T_NEAR);
1533             align(16);
1534
1535             for (int i = 1; i <= 7; i++) {
1536                 L(l_subloop_mask_20x[i]);
1537                 cmp(I, i);
1538                 if (i < 7) {
1539                     jne(l_subloop_mask_20x[i + 1], T_NEAR);
1540                 } else {
1541                     jne(subloop99, T_NEAR);
1542                 }
1543                 kernel(unroll_m, i, false, false, false);
1544                 jmp(subloop99, T_NEAR);
1545                 align(16);
1546             }
1547
1548             if (!isTransA) {
1549                 L(subloop30mask);
1550                 cmp(I, UNROLL_N);
1551                 jl(l_subloop_mask_30x[1], T_NEAR);
1552                 align(16);
1553
1554                 L(subloop31mask);
1555                 kernel(unroll_m, UNROLL_N, true, false, false);
1556                 sub(I, UNROLL_N);
1557                 cmp(I, UNROLL_N);
1558                 jge(subloop31mask, T_NEAR);
1559                 align(16);
1560
1561                 for (int i = 1; i <= 7; i++) {
1562                     L(l_subloop_mask_30x[i]);
1563                     cmp(I, i);
1564                     if (i < 7) {
1565                         jne(l_subloop_mask_30x[i + 1], T_NEAR);
1566                     } else {
1567                         jne(subloop99, T_NEAR);
1568                     }
1569                     kernel(unroll_m, i, true, false, false);
1570                     if (i < 7)
1571                         jmp(subloop99, T_NEAR);
1572                     align(16);
1573                 }
1574             }
1575
1576             L(subloop99);
1577             // Compute address for A
1578             if (!isTransA) {
1579                 add(A, unroll_m * SIZE);
1580             } else {
1581                 mov(rax, LDA);
1582                 imul(rax, rax, unroll_m);
1583                 add(A, rax);
1584             }
1585
1586             // Compute next address of BIAS
1587             if (hasBias) {
1588                 add(BIAS, unroll_m * SIZE);
1589             }
1590         };
1591
1592         preamble();
1593
1594         Label buffer_in_ws, buffer_allocated;
1595
1596         // Get the registers
1597         mov(B, ARG_B);
1598         mov(LDB, ARG_LDB);
1599         mov(r15, ARG_BETA);
1600         mov(r12, ARG_C);
1601         if (hasBias)
1602             mov(r10, ARG_BIAS);
1603         mov(LDC, ARG_LDC);
1604         mov(rbp, rsp);
1605
1606         vmovss(xmm0, ptr[ARG_ALPHA]);
1607         vmovss(xmm1, ptr[r15]);
1608
1609 #if _WIN32
1610         mov(A, ARG_A);
1611         mov(LDA, ARG_LDA);
1612 #endif
1613
1614         cmp(K, STACK_K_CAPACITY);
1615         jg(buffer_in_ws, T_NEAR);
1616
1617         // Create buffer and align to 4kB page
1618         lea(rax, ptr[K * SIZE]);
1619         imul(rax, rax, 0x30);
1620         add(rax, 256);
1621         sub(rsp, rax);
1622         and_(rsp, -PAGE_4K);
1623         jmp(buffer_allocated, T_NEAR);
1624
1625         L(buffer_in_ws);
1626         mov(rsp, ARG_WS);
1627
1628         L(buffer_allocated);
1629
1630         mov(ORIG_SP, rbp);
1631         mov(M, ARG_M);
1632         mov(N, ARG_N);
1633         mov(C, r12);
1634         if (hasBias)
1635             mov(BIAS, r10);
1636         vmovss(ALPHA, xmm0);
1637         vmovss(BETA, xmm1);
1638         sub(A, -OFFSET * SIZE);
1639         sub(B, -OFFSET * SIZE);
1640         mov(ORIG_A, A);
1641         sal(LDA, BASE_SHIFT);
1642         sal(LDB, BASE_SHIFT);
1643         sal(LDC, BASE_SHIFT);
1644         lea(LDB3, ptr[LDB + LDB * 2]);
1645
1646         if (isTransA) {
1647             vpbroadcastq(zmm2, LDA);
1648             vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE);
1649             mov(rax, -2);
1650             kmovw(k4, eax);
1651
1652             for (int i = 0; i < 6; i++) {
1653                 vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
1654                 kshiftlw(k4, k4, 1);
1655             }
1656             vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
1657         }
1658
1659         // Check A alignment and leading dimension; take copy-based path as
1660         // needed
1661         mov(rax, LDA);
1662         or_(rax, A);
1663         and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f);
1664         mov(FLAG, rax);
1665
1666         for (int i = 8; i < 16; i++) {
1667             for (int j = 0; j < 3; j++) {
1668                 vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j));
1669             }
1670         }
1671
1672         Label main0, main1, main2, main999;
1673
1674         cmp(M, 32);
1675         jle(main0, T_NEAR);
1676         align(16);
1677
1678         L(main1);
1679         subloop(48);
1680         sub(M, UNROLL_M);
1681         cmp(M, 32);
1682         jg(main1, T_NEAR);
1683         align(16);
1684
1685         L(main0);
1686         cmp(M, 16);
1687         jle(main2, T_NEAR);
1688
1689         subloop(32);
1690         jmp(main999, T_NEAR);
1691         align(16);
1692
1693         L(main2);
1694         cmp(M, 0);
1695         jle(main999, T_NEAR);
1696         subloop(16);
1697         align(16);
1698
1699         L(main999);
1700         // Restore original stack
1701         mov(rsp, ORIG_SP);
1702
1703         vzeroupper();
1704         postamble();
1705
1706         ker_ = this->getCode<ker_t>();
1707     }
1708
1709     typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
1710             const float *alpha, const float *a, dim_t lda,
1711             const float *b, dim_t ldb, const float *beta, float *c,
1712             dim_t ldc, const float *bias, float *ws);
1713
1714     void operator()(dim_t m, dim_t n, dim_t k,
1715             const float *alpha, const float *a, dim_t lda,
1716             const float *b, dim_t ldb, const float *beta, float *c,
1717             dim_t ldc, const float *bias, float *ws) const
1718     {
1719         ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
1720     }
1721
1722 private:
1723     ker_t ker_;
1724 };
1725
1726 const xbyak_gemm *get_xbyak_gemm(
1727         bool isTransA, bool isTransB, float beta, bool hasBias) {
1728     auto beta_idx = [](float beta) {
1729         return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
1730     };
1731
1732     // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
1733     static xbyak_gemm *kernel_table[2][2][2][3];
1734     static std::once_flag initialized;
1735     std::call_once(initialized, [=]{
1736             for (bool isTransA: {false, true})
1737             for (bool isTransB: {false, true})
1738             for (bool hasBias: {false, true})
1739             for (float beta: {0.0f, 1.0f, 2.0f}) {
1740                 // nocopy sgemm with bias for beta != 0.0 is not supported
1741                 if (hasBias && beta != 0.0)
1742                     continue;
1743                 kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
1744                     new xbyak_gemm(isTransA, isTransB, beta, hasBias);
1745             }
1746     });
1747
1748     return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
1749 }
1750
1751 void sgemm_nocopy_driver(const char *transa,
1752         const char *transb, int m, int n, int k, const float *alpha,
1753         const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
1754         float *c, dim_t ldc, const float *bias, float *ws)
1755 {
1756     bool isTransA = (*transa == 'T' || *transa == 't');
1757     bool isTransB = (*transb == 'T' || *transb == 't');
1758
1759     int Bm, sizeM, Bn, sizeN, Bk, sizeK;
1760
1761     int i, j;
1762
1763     if ((m <= 0) || (n <= 0))
1764         return;
1765
1766     if ((k <= 0) || (alpha[0] == 0.)) {
1767
1768         if (beta[0] == 0.) {
1769             for (j = 0; j < n; j++)
1770                 for (i = 0; i < m; i++)
1771                     c[i + j * ldc] = 0.0;
1772         } else if (beta[0] != 1.) {
1773             for (j = 0; j < n; j++)
1774                 for (i = 0; i < m; i++)
1775                     c[i + j * ldc] *= beta[0];
1776         }
1777
1778         return;
1779     }
1780
1781     assert(IMPLICATION(bias != nullptr, *beta == 0.0));
1782
1783     // XXX: this happens on every thread...
1784     bool hasBias = (bias != nullptr);
1785     auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
1786     auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
1787     auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
1788     assert(ker_bn && ker_b1 && ker_b0);
1789
1790     int BM = 4032, BN, BK;
1791     if (mayiuse(avx512_core)) {
1792         BN = isTransA ? 384 : 64;
1793         BK = 384;
1794     } else {
1795         BN = isTransA ? 96 : 64;
1796         BK = isTransB ? 96 : 192;
1797         if (!isTransA && !isTransB)
1798             BK = 128;
1799     }
1800     const float *curA, *curB, *curBias = nullptr;
1801     float *curC;
1802
1803     for (Bk = 0; Bk < k; Bk += sizeK) {
1804         sizeK = k - Bk;
1805         if (sizeK >= BK * 2)
1806             sizeK = BK;
1807         else {
1808             if (sizeK > BK)
1809                 sizeK = (sizeK + 1) / 2;
1810         }
1811
1812         for (Bm = 0; Bm < m; Bm += sizeM) {
1813             sizeM = m - Bm;
1814             if (sizeM >= BM * 2)
1815                 sizeM = BM;
1816             else {
1817                 if (sizeM > BM + BM / 2)
1818                     sizeM = (sizeM + 1) / 2;
1819             }
1820
1821             for (Bn = 0; Bn < n; Bn += sizeN) {
1822                 sizeN = n - Bn;
1823                 if (sizeN >= BN * 2)
1824                     sizeN = BN;
1825                 else {
1826                     if (sizeN > BN + BN / 2)
1827                         sizeN = (sizeN + 1) / 2;
1828                 }
1829
1830                 if (!isTransA) {
1831                     curA = a + Bm + Bk * lda;
1832                 } else {
1833                     curA = a + Bk + Bm * lda;
1834                 }
1835                 if (!isTransB) {
1836                     curB = b + Bk + Bn * ldb;
1837                 } else {
1838                     curB = b + Bn + Bk * ldb;
1839                 }
1840                 curC = c + Bm + (size_t)Bn * ldc;
1841                 if (bias != nullptr) {
1842                     if (Bk == 0) {
1843                         curBias = bias + Bm;
1844                     } else {
1845                         curBias = nullptr;
1846                     }
1847                 }
1848                 if (Bk == 0) {
1849                     if (*beta == 0.0 && bias == nullptr)
1850                         (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
1851                                 alpha, curA, lda, curB, ldb, beta, curC, ldc,
1852                                 curBias, ws);
1853                     else
1854                         (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
1855                                 alpha, curA, lda, curB, ldb, beta, curC, ldc,
1856                                 curBias, ws);
1857                 } else {
1858                     (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
1859                             alpha, curA, lda, curB, ldb, beta, curC, ldc,
1860                             curBias, ws);
1861                 }
1862             }
1863         }
1864     }
1865 }
1866
1867 }
1868
1869 mkldnn_status_t jit_avx512_common_gemm_f32(
1870         const char *transa, const char *transb,
1871         const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
1872         const float *A, const int *p_lda, const float *B, const int *p_ldb,
1873         const float *p_beta, float *C, const int *p_ldc, const float *bias)
1874 {
1875     using namespace mkldnn::impl::utils;
1876     using namespace avx512_common_gemm_f32;
1877     using namespace gemm_utils;
1878
1879     if (*p_beta != 0 && bias)
1880         return ref_gemm(transa, transb, p_m, p_n, p_k,
1881                 p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
1882
1883     int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
1884
1885     int m = *p_m;
1886     int n = *p_n;
1887     int k = *p_k;
1888     dim_t lda = *p_lda;
1889     dim_t ldb = *p_ldb;
1890     dim_t ldc = *p_ldc;
1891     float beta = *p_beta;
1892     int MB, NB, KB;
1893
1894     int nthr_m, nthr_n, nthr_k, nthr_mn;
1895
1896     // Determine threading partitioning
1897     calc_nthr_nocopy_avx512_common(
1898             m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
1899     assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
1900
1901     // May not happen, but just in case
1902     if (nthr < nthr_m * nthr_n * nthr_k)
1903         nthr = nthr_m * nthr_n * nthr_k;
1904
1905     nthr_mn = nthr_m * nthr_n;
1906
1907     unsigned char * ompstatus_ = nullptr;
1908     unsigned char volatile *ompstatus = nullptr;
1909
1910     float *c_buffers = nullptr;
1911     float *ws_buffers = nullptr;
1912
1913     if (nthr_k > 1) {
1914         ompstatus_ = (unsigned char *) malloc(
1915                 nthr * CACHE_LINE_SIZE,
1916                 CACHE_LINE_SIZE);
1917         ompstatus = (unsigned char volatile *) ompstatus_;
1918         assert(ompstatus);
1919
1920         for (int i = 0; i < nthr; i++)
1921             ompstatus[i * CACHE_LINE_SIZE] = 0;
1922
1923         c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
1924                 * sizeof(float), PAGE_4K);
1925     }
1926
1927     const size_t ws_elems_per_thr = (size_t)k * 48 + 64;
1928     const size_t ws_size_per_thr
1929             = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
1930     if (k > STACK_K_CAPACITY) {
1931         ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
1932     }
1933
1934     parallel_nd(nthr, [&](const int ithr) {
1935         int ithr_m, ithr_n, ithr_k, ithr_mn;
1936         int m_from, m_to, myM;
1937         int n_from, n_to, myN;
1938         int k_from, k_to, myK;
1939         int cbase, ibase;
1940         const float *myA, *myB, *myBias = nullptr;
1941         float *myC = C, myBeta;
1942         float *ws = ws_buffers ?
1943                 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
1944         dim_t ld = ldc;
1945
1946         int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
1947
1948         if (ithr < nthr_m * nthr_n * nthr_k) {
1949
1950             ithr_mn = ithr % nthr_mn;
1951             ithr_m = ithr_mn % nthr_m;
1952             ithr_n = ithr_mn / nthr_m;
1953             ithr_k = ithr / nthr_mn;
1954
1955             /* swap ithr_k for performance improvement */
1956             if (ithr_k == 0)
1957                 ithr_k = nthr_k - 1;
1958             else if (ithr_k == nthr_k - 1)
1959                 ithr_k = 0;
1960
1961             m_from = MB * (ithr_m);
1962             m_to = MB * (ithr_m + 1);
1963             if (m_to > m)
1964                 m_to = m;
1965             myM = m_to - m_from;
1966
1967             n_from = NB * (ithr_n);
1968             n_to = NB * (ithr_n + 1);
1969             if (n_to > n)
1970                 n_to = n;
1971             myN = n_to - n_from;
1972
1973             k_from = KB * (ithr_k);
1974             k_to = KB * (ithr_k + 1);
1975             if (k_to > k)
1976                 k_to = k;
1977             myK = k_to - k_from;
1978
1979             cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
1980             ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
1981
1982             if ((myM > 0) && (myN > 0)) {
1983
1984                 if (*transa == 'N' || *transa == 'n') {
1985                     myA = &(A[m_from + k_from * lda]);
1986                 } else {
1987                     myA = &(A[k_from + m_from * lda]);
1988                 }
1989                 if (*transb == 'N' || *transb == 'n') {
1990                     myB = &(B[k_from + n_from * ldb]);
1991                 } else {
1992                     myB = &(B[n_from + k_from * ldb]);
1993                 }
1994                 if (ithr_k == 0) {
1995                     myC = &(C[m_from + n_from * ldc]);
1996                     myBeta = beta;
1997                     ld = ldc;
1998                     if (bias)
1999                         myBias = &(bias[m_from]);
2000                 } else {
2001                     myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
2002                     myBeta = 0.0;
2003                     ld = MB;
2004                     myBias = nullptr;
2005                 }
2006
2007                 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
2008                         lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
2009
2010                 if (nthr_k > 1 && !sum_later)
2011                     ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
2012             }
2013
2014             if (nthr_k > 1 && !sum_later) {
2015
2016                 // sum matrices partitioned along K dimension
2017                 int n1, n2;
2018
2019                 partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2020
2021                 if (ithr_k > 0) {
2022
2023                     myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
2024                         + (dim_t)n1 * MB;
2025                     /* need to wait until main thread finishes */
2026                     while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
2027                     };
2028
2029                     /* my cache is hot */
2030                     sum_two_matrices(myM, n2, myC, MB,
2031                             &C[m_from + (n_from + n1) * ldc], ldc);
2032                 }
2033
2034                 for (int ik = 1; ik < nthr_k; ++ik) {
2035                     if (ik != ithr_k) {
2036
2037                         myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
2038                             + (dim_t)n1 * MB;
2039
2040                         while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2041                         };
2042
2043                         sum_two_matrices(myM, n2, myC, MB,
2044                                 &C[m_from + (n_from + n1) * ldc], ldc);
2045                     }
2046                 }
2047             }
2048         }
2049     });
2050
2051
2052     // handle C summation later
2053     if (nthr_k > 1 && ompstatus[0] == 0) {
2054
2055         parallel_nd(nthr, [&](const int ithr) {
2056             int ithr_m, ithr_n, ithr_k, ithr_mn;
2057             int m_from, m_to, myM;
2058             int n_from, n_to, myN;
2059             int cbase;
2060             float *myC = C;
2061
2062             if (ithr < nthr_m * nthr_n * nthr_k) {
2063
2064                 ithr_mn = ithr % nthr_mn;
2065                 ithr_m = ithr_mn % nthr_m;
2066                 ithr_n = ithr_mn / nthr_m;
2067                 ithr_k = ithr / nthr_mn;
2068
2069                 /* swap ithr_k for performance improvement */
2070                 if (ithr_k == 0)
2071                     ithr_k = nthr_k - 1;
2072                 else if (ithr_k == nthr_k - 1)
2073                     ithr_k = 0;
2074
2075                 m_from = MB * (ithr_m);
2076                 m_to = MB * (ithr_m + 1);
2077                 if (m_to > m)
2078                     m_to = m;
2079                 myM = m_to - m_from;
2080
2081                 n_from = NB * (ithr_n);
2082                 n_to = NB * (ithr_n + 1);
2083                 if (n_to > n)
2084                     n_to = n;
2085                 myN = n_to - n_from;
2086
2087                 cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2088
2089                 if (nthr_k > 1) {
2090                     // sum matrices partitioned along K dimension
2091                     int n1, n2;
2092
2093                     partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2094
2095                     if (ithr_k > 0) {
2096
2097                         myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
2098                             + (dim_t)n1 * MB;
2099
2100                         /* my cache is hot */
2101                         sum_two_matrices(myM, n2, myC, MB,
2102                                          &C[m_from + (n_from + n1) * ldc], ldc);
2103                     }
2104
2105                     for (int ik = 1; ik < nthr_k; ++ik) {
2106                         if (ik != ithr_k) {
2107
2108                             myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
2109                                 + (dim_t)n1 * MB;
2110
2111                             sum_two_matrices(myM, n2, myC, MB,
2112                                              &C[m_from + (n_from + n1) * ldc], ldc);
2113                         }
2114                     }
2115                 }
2116             }
2117         });
2118     }
2119
2120     free(c_buffers);
2121     free(ompstatus_);
2122     free(ws_buffers);
2123
2124     return mkldnn_success;
2125 }
2126
2127 }
2128 }
2129 }
2130
2131 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s