Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / jit_avx_gemm_f32.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <math.h>
18
19 #include "mkldnn_thread.hpp"
20 #include "utils.hpp"
21 #include "gemm_utils.hpp"
22 #include "jit_avx_gemm_f32.hpp"
23
24 #define CACHE_LINE_SIZE 64
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::utils;
32
33 using namespace Xbyak;
34 #define STACKSIZE get_size_of_abi_save_regs()
35 #if _WIN32
36 #define STACK_K_CAPACITY 128
37 #else
38 #define STACK_K_CAPACITY 8192
39 #endif
40 #define SIZE 4
41 #define OFFSET 32
42 #define BASE_SHIFT 2
43 #define SECOND_FETCH 14
44
45 struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
46     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
47
48     xbyak_gemm(char transa, char transb, float beta, bool hasBias = false,
49             void *code_ptr = nullptr,
50             size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
51         : jit_generator(code_ptr, code_size)
52     {
53         const bool is_avx2 = mayiuse(avx2);
54         assert(IMPLICATION(!is_avx2, mayiuse(avx)));
55
56         const int UNROLL_M = is_avx2 ? 16 : 8;
57         const int UNROLL_N = 6;
58
59         bool isTransA = (transa == 'T' || transa == 't');
60         bool isTransB = (transb == 'T' || transb == 't');
61         bool isBeta0 = (beta == 0.0);
62         bool isBetaN = (!isBeta0 && beta != 1.0);
63
64         // various definitions for convenience
65         auto ARG_M = abi_param1;
66         auto ARG_N = abi_param2;
67         auto K = abi_param3;
68         auto ARG_ALPHA = abi_param4;
69 #ifdef _WIN32
70         auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
71         auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
72             sizeof(float *) + STACKSIZE];
73         const auto stackOffset = OFFSET_SHADOWSPACE +
74             sizeof(float *) + STACKSIZE;
75         auto A = rsi;
76         auto LDA = rdi;
77 #else
78         auto ARG_A = r8;
79         auto ARG_LDA = r9;
80         const auto stackOffset = STACKSIZE;
81         auto A = ARG_A;
82         auto LDA = ARG_LDA;
83 #endif
84         auto ARG_B = ptr[rsp + 8 + stackOffset];
85         auto ARG_LDB = ptr[rsp + 16 + stackOffset];
86         auto ARG_BETA = ptr[rsp + 24 + stackOffset];
87         auto ARG_C = ptr[rsp + 32 + stackOffset];
88         auto ARG_LDC = ptr[rsp + 40 + stackOffset];
89         auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
90         auto ARG_WS = ptr[rsp + 56 + stackOffset];
91
92         auto B = r11;
93         auto LDB = rbx;
94         auto LDC = r13;
95         auto LL = rax;
96         auto AO1 = abi_param2;
97         auto BO1 = abi_param4;
98         auto BO2 = rbp;
99         auto CO1 = r14;
100         auto CO2 = r15;
101         auto LDB3 = r10;
102         auto LDA4 = abi_param1;
103         auto AA = r12;
104         auto BIAS1 = abi_param1;
105
106         auto M = qword[rsp + 0];
107         auto N = qword[rsp + 8];
108         auto FLAG = qword[rsp + 16];
109         auto I = qword[rsp + 24];
110         auto C = qword[rsp + 32];
111         auto BIAS = qword[rsp + 40];
112         auto ALPHA = qword[rsp + 48];
113         auto BETA = qword[rsp + 64];
114         auto ORIG_A = qword[rsp + 80];
115         auto MASK = dword[rsp + 88];
116         auto STRIDE = qword[rsp + 120];
117         auto ORIG_SP = qword[rsp + 152];
118
119         auto VALPHA = ymm1;
120         auto VBETA = ymm2;
121         auto VMASK = ymm3;
122         auto VBIAS1 = ymm2;
123         auto VBIAS2 = ymm4;
124
125         auto PREFETCHSIZEA = 128;
126         auto PREFETCHSIZEB = (!isTransB) ? -16 : 0;
127
128         // Function for packing if needed
129         auto do_pack = [&](
130                 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
131             Label pack2, pack3, pack4, pack10;
132
133             int regIdx;
134             Reg64 reg;
135
136             mov(BO1, A);
137             lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
138
139             if (isTransA) {
140                 lea(BO2, ptr[BO1 + LDA * 4]);
141                 lea(CO1, ptr[LDA + LDA * 2]);
142                 vmovupd(ymm7, STRIDE);
143             }
144
145             mov(LL, K);
146             sar(LL, 2);
147             jle(pack3, T_NEAR);
148             align(16);
149
150             L(pack2);
151             if (!isTransA) {
152                 for (int i = 0; i < 4; i++) {
153                     regIdx = (i % 2 == 0) ? 4 : 6;
154                     if (isLoad1Unmasked) {
155                         vmovups(Ymm(regIdx),
156                                 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
157                     } else {
158                         vmaskmovps(Ymm(regIdx), VMASK,
159                                 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
160                     }
161                     if (unroll_m > 8) {
162                         if (isLoad2Unmasked) {
163                             vmovups(Ymm(regIdx + 1),
164                                     ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
165                         } else {
166                             vmaskmovps(Ymm(regIdx + 1), VMASK,
167                                     ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
168                         }
169                     }
170                     add(BO1, LDA);
171
172                     vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
173                             Ymm(regIdx));
174                     if (unroll_m > 8) {
175                         vmovups(ptr[AO1
176                                         + (unroll_m * i + 1 * 8 - OFFSET)
177                                                 * SIZE],
178                                 Ymm(regIdx + 1));
179                     }
180                 }
181
182             } else {
183                 if (isLoad1Unmasked) {
184                     for (int i = 0; i < 2; i++) {
185                         reg = (i % 2 == 0) ? BO1 : BO2;
186                         vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]);
187                         vmovups(xmm1,
188                                 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
189                         lea(BO2, ptr[reg + LDA * 2]);
190                         vunpcklps(xmm4, xmm0, xmm1);
191                         vunpckhps(xmm5, xmm0, xmm1);
192                         vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
193                         vmovups(xmm1,
194                                 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
195                         lea(BO2, ptr[BO2 + LDA * 2]);
196                         vunpcklps(xmm6, xmm0, xmm1);
197                         vunpckhps(xmm2, xmm0, xmm1);
198
199                         vunpcklpd(xmm0, xmm4, xmm6);
200                         vunpckhpd(xmm1, xmm4, xmm6);
201                         vmovups(ptr[AO1
202                                         + (unroll_m * 0 + i * 4 - OFFSET)
203                                                 * SIZE],
204                                 xmm0);
205                         vmovups(ptr[AO1
206                                         + (unroll_m * 1 + i * 4 - OFFSET)
207                                                 * SIZE],
208                                 xmm1);
209                         vunpcklpd(xmm0, xmm5, xmm2);
210                         vunpckhpd(xmm1, xmm5, xmm2);
211                         vmovups(ptr[AO1
212                                         + (unroll_m * 2 + i * 4 - OFFSET)
213                                                 * SIZE],
214                                 xmm0);
215                         vmovups(ptr[AO1
216                                         + (unroll_m * 3 + i * 4 - OFFSET)
217                                                 * SIZE],
218                                 xmm1);
219                     }
220                 } else if (is_avx2) {
221                     for (int i = 0; i < 2; i++) {
222                         vmovaps(xmm4, xmm3);
223                         vgatherqps(xmm0,
224                                 ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE],
225                                 xmm4);
226                         vmovaps(xmm4, xmm3);
227                         vgatherqps(xmm1,
228                                 ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
229                                 xmm4);
230
231                         vmovups(ptr[AO1
232                                         + (unroll_m * (2 * i) + 0 * 4 - OFFSET)
233                                                 * SIZE],
234                                 xmm0);
235                         vmovups(ptr[AO1
236                                         + (unroll_m * (2 * i + 1) + 0 * 4
237                                                   - OFFSET)
238                                                 * SIZE],
239                                 xmm1);
240                     }
241
242                     lea(BO2, ptr[BO1 + LDA * 4]);
243
244                     for (int i = 0; i < 2; i++) {
245                         vextractf128(xmm4, ymm3, 1);
246                         vgatherqps(xmm0,
247                                 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
248                                 xmm4);
249                         vextractf128(xmm4, ymm3, 1);
250                         vgatherqps(xmm1,
251                                 ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
252                                 xmm4);
253
254                         vmovups(ptr[AO1
255                                         + (unroll_m * (2 * i) + 1 * 4 - OFFSET)
256                                                 * SIZE],
257                                 xmm0);
258                         vmovups(ptr[AO1
259                                         + (unroll_m * (2 * i + 1) + 1 * 4
260                                                   - OFFSET)
261                                                 * SIZE],
262                                 xmm1);
263                     }
264
265                     lea(BO2, ptr[BO2 + LDA * 4]);
266                 } else {
267                     vxorps(xmm4, xmm4, xmm4);
268                     lea(BO2, ptr[BO1 + LDA * 4]);
269
270                     auto el_cp = [&](int section, int ld_step) {
271                         RegExp src_addr = section == 0 ? BO1 : BO2;
272                         if (ld_step == 1 || ld_step == 2)
273                             src_addr = src_addr + LDA * ld_step;
274                         else if (ld_step == 3)
275                             src_addr = src_addr + CO1;
276                         src_addr = src_addr - OFFSET * SIZE;
277
278                         vmovups(Xmm(ld_step % 2), ptr[src_addr]);
279                         RegExp dst_addr = AO1
280                             + (ld_step + section * 4 - OFFSET) * SIZE;
281                         for (int off = 0; off < 4; ++off)
282                             pextrd(ptr[dst_addr + unroll_m * off * SIZE],
283                                     Xmm(ld_step % 2), off);
284                     };
285
286                     Label l_end;
287                     el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
288                     el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
289                     el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
290                     el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
291                     el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
292                     el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
293                     el_cp(1, 2);
294                     L(l_end);
295
296                     lea(BO2, ptr[BO2 + LDA * 4]);
297                 }
298
299                 if (unroll_m >= 16) {
300                     assert(is_avx2);
301                     if (isLoad2Unmasked) {
302                         for (int i = 0; i < 2; i++) {
303                             vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
304                             vmovups(xmm1, ptr[BO2 + LDA * 1
305                                                   + (0 * 8 - OFFSET) * SIZE]);
306                             lea(BO2, ptr[BO2 + LDA * 2]);
307                             vunpcklps(xmm4, xmm0, xmm1);
308                             vunpckhps(xmm5, xmm0, xmm1);
309                             vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
310                             vmovups(xmm1, ptr[BO2 + LDA * 1
311                                                   + (0 * 8 - OFFSET) * SIZE]);
312                             if (i == 0)
313                                 lea(BO2, ptr[BO2 + LDA * 2]);
314                             vunpcklps(xmm6, xmm0, xmm1);
315                             vunpckhps(xmm2, xmm0, xmm1);
316
317                             vunpcklpd(xmm0, xmm4, xmm6);
318                             vunpckhpd(xmm1, xmm4, xmm6);
319                             vmovups(ptr[AO1
320                                             + (unroll_m * 0 + (i + 2) * 4
321                                                       - OFFSET)
322                                                     * SIZE],
323                                     xmm0);
324                             vmovups(ptr[AO1
325                                             + (unroll_m * 1 + (i + 2) * 4
326                                                       - OFFSET)
327                                                     * SIZE],
328                                     xmm1);
329                             vunpcklpd(xmm0, xmm5, xmm2);
330                             vunpckhpd(xmm1, xmm5, xmm2);
331                             vmovups(ptr[AO1
332                                             + (unroll_m * 2 + (i + 2) * 4
333                                                       - OFFSET)
334                                                     * SIZE],
335                                     xmm0);
336                             vmovups(ptr[AO1
337                                             + (unroll_m * 3 + (i + 2) * 4
338                                                       - OFFSET)
339                                                     * SIZE],
340                                     xmm1);
341                         }
342                     } else {
343                         for (int i = 0; i < 2; i++) {
344                             vmovaps(xmm4, xmm3);
345                             vgatherqps(xmm0,
346                                     ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
347                                     xmm4);
348                             vmovaps(xmm4, xmm3);
349                             vgatherqps(xmm1,
350                                     ptr[BO2 + ymm7
351                                                + ((2 * i + 1) - OFFSET) * SIZE],
352                                     xmm4);
353
354                             vmovups(ptr[AO1
355                                             + (unroll_m * (2 * i) + 2 * 4
356                                                       - OFFSET)
357                                                     * SIZE],
358                                     xmm0);
359                             vmovups(ptr[AO1
360                                             + (unroll_m * (2 * i + 1) + 2 * 4
361                                                       - OFFSET)
362                                                     * SIZE],
363                                     xmm1);
364                         }
365
366                         lea(BO2, ptr[BO2 + LDA * 4]);
367
368                         for (int i = 0; i < 2; i++) {
369                             vextractf128(xmm4, ymm3, 1);
370                             vgatherqps(xmm0,
371                                     ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
372                                     xmm4);
373                             vextractf128(xmm4, ymm3, 1);
374                             vgatherqps(xmm1,
375                                     ptr[BO2 + ymm7
376                                                + ((2 * i + 1) - OFFSET) * SIZE],
377                                     xmm4);
378
379                             vmovups(ptr[AO1
380                                             + (unroll_m * (2 * i) + 3 * 4
381                                                       - OFFSET)
382                                                     * SIZE],
383                                     xmm0);
384                             vmovups(ptr[AO1
385                                             + (unroll_m * (2 * i + 1) + 3 * 4
386                                                       - OFFSET)
387                                                     * SIZE],
388                                     xmm1);
389                         }
390
391                         lea(BO2, ptr[BO2 + LDA * 4]);
392                     }
393                 }
394                 add(BO1, (4 * SIZE));
395             }
396
397             add(AO1, unroll_m * 4 * SIZE);
398             sub(LL, 1);
399             jg(pack2, T_NEAR);
400             align(16);
401
402             L(pack3);
403             mov(LL, K);
404             and_(LL, 3);
405             jle(pack10, T_NEAR);
406             align(16);
407
408             L(pack4);
409             if (!isTransA) {
410                 if (isLoad1Unmasked) {
411                     vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
412                 } else {
413                     vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
414                 }
415                 if (unroll_m > 8) {
416                     if (isLoad2Unmasked) {
417                         vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
418                     } else {
419                         vmaskmovps(ymm5, VMASK,
420                                 ptr[BO1 + (1 + 8 - OFFSET) * SIZE]);
421                     }
422                 }
423                 add(BO1, LDA);
424                 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
425                         ymm4);
426                 if (unroll_m > 8) {
427                     vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
428                             ymm5);
429                 }
430             } else {
431                 if (isLoad1Unmasked) {
432                     for (int i = 0; i < 2; i++) {
433                         reg = (i % 2 == 0) ? BO1 : BO2;
434                         vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]);
435                         vmovss(xmm0,
436                                 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
437                         lea(BO2, ptr[reg + LDA * 2]);
438                         vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
439                     }
440                     vunpcklpd(xmm1, xmm1, xmm2);
441                     vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
442                             xmm1);
443
444                     for (int i = 0; i < 2; i++) {
445                         vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
446                         vmovss(xmm0,
447                                 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
448                         lea(BO2, ptr[BO2 + LDA * 2]);
449                         vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
450                     }
451                     vunpcklpd(xmm1, xmm1, xmm2);
452                     vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
453                             xmm1);
454                 } else if (is_avx2) {
455                     vmovaps(xmm4, xmm3);
456                     vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE],
457                             xmm4);
458                     lea(BO2, ptr[BO1 + LDA * 4]);
459                     vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
460                             xmm1);
461
462                     vextractf128(xmm4, ymm3, 1);
463                     vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
464                             xmm4);
465                     lea(BO2, ptr[BO2 + LDA * 4]);
466                     vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
467                             xmm1);
468                 } else {
469                     vxorps(xmm4, xmm4, xmm4);
470                     lea(BO2, ptr[BO1 + LDA * 4]);
471
472                     auto el_cp = [&](int section, int ld_step) {
473                         RegExp src_addr = section == 0 ? BO1 : BO2;
474                         if (ld_step == 1 || ld_step == 2)
475                             src_addr = src_addr + LDA * ld_step;
476                         else if (ld_step == 3)
477                             src_addr = src_addr + CO1;
478                         src_addr = src_addr - OFFSET * SIZE;
479
480                         vmovss(xmm1, ptr[src_addr]);
481                         RegExp dst_addr = AO1
482                             + (ld_step + section * 4 - OFFSET) * SIZE;
483                         movss(ptr[dst_addr], xmm1);
484                     };
485
486                     Label l_end;
487                     el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
488                     el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
489                     el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
490                     el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
491                     el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
492                     el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
493                     el_cp(1, 2);
494                     L(l_end);
495
496                     lea(BO2, ptr[BO2 + LDA * 4]);
497                 }
498
499                 if (unroll_m >= 16) {
500                     assert(is_avx2);
501                     if (isLoad2Unmasked) {
502                         for (int i = 0; i < 2; i++) {
503                             vmovss(Xmm(i + 1),
504                                     ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
505                             vmovss(xmm0, ptr[BO2 + LDA * 1
506                                                  + (0 * 8 - OFFSET) * SIZE]);
507                             lea(BO2, ptr[BO2 + LDA * 2]);
508                             vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
509                         }
510                         vunpcklpd(xmm1, xmm1, xmm2);
511                     } else {
512                         vmovaps(xmm4, xmm3);
513                         vgatherqps(xmm1,
514                                 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
515                                 xmm4);
516                         lea(BO2, ptr[BO2 + LDA * 4]);
517                     }
518                     vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE],
519                             xmm1);
520
521                     if (isLoad2Unmasked) {
522                         for (int i = 0; i < 2; i++) {
523                             vmovss(Xmm(i + 1),
524                                     ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
525                             vmovss(xmm0, ptr[BO2 + LDA * 1
526                                                  + (0 * 8 - OFFSET) * SIZE]);
527                             lea(BO2, ptr[BO2 + LDA * 2]);
528                             vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
529                         }
530                         vunpcklpd(xmm1, xmm1, xmm2);
531                     } else {
532                         vextractf128(xmm4, ymm3, 1);
533                         vgatherqps(xmm1,
534                                 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
535                                 xmm4);
536                     }
537                     vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE],
538                             xmm1);
539                 }
540                 add(BO1, SIZE);
541             }
542
543             add(AO1, unroll_m * SIZE);
544             sub(LL, 1);
545             jg(pack4, T_NEAR);
546             align(16);
547
548             L(pack10);
549         };
550
551         // Fused multiply add; may become one or two instructions
552         auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2,
553                 bool overWrite = false) {
554             if (useFma) {
555                 if (is_avx2) {
556                     vfmadd231ps(reg2, reg1, reg0);
557                 } else {
558                     assert(UNROLL_M == 8);
559                     auto tent_vreg = overWrite ? reg1 : ymm1;
560                     vmulps(tent_vreg, reg1, reg0);
561                     vaddps(reg2, reg2, tent_vreg);
562                 }
563             } else {
564                 if (!overWrite) {
565                     vmulps(ymm15, reg1, reg0);
566                     vaddps(reg2, reg2, ymm15);
567                 } else {
568                     vmulps(reg1, reg1, reg0);
569                     vaddps(reg2, reg2, reg1);
570                 }
571             }
572         };
573
574         // Inner kernel with k=8
575         auto innerkernel8 = [&](int unroll_m, int unroll_n,
576                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
577                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
578                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
579                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
580                 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
581                 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
582                 Ymm reg23) {
583
584             Ymm fmareg;
585
586             if (!isDirect) {
587                 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
588             } else {
589                 prefetcht0(ptr[AO1 + LDA4]);
590             }
591
592             for (int i = 0; i < 8; i++) {
593                 if (isDirect) {
594                     if (isLoad1Unmasked) {
595                         vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
596                     } else {
597                         vmaskmovps(ymm0, VMASK,
598                                 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
599                     }
600                     if (unroll_m >= 16) {
601                         if (isLoad2Unmasked) {
602                             vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
603                         } else {
604                             vmaskmovps(ymm1, VMASK,
605                                     ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
606                         }
607                     }
608                     add(AO1, LDA);
609                 }
610
611                 if (!isTransB) {
612                     vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
613                 } else {
614                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
615                 }
616                 fmareg = (i % 2 == 0) ? reg00 : reg12;
617                 fma(useFma, ymm0, ymm2, fmareg);
618                 if (unroll_m >= 16) {
619                     fmareg = (i % 2 == 0) ? reg06 : reg18;
620                     fma(useFma, ymm1, ymm2, fmareg);
621                 }
622                 if (i == 0) {
623                     if (!isTransB) {
624                         prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
625                     }
626                 }
627                 if (unroll_n >= 2) {
628                     if (!isTransB) {
629                         if (i == 1) {
630                             prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
631                         }
632                         vbroadcastss(
633                                 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
634                     } else {
635                         vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
636                     }
637                     fmareg = (i % 2 == 0) ? reg01 : reg13;
638                     fma(useFma, ymm0, ymm2, fmareg);
639                     if (unroll_m >= 16) {
640                         fmareg = (i % 2 == 0) ? reg07 : reg19;
641                         fma(useFma, ymm1, ymm2, fmareg);
642                     }
643                 }
644
645                 if (isCopy) {
646                     vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
647                             ymm0);
648                     if (unroll_m >= 16) {
649                         vmovups(ptr[LDA4
650                                         + (unroll_m * i + 1 * 8 - OFFSET)
651                                                 * SIZE],
652                                 ymm1);
653                     }
654                     if (i == 7) {
655                         sub(LDA4, -unroll_m * 8 * SIZE);
656                     }
657                 }
658
659                 if (unroll_n >= 3) {
660                     if (!isTransB) {
661                         if (i == 2) {
662                             prefetcht0(
663                                     ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
664                         }
665                         vbroadcastss(
666                                 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
667                     } else {
668                         vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
669                     }
670                     fmareg = (i % 2 == 0) ? reg02 : reg14;
671                     fma(useFma, ymm0, ymm2, fmareg);
672                     if (unroll_m >= 16) {
673                         fmareg = (i % 2 == 0) ? reg08 : reg20;
674                         fma(useFma, ymm1, ymm2, fmareg);
675                     }
676                 }
677
678                 if (i == 7) {
679                     if (!isTransB) {
680                         sub(BO1, -8 * SIZE);
681                     }
682                 }
683
684                 if (unroll_n >= 4) {
685                     if (!isTransB) {
686                         if (i == 3) {
687                             prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
688                         }
689                         vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
690                     } else {
691                         vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
692                     }
693                     fmareg = (i % 2 == 0) ? reg03 : reg15;
694                     fma(useFma, ymm0, ymm2, fmareg);
695                     if (unroll_m >= 16) {
696                         fmareg = (i % 2 == 0) ? reg09 : reg21;
697                         fma(useFma, ymm1, ymm2, fmareg);
698                     }
699                 }
700
701                 if (unroll_n >= 5) {
702                     if (!isTransB) {
703                         if (i == 4) {
704                             prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
705                         }
706                         vbroadcastss(
707                                 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
708                     } else {
709                         vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
710                     }
711                     fmareg = (i % 2 == 0) ? reg04 : reg16;
712                     fma(useFma, ymm0, ymm2, fmareg);
713                     if (unroll_m >= 16) {
714                         fmareg = (i % 2 == 0) ? reg10 : reg22;
715                         fma(useFma, ymm1, ymm2, fmareg);
716                     }
717                 }
718
719                 if (unroll_n >= 6) {
720                     if (!isTransB) {
721                         if (i == 5) {
722                             prefetcht0(
723                                     ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
724                         }
725                         vbroadcastss(
726                                 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
727                     } else {
728                         vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
729                     }
730                     fmareg = (i % 2 == 0) ? reg05 : reg17;
731                     fma(useFma, ymm0, ymm2, fmareg);
732                     if (unroll_m >= 16) {
733                         fmareg = (i % 2 == 0) ? reg11 : reg23;
734                         fma(useFma, ymm1, ymm2, fmareg);
735                     }
736                 }
737                 if (isTransB) {
738                     prefetcht0(ptr[BO1 + BO2]);
739                     add(BO1, LDB);
740                 }
741
742                 if (i == 0) {
743                     if (unroll_m >= 4) {
744                         if (!isDirect) {
745                             prefetcht0(
746                                     ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
747                         } else {
748                             prefetcht0(ptr[AO1 + LDA4]);
749                         }
750                     }
751                 }
752                 if (i == 1 || i == 2) {
753                     if (unroll_m >= 8) {
754                         if (!isDirect) {
755                             prefetcht0(ptr[AO1
756                                     + (PREFETCHSIZEA + (2 + 2 * i) * 8)
757                                             * SIZE]);
758                         } else {
759                             prefetcht0(ptr[AO1 + LDA4]);
760                         }
761                     }
762                 }
763                 if (i == 3 || i == 4 || i == 5 || i == 6) {
764                     if (unroll_m >= 16) {
765                         if (!isDirect) {
766                             prefetcht0(ptr[AO1
767                                     + (PREFETCHSIZEA + (2 + 2 * i) * 8)
768                                             * SIZE]);
769                         } else {
770                             prefetcht0(ptr[AO1 + LDA4]);
771                         }
772                     }
773                 }
774                 if (i == 7) {
775                     if (!isTransB) {
776                         if (unroll_n >= 4) {
777                             sub(BO2, -8 * SIZE);
778                         }
779                     }
780                     if (!isTransA) {
781                         prefetcht2(ptr[AA]);
782                         lea(AA, ptr[AA + LDA]);
783                     }
784                 }
785
786                 if (!isDirect) {
787                     if (isLoad1Unmasked) {
788                         vmovups(ymm0,
789                                 ptr[AO1
790                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
791                                                 * SIZE]);
792                     } else {
793                         vmaskmovps(
794                                 ymm0, VMASK,
795                                 ptr[AO1
796                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
797                                                 * SIZE]);
798                     }
799                     if (unroll_m >= 16) {
800                         if (isLoad2Unmasked) {
801                             vmovups(ymm1, ptr[AO1
802                                                   + (unroll_m * (i + 1) + 1 * 8
803                                                             - OFFSET)
804                                                           * SIZE]);
805                         } else {
806                             vmaskmovps(ymm1, VMASK,
807                                     ptr[AO1
808                                                + (unroll_m * (i + 1) + 1 * 8
809                                                          - OFFSET)
810                                                        * SIZE]);
811                         }
812                     }
813                 }
814             }
815
816             if (!isDirect) {
817                 sub(AO1, -unroll_m * 8 * SIZE);
818             }
819             sub(LL, 1);
820
821         };
822
823         // Inner kernel with k=4
824         auto innerkernel4 = [&](int unroll_m, int unroll_n,
825                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
826                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
827                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
828                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
829                 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
830                 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
831                 Ymm reg23) {
832
833             Ymm fmareg;
834
835             if (!isDirect) {
836                 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
837             } else {
838                 prefetcht0(ptr[AO1 + LDA4]);
839             }
840
841             for (int i = 0; i < 4; i++) {
842                 if (isDirect) {
843                     if (isLoad1Unmasked) {
844                         vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
845                     } else {
846                         vmaskmovps(ymm0, VMASK,
847                                 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
848                     }
849                     if (unroll_m >= 16) {
850                         if (isLoad2Unmasked) {
851                             vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
852                         } else {
853                             vmaskmovps(ymm1, VMASK,
854                                     ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
855                         }
856                     }
857                     add(AO1, LDA);
858                 }
859
860                 if (!isTransB) {
861                     vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
862                 } else {
863                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
864                 }
865                 fmareg = (i % 2 == 0) ? reg00 : reg12;
866                 fma(useFma, ymm0, ymm2, fmareg);
867                 if (unroll_m >= 16) {
868                     fmareg = (i % 2 == 0) ? reg06 : reg18;
869                     fma(useFma, ymm1, ymm2, fmareg);
870                 }
871                 if (i == 0) {
872                     if (!isTransB) {
873                         prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
874                     }
875                 }
876                 if (unroll_n >= 2) {
877                     if (!isTransB) {
878                         if (i == 1) {
879                             prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
880                         }
881                         vbroadcastss(
882                                 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
883                     } else {
884                         vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
885                     }
886                     fmareg = (i % 2 == 0) ? reg01 : reg13;
887                     fma(useFma, ymm0, ymm2, fmareg);
888                     if (unroll_m >= 16) {
889                         fmareg = (i % 2 == 0) ? reg07 : reg19;
890                         fma(useFma, ymm1, ymm2, fmareg);
891                     }
892                 }
893
894                 if (isCopy) {
895                     vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
896                             ymm0);
897                     if (unroll_m >= 16) {
898                         vmovups(ptr[LDA4
899                                         + (unroll_m * i + 1 * 8 - OFFSET)
900                                                 * SIZE],
901                                 ymm1);
902                     }
903                     if (i == 3) {
904                         sub(LDA4, -unroll_m * 4 * SIZE);
905                     }
906                 }
907
908                 if (unroll_n >= 3) {
909                     if (!isTransB) {
910                         if (i == 2) {
911                             prefetcht0(
912                                     ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
913                         }
914                         vbroadcastss(
915                                 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
916                     } else {
917                         vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
918                     }
919                     fmareg = (i % 2 == 0) ? reg02 : reg14;
920                     fma(useFma, ymm0, ymm2, fmareg);
921                     if (unroll_m >= 16) {
922                         fmareg = (i % 2 == 0) ? reg08 : reg20;
923                         fma(useFma, ymm1, ymm2, fmareg);
924                     }
925                 }
926
927                 if (i == 7) {
928                     if (!isTransB) {
929                         sub(BO1, -8 * SIZE);
930                     }
931                 }
932
933                 if (unroll_n >= 4) {
934                     if (!isTransB) {
935                         if (i == 3) {
936                             prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
937                         }
938                         vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
939                     } else {
940                         vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
941                     }
942                     fmareg = (i % 2 == 0) ? reg03 : reg15;
943                     fma(useFma, ymm0, ymm2, fmareg);
944                     if (unroll_m >= 16) {
945                         fmareg = (i % 2 == 0) ? reg09 : reg21;
946                         fma(useFma, ymm1, ymm2, fmareg);
947                     }
948                 }
949
950                 if (unroll_n >= 5) {
951                     if (!isTransB) {
952                         if (i == 4) {
953                             prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
954                         }
955                         vbroadcastss(
956                                 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
957                     } else {
958                         vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
959                     }
960                     fmareg = (i % 2 == 0) ? reg04 : reg16;
961                     fma(useFma, ymm0, ymm2, fmareg);
962                     if (unroll_m >= 16) {
963                         fmareg = (i % 2 == 0) ? reg10 : reg22;
964                         fma(useFma, ymm1, ymm2, fmareg);
965                     }
966                 }
967
968                 if (unroll_n >= 6) {
969                     if (!isTransB) {
970                         if (i == 5) {
971                             prefetcht0(
972                                     ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
973                         }
974                         vbroadcastss(
975                                 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
976                     } else {
977                         vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
978                     }
979                     fmareg = (i % 2 == 0) ? reg05 : reg17;
980                     fma(useFma, ymm0, ymm2, fmareg);
981                     if (unroll_m >= 16) {
982                         fmareg = (i % 2 == 0) ? reg11 : reg23;
983                         fma(useFma, ymm1, ymm2, fmareg);
984                     }
985                 }
986                 if (isTransB) {
987                     prefetcht0(ptr[BO1 + BO2]);
988                     add(BO1, LDB);
989                 }
990
991                 if (i == 0) {
992                     if (unroll_m >= 4) {
993                         if (!isDirect) {
994                             prefetcht0(
995                                     ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
996                         } else {
997                             prefetcht0(ptr[AO1 + LDA4]);
998                         }
999                     }
1000                 }
1001                 if (i == 1 || i == 2) {
1002                     if (unroll_m >= 8) {
1003                         if (!isDirect) {
1004                             prefetcht0(ptr[AO1
1005                                     + (PREFETCHSIZEA + (2 + 2 * i) * 8)
1006                                             * SIZE]);
1007                         } else {
1008                             prefetcht0(ptr[AO1 + LDA4]);
1009                         }
1010                     }
1011                 }
1012                 if (i == 3) {
1013                     if (!isTransB) {
1014                         sub(BO1, -4 * SIZE);
1015                         if (unroll_n >= 4) {
1016                             sub(BO2, -4 * SIZE);
1017                         }
1018                     }
1019                 }
1020
1021                 if (!isDirect) {
1022                     if (isLoad1Unmasked) {
1023                         vmovups(ymm0,
1024                                 ptr[AO1
1025                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1026                                                 * SIZE]);
1027                     } else {
1028                         vmaskmovps(
1029                                 ymm0, VMASK,
1030                                 ptr[AO1
1031                                         + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1032                                                 * SIZE]);
1033                     }
1034                     if (unroll_m >= 16) {
1035                         if (isLoad2Unmasked) {
1036                             vmovups(ymm1, ptr[AO1
1037                                                   + (unroll_m * (i + 1) + 1 * 8
1038                                                             - OFFSET)
1039                                                           * SIZE]);
1040                         } else {
1041                             vmaskmovps(ymm1, VMASK,
1042                                     ptr[AO1
1043                                                + (unroll_m * (i + 1) + 1 * 8
1044                                                          - OFFSET)
1045                                                        * SIZE]);
1046                         }
1047                     }
1048                 }
1049             }
1050
1051             if (!isDirect) {
1052                 sub(AO1, -unroll_m * 4 * SIZE);
1053             }
1054
1055         };
1056
1057         // Inner kernel with k=2
1058         auto innerkernel2 = [&](int unroll_m, int unroll_n,
1059                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1060                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1061                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1062                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
1063                 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
1064                 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
1065                 Ymm reg23) {
1066
1067             Ymm fmareg;
1068
1069             for (int i = 0; i < 2; i++) {
1070                 if (isDirect) {
1071                     if (isLoad1Unmasked) {
1072                         vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1073                     } else {
1074                         vmaskmovps(ymm0, VMASK,
1075                                 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1076                     }
1077                     if (unroll_m >= 16) {
1078                         if (isLoad2Unmasked) {
1079                             vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1080                         } else {
1081                             vmaskmovps(ymm1, VMASK,
1082                                     ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1083                         }
1084                     }
1085                     add(AO1, LDA);
1086                 }
1087
1088                 if (!isTransB) {
1089                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1090                 } else {
1091                     vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1092                 }
1093                 fmareg = (i % 2 == 0) ? reg00 : reg12;
1094                 fma(useFma, ymm0, ymm2, fmareg);
1095                 if (unroll_m >= 16) {
1096                     fmareg = (i % 2 == 0) ? reg06 : reg18;
1097                     fma(useFma, ymm1, ymm2, fmareg);
1098                 }
1099                 if (unroll_n >= 2) {
1100                     if (!isTransB) {
1101                         vbroadcastss(
1102                                 ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1103                     } else {
1104                         vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1105                     }
1106                     fmareg = (i % 2 == 0) ? reg01 : reg13;
1107                     fma(useFma, ymm0, ymm2, fmareg);
1108                     if (unroll_m >= 16) {
1109                         fmareg = (i % 2 == 0) ? reg07 : reg19;
1110                         fma(useFma, ymm1, ymm2, fmareg);
1111                     }
1112                 }
1113
1114                 if (unroll_n >= 3) {
1115                     if (!isTransB) {
1116                         if (i == 2) {
1117                             prefetcht0(
1118                                     ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
1119                         }
1120                         vbroadcastss(
1121                                 ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1122                     } else {
1123                         vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1124                     }
1125                     fmareg = (i % 2 == 0) ? reg02 : reg14;
1126                     fma(useFma, ymm0, ymm2, fmareg);
1127                     if (unroll_m >= 16) {
1128                         fmareg = (i % 2 == 0) ? reg08 : reg20;
1129                         fma(useFma, ymm1, ymm2, fmareg);
1130                     }
1131                 }
1132
1133                 if (unroll_n >= 4) {
1134                     if (!isTransB) {
1135                         vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1136                     } else {
1137                         vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1138                     }
1139                     fmareg = (i % 2 == 0) ? reg03 : reg15;
1140                     fma(useFma, ymm0, ymm2, fmareg);
1141                     if (unroll_m >= 16) {
1142                         fmareg = (i % 2 == 0) ? reg09 : reg21;
1143                         fma(useFma, ymm1, ymm2, fmareg);
1144                     }
1145                 }
1146
1147                 if (unroll_n >= 5) {
1148                     if (!isTransB) {
1149                         vbroadcastss(
1150                                 ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1151                     } else {
1152                         vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1153                     }
1154                     fmareg = (i % 2 == 0) ? reg04 : reg16;
1155                     fma(useFma, ymm0, ymm2, fmareg);
1156                     if (unroll_m >= 16) {
1157                         fmareg = (i % 2 == 0) ? reg10 : reg22;
1158                         fma(useFma, ymm1, ymm2, fmareg);
1159                     }
1160                 }
1161
1162                 if (unroll_n >= 6) {
1163                     if (!isTransB) {
1164                         vbroadcastss(
1165                                 ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1166                     } else {
1167                         vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1168                     }
1169                     fmareg = (i % 2 == 0) ? reg05 : reg17;
1170                     fma(useFma, ymm0, ymm2, fmareg);
1171                     if (unroll_m >= 16) {
1172                         fmareg = (i % 2 == 0) ? reg11 : reg23;
1173                         fma(useFma, ymm1, ymm2, fmareg);
1174                     }
1175                 }
1176
1177                 if (isCopy) {
1178                     vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1179                             ymm0);
1180                     if (unroll_m >= 16) {
1181                         vmovups(ptr[LDA4
1182                                         + (unroll_m * 0 + 1 * 8 - OFFSET)
1183                                                 * SIZE],
1184                                 ymm1);
1185                     }
1186                     sub(LDA4, -unroll_m * SIZE);
1187                 }
1188
1189                 if (!isDirect) {
1190                     if (isLoad1Unmasked) {
1191                         vmovups(ymm0, ptr[AO1
1192                                               + (unroll_m * 1 + 0 * 8 - OFFSET)
1193                                                       * SIZE]);
1194                     } else {
1195                         vmaskmovps(ymm0, VMASK,
1196                                 ptr[AO1
1197                                            + (unroll_m * 1 + 0 * 8 - OFFSET)
1198                                                    * SIZE]);
1199                     }
1200                     if (unroll_m >= 16) {
1201                         if (isLoad2Unmasked) {
1202                             vmovups(ymm1,
1203                                     ptr[AO1
1204                                             + (unroll_m * 1 + 1 * 8 - OFFSET)
1205                                                     * SIZE]);
1206                         } else {
1207                             vmaskmovps(ymm1, VMASK,
1208                                     ptr[AO1
1209                                                + (unroll_m * 1 + 1 * 8 - OFFSET)
1210                                                        * SIZE]);
1211                         }
1212                     }
1213                     sub(AO1, -unroll_m * SIZE);
1214                 }
1215
1216                 if (!isTransB) {
1217                     sub(BO1, -SIZE);
1218                     if (unroll_n >= 4) {
1219                         sub(BO2, -SIZE);
1220                     }
1221                 } else {
1222                     add(BO1, LDB);
1223                 }
1224             }
1225
1226         };
1227
1228         // Inner kernel with k=1
1229         auto innerkernel1 = [&](int unroll_m, int unroll_n,
1230                 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1231                 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1232                 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1233                 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) {
1234
1235             if (isDirect) {
1236                 if (isLoad1Unmasked) {
1237                     vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1238                 } else {
1239                     vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1240                 }
1241                 if (unroll_m >= 16) {
1242                     if (isLoad2Unmasked) {
1243                         vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1244                     } else {
1245                         vmaskmovps(ymm1, VMASK,
1246                                 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1247                     }
1248                 }
1249                 add(AO1, LDA);
1250             }
1251
1252             if (!isTransB) {
1253                 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1254             } else {
1255                 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1256             }
1257             fma(useFma, ymm0, ymm2, reg00);
1258             if (unroll_m >= 16) {
1259                 fma(useFma, ymm1, ymm2, reg06);
1260             }
1261
1262             if (unroll_n >= 2) {
1263                 if (!isTransB) {
1264                     vbroadcastss(
1265                             ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1266                 } else {
1267                     vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1268                 }
1269                 fma(useFma, ymm0, ymm2, reg01);
1270                 if (unroll_m >= 16) {
1271                     fma(useFma, ymm1, ymm2, reg07);
1272                 }
1273             }
1274
1275             if (unroll_n >= 3) {
1276                 if (!isTransB) {
1277                     vbroadcastss(
1278                             ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1279                 } else {
1280                     vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1281                 }
1282                 fma(useFma, ymm0, ymm2, reg02);
1283                 if (unroll_m >= 16) {
1284                     fma(useFma, ymm1, ymm2, reg08);
1285                 }
1286             }
1287
1288             if (unroll_n >= 4) {
1289                 if (!isTransB) {
1290                     vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1291                 } else {
1292                     vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1293                 }
1294                 fma(useFma, ymm0, ymm2, reg03);
1295                 if (unroll_m >= 16) {
1296                     fma(useFma, ymm1, ymm2, reg09);
1297                 }
1298             }
1299
1300             if (unroll_n >= 5) {
1301                 if (!isTransB) {
1302                     vbroadcastss(
1303                             ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1304                 } else {
1305                     vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1306                 }
1307                 fma(useFma, ymm0, ymm2, reg04);
1308                 if (unroll_m >= 16) {
1309                     fma(useFma, ymm1, ymm2, reg10);
1310                 }
1311             }
1312
1313             if (unroll_n >= 6) {
1314                 if (!isTransB) {
1315                     vbroadcastss(
1316                             ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1317                 } else {
1318                     vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1319                 }
1320                 fma(useFma, ymm0, ymm2, reg05);
1321                 if (unroll_m >= 16) {
1322                     fma(useFma, ymm1, ymm2, reg11);
1323                 }
1324             }
1325
1326             if (isCopy) {
1327                 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1328                         ymm0);
1329                 if (unroll_m >= 16) {
1330                     vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
1331                             ymm1);
1332                 }
1333                 sub(LDA4, -unroll_m * SIZE);
1334             }
1335
1336             if (!isDirect) {
1337                 if (isLoad1Unmasked) {
1338                     vmovups(ymm0,
1339                             ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1340                 } else {
1341                     vmaskmovps(ymm0, VMASK,
1342                             ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1343                 }
1344                 if (unroll_m >= 16) {
1345                     if (isLoad2Unmasked) {
1346                         vmovups(ymm1, ptr[AO1
1347                                               + (unroll_m * 1 + 1 * 8 - OFFSET)
1348                                                       * SIZE]);
1349                     } else {
1350                         vmaskmovps(ymm1, VMASK,
1351                                 ptr[AO1
1352                                            + (unroll_m * 1 + 1 * 8 - OFFSET)
1353                                                    * SIZE]);
1354                     }
1355                 }
1356                 sub(AO1, -unroll_m * SIZE);
1357             }
1358
1359             if (!isTransB) {
1360                 sub(BO1, -SIZE);
1361                 if (unroll_n >= 4) {
1362                     sub(BO2, -SIZE);
1363                 }
1364             } else {
1365                 add(BO1, LDB);
1366             }
1367
1368         };
1369
1370         // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as
1371         // appropriate
1372         // After calculating results in registers, writes back to C matrix
1373         auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1374                 bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma,
1375                 Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6),
1376                 Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9),
1377                 Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12),
1378                 Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15),
1379                 Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6),
1380                 Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
1381                 Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
1382                 Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
1383             if (!isDirect) {
1384                 lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
1385             } else {
1386                 mov(AO1, A);
1387             }
1388
1389             if (isCopy) {
1390                 lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
1391             } else {
1392                 lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
1393             }
1394
1395             if (isTransB) {
1396                 lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]);
1397                 lea(BO2, ptr[BO2 + LDB * 2]);
1398             }
1399
1400             if (!isDirect) {
1401                 if (isLoad1Unmasked) {
1402                     vmovups(ymm0,
1403                             ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1404                 } else {
1405                     vmaskmovps(ymm0, VMASK,
1406                             ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1407                 }
1408                 if (unroll_m >= 16) {
1409                     if (isLoad2Unmasked) {
1410                         vmovups(ymm1, ptr[AO1
1411                                               + (unroll_m * 0 + 1 * 8 - OFFSET)
1412                                                       * SIZE]);
1413                     } else {
1414                         vmaskmovps(ymm1, VMASK,
1415                                 ptr[AO1
1416                                            + (unroll_m * 0 + 1 * 8 - OFFSET)
1417                                                    * SIZE]);
1418                     }
1419                 }
1420             }
1421
1422             for (int i = 4; i < 10; i++) {
1423                 vxorps(Ymm(i), Ymm(i), Ymm(i));
1424                 vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6));
1425             }
1426
1427             mov(LL, K);
1428             sar(LL, 3);
1429
1430             Label kernel12, kernel13, kernel14, kernel15;
1431             Label kernel16, kernel17, kernel18;
1432
1433             sub(LL, SECOND_FETCH);
1434             jle(kernel13, T_NEAR);
1435             align(16);
1436
1437             L(kernel12);
1438             innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1439                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1440                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1441                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1442                     reg21, reg22, reg23);
1443             jg(kernel12, T_NEAR);
1444             align(16);
1445
1446             L(kernel13);
1447             prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
1448             if (unroll_n >= 2)
1449                 prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
1450             if (unroll_n >= 3)
1451                 prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]);
1452             if (unroll_n >= 4)
1453                 prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]);
1454             if (unroll_n >= 5)
1455                 prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]);
1456             if (unroll_n >= 6)
1457                 prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
1458
1459             add(LL, SECOND_FETCH);
1460             jle(kernel15, T_NEAR);
1461             align(16);
1462
1463             L(kernel14);
1464             innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1465                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1466                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1467                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1468                     reg21, reg22, reg23);
1469             jg(kernel14, T_NEAR);
1470             align(16);
1471
1472             L(kernel15);
1473             test(K, 4);
1474             jle(kernel16, T_NEAR);
1475             innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1476                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1477                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1478                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1479                     reg21, reg22, reg23);
1480
1481             L(kernel16);
1482             test(K, 2);
1483             jle(kernel17, T_NEAR);
1484             innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1485                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1486                     reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1487                     reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1488                     reg21, reg22, reg23);
1489             align(16);
1490
1491             L(kernel17);
1492             if (unroll_m == 16) {
1493                 if (unroll_n <= 3) {
1494                     vaddps(reg00, reg00, reg12);
1495                     vaddps(reg01, reg01, reg13);
1496                     vaddps(reg02, reg02, reg14);
1497                     vaddps(reg06, reg06, reg18);
1498                     vaddps(reg07, reg07, reg19);
1499                     vaddps(reg08, reg08, reg20);
1500                 }
1501             }
1502
1503             if (unroll_m <= 8) {
1504                 vaddps(reg00, reg00, reg12);
1505                 vaddps(reg01, reg01, reg13);
1506                 vaddps(reg02, reg02, reg14);
1507                 vaddps(reg03, reg03, reg15);
1508                 vaddps(reg04, reg04, reg16);
1509                 vaddps(reg05, reg05, reg17);
1510             }
1511
1512             test(K, 1);
1513             jle(kernel18, T_NEAR);
1514             innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1515                     isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1516                     reg05, reg06, reg07, reg08, reg09, reg10, reg11);
1517             align(16);
1518
1519             L(kernel18);
1520             vbroadcastss(VALPHA, ALPHA);
1521
1522             if (isBetaN) {
1523                 vbroadcastss(VBETA, BETA);
1524             }
1525
1526             // Write back the results; all beta and bias cases need to be
1527             // handled
1528             switch (unroll_n) {
1529             case 1: mov(rax, LDC); break;
1530             case 2: lea(rax, ptr[LDC * 2]); break;
1531             case 3: lea(rax, ptr[LDC + LDC * 2]); break;
1532             case 4: lea(rax, ptr[LDC + LDC * 4]); break;
1533             case 5:
1534                 lea(rax, ptr[LDC * 4]);
1535                 add(rax, LDC);
1536                 break;
1537             case 6:
1538                 lea(rax, ptr[LDC + LDC * 2]);
1539                 add(rax, rax);
1540                 break;
1541             }
1542
1543             if (hasBias) {
1544                 mov(BIAS1, BIAS);
1545                 if (isLoad1Unmasked) {
1546                     vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1547                 } else {
1548                     vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]);
1549                 }
1550             }
1551
1552             for (int i = 0; i < unroll_n; i++) {
1553                 vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA);
1554                 if (!isBeta0) {
1555                     if (isLoad1Unmasked) {
1556                         switch (i) {
1557                         case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break;
1558                         case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break;
1559                         case 2:
1560                             vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1561                             break;
1562                         case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break;
1563                         case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break;
1564                         case 5:
1565                             vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1566                             break;
1567                         }
1568                     } else {
1569                         switch (i) {
1570                         case 0:
1571                             vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]);
1572                             break;
1573                         case 1:
1574                             vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]);
1575                             break;
1576                         case 2:
1577                             vmaskmovps(
1578                                     ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1579                             break;
1580                         case 3:
1581                             vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]);
1582                             break;
1583                         case 4:
1584                             vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]);
1585                             break;
1586                         case 5:
1587                             vmaskmovps(
1588                                     ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1589                             break;
1590                         }
1591                     }
1592
1593                     if (!isBetaN) {
1594                         vaddps(Ymm(i + 4), ymm0, Ymm(i + 4));
1595                     } else {
1596                         fma(useFma, VBETA, ymm0, Ymm(i + 4), true);
1597                     }
1598                 }
1599                 if (hasBias) {
1600                     vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4));
1601                 }
1602                 if (isLoad1Unmasked) {
1603                     switch (i) {
1604                     case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break;
1605                     case 1:
1606                         vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4));
1607                         break;
1608                     case 2:
1609                         vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1610                         break;
1611                     case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break;
1612                     case 4:
1613                         vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4));
1614                         break;
1615                     case 5:
1616                         vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1617                         break;
1618                     }
1619                 } else {
1620                     switch (i) {
1621                     case 0:
1622                         vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4));
1623                         break;
1624                     case 1:
1625                         vmaskmovps(
1626                                 ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1627                         break;
1628                     case 2:
1629                         vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK,
1630                                 Ymm(i + 4));
1631                         break;
1632                     case 3:
1633                         vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4));
1634                         break;
1635                     case 4:
1636                         vmaskmovps(
1637                                 ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1638                         break;
1639                     case 5:
1640                         vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK,
1641                                 Ymm(i + 4));
1642                         break;
1643                     }
1644                 }
1645
1646                 if (unroll_m >= 16) {
1647                     // Re-use ymm4 (VBIAS2)
1648                     if (i == 0) {
1649                         if (hasBias) {
1650                             if (isLoad1Unmasked) {
1651                                 vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]);
1652                             } else {
1653                                 vmaskmovps(
1654                                         VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]);
1655                             }
1656                         }
1657                     }
1658                     vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA);
1659                     if (!isBeta0) {
1660                         if (isLoad2Unmasked) {
1661                             switch (i) {
1662                             case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break;
1663                             case 1:
1664                                 vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]);
1665                                 break;
1666                             case 2:
1667                                 vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]);
1668                                 break;
1669                             case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break;
1670                             case 4:
1671                                 vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]);
1672                                 break;
1673                             case 5:
1674                                 vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]);
1675                                 break;
1676                             }
1677                         } else {
1678                             switch (i) {
1679                             case 0:
1680                                 vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]);
1681                                 break;
1682                             case 1:
1683                                 vmaskmovps(
1684                                         ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]);
1685                                 break;
1686                             case 2:
1687                                 vmaskmovps(ymm0, VMASK,
1688                                         ptr[CO1 + LDC * 2 + 8 * SIZE]);
1689                                 break;
1690                             case 3:
1691                                 vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]);
1692                                 break;
1693                             case 4:
1694                                 vmaskmovps(
1695                                         ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]);
1696                                 break;
1697                             case 5:
1698                                 vmaskmovps(ymm0, VMASK,
1699                                         ptr[CO2 + LDC * 2 + 8 * SIZE]);
1700                                 break;
1701                             }
1702                         }
1703                         if (!isBetaN) {
1704                             vaddps(Ymm(i + 10), ymm0, Ymm(i + 10));
1705                         } else {
1706                             fma(useFma, VBETA, ymm0, Ymm(i + 10), true);
1707                         }
1708                     }
1709                     if (hasBias) {
1710                         vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10));
1711                     }
1712                     if (isLoad2Unmasked) {
1713                         switch (i) {
1714                         case 0:
1715                             vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10));
1716                             break;
1717                         case 1:
1718                             vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10));
1719                             break;
1720                         case 2:
1721                             vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1722                             break;
1723                         case 3:
1724                             vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10));
1725                             break;
1726                         case 4:
1727                             vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10));
1728                             break;
1729                         case 5:
1730                             vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1731                             break;
1732                         }
1733                     } else {
1734                         switch (i) {
1735                         case 0:
1736                             vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10));
1737                             break;
1738                         case 1:
1739                             vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK,
1740                                     Ymm(i + 10));
1741                             break;
1742                         case 2:
1743                             vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK,
1744                                     Ymm(i + 10));
1745                             break;
1746                         case 3:
1747                             vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10));
1748                             break;
1749                         case 4:
1750                             vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK,
1751                                     Ymm(i + 10));
1752                             break;
1753                         case 5:
1754                             vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK,
1755                                     Ymm(i + 10));
1756                             break;
1757                         }
1758                     }
1759                 }
1760                 if (i == 2)
1761                     add(CO1, rax);
1762             }
1763             if (unroll_n >= 4) {
1764                 add(CO2, rax);
1765             }
1766
1767             // Compute next address of B
1768             if (!isTransB) {
1769                 lea(rax, ptr[K * SIZE]);
1770                 switch (unroll_n) {
1771                 case 1:
1772                     add(BO1, LDB);
1773                     add(BO2, LDB);
1774                     break;
1775                 case 2:
1776                     lea(BO1, ptr[BO1 + LDB * 2]);
1777                     lea(BO2, ptr[BO2 + LDB * 2]);
1778                     break;
1779                 case 3:
1780                     lea(BO1, ptr[BO1 + LDB3]);
1781                     lea(BO2, ptr[BO2 + LDB3]);
1782                     break;
1783                 case 4:
1784                     lea(BO1, ptr[BO1 + LDB * 4]);
1785                     lea(BO2, ptr[BO2 + LDB * 4]);
1786                     break;
1787                 case 5:
1788                     lea(BO1, ptr[BO1 + LDB * 4]);
1789                     add(BO1, LDB);
1790                     lea(BO2, ptr[BO2 + LDB * 4]);
1791                     add(BO2, LDB);
1792                     break;
1793                 case 6:
1794                     lea(BO1, ptr[BO1 + LDB3 * 2]);
1795                     lea(BO2, ptr[BO2 + LDB3 * 2]);
1796                     break;
1797                 }
1798                 sub(BO1, rax);
1799                 sub(BO2, rax);
1800             } else {
1801                 mov(rax, LDB);
1802                 imul(rax, K);
1803                 sub(BO1, rax);
1804                 add(BO1, unroll_n * SIZE);
1805             }
1806         };
1807
1808         auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1809                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1810             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1811                     isDirect, isCopy, true);
1812         };
1813
1814         auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1815                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1816             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1817                     isDirect, isCopy, true);
1818         };
1819
1820         auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1821                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1822             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1823                     isDirect, isCopy, true);
1824         };
1825
1826         auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1827                 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1828                 bool useFma = true) {
1829             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1830                     isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1831                     Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1832                     Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1833                     Ymm(13), Ymm(14), Ymm(15));
1834         };
1835
1836         auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1837                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1838             kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1839                     isDirect, isCopy, false);
1840         };
1841
1842         auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1843                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1844             kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1845                     isDirect, isCopy, false);
1846         };
1847
1848         auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1849                 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1850                 bool useFma = true) {
1851             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1852                     isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1853                     Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1854                     Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1855                     Ymm(15));
1856         };
1857
1858         auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1859                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1860             kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1861                     isDirect, isCopy);
1862         };
1863
1864         auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1865                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1866             kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1867                     isDirect, isCopy);
1868         };
1869
1870         auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1871                 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1872                 bool useFma = true) {
1873             kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1874                     isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1875                     Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1876                     Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1877                     Ymm(13), Ymm(14), Ymm(15));
1878         };
1879
1880         auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1881                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1882             kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1883                     isDirect, isCopy, false);
1884         };
1885
1886         auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1887                 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1888             kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1889                     isDirect, isCopy, false);
1890         };
1891
1892         // High-level subroutine; does packing if needed, then splits C matrix.
1893         // Operates on chunks of 16 rows, 6 columns at a time (handling tail
1894         // cases appropriately).
1895         // Masking is used for tail cases where M is not divisible by 8.
1896         auto subloop = [&](
1897                 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
1898             if (isTransA) {
1899                 do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
1900             }
1901
1902             Label subloop11, subloop11mask;
1903             Label subloop20, subloop21, subloop22, subloop23;
1904             Label subloop24, subloop25;
1905             Label subloop30, subloop31, subloop32, subloop33;
1906             Label subloop34, subloop35;
1907             Label subloop98, subloop98mask;
1908             Label subloop99, subloop99mask;
1909
1910             mov(CO1, C);
1911             lea(CO2, ptr[CO1 + LDC * 2]);
1912             add(CO2, LDC);
1913             add(C, unroll_m * SIZE);
1914             mov(BO1, B);
1915             if (!isTransB) {
1916                 lea(BO2, qword[B + LDB3]);
1917             }
1918
1919             if (!isTransA) {
1920                 lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
1921                 cmp(M, UNROLL_M);
1922                 jg(subloop98, T_NEAR);
1923
1924                 mov(AA, ORIG_A);
1925                 lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
1926                 L(subloop98);
1927             }
1928
1929             mov(LL, N);
1930             mov(I, LL);
1931             if (!isTransA) {
1932                 // If N is too small, skip copy operation
1933                 cmp(LL, UNROLL_N * 3);
1934                 jle(subloop30, T_NEAR);
1935
1936                 // If A is not aligned to cache line
1937                 cmp(FLAG, 0);
1938                 je(subloop30, T_NEAR);
1939             } else {
1940                 cmp(LL, UNROLL_N);
1941                 jl(subloop20, T_NEAR);
1942             }
1943             align(16);
1944
1945             if (!isTransA) {
1946                 if (unroll_m == 16) {
1947                     kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1948                             isLoad2Unmasked, true, true);
1949                 } else {
1950                     kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1951                             isLoad2Unmasked, true, true);
1952                 }
1953             } else {
1954                 if (unroll_m == 16) {
1955                     kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1956                             isLoad2Unmasked, false, false);
1957                 } else {
1958                     kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1959                             isLoad2Unmasked, false, false);
1960                 }
1961             }
1962
1963             sub(I, UNROLL_N);
1964             cmp(I, UNROLL_N);
1965             jl(subloop20, T_NEAR);
1966             align(16);
1967
1968             L(subloop11);
1969             if (unroll_m == 16) {
1970                 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1971                         isLoad2Unmasked, false, false);
1972             } else {
1973                 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked,
1974                         false, false);
1975             }
1976             sub(I, UNROLL_N);
1977             cmp(I, UNROLL_N);
1978             jge(subloop11, T_NEAR);
1979             align(16);
1980
1981             L(subloop20);
1982             cmp(I, 1);
1983             jne(subloop21, T_NEAR);
1984             if (unroll_m == 16) {
1985                 kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
1986                         false, false);
1987             } else {
1988                 kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
1989                         false);
1990             }
1991             jmp(subloop99, T_NEAR);
1992             align(16);
1993
1994             L(subloop21);
1995             cmp(I, 2);
1996             jne(subloop22, T_NEAR);
1997             if (unroll_m == 16) {
1998                 kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
1999                         false, false);
2000             } else {
2001                 kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
2002                         false);
2003             }
2004             jmp(subloop99, T_NEAR);
2005             align(16);
2006
2007             L(subloop22);
2008             cmp(I, 3);
2009             jne(subloop23, T_NEAR);
2010             if (unroll_m == 16) {
2011                 kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2012                         false, false);
2013             } else {
2014                 kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
2015                         false);
2016             }
2017             jmp(subloop99, T_NEAR);
2018             align(16);
2019
2020             L(subloop23);
2021             cmp(I, 4);
2022             jne(subloop24, T_NEAR);
2023             if (unroll_m == 16) {
2024                 kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2025                         false, false);
2026             } else {
2027                 kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
2028                         false);
2029             }
2030             jmp(subloop99, T_NEAR);
2031             align(16);
2032
2033             L(subloop24);
2034             cmp(I, 5);
2035             jne(subloop99, T_NEAR);
2036             if (unroll_m == 16) {
2037                 kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2038                         false, false);
2039             } else {
2040                 kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
2041                         false);
2042             }
2043             jmp(subloop99, T_NEAR);
2044             align(16);
2045
2046             if (!isTransA) {
2047                 L(subloop30);
2048                 cmp(I, UNROLL_N);
2049                 jl(subloop25, T_NEAR);
2050                 align(16);
2051
2052                 L(subloop31);
2053                 if (unroll_m == 16) {
2054                     kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2055                             isLoad2Unmasked, true, false);
2056                 } else {
2057                     kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2058                             isLoad2Unmasked, true, false);
2059                 }
2060                 sub(I, UNROLL_N);
2061                 cmp(I, UNROLL_N);
2062                 jge(subloop31, T_NEAR);
2063                 align(16);
2064
2065                 L(subloop25);
2066                 cmp(I, 1);
2067                 jne(subloop32, T_NEAR);
2068                 if (unroll_m == 16) {
2069                     kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2070                             true, false);
2071                 } else {
2072                     kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2073                             true, false);
2074                 }
2075                 jmp(subloop99, T_NEAR);
2076                 align(16);
2077
2078                 L(subloop32);
2079                 cmp(I, 2);
2080                 jne(subloop33, T_NEAR);
2081                 if (unroll_m == 16) {
2082                     kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2083                             true, false);
2084                 } else {
2085                     kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2086                             true, false);
2087                 }
2088                 jmp(subloop99, T_NEAR);
2089                 align(16);
2090
2091                 L(subloop33);
2092                 cmp(I, 3);
2093                 jne(subloop34, T_NEAR);
2094                 if (unroll_m == 16) {
2095                     kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2096                             true, false);
2097                 } else {
2098                     kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2099                             true, false);
2100                 }
2101                 jmp(subloop99, T_NEAR);
2102                 align(16);
2103
2104                 L(subloop34);
2105                 cmp(I, 4);
2106                 jne(subloop35, T_NEAR);
2107                 if (unroll_m == 16) {
2108                     kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2109                             true, false);
2110                 } else {
2111                     kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2112                             true, false);
2113                 }
2114                 jmp(subloop99, T_NEAR);
2115                 align(16);
2116
2117                 L(subloop35);
2118                 cmp(I, 5);
2119                 jne(subloop99, T_NEAR);
2120                 if (unroll_m == 16) {
2121                     kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2122                             true, false);
2123                 } else {
2124                     kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2125                             true, false);
2126                 }
2127                 align(16);
2128             }
2129
2130             L(subloop99);
2131             // Compute address for A
2132             if (!isTransA) {
2133                 add(A, unroll_m * SIZE);
2134             } else {
2135                 mov(rax, LDA);
2136                 imul(rax, rax, unroll_m);
2137                 add(A, rax);
2138             }
2139
2140             // Compute next address of BIAS
2141             if (hasBias) {
2142                 add(BIAS, unroll_m * SIZE);
2143             }
2144         };
2145
2146         preamble();
2147
2148         Label buffer_in_ws, buffer_allocated;
2149
2150         // Get the registers
2151         mov(B, ARG_B);
2152         mov(LDB, ARG_LDB);
2153         mov(r15, ARG_BETA);
2154         mov(r12, ARG_C);
2155         if (hasBias)
2156             mov(r10, ARG_BIAS);
2157         mov(LDC, ARG_LDC);
2158         mov(rbp, rsp);
2159
2160         vmovss(xmm0, ptr[ARG_ALPHA]);
2161         vmovss(xmm1, ptr[r15]);
2162
2163 #if _WIN32
2164         mov(A, ARG_A);
2165         mov(LDA, ARG_LDA);
2166 #endif
2167
2168         cmp(K, STACK_K_CAPACITY);
2169         jg(buffer_in_ws, T_NEAR);
2170
2171         // Create buffer and align to 4kB page
2172         lea(rax, ptr[K * SIZE]);
2173         sal(rax, 4);
2174         add(rax, 256);
2175         sub(rsp, rax);
2176         and_(rsp, -PAGE_4K);
2177         jmp(buffer_allocated, T_NEAR);
2178
2179         L(buffer_in_ws);
2180         mov(rsp, ARG_WS);
2181
2182         L(buffer_allocated);
2183
2184         mov(ORIG_SP, rbp);
2185         mov(M, ARG_M);
2186         mov(N, ARG_N);
2187         mov(C, r12);
2188         if (hasBias)
2189             mov(BIAS, r10);
2190         vmovss(ALPHA, xmm0);
2191         vmovss(BETA, xmm1);
2192         sub(A, -OFFSET * SIZE);
2193         sub(B, -OFFSET * SIZE);
2194         mov(ORIG_A, A);
2195         sal(LDA, BASE_SHIFT);
2196         sal(LDB, BASE_SHIFT);
2197         sal(LDC, BASE_SHIFT);
2198         lea(LDB3, ptr[LDB + LDB * 2]);
2199
2200         for (int i = 0; i < 8; i++) {
2201             mov(dword[rsp + 88 + i * 4], i);
2202         }
2203
2204         if (isTransA && is_avx2) {
2205             movq(xmm0, LDA);
2206             vpbroadcastq(ymm1, xmm0);
2207             vinsertf128(ymm0, ymm0, xmm0, 1);
2208             vpermilpd(ymm0, ymm0, 5);
2209             vpaddq(ymm1, ymm1, ymm1);
2210             vperm2f128(ymm1, ymm1, ymm1, 8);
2211             vpaddq(ymm0, ymm0, ymm1);
2212             vmovups(STRIDE, ymm0);
2213         }
2214
2215         // Check A alignment and leading dimension; take copy-based path as
2216         // needed
2217         mov(rax, LDA);
2218         or_(rax, A);
2219         and_(rax, 0x1f);
2220         mov(FLAG, rax);
2221
2222         Label main0, main1, main2, main3, main999;
2223
2224         cmp(M, UNROLL_M);
2225         jl(main0, T_NEAR);
2226         align(16);
2227
2228         L(main1);
2229         subloop(UNROLL_M, true, true);
2230         sub(M, UNROLL_M);
2231         cmp(M, UNROLL_M);
2232         jge(main1, T_NEAR);
2233         align(16);
2234
2235         L(main0);
2236         cmp(M, 0);
2237         jle(main999, T_NEAR);
2238
2239         if (UNROLL_M > 8) {
2240             cmp(M, 8);
2241             jle(main2, T_NEAR);
2242
2243             sub(M, 8);
2244             vbroadcastss(VMASK, M);
2245             vpcmpgtd(VMASK, VMASK, MASK);
2246
2247             subloop(16, true, false);
2248             jmp(main999, T_NEAR);
2249             align(16);
2250
2251             L(main2);
2252             cmp(M, 8);
2253             jne(main3, T_NEAR);
2254             subloop(8, true, true);
2255             jmp(main999, T_NEAR);
2256         }
2257
2258         align(16);
2259
2260         L(main3);
2261         vbroadcastss(VMASK, M);
2262         if (is_avx2) {
2263             vpcmpgtd(VMASK, VMASK, MASK);
2264         } else {
2265             auto xmask = Xmm(VMASK.getIdx());
2266             auto xmm_tmp = xmm4;
2267
2268             vextractf128(xmm_tmp, VMASK, 1);
2269             vpcmpgtd(xmask, xmask, MASK);
2270             vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4
2271             vinsertf128(VMASK, VMASK, xmm_tmp, 1);
2272         }
2273         subloop(8, false, false);
2274         align(16);
2275
2276         L(main999);
2277         // Restore original stack
2278         mov(rax, ORIG_SP);
2279         mov(rsp, rax);
2280
2281         vzeroupper();
2282         postamble();
2283
2284         ker_ = reinterpret_cast<decltype(ker_)>(
2285                 const_cast<uint8_t *>(this->getCode()));
2286     }
2287
2288     void operator()(long long int m, long long int n, long long int k,
2289             const float *alpha, const float *a, long long int lda,
2290             const float *b, long long int ldb, const float *beta, float *c,
2291             long long int ldc, const float *bias, float *ws)
2292     {
2293         (*ker_)(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
2294     }
2295
2296 private:
2297     void (*ker_)(long long int m, long long int n, long long int k,
2298             const float *alpha, const float *a, long long int lda,
2299             const float *b, long long int ldb, const float *beta, float *c,
2300             long long int ldc, const float *bias, float *ws);
2301 };
2302
2303 typedef void (*ker)(long long int, long long int, long long int, float *,
2304         float *, long long int, float *, long long int, float *, float *,
2305         long long int, float *);
2306 void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
2307         const char *transb, int m, int n, int k, const float *alpha,
2308         const float *a, int lda, const float *b, int ldb, const float *beta,
2309         float *c, int ldc, const float *bias, float *ws)
2310 {
2311     bool isTransA = (*transa == 'T' || *transa == 't');
2312     bool isTransB = (*transb == 'T' || *transb == 't');
2313
2314     int Bm, sizeM, Bn, sizeN, Bk, sizeK;
2315
2316     int i, j;
2317
2318     if ((m <= 0) || (n <= 0))
2319         return;
2320
2321     if ((k <= 0) || (alpha[0] == 0.)) {
2322
2323         if (beta[0] == 0.) {
2324             for (j = 0; j < n; j++)
2325                 for (i = 0; i < m; i++)
2326                     c[i + j * ldc] = 0.0;
2327         } else if (beta[0] != 1.) {
2328             for (j = 0; j < n; j++)
2329                 for (i = 0; i < m; i++)
2330                     c[i + j * ldc] *= beta[0];
2331         }
2332
2333         return;
2334     }
2335
2336     int BM = 4032;
2337     int BN = isTransA ? 96 : 48;
2338     int BK = isTransB ? 96 : 256;
2339     const float *curA, *curB, *curBias = nullptr;
2340     float *curC;
2341
2342     for (Bk = 0; Bk < k; Bk += sizeK) {
2343         sizeK = k - Bk;
2344         if (sizeK >= BK * 2)
2345             sizeK = BK;
2346         else {
2347             if (sizeK > BK)
2348                 sizeK = (sizeK + 1) / 2;
2349         }
2350
2351         for (Bm = 0; Bm < m; Bm += sizeM) {
2352             sizeM = m - Bm;
2353             if (sizeM >= BM * 2)
2354                 sizeM = BM;
2355             else {
2356                 if (sizeM > BM + BM / 2)
2357                     sizeM = (sizeM + 1) / 2;
2358             }
2359
2360             for (Bn = 0; Bn < n; Bn += sizeN) {
2361                 sizeN = n - Bn;
2362                 if (sizeN >= BN * 2)
2363                     sizeN = BN;
2364                 else {
2365                     if (sizeN > BN + BN / 2)
2366                         sizeN = (sizeN + 1) / 2;
2367                 }
2368
2369                 if (!isTransA) {
2370                     curA = a + Bm + (size_t)Bk * lda;
2371                 } else {
2372                     curA = a + Bk + (size_t)Bm * lda;
2373                 }
2374                 if (!isTransB) {
2375                     curB = b + Bk + (size_t)Bn * ldb;
2376                 } else {
2377                     curB = b + Bn + (size_t)Bk * ldb;
2378                 }
2379                 curC = c + Bm + (size_t)Bn * ldc;
2380                 if (bias != nullptr) {
2381                     if (Bk == 0) {
2382                         curBias = bias + Bm;
2383                     } else {
2384                         curBias = nullptr;
2385                     }
2386                 }
2387                 if (Bk == 0) {
2388                     if (*beta == 0.0 && bias == nullptr)
2389                         (*ker_b0_)((long long int)sizeM, (long long int)sizeN,
2390                                 (long long int)sizeK, alpha, curA,
2391                                 (long long int)lda, curB, (long long int)ldb,
2392                                 beta, curC, (long long int)ldc, curBias, ws);
2393                     else
2394                         (*ker_bn_)((long long int)sizeM, (long long int)sizeN,
2395                                 (long long int)sizeK, alpha, curA,
2396                                 (long long int)lda, curB, (long long int)ldb,
2397                                 beta, curC, (long long int)ldc, curBias, ws);
2398                 } else {
2399                     (*ker_b1_)((long long int)sizeM, (long long int)sizeN,
2400                             (long long int)sizeK, alpha, curA,
2401                             (long long int)lda, curB, (long long int)ldb, beta,
2402                             curC, (long long int)ldc, curBias, ws);
2403                 }
2404             }
2405         }
2406     }
2407     return;
2408 }
2409 void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
2410         const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
2411         const float *A, const int *p_lda, const float *B, const int *p_ldb,
2412         const float *p_beta, float *C, const int *p_ldc, const float *bias)
2413 {
2414     if (beta_ == 0. || beta_ == 1.)
2415         assert(*p_beta == beta_);
2416     assert((one_of(*transa, 'T', 't') == one_of(transa_, 'T', 't')));
2417
2418     int nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
2419     int m = *p_m;
2420     int n = *p_n;
2421     int k = *p_k;
2422     int lda = *p_lda;
2423     int ldb = *p_ldb;
2424     int ldc = *p_ldc;
2425     float beta = *p_beta;
2426     int MB, NB, KB;
2427
2428     int nthr_m, nthr_n, nthr_k, nthr_mn;
2429
2430     assert(nthr <= nthrs_);
2431
2432     // Determine threading partitioning
2433     gemm_utils::calc_nthr_nocopy_avx(
2434             m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
2435     assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
2436
2437     // May not happen, but just in case
2438     if (nthr < nthr_m * nthr_n * nthr_k)
2439         nthr = nthr_m * nthr_n * nthr_k;
2440
2441     nthr_mn = nthr_m * nthr_n;
2442
2443     unsigned char * ompstatus_ = nullptr;
2444     unsigned char volatile *ompstatus = nullptr;
2445
2446     float *c_buffers = nullptr;
2447     float *ws_buffers = nullptr;
2448
2449     if (nthr_k > 1) {
2450         ompstatus_ = (unsigned char *) malloc(
2451                 nthr * CACHE_LINE_SIZE,
2452                 CACHE_LINE_SIZE);
2453         ompstatus = (unsigned char volatile *) ompstatus_;
2454         assert(ompstatus);
2455
2456         for (int i = 0; i < nthr; i++)
2457             ompstatus[i * CACHE_LINE_SIZE] = 0;
2458
2459         c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
2460                 * sizeof(float), PAGE_4K);
2461     }
2462
2463     const size_t ws_elems_per_thr = k * 16 + 64;
2464     const size_t ws_size_per_thr
2465             = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
2466     if (k > STACK_K_CAPACITY) {
2467         ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
2468     }
2469
2470     parallel(nthr, [&](const int ithr, const int nthr) {
2471         int ithr_m, ithr_n, ithr_k, ithr_mn;
2472         int m_from, m_to, myM;
2473         int n_from, n_to, myN;
2474         int k_from, k_to, myK;
2475         int cbase, ibase;
2476         const float *myA, *myB, *myBias = nullptr;
2477         float *myC = C, myBeta;
2478         float *ws = ws_buffers ?
2479                 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
2480         int ld = ldc;
2481
2482         if (ithr < nthr_m * nthr_n * nthr_k) {
2483
2484             ithr_mn = ithr % nthr_mn;
2485             ithr_m = ithr_mn % nthr_m;
2486             ithr_n = ithr_mn / nthr_m;
2487             ithr_k = ithr / nthr_mn;
2488
2489             /* swap ithr_k for performance improvement */
2490             if (ithr_k == 0)
2491                 ithr_k = nthr_k - 1;
2492             else if (ithr_k == nthr_k - 1)
2493                 ithr_k = 0;
2494
2495             m_from = MB * (ithr_m);
2496             m_to = MB * (ithr_m + 1);
2497             if (m_to > m)
2498                 m_to = m;
2499             myM = m_to - m_from;
2500
2501             n_from = NB * (ithr_n);
2502             n_to = NB * (ithr_n + 1);
2503             if (n_to > n)
2504                 n_to = n;
2505             myN = n_to - n_from;
2506
2507             k_from = KB * (ithr_k);
2508             k_to = KB * (ithr_k + 1);
2509             if (k_to > k)
2510                 k_to = k;
2511             myK = k_to - k_from;
2512
2513             cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2514             ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
2515
2516             if ((myM > 0) && (myN > 0)) {
2517
2518                 if (*transa == 'N' || *transa == 'n') {
2519                     myA = &(A[m_from + k_from * lda]);
2520                 } else {
2521                     myA = &(A[k_from + m_from * lda]);
2522                 }
2523                 if (*transb == 'N' || *transb == 'n') {
2524                     myB = &(B[k_from + n_from * ldb]);
2525                 } else {
2526                     myB = &(B[n_from + k_from * ldb]);
2527                 }
2528                 if (ithr_k == 0) {
2529                     myC = &(C[m_from + n_from * ldc]);
2530                     myBeta = beta;
2531                     ld = ldc;
2532                     if (hasBias_)
2533                         myBias = &(bias[m_from]);
2534                 } else {
2535                     myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
2536                     myBeta = 0.0;
2537                     ld = MB;
2538                     myBias = nullptr;
2539                 }
2540
2541                 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
2542                         lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
2543
2544                 if (nthr_k > 1)
2545                     ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
2546             }
2547
2548             if (nthr_k > 1) {
2549
2550                 // sum matrices partitioned along K dimension
2551                 int n1, n2;
2552
2553                 gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2554
2555                 if (ithr_k > 0) {
2556
2557                     myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
2558                     myC = myC + n1 * MB;
2559                     /* need to wait until main thread finishes */
2560                     while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
2561                     };
2562
2563                     /* my cache is hot */
2564                     gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2565                             &C[m_from + (n_from + n1) * ldc], ldc);
2566                 }
2567
2568                 for (int ik = 1; ik < nthr_k; ++ik) {
2569                     if (ik != ithr_k) {
2570
2571                         myC = c_buffers + MB * NB * (cbase + ik - 1);
2572                         myC = myC + n1 * MB;
2573
2574                         while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2575                         };
2576
2577                         gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2578                                 &C[m_from + (n_from + n1) * ldc], ldc);
2579                     }
2580                 }
2581             }
2582         }
2583     });
2584
2585     free(c_buffers);
2586     free(ompstatus_);
2587     free(ws_buffers);
2588 }
2589
2590 jit_avx_gemm_f32::jit_avx_gemm_f32(
2591         char transa, char transb, float beta, bool hasBias)
2592 {
2593     transa_ = transa;
2594     transb_ = transb;
2595     beta_ = beta;
2596     hasBias_ = hasBias;
2597     if (hasBias) {
2598         assert(beta == 0.0);
2599     }
2600     ker_bn_ = new xbyak_gemm(transa, transb, beta, hasBias);
2601     if (beta != 1.0) {
2602         ker_b1_ = new xbyak_gemm(transa, transb, 1.0);
2603     } else {
2604         ker_b1_ = ker_bn_;
2605     }
2606     if (beta != 0.0 || (beta == 0.0 && hasBias)) {
2607         ker_b0_ = new xbyak_gemm(transa, transb, 0.0);
2608     } else {
2609         ker_b0_ = ker_bn_;
2610     }
2611     nthrs_ = mkldnn_get_max_threads();
2612 }
2613
2614 jit_avx_gemm_f32::~jit_avx_gemm_f32()
2615 {
2616     delete ker_bn_;
2617     if (beta_ != 1.0)
2618         delete ker_b1_;
2619     if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
2620         delete ker_b0_;
2621 }
2622
2623 }
2624 }
2625 }
2626
2627 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s