Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / f32 / jit_avx_gemm_f32.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cmath>
18 #include <mutex>
19
20 #include "mkldnn_thread.hpp"
21 #include "utils.hpp"
22
23 #include "ref_gemm_f32.hpp"
24 #include "gemm_utils_f32.hpp"
25 #include "jit_avx_gemm_f32.hpp"
26
27 #include "jit_generator.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 #define CACHE_LINE_SIZE 64
34
35 #define STACKSIZE get_size_of_abi_save_regs()
36 #if _WIN32
37 #define STACK_K_CAPACITY 128
38 #else
39 #define STACK_K_CAPACITY 8192
40 #endif
41 #define SIZE 4
42 #define OFFSET 32
43 #define BASE_SHIFT 2
44 #define SECOND_FETCH 14
45
46 namespace avx_gemm_f32 {
47 using namespace gemm_utils;
48
49 struct xbyak_gemm : public jit_generator {
50     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
51
52     xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
53             void *code_ptr = nullptr,
54             size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
55         : jit_generator(code_ptr, code_size)
56     {
57         using namespace Xbyak;
58
59         const bool is_avx2 = mayiuse(avx2);
60         assert(IMPLICATION(!is_avx2, mayiuse(avx)));
61
62         const int UNROLL_M = is_avx2 ? 16 : 8;
63         const int UNROLL_N = 6;
64
65         bool isBeta0 = (beta == 0.0);
66         bool isBetaN = (!isBeta0 && beta != 1.0);
67
68         // various definitions for convenience
69         auto ARG_M = abi_param1;
70         auto ARG_N = abi_param2;
71         auto K = abi_param3;
72         auto ARG_ALPHA = abi_param4;
73 #ifdef _WIN32
74         auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
75         auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
76             sizeof(float *) + STACKSIZE];
77         const auto stackOffset = OFFSET_SHADOWSPACE +
78             sizeof(float *) + STACKSIZE;
79         auto A = rsi;
80         auto LDA = rdi;
81 #else
82         auto ARG_A = r8;
83         auto ARG_LDA = r9;
84         const auto stackOffset = STACKSIZE;
85         auto A = ARG_A;
86         auto LDA = ARG_LDA;
87 #endif
88         auto ARG_B = ptr[rsp + 8 + stackOffset];
89         auto ARG_LDB = ptr[rsp + 16 + stackOffset];
90         auto ARG_BETA = ptr[rsp + 24 + stackOffset];
91         auto ARG_C = ptr[rsp + 32 + stackOffset];
92         auto ARG_LDC = ptr[rsp + 40 + stackOffset];
93         auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
94         auto ARG_WS = ptr[rsp + 56 + stackOffset];
95
96         auto B = r11;
97         auto LDB = rbx;
98         auto LDC = r13;
99         auto LL = rax;
100         auto AO1 = abi_param2;
101         auto BO1 = abi_param4;
102         auto BO2 = rbp;
103         auto CO1 = r14;
104         auto CO2 = r15;
105         auto LDB3 = r10;
106         auto LDA4 = abi_param1;
107         auto AA = r12;
108         auto BIAS1 = abi_param1;
109
110         auto M = qword[rsp + 0];
111         auto N = qword[rsp + 8];
112         auto FLAG = qword[rsp + 16];
113         auto I = qword[rsp + 24];
114         auto C = qword[rsp + 32];
115         auto BIAS = qword[rsp + 40];
116         auto ALPHA = qword[rsp + 48];
117         auto BETA = qword[rsp + 64];
118         auto ORIG_A = qword[rsp + 80];
119         auto MASK = dword[rsp + 88];
120         auto STRIDE = qword[rsp + 120];
121         auto ORIG_SP = qword[rsp + 152];
122
123         auto VALPHA = ymm1;
124         auto VBETA = ymm2;
125         auto VMASK = ymm3;
126         auto VBIAS1 = ymm2;
127         auto VBIAS2 = ymm4;
128
129         auto PREFETCHSIZEA = 128;
130         auto PREFETCHSIZEB = (!isTransB) ? -16 : 0;
131
132         // Function for packing if needed
133         auto do_pack = [&](
134                 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
135             Label pack2, pack3, pack4, pack10;
136
137             int regIdx;
138             Reg64 reg;
139
140             mov(BO1, A);
141             lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
142
143             if (isTransA) {
144                 lea(BO2, ptr[BO1 + LDA * 4]);
145                 lea(CO1, ptr[LDA + LDA * 2]);
146                 vmovupd(ymm7, STRIDE);
147             }
148
149             mov(LL, K);
150             sar(LL, 2);
151             jle(pack3, T_NEAR);
152             align(16);
153
154             L(pack2);
155             if (!isTransA) {
156                 for (int i = 0; i < 4; i++) {
157                     regIdx = (i % 2 == 0) ? 4 : 6;
158                     if (isLoad1Unmasked) {
159                         vmovups(Ymm(regIdx),
160                                 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
161                     } else {
162                         vmaskmovps(Ymm(regIdx), VMASK,
163                                 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
164                     }
165                     if (unroll_m > 8) {
166                         if (isLoad2Unmasked) {
167                             vmovups(Ymm(regIdx + 1),
168                                     ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
169                         } else {
170                             vmaskmovps(Ymm(regIdx + 1), VMASK,
171                                     ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
172                         }
173                     }
174                     add(BO1, LDA);
175
176                     vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
177                             Ymm(regIdx));
178                     if (unroll_m > 8) {
179                         vmovups(ptr[AO1
180                                         + (unroll_m * i + 1 * 8 - OFFSET)
181                                                 * SIZE],
182                                 Ymm(regIdx + 1));
183                     }
184                 }
185
186             } else {
187                 if (isLoad1Unmasked) {
188                     for (int i = 0; i < 2; i++) {
189                         reg = (i % 2 == 0) ? BO1 : BO2;
190                         vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]);
191                         vmovups(xmm1,
192                                 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
193                         lea(BO2, ptr[reg + LDA * 2]);
194                         vunpcklps(xmm4, xmm0, xmm1);
195                         vunpckhps(xmm5, xmm0, xmm1);
196                         vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
197                         vmovups(xmm1,
198                                 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
199                         lea(BO2, ptr[BO2 + LDA * 2]);
200                         vunpcklps(xmm6, xmm0, xmm1);
201                         vunpckhps(xmm2, xmm0, xmm1);
202
203                         vunpcklpd(xmm0, xmm4, xmm6);
204                         vunpckhpd(xmm1, xmm4, xmm6);
205                         vmovups(ptr[AO1
206                                         + (unroll_m * 0 + i * 4 - OFFSET)
207                                                 * SIZE],
208                                 xmm0);
209                         vmovups(ptr[AO1
210                                         + (unroll_m * 1 + i * 4 - OFFSET)
211                                                 * SIZE],
212                                 xmm1);
213                         vunpcklpd(xmm0, xmm5, xmm2);
214                         vunpckhpd(xmm1, xmm5, xmm2);
215                         vmovups(ptr[AO1
216                                         + (unroll_m * 2 + i * 4 - OFFSET)
217                                                 * SIZE],
218                                 xmm0);
219                         vmovups(ptr[AO1
220                                         + (unroll_m * 3 + i * 4 - OFFSET)
221                                                 * SIZE],
222                                 xmm1);
223                     }
224                 } else if (is_avx2) {
225                     for (int i = 0; i < 2; i++) {
226                         vmovaps(xmm4, xmm3);
227                         vgatherqps(xmm0,
228                                 ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE],
229                                 xmm4);
230                         vmovaps(xmm4, xmm3);
231                         vgatherqps(xmm1,
232                                 ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
233                                 xmm4);
234
235                         vmovups(ptr[AO1
236                                         + (unroll_m * (2 * i) + 0 * 4 - OFFSET)
237                                                 * SIZE],
238                                 xmm0);
239                         vmovups(ptr[AO1
240                                         + (unroll_m * (2 * i + 1) + 0 * 4
241                                                   - OFFSET)
242                                                 * SIZE],
243                                 xmm1);
244                     }
245
246                     lea(BO2, ptr[BO1 + LDA * 4]);
247
248                     for (int i = 0; i < 2; i++) {
249                         vextractf128(xmm4, ymm3, 1);
250                         vgatherqps(xmm0,
251                                 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
252                                 xmm4);
253                         vextractf128(xmm4, ymm3, 1);
254                         vgatherqps(xmm1,
255                                 ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
256                                 xmm4);
257
258                         vmovups(ptr[AO1
259                                         + (unroll_m * (2 * i) + 1 * 4 - OFFSET)
260                                                 * SIZE],
261                                 xmm0);
262                         vmovups(ptr[AO1
263                                         + (unroll_m * (2 * i + 1) + 1 * 4
264                                                   - OFFSET)
265                                                 * SIZE],
266                                 xmm1);
267                     }
268
269                     lea(BO2, ptr[BO2 + LDA * 4]);
270                 } else {
271                     vxorps(xmm4, xmm4, xmm4);
272                     lea(BO2, ptr[BO1 + LDA * 4]);
273
274                     auto el_cp = [&](int section, int ld_step) {
275                         RegExp src_addr = section == 0 ? BO1 : BO2;
276                         if (ld_step == 1 || ld_step == 2)
277                             src_addr = src_addr + LDA * ld_step;
278                         else if (ld_step == 3)
279                             src_addr = src_addr + CO1;
280                         src_addr = src_addr - OFFSET * SIZE;
281
282                         vmovups(Xmm(ld_step % 2), ptr[src_addr]);
283                         RegExp dst_addr = AO1
284                             + (ld_step + section * 4 - OFFSET) * SIZE;
285                         for (int off = 0; off < 4; ++off)
286                             pextrd(ptr[dst_addr + unroll_m * off * SIZE],
287                                     Xmm(ld_step % 2), off);
288                     };
289
290                     Label l_end;
291                     el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
292                     el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
293                     el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
294                     el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
295                     el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
296                     el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
297                     el_cp(1, 2);
298                     L(l_end);
299
300                     lea(BO2, ptr[BO2 + LDA * 4]);
301                 }
302
303                 if (unroll_m >= 16) {
304                     assert(is_avx2);
305                     if (isLoad2Unmasked) {
306                         for (int i = 0; i < 2; i++) {
307                             vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
308                             vmovups(xmm1, ptr[BO2 + LDA * 1
309                                                   + (0 * 8 - OFFSET) * SIZE]);
310                             lea(BO2, ptr[BO2 + LDA * 2]);
311                             vunpcklps(xmm4, xmm0, xmm1);
312                             vunpckhps(xmm5, xmm0, xmm1);
313                             vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
314                             vmovups(xmm1, ptr[BO2 + LDA * 1
315                                                   + (0 * 8 - OFFSET) * SIZE]);
316                             if (i == 0)
317                                 lea(BO2, ptr[BO2 + LDA * 2]);
318                             vunpcklps(xmm6, xmm0, xmm1);
319                             vunpckhps(xmm2, xmm0, xmm1);
320
321                             vunpcklpd(xmm0, xmm4, xmm6);
322                             vunpckhpd(xmm1, xmm4, xmm6);
323                             vmovups(ptr[AO1
324                                             + (unroll_m * 0 + (i + 2) * 4
325                                                       - OFFSET)
326                                                     * SIZE],
327                                     xmm0);
328                             vmovups(ptr[AO1
329                                             + (unroll_m * 1 + (i + 2) * 4
330                                                       - OFFSET)
331                                                     * SIZE],
332                                     xmm1);
333                             vunpcklpd(xmm0, xmm5, xmm2);
334                             vunpckhpd(xmm1, xmm5, xmm2);
335                             vmovups(ptr[AO1
336                                             + (unroll_m * 2 + (i + 2) * 4
337                                                       - OFFSET)
338                                                     * SIZE],
339                                     xmm0);
340                             vmovups(ptr[AO1
341                                             + (unroll_m * 3 + (i + 2) * 4
342                                                       - OFFSET)
343                                                     * SIZE],
344                                     xmm1);
345                         }
346                     } else {
347                         for (int i = 0; i < 2; i++) {
348                             vmovaps(xmm4, xmm3);
349                             vgatherqps(xmm0,
350                                     ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
351                                     xmm4);
352                             vmovaps(xmm4, xmm3);
353                             vgatherqps(xmm1,
354                                     ptr[BO2 + ymm7
355                                                + ((2 * i + 1) - OFFSET) * SIZE],
356                                     xmm4);
357
358                             vmovups(ptr[AO1
359                                             + (unroll_m * (2 * i) + 2 * 4
360                                                       - OFFSET)
361                                                     * SIZE],
362                                     xmm0);
363                             vmovups(ptr[AO1
364                                             + (unroll_m * (2 * i + 1) + 2 * 4
365                                                       - OFFSET)
366                                                     * SIZE],
367                                     xmm1);
368                         }
369
370                         lea(BO2, ptr[BO2 + LDA * 4]);
371
372                         for (int i = 0; i < 2; i++) {
373                             vextractf128(xmm4, ymm3, 1);
374                             vgatherqps(xmm0,
375                                     ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
376                                     xmm4);
377                             vextractf128(xmm4, ymm3, 1);
378                             vgatherqps(xmm1,
379                                     ptr[BO2 + ymm7
380                                                + ((2 * i + 1) - OFFSET) * SIZE],
381                                     xmm4);
382
383                             vmovups(ptr[AO1
384                                             + (unroll_m * (2 * i) + 3 * 4
385                                                       - OFFSET)
386                                                     * SIZE],
387                                     xmm0);
388                             vmovups(ptr[AO1
389                                             + (unroll_m * (2 * i + 1) + 3 * 4
390                                                       - OFFSET)
391                                                     * SIZE],
392                                     xmm1);
393                         }
394
395                         lea(BO2, ptr[BO2 + LDA * 4]);
396                     }
397                 }
398                 add(BO1, (4 * SIZE));
399             }
400
401             add(AO1, unroll_m * 4 * SIZE);
402             sub(LL, 1);
403             jg(pack2, T_NEAR);
404             align(16);
405
406             L(pack3);
407             mov(LL, K);
408             and_(LL, 3);
409             jle(pack10, T_NEAR);
410             align(16);
411
412             L(pack4);
413             if (!isTransA) {
414                 if (isLoad1Unmasked) {
415                     vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
416                 } else {
417                     vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
418                 }
419                 if (unroll_m > 8) {
420                     if (isLoad2Unmasked) {
421                         vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
422                     } else {
423                         vmaskmovps(ymm5, VMASK,
424                                 ptr[BO1 + (1 + 8 - OFFSET) * SIZE]);
425                     }
426                 }
427                 add(BO1, LDA);
428                 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
429                         ymm4);
430                 if (unroll_m > 8) {
431                     vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
432                             ymm5);
433                 }
434             } else {
435                 if (isLoad1Unmasked) {
436                     for (int i = 0; i < 2; i++) {
437                         reg = (i % 2 == 0) ? BO1 : BO2;
438                         vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]);
439                         vmovss(xmm0,
440                                 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
441                         lea(BO2, ptr[reg + LDA * 2]);
442                         vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
443                     }
444                     vunpcklpd(xmm1, xmm1, xmm2);
445                     vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
446                             xmm1);
447
448                     for (int i = 0; i < 2; i++) {
449                         vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
450                         vmovss(xmm0,
451                                 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
452                         lea(BO2, ptr[BO2 + LDA * 2]);
453                         vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
454                     }
455                     vunpcklpd(xmm1, xmm1, xmm2);
456                     vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
457                             xmm1);
458                 } else if (is_avx2) {
459                     vmovaps(xmm4, xmm3);
460                     vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE],
461                             xmm4);
462                     lea(BO2, ptr[BO1 + LDA * 4]);
463                     vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
464                             xmm1);
465
466                     vextractf128(xmm4, ymm3, 1);
467                     vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
468                             xmm4);
469                     lea(BO2, ptr[BO2 + LDA * 4]);
470                     vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
471                             xmm1);
472                 } else {
473                     vxorps(xmm4, xmm4, xmm4);
474                     lea(BO2, ptr[BO1 + LDA * 4]);
475
476                     auto el_cp = [&](int section, int ld_step) {
477                         RegExp src_addr = section == 0 ? BO1 : BO2;
478                         if (ld_step == 1 || ld_step == 2)
479                             src_addr = src_addr + LDA * ld_step;
480                         else if (ld_step == 3)
481                             src_addr = src_addr + CO1;
482                         src_addr = src_addr - OFFSET * SIZE;
483
484                         vmovss(xmm1, ptr[src_addr]);
485                         RegExp dst_addr = AO1
486                             + (ld_step + section * 4 - OFFSET) * SIZE;
487                         movss(ptr[dst_addr], xmm1);
488                     };
489
490                     Label l_end;
491                     el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
492                     el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
493                     el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
494                     el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
495                     el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
496                     el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
497                     el_cp(1, 2);
498                     L(l_end);
499
500                     lea(BO2, ptr[BO2 + LDA * 4]);
501                 }
502
503                 if (unroll_m >= 16) {
504                     assert(is_avx2);
505                     if (isLoad2Unmasked) {
506                         for (int i = 0; i < 2; i++) {
507                             vmovss(Xmm(i + 1),
508                                     ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
509                             vmovss(xmm0, ptr[BO2 + LDA * 1
510                                                  + (0 * 8 - OFFSET) * SIZE]);
511                             lea(BO2, ptr[BO2 + LDA * 2]);
512                             vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
513                         }
514                         vunpcklpd(xmm1, xmm1, xmm2);
515                     } else {
516                         vmovaps(xmm4, xmm3);
517                         vgatherqps(xmm1,
518                                 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
519                                 xmm4);
520                         lea(BO2, ptr[BO2 + LDA * 4]);
521                     }
522                     vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE],
523                             xmm1);
524
525                     if (isLoad2Unmasked) {
526                         for (int i = 0; i < 2; i++) {
527                             vmovss(Xmm(i + 1),
528                                     ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
529                             vmovss(xmm0, ptr[BO2 + LDA * 1
530                                                  + (0 * 8 - OFFSET) * SIZE]);
531                             lea(BO2, ptr[BO2 + LDA * 2]);
532                             vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
533                         }
534                         vunpcklpd(xmm1, xmm1, xmm2);
535                     } else {
536                         vextractf128(xmm4, ymm3, 1);
537                         vgatherqps(xmm1,
538                                 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
539                                 xmm4);
540                     }
541                     vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE],
542                             xmm1);
543                 }
544                 add(BO1, SIZE);
545             }
546
547             add(AO1, unroll_m * SIZE);
548             sub(LL, 1);
549             jg(pack4, T_NEAR);
550             align(16);
551
552             L(pack10);
553         };
554
555         // Fused multiply add; may become one or two instructions
556         auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2,
557                 bool overWrite = false) {
558             if (useFma) {
559                 if (is_avx2) {
560                     vfmadd231ps(reg2, reg1, reg0);
561                 } else {
562                     assert(UNROLL_M == 8);
563                     auto tent_vreg = overWrite ? reg1 : ymm1;
564                     vmulps(tent_vreg, reg1, reg0);
565                     vaddps(reg2, reg2, tent_vreg);
566                 }
567             } else {
568                 if (!overWrite) {
569                     vmulps(ymm15, reg1, reg0);
570                     vaddps(reg2, reg2, ymm15);
571                 } else {
572                     vmulps(reg1, reg1, reg0);
573                     vaddps(reg2, reg2, reg1);
574                 }
575             }
576         };
577
578         // Inner kernel with k=8
579         auto innerkernel8 = [&](int unroll_m, int unroll_n,
580                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
581                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
582                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
583                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
584                 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
585                 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
586                 Ymm reg23) {
587
588             Ymm fmareg;
589
590             if (!isDirect) {
591                 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
592             } else {
593                 prefetcht0(ptr[AO1 + LDA4]);
594             }
595
596             for (int i = 0; i < 8; i++) {
597                 if (isDirect) {
598                     if (isLoad1Unmasked) {
599                         vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
600                     } else {
601                         vmaskmovps(ymm0, VMASK,
602                                 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
603                     }
604                     if (unroll_m >= 16) {
605                         if (isLoad2Unmasked) {
606                             vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
607                         } else {
608                             vmaskmovps(ymm1, VMASK,
609                                     ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
610                         }
611                     }
612                     add(AO1, LDA);
613                 }
614
615                 if (!isTransB) {
616                     vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
617                 } else {
618                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
619                 }
620                 fmareg = (i % 2 == 0) ? reg00 : reg12;
621                 fma(useFma, ymm0, ymm2, fmareg);
622                 if (unroll_m >= 16) {
623                     fmareg = (i % 2 == 0) ? reg06 : reg18;
624                     fma(useFma, ymm1, ymm2, fmareg);
625                 }
626                 if (i == 0) {
627                     if (!isTransB) {
628                         prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
629                     }
630                 }
631                 if (unroll_n >= 2) {
632                     if (!isTransB) {
633                         if (i == 1) {
634                             prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
635                         }
636                         vbroadcastss(
637                                 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
638                     } else {
639                         vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
640                     }
641                     fmareg = (i % 2 == 0) ? reg01 : reg13;
642                     fma(useFma, ymm0, ymm2, fmareg);
643                     if (unroll_m >= 16) {
644                         fmareg = (i % 2 == 0) ? reg07 : reg19;
645                         fma(useFma, ymm1, ymm2, fmareg);
646                     }
647                 }
648
649                 if (isCopy) {
650                     vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
651                             ymm0);
652                     if (unroll_m >= 16) {
653                         vmovups(ptr[LDA4
654                                         + (unroll_m * i + 1 * 8 - OFFSET)
655                                                 * SIZE],
656                                 ymm1);
657                     }
658                     if (i == 7) {
659                         sub(LDA4, -unroll_m * 8 * SIZE);
660                     }
661                 }
662
663                 if (unroll_n >= 3) {
664                     if (!isTransB) {
665                         if (i == 2) {
666                             prefetcht0(
667                                     ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
668                         }
669                         vbroadcastss(
670                                 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
671                     } else {
672                         vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
673                     }
674                     fmareg = (i % 2 == 0) ? reg02 : reg14;
675                     fma(useFma, ymm0, ymm2, fmareg);
676                     if (unroll_m >= 16) {
677                         fmareg = (i % 2 == 0) ? reg08 : reg20;
678                         fma(useFma, ymm1, ymm2, fmareg);
679                     }
680                 }
681
682                 if (i == 7) {
683                     if (!isTransB) {
684                         sub(BO1, -8 * SIZE);
685                     }
686                 }
687
688                 if (unroll_n >= 4) {
689                     if (!isTransB) {
690                         if (i == 3) {
691                             prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
692                         }
693                         vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
694                     } else {
695                         vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
696                     }
697                     fmareg = (i % 2 == 0) ? reg03 : reg15;
698                     fma(useFma, ymm0, ymm2, fmareg);
699                     if (unroll_m >= 16) {
700                         fmareg = (i % 2 == 0) ? reg09 : reg21;
701                         fma(useFma, ymm1, ymm2, fmareg);
702                     }
703                 }
704
705                 if (unroll_n >= 5) {
706                     if (!isTransB) {
707                         if (i == 4) {
708                             prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
709                         }
710                         vbroadcastss(
711                                 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
712                     } else {
713                         vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
714                     }
715                     fmareg = (i % 2 == 0) ? reg04 : reg16;
716                     fma(useFma, ymm0, ymm2, fmareg);
717                     if (unroll_m >= 16) {
718                         fmareg = (i % 2 == 0) ? reg10 : reg22;
719                         fma(useFma, ymm1, ymm2, fmareg);
720                     }
721                 }
722
723                 if (unroll_n >= 6) {
724                     if (!isTransB) {
725                         if (i == 5) {
726                             prefetcht0(
727                                     ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
728                         }
729                         vbroadcastss(
730                                 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
731                     } else {
732                         vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
733                     }
734                     fmareg = (i % 2 == 0) ? reg05 : reg17;
735                     fma(useFma, ymm0, ymm2, fmareg);
736                     if (unroll_m >= 16) {
737                         fmareg = (i % 2 == 0) ? reg11 : reg23;
738                         fma(useFma, ymm1, ymm2, fmareg);
739                     }
740                 }
741                 if (isTransB) {
742                     prefetcht0(ptr[BO1 + BO2]);
743                     add(BO1, LDB);
744                 }
745
746                 if (i == 0) {
747                     if (unroll_m >= 4) {
748                         if (!isDirect) {
749                             prefetcht0(
750                                     ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
751                         } else {
752                             prefetcht0(ptr[AO1 + LDA4]);
753                         }
754                     }
755                 }
756                 if (i == 1 || i == 2) {
757                     if (unroll_m >= 8) {
758                         if (!isDirect) {
759                             prefetcht0(ptr[AO1
760                                     + (PREFETCHSIZEA + (2 + 2 * i) * 8)
761                                             * SIZE]);
762                         } else {
763                             prefetcht0(ptr[AO1 + LDA4]);
764                         }
765                     }
766                 }
767                 if (i == 3 || i == 4 || i == 5 || i == 6) {
768                     if (unroll_m >= 16) {
769                         if (!isDirect) {
770                             prefetcht0(ptr[AO1
771                                     + (PREFETCHSIZEA + (2 + 2 * i) * 8)
772                                             * SIZE]);
773                         } else {
774                             prefetcht0(ptr[AO1 + LDA4]);
775                         }
776                     }
777                 }
778                 if (i == 7) {
779                     if (!isTransB) {
780                         if (unroll_n >= 4) {
781                             sub(BO2, -8 * SIZE);
782                         }
783                     }
784                     if (!isTransA) {
785                         prefetcht2(ptr[AA]);
786                         lea(AA, ptr[AA + LDA]);
787                     }
788                 }
789
790                 if (!isDirect) {
791                     if (isLoad1Unmasked) {
792                         vmovups(ymm0,
793                                 ptr[AO1
794                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
795                                                 * SIZE]);
796                     } else {
797                         vmaskmovps(
798                                 ymm0, VMASK,
799                                 ptr[AO1
800                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
801                                                 * SIZE]);
802                     }
803                     if (unroll_m >= 16) {
804                         if (isLoad2Unmasked) {
805                             vmovups(ymm1, ptr[AO1
806                                                   + (unroll_m * (i + 1) + 1 * 8
807                                                             - OFFSET)
808                                                           * SIZE]);
809                         } else {
810                             vmaskmovps(ymm1, VMASK,
811                                     ptr[AO1
812                                                + (unroll_m * (i + 1) + 1 * 8
813                                                          - OFFSET)
814                                                        * SIZE]);
815                         }
816                     }
817                 }
818             }
819
820             if (!isDirect) {
821                 sub(AO1, -unroll_m * 8 * SIZE);
822             }
823             sub(LL, 1);
824
825         };
826
827         // Inner kernel with k=4
828         auto innerkernel4 = [&](int unroll_m, int unroll_n,
829                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
830                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
831                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
832                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
833                 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
834                 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
835                 Ymm reg23) {
836
837             Ymm fmareg;
838
839             if (!isDirect) {
840                 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
841             } else {
842                 prefetcht0(ptr[AO1 + LDA4]);
843             }
844
845             for (int i = 0; i < 4; i++) {
846                 if (isDirect) {
847                     if (isLoad1Unmasked) {
848                         vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
849                     } else {
850                         vmaskmovps(ymm0, VMASK,
851                                 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
852                     }
853                     if (unroll_m >= 16) {
854                         if (isLoad2Unmasked) {
855                             vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
856                         } else {
857                             vmaskmovps(ymm1, VMASK,
858                                     ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
859                         }
860                     }
861                     add(AO1, LDA);
862                 }
863
864                 if (!isTransB) {
865                     vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
866                 } else {
867                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
868                 }
869                 fmareg = (i % 2 == 0) ? reg00 : reg12;
870                 fma(useFma, ymm0, ymm2, fmareg);
871                 if (unroll_m >= 16) {
872                     fmareg = (i % 2 == 0) ? reg06 : reg18;
873                     fma(useFma, ymm1, ymm2, fmareg);
874                 }
875                 if (i == 0) {
876                     if (!isTransB) {
877                         prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
878                     }
879                 }
880                 if (unroll_n >= 2) {
881                     if (!isTransB) {
882                         if (i == 1) {
883                             prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
884                         }
885                         vbroadcastss(
886                                 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
887                     } else {
888                         vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
889                     }
890                     fmareg = (i % 2 == 0) ? reg01 : reg13;
891                     fma(useFma, ymm0, ymm2, fmareg);
892                     if (unroll_m >= 16) {
893                         fmareg = (i % 2 == 0) ? reg07 : reg19;
894                         fma(useFma, ymm1, ymm2, fmareg);
895                     }
896                 }
897
898                 if (isCopy) {
899                     vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
900                             ymm0);
901                     if (unroll_m >= 16) {
902                         vmovups(ptr[LDA4
903                                         + (unroll_m * i + 1 * 8 - OFFSET)
904                                                 * SIZE],
905                                 ymm1);
906                     }
907                     if (i == 3) {
908                         sub(LDA4, -unroll_m * 4 * SIZE);
909                     }
910                 }
911
912                 if (unroll_n >= 3) {
913                     if (!isTransB) {
914                         if (i == 2) {
915                             prefetcht0(
916                                     ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
917                         }
918                         vbroadcastss(
919                                 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
920                     } else {
921                         vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
922                     }
923                     fmareg = (i % 2 == 0) ? reg02 : reg14;
924                     fma(useFma, ymm0, ymm2, fmareg);
925                     if (unroll_m >= 16) {
926                         fmareg = (i % 2 == 0) ? reg08 : reg20;
927                         fma(useFma, ymm1, ymm2, fmareg);
928                     }
929                 }
930
931                 if (i == 7) {
932                     if (!isTransB) {
933                         sub(BO1, -8 * SIZE);
934                     }
935                 }
936
937                 if (unroll_n >= 4) {
938                     if (!isTransB) {
939                         if (i == 3) {
940                             prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
941                         }
942                         vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
943                     } else {
944                         vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
945                     }
946                     fmareg = (i % 2 == 0) ? reg03 : reg15;
947                     fma(useFma, ymm0, ymm2, fmareg);
948                     if (unroll_m >= 16) {
949                         fmareg = (i % 2 == 0) ? reg09 : reg21;
950                         fma(useFma, ymm1, ymm2, fmareg);
951                     }
952                 }
953
954                 if (unroll_n >= 5) {
955                     if (!isTransB) {
956                         if (i == 4) {
957                             prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
958                         }
959                         vbroadcastss(
960                                 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
961                     } else {
962                         vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
963                     }
964                     fmareg = (i % 2 == 0) ? reg04 : reg16;
965                     fma(useFma, ymm0, ymm2, fmareg);
966                     if (unroll_m >= 16) {
967                         fmareg = (i % 2 == 0) ? reg10 : reg22;
968                         fma(useFma, ymm1, ymm2, fmareg);
969                     }
970                 }
971
972                 if (unroll_n >= 6) {
973                     if (!isTransB) {
974                         if (i == 5) {
975                             prefetcht0(
976                                     ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
977                         }
978                         vbroadcastss(
979                                 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
980                     } else {
981                         vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
982                     }
983                     fmareg = (i % 2 == 0) ? reg05 : reg17;
984                     fma(useFma, ymm0, ymm2, fmareg);
985                     if (unroll_m >= 16) {
986                         fmareg = (i % 2 == 0) ? reg11 : reg23;
987                         fma(useFma, ymm1, ymm2, fmareg);
988                     }
989                 }
990                 if (isTransB) {
991                     prefetcht0(ptr[BO1 + BO2]);
992                     add(BO1, LDB);
993                 }
994
995                 if (i == 0) {
996                     if (unroll_m >= 4) {
997                         if (!isDirect) {
998                             prefetcht0(
999                                     ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
1000                         } else {
1001                             prefetcht0(ptr[AO1 + LDA4]);
1002                         }
1003                     }
1004                 }
1005                 if (i == 1 || i == 2) {
1006                     if (unroll_m >= 8) {
1007                         if (!isDirect) {
1008                             prefetcht0(ptr[AO1
1009                                     + (PREFETCHSIZEA + (2 + 2 * i) * 8)
1010                                             * SIZE]);
1011                         } else {
1012                             prefetcht0(ptr[AO1 + LDA4]);
1013                         }
1014                     }
1015                 }
1016                 if (i == 3) {
1017                     if (!isTransB) {
1018                         sub(BO1, -4 * SIZE);
1019                         if (unroll_n >= 4) {
1020                             sub(BO2, -4 * SIZE);
1021                         }
1022                     }
1023                 }
1024
1025                 if (!isDirect) {
1026                     if (isLoad1Unmasked) {
1027                         vmovups(ymm0,
1028                                 ptr[AO1
1029                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1030                                                 * SIZE]);
1031                     } else {
1032                         vmaskmovps(
1033                                 ymm0, VMASK,
1034                                 ptr[AO1
1035                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1036                                                 * SIZE]);
1037                     }
1038                     if (unroll_m >= 16) {
1039                         if (isLoad2Unmasked) {
1040                             vmovups(ymm1, ptr[AO1
1041                                                   + (unroll_m * (i + 1) + 1 * 8
1042                                                             - OFFSET)
1043                                                           * SIZE]);
1044                         } else {
1045                             vmaskmovps(ymm1, VMASK,
1046                                     ptr[AO1
1047                                                + (unroll_m * (i + 1) + 1 * 8
1048                                                          - OFFSET)
1049                                                        * SIZE]);
1050                         }
1051                     }
1052                 }
1053             }
1054
1055             if (!isDirect) {
1056                 sub(AO1, -unroll_m * 4 * SIZE);
1057             }
1058
1059         };
1060
1061         // Inner kernel with k=2
1062         auto innerkernel2 = [&](int unroll_m, int unroll_n,
1063                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1064                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1065                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1066                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
1067                 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
1068                 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
1069                 Ymm reg23) {
1070
1071             Ymm fmareg;
1072
1073             for (int i = 0; i < 2; i++) {
1074                 if (isDirect) {
1075                     if (isLoad1Unmasked) {
1076                         vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1077                     } else {
1078                         vmaskmovps(ymm0, VMASK,
1079                                 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1080                     }
1081                     if (unroll_m >= 16) {
1082                         if (isLoad2Unmasked) {
1083                             vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1084                         } else {
1085                             vmaskmovps(ymm1, VMASK,
1086                                     ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1087                         }
1088                     }
1089                     add(AO1, LDA);
1090                 }
1091
1092                 if (!isTransB) {
1093                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1094                 } else {
1095                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1096                 }
1097                 fmareg = (i % 2 == 0) ? reg00 : reg12;
1098                 fma(useFma, ymm0, ymm2, fmareg);
1099                 if (unroll_m >= 16) {
1100                     fmareg = (i % 2 == 0) ? reg06 : reg18;
1101                     fma(useFma, ymm1, ymm2, fmareg);
1102                 }
1103                 if (unroll_n >= 2) {
1104                     if (!isTransB) {
1105                         vbroadcastss(
1106                                 ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1107                     } else {
1108                         vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1109                     }
1110                     fmareg = (i % 2 == 0) ? reg01 : reg13;
1111                     fma(useFma, ymm0, ymm2, fmareg);
1112                     if (unroll_m >= 16) {
1113                         fmareg = (i % 2 == 0) ? reg07 : reg19;
1114                         fma(useFma, ymm1, ymm2, fmareg);
1115                     }
1116                 }
1117
1118                 if (unroll_n >= 3) {
1119                     if (!isTransB) {
1120                         if (i == 2) {
1121                             prefetcht0(
1122                                     ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
1123                         }
1124                         vbroadcastss(
1125                                 ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1126                     } else {
1127                         vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1128                     }
1129                     fmareg = (i % 2 == 0) ? reg02 : reg14;
1130                     fma(useFma, ymm0, ymm2, fmareg);
1131                     if (unroll_m >= 16) {
1132                         fmareg = (i % 2 == 0) ? reg08 : reg20;
1133                         fma(useFma, ymm1, ymm2, fmareg);
1134                     }
1135                 }
1136
1137                 if (unroll_n >= 4) {
1138                     if (!isTransB) {
1139                         vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1140                     } else {
1141                         vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1142                     }
1143                     fmareg = (i % 2 == 0) ? reg03 : reg15;
1144                     fma(useFma, ymm0, ymm2, fmareg);
1145                     if (unroll_m >= 16) {
1146                         fmareg = (i % 2 == 0) ? reg09 : reg21;
1147                         fma(useFma, ymm1, ymm2, fmareg);
1148                     }
1149                 }
1150
1151                 if (unroll_n >= 5) {
1152                     if (!isTransB) {
1153                         vbroadcastss(
1154                                 ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1155                     } else {
1156                         vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1157                     }
1158                     fmareg = (i % 2 == 0) ? reg04 : reg16;
1159                     fma(useFma, ymm0, ymm2, fmareg);
1160                     if (unroll_m >= 16) {
1161                         fmareg = (i % 2 == 0) ? reg10 : reg22;
1162                         fma(useFma, ymm1, ymm2, fmareg);
1163                     }
1164                 }
1165
1166                 if (unroll_n >= 6) {
1167                     if (!isTransB) {
1168                         vbroadcastss(
1169                                 ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1170                     } else {
1171                         vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1172                     }
1173                     fmareg = (i % 2 == 0) ? reg05 : reg17;
1174                     fma(useFma, ymm0, ymm2, fmareg);
1175                     if (unroll_m >= 16) {
1176                         fmareg = (i % 2 == 0) ? reg11 : reg23;
1177                         fma(useFma, ymm1, ymm2, fmareg);
1178                     }
1179                 }
1180
1181                 if (isCopy) {
1182                     vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1183                             ymm0);
1184                     if (unroll_m >= 16) {
1185                         vmovups(ptr[LDA4
1186                                         + (unroll_m * 0 + 1 * 8 - OFFSET)
1187                                                 * SIZE],
1188                                 ymm1);
1189                     }
1190                     sub(LDA4, -unroll_m * SIZE);
1191                 }
1192
1193                 if (!isDirect) {
1194                     if (isLoad1Unmasked) {
1195                         vmovups(ymm0, ptr[AO1
1196                                               + (unroll_m * 1 + 0 * 8 - OFFSET)
1197                                                       * SIZE]);
1198                     } else {
1199                         vmaskmovps(ymm0, VMASK,
1200                                 ptr[AO1
1201                                            + (unroll_m * 1 + 0 * 8 - OFFSET)
1202                                                    * SIZE]);
1203                     }
1204                     if (unroll_m >= 16) {
1205                         if (isLoad2Unmasked) {
1206                             vmovups(ymm1,
1207                                     ptr[AO1
1208                                             + (unroll_m * 1 + 1 * 8 - OFFSET)
1209                                                     * SIZE]);
1210                         } else {
1211                             vmaskmovps(ymm1, VMASK,
1212                                     ptr[AO1
1213                                                + (unroll_m * 1 + 1 * 8 - OFFSET)
1214                                                        * SIZE]);
1215                         }
1216                     }
1217                     sub(AO1, -unroll_m * SIZE);
1218                 }
1219
1220                 if (!isTransB) {
1221                     sub(BO1, -SIZE);
1222                     if (unroll_n >= 4) {
1223                         sub(BO2, -SIZE);
1224                     }
1225                 } else {
1226                     add(BO1, LDB);
1227                 }
1228             }
1229
1230         };
1231
1232         // Inner kernel with k=1
1233         auto innerkernel1 = [&](int unroll_m, int unroll_n,
1234                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1235                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1236                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1237                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) {
1238
1239             if (isDirect) {
1240                 if (isLoad1Unmasked) {
1241                     vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1242                 } else {
1243                     vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1244                 }
1245                 if (unroll_m >= 16) {
1246                     if (isLoad2Unmasked) {
1247                         vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1248                     } else {
1249                         vmaskmovps(ymm1, VMASK,
1250                                 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1251                     }
1252                 }
1253                 add(AO1, LDA);
1254             }
1255
1256             if (!isTransB) {
1257                 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1258             } else {
1259                 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1260             }
1261             fma(useFma, ymm0, ymm2, reg00);
1262             if (unroll_m >= 16) {
1263                 fma(useFma, ymm1, ymm2, reg06);
1264             }
1265
1266             if (unroll_n >= 2) {
1267                 if (!isTransB) {
1268                     vbroadcastss(
1269                             ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1270                 } else {
1271                     vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1272                 }
1273                 fma(useFma, ymm0, ymm2, reg01);
1274                 if (unroll_m >= 16) {
1275                     fma(useFma, ymm1, ymm2, reg07);
1276                 }
1277             }
1278
1279             if (unroll_n >= 3) {
1280                 if (!isTransB) {
1281                     vbroadcastss(
1282                             ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1283                 } else {
1284                     vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1285                 }
1286                 fma(useFma, ymm0, ymm2, reg02);
1287                 if (unroll_m >= 16) {
1288                     fma(useFma, ymm1, ymm2, reg08);
1289                 }
1290             }
1291
1292             if (unroll_n >= 4) {
1293                 if (!isTransB) {
1294                     vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1295                 } else {
1296                     vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1297                 }
1298                 fma(useFma, ymm0, ymm2, reg03);
1299                 if (unroll_m >= 16) {
1300                     fma(useFma, ymm1, ymm2, reg09);
1301                 }
1302             }
1303
1304             if (unroll_n >= 5) {
1305                 if (!isTransB) {
1306                     vbroadcastss(
1307                             ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1308                 } else {
1309                     vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1310                 }
1311                 fma(useFma, ymm0, ymm2, reg04);
1312                 if (unroll_m >= 16) {
1313                     fma(useFma, ymm1, ymm2, reg10);
1314                 }
1315             }
1316
1317             if (unroll_n >= 6) {
1318                 if (!isTransB) {
1319                     vbroadcastss(
1320                             ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1321                 } else {
1322                     vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1323                 }
1324                 fma(useFma, ymm0, ymm2, reg05);
1325                 if (unroll_m >= 16) {
1326                     fma(useFma, ymm1, ymm2, reg11);
1327                 }
1328             }
1329
1330             if (isCopy) {
1331                 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1332                         ymm0);
1333                 if (unroll_m >= 16) {
1334                     vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
1335                             ymm1);
1336                 }
1337                 sub(LDA4, -unroll_m * SIZE);
1338             }
1339
1340             if (!isDirect) {
1341                 if (isLoad1Unmasked) {
1342                     vmovups(ymm0,
1343                             ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1344                 } else {
1345                     vmaskmovps(ymm0, VMASK,
1346                             ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1347                 }
1348                 if (unroll_m >= 16) {
1349                     if (isLoad2Unmasked) {
1350                         vmovups(ymm1, ptr[AO1
1351                                               + (unroll_m * 1 + 1 * 8 - OFFSET)
1352                                                       * SIZE]);
1353                     } else {
1354                         vmaskmovps(ymm1, VMASK,
1355                                 ptr[AO1
1356                                            + (unroll_m * 1 + 1 * 8 - OFFSET)
1357                                                    * SIZE]);
1358                     }
1359                 }
1360                 sub(AO1, -unroll_m * SIZE);
1361             }
1362
1363             if (!isTransB) {
1364                 sub(BO1, -SIZE);
1365                 if (unroll_n >= 4) {
1366                     sub(BO2, -SIZE);
1367                 }
1368             } else {
1369                 add(BO1, LDB);
1370             }
1371
1372         };
1373
1374         // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as
1375         // appropriate
1376         // After calculating results in registers, writes back to C matrix
1377         auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1378                 bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma,
1379                 Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6),
1380                 Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9),
1381                 Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12),
1382                 Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15),
1383                 Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6),
1384                 Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
1385                 Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
1386                 Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
1387             if (!isDirect) {
1388                 lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
1389             } else {
1390                 mov(AO1, A);
1391             }
1392
1393             if (isCopy) {
1394                 lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
1395             } else {
1396                 lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
1397             }
1398
1399             if (isTransB) {
1400                 lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]);
1401                 lea(BO2, ptr[BO2 + LDB * 2]);
1402             }
1403
1404             if (!isDirect) {
1405                 if (isLoad1Unmasked) {
1406                     vmovups(ymm0,
1407                             ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1408                 } else {
1409                     vmaskmovps(ymm0, VMASK,
1410                             ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1411                 }
1412                 if (unroll_m >= 16) {
1413                     if (isLoad2Unmasked) {
1414                         vmovups(ymm1, ptr[AO1
1415                                               + (unroll_m * 0 + 1 * 8 - OFFSET)
1416                                                       * SIZE]);
1417                     } else {
1418                         vmaskmovps(ymm1, VMASK,
1419                                 ptr[AO1
1420                                            + (unroll_m * 0 + 1 * 8 - OFFSET)
1421                                                    * SIZE]);
1422                     }
1423                 }
1424             }
1425
1426             for (int i = 4; i < 10; i++) {
1427                 vxorps(Ymm(i), Ymm(i), Ymm(i));
1428                 vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6));
1429             }
1430
1431             mov(LL, K);
1432             sar(LL, 3);
1433
1434             Label kernel12, kernel13, kernel14, kernel15;
1435             Label kernel16, kernel17, kernel18;
1436
1437             sub(LL, SECOND_FETCH);
1438             jle(kernel13, T_NEAR);
1439             align(16);
1440
1441             L(kernel12);
1442             innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1443                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1444                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1445                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1446                     reg21, reg22, reg23);
1447             jg(kernel12, T_NEAR);
1448             align(16);
1449
1450             L(kernel13);
1451             prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
1452             if (unroll_n >= 2)
1453                 prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
1454             if (unroll_n >= 3)
1455                 prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]);
1456             if (unroll_n >= 4)
1457                 prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]);
1458             if (unroll_n >= 5)
1459                 prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]);
1460             if (unroll_n >= 6)
1461                 prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
1462
1463             add(LL, SECOND_FETCH);
1464             jle(kernel15, T_NEAR);
1465             align(16);
1466
1467             L(kernel14);
1468             innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1469                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1470                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1471                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1472                     reg21, reg22, reg23);
1473             jg(kernel14, T_NEAR);
1474             align(16);
1475
1476             L(kernel15);
1477             test(K, 4);
1478             jle(kernel16, T_NEAR);
1479             innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1480                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1481                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1482                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1483                     reg21, reg22, reg23);
1484
1485             L(kernel16);
1486             test(K, 2);
1487             jle(kernel17, T_NEAR);
1488             innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1489                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1490                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1491                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1492                     reg21, reg22, reg23);
1493             align(16);
1494
1495             L(kernel17);
1496             if (unroll_m == 16) {
1497                 if (unroll_n <= 3) {
1498                     vaddps(reg00, reg00, reg12);
1499                     vaddps(reg01, reg01, reg13);
1500                     vaddps(reg02, reg02, reg14);
1501                     vaddps(reg06, reg06, reg18);
1502                     vaddps(reg07, reg07, reg19);
1503                     vaddps(reg08, reg08, reg20);
1504                 }
1505             }
1506
1507             if (unroll_m <= 8) {
1508                 vaddps(reg00, reg00, reg12);
1509                 vaddps(reg01, reg01, reg13);
1510                 vaddps(reg02, reg02, reg14);
1511                 vaddps(reg03, reg03, reg15);
1512                 vaddps(reg04, reg04, reg16);
1513                 vaddps(reg05, reg05, reg17);
1514             }
1515
1516             test(K, 1);
1517             jle(kernel18, T_NEAR);
1518             innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1519                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1520                     reg05, reg06, reg07, reg08, reg09, reg10, reg11);
1521             align(16);
1522
1523             L(kernel18);
1524             vbroadcastss(VALPHA, ALPHA);
1525
1526             if (isBetaN) {
1527                 vbroadcastss(VBETA, BETA);
1528             }
1529
1530             // Write back the results; all beta and bias cases need to be
1531             // handled
1532             switch (unroll_n) {
1533             case 1: mov(rax, LDC); break;
1534             case 2: lea(rax, ptr[LDC * 2]); break;
1535             case 3: lea(rax, ptr[LDC + LDC * 2]); break;
1536             case 4: lea(rax, ptr[LDC + LDC * 4]); break;
1537             case 5:
1538                 lea(rax, ptr[LDC * 4]);
1539                 add(rax, LDC);
1540                 break;
1541             case 6:
1542                 lea(rax, ptr[LDC + LDC * 2]);
1543                 add(rax, rax);
1544                 break;
1545             }
1546
1547             if (hasBias) {
1548                 mov(BIAS1, BIAS);
1549                 if (isLoad1Unmasked) {
1550                     vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1551                 } else {
1552                     vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]);
1553                 }
1554             }
1555
1556             for (int i = 0; i < unroll_n; i++) {
1557                 vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA);
1558                 if (!isBeta0) {
1559                     if (isLoad1Unmasked) {
1560                         switch (i) {
1561                         case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break;
1562                         case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break;
1563                         case 2:
1564                             vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1565                             break;
1566                         case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break;
1567                         case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break;
1568                         case 5:
1569                             vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1570                             break;
1571                         }
1572                     } else {
1573                         switch (i) {
1574                         case 0:
1575                             vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]);
1576                             break;
1577                         case 1:
1578                             vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]);
1579                             break;
1580                         case 2:
1581                             vmaskmovps(
1582                                     ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1583                             break;
1584                         case 3:
1585                             vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]);
1586                             break;
1587                         case 4:
1588                             vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]);
1589                             break;
1590                         case 5:
1591                             vmaskmovps(
1592                                     ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1593                             break;
1594                         }
1595                     }
1596
1597                     if (!isBetaN) {
1598                         vaddps(Ymm(i + 4), ymm0, Ymm(i + 4));
1599                     } else {
1600                         fma(useFma, VBETA, ymm0, Ymm(i + 4), true);
1601                     }
1602                 }
1603                 if (hasBias) {
1604                     vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4));
1605                 }
1606                 if (isLoad1Unmasked) {
1607                     switch (i) {
1608                     case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break;
1609                     case 1:
1610                         vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4));
1611                         break;
1612                     case 2:
1613                         vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1614                         break;
1615                     case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break;
1616                     case 4:
1617                         vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4));
1618                         break;
1619                     case 5:
1620                         vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1621                         break;
1622                     }
1623                 } else {
1624                     switch (i) {
1625                     case 0:
1626                         vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4));
1627                         break;
1628                     case 1:
1629                         vmaskmovps(
1630                                 ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1631                         break;
1632                     case 2:
1633                         vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK,
1634                                 Ymm(i + 4));
1635                         break;
1636                     case 3:
1637                         vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4));
1638                         break;
1639                     case 4:
1640                         vmaskmovps(
1641                                 ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1642                         break;
1643                     case 5:
1644                         vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK,
1645                                 Ymm(i + 4));
1646                         break;
1647                     }
1648                 }
1649
1650                 if (unroll_m >= 16) {
1651                     // Re-use ymm4 (VBIAS2)
1652                     if (i == 0) {
1653                         if (hasBias) {
1654                             if (isLoad1Unmasked) {
1655                                 vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]);
1656                             } else {
1657                                 vmaskmovps(
1658                                         VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]);
1659                             }
1660                         }
1661                     }
1662                     vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA);
1663                     if (!isBeta0) {
1664                         if (isLoad2Unmasked) {
1665                             switch (i) {
1666                             case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break;
1667                             case 1:
1668                                 vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]);
1669                                 break;
1670                             case 2:
1671                                 vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]);
1672                                 break;
1673                             case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break;
1674                             case 4:
1675                                 vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]);
1676                                 break;
1677                             case 5:
1678                                 vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]);
1679                                 break;
1680                             }
1681                         } else {
1682                             switch (i) {
1683                             case 0:
1684                                 vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]);
1685                                 break;
1686                             case 1:
1687                                 vmaskmovps(
1688                                         ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]);
1689                                 break;
1690                             case 2:
1691                                 vmaskmovps(ymm0, VMASK,
1692                                         ptr[CO1 + LDC * 2 + 8 * SIZE]);
1693                                 break;
1694                             case 3:
1695                                 vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]);
1696                                 break;
1697                             case 4:
1698                                 vmaskmovps(
1699                                         ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]);
1700                                 break;
1701                             case 5:
1702                                 vmaskmovps(ymm0, VMASK,
1703                                         ptr[CO2 + LDC * 2 + 8 * SIZE]);
1704                                 break;
1705                             }
1706                         }
1707                         if (!isBetaN) {
1708                             vaddps(Ymm(i + 10), ymm0, Ymm(i + 10));
1709                         } else {
1710                             fma(useFma, VBETA, ymm0, Ymm(i + 10), true);
1711                         }
1712                     }
1713                     if (hasBias) {
1714                         vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10));
1715                     }
1716                     if (isLoad2Unmasked) {
1717                         switch (i) {
1718                         case 0:
1719                             vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10));
1720                             break;
1721                         case 1:
1722                             vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10));
1723                             break;
1724                         case 2:
1725                             vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1726                             break;
1727                         case 3:
1728                             vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10));
1729                             break;
1730                         case 4:
1731                             vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10));
1732                             break;
1733                         case 5:
1734                             vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1735                             break;
1736                         }
1737                     } else {
1738                         switch (i) {
1739                         case 0:
1740                             vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10));
1741                             break;
1742                         case 1:
1743                             vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK,
1744                                     Ymm(i + 10));
1745                             break;
1746                         case 2:
1747                             vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK,
1748                                     Ymm(i + 10));
1749                             break;
1750                         case 3:
1751                             vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10));
1752                             break;
1753                         case 4:
1754                             vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK,
1755                                     Ymm(i + 10));
1756                             break;
1757                         case 5:
1758                             vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK,
1759                                     Ymm(i + 10));
1760                             break;
1761                         }
1762                     }
1763                 }
1764                 if (i == 2)
1765                     add(CO1, rax);
1766             }
1767             if (unroll_n >= 4) {
1768                 add(CO2, rax);
1769             }
1770
1771             // Compute next address of B
1772             if (!isTransB) {
1773                 lea(rax, ptr[K * SIZE]);
1774                 switch (unroll_n) {
1775                 case 1:
1776                     add(BO1, LDB);
1777                     add(BO2, LDB);
1778                     break;
1779                 case 2:
1780                     lea(BO1, ptr[BO1 + LDB * 2]);
1781                     lea(BO2, ptr[BO2 + LDB * 2]);
1782                     break;
1783                 case 3:
1784                     lea(BO1, ptr[BO1 + LDB3]);
1785                     lea(BO2, ptr[BO2 + LDB3]);
1786                     break;
1787                 case 4:
1788                     lea(BO1, ptr[BO1 + LDB * 4]);
1789                     lea(BO2, ptr[BO2 + LDB * 4]);
1790                     break;
1791                 case 5:
1792                     lea(BO1, ptr[BO1 + LDB * 4]);
1793                     add(BO1, LDB);
1794                     lea(BO2, ptr[BO2 + LDB * 4]);
1795                     add(BO2, LDB);
1796                     break;
1797                 case 6:
1798                     lea(BO1, ptr[BO1 + LDB3 * 2]);
1799                     lea(BO2, ptr[BO2 + LDB3 * 2]);
1800                     break;
1801                 }
1802                 sub(BO1, rax);
1803                 sub(BO2, rax);
1804             } else {
1805                 mov(rax, LDB);
1806                 imul(rax, K);
1807                 sub(BO1, rax);
1808                 add(BO1, unroll_n * SIZE);
1809             }
1810         };
1811
1812         auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1813                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1814             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1815                     isDirect, isCopy, true);
1816         };
1817
1818         auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1819                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1820             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1821                     isDirect, isCopy, true);
1822         };
1823
1824         auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1825                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1826             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1827                     isDirect, isCopy, true);
1828         };
1829
1830         auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1831                 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1832                 bool useFma = true) {
1833             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1834                     isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1835                     Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1836                     Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1837                     Ymm(13), Ymm(14), Ymm(15));
1838         };
1839
1840         auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1841                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1842             kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1843                     isDirect, isCopy, false);
1844         };
1845
1846         auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1847                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1848             kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1849                     isDirect, isCopy, false);
1850         };
1851
1852         auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1853                 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1854                 bool useFma = true) {
1855             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1856                     isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1857                     Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1858                     Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1859                     Ymm(15));
1860         };
1861
1862         auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1863                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1864             kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1865                     isDirect, isCopy);
1866         };
1867
1868         auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1869                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1870             kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1871                     isDirect, isCopy);
1872         };
1873
1874         auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1875                 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1876                 bool useFma = true) {
1877             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1878                     isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1879                     Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1880                     Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1881                     Ymm(13), Ymm(14), Ymm(15));
1882         };
1883
1884         auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1885                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1886             kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1887                     isDirect, isCopy, false);
1888         };
1889
1890         auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1891                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1892             kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1893                     isDirect, isCopy, false);
1894         };
1895
1896         // High-level subroutine; does packing if needed, then splits C matrix.
1897         // Operates on chunks of 16 rows, 6 columns at a time (handling tail
1898         // cases appropriately).
1899         // Masking is used for tail cases where M is not divisible by 8.
1900         auto subloop = [&](
1901                 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
1902             if (isTransA) {
1903                 do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
1904             }
1905
1906             Label subloop11, subloop11mask;
1907             Label subloop20, subloop21, subloop22, subloop23;
1908             Label subloop24, subloop25;
1909             Label subloop30, subloop31, subloop32, subloop33;
1910             Label subloop34, subloop35;
1911             Label subloop98, subloop98mask;
1912             Label subloop99, subloop99mask;
1913
1914             mov(CO1, C);
1915             lea(CO2, ptr[CO1 + LDC * 2]);
1916             add(CO2, LDC);
1917             add(C, unroll_m * SIZE);
1918             mov(BO1, B);
1919             if (!isTransB) {
1920                 lea(BO2, qword[B + LDB3]);
1921             }
1922
1923             if (!isTransA) {
1924                 lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
1925                 cmp(M, UNROLL_M);
1926                 jg(subloop98, T_NEAR);
1927
1928                 mov(AA, ORIG_A);
1929                 lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
1930                 L(subloop98);
1931             }
1932
1933             mov(LL, N);
1934             mov(I, LL);
1935             if (!isTransA) {
1936                 // If N is too small, skip copy operation
1937                 cmp(LL, UNROLL_N * 3);
1938                 jle(subloop30, T_NEAR);
1939
1940                 // If A is not aligned to cache line
1941                 cmp(FLAG, 0);
1942                 je(subloop30, T_NEAR);
1943             } else {
1944                 cmp(LL, UNROLL_N);
1945                 jl(subloop20, T_NEAR);
1946             }
1947             align(16);
1948
1949             if (!isTransA) {
1950                 if (unroll_m == 16) {
1951                     kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1952                             isLoad2Unmasked, true, true);
1953                 } else {
1954                     kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1955                             isLoad2Unmasked, true, true);
1956                 }
1957             } else {
1958                 if (unroll_m == 16) {
1959                     kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1960                             isLoad2Unmasked, false, false);
1961                 } else {
1962                     kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1963                             isLoad2Unmasked, false, false);
1964                 }
1965             }
1966
1967             sub(I, UNROLL_N);
1968             cmp(I, UNROLL_N);
1969             jl(subloop20, T_NEAR);
1970             align(16);
1971
1972             L(subloop11);
1973             if (unroll_m == 16) {
1974                 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1975                         isLoad2Unmasked, false, false);
1976             } else {
1977                 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked,
1978                         false, false);
1979             }
1980             sub(I, UNROLL_N);
1981             cmp(I, UNROLL_N);
1982             jge(subloop11, T_NEAR);
1983             align(16);
1984
1985             L(subloop20);
1986             cmp(I, 1);
1987             jne(subloop21, T_NEAR);
1988             if (unroll_m == 16) {
1989                 kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
1990                         false, false);
1991             } else {
1992                 kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
1993                         false);
1994             }
1995             jmp(subloop99, T_NEAR);
1996             align(16);
1997
1998             L(subloop21);
1999             cmp(I, 2);
2000             jne(subloop22, T_NEAR);
2001             if (unroll_m == 16) {
2002                 kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2003                         false, false);
2004             } else {
2005                 kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
2006                         false);
2007             }
2008             jmp(subloop99, T_NEAR);
2009             align(16);
2010
2011             L(subloop22);
2012             cmp(I, 3);
2013             jne(subloop23, T_NEAR);
2014             if (unroll_m == 16) {
2015                 kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2016                         false, false);
2017             } else {
2018                 kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
2019                         false);
2020             }
2021             jmp(subloop99, T_NEAR);
2022             align(16);
2023
2024             L(subloop23);
2025             cmp(I, 4);
2026             jne(subloop24, T_NEAR);
2027             if (unroll_m == 16) {
2028                 kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2029                         false, false);
2030             } else {
2031                 kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
2032                         false);
2033             }
2034             jmp(subloop99, T_NEAR);
2035             align(16);
2036
2037             L(subloop24);
2038             cmp(I, 5);
2039             jne(subloop99, T_NEAR);
2040             if (unroll_m == 16) {
2041                 kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2042                         false, false);
2043             } else {
2044                 kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
2045                         false);
2046             }
2047             jmp(subloop99, T_NEAR);
2048             align(16);
2049
2050             if (!isTransA) {
2051                 L(subloop30);
2052                 cmp(I, UNROLL_N);
2053                 jl(subloop25, T_NEAR);
2054                 align(16);
2055
2056                 L(subloop31);
2057                 if (unroll_m == 16) {
2058                     kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2059                             isLoad2Unmasked, true, false);
2060                 } else {
2061                     kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2062                             isLoad2Unmasked, true, false);
2063                 }
2064                 sub(I, UNROLL_N);
2065                 cmp(I, UNROLL_N);
2066                 jge(subloop31, T_NEAR);
2067                 align(16);
2068
2069                 L(subloop25);
2070                 cmp(I, 1);
2071                 jne(subloop32, T_NEAR);
2072                 if (unroll_m == 16) {
2073                     kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2074                             true, false);
2075                 } else {
2076                     kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2077                             true, false);
2078                 }
2079                 jmp(subloop99, T_NEAR);
2080                 align(16);
2081
2082                 L(subloop32);
2083                 cmp(I, 2);
2084                 jne(subloop33, T_NEAR);
2085                 if (unroll_m == 16) {
2086                     kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2087                             true, false);
2088                 } else {
2089                     kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2090                             true, false);
2091                 }
2092                 jmp(subloop99, T_NEAR);
2093                 align(16);
2094
2095                 L(subloop33);
2096                 cmp(I, 3);
2097                 jne(subloop34, T_NEAR);
2098                 if (unroll_m == 16) {
2099                     kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2100                             true, false);
2101                 } else {
2102                     kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2103                             true, false);
2104                 }
2105                 jmp(subloop99, T_NEAR);
2106                 align(16);
2107
2108                 L(subloop34);
2109                 cmp(I, 4);
2110                 jne(subloop35, T_NEAR);
2111                 if (unroll_m == 16) {
2112                     kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2113                             true, false);
2114                 } else {
2115                     kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2116                             true, false);
2117                 }
2118                 jmp(subloop99, T_NEAR);
2119                 align(16);
2120
2121                 L(subloop35);
2122                 cmp(I, 5);
2123                 jne(subloop99, T_NEAR);
2124                 if (unroll_m == 16) {
2125                     kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2126                             true, false);
2127                 } else {
2128                     kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2129                             true, false);
2130                 }
2131                 align(16);
2132             }
2133
2134             L(subloop99);
2135             // Compute address for A
2136             if (!isTransA) {
2137                 add(A, unroll_m * SIZE);
2138             } else {
2139                 mov(rax, LDA);
2140                 imul(rax, rax, unroll_m);
2141                 add(A, rax);
2142             }
2143
2144             // Compute next address of BIAS
2145             if (hasBias) {
2146                 add(BIAS, unroll_m * SIZE);
2147             }
2148         };
2149
2150         preamble();
2151
2152         Label buffer_in_ws, buffer_allocated;
2153
2154         // Get the registers
2155         mov(B, ARG_B);
2156         mov(LDB, ARG_LDB);
2157         mov(r15, ARG_BETA);
2158         mov(r12, ARG_C);
2159         if (hasBias)
2160             mov(r10, ARG_BIAS);
2161         mov(LDC, ARG_LDC);
2162         mov(rbp, rsp);
2163
2164         vmovss(xmm0, ptr[ARG_ALPHA]);
2165         vmovss(xmm1, ptr[r15]);
2166
2167 #if _WIN32
2168         mov(A, ARG_A);
2169         mov(LDA, ARG_LDA);
2170 #endif
2171
2172         cmp(K, STACK_K_CAPACITY);
2173         jg(buffer_in_ws, T_NEAR);
2174
2175         // Create buffer and align to 4kB page
2176         lea(rax, ptr[K * SIZE]);
2177         sal(rax, 4);
2178         add(rax, 256);
2179         sub(rsp, rax);
2180         and_(rsp, -PAGE_4K);
2181         jmp(buffer_allocated, T_NEAR);
2182
2183         L(buffer_in_ws);
2184         mov(rsp, ARG_WS);
2185
2186         L(buffer_allocated);
2187
2188         mov(ORIG_SP, rbp);
2189         mov(M, ARG_M);
2190         mov(N, ARG_N);
2191         mov(C, r12);
2192         if (hasBias)
2193             mov(BIAS, r10);
2194         vmovss(ALPHA, xmm0);
2195         vmovss(BETA, xmm1);
2196         sub(A, -OFFSET * SIZE);
2197         sub(B, -OFFSET * SIZE);
2198         mov(ORIG_A, A);
2199         sal(LDA, BASE_SHIFT);
2200         sal(LDB, BASE_SHIFT);
2201         sal(LDC, BASE_SHIFT);
2202         lea(LDB3, ptr[LDB + LDB * 2]);
2203
2204         for (int i = 0; i < 8; i++) {
2205             mov(dword[rsp + 88 + i * 4], i);
2206         }
2207
2208         if (isTransA && is_avx2) {
2209             movq(xmm0, LDA);
2210             vpbroadcastq(ymm1, xmm0);
2211             vinsertf128(ymm0, ymm0, xmm0, 1);
2212             vpermilpd(ymm0, ymm0, 5);
2213             vpaddq(ymm1, ymm1, ymm1);
2214             vperm2f128(ymm1, ymm1, ymm1, 8);
2215             vpaddq(ymm0, ymm0, ymm1);
2216             vmovups(STRIDE, ymm0);
2217         }
2218
2219         // Check A alignment and leading dimension; take copy-based path as
2220         // needed
2221         mov(rax, LDA);
2222         or_(rax, A);
2223         and_(rax, 0x1f);
2224         mov(FLAG, rax);
2225
2226         Label main0, main1, main2, main3, main999;
2227
2228         cmp(M, UNROLL_M);
2229         jl(main0, T_NEAR);
2230         align(16);
2231
2232         L(main1);
2233         subloop(UNROLL_M, true, true);
2234         sub(M, UNROLL_M);
2235         cmp(M, UNROLL_M);
2236         jge(main1, T_NEAR);
2237         align(16);
2238
2239         L(main0);
2240         cmp(M, 0);
2241         jle(main999, T_NEAR);
2242
2243         if (UNROLL_M > 8) {
2244             cmp(M, 8);
2245             jle(main2, T_NEAR);
2246
2247             sub(M, 8);
2248             vbroadcastss(VMASK, M);
2249             vpcmpgtd(VMASK, VMASK, MASK);
2250
2251             subloop(16, true, false);
2252             jmp(main999, T_NEAR);
2253             align(16);
2254
2255             L(main2);
2256             cmp(M, 8);
2257             jne(main3, T_NEAR);
2258             subloop(8, true, true);
2259             jmp(main999, T_NEAR);
2260         }
2261
2262         align(16);
2263
2264         L(main3);
2265         vbroadcastss(VMASK, M);
2266         if (is_avx2) {
2267             vpcmpgtd(VMASK, VMASK, MASK);
2268         } else {
2269             auto xmask = Xmm(VMASK.getIdx());
2270             auto xmm_tmp = xmm4;
2271
2272             vextractf128(xmm_tmp, VMASK, 1);
2273             vpcmpgtd(xmask, xmask, MASK);
2274             vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4
2275             vinsertf128(VMASK, VMASK, xmm_tmp, 1);
2276         }
2277         subloop(8, false, false);
2278         align(16);
2279
2280         L(main999);
2281         // Restore original stack
2282         mov(rsp, ORIG_SP);
2283
2284         vzeroupper();
2285         postamble();
2286
2287         ker_ = this->getCode<ker_t>();
2288     }
2289
2290     typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
2291             const float *alpha, const float *a, dim_t lda,
2292             const float *b, dim_t ldb, const float *beta, float *c,
2293             dim_t ldc, const float *bias, float *ws);
2294
2295     void operator()(dim_t  m, dim_t n, dim_t k,
2296             const float *alpha, const float *a, dim_t lda,
2297             const float *b, dim_t ldb, const float *beta, float *c,
2298             dim_t ldc, const float *bias, float *ws) const
2299     {
2300         ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
2301     }
2302
2303 private:
2304     ker_t ker_;
2305 };
2306
2307 const xbyak_gemm *get_xbyak_gemm(
2308         bool isTransA, bool isTransB, float beta, bool hasBias) {
2309     auto beta_idx = [](float beta) {
2310         return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
2311     };
2312
2313     // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
2314     static xbyak_gemm *kernel_table[2][2][2][3];
2315     static std::once_flag initialized;
2316     std::call_once(initialized, [=]{
2317             for (bool isTransA: {false, true})
2318             for (bool isTransB: {false, true})
2319             for (bool hasBias: {false, true})
2320             for (float beta: {0.0f, 1.0f, 2.0f}) {
2321                 // nocopy sgemm with bias for beta != 0.0 is not supported
2322                 if (hasBias && beta != 0.0)
2323                     continue;
2324                 kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
2325                     new xbyak_gemm(isTransA, isTransB, beta, hasBias);
2326             }
2327     });
2328
2329     return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
2330 }
2331
2332 void sgemm_nocopy_driver(const char *transa,
2333         const char *transb, int m, int n, int k, const float *alpha,
2334         const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
2335         float *c, dim_t ldc, const float *bias, float *ws)
2336 {
2337     bool isTransA = (*transa == 'T' || *transa == 't');
2338     bool isTransB = (*transb == 'T' || *transb == 't');
2339
2340     int Bm, sizeM, Bn, sizeN, Bk, sizeK;
2341
2342     int i, j;
2343
2344     if ((m <= 0) || (n <= 0))
2345         return;
2346
2347     if ((k <= 0) || (alpha[0] == 0.)) {
2348
2349         if (beta[0] == 0.) {
2350             for (j = 0; j < n; j++)
2351                 for (i = 0; i < m; i++)
2352                     c[i + j * ldc] = 0.0;
2353         } else if (beta[0] != 1.) {
2354             for (j = 0; j < n; j++)
2355                 for (i = 0; i < m; i++)
2356                     c[i + j * ldc] *= beta[0];
2357         }
2358
2359         return;
2360     }
2361
2362     assert(IMPLICATION(bias != nullptr, *beta == 0.0));
2363
2364     // XXX: this happens on every thread...
2365     bool hasBias = (bias != nullptr);
2366     auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
2367     auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
2368     auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
2369     assert(ker_bn && ker_b1 && ker_b0);
2370
2371     int BM = 4032;
2372     int BN = isTransA ? 96 : 48;
2373     int BK = isTransB ? 96 : 256;
2374     const float *curA, *curB, *curBias = nullptr;
2375     float *curC;
2376
2377     for (Bk = 0; Bk < k; Bk += sizeK) {
2378         sizeK = k - Bk;
2379         if (sizeK >= BK * 2)
2380             sizeK = BK;
2381         else {
2382             if (sizeK > BK)
2383                 sizeK = (sizeK + 1) / 2;
2384         }
2385
2386         for (Bm = 0; Bm < m; Bm += sizeM) {
2387             sizeM = m - Bm;
2388             if (sizeM >= BM * 2)
2389                 sizeM = BM;
2390             else {
2391                 if (sizeM > BM + BM / 2)
2392                     sizeM = (sizeM + 1) / 2;
2393             }
2394
2395             for (Bn = 0; Bn < n; Bn += sizeN) {
2396                 sizeN = n - Bn;
2397                 if (sizeN >= BN * 2)
2398                     sizeN = BN;
2399                 else {
2400                     if (sizeN > BN + BN / 2)
2401                         sizeN = (sizeN + 1) / 2;
2402                 }
2403
2404                 if (!isTransA) {
2405                     curA = a + Bm + Bk * lda;
2406                 } else {
2407                     curA = a + Bk + Bm * lda;
2408                 }
2409                 if (!isTransB) {
2410                     curB = b + Bk + Bn * ldb;
2411                 } else {
2412                     curB = b + Bn + Bk * ldb;
2413                 }
2414                 curC = c + Bm + (size_t)Bn * ldc;
2415                 if (bias != nullptr) {
2416                     if (Bk == 0) {
2417                         curBias = bias + Bm;
2418                     } else {
2419                         curBias = nullptr;
2420                     }
2421                 }
2422                 if (Bk == 0) {
2423                     if (*beta == 0.0 && bias == nullptr)
2424                         (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
2425                                 alpha, curA, lda, curB, ldb, beta, curC, ldc,
2426                                 curBias, ws);
2427                     else
2428                         (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
2429                                 alpha, curA, lda, curB, ldb, beta, curC, ldc,
2430                                 curBias, ws);
2431                 } else {
2432                     (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
2433                             alpha, curA, lda, curB, ldb, beta, curC, ldc,
2434                             curBias, ws);
2435                 }
2436             }
2437         }
2438     }
2439 }
2440
2441 }
2442
2443 mkldnn_status_t jit_avx_gemm_f32(
2444         const char *transa, const char *transb,
2445         const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
2446         const float *A, const int *p_lda, const float *B, const int *p_ldb,
2447         const float *p_beta, float *C, const int *p_ldc, const float *bias)
2448 {
2449     using namespace mkldnn::impl::utils;
2450     using namespace avx_gemm_f32;
2451     using namespace gemm_utils;
2452
2453     if (*p_beta != 0 && bias)
2454         return ref_gemm(transa, transb, p_m, p_n, p_k,
2455                 p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
2456
2457     int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
2458
2459     int m = *p_m;
2460     int n = *p_n;
2461     int k = *p_k;
2462     dim_t lda = *p_lda;
2463     dim_t ldb = *p_ldb;
2464     dim_t ldc = *p_ldc;
2465     float beta = *p_beta;
2466     int MB, NB, KB;
2467
2468     int nthr_m, nthr_n, nthr_k, nthr_mn;
2469
2470     // Determine threading partitioning
2471     calc_nthr_nocopy_avx(
2472             m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
2473     assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
2474
2475     // May not happen, but just in case
2476     if (nthr < nthr_m * nthr_n * nthr_k)
2477         nthr = nthr_m * nthr_n * nthr_k;
2478
2479     nthr_mn = nthr_m * nthr_n;
2480
2481     unsigned char * ompstatus_ = nullptr;
2482     unsigned char volatile *ompstatus = nullptr;
2483
2484     float *c_buffers = nullptr;
2485     float *ws_buffers = nullptr;
2486
2487     if (nthr_k > 1) {
2488         ompstatus_ = (unsigned char *) malloc(
2489                 nthr * CACHE_LINE_SIZE,
2490                 CACHE_LINE_SIZE);
2491         ompstatus = (unsigned char volatile *) ompstatus_;
2492         assert(ompstatus);
2493
2494         for (int i = 0; i < nthr; i++)
2495             ompstatus[i * CACHE_LINE_SIZE] = 0;
2496
2497         c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
2498                 * sizeof(float), PAGE_4K);
2499     }
2500
2501     const size_t ws_elems_per_thr = (size_t)k * 16 + 64;
2502     const size_t ws_size_per_thr
2503             = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
2504     if (k > STACK_K_CAPACITY) {
2505         ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
2506     }
2507
2508     parallel_nd(nthr, [&](const int ithr) {
2509         int ithr_m, ithr_n, ithr_k, ithr_mn;
2510         int m_from, m_to, myM;
2511         int n_from, n_to, myN;
2512         int k_from, k_to, myK;
2513         int cbase, ibase;
2514         const float *myA, *myB, *myBias = nullptr;
2515         float *myC = C, myBeta;
2516         float *ws = ws_buffers ?
2517                 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
2518         dim_t ld = ldc;
2519
2520         int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
2521
2522         if (ithr < nthr_m * nthr_n * nthr_k) {
2523
2524             ithr_mn = ithr % nthr_mn;
2525             ithr_m = ithr_mn % nthr_m;
2526             ithr_n = ithr_mn / nthr_m;
2527             ithr_k = ithr / nthr_mn;
2528
2529             /* swap ithr_k for performance improvement */
2530             if (ithr_k == 0)
2531                 ithr_k = nthr_k - 1;
2532             else if (ithr_k == nthr_k - 1)
2533                 ithr_k = 0;
2534
2535             m_from = MB * (ithr_m);
2536             m_to = MB * (ithr_m + 1);
2537             if (m_to > m)
2538                 m_to = m;
2539             myM = m_to - m_from;
2540
2541             n_from = NB * (ithr_n);
2542             n_to = NB * (ithr_n + 1);
2543             if (n_to > n)
2544                 n_to = n;
2545             myN = n_to - n_from;
2546
2547             k_from = KB * (ithr_k);
2548             k_to = KB * (ithr_k + 1);
2549             if (k_to > k)
2550                 k_to = k;
2551             myK = k_to - k_from;
2552
2553             cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2554             ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
2555
2556             if ((myM > 0) && (myN > 0)) {
2557
2558                 if (*transa == 'N' || *transa == 'n') {
2559                     myA = &(A[m_from + k_from * lda]);
2560                 } else {
2561                     myA = &(A[k_from + m_from * lda]);
2562                 }
2563                 if (*transb == 'N' || *transb == 'n') {
2564                     myB = &(B[k_from + n_from * ldb]);
2565                 } else {
2566                     myB = &(B[n_from + k_from * ldb]);
2567                 }
2568                 if (ithr_k == 0) {
2569                     myC = &(C[m_from + n_from * ldc]);
2570                     myBeta = beta;
2571                     ld = ldc;
2572                     if (bias)
2573                         myBias = &(bias[m_from]);
2574                 } else {
2575                     myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
2576                     myBeta = 0.0;
2577                     ld = MB;
2578                     myBias = nullptr;
2579                 }
2580
2581                 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
2582                         lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
2583
2584                 if (nthr_k > 1 && !sum_later)
2585                     ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
2586             }
2587
2588             if (nthr_k > 1 && !sum_later) {
2589
2590                 // sum matrices partitioned along K dimension
2591                 int n1, n2;
2592
2593                 partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2594
2595                 if (ithr_k > 0) {
2596
2597                     myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
2598                         + (dim_t)n1 * MB;
2599                     /* need to wait until main thread finishes */
2600                     while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
2601                     };
2602
2603                     /* my cache is hot */
2604                     sum_two_matrices(myM, n2, myC, MB,
2605                             &C[m_from + (n_from + n1) * ldc], ldc);
2606                 }
2607
2608                 for (int ik = 1; ik < nthr_k; ++ik) {
2609                     if (ik != ithr_k) {
2610
2611                         myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
2612                             + (dim_t)n1 * MB;
2613
2614                         while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2615                         };
2616
2617                         sum_two_matrices(myM, n2, myC, MB,
2618                                 &C[m_from + (n_from + n1) * ldc], ldc);
2619                     }
2620                 }
2621             }
2622         }
2623     });
2624
2625     // handle C summation later
2626     if (nthr_k > 1 && ompstatus[0] == 0) {
2627
2628         parallel_nd(nthr, [&](const int ithr) {
2629             int ithr_m, ithr_n, ithr_k, ithr_mn;
2630             int m_from, m_to, myM;
2631             int n_from, n_to, myN;
2632             int cbase;
2633             float *myC = C;
2634
2635             if (ithr < nthr_m * nthr_n * nthr_k) {
2636
2637                 ithr_mn = ithr % nthr_mn;
2638                 ithr_m = ithr_mn % nthr_m;
2639                 ithr_n = ithr_mn / nthr_m;
2640                 ithr_k = ithr / nthr_mn;
2641
2642                 /* swap ithr_k for performance improvement */
2643                 if (ithr_k == 0)
2644                     ithr_k = nthr_k - 1;
2645                 else if (ithr_k == nthr_k - 1)
2646                     ithr_k = 0;
2647
2648                 m_from = MB * (ithr_m);
2649                 m_to = MB * (ithr_m + 1);
2650                 if (m_to > m)
2651                     m_to = m;
2652                 myM = m_to - m_from;
2653
2654                 n_from = NB * (ithr_n);
2655                 n_to = NB * (ithr_n + 1);
2656                 if (n_to > n)
2657                     n_to = n;
2658                 myN = n_to - n_from;
2659
2660                 cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2661
2662                 if (nthr_k > 1) {
2663                     // sum matrices partitioned along K dimension
2664                     int n1, n2;
2665
2666                     partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2667
2668                     if (ithr_k > 0) {
2669
2670                         myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
2671                             + (dim_t)n1 * MB;
2672
2673                         /* my cache is hot */
2674                         sum_two_matrices(myM, n2, myC, MB,
2675                                          &C[m_from + (n_from + n1) * ldc], ldc);
2676                     }
2677
2678                     for (int ik = 1; ik < nthr_k; ++ik) {
2679                         if (ik != ithr_k) {
2680
2681                             myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
2682                                 + (dim_t)n1 * MB;
2683
2684                             sum_two_matrices(myM, n2, myC, MB,
2685                                              &C[m_from + (n_from + n1) * ldc], ldc);
2686                         }
2687                     }
2688                 }
2689             }
2690         });
2691     }
2692
2693
2694     free(c_buffers);
2695     free(ompstatus_);
2696     free(ws_buffers);
2697
2698     return mkldnn_success;
2699 }
2700
2701 }
2702 }
2703 }
2704
2705 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s