1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "mkldnn_thread.hpp"
21 #include "gemm_utils.hpp"
22 #include "jit_avx_gemm_f32.hpp"
24 #define CACHE_LINE_SIZE 64
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::utils;
33 using namespace Xbyak;
34 #define STACKSIZE get_size_of_abi_save_regs()
36 #define STACK_K_CAPACITY 128
38 #define STACK_K_CAPACITY 8192
43 #define SECOND_FETCH 14
45 struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
46 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
48 xbyak_gemm(char transa, char transb, float beta, bool hasBias = false,
49 void *code_ptr = nullptr,
50 size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
51 : jit_generator(code_ptr, code_size)
53 const bool is_avx2 = mayiuse(avx2);
54 assert(IMPLICATION(!is_avx2, mayiuse(avx)));
56 const int UNROLL_M = is_avx2 ? 16 : 8;
57 const int UNROLL_N = 6;
59 bool isTransA = (transa == 'T' || transa == 't');
60 bool isTransB = (transb == 'T' || transb == 't');
61 bool isBeta0 = (beta == 0.0);
62 bool isBetaN = (!isBeta0 && beta != 1.0);
64 // various definitions for convenience
65 auto ARG_M = abi_param1;
66 auto ARG_N = abi_param2;
68 auto ARG_ALPHA = abi_param4;
70 auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
71 auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
72 sizeof(float *) + STACKSIZE];
73 const auto stackOffset = OFFSET_SHADOWSPACE +
74 sizeof(float *) + STACKSIZE;
80 const auto stackOffset = STACKSIZE;
84 auto ARG_B = ptr[rsp + 8 + stackOffset];
85 auto ARG_LDB = ptr[rsp + 16 + stackOffset];
86 auto ARG_BETA = ptr[rsp + 24 + stackOffset];
87 auto ARG_C = ptr[rsp + 32 + stackOffset];
88 auto ARG_LDC = ptr[rsp + 40 + stackOffset];
89 auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
90 auto ARG_WS = ptr[rsp + 56 + stackOffset];
96 auto AO1 = abi_param2;
97 auto BO1 = abi_param4;
102 auto LDA4 = abi_param1;
104 auto BIAS1 = abi_param1;
106 auto M = qword[rsp + 0];
107 auto N = qword[rsp + 8];
108 auto FLAG = qword[rsp + 16];
109 auto I = qword[rsp + 24];
110 auto C = qword[rsp + 32];
111 auto BIAS = qword[rsp + 40];
112 auto ALPHA = qword[rsp + 48];
113 auto BETA = qword[rsp + 64];
114 auto ORIG_A = qword[rsp + 80];
115 auto MASK = dword[rsp + 88];
116 auto STRIDE = qword[rsp + 120];
117 auto ORIG_SP = qword[rsp + 152];
125 auto PREFETCHSIZEA = 128;
126 auto PREFETCHSIZEB = (!isTransB) ? -16 : 0;
128 // Function for packing if needed
130 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
131 Label pack2, pack3, pack4, pack10;
137 lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
140 lea(BO2, ptr[BO1 + LDA * 4]);
141 lea(CO1, ptr[LDA + LDA * 2]);
142 vmovupd(ymm7, STRIDE);
152 for (int i = 0; i < 4; i++) {
153 regIdx = (i % 2 == 0) ? 4 : 6;
154 if (isLoad1Unmasked) {
156 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
158 vmaskmovps(Ymm(regIdx), VMASK,
159 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
162 if (isLoad2Unmasked) {
163 vmovups(Ymm(regIdx + 1),
164 ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
166 vmaskmovps(Ymm(regIdx + 1), VMASK,
167 ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
172 vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
176 + (unroll_m * i + 1 * 8 - OFFSET)
183 if (isLoad1Unmasked) {
184 for (int i = 0; i < 2; i++) {
185 reg = (i % 2 == 0) ? BO1 : BO2;
186 vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]);
188 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
189 lea(BO2, ptr[reg + LDA * 2]);
190 vunpcklps(xmm4, xmm0, xmm1);
191 vunpckhps(xmm5, xmm0, xmm1);
192 vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
194 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
195 lea(BO2, ptr[BO2 + LDA * 2]);
196 vunpcklps(xmm6, xmm0, xmm1);
197 vunpckhps(xmm2, xmm0, xmm1);
199 vunpcklpd(xmm0, xmm4, xmm6);
200 vunpckhpd(xmm1, xmm4, xmm6);
202 + (unroll_m * 0 + i * 4 - OFFSET)
206 + (unroll_m * 1 + i * 4 - OFFSET)
209 vunpcklpd(xmm0, xmm5, xmm2);
210 vunpckhpd(xmm1, xmm5, xmm2);
212 + (unroll_m * 2 + i * 4 - OFFSET)
216 + (unroll_m * 3 + i * 4 - OFFSET)
220 } else if (is_avx2) {
221 for (int i = 0; i < 2; i++) {
224 ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE],
228 ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
232 + (unroll_m * (2 * i) + 0 * 4 - OFFSET)
236 + (unroll_m * (2 * i + 1) + 0 * 4
242 lea(BO2, ptr[BO1 + LDA * 4]);
244 for (int i = 0; i < 2; i++) {
245 vextractf128(xmm4, ymm3, 1);
247 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
249 vextractf128(xmm4, ymm3, 1);
251 ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
255 + (unroll_m * (2 * i) + 1 * 4 - OFFSET)
259 + (unroll_m * (2 * i + 1) + 1 * 4
265 lea(BO2, ptr[BO2 + LDA * 4]);
267 vxorps(xmm4, xmm4, xmm4);
268 lea(BO2, ptr[BO1 + LDA * 4]);
270 auto el_cp = [&](int section, int ld_step) {
271 RegExp src_addr = section == 0 ? BO1 : BO2;
272 if (ld_step == 1 || ld_step == 2)
273 src_addr = src_addr + LDA * ld_step;
274 else if (ld_step == 3)
275 src_addr = src_addr + CO1;
276 src_addr = src_addr - OFFSET * SIZE;
278 vmovups(Xmm(ld_step % 2), ptr[src_addr]);
279 RegExp dst_addr = AO1
280 + (ld_step + section * 4 - OFFSET) * SIZE;
281 for (int off = 0; off < 4; ++off)
282 pextrd(ptr[dst_addr + unroll_m * off * SIZE],
283 Xmm(ld_step % 2), off);
287 el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
288 el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
289 el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
290 el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
291 el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
292 el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
296 lea(BO2, ptr[BO2 + LDA * 4]);
299 if (unroll_m >= 16) {
301 if (isLoad2Unmasked) {
302 for (int i = 0; i < 2; i++) {
303 vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
304 vmovups(xmm1, ptr[BO2 + LDA * 1
305 + (0 * 8 - OFFSET) * SIZE]);
306 lea(BO2, ptr[BO2 + LDA * 2]);
307 vunpcklps(xmm4, xmm0, xmm1);
308 vunpckhps(xmm5, xmm0, xmm1);
309 vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
310 vmovups(xmm1, ptr[BO2 + LDA * 1
311 + (0 * 8 - OFFSET) * SIZE]);
313 lea(BO2, ptr[BO2 + LDA * 2]);
314 vunpcklps(xmm6, xmm0, xmm1);
315 vunpckhps(xmm2, xmm0, xmm1);
317 vunpcklpd(xmm0, xmm4, xmm6);
318 vunpckhpd(xmm1, xmm4, xmm6);
320 + (unroll_m * 0 + (i + 2) * 4
325 + (unroll_m * 1 + (i + 2) * 4
329 vunpcklpd(xmm0, xmm5, xmm2);
330 vunpckhpd(xmm1, xmm5, xmm2);
332 + (unroll_m * 2 + (i + 2) * 4
337 + (unroll_m * 3 + (i + 2) * 4
343 for (int i = 0; i < 2; i++) {
346 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
351 + ((2 * i + 1) - OFFSET) * SIZE],
355 + (unroll_m * (2 * i) + 2 * 4
360 + (unroll_m * (2 * i + 1) + 2 * 4
366 lea(BO2, ptr[BO2 + LDA * 4]);
368 for (int i = 0; i < 2; i++) {
369 vextractf128(xmm4, ymm3, 1);
371 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
373 vextractf128(xmm4, ymm3, 1);
376 + ((2 * i + 1) - OFFSET) * SIZE],
380 + (unroll_m * (2 * i) + 3 * 4
385 + (unroll_m * (2 * i + 1) + 3 * 4
391 lea(BO2, ptr[BO2 + LDA * 4]);
394 add(BO1, (4 * SIZE));
397 add(AO1, unroll_m * 4 * SIZE);
410 if (isLoad1Unmasked) {
411 vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
413 vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
416 if (isLoad2Unmasked) {
417 vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
419 vmaskmovps(ymm5, VMASK,
420 ptr[BO1 + (1 + 8 - OFFSET) * SIZE]);
424 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
427 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
431 if (isLoad1Unmasked) {
432 for (int i = 0; i < 2; i++) {
433 reg = (i % 2 == 0) ? BO1 : BO2;
434 vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]);
436 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
437 lea(BO2, ptr[reg + LDA * 2]);
438 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
440 vunpcklpd(xmm1, xmm1, xmm2);
441 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
444 for (int i = 0; i < 2; i++) {
445 vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
447 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
448 lea(BO2, ptr[BO2 + LDA * 2]);
449 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
451 vunpcklpd(xmm1, xmm1, xmm2);
452 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
454 } else if (is_avx2) {
456 vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE],
458 lea(BO2, ptr[BO1 + LDA * 4]);
459 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
462 vextractf128(xmm4, ymm3, 1);
463 vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
465 lea(BO2, ptr[BO2 + LDA * 4]);
466 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
469 vxorps(xmm4, xmm4, xmm4);
470 lea(BO2, ptr[BO1 + LDA * 4]);
472 auto el_cp = [&](int section, int ld_step) {
473 RegExp src_addr = section == 0 ? BO1 : BO2;
474 if (ld_step == 1 || ld_step == 2)
475 src_addr = src_addr + LDA * ld_step;
476 else if (ld_step == 3)
477 src_addr = src_addr + CO1;
478 src_addr = src_addr - OFFSET * SIZE;
480 vmovss(xmm1, ptr[src_addr]);
481 RegExp dst_addr = AO1
482 + (ld_step + section * 4 - OFFSET) * SIZE;
483 movss(ptr[dst_addr], xmm1);
487 el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
488 el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
489 el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
490 el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
491 el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
492 el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
496 lea(BO2, ptr[BO2 + LDA * 4]);
499 if (unroll_m >= 16) {
501 if (isLoad2Unmasked) {
502 for (int i = 0; i < 2; i++) {
504 ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
505 vmovss(xmm0, ptr[BO2 + LDA * 1
506 + (0 * 8 - OFFSET) * SIZE]);
507 lea(BO2, ptr[BO2 + LDA * 2]);
508 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
510 vunpcklpd(xmm1, xmm1, xmm2);
514 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
516 lea(BO2, ptr[BO2 + LDA * 4]);
518 vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE],
521 if (isLoad2Unmasked) {
522 for (int i = 0; i < 2; i++) {
524 ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
525 vmovss(xmm0, ptr[BO2 + LDA * 1
526 + (0 * 8 - OFFSET) * SIZE]);
527 lea(BO2, ptr[BO2 + LDA * 2]);
528 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
530 vunpcklpd(xmm1, xmm1, xmm2);
532 vextractf128(xmm4, ymm3, 1);
534 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
537 vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE],
543 add(AO1, unroll_m * SIZE);
551 // Fused multiply add; may become one or two instructions
552 auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2,
553 bool overWrite = false) {
556 vfmadd231ps(reg2, reg1, reg0);
558 assert(UNROLL_M == 8);
559 auto tent_vreg = overWrite ? reg1 : ymm1;
560 vmulps(tent_vreg, reg1, reg0);
561 vaddps(reg2, reg2, tent_vreg);
565 vmulps(ymm15, reg1, reg0);
566 vaddps(reg2, reg2, ymm15);
568 vmulps(reg1, reg1, reg0);
569 vaddps(reg2, reg2, reg1);
574 // Inner kernel with k=8
575 auto innerkernel8 = [&](int unroll_m, int unroll_n,
576 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
577 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
578 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
579 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
580 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
581 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
587 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
589 prefetcht0(ptr[AO1 + LDA4]);
592 for (int i = 0; i < 8; i++) {
594 if (isLoad1Unmasked) {
595 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
597 vmaskmovps(ymm0, VMASK,
598 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
600 if (unroll_m >= 16) {
601 if (isLoad2Unmasked) {
602 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
604 vmaskmovps(ymm1, VMASK,
605 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
612 vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
614 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
616 fmareg = (i % 2 == 0) ? reg00 : reg12;
617 fma(useFma, ymm0, ymm2, fmareg);
618 if (unroll_m >= 16) {
619 fmareg = (i % 2 == 0) ? reg06 : reg18;
620 fma(useFma, ymm1, ymm2, fmareg);
624 prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
630 prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
633 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
635 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
637 fmareg = (i % 2 == 0) ? reg01 : reg13;
638 fma(useFma, ymm0, ymm2, fmareg);
639 if (unroll_m >= 16) {
640 fmareg = (i % 2 == 0) ? reg07 : reg19;
641 fma(useFma, ymm1, ymm2, fmareg);
646 vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
648 if (unroll_m >= 16) {
650 + (unroll_m * i + 1 * 8 - OFFSET)
655 sub(LDA4, -unroll_m * 8 * SIZE);
663 ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
666 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
668 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
670 fmareg = (i % 2 == 0) ? reg02 : reg14;
671 fma(useFma, ymm0, ymm2, fmareg);
672 if (unroll_m >= 16) {
673 fmareg = (i % 2 == 0) ? reg08 : reg20;
674 fma(useFma, ymm1, ymm2, fmareg);
687 prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
689 vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
691 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
693 fmareg = (i % 2 == 0) ? reg03 : reg15;
694 fma(useFma, ymm0, ymm2, fmareg);
695 if (unroll_m >= 16) {
696 fmareg = (i % 2 == 0) ? reg09 : reg21;
697 fma(useFma, ymm1, ymm2, fmareg);
704 prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
707 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
709 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
711 fmareg = (i % 2 == 0) ? reg04 : reg16;
712 fma(useFma, ymm0, ymm2, fmareg);
713 if (unroll_m >= 16) {
714 fmareg = (i % 2 == 0) ? reg10 : reg22;
715 fma(useFma, ymm1, ymm2, fmareg);
723 ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
726 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
728 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
730 fmareg = (i % 2 == 0) ? reg05 : reg17;
731 fma(useFma, ymm0, ymm2, fmareg);
732 if (unroll_m >= 16) {
733 fmareg = (i % 2 == 0) ? reg11 : reg23;
734 fma(useFma, ymm1, ymm2, fmareg);
738 prefetcht0(ptr[BO1 + BO2]);
746 ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
748 prefetcht0(ptr[AO1 + LDA4]);
752 if (i == 1 || i == 2) {
756 + (PREFETCHSIZEA + (2 + 2 * i) * 8)
759 prefetcht0(ptr[AO1 + LDA4]);
763 if (i == 3 || i == 4 || i == 5 || i == 6) {
764 if (unroll_m >= 16) {
767 + (PREFETCHSIZEA + (2 + 2 * i) * 8)
770 prefetcht0(ptr[AO1 + LDA4]);
782 lea(AA, ptr[AA + LDA]);
787 if (isLoad1Unmasked) {
790 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
796 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
799 if (unroll_m >= 16) {
800 if (isLoad2Unmasked) {
801 vmovups(ymm1, ptr[AO1
802 + (unroll_m * (i + 1) + 1 * 8
806 vmaskmovps(ymm1, VMASK,
808 + (unroll_m * (i + 1) + 1 * 8
817 sub(AO1, -unroll_m * 8 * SIZE);
823 // Inner kernel with k=4
824 auto innerkernel4 = [&](int unroll_m, int unroll_n,
825 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
826 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
827 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
828 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
829 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
830 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
836 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
838 prefetcht0(ptr[AO1 + LDA4]);
841 for (int i = 0; i < 4; i++) {
843 if (isLoad1Unmasked) {
844 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
846 vmaskmovps(ymm0, VMASK,
847 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
849 if (unroll_m >= 16) {
850 if (isLoad2Unmasked) {
851 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
853 vmaskmovps(ymm1, VMASK,
854 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
861 vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
863 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
865 fmareg = (i % 2 == 0) ? reg00 : reg12;
866 fma(useFma, ymm0, ymm2, fmareg);
867 if (unroll_m >= 16) {
868 fmareg = (i % 2 == 0) ? reg06 : reg18;
869 fma(useFma, ymm1, ymm2, fmareg);
873 prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
879 prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
882 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
884 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
886 fmareg = (i % 2 == 0) ? reg01 : reg13;
887 fma(useFma, ymm0, ymm2, fmareg);
888 if (unroll_m >= 16) {
889 fmareg = (i % 2 == 0) ? reg07 : reg19;
890 fma(useFma, ymm1, ymm2, fmareg);
895 vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
897 if (unroll_m >= 16) {
899 + (unroll_m * i + 1 * 8 - OFFSET)
904 sub(LDA4, -unroll_m * 4 * SIZE);
912 ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
915 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
917 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
919 fmareg = (i % 2 == 0) ? reg02 : reg14;
920 fma(useFma, ymm0, ymm2, fmareg);
921 if (unroll_m >= 16) {
922 fmareg = (i % 2 == 0) ? reg08 : reg20;
923 fma(useFma, ymm1, ymm2, fmareg);
936 prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
938 vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
940 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
942 fmareg = (i % 2 == 0) ? reg03 : reg15;
943 fma(useFma, ymm0, ymm2, fmareg);
944 if (unroll_m >= 16) {
945 fmareg = (i % 2 == 0) ? reg09 : reg21;
946 fma(useFma, ymm1, ymm2, fmareg);
953 prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
956 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
958 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
960 fmareg = (i % 2 == 0) ? reg04 : reg16;
961 fma(useFma, ymm0, ymm2, fmareg);
962 if (unroll_m >= 16) {
963 fmareg = (i % 2 == 0) ? reg10 : reg22;
964 fma(useFma, ymm1, ymm2, fmareg);
972 ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
975 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
977 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
979 fmareg = (i % 2 == 0) ? reg05 : reg17;
980 fma(useFma, ymm0, ymm2, fmareg);
981 if (unroll_m >= 16) {
982 fmareg = (i % 2 == 0) ? reg11 : reg23;
983 fma(useFma, ymm1, ymm2, fmareg);
987 prefetcht0(ptr[BO1 + BO2]);
995 ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
997 prefetcht0(ptr[AO1 + LDA4]);
1001 if (i == 1 || i == 2) {
1002 if (unroll_m >= 8) {
1005 + (PREFETCHSIZEA + (2 + 2 * i) * 8)
1008 prefetcht0(ptr[AO1 + LDA4]);
1014 sub(BO1, -4 * SIZE);
1015 if (unroll_n >= 4) {
1016 sub(BO2, -4 * SIZE);
1022 if (isLoad1Unmasked) {
1025 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1031 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1034 if (unroll_m >= 16) {
1035 if (isLoad2Unmasked) {
1036 vmovups(ymm1, ptr[AO1
1037 + (unroll_m * (i + 1) + 1 * 8
1041 vmaskmovps(ymm1, VMASK,
1043 + (unroll_m * (i + 1) + 1 * 8
1052 sub(AO1, -unroll_m * 4 * SIZE);
1057 // Inner kernel with k=2
1058 auto innerkernel2 = [&](int unroll_m, int unroll_n,
1059 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1060 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1061 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1062 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
1063 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
1064 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
1069 for (int i = 0; i < 2; i++) {
1071 if (isLoad1Unmasked) {
1072 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1074 vmaskmovps(ymm0, VMASK,
1075 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1077 if (unroll_m >= 16) {
1078 if (isLoad2Unmasked) {
1079 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1081 vmaskmovps(ymm1, VMASK,
1082 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1089 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1091 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1093 fmareg = (i % 2 == 0) ? reg00 : reg12;
1094 fma(useFma, ymm0, ymm2, fmareg);
1095 if (unroll_m >= 16) {
1096 fmareg = (i % 2 == 0) ? reg06 : reg18;
1097 fma(useFma, ymm1, ymm2, fmareg);
1099 if (unroll_n >= 2) {
1102 ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1104 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1106 fmareg = (i % 2 == 0) ? reg01 : reg13;
1107 fma(useFma, ymm0, ymm2, fmareg);
1108 if (unroll_m >= 16) {
1109 fmareg = (i % 2 == 0) ? reg07 : reg19;
1110 fma(useFma, ymm1, ymm2, fmareg);
1114 if (unroll_n >= 3) {
1118 ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
1121 ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1123 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1125 fmareg = (i % 2 == 0) ? reg02 : reg14;
1126 fma(useFma, ymm0, ymm2, fmareg);
1127 if (unroll_m >= 16) {
1128 fmareg = (i % 2 == 0) ? reg08 : reg20;
1129 fma(useFma, ymm1, ymm2, fmareg);
1133 if (unroll_n >= 4) {
1135 vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1137 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1139 fmareg = (i % 2 == 0) ? reg03 : reg15;
1140 fma(useFma, ymm0, ymm2, fmareg);
1141 if (unroll_m >= 16) {
1142 fmareg = (i % 2 == 0) ? reg09 : reg21;
1143 fma(useFma, ymm1, ymm2, fmareg);
1147 if (unroll_n >= 5) {
1150 ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1152 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1154 fmareg = (i % 2 == 0) ? reg04 : reg16;
1155 fma(useFma, ymm0, ymm2, fmareg);
1156 if (unroll_m >= 16) {
1157 fmareg = (i % 2 == 0) ? reg10 : reg22;
1158 fma(useFma, ymm1, ymm2, fmareg);
1162 if (unroll_n >= 6) {
1165 ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1167 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1169 fmareg = (i % 2 == 0) ? reg05 : reg17;
1170 fma(useFma, ymm0, ymm2, fmareg);
1171 if (unroll_m >= 16) {
1172 fmareg = (i % 2 == 0) ? reg11 : reg23;
1173 fma(useFma, ymm1, ymm2, fmareg);
1178 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1180 if (unroll_m >= 16) {
1182 + (unroll_m * 0 + 1 * 8 - OFFSET)
1186 sub(LDA4, -unroll_m * SIZE);
1190 if (isLoad1Unmasked) {
1191 vmovups(ymm0, ptr[AO1
1192 + (unroll_m * 1 + 0 * 8 - OFFSET)
1195 vmaskmovps(ymm0, VMASK,
1197 + (unroll_m * 1 + 0 * 8 - OFFSET)
1200 if (unroll_m >= 16) {
1201 if (isLoad2Unmasked) {
1204 + (unroll_m * 1 + 1 * 8 - OFFSET)
1207 vmaskmovps(ymm1, VMASK,
1209 + (unroll_m * 1 + 1 * 8 - OFFSET)
1213 sub(AO1, -unroll_m * SIZE);
1218 if (unroll_n >= 4) {
1228 // Inner kernel with k=1
1229 auto innerkernel1 = [&](int unroll_m, int unroll_n,
1230 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1231 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1232 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1233 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) {
1236 if (isLoad1Unmasked) {
1237 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1239 vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1241 if (unroll_m >= 16) {
1242 if (isLoad2Unmasked) {
1243 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1245 vmaskmovps(ymm1, VMASK,
1246 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1253 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1255 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1257 fma(useFma, ymm0, ymm2, reg00);
1258 if (unroll_m >= 16) {
1259 fma(useFma, ymm1, ymm2, reg06);
1262 if (unroll_n >= 2) {
1265 ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1267 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1269 fma(useFma, ymm0, ymm2, reg01);
1270 if (unroll_m >= 16) {
1271 fma(useFma, ymm1, ymm2, reg07);
1275 if (unroll_n >= 3) {
1278 ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1280 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1282 fma(useFma, ymm0, ymm2, reg02);
1283 if (unroll_m >= 16) {
1284 fma(useFma, ymm1, ymm2, reg08);
1288 if (unroll_n >= 4) {
1290 vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1292 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1294 fma(useFma, ymm0, ymm2, reg03);
1295 if (unroll_m >= 16) {
1296 fma(useFma, ymm1, ymm2, reg09);
1300 if (unroll_n >= 5) {
1303 ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1305 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1307 fma(useFma, ymm0, ymm2, reg04);
1308 if (unroll_m >= 16) {
1309 fma(useFma, ymm1, ymm2, reg10);
1313 if (unroll_n >= 6) {
1316 ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1318 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1320 fma(useFma, ymm0, ymm2, reg05);
1321 if (unroll_m >= 16) {
1322 fma(useFma, ymm1, ymm2, reg11);
1327 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1329 if (unroll_m >= 16) {
1330 vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
1333 sub(LDA4, -unroll_m * SIZE);
1337 if (isLoad1Unmasked) {
1339 ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1341 vmaskmovps(ymm0, VMASK,
1342 ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1344 if (unroll_m >= 16) {
1345 if (isLoad2Unmasked) {
1346 vmovups(ymm1, ptr[AO1
1347 + (unroll_m * 1 + 1 * 8 - OFFSET)
1350 vmaskmovps(ymm1, VMASK,
1352 + (unroll_m * 1 + 1 * 8 - OFFSET)
1356 sub(AO1, -unroll_m * SIZE);
1361 if (unroll_n >= 4) {
1370 // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as
1372 // After calculating results in registers, writes back to C matrix
1373 auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1374 bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma,
1375 Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6),
1376 Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9),
1377 Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12),
1378 Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15),
1379 Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6),
1380 Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
1381 Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
1382 Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
1384 lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
1390 lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
1392 lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
1396 lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]);
1397 lea(BO2, ptr[BO2 + LDB * 2]);
1401 if (isLoad1Unmasked) {
1403 ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1405 vmaskmovps(ymm0, VMASK,
1406 ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1408 if (unroll_m >= 16) {
1409 if (isLoad2Unmasked) {
1410 vmovups(ymm1, ptr[AO1
1411 + (unroll_m * 0 + 1 * 8 - OFFSET)
1414 vmaskmovps(ymm1, VMASK,
1416 + (unroll_m * 0 + 1 * 8 - OFFSET)
1422 for (int i = 4; i < 10; i++) {
1423 vxorps(Ymm(i), Ymm(i), Ymm(i));
1424 vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6));
1430 Label kernel12, kernel13, kernel14, kernel15;
1431 Label kernel16, kernel17, kernel18;
1433 sub(LL, SECOND_FETCH);
1434 jle(kernel13, T_NEAR);
1438 innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1439 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1440 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1441 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1442 reg21, reg22, reg23);
1443 jg(kernel12, T_NEAR);
1447 prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
1449 prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
1451 prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]);
1453 prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]);
1455 prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]);
1457 prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
1459 add(LL, SECOND_FETCH);
1460 jle(kernel15, T_NEAR);
1464 innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1465 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1466 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1467 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1468 reg21, reg22, reg23);
1469 jg(kernel14, T_NEAR);
1474 jle(kernel16, T_NEAR);
1475 innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1476 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1477 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1478 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1479 reg21, reg22, reg23);
1483 jle(kernel17, T_NEAR);
1484 innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1485 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1486 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1487 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1488 reg21, reg22, reg23);
1492 if (unroll_m == 16) {
1493 if (unroll_n <= 3) {
1494 vaddps(reg00, reg00, reg12);
1495 vaddps(reg01, reg01, reg13);
1496 vaddps(reg02, reg02, reg14);
1497 vaddps(reg06, reg06, reg18);
1498 vaddps(reg07, reg07, reg19);
1499 vaddps(reg08, reg08, reg20);
1503 if (unroll_m <= 8) {
1504 vaddps(reg00, reg00, reg12);
1505 vaddps(reg01, reg01, reg13);
1506 vaddps(reg02, reg02, reg14);
1507 vaddps(reg03, reg03, reg15);
1508 vaddps(reg04, reg04, reg16);
1509 vaddps(reg05, reg05, reg17);
1513 jle(kernel18, T_NEAR);
1514 innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1515 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1516 reg05, reg06, reg07, reg08, reg09, reg10, reg11);
1520 vbroadcastss(VALPHA, ALPHA);
1523 vbroadcastss(VBETA, BETA);
1526 // Write back the results; all beta and bias cases need to be
1529 case 1: mov(rax, LDC); break;
1530 case 2: lea(rax, ptr[LDC * 2]); break;
1531 case 3: lea(rax, ptr[LDC + LDC * 2]); break;
1532 case 4: lea(rax, ptr[LDC + LDC * 4]); break;
1534 lea(rax, ptr[LDC * 4]);
1538 lea(rax, ptr[LDC + LDC * 2]);
1545 if (isLoad1Unmasked) {
1546 vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1548 vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]);
1552 for (int i = 0; i < unroll_n; i++) {
1553 vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA);
1555 if (isLoad1Unmasked) {
1557 case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break;
1558 case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break;
1560 vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1562 case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break;
1563 case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break;
1565 vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1571 vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]);
1574 vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]);
1578 ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1581 vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]);
1584 vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]);
1588 ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1594 vaddps(Ymm(i + 4), ymm0, Ymm(i + 4));
1596 fma(useFma, VBETA, ymm0, Ymm(i + 4), true);
1600 vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4));
1602 if (isLoad1Unmasked) {
1604 case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break;
1606 vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4));
1609 vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1611 case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break;
1613 vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4));
1616 vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1622 vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4));
1626 ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1629 vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK,
1633 vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4));
1637 ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1640 vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK,
1646 if (unroll_m >= 16) {
1647 // Re-use ymm4 (VBIAS2)
1650 if (isLoad1Unmasked) {
1651 vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]);
1654 VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]);
1658 vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA);
1660 if (isLoad2Unmasked) {
1662 case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break;
1664 vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]);
1667 vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]);
1669 case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break;
1671 vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]);
1674 vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]);
1680 vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]);
1684 ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]);
1687 vmaskmovps(ymm0, VMASK,
1688 ptr[CO1 + LDC * 2 + 8 * SIZE]);
1691 vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]);
1695 ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]);
1698 vmaskmovps(ymm0, VMASK,
1699 ptr[CO2 + LDC * 2 + 8 * SIZE]);
1704 vaddps(Ymm(i + 10), ymm0, Ymm(i + 10));
1706 fma(useFma, VBETA, ymm0, Ymm(i + 10), true);
1710 vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10));
1712 if (isLoad2Unmasked) {
1715 vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10));
1718 vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10));
1721 vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1724 vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10));
1727 vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10));
1730 vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1736 vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10));
1739 vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK,
1743 vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK,
1747 vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10));
1750 vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK,
1754 vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK,
1763 if (unroll_n >= 4) {
1767 // Compute next address of B
1769 lea(rax, ptr[K * SIZE]);
1776 lea(BO1, ptr[BO1 + LDB * 2]);
1777 lea(BO2, ptr[BO2 + LDB * 2]);
1780 lea(BO1, ptr[BO1 + LDB3]);
1781 lea(BO2, ptr[BO2 + LDB3]);
1784 lea(BO1, ptr[BO1 + LDB * 4]);
1785 lea(BO2, ptr[BO2 + LDB * 4]);
1788 lea(BO1, ptr[BO1 + LDB * 4]);
1790 lea(BO2, ptr[BO2 + LDB * 4]);
1794 lea(BO1, ptr[BO1 + LDB3 * 2]);
1795 lea(BO2, ptr[BO2 + LDB3 * 2]);
1804 add(BO1, unroll_n * SIZE);
1808 auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1809 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1810 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1811 isDirect, isCopy, true);
1814 auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1815 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1816 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1817 isDirect, isCopy, true);
1820 auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1821 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1822 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1823 isDirect, isCopy, true);
1826 auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1827 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1828 bool useFma = true) {
1829 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1830 isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1831 Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1832 Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1833 Ymm(13), Ymm(14), Ymm(15));
1836 auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1837 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1838 kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1839 isDirect, isCopy, false);
1842 auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1843 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1844 kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1845 isDirect, isCopy, false);
1848 auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1849 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1850 bool useFma = true) {
1851 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1852 isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1853 Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1854 Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1858 auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1859 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1860 kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1864 auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1865 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1866 kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1870 auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1871 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1872 bool useFma = true) {
1873 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1874 isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1875 Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1876 Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1877 Ymm(13), Ymm(14), Ymm(15));
1880 auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1881 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1882 kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1883 isDirect, isCopy, false);
1886 auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1887 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1888 kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1889 isDirect, isCopy, false);
1892 // High-level subroutine; does packing if needed, then splits C matrix.
1893 // Operates on chunks of 16 rows, 6 columns at a time (handling tail
1894 // cases appropriately).
1895 // Masking is used for tail cases where M is not divisible by 8.
1897 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
1899 do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
1902 Label subloop11, subloop11mask;
1903 Label subloop20, subloop21, subloop22, subloop23;
1904 Label subloop24, subloop25;
1905 Label subloop30, subloop31, subloop32, subloop33;
1906 Label subloop34, subloop35;
1907 Label subloop98, subloop98mask;
1908 Label subloop99, subloop99mask;
1911 lea(CO2, ptr[CO1 + LDC * 2]);
1913 add(C, unroll_m * SIZE);
1916 lea(BO2, qword[B + LDB3]);
1920 lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
1922 jg(subloop98, T_NEAR);
1925 lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
1932 // If N is too small, skip copy operation
1933 cmp(LL, UNROLL_N * 3);
1934 jle(subloop30, T_NEAR);
1936 // If A is not aligned to cache line
1938 je(subloop30, T_NEAR);
1941 jl(subloop20, T_NEAR);
1946 if (unroll_m == 16) {
1947 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1948 isLoad2Unmasked, true, true);
1950 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1951 isLoad2Unmasked, true, true);
1954 if (unroll_m == 16) {
1955 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1956 isLoad2Unmasked, false, false);
1958 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1959 isLoad2Unmasked, false, false);
1965 jl(subloop20, T_NEAR);
1969 if (unroll_m == 16) {
1970 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1971 isLoad2Unmasked, false, false);
1973 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked,
1978 jge(subloop11, T_NEAR);
1983 jne(subloop21, T_NEAR);
1984 if (unroll_m == 16) {
1985 kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
1988 kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
1991 jmp(subloop99, T_NEAR);
1996 jne(subloop22, T_NEAR);
1997 if (unroll_m == 16) {
1998 kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2001 kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
2004 jmp(subloop99, T_NEAR);
2009 jne(subloop23, T_NEAR);
2010 if (unroll_m == 16) {
2011 kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2014 kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
2017 jmp(subloop99, T_NEAR);
2022 jne(subloop24, T_NEAR);
2023 if (unroll_m == 16) {
2024 kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2027 kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
2030 jmp(subloop99, T_NEAR);
2035 jne(subloop99, T_NEAR);
2036 if (unroll_m == 16) {
2037 kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2040 kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
2043 jmp(subloop99, T_NEAR);
2049 jl(subloop25, T_NEAR);
2053 if (unroll_m == 16) {
2054 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2055 isLoad2Unmasked, true, false);
2057 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2058 isLoad2Unmasked, true, false);
2062 jge(subloop31, T_NEAR);
2067 jne(subloop32, T_NEAR);
2068 if (unroll_m == 16) {
2069 kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2072 kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2075 jmp(subloop99, T_NEAR);
2080 jne(subloop33, T_NEAR);
2081 if (unroll_m == 16) {
2082 kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2085 kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2088 jmp(subloop99, T_NEAR);
2093 jne(subloop34, T_NEAR);
2094 if (unroll_m == 16) {
2095 kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2098 kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2101 jmp(subloop99, T_NEAR);
2106 jne(subloop35, T_NEAR);
2107 if (unroll_m == 16) {
2108 kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2111 kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2114 jmp(subloop99, T_NEAR);
2119 jne(subloop99, T_NEAR);
2120 if (unroll_m == 16) {
2121 kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2124 kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2131 // Compute address for A
2133 add(A, unroll_m * SIZE);
2136 imul(rax, rax, unroll_m);
2140 // Compute next address of BIAS
2142 add(BIAS, unroll_m * SIZE);
2148 Label buffer_in_ws, buffer_allocated;
2150 // Get the registers
2160 vmovss(xmm0, ptr[ARG_ALPHA]);
2161 vmovss(xmm1, ptr[r15]);
2168 cmp(K, STACK_K_CAPACITY);
2169 jg(buffer_in_ws, T_NEAR);
2171 // Create buffer and align to 4kB page
2172 lea(rax, ptr[K * SIZE]);
2176 and_(rsp, -PAGE_4K);
2177 jmp(buffer_allocated, T_NEAR);
2182 L(buffer_allocated);
2190 vmovss(ALPHA, xmm0);
2192 sub(A, -OFFSET * SIZE);
2193 sub(B, -OFFSET * SIZE);
2195 sal(LDA, BASE_SHIFT);
2196 sal(LDB, BASE_SHIFT);
2197 sal(LDC, BASE_SHIFT);
2198 lea(LDB3, ptr[LDB + LDB * 2]);
2200 for (int i = 0; i < 8; i++) {
2201 mov(dword[rsp + 88 + i * 4], i);
2204 if (isTransA && is_avx2) {
2206 vpbroadcastq(ymm1, xmm0);
2207 vinsertf128(ymm0, ymm0, xmm0, 1);
2208 vpermilpd(ymm0, ymm0, 5);
2209 vpaddq(ymm1, ymm1, ymm1);
2210 vperm2f128(ymm1, ymm1, ymm1, 8);
2211 vpaddq(ymm0, ymm0, ymm1);
2212 vmovups(STRIDE, ymm0);
2215 // Check A alignment and leading dimension; take copy-based path as
2222 Label main0, main1, main2, main3, main999;
2229 subloop(UNROLL_M, true, true);
2237 jle(main999, T_NEAR);
2244 vbroadcastss(VMASK, M);
2245 vpcmpgtd(VMASK, VMASK, MASK);
2247 subloop(16, true, false);
2248 jmp(main999, T_NEAR);
2254 subloop(8, true, true);
2255 jmp(main999, T_NEAR);
2261 vbroadcastss(VMASK, M);
2263 vpcmpgtd(VMASK, VMASK, MASK);
2265 auto xmask = Xmm(VMASK.getIdx());
2266 auto xmm_tmp = xmm4;
2268 vextractf128(xmm_tmp, VMASK, 1);
2269 vpcmpgtd(xmask, xmask, MASK);
2270 vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4
2271 vinsertf128(VMASK, VMASK, xmm_tmp, 1);
2273 subloop(8, false, false);
2277 // Restore original stack
2284 ker_ = reinterpret_cast<decltype(ker_)>(
2285 const_cast<uint8_t *>(this->getCode()));
2288 void operator()(long long int m, long long int n, long long int k,
2289 const float *alpha, const float *a, long long int lda,
2290 const float *b, long long int ldb, const float *beta, float *c,
2291 long long int ldc, const float *bias, float *ws)
2293 (*ker_)(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
2297 void (*ker_)(long long int m, long long int n, long long int k,
2298 const float *alpha, const float *a, long long int lda,
2299 const float *b, long long int ldb, const float *beta, float *c,
2300 long long int ldc, const float *bias, float *ws);
2303 typedef void (*ker)(long long int, long long int, long long int, float *,
2304 float *, long long int, float *, long long int, float *, float *,
2305 long long int, float *);
2306 void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
2307 const char *transb, int m, int n, int k, const float *alpha,
2308 const float *a, int lda, const float *b, int ldb, const float *beta,
2309 float *c, int ldc, const float *bias, float *ws)
2311 bool isTransA = (*transa == 'T' || *transa == 't');
2312 bool isTransB = (*transb == 'T' || *transb == 't');
2314 int Bm, sizeM, Bn, sizeN, Bk, sizeK;
2318 if ((m <= 0) || (n <= 0))
2321 if ((k <= 0) || (alpha[0] == 0.)) {
2323 if (beta[0] == 0.) {
2324 for (j = 0; j < n; j++)
2325 for (i = 0; i < m; i++)
2326 c[i + j * ldc] = 0.0;
2327 } else if (beta[0] != 1.) {
2328 for (j = 0; j < n; j++)
2329 for (i = 0; i < m; i++)
2330 c[i + j * ldc] *= beta[0];
2337 int BN = isTransA ? 96 : 48;
2338 int BK = isTransB ? 96 : 256;
2339 const float *curA, *curB, *curBias = nullptr;
2342 for (Bk = 0; Bk < k; Bk += sizeK) {
2344 if (sizeK >= BK * 2)
2348 sizeK = (sizeK + 1) / 2;
2351 for (Bm = 0; Bm < m; Bm += sizeM) {
2353 if (sizeM >= BM * 2)
2356 if (sizeM > BM + BM / 2)
2357 sizeM = (sizeM + 1) / 2;
2360 for (Bn = 0; Bn < n; Bn += sizeN) {
2362 if (sizeN >= BN * 2)
2365 if (sizeN > BN + BN / 2)
2366 sizeN = (sizeN + 1) / 2;
2370 curA = a + Bm + (size_t)Bk * lda;
2372 curA = a + Bk + (size_t)Bm * lda;
2375 curB = b + Bk + (size_t)Bn * ldb;
2377 curB = b + Bn + (size_t)Bk * ldb;
2379 curC = c + Bm + (size_t)Bn * ldc;
2380 if (bias != nullptr) {
2382 curBias = bias + Bm;
2388 if (*beta == 0.0 && bias == nullptr)
2389 (*ker_b0_)((long long int)sizeM, (long long int)sizeN,
2390 (long long int)sizeK, alpha, curA,
2391 (long long int)lda, curB, (long long int)ldb,
2392 beta, curC, (long long int)ldc, curBias, ws);
2394 (*ker_bn_)((long long int)sizeM, (long long int)sizeN,
2395 (long long int)sizeK, alpha, curA,
2396 (long long int)lda, curB, (long long int)ldb,
2397 beta, curC, (long long int)ldc, curBias, ws);
2399 (*ker_b1_)((long long int)sizeM, (long long int)sizeN,
2400 (long long int)sizeK, alpha, curA,
2401 (long long int)lda, curB, (long long int)ldb, beta,
2402 curC, (long long int)ldc, curBias, ws);
2409 void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
2410 const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
2411 const float *A, const int *p_lda, const float *B, const int *p_ldb,
2412 const float *p_beta, float *C, const int *p_ldc, const float *bias)
2414 if (beta_ == 0. || beta_ == 1.)
2415 assert(*p_beta == beta_);
2416 assert((one_of(*transa, 'T', 't') == one_of(transa_, 'T', 't')));
2418 int nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
2425 float beta = *p_beta;
2428 int nthr_m, nthr_n, nthr_k, nthr_mn;
2430 assert(nthr <= nthrs_);
2432 // Determine threading partitioning
2433 gemm_utils::calc_nthr_nocopy_avx(
2434 m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
2435 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
2437 // May not happen, but just in case
2438 if (nthr < nthr_m * nthr_n * nthr_k)
2439 nthr = nthr_m * nthr_n * nthr_k;
2441 nthr_mn = nthr_m * nthr_n;
2443 unsigned char * ompstatus_ = nullptr;
2444 unsigned char volatile *ompstatus = nullptr;
2446 float *c_buffers = nullptr;
2447 float *ws_buffers = nullptr;
2450 ompstatus_ = (unsigned char *) malloc(
2451 nthr * CACHE_LINE_SIZE,
2453 ompstatus = (unsigned char volatile *) ompstatus_;
2456 for (int i = 0; i < nthr; i++)
2457 ompstatus[i * CACHE_LINE_SIZE] = 0;
2459 c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
2460 * sizeof(float), PAGE_4K);
2463 const size_t ws_elems_per_thr = k * 16 + 64;
2464 const size_t ws_size_per_thr
2465 = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
2466 if (k > STACK_K_CAPACITY) {
2467 ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
2470 parallel(nthr, [&](const int ithr, const int nthr) {
2471 int ithr_m, ithr_n, ithr_k, ithr_mn;
2472 int m_from, m_to, myM;
2473 int n_from, n_to, myN;
2474 int k_from, k_to, myK;
2476 const float *myA, *myB, *myBias = nullptr;
2477 float *myC = C, myBeta;
2478 float *ws = ws_buffers ?
2479 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
2482 if (ithr < nthr_m * nthr_n * nthr_k) {
2484 ithr_mn = ithr % nthr_mn;
2485 ithr_m = ithr_mn % nthr_m;
2486 ithr_n = ithr_mn / nthr_m;
2487 ithr_k = ithr / nthr_mn;
2489 /* swap ithr_k for performance improvement */
2491 ithr_k = nthr_k - 1;
2492 else if (ithr_k == nthr_k - 1)
2495 m_from = MB * (ithr_m);
2496 m_to = MB * (ithr_m + 1);
2499 myM = m_to - m_from;
2501 n_from = NB * (ithr_n);
2502 n_to = NB * (ithr_n + 1);
2505 myN = n_to - n_from;
2507 k_from = KB * (ithr_k);
2508 k_to = KB * (ithr_k + 1);
2511 myK = k_to - k_from;
2513 cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2514 ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
2516 if ((myM > 0) && (myN > 0)) {
2518 if (*transa == 'N' || *transa == 'n') {
2519 myA = &(A[m_from + k_from * lda]);
2521 myA = &(A[k_from + m_from * lda]);
2523 if (*transb == 'N' || *transb == 'n') {
2524 myB = &(B[k_from + n_from * ldb]);
2526 myB = &(B[n_from + k_from * ldb]);
2529 myC = &(C[m_from + n_from * ldc]);
2533 myBias = &(bias[m_from]);
2535 myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
2541 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
2542 lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
2545 ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
2550 // sum matrices partitioned along K dimension
2553 gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2557 myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
2558 myC = myC + n1 * MB;
2559 /* need to wait until main thread finishes */
2560 while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
2563 /* my cache is hot */
2564 gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2565 &C[m_from + (n_from + n1) * ldc], ldc);
2568 for (int ik = 1; ik < nthr_k; ++ik) {
2571 myC = c_buffers + MB * NB * (cbase + ik - 1);
2572 myC = myC + n1 * MB;
2574 while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2577 gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2578 &C[m_from + (n_from + n1) * ldc], ldc);
2590 jit_avx_gemm_f32::jit_avx_gemm_f32(
2591 char transa, char transb, float beta, bool hasBias)
2598 assert(beta == 0.0);
2600 ker_bn_ = new xbyak_gemm(transa, transb, beta, hasBias);
2602 ker_b1_ = new xbyak_gemm(transa, transb, 1.0);
2606 if (beta != 0.0 || (beta == 0.0 && hasBias)) {
2607 ker_b0_ = new xbyak_gemm(transa, transb, 0.0);
2611 nthrs_ = mkldnn_get_max_threads();
2614 jit_avx_gemm_f32::~jit_avx_gemm_f32()
2619 if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
2627 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s