1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "mkldnn_thread.hpp"
21 #include "gemm_utils.hpp"
22 #include "jit_avx_gemm_f32.hpp"
24 #define CACHE_LINE_SIZE 16
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::utils;
33 using namespace Xbyak;
34 #define STACKSIZE get_size_of_abi_save_regs()
36 #define STACK_K_CAPACITY 128
38 #define STACK_K_CAPACITY 8192
43 #define SECOND_FETCH 14
45 struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
46 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
48 xbyak_gemm(char transa, char transb, float beta, bool hasBias = false,
49 void *code_ptr = nullptr,
50 size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
51 : jit_generator(code_ptr, code_size)
53 const bool is_avx2 = mayiuse(avx2);
54 assert(implication(!is_avx2, mayiuse(avx)));
56 const int UNROLL_M = is_avx2 ? 16 : 8;
57 const int UNROLL_N = 6;
59 bool isTransA = (transa == 'T' || transa == 't');
60 bool isTransB = (transb == 'T' || transb == 't');
61 bool isBeta0 = (beta == 0.0);
62 bool isBetaN = (!isBeta0 && beta != 1.0);
64 // various definitions for convenience
65 auto ARG_M = abi_param1;
66 auto ARG_N = abi_param2;
68 auto ARG_ALPHA = abi_param4;
70 auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
71 auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
72 sizeof(float *) + STACKSIZE];
73 const auto stackOffset = OFFSET_SHADOWSPACE +
74 sizeof(float *) + STACKSIZE;
80 const auto stackOffset = STACKSIZE;
84 auto ARG_B = ptr[rsp + 8 + stackOffset];
85 auto ARG_LDB = ptr[rsp + 16 + stackOffset];
86 auto ARG_BETA = ptr[rsp + 24 + stackOffset];
87 auto ARG_C = ptr[rsp + 32 + stackOffset];
88 auto ARG_LDC = ptr[rsp + 40 + stackOffset];
89 auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
90 auto ARG_WS = ptr[rsp + 56 + stackOffset];
96 auto AO1 = abi_param2;
97 auto BO1 = abi_param4;
102 auto LDA4 = abi_param1;
104 auto BIAS1 = abi_param1;
106 auto M = qword[rsp + 0];
107 auto N = qword[rsp + 8];
108 auto FLAG = qword[rsp + 16];
109 auto I = qword[rsp + 24];
110 auto C = qword[rsp + 32];
111 auto BIAS = qword[rsp + 40];
112 auto ALPHA = qword[rsp + 48];
113 auto BETA = qword[rsp + 64];
114 auto ORIG_A = qword[rsp + 80];
115 auto MASK = dword[rsp + 88];
116 auto STRIDE = qword[rsp + 120];
117 auto ORIG_SP = qword[rsp + 152];
125 auto PREFETCHSIZEA = 128;
126 auto PREFETCHSIZEB = (!isTransB) ? -16 : 0;
128 // Function for packing if needed
130 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
137 lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
140 lea(BO2, ptr[BO1 + LDA * 4]);
141 lea(CO1, ptr[LDA + LDA * 2]);
142 vmovupd(ymm7, STRIDE);
147 jle(".pack3", T_NEAR);
152 for (int i = 0; i < 4; i++) {
153 regIdx = (i % 2 == 0) ? 4 : 6;
154 if (isLoad1Unmasked) {
156 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
158 vmaskmovps(Ymm(regIdx), VMASK,
159 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
162 if (isLoad2Unmasked) {
163 vmovups(Ymm(regIdx + 1),
164 ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
166 vmaskmovps(Ymm(regIdx + 1), VMASK,
167 ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
172 vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
176 + (unroll_m * i + 1 * 8 - OFFSET)
183 if (isLoad1Unmasked) {
184 for (int i = 0; i < 2; i++) {
185 reg = (i % 2 == 0) ? BO1 : BO2;
186 vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]);
188 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
189 lea(BO2, ptr[reg + LDA * 2]);
190 vunpcklps(xmm4, xmm0, xmm1);
191 vunpckhps(xmm5, xmm0, xmm1);
192 vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
194 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
195 lea(BO2, ptr[BO2 + LDA * 2]);
196 vunpcklps(xmm6, xmm0, xmm1);
197 vunpckhps(xmm2, xmm0, xmm1);
199 vunpcklpd(xmm0, xmm4, xmm6);
200 vunpckhpd(xmm1, xmm4, xmm6);
202 + (unroll_m * 0 + i * 4 - OFFSET)
206 + (unroll_m * 1 + i * 4 - OFFSET)
209 vunpcklpd(xmm0, xmm5, xmm2);
210 vunpckhpd(xmm1, xmm5, xmm2);
212 + (unroll_m * 2 + i * 4 - OFFSET)
216 + (unroll_m * 3 + i * 4 - OFFSET)
220 } else if (is_avx2) {
221 for (int i = 0; i < 2; i++) {
224 ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE],
228 ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
232 + (unroll_m * (2 * i) + 0 * 4 - OFFSET)
236 + (unroll_m * (2 * i + 1) + 0 * 4
242 lea(BO2, ptr[BO1 + LDA * 4]);
244 for (int i = 0; i < 2; i++) {
245 vextractf128(xmm4, ymm3, 1);
247 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
249 vextractf128(xmm4, ymm3, 1);
251 ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
255 + (unroll_m * (2 * i) + 1 * 4 - OFFSET)
259 + (unroll_m * (2 * i + 1) + 1 * 4
265 lea(BO2, ptr[BO2 + LDA * 4]);
267 vxorps(xmm4, xmm4, xmm4);
268 lea(BO2, ptr[BO1 + LDA * 4]);
270 auto el_cp = [&](int section, int ld_step) {
271 RegExp src_addr = section == 0 ? BO1 : BO2;
272 if (ld_step == 1 || ld_step == 2)
273 src_addr = src_addr + LDA * ld_step;
274 else if (ld_step == 3)
275 src_addr = src_addr + CO1;
276 src_addr = src_addr - OFFSET * SIZE;
278 vmovups(Xmm(ld_step % 2), ptr[src_addr]);
279 RegExp dst_addr = AO1
280 + (ld_step + section * 4 - OFFSET) * SIZE;
281 for (int off = 0; off < 4; ++off)
282 pextrd(ptr[dst_addr + unroll_m * off * SIZE],
283 Xmm(ld_step % 2), off);
287 el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
288 el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
289 el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
290 el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
291 el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
292 el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
296 lea(BO2, ptr[BO2 + LDA * 4]);
299 if (unroll_m >= 16) {
301 if (isLoad2Unmasked) {
302 for (int i = 0; i < 2; i++) {
303 vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
304 vmovups(xmm1, ptr[BO2 + LDA * 1
305 + (0 * 8 - OFFSET) * SIZE]);
306 lea(BO2, ptr[BO2 + LDA * 2]);
307 vunpcklps(xmm4, xmm0, xmm1);
308 vunpckhps(xmm5, xmm0, xmm1);
309 vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
310 vmovups(xmm1, ptr[BO2 + LDA * 1
311 + (0 * 8 - OFFSET) * SIZE]);
313 lea(BO2, ptr[BO2 + LDA * 2]);
314 vunpcklps(xmm6, xmm0, xmm1);
315 vunpckhps(xmm2, xmm0, xmm1);
317 vunpcklpd(xmm0, xmm4, xmm6);
318 vunpckhpd(xmm1, xmm4, xmm6);
320 + (unroll_m * 0 + (i + 2) * 4
325 + (unroll_m * 1 + (i + 2) * 4
329 vunpcklpd(xmm0, xmm5, xmm2);
330 vunpckhpd(xmm1, xmm5, xmm2);
332 + (unroll_m * 2 + (i + 2) * 4
337 + (unroll_m * 3 + (i + 2) * 4
343 for (int i = 0; i < 2; i++) {
346 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
351 + ((2 * i + 1) - OFFSET) * SIZE],
355 + (unroll_m * (2 * i) + 2 * 4
360 + (unroll_m * (2 * i + 1) + 2 * 4
366 lea(BO2, ptr[BO2 + LDA * 4]);
368 for (int i = 0; i < 2; i++) {
369 vextractf128(xmm4, ymm3, 1);
371 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
373 vextractf128(xmm4, ymm3, 1);
376 + ((2 * i + 1) - OFFSET) * SIZE],
380 + (unroll_m * (2 * i) + 3 * 4
385 + (unroll_m * (2 * i + 1) + 3 * 4
391 lea(BO2, ptr[BO2 + LDA * 4]);
394 add(BO1, (4 * SIZE));
397 add(AO1, unroll_m * 4 * SIZE);
399 jg(".pack2", T_NEAR);
405 jle(".pack10", T_NEAR);
410 if (isLoad1Unmasked) {
411 vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
413 vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
416 if (isLoad2Unmasked) {
417 vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
419 vmaskmovps(ymm5, VMASK,
420 ptr[BO1 + (1 + 8 - OFFSET) * SIZE]);
424 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
427 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
431 if (isLoad1Unmasked) {
432 for (int i = 0; i < 2; i++) {
433 reg = (i % 2 == 0) ? BO1 : BO2;
434 vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]);
436 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
437 lea(BO2, ptr[reg + LDA * 2]);
438 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
440 vunpcklpd(xmm1, xmm1, xmm2);
441 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
444 for (int i = 0; i < 2; i++) {
445 vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
447 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
448 lea(BO2, ptr[BO2 + LDA * 2]);
449 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
451 vunpcklpd(xmm1, xmm1, xmm2);
452 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
454 } else if (is_avx2) {
456 vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE],
458 lea(BO2, ptr[BO1 + LDA * 4]);
459 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
462 vextractf128(xmm4, ymm3, 1);
463 vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
465 lea(BO2, ptr[BO2 + LDA * 4]);
466 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
469 vxorps(xmm4, xmm4, xmm4);
470 lea(BO2, ptr[BO1 + LDA * 4]);
472 auto el_cp = [&](int section, int ld_step) {
473 RegExp src_addr = section == 0 ? BO1 : BO2;
474 if (ld_step == 1 || ld_step == 2)
475 src_addr = src_addr + LDA * ld_step;
476 else if (ld_step == 3)
477 src_addr = src_addr + CO1;
478 src_addr = src_addr - OFFSET * SIZE;
480 vmovss(xmm1, ptr[src_addr]);
481 RegExp dst_addr = AO1
482 + (ld_step + section * 4 - OFFSET) * SIZE;
483 movss(ptr[dst_addr], xmm1);
487 el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
488 el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
489 el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
490 el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
491 el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
492 el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
496 lea(BO2, ptr[BO2 + LDA * 4]);
499 if (unroll_m >= 16) {
501 if (isLoad2Unmasked) {
502 for (int i = 0; i < 2; i++) {
504 ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
505 vmovss(xmm0, ptr[BO2 + LDA * 1
506 + (0 * 8 - OFFSET) * SIZE]);
507 lea(BO2, ptr[BO2 + LDA * 2]);
508 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
510 vunpcklpd(xmm1, xmm1, xmm2);
514 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
516 lea(BO2, ptr[BO2 + LDA * 4]);
518 vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE],
521 if (isLoad2Unmasked) {
522 for (int i = 0; i < 2; i++) {
524 ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
525 vmovss(xmm0, ptr[BO2 + LDA * 1
526 + (0 * 8 - OFFSET) * SIZE]);
527 lea(BO2, ptr[BO2 + LDA * 2]);
528 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
530 vunpcklpd(xmm1, xmm1, xmm2);
532 vextractf128(xmm4, ymm3, 1);
534 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
537 vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE],
543 add(AO1, unroll_m * SIZE);
545 jg(".pack4", T_NEAR);
553 // Fused multiply add; may become one or two instructions
554 auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2,
555 bool overWrite = false) {
558 vfmadd231ps(reg2, reg1, reg0);
560 assert(UNROLL_M == 8);
561 auto tent_vreg = overWrite ? reg1 : ymm1;
562 vmulps(tent_vreg, reg1, reg0);
563 vaddps(reg2, reg2, tent_vreg);
567 vmulps(ymm15, reg1, reg0);
568 vaddps(reg2, reg2, ymm15);
570 vmulps(reg1, reg1, reg0);
571 vaddps(reg2, reg2, reg1);
576 // Inner kernel with k=8
577 auto innerkernel8 = [&](int unroll_m, int unroll_n,
578 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
579 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
580 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
581 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
582 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
583 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
589 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
591 prefetcht0(ptr[AO1 + LDA4]);
594 for (int i = 0; i < 8; i++) {
596 if (isLoad1Unmasked) {
597 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
599 vmaskmovps(ymm0, VMASK,
600 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
602 if (unroll_m >= 16) {
603 if (isLoad2Unmasked) {
604 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
606 vmaskmovps(ymm1, VMASK,
607 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
614 vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
616 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
618 fmareg = (i % 2 == 0) ? reg00 : reg12;
619 fma(useFma, ymm0, ymm2, fmareg);
620 if (unroll_m >= 16) {
621 fmareg = (i % 2 == 0) ? reg06 : reg18;
622 fma(useFma, ymm1, ymm2, fmareg);
626 prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
632 prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
635 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
637 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
639 fmareg = (i % 2 == 0) ? reg01 : reg13;
640 fma(useFma, ymm0, ymm2, fmareg);
641 if (unroll_m >= 16) {
642 fmareg = (i % 2 == 0) ? reg07 : reg19;
643 fma(useFma, ymm1, ymm2, fmareg);
648 vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
650 if (unroll_m >= 16) {
652 + (unroll_m * i + 1 * 8 - OFFSET)
657 sub(LDA4, -unroll_m * 8 * SIZE);
665 ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
668 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
670 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
672 fmareg = (i % 2 == 0) ? reg02 : reg14;
673 fma(useFma, ymm0, ymm2, fmareg);
674 if (unroll_m >= 16) {
675 fmareg = (i % 2 == 0) ? reg08 : reg20;
676 fma(useFma, ymm1, ymm2, fmareg);
689 prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
691 vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
693 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
695 fmareg = (i % 2 == 0) ? reg03 : reg15;
696 fma(useFma, ymm0, ymm2, fmareg);
697 if (unroll_m >= 16) {
698 fmareg = (i % 2 == 0) ? reg09 : reg21;
699 fma(useFma, ymm1, ymm2, fmareg);
706 prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
709 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
711 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
713 fmareg = (i % 2 == 0) ? reg04 : reg16;
714 fma(useFma, ymm0, ymm2, fmareg);
715 if (unroll_m >= 16) {
716 fmareg = (i % 2 == 0) ? reg10 : reg22;
717 fma(useFma, ymm1, ymm2, fmareg);
725 ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
728 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
730 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
732 fmareg = (i % 2 == 0) ? reg05 : reg17;
733 fma(useFma, ymm0, ymm2, fmareg);
734 if (unroll_m >= 16) {
735 fmareg = (i % 2 == 0) ? reg11 : reg23;
736 fma(useFma, ymm1, ymm2, fmareg);
740 prefetcht0(ptr[BO1 + BO2]);
748 ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
750 prefetcht0(ptr[AO1 + LDA4]);
754 if (i == 1 || i == 2) {
758 + (PREFETCHSIZEA + (2 + 2 * i) * 8)
761 prefetcht0(ptr[AO1 + LDA4]);
765 if (i == 3 || i == 4 || i == 5 || i == 6) {
766 if (unroll_m >= 16) {
769 + (PREFETCHSIZEA + (2 + 2 * i) * 8)
772 prefetcht0(ptr[AO1 + LDA4]);
784 lea(AA, ptr[AA + LDA]);
789 if (isLoad1Unmasked) {
792 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
798 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
801 if (unroll_m >= 16) {
802 if (isLoad2Unmasked) {
803 vmovups(ymm1, ptr[AO1
804 + (unroll_m * (i + 1) + 1 * 8
808 vmaskmovps(ymm1, VMASK,
810 + (unroll_m * (i + 1) + 1 * 8
819 sub(AO1, -unroll_m * 8 * SIZE);
825 // Inner kernel with k=4
826 auto innerkernel4 = [&](int unroll_m, int unroll_n,
827 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
828 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
829 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
830 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
831 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
832 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
838 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
840 prefetcht0(ptr[AO1 + LDA4]);
843 for (int i = 0; i < 4; i++) {
845 if (isLoad1Unmasked) {
846 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
848 vmaskmovps(ymm0, VMASK,
849 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
851 if (unroll_m >= 16) {
852 if (isLoad2Unmasked) {
853 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
855 vmaskmovps(ymm1, VMASK,
856 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
863 vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
865 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
867 fmareg = (i % 2 == 0) ? reg00 : reg12;
868 fma(useFma, ymm0, ymm2, fmareg);
869 if (unroll_m >= 16) {
870 fmareg = (i % 2 == 0) ? reg06 : reg18;
871 fma(useFma, ymm1, ymm2, fmareg);
875 prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
881 prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
884 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
886 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
888 fmareg = (i % 2 == 0) ? reg01 : reg13;
889 fma(useFma, ymm0, ymm2, fmareg);
890 if (unroll_m >= 16) {
891 fmareg = (i % 2 == 0) ? reg07 : reg19;
892 fma(useFma, ymm1, ymm2, fmareg);
897 vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
899 if (unroll_m >= 16) {
901 + (unroll_m * i + 1 * 8 - OFFSET)
906 sub(LDA4, -unroll_m * 4 * SIZE);
914 ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
917 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
919 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
921 fmareg = (i % 2 == 0) ? reg02 : reg14;
922 fma(useFma, ymm0, ymm2, fmareg);
923 if (unroll_m >= 16) {
924 fmareg = (i % 2 == 0) ? reg08 : reg20;
925 fma(useFma, ymm1, ymm2, fmareg);
938 prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
940 vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
942 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
944 fmareg = (i % 2 == 0) ? reg03 : reg15;
945 fma(useFma, ymm0, ymm2, fmareg);
946 if (unroll_m >= 16) {
947 fmareg = (i % 2 == 0) ? reg09 : reg21;
948 fma(useFma, ymm1, ymm2, fmareg);
955 prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
958 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
960 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
962 fmareg = (i % 2 == 0) ? reg04 : reg16;
963 fma(useFma, ymm0, ymm2, fmareg);
964 if (unroll_m >= 16) {
965 fmareg = (i % 2 == 0) ? reg10 : reg22;
966 fma(useFma, ymm1, ymm2, fmareg);
974 ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
977 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
979 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
981 fmareg = (i % 2 == 0) ? reg05 : reg17;
982 fma(useFma, ymm0, ymm2, fmareg);
983 if (unroll_m >= 16) {
984 fmareg = (i % 2 == 0) ? reg11 : reg23;
985 fma(useFma, ymm1, ymm2, fmareg);
989 prefetcht0(ptr[BO1 + BO2]);
997 ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
999 prefetcht0(ptr[AO1 + LDA4]);
1003 if (i == 1 || i == 2) {
1004 if (unroll_m >= 8) {
1007 + (PREFETCHSIZEA + (2 + 2 * i) * 8)
1010 prefetcht0(ptr[AO1 + LDA4]);
1016 sub(BO1, -4 * SIZE);
1017 if (unroll_n >= 4) {
1018 sub(BO2, -4 * SIZE);
1024 if (isLoad1Unmasked) {
1027 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1033 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1036 if (unroll_m >= 16) {
1037 if (isLoad2Unmasked) {
1038 vmovups(ymm1, ptr[AO1
1039 + (unroll_m * (i + 1) + 1 * 8
1043 vmaskmovps(ymm1, VMASK,
1045 + (unroll_m * (i + 1) + 1 * 8
1054 sub(AO1, -unroll_m * 4 * SIZE);
1059 // Inner kernel with k=2
1060 auto innerkernel2 = [&](int unroll_m, int unroll_n,
1061 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1062 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1063 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1064 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
1065 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
1066 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
1071 for (int i = 0; i < 2; i++) {
1073 if (isLoad1Unmasked) {
1074 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1076 vmaskmovps(ymm0, VMASK,
1077 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1079 if (unroll_m >= 16) {
1080 if (isLoad2Unmasked) {
1081 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1083 vmaskmovps(ymm1, VMASK,
1084 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1091 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1093 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1095 fmareg = (i % 2 == 0) ? reg00 : reg12;
1096 fma(useFma, ymm0, ymm2, fmareg);
1097 if (unroll_m >= 16) {
1098 fmareg = (i % 2 == 0) ? reg06 : reg18;
1099 fma(useFma, ymm1, ymm2, fmareg);
1101 if (unroll_n >= 2) {
1104 ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1106 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1108 fmareg = (i % 2 == 0) ? reg01 : reg13;
1109 fma(useFma, ymm0, ymm2, fmareg);
1110 if (unroll_m >= 16) {
1111 fmareg = (i % 2 == 0) ? reg07 : reg19;
1112 fma(useFma, ymm1, ymm2, fmareg);
1116 if (unroll_n >= 3) {
1120 ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
1123 ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1125 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1127 fmareg = (i % 2 == 0) ? reg02 : reg14;
1128 fma(useFma, ymm0, ymm2, fmareg);
1129 if (unroll_m >= 16) {
1130 fmareg = (i % 2 == 0) ? reg08 : reg20;
1131 fma(useFma, ymm1, ymm2, fmareg);
1135 if (unroll_n >= 4) {
1137 vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1139 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1141 fmareg = (i % 2 == 0) ? reg03 : reg15;
1142 fma(useFma, ymm0, ymm2, fmareg);
1143 if (unroll_m >= 16) {
1144 fmareg = (i % 2 == 0) ? reg09 : reg21;
1145 fma(useFma, ymm1, ymm2, fmareg);
1149 if (unroll_n >= 5) {
1152 ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1154 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1156 fmareg = (i % 2 == 0) ? reg04 : reg16;
1157 fma(useFma, ymm0, ymm2, fmareg);
1158 if (unroll_m >= 16) {
1159 fmareg = (i % 2 == 0) ? reg10 : reg22;
1160 fma(useFma, ymm1, ymm2, fmareg);
1164 if (unroll_n >= 6) {
1167 ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1169 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1171 fmareg = (i % 2 == 0) ? reg05 : reg17;
1172 fma(useFma, ymm0, ymm2, fmareg);
1173 if (unroll_m >= 16) {
1174 fmareg = (i % 2 == 0) ? reg11 : reg23;
1175 fma(useFma, ymm1, ymm2, fmareg);
1180 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1182 if (unroll_m >= 16) {
1184 + (unroll_m * 0 + 1 * 8 - OFFSET)
1188 sub(LDA4, -unroll_m * SIZE);
1192 if (isLoad1Unmasked) {
1193 vmovups(ymm0, ptr[AO1
1194 + (unroll_m * 1 + 0 * 8 - OFFSET)
1197 vmaskmovps(ymm0, VMASK,
1199 + (unroll_m * 1 + 0 * 8 - OFFSET)
1202 if (unroll_m >= 16) {
1203 if (isLoad2Unmasked) {
1206 + (unroll_m * 1 + 1 * 8 - OFFSET)
1209 vmaskmovps(ymm1, VMASK,
1211 + (unroll_m * 1 + 1 * 8 - OFFSET)
1215 sub(AO1, -unroll_m * SIZE);
1220 if (unroll_n >= 4) {
1230 // Inner kernel with k=1
1231 auto innerkernel1 = [&](int unroll_m, int unroll_n,
1232 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1233 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1234 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1235 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) {
1238 if (isLoad1Unmasked) {
1239 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1241 vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1243 if (unroll_m >= 16) {
1244 if (isLoad2Unmasked) {
1245 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1247 vmaskmovps(ymm1, VMASK,
1248 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1255 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1257 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1259 fma(useFma, ymm0, ymm2, reg00);
1260 if (unroll_m >= 16) {
1261 fma(useFma, ymm1, ymm2, reg06);
1264 if (unroll_n >= 2) {
1267 ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1269 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1271 fma(useFma, ymm0, ymm2, reg01);
1272 if (unroll_m >= 16) {
1273 fma(useFma, ymm1, ymm2, reg07);
1277 if (unroll_n >= 3) {
1280 ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1282 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1284 fma(useFma, ymm0, ymm2, reg02);
1285 if (unroll_m >= 16) {
1286 fma(useFma, ymm1, ymm2, reg08);
1290 if (unroll_n >= 4) {
1292 vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1294 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1296 fma(useFma, ymm0, ymm2, reg03);
1297 if (unroll_m >= 16) {
1298 fma(useFma, ymm1, ymm2, reg09);
1302 if (unroll_n >= 5) {
1305 ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1307 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1309 fma(useFma, ymm0, ymm2, reg04);
1310 if (unroll_m >= 16) {
1311 fma(useFma, ymm1, ymm2, reg10);
1315 if (unroll_n >= 6) {
1318 ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1320 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1322 fma(useFma, ymm0, ymm2, reg05);
1323 if (unroll_m >= 16) {
1324 fma(useFma, ymm1, ymm2, reg11);
1329 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1331 if (unroll_m >= 16) {
1332 vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
1335 sub(LDA4, -unroll_m * SIZE);
1339 if (isLoad1Unmasked) {
1341 ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1343 vmaskmovps(ymm0, VMASK,
1344 ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1346 if (unroll_m >= 16) {
1347 if (isLoad2Unmasked) {
1348 vmovups(ymm1, ptr[AO1
1349 + (unroll_m * 1 + 1 * 8 - OFFSET)
1352 vmaskmovps(ymm1, VMASK,
1354 + (unroll_m * 1 + 1 * 8 - OFFSET)
1358 sub(AO1, -unroll_m * SIZE);
1363 if (unroll_n >= 4) {
1372 // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as
1374 // After calculating results in registers, writes back to C matrix
1375 auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1376 bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma,
1377 Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6),
1378 Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9),
1379 Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12),
1380 Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15),
1381 Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6),
1382 Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
1383 Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
1384 Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
1388 lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
1394 lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
1396 lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
1400 lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]);
1401 lea(BO2, ptr[BO2 + LDB * 2]);
1405 if (isLoad1Unmasked) {
1407 ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1409 vmaskmovps(ymm0, VMASK,
1410 ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1412 if (unroll_m >= 16) {
1413 if (isLoad2Unmasked) {
1414 vmovups(ymm1, ptr[AO1
1415 + (unroll_m * 0 + 1 * 8 - OFFSET)
1418 vmaskmovps(ymm1, VMASK,
1420 + (unroll_m * 0 + 1 * 8 - OFFSET)
1426 for (int i = 4; i < 10; i++) {
1427 vxorps(Ymm(i), Ymm(i), Ymm(i));
1428 vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6));
1434 sub(LL, SECOND_FETCH);
1435 jle(".kernel13", T_NEAR);
1439 innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1440 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1441 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1442 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1443 reg21, reg22, reg23);
1444 jg(".kernel12", T_NEAR);
1448 prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
1450 prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
1452 prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]);
1454 prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]);
1456 prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]);
1458 prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
1460 add(LL, SECOND_FETCH);
1461 jle(".kernel15", T_NEAR);
1465 innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1466 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1467 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1468 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1469 reg21, reg22, reg23);
1470 jg(".kernel14", T_NEAR);
1475 jle(".kernel16", T_NEAR);
1476 innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1477 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1478 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1479 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1480 reg21, reg22, reg23);
1484 jle(".kernel17", T_NEAR);
1485 innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1486 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1487 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1488 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1489 reg21, reg22, reg23);
1493 if (unroll_m == 16) {
1494 if (unroll_n <= 3) {
1495 vaddps(reg00, reg00, reg12);
1496 vaddps(reg01, reg01, reg13);
1497 vaddps(reg02, reg02, reg14);
1498 vaddps(reg06, reg06, reg18);
1499 vaddps(reg07, reg07, reg19);
1500 vaddps(reg08, reg08, reg20);
1504 if (unroll_m <= 8) {
1505 vaddps(reg00, reg00, reg12);
1506 vaddps(reg01, reg01, reg13);
1507 vaddps(reg02, reg02, reg14);
1508 vaddps(reg03, reg03, reg15);
1509 vaddps(reg04, reg04, reg16);
1510 vaddps(reg05, reg05, reg17);
1514 jle(".kernel18", T_NEAR);
1515 innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1516 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1517 reg05, reg06, reg07, reg08, reg09, reg10, reg11);
1521 vbroadcastss(VALPHA, ALPHA);
1524 vbroadcastss(VBETA, BETA);
1527 // Write back the results; all beta and bias cases need to be
1530 case 1: mov(rax, LDC); break;
1531 case 2: lea(rax, ptr[LDC * 2]); break;
1532 case 3: lea(rax, ptr[LDC + LDC * 2]); break;
1533 case 4: lea(rax, ptr[LDC + LDC * 4]); break;
1535 lea(rax, ptr[LDC * 4]);
1539 lea(rax, ptr[LDC + LDC * 2]);
1546 if (isLoad1Unmasked) {
1547 vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1549 vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]);
1553 for (int i = 0; i < unroll_n; i++) {
1554 vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA);
1556 if (isLoad1Unmasked) {
1558 case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break;
1559 case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break;
1561 vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1563 case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break;
1564 case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break;
1566 vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1572 vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]);
1575 vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]);
1579 ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1582 vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]);
1585 vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]);
1589 ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1595 vaddps(Ymm(i + 4), ymm0, Ymm(i + 4));
1597 fma(useFma, VBETA, ymm0, Ymm(i + 4), true);
1601 vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4));
1603 if (isLoad1Unmasked) {
1605 case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break;
1607 vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4));
1610 vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1612 case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break;
1614 vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4));
1617 vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1623 vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4));
1627 ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1630 vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK,
1634 vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4));
1638 ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1641 vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK,
1647 if (unroll_m >= 16) {
1648 // Re-use ymm4 (VBIAS2)
1651 if (isLoad1Unmasked) {
1652 vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]);
1655 VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]);
1659 vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA);
1661 if (isLoad2Unmasked) {
1663 case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break;
1665 vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]);
1668 vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]);
1670 case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break;
1672 vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]);
1675 vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]);
1681 vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]);
1685 ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]);
1688 vmaskmovps(ymm0, VMASK,
1689 ptr[CO1 + LDC * 2 + 8 * SIZE]);
1692 vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]);
1696 ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]);
1699 vmaskmovps(ymm0, VMASK,
1700 ptr[CO2 + LDC * 2 + 8 * SIZE]);
1705 vaddps(Ymm(i + 10), ymm0, Ymm(i + 10));
1707 fma(useFma, VBETA, ymm0, Ymm(i + 10), true);
1711 vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10));
1713 if (isLoad2Unmasked) {
1716 vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10));
1719 vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10));
1722 vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1725 vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10));
1728 vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10));
1731 vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1737 vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10));
1740 vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK,
1744 vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK,
1748 vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10));
1751 vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK,
1755 vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK,
1764 if (unroll_n >= 4) {
1768 // Compute next address of B
1770 lea(rax, ptr[K * SIZE]);
1777 lea(BO1, ptr[BO1 + LDB * 2]);
1778 lea(BO2, ptr[BO2 + LDB * 2]);
1781 lea(BO1, ptr[BO1 + LDB3]);
1782 lea(BO2, ptr[BO2 + LDB3]);
1785 lea(BO1, ptr[BO1 + LDB * 4]);
1786 lea(BO2, ptr[BO2 + LDB * 4]);
1789 lea(BO1, ptr[BO1 + LDB * 4]);
1791 lea(BO2, ptr[BO2 + LDB * 4]);
1795 lea(BO1, ptr[BO1 + LDB3 * 2]);
1796 lea(BO2, ptr[BO2 + LDB3 * 2]);
1805 add(BO1, unroll_n * SIZE);
1811 auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1812 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1813 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1814 isDirect, isCopy, true);
1817 auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1818 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1819 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1820 isDirect, isCopy, true);
1823 auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1824 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1825 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1826 isDirect, isCopy, true);
1829 auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1830 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1831 bool useFma = true) {
1832 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1833 isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1834 Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1835 Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1836 Ymm(13), Ymm(14), Ymm(15));
1839 auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1840 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1841 kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1842 isDirect, isCopy, false);
1845 auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1846 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1847 kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1848 isDirect, isCopy, false);
1851 auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1852 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1853 bool useFma = true) {
1854 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1855 isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1856 Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1857 Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1861 auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1862 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1863 kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1867 auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1868 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1869 kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1873 auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1874 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1875 bool useFma = true) {
1876 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1877 isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1878 Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1879 Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1880 Ymm(13), Ymm(14), Ymm(15));
1883 auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1884 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1885 kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1886 isDirect, isCopy, false);
1889 auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1890 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1891 kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1892 isDirect, isCopy, false);
1895 // High-level subroutine; does packing if needed, then splits C matrix.
1896 // Operates on chunks of 16 rows, 6 columns at a time (handling tail
1897 // cases appropriately).
1898 // Masking is used for tail cases where M is not divisible by 8.
1900 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
1904 do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
1908 lea(CO2, ptr[CO1 + LDC * 2]);
1910 add(C, unroll_m * SIZE);
1913 lea(BO2, qword[B + LDB3]);
1917 lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
1919 jg(".subloop98", T_NEAR);
1922 lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
1929 // If N is too small, skip copy operation
1930 cmp(LL, UNROLL_N * 3);
1931 jle(".subloop30", T_NEAR);
1933 // If A is not aligned to cache line
1935 je(".subloop30", T_NEAR);
1938 jl(".subloop20", T_NEAR);
1943 if (unroll_m == 16) {
1944 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1945 isLoad2Unmasked, true, true);
1947 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1948 isLoad2Unmasked, true, true);
1951 if (unroll_m == 16) {
1952 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1953 isLoad2Unmasked, false, false);
1955 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1956 isLoad2Unmasked, false, false);
1962 jl(".subloop20", T_NEAR);
1966 if (unroll_m == 16) {
1967 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1968 isLoad2Unmasked, false, false);
1970 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked,
1975 jge(".subloop11", T_NEAR);
1980 jne(".subloop21", T_NEAR);
1981 if (unroll_m == 16) {
1982 kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
1985 kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
1988 jmp(".subloop99", T_NEAR);
1993 jne(".subloop22", T_NEAR);
1994 if (unroll_m == 16) {
1995 kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
1998 kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
2001 jmp(".subloop99", T_NEAR);
2006 jne(".subloop23", T_NEAR);
2007 if (unroll_m == 16) {
2008 kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2011 kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
2014 jmp(".subloop99", T_NEAR);
2019 jne(".subloop24", T_NEAR);
2020 if (unroll_m == 16) {
2021 kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2024 kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
2027 jmp(".subloop99", T_NEAR);
2032 jne(".subloop99", T_NEAR);
2033 if (unroll_m == 16) {
2034 kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2037 kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
2040 jmp(".subloop99", T_NEAR);
2046 jl(".subloop25", T_NEAR);
2050 if (unroll_m == 16) {
2051 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2052 isLoad2Unmasked, true, false);
2054 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2055 isLoad2Unmasked, true, false);
2059 jge(".subloop31", T_NEAR);
2064 jne(".subloop32", T_NEAR);
2065 if (unroll_m == 16) {
2066 kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2069 kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2072 jmp(".subloop99", T_NEAR);
2077 jne(".subloop33", T_NEAR);
2078 if (unroll_m == 16) {
2079 kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2082 kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2085 jmp(".subloop99", T_NEAR);
2090 jne(".subloop34", T_NEAR);
2091 if (unroll_m == 16) {
2092 kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2095 kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2098 jmp(".subloop99", T_NEAR);
2103 jne(".subloop35", T_NEAR);
2104 if (unroll_m == 16) {
2105 kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2108 kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2111 jmp(".subloop99", T_NEAR);
2116 jne(".subloop99", T_NEAR);
2117 if (unroll_m == 16) {
2118 kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2121 kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2128 // Compute address for A
2130 add(A, unroll_m * SIZE);
2133 imul(rax, rax, unroll_m);
2137 // Compute next address of BIAS
2139 add(BIAS, unroll_m * SIZE);
2149 // Get the registers
2159 vmovss(xmm0, ptr[ARG_ALPHA]);
2160 vmovss(xmm1, ptr[r15]);
2167 cmp(K, STACK_K_CAPACITY);
2168 jg(".buffer_in_ws", T_NEAR);
2170 // Create buffer and align to 4kB page
2171 lea(rax, ptr[K * SIZE]);
2175 and_(rsp, -PAGE_4K);
2176 jmp(".buffer_allocated", T_NEAR);
2181 L(".buffer_allocated");
2189 vmovss(ALPHA, xmm0);
2191 sub(A, -OFFSET * SIZE);
2192 sub(B, -OFFSET * SIZE);
2194 sal(LDA, BASE_SHIFT);
2195 sal(LDB, BASE_SHIFT);
2196 sal(LDC, BASE_SHIFT);
2197 lea(LDB3, ptr[LDB + LDB * 2]);
2199 for (int i = 0; i < 8; i++) {
2200 mov(dword[rsp + 88 + i * 4], i);
2203 if (isTransA && is_avx2) {
2205 vpbroadcastq(ymm1, xmm0);
2206 vinsertf128(ymm0, ymm0, xmm0, 1);
2207 vpermilpd(ymm0, ymm0, 5);
2208 vpaddq(ymm1, ymm1, ymm1);
2209 vperm2f128(ymm1, ymm1, ymm1, 8);
2210 vpaddq(ymm0, ymm0, ymm1);
2211 vmovups(STRIDE, ymm0);
2214 // Check A alignment and leading dimension; take copy-based path as
2222 jl(".main0", T_NEAR);
2226 subloop(UNROLL_M, true, true);
2229 jge(".main1", T_NEAR);
2234 jle(".main999", T_NEAR);
2238 jle(".main2", T_NEAR);
2241 vbroadcastss(VMASK, M);
2242 vpcmpgtd(VMASK, VMASK, MASK);
2244 subloop(16, true, false);
2245 jmp(".main999", T_NEAR);
2250 jne(".main3", T_NEAR);
2251 subloop(8, true, true);
2252 jmp(".main999", T_NEAR);
2258 vbroadcastss(VMASK, M);
2260 vpcmpgtd(VMASK, VMASK, MASK);
2262 auto xmask = Xmm(VMASK.getIdx());
2263 auto xmm_tmp = xmm4;
2265 vextractf128(xmm_tmp, VMASK, 1);
2266 vpcmpgtd(xmask, xmask, MASK);
2267 vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4
2268 vinsertf128(VMASK, VMASK, xmm_tmp, 1);
2270 subloop(8, false, false);
2274 // Restore original stack
2283 ker_ = reinterpret_cast<decltype(ker_)>(
2284 const_cast<uint8_t *>(this->getCode()));
2287 void operator()(long long int m, long long int n, long long int k,
2288 const float *alpha, const float *a, long long int lda,
2289 const float *b, long long int ldb, const float *beta, float *c,
2290 long long int ldc, const float *bias, float *ws)
2292 (*ker_)(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
2296 void (*ker_)(long long int m, long long int n, long long int k,
2297 const float *alpha, const float *a, long long int lda,
2298 const float *b, long long int ldb, const float *beta, float *c,
2299 long long int ldc, const float *bias, float *ws);
2302 typedef void (*ker)(long long int, long long int, long long int, float *,
2303 float *, long long int, float *, long long int, float *, float *,
2304 long long int, float *);
2305 void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
2306 const char *transb, int m, int n, int k, const float *alpha,
2307 const float *a, int lda, const float *b, int ldb, const float *beta,
2308 float *c, int ldc, const float *bias, float *ws)
2310 bool isTransA = (*transa == 'T' || *transa == 't');
2311 bool isTransB = (*transb == 'T' || *transb == 't');
2313 int Bm, sizeM, Bn, sizeN, Bk, sizeK;
2317 if ((m <= 0) || (n <= 0))
2320 if ((k <= 0) || (alpha[0] == 0.)) {
2322 if (beta[0] == 0.) {
2323 for (j = 0; j < n; j++)
2324 for (i = 0; i < m; i++)
2325 c[i + j * ldc] = 0.0;
2326 } else if (beta[0] != 1.) {
2327 for (j = 0; j < n; j++)
2328 for (i = 0; i < m; i++)
2329 c[i + j * ldc] *= beta[0];
2336 int BN = isTransA ? 96 : 48;
2337 int BK = isTransB ? 96 : 256;
2338 const float *curA, *curB, *curBias = NULL;
2341 for (Bk = 0; Bk < k; Bk += sizeK) {
2343 if (sizeK >= BK * 2)
2347 sizeK = (sizeK + 1) / 2;
2350 for (Bm = 0; Bm < m; Bm += sizeM) {
2352 if (sizeM >= BM * 2)
2355 if (sizeM > BM + BM / 2)
2356 sizeM = (sizeM + 1) / 2;
2359 for (Bn = 0; Bn < n; Bn += sizeN) {
2361 if (sizeN >= BN * 2)
2364 if (sizeN > BN + BN / 2)
2365 sizeN = (sizeN + 1) / 2;
2369 curA = a + Bm + (size_t)Bk * lda;
2371 curA = a + Bk + (size_t)Bm * lda;
2374 curB = b + Bk + (size_t)Bn * ldb;
2376 curB = b + Bn + (size_t)Bk * ldb;
2378 curC = c + Bm + (size_t)Bn * ldc;
2381 curBias = bias + Bm;
2387 if (*beta == 0.0 && bias == NULL)
2388 (*ker_b0_)((long long int)sizeM, (long long int)sizeN,
2389 (long long int)sizeK, alpha, curA,
2390 (long long int)lda, curB, (long long int)ldb,
2391 beta, curC, (long long int)ldc, curBias, ws);
2393 (*ker_bn_)((long long int)sizeM, (long long int)sizeN,
2394 (long long int)sizeK, alpha, curA,
2395 (long long int)lda, curB, (long long int)ldb,
2396 beta, curC, (long long int)ldc, curBias, ws);
2398 (*ker_b1_)((long long int)sizeM, (long long int)sizeN,
2399 (long long int)sizeK, alpha, curA,
2400 (long long int)lda, curB, (long long int)ldb, beta,
2401 curC, (long long int)ldc, curBias, ws);
2408 void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
2409 const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
2410 const float *A, const int *p_lda, const float *B, const int *p_ldb,
2411 const float *p_beta, float *C, const int *p_ldc, const float *bias)
2413 if (beta_ == 0. || beta_ == 1.)
2414 assert(*p_beta == beta_);
2415 assert((one_of(*transa, 'T', 't') == one_of(transa_, 'T', 't')));
2417 int nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
2424 float beta = *p_beta;
2427 int nthr_m, nthr_n, nthr_k, nthr_mn;
2429 assert(nthr <= nthrs_);
2431 // Determine threading partitioning
2432 gemm_utils::calc_nthr_nocopy_avx(
2433 m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
2434 assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1));
2436 // May not happen, but just in case
2437 if (nthr < nthr_m * nthr_n * nthr_k)
2438 nthr = nthr_m * nthr_n * nthr_k;
2440 nthr_mn = nthr_m * nthr_n;
2442 unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_;
2443 if (!ompstatus) return;
2445 float *c_buffers = NULL;
2446 float *ws_buffers = NULL;
2449 for (int i = 0; i < nthr; i++)
2450 ompstatus[i * CACHE_LINE_SIZE] = 0;
2452 c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
2453 * sizeof(float), PAGE_4K);
2456 const size_t ws_elems_per_thr = k * 16 + 64;
2457 const size_t ws_size_per_thr
2458 = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
2459 if (k > STACK_K_CAPACITY) {
2460 ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
2463 parallel(nthr, [&](const int ithr, const int nthr) {
2464 int ithr_m, ithr_n, ithr_k, ithr_mn;
2465 int m_from, m_to, myM;
2466 int n_from, n_to, myN;
2467 int k_from, k_to, myK;
2469 const float *myA, *myB, *myBias = NULL;
2470 float *myC = C, myBeta;
2471 float *ws = ws_buffers ?
2472 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
2475 if (ithr < nthr_m * nthr_n * nthr_k) {
2477 ithr_mn = ithr % nthr_mn;
2478 ithr_m = ithr_mn % nthr_m;
2479 ithr_n = ithr_mn / nthr_m;
2480 ithr_k = ithr / nthr_mn;
2482 /* swap ithr_k for performance improvement */
2484 ithr_k = nthr_k - 1;
2485 else if (ithr_k == nthr_k - 1)
2488 m_from = MB * (ithr_m);
2489 m_to = MB * (ithr_m + 1);
2492 myM = m_to - m_from;
2494 n_from = NB * (ithr_n);
2495 n_to = NB * (ithr_n + 1);
2498 myN = n_to - n_from;
2500 k_from = KB * (ithr_k);
2501 k_to = KB * (ithr_k + 1);
2504 myK = k_to - k_from;
2506 cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2507 ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
2509 if ((myM > 0) && (myN > 0)) {
2511 if (*transa == 'N' || *transa == 'n') {
2512 myA = &(A[m_from + k_from * lda]);
2514 myA = &(A[k_from + m_from * lda]);
2516 if (*transb == 'N' || *transb == 'n') {
2517 myB = &(B[k_from + n_from * ldb]);
2519 myB = &(B[n_from + k_from * ldb]);
2522 myC = &(C[m_from + n_from * ldc]);
2526 myBias = &(bias[m_from]);
2528 myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
2534 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
2535 lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
2538 ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
2543 // sum matrices partitioned along K dimension
2546 gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2550 myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
2551 myC = myC + n1 * MB;
2552 /* need to wait until main thread finishes */
2553 while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
2556 /* my cache is hot */
2557 gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2558 &C[m_from + (n_from + n1) * ldc], ldc);
2561 for (int ik = 1; ik < nthr_k; ++ik) {
2564 myC = c_buffers + MB * NB * (cbase + ik - 1);
2565 myC = myC + n1 * MB;
2567 while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2570 gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2571 &C[m_from + (n_from + n1) * ldc], ldc);
2583 jit_avx_gemm_f32::jit_avx_gemm_f32(
2584 char transa, char transb, float beta, bool hasBias)
2591 assert(beta == 0.0);
2593 ker_bn_ = new xbyak_gemm(transa, transb, beta, hasBias);
2595 ker_b1_ = new xbyak_gemm(transa, transb, 1.0);
2599 if (beta != 0.0 || (beta == 0.0 && hasBias)) {
2600 ker_b0_ = new xbyak_gemm(transa, transb, 0.0);
2604 nthrs_ = mkldnn_get_max_threads();
2605 ompstatus_ = (unsigned int *)malloc(
2606 sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64);
2610 jit_avx_gemm_f32::~jit_avx_gemm_f32()
2615 if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
2624 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s