9766a46d76592a68c9306fb26512f615fd0c2fef
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / jit_avx_gemm_f32.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <math.h>
18
19 #include "mkldnn_thread.hpp"
20 #include "utils.hpp"
21 #include "gemm_utils.hpp"
22 #include "jit_avx_gemm_f32.hpp"
23
24 #define CACHE_LINE_SIZE 16
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::utils;
32
33 using namespace Xbyak;
34 #define STACKSIZE get_size_of_abi_save_regs()
35 #if _WIN32
36 #define STACK_K_CAPACITY 128
37 #else
38 #define STACK_K_CAPACITY 8192
39 #endif
40 #define SIZE 4
41 #define OFFSET 32
42 #define BASE_SHIFT 2
43 #define SECOND_FETCH 14
44
45 struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
46     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
47
48     xbyak_gemm(char transa, char transb, float beta, bool hasBias = false,
49             void *code_ptr = nullptr,
50             size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
51         : jit_generator(code_ptr, code_size)
52     {
53         const bool is_avx2 = mayiuse(avx2);
54         assert(implication(!is_avx2, mayiuse(avx)));
55
56         const int UNROLL_M = is_avx2 ? 16 : 8;
57         const int UNROLL_N = 6;
58
59         bool isTransA = (transa == 'T' || transa == 't');
60         bool isTransB = (transb == 'T' || transb == 't');
61         bool isBeta0 = (beta == 0.0);
62         bool isBetaN = (!isBeta0 && beta != 1.0);
63
64         // various definitions for convenience
65         auto ARG_M = abi_param1;
66         auto ARG_N = abi_param2;
67         auto K = abi_param3;
68         auto ARG_ALPHA = abi_param4;
69 #ifdef _WIN32
70         auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
71         auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
72             sizeof(float *) + STACKSIZE];
73         const auto stackOffset = OFFSET_SHADOWSPACE +
74             sizeof(float *) + STACKSIZE;
75         auto A = rsi;
76         auto LDA = rdi;
77 #else
78         auto ARG_A = r8;
79         auto ARG_LDA = r9;
80         const auto stackOffset = STACKSIZE;
81         auto A = ARG_A;
82         auto LDA = ARG_LDA;
83 #endif
84         auto ARG_B = ptr[rsp + 8 + stackOffset];
85         auto ARG_LDB = ptr[rsp + 16 + stackOffset];
86         auto ARG_BETA = ptr[rsp + 24 + stackOffset];
87         auto ARG_C = ptr[rsp + 32 + stackOffset];
88         auto ARG_LDC = ptr[rsp + 40 + stackOffset];
89         auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
90         auto ARG_WS = ptr[rsp + 56 + stackOffset];
91
92         auto B = r11;
93         auto LDB = rbx;
94         auto LDC = r13;
95         auto LL = rax;
96         auto AO1 = abi_param2;
97         auto BO1 = abi_param4;
98         auto BO2 = rbp;
99         auto CO1 = r14;
100         auto CO2 = r15;
101         auto LDB3 = r10;
102         auto LDA4 = abi_param1;
103         auto AA = r12;
104         auto BIAS1 = abi_param1;
105
106         auto M = qword[rsp + 0];
107         auto N = qword[rsp + 8];
108         auto FLAG = qword[rsp + 16];
109         auto I = qword[rsp + 24];
110         auto C = qword[rsp + 32];
111         auto BIAS = qword[rsp + 40];
112         auto ALPHA = qword[rsp + 48];
113         auto BETA = qword[rsp + 64];
114         auto ORIG_A = qword[rsp + 80];
115         auto MASK = dword[rsp + 88];
116         auto STRIDE = qword[rsp + 120];
117         auto ORIG_SP = qword[rsp + 152];
118
119         auto VALPHA = ymm1;
120         auto VBETA = ymm2;
121         auto VMASK = ymm3;
122         auto VBIAS1 = ymm2;
123         auto VBIAS2 = ymm4;
124
125         auto PREFETCHSIZEA = 128;
126         auto PREFETCHSIZEB = (!isTransB) ? -16 : 0;
127
128         // Function for packing if needed
129         auto do_pack = [&](
130                 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
131
132             int regIdx;
133             Reg64 reg;
134             inLocalLabel();
135
136             mov(BO1, A);
137             lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
138
139             if (isTransA) {
140                 lea(BO2, ptr[BO1 + LDA * 4]);
141                 lea(CO1, ptr[LDA + LDA * 2]);
142                 vmovupd(ymm7, STRIDE);
143             }
144
145             mov(LL, K);
146             sar(LL, 2);
147             jle(".pack3", T_NEAR);
148             align(16);
149
150             L(".pack2");
151             if (!isTransA) {
152                 for (int i = 0; i < 4; i++) {
153                     regIdx = (i % 2 == 0) ? 4 : 6;
154                     if (isLoad1Unmasked) {
155                         vmovups(Ymm(regIdx),
156                                 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
157                     } else {
158                         vmaskmovps(Ymm(regIdx), VMASK,
159                                 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
160                     }
161                     if (unroll_m > 8) {
162                         if (isLoad2Unmasked) {
163                             vmovups(Ymm(regIdx + 1),
164                                     ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
165                         } else {
166                             vmaskmovps(Ymm(regIdx + 1), VMASK,
167                                     ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
168                         }
169                     }
170                     add(BO1, LDA);
171
172                     vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
173                             Ymm(regIdx));
174                     if (unroll_m > 8) {
175                         vmovups(ptr[AO1
176                                         + (unroll_m * i + 1 * 8 - OFFSET)
177                                                 * SIZE],
178                                 Ymm(regIdx + 1));
179                     }
180                 }
181
182             } else {
183                 if (isLoad1Unmasked) {
184                     for (int i = 0; i < 2; i++) {
185                         reg = (i % 2 == 0) ? BO1 : BO2;
186                         vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]);
187                         vmovups(xmm1,
188                                 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
189                         lea(BO2, ptr[reg + LDA * 2]);
190                         vunpcklps(xmm4, xmm0, xmm1);
191                         vunpckhps(xmm5, xmm0, xmm1);
192                         vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
193                         vmovups(xmm1,
194                                 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
195                         lea(BO2, ptr[BO2 + LDA * 2]);
196                         vunpcklps(xmm6, xmm0, xmm1);
197                         vunpckhps(xmm2, xmm0, xmm1);
198
199                         vunpcklpd(xmm0, xmm4, xmm6);
200                         vunpckhpd(xmm1, xmm4, xmm6);
201                         vmovups(ptr[AO1
202                                         + (unroll_m * 0 + i * 4 - OFFSET)
203                                                 * SIZE],
204                                 xmm0);
205                         vmovups(ptr[AO1
206                                         + (unroll_m * 1 + i * 4 - OFFSET)
207                                                 * SIZE],
208                                 xmm1);
209                         vunpcklpd(xmm0, xmm5, xmm2);
210                         vunpckhpd(xmm1, xmm5, xmm2);
211                         vmovups(ptr[AO1
212                                         + (unroll_m * 2 + i * 4 - OFFSET)
213                                                 * SIZE],
214                                 xmm0);
215                         vmovups(ptr[AO1
216                                         + (unroll_m * 3 + i * 4 - OFFSET)
217                                                 * SIZE],
218                                 xmm1);
219                     }
220                 } else if (is_avx2) {
221                     for (int i = 0; i < 2; i++) {
222                         vmovaps(xmm4, xmm3);
223                         vgatherqps(xmm0,
224                                 ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE],
225                                 xmm4);
226                         vmovaps(xmm4, xmm3);
227                         vgatherqps(xmm1,
228                                 ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
229                                 xmm4);
230
231                         vmovups(ptr[AO1
232                                         + (unroll_m * (2 * i) + 0 * 4 - OFFSET)
233                                                 * SIZE],
234                                 xmm0);
235                         vmovups(ptr[AO1
236                                         + (unroll_m * (2 * i + 1) + 0 * 4
237                                                   - OFFSET)
238                                                 * SIZE],
239                                 xmm1);
240                     }
241
242                     lea(BO2, ptr[BO1 + LDA * 4]);
243
244                     for (int i = 0; i < 2; i++) {
245                         vextractf128(xmm4, ymm3, 1);
246                         vgatherqps(xmm0,
247                                 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
248                                 xmm4);
249                         vextractf128(xmm4, ymm3, 1);
250                         vgatherqps(xmm1,
251                                 ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
252                                 xmm4);
253
254                         vmovups(ptr[AO1
255                                         + (unroll_m * (2 * i) + 1 * 4 - OFFSET)
256                                                 * SIZE],
257                                 xmm0);
258                         vmovups(ptr[AO1
259                                         + (unroll_m * (2 * i + 1) + 1 * 4
260                                                   - OFFSET)
261                                                 * SIZE],
262                                 xmm1);
263                     }
264
265                     lea(BO2, ptr[BO2 + LDA * 4]);
266                 } else {
267                     vxorps(xmm4, xmm4, xmm4);
268                     lea(BO2, ptr[BO1 + LDA * 4]);
269
270                     auto el_cp = [&](int section, int ld_step) {
271                         RegExp src_addr = section == 0 ? BO1 : BO2;
272                         if (ld_step == 1 || ld_step == 2)
273                             src_addr = src_addr + LDA * ld_step;
274                         else if (ld_step == 3)
275                             src_addr = src_addr + CO1;
276                         src_addr = src_addr - OFFSET * SIZE;
277
278                         vmovups(Xmm(ld_step % 2), ptr[src_addr]);
279                         RegExp dst_addr = AO1
280                             + (ld_step + section * 4 - OFFSET) * SIZE;
281                         for (int off = 0; off < 4; ++off)
282                             pextrd(ptr[dst_addr + unroll_m * off * SIZE],
283                                     Xmm(ld_step % 2), off);
284                     };
285
286                     Label l_end;
287                     el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
288                     el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
289                     el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
290                     el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
291                     el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
292                     el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
293                     el_cp(1, 2);
294                     L(l_end);
295
296                     lea(BO2, ptr[BO2 + LDA * 4]);
297                 }
298
299                 if (unroll_m >= 16) {
300                     assert(is_avx2);
301                     if (isLoad2Unmasked) {
302                         for (int i = 0; i < 2; i++) {
303                             vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
304                             vmovups(xmm1, ptr[BO2 + LDA * 1
305                                                   + (0 * 8 - OFFSET) * SIZE]);
306                             lea(BO2, ptr[BO2 + LDA * 2]);
307                             vunpcklps(xmm4, xmm0, xmm1);
308                             vunpckhps(xmm5, xmm0, xmm1);
309                             vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
310                             vmovups(xmm1, ptr[BO2 + LDA * 1
311                                                   + (0 * 8 - OFFSET) * SIZE]);
312                             if (i == 0)
313                                 lea(BO2, ptr[BO2 + LDA * 2]);
314                             vunpcklps(xmm6, xmm0, xmm1);
315                             vunpckhps(xmm2, xmm0, xmm1);
316
317                             vunpcklpd(xmm0, xmm4, xmm6);
318                             vunpckhpd(xmm1, xmm4, xmm6);
319                             vmovups(ptr[AO1
320                                             + (unroll_m * 0 + (i + 2) * 4
321                                                       - OFFSET)
322                                                     * SIZE],
323                                     xmm0);
324                             vmovups(ptr[AO1
325                                             + (unroll_m * 1 + (i + 2) * 4
326                                                       - OFFSET)
327                                                     * SIZE],
328                                     xmm1);
329                             vunpcklpd(xmm0, xmm5, xmm2);
330                             vunpckhpd(xmm1, xmm5, xmm2);
331                             vmovups(ptr[AO1
332                                             + (unroll_m * 2 + (i + 2) * 4
333                                                       - OFFSET)
334                                                     * SIZE],
335                                     xmm0);
336                             vmovups(ptr[AO1
337                                             + (unroll_m * 3 + (i + 2) * 4
338                                                       - OFFSET)
339                                                     * SIZE],
340                                     xmm1);
341                         }
342                     } else {
343                         for (int i = 0; i < 2; i++) {
344                             vmovaps(xmm4, xmm3);
345                             vgatherqps(xmm0,
346                                     ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
347                                     xmm4);
348                             vmovaps(xmm4, xmm3);
349                             vgatherqps(xmm1,
350                                     ptr[BO2 + ymm7
351                                                + ((2 * i + 1) - OFFSET) * SIZE],
352                                     xmm4);
353
354                             vmovups(ptr[AO1
355                                             + (unroll_m * (2 * i) + 2 * 4
356                                                       - OFFSET)
357                                                     * SIZE],
358                                     xmm0);
359                             vmovups(ptr[AO1
360                                             + (unroll_m * (2 * i + 1) + 2 * 4
361                                                       - OFFSET)
362                                                     * SIZE],
363                                     xmm1);
364                         }
365
366                         lea(BO2, ptr[BO2 + LDA * 4]);
367
368                         for (int i = 0; i < 2; i++) {
369                             vextractf128(xmm4, ymm3, 1);
370                             vgatherqps(xmm0,
371                                     ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
372                                     xmm4);
373                             vextractf128(xmm4, ymm3, 1);
374                             vgatherqps(xmm1,
375                                     ptr[BO2 + ymm7
376                                                + ((2 * i + 1) - OFFSET) * SIZE],
377                                     xmm4);
378
379                             vmovups(ptr[AO1
380                                             + (unroll_m * (2 * i) + 3 * 4
381                                                       - OFFSET)
382                                                     * SIZE],
383                                     xmm0);
384                             vmovups(ptr[AO1
385                                             + (unroll_m * (2 * i + 1) + 3 * 4
386                                                       - OFFSET)
387                                                     * SIZE],
388                                     xmm1);
389                         }
390
391                         lea(BO2, ptr[BO2 + LDA * 4]);
392                     }
393                 }
394                 add(BO1, (4 * SIZE));
395             }
396
397             add(AO1, unroll_m * 4 * SIZE);
398             sub(LL, 1);
399             jg(".pack2", T_NEAR);
400             align(16);
401
402             L(".pack3");
403             mov(LL, K);
404             and_(LL, 3);
405             jle(".pack10", T_NEAR);
406             align(16);
407
408             L(".pack4");
409             if (!isTransA) {
410                 if (isLoad1Unmasked) {
411                     vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
412                 } else {
413                     vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
414                 }
415                 if (unroll_m > 8) {
416                     if (isLoad2Unmasked) {
417                         vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
418                     } else {
419                         vmaskmovps(ymm5, VMASK,
420                                 ptr[BO1 + (1 + 8 - OFFSET) * SIZE]);
421                     }
422                 }
423                 add(BO1, LDA);
424                 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
425                         ymm4);
426                 if (unroll_m > 8) {
427                     vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
428                             ymm5);
429                 }
430             } else {
431                 if (isLoad1Unmasked) {
432                     for (int i = 0; i < 2; i++) {
433                         reg = (i % 2 == 0) ? BO1 : BO2;
434                         vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]);
435                         vmovss(xmm0,
436                                 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
437                         lea(BO2, ptr[reg + LDA * 2]);
438                         vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
439                     }
440                     vunpcklpd(xmm1, xmm1, xmm2);
441                     vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
442                             xmm1);
443
444                     for (int i = 0; i < 2; i++) {
445                         vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
446                         vmovss(xmm0,
447                                 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
448                         lea(BO2, ptr[BO2 + LDA * 2]);
449                         vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
450                     }
451                     vunpcklpd(xmm1, xmm1, xmm2);
452                     vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
453                             xmm1);
454                 } else if (is_avx2) {
455                     vmovaps(xmm4, xmm3);
456                     vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE],
457                             xmm4);
458                     lea(BO2, ptr[BO1 + LDA * 4]);
459                     vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
460                             xmm1);
461
462                     vextractf128(xmm4, ymm3, 1);
463                     vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
464                             xmm4);
465                     lea(BO2, ptr[BO2 + LDA * 4]);
466                     vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
467                             xmm1);
468                 } else {
469                     vxorps(xmm4, xmm4, xmm4);
470                     lea(BO2, ptr[BO1 + LDA * 4]);
471
472                     auto el_cp = [&](int section, int ld_step) {
473                         RegExp src_addr = section == 0 ? BO1 : BO2;
474                         if (ld_step == 1 || ld_step == 2)
475                             src_addr = src_addr + LDA * ld_step;
476                         else if (ld_step == 3)
477                             src_addr = src_addr + CO1;
478                         src_addr = src_addr - OFFSET * SIZE;
479
480                         vmovss(xmm1, ptr[src_addr]);
481                         RegExp dst_addr = AO1
482                             + (ld_step + section * 4 - OFFSET) * SIZE;
483                         movss(ptr[dst_addr], xmm1);
484                     };
485
486                     Label l_end;
487                     el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
488                     el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
489                     el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
490                     el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
491                     el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
492                     el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
493                     el_cp(1, 2);
494                     L(l_end);
495
496                     lea(BO2, ptr[BO2 + LDA * 4]);
497                 }
498
499                 if (unroll_m >= 16) {
500                     assert(is_avx2);
501                     if (isLoad2Unmasked) {
502                         for (int i = 0; i < 2; i++) {
503                             vmovss(Xmm(i + 1),
504                                     ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
505                             vmovss(xmm0, ptr[BO2 + LDA * 1
506                                                  + (0 * 8 - OFFSET) * SIZE]);
507                             lea(BO2, ptr[BO2 + LDA * 2]);
508                             vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
509                         }
510                         vunpcklpd(xmm1, xmm1, xmm2);
511                     } else {
512                         vmovaps(xmm4, xmm3);
513                         vgatherqps(xmm1,
514                                 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
515                                 xmm4);
516                         lea(BO2, ptr[BO2 + LDA * 4]);
517                     }
518                     vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE],
519                             xmm1);
520
521                     if (isLoad2Unmasked) {
522                         for (int i = 0; i < 2; i++) {
523                             vmovss(Xmm(i + 1),
524                                     ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
525                             vmovss(xmm0, ptr[BO2 + LDA * 1
526                                                  + (0 * 8 - OFFSET) * SIZE]);
527                             lea(BO2, ptr[BO2 + LDA * 2]);
528                             vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
529                         }
530                         vunpcklpd(xmm1, xmm1, xmm2);
531                     } else {
532                         vextractf128(xmm4, ymm3, 1);
533                         vgatherqps(xmm1,
534                                 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
535                                 xmm4);
536                     }
537                     vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE],
538                             xmm1);
539                 }
540                 add(BO1, SIZE);
541             }
542
543             add(AO1, unroll_m * SIZE);
544             sub(LL, 1);
545             jg(".pack4", T_NEAR);
546             align(16);
547
548             L(".pack10");
549
550             outLocalLabel();
551         };
552
553         // Fused multiply add; may become one or two instructions
554         auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2,
555                 bool overWrite = false) {
556             if (useFma) {
557                 if (is_avx2) {
558                     vfmadd231ps(reg2, reg1, reg0);
559                 } else {
560                     assert(UNROLL_M == 8);
561                     auto tent_vreg = overWrite ? reg1 : ymm1;
562                     vmulps(tent_vreg, reg1, reg0);
563                     vaddps(reg2, reg2, tent_vreg);
564                 }
565             } else {
566                 if (!overWrite) {
567                     vmulps(ymm15, reg1, reg0);
568                     vaddps(reg2, reg2, ymm15);
569                 } else {
570                     vmulps(reg1, reg1, reg0);
571                     vaddps(reg2, reg2, reg1);
572                 }
573             }
574         };
575
576         // Inner kernel with k=8
577         auto innerkernel8 = [&](int unroll_m, int unroll_n,
578                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
579                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
580                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
581                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
582                 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
583                 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
584                 Ymm reg23) {
585
586             Ymm fmareg;
587
588             if (!isDirect) {
589                 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
590             } else {
591                 prefetcht0(ptr[AO1 + LDA4]);
592             }
593
594             for (int i = 0; i < 8; i++) {
595                 if (isDirect) {
596                     if (isLoad1Unmasked) {
597                         vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
598                     } else {
599                         vmaskmovps(ymm0, VMASK,
600                                 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
601                     }
602                     if (unroll_m >= 16) {
603                         if (isLoad2Unmasked) {
604                             vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
605                         } else {
606                             vmaskmovps(ymm1, VMASK,
607                                     ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
608                         }
609                     }
610                     add(AO1, LDA);
611                 }
612
613                 if (!isTransB) {
614                     vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
615                 } else {
616                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
617                 }
618                 fmareg = (i % 2 == 0) ? reg00 : reg12;
619                 fma(useFma, ymm0, ymm2, fmareg);
620                 if (unroll_m >= 16) {
621                     fmareg = (i % 2 == 0) ? reg06 : reg18;
622                     fma(useFma, ymm1, ymm2, fmareg);
623                 }
624                 if (i == 0) {
625                     if (!isTransB) {
626                         prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
627                     }
628                 }
629                 if (unroll_n >= 2) {
630                     if (!isTransB) {
631                         if (i == 1) {
632                             prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
633                         }
634                         vbroadcastss(
635                                 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
636                     } else {
637                         vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
638                     }
639                     fmareg = (i % 2 == 0) ? reg01 : reg13;
640                     fma(useFma, ymm0, ymm2, fmareg);
641                     if (unroll_m >= 16) {
642                         fmareg = (i % 2 == 0) ? reg07 : reg19;
643                         fma(useFma, ymm1, ymm2, fmareg);
644                     }
645                 }
646
647                 if (isCopy) {
648                     vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
649                             ymm0);
650                     if (unroll_m >= 16) {
651                         vmovups(ptr[LDA4
652                                         + (unroll_m * i + 1 * 8 - OFFSET)
653                                                 * SIZE],
654                                 ymm1);
655                     }
656                     if (i == 7) {
657                         sub(LDA4, -unroll_m * 8 * SIZE);
658                     }
659                 }
660
661                 if (unroll_n >= 3) {
662                     if (!isTransB) {
663                         if (i == 2) {
664                             prefetcht0(
665                                     ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
666                         }
667                         vbroadcastss(
668                                 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
669                     } else {
670                         vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
671                     }
672                     fmareg = (i % 2 == 0) ? reg02 : reg14;
673                     fma(useFma, ymm0, ymm2, fmareg);
674                     if (unroll_m >= 16) {
675                         fmareg = (i % 2 == 0) ? reg08 : reg20;
676                         fma(useFma, ymm1, ymm2, fmareg);
677                     }
678                 }
679
680                 if (i == 7) {
681                     if (!isTransB) {
682                         sub(BO1, -8 * SIZE);
683                     }
684                 }
685
686                 if (unroll_n >= 4) {
687                     if (!isTransB) {
688                         if (i == 3) {
689                             prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
690                         }
691                         vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
692                     } else {
693                         vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
694                     }
695                     fmareg = (i % 2 == 0) ? reg03 : reg15;
696                     fma(useFma, ymm0, ymm2, fmareg);
697                     if (unroll_m >= 16) {
698                         fmareg = (i % 2 == 0) ? reg09 : reg21;
699                         fma(useFma, ymm1, ymm2, fmareg);
700                     }
701                 }
702
703                 if (unroll_n >= 5) {
704                     if (!isTransB) {
705                         if (i == 4) {
706                             prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
707                         }
708                         vbroadcastss(
709                                 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
710                     } else {
711                         vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
712                     }
713                     fmareg = (i % 2 == 0) ? reg04 : reg16;
714                     fma(useFma, ymm0, ymm2, fmareg);
715                     if (unroll_m >= 16) {
716                         fmareg = (i % 2 == 0) ? reg10 : reg22;
717                         fma(useFma, ymm1, ymm2, fmareg);
718                     }
719                 }
720
721                 if (unroll_n >= 6) {
722                     if (!isTransB) {
723                         if (i == 5) {
724                             prefetcht0(
725                                     ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
726                         }
727                         vbroadcastss(
728                                 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
729                     } else {
730                         vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
731                     }
732                     fmareg = (i % 2 == 0) ? reg05 : reg17;
733                     fma(useFma, ymm0, ymm2, fmareg);
734                     if (unroll_m >= 16) {
735                         fmareg = (i % 2 == 0) ? reg11 : reg23;
736                         fma(useFma, ymm1, ymm2, fmareg);
737                     }
738                 }
739                 if (isTransB) {
740                     prefetcht0(ptr[BO1 + BO2]);
741                     add(BO1, LDB);
742                 }
743
744                 if (i == 0) {
745                     if (unroll_m >= 4) {
746                         if (!isDirect) {
747                             prefetcht0(
748                                     ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
749                         } else {
750                             prefetcht0(ptr[AO1 + LDA4]);
751                         }
752                     }
753                 }
754                 if (i == 1 || i == 2) {
755                     if (unroll_m >= 8) {
756                         if (!isDirect) {
757                             prefetcht0(ptr[AO1
758                                     + (PREFETCHSIZEA + (2 + 2 * i) * 8)
759                                             * SIZE]);
760                         } else {
761                             prefetcht0(ptr[AO1 + LDA4]);
762                         }
763                     }
764                 }
765                 if (i == 3 || i == 4 || i == 5 || i == 6) {
766                     if (unroll_m >= 16) {
767                         if (!isDirect) {
768                             prefetcht0(ptr[AO1
769                                     + (PREFETCHSIZEA + (2 + 2 * i) * 8)
770                                             * SIZE]);
771                         } else {
772                             prefetcht0(ptr[AO1 + LDA4]);
773                         }
774                     }
775                 }
776                 if (i == 7) {
777                     if (!isTransB) {
778                         if (unroll_n >= 4) {
779                             sub(BO2, -8 * SIZE);
780                         }
781                     }
782                     if (!isTransA) {
783                         prefetcht2(ptr[AA]);
784                         lea(AA, ptr[AA + LDA]);
785                     }
786                 }
787
788                 if (!isDirect) {
789                     if (isLoad1Unmasked) {
790                         vmovups(ymm0,
791                                 ptr[AO1
792                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
793                                                 * SIZE]);
794                     } else {
795                         vmaskmovps(
796                                 ymm0, VMASK,
797                                 ptr[AO1
798                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
799                                                 * SIZE]);
800                     }
801                     if (unroll_m >= 16) {
802                         if (isLoad2Unmasked) {
803                             vmovups(ymm1, ptr[AO1
804                                                   + (unroll_m * (i + 1) + 1 * 8
805                                                             - OFFSET)
806                                                           * SIZE]);
807                         } else {
808                             vmaskmovps(ymm1, VMASK,
809                                     ptr[AO1
810                                                + (unroll_m * (i + 1) + 1 * 8
811                                                          - OFFSET)
812                                                        * SIZE]);
813                         }
814                     }
815                 }
816             }
817
818             if (!isDirect) {
819                 sub(AO1, -unroll_m * 8 * SIZE);
820             }
821             sub(LL, 1);
822
823         };
824
825         // Inner kernel with k=4
826         auto innerkernel4 = [&](int unroll_m, int unroll_n,
827                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
828                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
829                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
830                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
831                 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
832                 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
833                 Ymm reg23) {
834
835             Ymm fmareg;
836
837             if (!isDirect) {
838                 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
839             } else {
840                 prefetcht0(ptr[AO1 + LDA4]);
841             }
842
843             for (int i = 0; i < 4; i++) {
844                 if (isDirect) {
845                     if (isLoad1Unmasked) {
846                         vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
847                     } else {
848                         vmaskmovps(ymm0, VMASK,
849                                 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
850                     }
851                     if (unroll_m >= 16) {
852                         if (isLoad2Unmasked) {
853                             vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
854                         } else {
855                             vmaskmovps(ymm1, VMASK,
856                                     ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
857                         }
858                     }
859                     add(AO1, LDA);
860                 }
861
862                 if (!isTransB) {
863                     vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
864                 } else {
865                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
866                 }
867                 fmareg = (i % 2 == 0) ? reg00 : reg12;
868                 fma(useFma, ymm0, ymm2, fmareg);
869                 if (unroll_m >= 16) {
870                     fmareg = (i % 2 == 0) ? reg06 : reg18;
871                     fma(useFma, ymm1, ymm2, fmareg);
872                 }
873                 if (i == 0) {
874                     if (!isTransB) {
875                         prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
876                     }
877                 }
878                 if (unroll_n >= 2) {
879                     if (!isTransB) {
880                         if (i == 1) {
881                             prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
882                         }
883                         vbroadcastss(
884                                 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
885                     } else {
886                         vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
887                     }
888                     fmareg = (i % 2 == 0) ? reg01 : reg13;
889                     fma(useFma, ymm0, ymm2, fmareg);
890                     if (unroll_m >= 16) {
891                         fmareg = (i % 2 == 0) ? reg07 : reg19;
892                         fma(useFma, ymm1, ymm2, fmareg);
893                     }
894                 }
895
896                 if (isCopy) {
897                     vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
898                             ymm0);
899                     if (unroll_m >= 16) {
900                         vmovups(ptr[LDA4
901                                         + (unroll_m * i + 1 * 8 - OFFSET)
902                                                 * SIZE],
903                                 ymm1);
904                     }
905                     if (i == 3) {
906                         sub(LDA4, -unroll_m * 4 * SIZE);
907                     }
908                 }
909
910                 if (unroll_n >= 3) {
911                     if (!isTransB) {
912                         if (i == 2) {
913                             prefetcht0(
914                                     ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
915                         }
916                         vbroadcastss(
917                                 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
918                     } else {
919                         vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
920                     }
921                     fmareg = (i % 2 == 0) ? reg02 : reg14;
922                     fma(useFma, ymm0, ymm2, fmareg);
923                     if (unroll_m >= 16) {
924                         fmareg = (i % 2 == 0) ? reg08 : reg20;
925                         fma(useFma, ymm1, ymm2, fmareg);
926                     }
927                 }
928
929                 if (i == 7) {
930                     if (!isTransB) {
931                         sub(BO1, -8 * SIZE);
932                     }
933                 }
934
935                 if (unroll_n >= 4) {
936                     if (!isTransB) {
937                         if (i == 3) {
938                             prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
939                         }
940                         vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
941                     } else {
942                         vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
943                     }
944                     fmareg = (i % 2 == 0) ? reg03 : reg15;
945                     fma(useFma, ymm0, ymm2, fmareg);
946                     if (unroll_m >= 16) {
947                         fmareg = (i % 2 == 0) ? reg09 : reg21;
948                         fma(useFma, ymm1, ymm2, fmareg);
949                     }
950                 }
951
952                 if (unroll_n >= 5) {
953                     if (!isTransB) {
954                         if (i == 4) {
955                             prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
956                         }
957                         vbroadcastss(
958                                 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
959                     } else {
960                         vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
961                     }
962                     fmareg = (i % 2 == 0) ? reg04 : reg16;
963                     fma(useFma, ymm0, ymm2, fmareg);
964                     if (unroll_m >= 16) {
965                         fmareg = (i % 2 == 0) ? reg10 : reg22;
966                         fma(useFma, ymm1, ymm2, fmareg);
967                     }
968                 }
969
970                 if (unroll_n >= 6) {
971                     if (!isTransB) {
972                         if (i == 5) {
973                             prefetcht0(
974                                     ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
975                         }
976                         vbroadcastss(
977                                 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
978                     } else {
979                         vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
980                     }
981                     fmareg = (i % 2 == 0) ? reg05 : reg17;
982                     fma(useFma, ymm0, ymm2, fmareg);
983                     if (unroll_m >= 16) {
984                         fmareg = (i % 2 == 0) ? reg11 : reg23;
985                         fma(useFma, ymm1, ymm2, fmareg);
986                     }
987                 }
988                 if (isTransB) {
989                     prefetcht0(ptr[BO1 + BO2]);
990                     add(BO1, LDB);
991                 }
992
993                 if (i == 0) {
994                     if (unroll_m >= 4) {
995                         if (!isDirect) {
996                             prefetcht0(
997                                     ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
998                         } else {
999                             prefetcht0(ptr[AO1 + LDA4]);
1000                         }
1001                     }
1002                 }
1003                 if (i == 1 || i == 2) {
1004                     if (unroll_m >= 8) {
1005                         if (!isDirect) {
1006                             prefetcht0(ptr[AO1
1007                                     + (PREFETCHSIZEA + (2 + 2 * i) * 8)
1008                                             * SIZE]);
1009                         } else {
1010                             prefetcht0(ptr[AO1 + LDA4]);
1011                         }
1012                     }
1013                 }
1014                 if (i == 3) {
1015                     if (!isTransB) {
1016                         sub(BO1, -4 * SIZE);
1017                         if (unroll_n >= 4) {
1018                             sub(BO2, -4 * SIZE);
1019                         }
1020                     }
1021                 }
1022
1023                 if (!isDirect) {
1024                     if (isLoad1Unmasked) {
1025                         vmovups(ymm0,
1026                                 ptr[AO1
1027                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1028                                                 * SIZE]);
1029                     } else {
1030                         vmaskmovps(
1031                                 ymm0, VMASK,
1032                                 ptr[AO1
1033                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1034                                                 * SIZE]);
1035                     }
1036                     if (unroll_m >= 16) {
1037                         if (isLoad2Unmasked) {
1038                             vmovups(ymm1, ptr[AO1
1039                                                   + (unroll_m * (i + 1) + 1 * 8
1040                                                             - OFFSET)
1041                                                           * SIZE]);
1042                         } else {
1043                             vmaskmovps(ymm1, VMASK,
1044                                     ptr[AO1
1045                                                + (unroll_m * (i + 1) + 1 * 8
1046                                                          - OFFSET)
1047                                                        * SIZE]);
1048                         }
1049                     }
1050                 }
1051             }
1052
1053             if (!isDirect) {
1054                 sub(AO1, -unroll_m * 4 * SIZE);
1055             }
1056
1057         };
1058
1059         // Inner kernel with k=2
1060         auto innerkernel2 = [&](int unroll_m, int unroll_n,
1061                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1062                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1063                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1064                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
1065                 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
1066                 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
1067                 Ymm reg23) {
1068
1069             Ymm fmareg;
1070
1071             for (int i = 0; i < 2; i++) {
1072                 if (isDirect) {
1073                     if (isLoad1Unmasked) {
1074                         vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1075                     } else {
1076                         vmaskmovps(ymm0, VMASK,
1077                                 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1078                     }
1079                     if (unroll_m >= 16) {
1080                         if (isLoad2Unmasked) {
1081                             vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1082                         } else {
1083                             vmaskmovps(ymm1, VMASK,
1084                                     ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1085                         }
1086                     }
1087                     add(AO1, LDA);
1088                 }
1089
1090                 if (!isTransB) {
1091                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1092                 } else {
1093                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1094                 }
1095                 fmareg = (i % 2 == 0) ? reg00 : reg12;
1096                 fma(useFma, ymm0, ymm2, fmareg);
1097                 if (unroll_m >= 16) {
1098                     fmareg = (i % 2 == 0) ? reg06 : reg18;
1099                     fma(useFma, ymm1, ymm2, fmareg);
1100                 }
1101                 if (unroll_n >= 2) {
1102                     if (!isTransB) {
1103                         vbroadcastss(
1104                                 ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1105                     } else {
1106                         vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1107                     }
1108                     fmareg = (i % 2 == 0) ? reg01 : reg13;
1109                     fma(useFma, ymm0, ymm2, fmareg);
1110                     if (unroll_m >= 16) {
1111                         fmareg = (i % 2 == 0) ? reg07 : reg19;
1112                         fma(useFma, ymm1, ymm2, fmareg);
1113                     }
1114                 }
1115
1116                 if (unroll_n >= 3) {
1117                     if (!isTransB) {
1118                         if (i == 2) {
1119                             prefetcht0(
1120                                     ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
1121                         }
1122                         vbroadcastss(
1123                                 ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1124                     } else {
1125                         vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1126                     }
1127                     fmareg = (i % 2 == 0) ? reg02 : reg14;
1128                     fma(useFma, ymm0, ymm2, fmareg);
1129                     if (unroll_m >= 16) {
1130                         fmareg = (i % 2 == 0) ? reg08 : reg20;
1131                         fma(useFma, ymm1, ymm2, fmareg);
1132                     }
1133                 }
1134
1135                 if (unroll_n >= 4) {
1136                     if (!isTransB) {
1137                         vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1138                     } else {
1139                         vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1140                     }
1141                     fmareg = (i % 2 == 0) ? reg03 : reg15;
1142                     fma(useFma, ymm0, ymm2, fmareg);
1143                     if (unroll_m >= 16) {
1144                         fmareg = (i % 2 == 0) ? reg09 : reg21;
1145                         fma(useFma, ymm1, ymm2, fmareg);
1146                     }
1147                 }
1148
1149                 if (unroll_n >= 5) {
1150                     if (!isTransB) {
1151                         vbroadcastss(
1152                                 ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1153                     } else {
1154                         vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1155                     }
1156                     fmareg = (i % 2 == 0) ? reg04 : reg16;
1157                     fma(useFma, ymm0, ymm2, fmareg);
1158                     if (unroll_m >= 16) {
1159                         fmareg = (i % 2 == 0) ? reg10 : reg22;
1160                         fma(useFma, ymm1, ymm2, fmareg);
1161                     }
1162                 }
1163
1164                 if (unroll_n >= 6) {
1165                     if (!isTransB) {
1166                         vbroadcastss(
1167                                 ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1168                     } else {
1169                         vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1170                     }
1171                     fmareg = (i % 2 == 0) ? reg05 : reg17;
1172                     fma(useFma, ymm0, ymm2, fmareg);
1173                     if (unroll_m >= 16) {
1174                         fmareg = (i % 2 == 0) ? reg11 : reg23;
1175                         fma(useFma, ymm1, ymm2, fmareg);
1176                     }
1177                 }
1178
1179                 if (isCopy) {
1180                     vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1181                             ymm0);
1182                     if (unroll_m >= 16) {
1183                         vmovups(ptr[LDA4
1184                                         + (unroll_m * 0 + 1 * 8 - OFFSET)
1185                                                 * SIZE],
1186                                 ymm1);
1187                     }
1188                     sub(LDA4, -unroll_m * SIZE);
1189                 }
1190
1191                 if (!isDirect) {
1192                     if (isLoad1Unmasked) {
1193                         vmovups(ymm0, ptr[AO1
1194                                               + (unroll_m * 1 + 0 * 8 - OFFSET)
1195                                                       * SIZE]);
1196                     } else {
1197                         vmaskmovps(ymm0, VMASK,
1198                                 ptr[AO1
1199                                            + (unroll_m * 1 + 0 * 8 - OFFSET)
1200                                                    * SIZE]);
1201                     }
1202                     if (unroll_m >= 16) {
1203                         if (isLoad2Unmasked) {
1204                             vmovups(ymm1,
1205                                     ptr[AO1
1206                                             + (unroll_m * 1 + 1 * 8 - OFFSET)
1207                                                     * SIZE]);
1208                         } else {
1209                             vmaskmovps(ymm1, VMASK,
1210                                     ptr[AO1
1211                                                + (unroll_m * 1 + 1 * 8 - OFFSET)
1212                                                        * SIZE]);
1213                         }
1214                     }
1215                     sub(AO1, -unroll_m * SIZE);
1216                 }
1217
1218                 if (!isTransB) {
1219                     sub(BO1, -SIZE);
1220                     if (unroll_n >= 4) {
1221                         sub(BO2, -SIZE);
1222                     }
1223                 } else {
1224                     add(BO1, LDB);
1225                 }
1226             }
1227
1228         };
1229
1230         // Inner kernel with k=1
1231         auto innerkernel1 = [&](int unroll_m, int unroll_n,
1232                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1233                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1234                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1235                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) {
1236
1237             if (isDirect) {
1238                 if (isLoad1Unmasked) {
1239                     vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1240                 } else {
1241                     vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1242                 }
1243                 if (unroll_m >= 16) {
1244                     if (isLoad2Unmasked) {
1245                         vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1246                     } else {
1247                         vmaskmovps(ymm1, VMASK,
1248                                 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1249                     }
1250                 }
1251                 add(AO1, LDA);
1252             }
1253
1254             if (!isTransB) {
1255                 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1256             } else {
1257                 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1258             }
1259             fma(useFma, ymm0, ymm2, reg00);
1260             if (unroll_m >= 16) {
1261                 fma(useFma, ymm1, ymm2, reg06);
1262             }
1263
1264             if (unroll_n >= 2) {
1265                 if (!isTransB) {
1266                     vbroadcastss(
1267                             ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1268                 } else {
1269                     vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1270                 }
1271                 fma(useFma, ymm0, ymm2, reg01);
1272                 if (unroll_m >= 16) {
1273                     fma(useFma, ymm1, ymm2, reg07);
1274                 }
1275             }
1276
1277             if (unroll_n >= 3) {
1278                 if (!isTransB) {
1279                     vbroadcastss(
1280                             ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1281                 } else {
1282                     vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1283                 }
1284                 fma(useFma, ymm0, ymm2, reg02);
1285                 if (unroll_m >= 16) {
1286                     fma(useFma, ymm1, ymm2, reg08);
1287                 }
1288             }
1289
1290             if (unroll_n >= 4) {
1291                 if (!isTransB) {
1292                     vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1293                 } else {
1294                     vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1295                 }
1296                 fma(useFma, ymm0, ymm2, reg03);
1297                 if (unroll_m >= 16) {
1298                     fma(useFma, ymm1, ymm2, reg09);
1299                 }
1300             }
1301
1302             if (unroll_n >= 5) {
1303                 if (!isTransB) {
1304                     vbroadcastss(
1305                             ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1306                 } else {
1307                     vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1308                 }
1309                 fma(useFma, ymm0, ymm2, reg04);
1310                 if (unroll_m >= 16) {
1311                     fma(useFma, ymm1, ymm2, reg10);
1312                 }
1313             }
1314
1315             if (unroll_n >= 6) {
1316                 if (!isTransB) {
1317                     vbroadcastss(
1318                             ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1319                 } else {
1320                     vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1321                 }
1322                 fma(useFma, ymm0, ymm2, reg05);
1323                 if (unroll_m >= 16) {
1324                     fma(useFma, ymm1, ymm2, reg11);
1325                 }
1326             }
1327
1328             if (isCopy) {
1329                 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1330                         ymm0);
1331                 if (unroll_m >= 16) {
1332                     vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
1333                             ymm1);
1334                 }
1335                 sub(LDA4, -unroll_m * SIZE);
1336             }
1337
1338             if (!isDirect) {
1339                 if (isLoad1Unmasked) {
1340                     vmovups(ymm0,
1341                             ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1342                 } else {
1343                     vmaskmovps(ymm0, VMASK,
1344                             ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1345                 }
1346                 if (unroll_m >= 16) {
1347                     if (isLoad2Unmasked) {
1348                         vmovups(ymm1, ptr[AO1
1349                                               + (unroll_m * 1 + 1 * 8 - OFFSET)
1350                                                       * SIZE]);
1351                     } else {
1352                         vmaskmovps(ymm1, VMASK,
1353                                 ptr[AO1
1354                                            + (unroll_m * 1 + 1 * 8 - OFFSET)
1355                                                    * SIZE]);
1356                     }
1357                 }
1358                 sub(AO1, -unroll_m * SIZE);
1359             }
1360
1361             if (!isTransB) {
1362                 sub(BO1, -SIZE);
1363                 if (unroll_n >= 4) {
1364                     sub(BO2, -SIZE);
1365                 }
1366             } else {
1367                 add(BO1, LDB);
1368             }
1369
1370         };
1371
1372         // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as
1373         // appropriate
1374         // After calculating results in registers, writes back to C matrix
1375         auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1376                 bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma,
1377                 Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6),
1378                 Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9),
1379                 Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12),
1380                 Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15),
1381                 Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6),
1382                 Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
1383                 Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
1384                 Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
1385             inLocalLabel();
1386
1387             if (!isDirect) {
1388                 lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
1389             } else {
1390                 mov(AO1, A);
1391             }
1392
1393             if (isCopy) {
1394                 lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
1395             } else {
1396                 lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
1397             }
1398
1399             if (isTransB) {
1400                 lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]);
1401                 lea(BO2, ptr[BO2 + LDB * 2]);
1402             }
1403
1404             if (!isDirect) {
1405                 if (isLoad1Unmasked) {
1406                     vmovups(ymm0,
1407                             ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1408                 } else {
1409                     vmaskmovps(ymm0, VMASK,
1410                             ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1411                 }
1412                 if (unroll_m >= 16) {
1413                     if (isLoad2Unmasked) {
1414                         vmovups(ymm1, ptr[AO1
1415                                               + (unroll_m * 0 + 1 * 8 - OFFSET)
1416                                                       * SIZE]);
1417                     } else {
1418                         vmaskmovps(ymm1, VMASK,
1419                                 ptr[AO1
1420                                            + (unroll_m * 0 + 1 * 8 - OFFSET)
1421                                                    * SIZE]);
1422                     }
1423                 }
1424             }
1425
1426             for (int i = 4; i < 10; i++) {
1427                 vxorps(Ymm(i), Ymm(i), Ymm(i));
1428                 vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6));
1429             }
1430
1431             mov(LL, K);
1432             sar(LL, 3);
1433
1434             sub(LL, SECOND_FETCH);
1435             jle(".kernel13", T_NEAR);
1436             align(16);
1437
1438             L(".kernel12");
1439             innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1440                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1441                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1442                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1443                     reg21, reg22, reg23);
1444             jg(".kernel12", T_NEAR);
1445             align(16);
1446
1447             L(".kernel13");
1448             prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
1449             if (unroll_n >= 2)
1450                 prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
1451             if (unroll_n >= 3)
1452                 prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]);
1453             if (unroll_n >= 4)
1454                 prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]);
1455             if (unroll_n >= 5)
1456                 prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]);
1457             if (unroll_n >= 6)
1458                 prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
1459
1460             add(LL, SECOND_FETCH);
1461             jle(".kernel15", T_NEAR);
1462             align(16);
1463
1464             L(".kernel14");
1465             innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1466                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1467                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1468                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1469                     reg21, reg22, reg23);
1470             jg(".kernel14", T_NEAR);
1471             align(16);
1472
1473             L(".kernel15");
1474             test(K, 4);
1475             jle(".kernel16", T_NEAR);
1476             innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1477                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1478                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1479                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1480                     reg21, reg22, reg23);
1481
1482             L(".kernel16");
1483             test(K, 2);
1484             jle(".kernel17", T_NEAR);
1485             innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1486                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1487                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1488                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1489                     reg21, reg22, reg23);
1490             align(16);
1491
1492             L(".kernel17");
1493             if (unroll_m == 16) {
1494                 if (unroll_n <= 3) {
1495                     vaddps(reg00, reg00, reg12);
1496                     vaddps(reg01, reg01, reg13);
1497                     vaddps(reg02, reg02, reg14);
1498                     vaddps(reg06, reg06, reg18);
1499                     vaddps(reg07, reg07, reg19);
1500                     vaddps(reg08, reg08, reg20);
1501                 }
1502             }
1503
1504             if (unroll_m <= 8) {
1505                 vaddps(reg00, reg00, reg12);
1506                 vaddps(reg01, reg01, reg13);
1507                 vaddps(reg02, reg02, reg14);
1508                 vaddps(reg03, reg03, reg15);
1509                 vaddps(reg04, reg04, reg16);
1510                 vaddps(reg05, reg05, reg17);
1511             }
1512
1513             test(K, 1);
1514             jle(".kernel18", T_NEAR);
1515             innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1516                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1517                     reg05, reg06, reg07, reg08, reg09, reg10, reg11);
1518             align(16);
1519
1520             L(".kernel18");
1521             vbroadcastss(VALPHA, ALPHA);
1522
1523             if (isBetaN) {
1524                 vbroadcastss(VBETA, BETA);
1525             }
1526
1527             // Write back the results; all beta and bias cases need to be
1528             // handled
1529             switch (unroll_n) {
1530             case 1: mov(rax, LDC); break;
1531             case 2: lea(rax, ptr[LDC * 2]); break;
1532             case 3: lea(rax, ptr[LDC + LDC * 2]); break;
1533             case 4: lea(rax, ptr[LDC + LDC * 4]); break;
1534             case 5:
1535                 lea(rax, ptr[LDC * 4]);
1536                 add(rax, LDC);
1537                 break;
1538             case 6:
1539                 lea(rax, ptr[LDC + LDC * 2]);
1540                 add(rax, rax);
1541                 break;
1542             }
1543
1544             if (hasBias) {
1545                 mov(BIAS1, BIAS);
1546                 if (isLoad1Unmasked) {
1547                     vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1548                 } else {
1549                     vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]);
1550                 }
1551             }
1552
1553             for (int i = 0; i < unroll_n; i++) {
1554                 vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA);
1555                 if (!isBeta0) {
1556                     if (isLoad1Unmasked) {
1557                         switch (i) {
1558                         case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break;
1559                         case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break;
1560                         case 2:
1561                             vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1562                             break;
1563                         case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break;
1564                         case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break;
1565                         case 5:
1566                             vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1567                             break;
1568                         }
1569                     } else {
1570                         switch (i) {
1571                         case 0:
1572                             vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]);
1573                             break;
1574                         case 1:
1575                             vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]);
1576                             break;
1577                         case 2:
1578                             vmaskmovps(
1579                                     ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1580                             break;
1581                         case 3:
1582                             vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]);
1583                             break;
1584                         case 4:
1585                             vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]);
1586                             break;
1587                         case 5:
1588                             vmaskmovps(
1589                                     ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1590                             break;
1591                         }
1592                     }
1593
1594                     if (!isBetaN) {
1595                         vaddps(Ymm(i + 4), ymm0, Ymm(i + 4));
1596                     } else {
1597                         fma(useFma, VBETA, ymm0, Ymm(i + 4), true);
1598                     }
1599                 }
1600                 if (hasBias) {
1601                     vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4));
1602                 }
1603                 if (isLoad1Unmasked) {
1604                     switch (i) {
1605                     case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break;
1606                     case 1:
1607                         vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4));
1608                         break;
1609                     case 2:
1610                         vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1611                         break;
1612                     case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break;
1613                     case 4:
1614                         vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4));
1615                         break;
1616                     case 5:
1617                         vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1618                         break;
1619                     }
1620                 } else {
1621                     switch (i) {
1622                     case 0:
1623                         vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4));
1624                         break;
1625                     case 1:
1626                         vmaskmovps(
1627                                 ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1628                         break;
1629                     case 2:
1630                         vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK,
1631                                 Ymm(i + 4));
1632                         break;
1633                     case 3:
1634                         vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4));
1635                         break;
1636                     case 4:
1637                         vmaskmovps(
1638                                 ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1639                         break;
1640                     case 5:
1641                         vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK,
1642                                 Ymm(i + 4));
1643                         break;
1644                     }
1645                 }
1646
1647                 if (unroll_m >= 16) {
1648                     // Re-use ymm4 (VBIAS2)
1649                     if (i == 0) {
1650                         if (hasBias) {
1651                             if (isLoad1Unmasked) {
1652                                 vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]);
1653                             } else {
1654                                 vmaskmovps(
1655                                         VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]);
1656                             }
1657                         }
1658                     }
1659                     vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA);
1660                     if (!isBeta0) {
1661                         if (isLoad2Unmasked) {
1662                             switch (i) {
1663                             case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break;
1664                             case 1:
1665                                 vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]);
1666                                 break;
1667                             case 2:
1668                                 vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]);
1669                                 break;
1670                             case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break;
1671                             case 4:
1672                                 vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]);
1673                                 break;
1674                             case 5:
1675                                 vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]);
1676                                 break;
1677                             }
1678                         } else {
1679                             switch (i) {
1680                             case 0:
1681                                 vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]);
1682                                 break;
1683                             case 1:
1684                                 vmaskmovps(
1685                                         ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]);
1686                                 break;
1687                             case 2:
1688                                 vmaskmovps(ymm0, VMASK,
1689                                         ptr[CO1 + LDC * 2 + 8 * SIZE]);
1690                                 break;
1691                             case 3:
1692                                 vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]);
1693                                 break;
1694                             case 4:
1695                                 vmaskmovps(
1696                                         ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]);
1697                                 break;
1698                             case 5:
1699                                 vmaskmovps(ymm0, VMASK,
1700                                         ptr[CO2 + LDC * 2 + 8 * SIZE]);
1701                                 break;
1702                             }
1703                         }
1704                         if (!isBetaN) {
1705                             vaddps(Ymm(i + 10), ymm0, Ymm(i + 10));
1706                         } else {
1707                             fma(useFma, VBETA, ymm0, Ymm(i + 10), true);
1708                         }
1709                     }
1710                     if (hasBias) {
1711                         vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10));
1712                     }
1713                     if (isLoad2Unmasked) {
1714                         switch (i) {
1715                         case 0:
1716                             vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10));
1717                             break;
1718                         case 1:
1719                             vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10));
1720                             break;
1721                         case 2:
1722                             vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1723                             break;
1724                         case 3:
1725                             vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10));
1726                             break;
1727                         case 4:
1728                             vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10));
1729                             break;
1730                         case 5:
1731                             vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1732                             break;
1733                         }
1734                     } else {
1735                         switch (i) {
1736                         case 0:
1737                             vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10));
1738                             break;
1739                         case 1:
1740                             vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK,
1741                                     Ymm(i + 10));
1742                             break;
1743                         case 2:
1744                             vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK,
1745                                     Ymm(i + 10));
1746                             break;
1747                         case 3:
1748                             vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10));
1749                             break;
1750                         case 4:
1751                             vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK,
1752                                     Ymm(i + 10));
1753                             break;
1754                         case 5:
1755                             vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK,
1756                                     Ymm(i + 10));
1757                             break;
1758                         }
1759                     }
1760                 }
1761                 if (i == 2)
1762                     add(CO1, rax);
1763             }
1764             if (unroll_n >= 4) {
1765                 add(CO2, rax);
1766             }
1767
1768             // Compute next address of B
1769             if (!isTransB) {
1770                 lea(rax, ptr[K * SIZE]);
1771                 switch (unroll_n) {
1772                 case 1:
1773                     add(BO1, LDB);
1774                     add(BO2, LDB);
1775                     break;
1776                 case 2:
1777                     lea(BO1, ptr[BO1 + LDB * 2]);
1778                     lea(BO2, ptr[BO2 + LDB * 2]);
1779                     break;
1780                 case 3:
1781                     lea(BO1, ptr[BO1 + LDB3]);
1782                     lea(BO2, ptr[BO2 + LDB3]);
1783                     break;
1784                 case 4:
1785                     lea(BO1, ptr[BO1 + LDB * 4]);
1786                     lea(BO2, ptr[BO2 + LDB * 4]);
1787                     break;
1788                 case 5:
1789                     lea(BO1, ptr[BO1 + LDB * 4]);
1790                     add(BO1, LDB);
1791                     lea(BO2, ptr[BO2 + LDB * 4]);
1792                     add(BO2, LDB);
1793                     break;
1794                 case 6:
1795                     lea(BO1, ptr[BO1 + LDB3 * 2]);
1796                     lea(BO2, ptr[BO2 + LDB3 * 2]);
1797                     break;
1798                 }
1799                 sub(BO1, rax);
1800                 sub(BO2, rax);
1801             } else {
1802                 mov(rax, LDB);
1803                 imul(rax, K);
1804                 sub(BO1, rax);
1805                 add(BO1, unroll_n * SIZE);
1806             }
1807
1808             outLocalLabel();
1809         };
1810
1811         auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1812                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1813             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1814                     isDirect, isCopy, true);
1815         };
1816
1817         auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1818                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1819             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1820                     isDirect, isCopy, true);
1821         };
1822
1823         auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1824                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1825             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1826                     isDirect, isCopy, true);
1827         };
1828
1829         auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1830                 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1831                 bool useFma = true) {
1832             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1833                     isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1834                     Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1835                     Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1836                     Ymm(13), Ymm(14), Ymm(15));
1837         };
1838
1839         auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1840                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1841             kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1842                     isDirect, isCopy, false);
1843         };
1844
1845         auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1846                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1847             kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1848                     isDirect, isCopy, false);
1849         };
1850
1851         auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1852                 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1853                 bool useFma = true) {
1854             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1855                     isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1856                     Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1857                     Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1858                     Ymm(15));
1859         };
1860
1861         auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1862                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1863             kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1864                     isDirect, isCopy);
1865         };
1866
1867         auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1868                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1869             kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1870                     isDirect, isCopy);
1871         };
1872
1873         auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1874                 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1875                 bool useFma = true) {
1876             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1877                     isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1878                     Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1879                     Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1880                     Ymm(13), Ymm(14), Ymm(15));
1881         };
1882
1883         auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1884                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1885             kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1886                     isDirect, isCopy, false);
1887         };
1888
1889         auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1890                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1891             kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1892                     isDirect, isCopy, false);
1893         };
1894
1895         // High-level subroutine; does packing if needed, then splits C matrix.
1896         // Operates on chunks of 16 rows, 6 columns at a time (handling tail
1897         // cases appropriately).
1898         // Masking is used for tail cases where M is not divisible by 8.
1899         auto subloop = [&](
1900                 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
1901             inLocalLabel();
1902
1903             if (isTransA) {
1904                 do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
1905             }
1906
1907             mov(CO1, C);
1908             lea(CO2, ptr[CO1 + LDC * 2]);
1909             add(CO2, LDC);
1910             add(C, unroll_m * SIZE);
1911             mov(BO1, B);
1912             if (!isTransB) {
1913                 lea(BO2, qword[B + LDB3]);
1914             }
1915
1916             if (!isTransA) {
1917                 lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
1918                 cmp(M, UNROLL_M);
1919                 jg(".subloop98", T_NEAR);
1920
1921                 mov(AA, ORIG_A);
1922                 lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
1923                 L(".subloop98");
1924             }
1925
1926             mov(LL, N);
1927             mov(I, LL);
1928             if (!isTransA) {
1929                 // If N is too small, skip copy operation
1930                 cmp(LL, UNROLL_N * 3);
1931                 jle(".subloop30", T_NEAR);
1932
1933                 // If A is not aligned to cache line
1934                 cmp(FLAG, 0);
1935                 je(".subloop30", T_NEAR);
1936             } else {
1937                 cmp(LL, UNROLL_N);
1938                 jl(".subloop20", T_NEAR);
1939             }
1940             align(16);
1941
1942             if (!isTransA) {
1943                 if (unroll_m == 16) {
1944                     kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1945                             isLoad2Unmasked, true, true);
1946                 } else {
1947                     kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1948                             isLoad2Unmasked, true, true);
1949                 }
1950             } else {
1951                 if (unroll_m == 16) {
1952                     kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1953                             isLoad2Unmasked, false, false);
1954                 } else {
1955                     kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1956                             isLoad2Unmasked, false, false);
1957                 }
1958             }
1959
1960             sub(I, UNROLL_N);
1961             cmp(I, UNROLL_N);
1962             jl(".subloop20", T_NEAR);
1963             align(16);
1964
1965             L(".subloop11");
1966             if (unroll_m == 16) {
1967                 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1968                         isLoad2Unmasked, false, false);
1969             } else {
1970                 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked,
1971                         false, false);
1972             }
1973             sub(I, UNROLL_N);
1974             cmp(I, UNROLL_N);
1975             jge(".subloop11", T_NEAR);
1976             align(16);
1977
1978             L(".subloop20");
1979             cmp(I, 1);
1980             jne(".subloop21", T_NEAR);
1981             if (unroll_m == 16) {
1982                 kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
1983                         false, false);
1984             } else {
1985                 kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
1986                         false);
1987             }
1988             jmp(".subloop99", T_NEAR);
1989             align(16);
1990
1991             L(".subloop21");
1992             cmp(I, 2);
1993             jne(".subloop22", T_NEAR);
1994             if (unroll_m == 16) {
1995                 kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
1996                         false, false);
1997             } else {
1998                 kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
1999                         false);
2000             }
2001             jmp(".subloop99", T_NEAR);
2002             align(16);
2003
2004             L(".subloop22");
2005             cmp(I, 3);
2006             jne(".subloop23", T_NEAR);
2007             if (unroll_m == 16) {
2008                 kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2009                         false, false);
2010             } else {
2011                 kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
2012                         false);
2013             }
2014             jmp(".subloop99", T_NEAR);
2015             align(16);
2016
2017             L(".subloop23");
2018             cmp(I, 4);
2019             jne(".subloop24", T_NEAR);
2020             if (unroll_m == 16) {
2021                 kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2022                         false, false);
2023             } else {
2024                 kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
2025                         false);
2026             }
2027             jmp(".subloop99", T_NEAR);
2028             align(16);
2029
2030             L(".subloop24");
2031             cmp(I, 5);
2032             jne(".subloop99", T_NEAR);
2033             if (unroll_m == 16) {
2034                 kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2035                         false, false);
2036             } else {
2037                 kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
2038                         false);
2039             }
2040             jmp(".subloop99", T_NEAR);
2041             align(16);
2042
2043             if (!isTransA) {
2044                 L(".subloop30");
2045                 cmp(I, UNROLL_N);
2046                 jl(".subloop25", T_NEAR);
2047                 align(16);
2048
2049                 L(".subloop31");
2050                 if (unroll_m == 16) {
2051                     kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2052                             isLoad2Unmasked, true, false);
2053                 } else {
2054                     kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2055                             isLoad2Unmasked, true, false);
2056                 }
2057                 sub(I, UNROLL_N);
2058                 cmp(I, UNROLL_N);
2059                 jge(".subloop31", T_NEAR);
2060                 align(16);
2061
2062                 L(".subloop25");
2063                 cmp(I, 1);
2064                 jne(".subloop32", T_NEAR);
2065                 if (unroll_m == 16) {
2066                     kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2067                             true, false);
2068                 } else {
2069                     kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2070                             true, false);
2071                 }
2072                 jmp(".subloop99", T_NEAR);
2073                 align(16);
2074
2075                 L(".subloop32");
2076                 cmp(I, 2);
2077                 jne(".subloop33", T_NEAR);
2078                 if (unroll_m == 16) {
2079                     kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2080                             true, false);
2081                 } else {
2082                     kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2083                             true, false);
2084                 }
2085                 jmp(".subloop99", T_NEAR);
2086                 align(16);
2087
2088                 L(".subloop33");
2089                 cmp(I, 3);
2090                 jne(".subloop34", T_NEAR);
2091                 if (unroll_m == 16) {
2092                     kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2093                             true, false);
2094                 } else {
2095                     kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2096                             true, false);
2097                 }
2098                 jmp(".subloop99", T_NEAR);
2099                 align(16);
2100
2101                 L(".subloop34");
2102                 cmp(I, 4);
2103                 jne(".subloop35", T_NEAR);
2104                 if (unroll_m == 16) {
2105                     kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2106                             true, false);
2107                 } else {
2108                     kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2109                             true, false);
2110                 }
2111                 jmp(".subloop99", T_NEAR);
2112                 align(16);
2113
2114                 L(".subloop35");
2115                 cmp(I, 5);
2116                 jne(".subloop99", T_NEAR);
2117                 if (unroll_m == 16) {
2118                     kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2119                             true, false);
2120                 } else {
2121                     kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2122                             true, false);
2123                 }
2124                 align(16);
2125             }
2126
2127             L(".subloop99");
2128             // Compute address for A
2129             if (!isTransA) {
2130                 add(A, unroll_m * SIZE);
2131             } else {
2132                 mov(rax, LDA);
2133                 imul(rax, rax, unroll_m);
2134                 add(A, rax);
2135             }
2136
2137             // Compute next address of BIAS
2138             if (hasBias) {
2139                 add(BIAS, unroll_m * SIZE);
2140             }
2141
2142             outLocalLabel();
2143         };
2144
2145         inLocalLabel();
2146
2147         preamble();
2148
2149         // Get the registers
2150         mov(B, ARG_B);
2151         mov(LDB, ARG_LDB);
2152         mov(r15, ARG_BETA);
2153         mov(r12, ARG_C);
2154         if (hasBias)
2155             mov(r10, ARG_BIAS);
2156         mov(LDC, ARG_LDC);
2157         mov(rbp, rsp);
2158
2159         vmovss(xmm0, ptr[ARG_ALPHA]);
2160         vmovss(xmm1, ptr[r15]);
2161
2162 #if _WIN32
2163         mov(A, ARG_A);
2164         mov(LDA, ARG_LDA);
2165 #endif
2166
2167         cmp(K, STACK_K_CAPACITY);
2168         jg(".buffer_in_ws", T_NEAR);
2169
2170         // Create buffer and align to 4kB page
2171         lea(rax, ptr[K * SIZE]);
2172         sal(rax, 4);
2173         add(rax, 256);
2174         sub(rsp, rax);
2175         and_(rsp, -PAGE_4K);
2176         jmp(".buffer_allocated", T_NEAR);
2177
2178         L(".buffer_in_ws");
2179         mov(rsp, ARG_WS);
2180
2181         L(".buffer_allocated");
2182
2183         mov(ORIG_SP, rbp);
2184         mov(M, ARG_M);
2185         mov(N, ARG_N);
2186         mov(C, r12);
2187         if (hasBias)
2188             mov(BIAS, r10);
2189         vmovss(ALPHA, xmm0);
2190         vmovss(BETA, xmm1);
2191         sub(A, -OFFSET * SIZE);
2192         sub(B, -OFFSET * SIZE);
2193         mov(ORIG_A, A);
2194         sal(LDA, BASE_SHIFT);
2195         sal(LDB, BASE_SHIFT);
2196         sal(LDC, BASE_SHIFT);
2197         lea(LDB3, ptr[LDB + LDB * 2]);
2198
2199         for (int i = 0; i < 8; i++) {
2200             mov(dword[rsp + 88 + i * 4], i);
2201         }
2202
2203         if (isTransA && is_avx2) {
2204             movq(xmm0, LDA);
2205             vpbroadcastq(ymm1, xmm0);
2206             vinsertf128(ymm0, ymm0, xmm0, 1);
2207             vpermilpd(ymm0, ymm0, 5);
2208             vpaddq(ymm1, ymm1, ymm1);
2209             vperm2f128(ymm1, ymm1, ymm1, 8);
2210             vpaddq(ymm0, ymm0, ymm1);
2211             vmovups(STRIDE, ymm0);
2212         }
2213
2214         // Check A alignment and leading dimension; take copy-based path as
2215         // needed
2216         mov(rax, LDA);
2217         or_(rax, A);
2218         and_(rax, 0x1f);
2219         mov(FLAG, rax);
2220
2221         cmp(M, UNROLL_M);
2222         jl(".main0", T_NEAR);
2223         align(16);
2224
2225         L(".main1");
2226         subloop(UNROLL_M, true, true);
2227         sub(M, UNROLL_M);
2228         cmp(M, UNROLL_M);
2229         jge(".main1", T_NEAR);
2230         align(16);
2231
2232         L(".main0");
2233         cmp(M, 0);
2234         jle(".main999", T_NEAR);
2235
2236         if (UNROLL_M > 8) {
2237             cmp(M, 8);
2238             jle(".main2", T_NEAR);
2239
2240             sub(M, 8);
2241             vbroadcastss(VMASK, M);
2242             vpcmpgtd(VMASK, VMASK, MASK);
2243
2244             subloop(16, true, false);
2245             jmp(".main999", T_NEAR);
2246             align(16);
2247
2248             L(".main2");
2249             cmp(M, 8);
2250             jne(".main3", T_NEAR);
2251             subloop(8, true, true);
2252             jmp(".main999", T_NEAR);
2253         }
2254
2255         align(16);
2256
2257         L(".main3");
2258         vbroadcastss(VMASK, M);
2259         if (is_avx2) {
2260             vpcmpgtd(VMASK, VMASK, MASK);
2261         } else {
2262             auto xmask = Xmm(VMASK.getIdx());
2263             auto xmm_tmp = xmm4;
2264
2265             vextractf128(xmm_tmp, VMASK, 1);
2266             vpcmpgtd(xmask, xmask, MASK);
2267             vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4
2268             vinsertf128(VMASK, VMASK, xmm_tmp, 1);
2269         }
2270         subloop(8, false, false);
2271         align(16);
2272
2273         L(".main999");
2274         // Restore original stack
2275         mov(rax, ORIG_SP);
2276         mov(rsp, rax);
2277
2278         vzeroupper();
2279         postamble();
2280
2281         outLocalLabel();
2282
2283         ker_ = reinterpret_cast<decltype(ker_)>(
2284                 const_cast<uint8_t *>(this->getCode()));
2285     }
2286
2287     void operator()(long long int m, long long int n, long long int k,
2288             const float *alpha, const float *a, long long int lda,
2289             const float *b, long long int ldb, const float *beta, float *c,
2290             long long int ldc, const float *bias, float *ws)
2291     {
2292         (*ker_)(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
2293     }
2294
2295 private:
2296     void (*ker_)(long long int m, long long int n, long long int k,
2297             const float *alpha, const float *a, long long int lda,
2298             const float *b, long long int ldb, const float *beta, float *c,
2299             long long int ldc, const float *bias, float *ws);
2300 };
2301
2302 typedef void (*ker)(long long int, long long int, long long int, float *,
2303         float *, long long int, float *, long long int, float *, float *,
2304         long long int, float *);
2305 void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
2306         const char *transb, int m, int n, int k, const float *alpha,
2307         const float *a, int lda, const float *b, int ldb, const float *beta,
2308         float *c, int ldc, const float *bias, float *ws)
2309 {
2310     bool isTransA = (*transa == 'T' || *transa == 't');
2311     bool isTransB = (*transb == 'T' || *transb == 't');
2312
2313     int Bm, sizeM, Bn, sizeN, Bk, sizeK;
2314
2315     int i, j;
2316
2317     if ((m <= 0) || (n <= 0))
2318         return;
2319
2320     if ((k <= 0) || (alpha[0] == 0.)) {
2321
2322         if (beta[0] == 0.) {
2323             for (j = 0; j < n; j++)
2324                 for (i = 0; i < m; i++)
2325                     c[i + j * ldc] = 0.0;
2326         } else if (beta[0] != 1.) {
2327             for (j = 0; j < n; j++)
2328                 for (i = 0; i < m; i++)
2329                     c[i + j * ldc] *= beta[0];
2330         }
2331
2332         return;
2333     }
2334
2335     int BM = 4032;
2336     int BN = isTransA ? 96 : 48;
2337     int BK = isTransB ? 96 : 256;
2338     const float *curA, *curB, *curBias = NULL;
2339     float *curC;
2340
2341     for (Bk = 0; Bk < k; Bk += sizeK) {
2342         sizeK = k - Bk;
2343         if (sizeK >= BK * 2)
2344             sizeK = BK;
2345         else {
2346             if (sizeK > BK)
2347                 sizeK = (sizeK + 1) / 2;
2348         }
2349
2350         for (Bm = 0; Bm < m; Bm += sizeM) {
2351             sizeM = m - Bm;
2352             if (sizeM >= BM * 2)
2353                 sizeM = BM;
2354             else {
2355                 if (sizeM > BM + BM / 2)
2356                     sizeM = (sizeM + 1) / 2;
2357             }
2358
2359             for (Bn = 0; Bn < n; Bn += sizeN) {
2360                 sizeN = n - Bn;
2361                 if (sizeN >= BN * 2)
2362                     sizeN = BN;
2363                 else {
2364                     if (sizeN > BN + BN / 2)
2365                         sizeN = (sizeN + 1) / 2;
2366                 }
2367
2368                 if (!isTransA) {
2369                     curA = a + Bm + (size_t)Bk * lda;
2370                 } else {
2371                     curA = a + Bk + (size_t)Bm * lda;
2372                 }
2373                 if (!isTransB) {
2374                     curB = b + Bk + (size_t)Bn * ldb;
2375                 } else {
2376                     curB = b + Bn + (size_t)Bk * ldb;
2377                 }
2378                 curC = c + Bm + (size_t)Bn * ldc;
2379                 if (bias != NULL) {
2380                     if (Bk == 0) {
2381                         curBias = bias + Bm;
2382                     } else {
2383                         curBias = NULL;
2384                     }
2385                 }
2386                 if (Bk == 0) {
2387                     if (*beta == 0.0 && bias == NULL)
2388                         (*ker_b0_)((long long int)sizeM, (long long int)sizeN,
2389                                 (long long int)sizeK, alpha, curA,
2390                                 (long long int)lda, curB, (long long int)ldb,
2391                                 beta, curC, (long long int)ldc, curBias, ws);
2392                     else
2393                         (*ker_bn_)((long long int)sizeM, (long long int)sizeN,
2394                                 (long long int)sizeK, alpha, curA,
2395                                 (long long int)lda, curB, (long long int)ldb,
2396                                 beta, curC, (long long int)ldc, curBias, ws);
2397                 } else {
2398                     (*ker_b1_)((long long int)sizeM, (long long int)sizeN,
2399                             (long long int)sizeK, alpha, curA,
2400                             (long long int)lda, curB, (long long int)ldb, beta,
2401                             curC, (long long int)ldc, curBias, ws);
2402                 }
2403             }
2404         }
2405     }
2406     return;
2407 }
2408 void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
2409         const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
2410         const float *A, const int *p_lda, const float *B, const int *p_ldb,
2411         const float *p_beta, float *C, const int *p_ldc, const float *bias)
2412 {
2413     if (beta_ == 0. || beta_ == 1.)
2414         assert(*p_beta == beta_);
2415     assert((one_of(*transa, 'T', 't') == one_of(transa_, 'T', 't')));
2416
2417     int nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
2418     int m = *p_m;
2419     int n = *p_n;
2420     int k = *p_k;
2421     int lda = *p_lda;
2422     int ldb = *p_ldb;
2423     int ldc = *p_ldc;
2424     float beta = *p_beta;
2425     int MB, NB, KB;
2426
2427     int nthr_m, nthr_n, nthr_k, nthr_mn;
2428
2429     assert(nthr <= nthrs_);
2430
2431     // Determine threading partitioning
2432     gemm_utils::calc_nthr_nocopy_avx(
2433             m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
2434     assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1));
2435
2436     // May not happen, but just in case
2437     if (nthr < nthr_m * nthr_n * nthr_k)
2438         nthr = nthr_m * nthr_n * nthr_k;
2439
2440     nthr_mn = nthr_m * nthr_n;
2441
2442     unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_;
2443     if (!ompstatus) return;
2444
2445     float *c_buffers = NULL;
2446     float *ws_buffers = NULL;
2447
2448     if (nthr_k > 1) {
2449         for (int i = 0; i < nthr; i++)
2450             ompstatus[i * CACHE_LINE_SIZE] = 0;
2451
2452         c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
2453                 * sizeof(float), PAGE_4K);
2454     }
2455
2456     const size_t ws_elems_per_thr = k * 16 + 64;
2457     const size_t ws_size_per_thr
2458             = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
2459     if (k > STACK_K_CAPACITY) {
2460         ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
2461     }
2462
2463     parallel(nthr, [&](const int ithr, const int nthr) {
2464         int ithr_m, ithr_n, ithr_k, ithr_mn;
2465         int m_from, m_to, myM;
2466         int n_from, n_to, myN;
2467         int k_from, k_to, myK;
2468         int cbase, ibase;
2469         const float *myA, *myB, *myBias = NULL;
2470         float *myC = C, myBeta;
2471         float *ws = ws_buffers ?
2472                 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
2473         int ld = ldc;
2474
2475         if (ithr < nthr_m * nthr_n * nthr_k) {
2476
2477             ithr_mn = ithr % nthr_mn;
2478             ithr_m = ithr_mn % nthr_m;
2479             ithr_n = ithr_mn / nthr_m;
2480             ithr_k = ithr / nthr_mn;
2481
2482             /* swap ithr_k for performance improvement */
2483             if (ithr_k == 0)
2484                 ithr_k = nthr_k - 1;
2485             else if (ithr_k == nthr_k - 1)
2486                 ithr_k = 0;
2487
2488             m_from = MB * (ithr_m);
2489             m_to = MB * (ithr_m + 1);
2490             if (m_to > m)
2491                 m_to = m;
2492             myM = m_to - m_from;
2493
2494             n_from = NB * (ithr_n);
2495             n_to = NB * (ithr_n + 1);
2496             if (n_to > n)
2497                 n_to = n;
2498             myN = n_to - n_from;
2499
2500             k_from = KB * (ithr_k);
2501             k_to = KB * (ithr_k + 1);
2502             if (k_to > k)
2503                 k_to = k;
2504             myK = k_to - k_from;
2505
2506             cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2507             ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
2508
2509             if ((myM > 0) && (myN > 0)) {
2510
2511                 if (*transa == 'N' || *transa == 'n') {
2512                     myA = &(A[m_from + k_from * lda]);
2513                 } else {
2514                     myA = &(A[k_from + m_from * lda]);
2515                 }
2516                 if (*transb == 'N' || *transb == 'n') {
2517                     myB = &(B[k_from + n_from * ldb]);
2518                 } else {
2519                     myB = &(B[n_from + k_from * ldb]);
2520                 }
2521                 if (ithr_k == 0) {
2522                     myC = &(C[m_from + n_from * ldc]);
2523                     myBeta = beta;
2524                     ld = ldc;
2525                     if (hasBias_)
2526                         myBias = &(bias[m_from]);
2527                 } else {
2528                     myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
2529                     myBeta = 0.0;
2530                     ld = MB;
2531                     myBias = NULL;
2532                 }
2533
2534                 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
2535                         lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
2536
2537                 if (nthr_k > 1)
2538                     ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
2539             }
2540
2541             if (nthr_k > 1) {
2542
2543                 // sum matrices partitioned along K dimension
2544                 int n1, n2;
2545
2546                 gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2547
2548                 if (ithr_k > 0) {
2549
2550                     myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
2551                     myC = myC + n1 * MB;
2552                     /* need to wait until main thread finishes */
2553                     while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
2554                     };
2555
2556                     /* my cache is hot */
2557                     gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2558                             &C[m_from + (n_from + n1) * ldc], ldc);
2559                 }
2560
2561                 for (int ik = 1; ik < nthr_k; ++ik) {
2562                     if (ik != ithr_k) {
2563
2564                         myC = c_buffers + MB * NB * (cbase + ik - 1);
2565                         myC = myC + n1 * MB;
2566
2567                         while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2568                         };
2569
2570                         gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2571                                 &C[m_from + (n_from + n1) * ldc], ldc);
2572                     }
2573                 }
2574             }
2575         }
2576     });
2577
2578     if (nthr_k > 1)
2579         free(c_buffers);
2580     free(ws_buffers);
2581 }
2582
2583 jit_avx_gemm_f32::jit_avx_gemm_f32(
2584         char transa, char transb, float beta, bool hasBias)
2585 {
2586     transa_ = transa;
2587     transb_ = transb;
2588     beta_ = beta;
2589     hasBias_ = hasBias;
2590     if (hasBias) {
2591         assert(beta == 0.0);
2592     }
2593     ker_bn_ = new xbyak_gemm(transa, transb, beta, hasBias);
2594     if (beta != 1.0) {
2595         ker_b1_ = new xbyak_gemm(transa, transb, 1.0);
2596     } else {
2597         ker_b1_ = ker_bn_;
2598     }
2599     if (beta != 0.0 || (beta == 0.0 && hasBias)) {
2600         ker_b0_ = new xbyak_gemm(transa, transb, 0.0);
2601     } else {
2602         ker_b0_ = ker_bn_;
2603     }
2604     nthrs_ = mkldnn_get_max_threads();
2605     ompstatus_ = (unsigned int *)malloc(
2606         sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64);
2607     assert(ompstatus_);
2608 }
2609
2610 jit_avx_gemm_f32::~jit_avx_gemm_f32()
2611 {
2612     delete ker_bn_;
2613     if (beta_ != 1.0)
2614         delete ker_b1_;
2615     if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
2616         delete ker_b0_;
2617     free(ompstatus_);
2618 }
2619
2620 }
2621 }
2622 }
2623
2624 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s