1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "mkldnn_thread.hpp"
23 #include "ref_gemm_f32.hpp"
24 #include "gemm_utils_f32.hpp"
25 #include "jit_avx_gemm_f32.hpp"
27 #include "jit_generator.hpp"
33 #define CACHE_LINE_SIZE 64
35 #define STACKSIZE get_size_of_abi_save_regs()
37 #define STACK_K_CAPACITY 128
39 #define STACK_K_CAPACITY 8192
44 #define SECOND_FETCH 14
46 namespace avx_gemm_f32 {
47 using namespace gemm_utils;
49 struct xbyak_gemm : public jit_generator {
50 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
52 xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
53 void *code_ptr = nullptr,
54 size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
55 : jit_generator(code_ptr, code_size)
57 using namespace Xbyak;
59 const bool is_avx2 = mayiuse(avx2);
60 assert(IMPLICATION(!is_avx2, mayiuse(avx)));
62 const int UNROLL_M = is_avx2 ? 16 : 8;
63 const int UNROLL_N = 6;
65 bool isBeta0 = (beta == 0.0);
66 bool isBetaN = (!isBeta0 && beta != 1.0);
68 // various definitions for convenience
69 auto ARG_M = abi_param1;
70 auto ARG_N = abi_param2;
72 auto ARG_ALPHA = abi_param4;
74 auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
75 auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
76 sizeof(float *) + STACKSIZE];
77 const auto stackOffset = OFFSET_SHADOWSPACE +
78 sizeof(float *) + STACKSIZE;
84 const auto stackOffset = STACKSIZE;
88 auto ARG_B = ptr[rsp + 8 + stackOffset];
89 auto ARG_LDB = ptr[rsp + 16 + stackOffset];
90 auto ARG_BETA = ptr[rsp + 24 + stackOffset];
91 auto ARG_C = ptr[rsp + 32 + stackOffset];
92 auto ARG_LDC = ptr[rsp + 40 + stackOffset];
93 auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
94 auto ARG_WS = ptr[rsp + 56 + stackOffset];
100 auto AO1 = abi_param2;
101 auto BO1 = abi_param4;
106 auto LDA4 = abi_param1;
108 auto BIAS1 = abi_param1;
110 auto M = qword[rsp + 0];
111 auto N = qword[rsp + 8];
112 auto FLAG = qword[rsp + 16];
113 auto I = qword[rsp + 24];
114 auto C = qword[rsp + 32];
115 auto BIAS = qword[rsp + 40];
116 auto ALPHA = qword[rsp + 48];
117 auto BETA = qword[rsp + 64];
118 auto ORIG_A = qword[rsp + 80];
119 auto MASK = dword[rsp + 88];
120 auto STRIDE = qword[rsp + 120];
121 auto ORIG_SP = qword[rsp + 152];
129 auto PREFETCHSIZEA = 128;
130 auto PREFETCHSIZEB = (!isTransB) ? -16 : 0;
132 // Function for packing if needed
134 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
135 Label pack2, pack3, pack4, pack10;
141 lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
144 lea(BO2, ptr[BO1 + LDA * 4]);
145 lea(CO1, ptr[LDA + LDA * 2]);
146 vmovupd(ymm7, STRIDE);
156 for (int i = 0; i < 4; i++) {
157 regIdx = (i % 2 == 0) ? 4 : 6;
158 if (isLoad1Unmasked) {
160 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
162 vmaskmovps(Ymm(regIdx), VMASK,
163 ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
166 if (isLoad2Unmasked) {
167 vmovups(Ymm(regIdx + 1),
168 ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
170 vmaskmovps(Ymm(regIdx + 1), VMASK,
171 ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
176 vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
180 + (unroll_m * i + 1 * 8 - OFFSET)
187 if (isLoad1Unmasked) {
188 for (int i = 0; i < 2; i++) {
189 reg = (i % 2 == 0) ? BO1 : BO2;
190 vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]);
192 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
193 lea(BO2, ptr[reg + LDA * 2]);
194 vunpcklps(xmm4, xmm0, xmm1);
195 vunpckhps(xmm5, xmm0, xmm1);
196 vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
198 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
199 lea(BO2, ptr[BO2 + LDA * 2]);
200 vunpcklps(xmm6, xmm0, xmm1);
201 vunpckhps(xmm2, xmm0, xmm1);
203 vunpcklpd(xmm0, xmm4, xmm6);
204 vunpckhpd(xmm1, xmm4, xmm6);
206 + (unroll_m * 0 + i * 4 - OFFSET)
210 + (unroll_m * 1 + i * 4 - OFFSET)
213 vunpcklpd(xmm0, xmm5, xmm2);
214 vunpckhpd(xmm1, xmm5, xmm2);
216 + (unroll_m * 2 + i * 4 - OFFSET)
220 + (unroll_m * 3 + i * 4 - OFFSET)
224 } else if (is_avx2) {
225 for (int i = 0; i < 2; i++) {
228 ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE],
232 ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
236 + (unroll_m * (2 * i) + 0 * 4 - OFFSET)
240 + (unroll_m * (2 * i + 1) + 0 * 4
246 lea(BO2, ptr[BO1 + LDA * 4]);
248 for (int i = 0; i < 2; i++) {
249 vextractf128(xmm4, ymm3, 1);
251 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
253 vextractf128(xmm4, ymm3, 1);
255 ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
259 + (unroll_m * (2 * i) + 1 * 4 - OFFSET)
263 + (unroll_m * (2 * i + 1) + 1 * 4
269 lea(BO2, ptr[BO2 + LDA * 4]);
271 vxorps(xmm4, xmm4, xmm4);
272 lea(BO2, ptr[BO1 + LDA * 4]);
274 auto el_cp = [&](int section, int ld_step) {
275 RegExp src_addr = section == 0 ? BO1 : BO2;
276 if (ld_step == 1 || ld_step == 2)
277 src_addr = src_addr + LDA * ld_step;
278 else if (ld_step == 3)
279 src_addr = src_addr + CO1;
280 src_addr = src_addr - OFFSET * SIZE;
282 vmovups(Xmm(ld_step % 2), ptr[src_addr]);
283 RegExp dst_addr = AO1
284 + (ld_step + section * 4 - OFFSET) * SIZE;
285 for (int off = 0; off < 4; ++off)
286 pextrd(ptr[dst_addr + unroll_m * off * SIZE],
287 Xmm(ld_step % 2), off);
291 el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
292 el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
293 el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
294 el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
295 el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
296 el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
300 lea(BO2, ptr[BO2 + LDA * 4]);
303 if (unroll_m >= 16) {
305 if (isLoad2Unmasked) {
306 for (int i = 0; i < 2; i++) {
307 vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
308 vmovups(xmm1, ptr[BO2 + LDA * 1
309 + (0 * 8 - OFFSET) * SIZE]);
310 lea(BO2, ptr[BO2 + LDA * 2]);
311 vunpcklps(xmm4, xmm0, xmm1);
312 vunpckhps(xmm5, xmm0, xmm1);
313 vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
314 vmovups(xmm1, ptr[BO2 + LDA * 1
315 + (0 * 8 - OFFSET) * SIZE]);
317 lea(BO2, ptr[BO2 + LDA * 2]);
318 vunpcklps(xmm6, xmm0, xmm1);
319 vunpckhps(xmm2, xmm0, xmm1);
321 vunpcklpd(xmm0, xmm4, xmm6);
322 vunpckhpd(xmm1, xmm4, xmm6);
324 + (unroll_m * 0 + (i + 2) * 4
329 + (unroll_m * 1 + (i + 2) * 4
333 vunpcklpd(xmm0, xmm5, xmm2);
334 vunpckhpd(xmm1, xmm5, xmm2);
336 + (unroll_m * 2 + (i + 2) * 4
341 + (unroll_m * 3 + (i + 2) * 4
347 for (int i = 0; i < 2; i++) {
350 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
355 + ((2 * i + 1) - OFFSET) * SIZE],
359 + (unroll_m * (2 * i) + 2 * 4
364 + (unroll_m * (2 * i + 1) + 2 * 4
370 lea(BO2, ptr[BO2 + LDA * 4]);
372 for (int i = 0; i < 2; i++) {
373 vextractf128(xmm4, ymm3, 1);
375 ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
377 vextractf128(xmm4, ymm3, 1);
380 + ((2 * i + 1) - OFFSET) * SIZE],
384 + (unroll_m * (2 * i) + 3 * 4
389 + (unroll_m * (2 * i + 1) + 3 * 4
395 lea(BO2, ptr[BO2 + LDA * 4]);
398 add(BO1, (4 * SIZE));
401 add(AO1, unroll_m * 4 * SIZE);
414 if (isLoad1Unmasked) {
415 vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
417 vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
420 if (isLoad2Unmasked) {
421 vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
423 vmaskmovps(ymm5, VMASK,
424 ptr[BO1 + (1 + 8 - OFFSET) * SIZE]);
428 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
431 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
435 if (isLoad1Unmasked) {
436 for (int i = 0; i < 2; i++) {
437 reg = (i % 2 == 0) ? BO1 : BO2;
438 vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]);
440 ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
441 lea(BO2, ptr[reg + LDA * 2]);
442 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
444 vunpcklpd(xmm1, xmm1, xmm2);
445 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
448 for (int i = 0; i < 2; i++) {
449 vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
451 ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
452 lea(BO2, ptr[BO2 + LDA * 2]);
453 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
455 vunpcklpd(xmm1, xmm1, xmm2);
456 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
458 } else if (is_avx2) {
460 vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE],
462 lea(BO2, ptr[BO1 + LDA * 4]);
463 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
466 vextractf128(xmm4, ymm3, 1);
467 vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
469 lea(BO2, ptr[BO2 + LDA * 4]);
470 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
473 vxorps(xmm4, xmm4, xmm4);
474 lea(BO2, ptr[BO1 + LDA * 4]);
476 auto el_cp = [&](int section, int ld_step) {
477 RegExp src_addr = section == 0 ? BO1 : BO2;
478 if (ld_step == 1 || ld_step == 2)
479 src_addr = src_addr + LDA * ld_step;
480 else if (ld_step == 3)
481 src_addr = src_addr + CO1;
482 src_addr = src_addr - OFFSET * SIZE;
484 vmovss(xmm1, ptr[src_addr]);
485 RegExp dst_addr = AO1
486 + (ld_step + section * 4 - OFFSET) * SIZE;
487 movss(ptr[dst_addr], xmm1);
491 el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
492 el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
493 el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
494 el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
495 el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
496 el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
500 lea(BO2, ptr[BO2 + LDA * 4]);
503 if (unroll_m >= 16) {
505 if (isLoad2Unmasked) {
506 for (int i = 0; i < 2; i++) {
508 ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
509 vmovss(xmm0, ptr[BO2 + LDA * 1
510 + (0 * 8 - OFFSET) * SIZE]);
511 lea(BO2, ptr[BO2 + LDA * 2]);
512 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
514 vunpcklpd(xmm1, xmm1, xmm2);
518 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
520 lea(BO2, ptr[BO2 + LDA * 4]);
522 vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE],
525 if (isLoad2Unmasked) {
526 for (int i = 0; i < 2; i++) {
528 ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
529 vmovss(xmm0, ptr[BO2 + LDA * 1
530 + (0 * 8 - OFFSET) * SIZE]);
531 lea(BO2, ptr[BO2 + LDA * 2]);
532 vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
534 vunpcklpd(xmm1, xmm1, xmm2);
536 vextractf128(xmm4, ymm3, 1);
538 ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
541 vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE],
547 add(AO1, unroll_m * SIZE);
555 // Fused multiply add; may become one or two instructions
556 auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2,
557 bool overWrite = false) {
560 vfmadd231ps(reg2, reg1, reg0);
562 assert(UNROLL_M == 8);
563 auto tent_vreg = overWrite ? reg1 : ymm1;
564 vmulps(tent_vreg, reg1, reg0);
565 vaddps(reg2, reg2, tent_vreg);
569 vmulps(ymm15, reg1, reg0);
570 vaddps(reg2, reg2, ymm15);
572 vmulps(reg1, reg1, reg0);
573 vaddps(reg2, reg2, reg1);
578 // Inner kernel with k=8
579 auto innerkernel8 = [&](int unroll_m, int unroll_n,
580 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
581 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
582 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
583 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
584 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
585 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
591 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
593 prefetcht0(ptr[AO1 + LDA4]);
596 for (int i = 0; i < 8; i++) {
598 if (isLoad1Unmasked) {
599 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
601 vmaskmovps(ymm0, VMASK,
602 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
604 if (unroll_m >= 16) {
605 if (isLoad2Unmasked) {
606 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
608 vmaskmovps(ymm1, VMASK,
609 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
616 vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
618 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
620 fmareg = (i % 2 == 0) ? reg00 : reg12;
621 fma(useFma, ymm0, ymm2, fmareg);
622 if (unroll_m >= 16) {
623 fmareg = (i % 2 == 0) ? reg06 : reg18;
624 fma(useFma, ymm1, ymm2, fmareg);
628 prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
634 prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
637 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
639 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
641 fmareg = (i % 2 == 0) ? reg01 : reg13;
642 fma(useFma, ymm0, ymm2, fmareg);
643 if (unroll_m >= 16) {
644 fmareg = (i % 2 == 0) ? reg07 : reg19;
645 fma(useFma, ymm1, ymm2, fmareg);
650 vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
652 if (unroll_m >= 16) {
654 + (unroll_m * i + 1 * 8 - OFFSET)
659 sub(LDA4, -unroll_m * 8 * SIZE);
667 ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
670 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
672 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
674 fmareg = (i % 2 == 0) ? reg02 : reg14;
675 fma(useFma, ymm0, ymm2, fmareg);
676 if (unroll_m >= 16) {
677 fmareg = (i % 2 == 0) ? reg08 : reg20;
678 fma(useFma, ymm1, ymm2, fmareg);
691 prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
693 vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
695 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
697 fmareg = (i % 2 == 0) ? reg03 : reg15;
698 fma(useFma, ymm0, ymm2, fmareg);
699 if (unroll_m >= 16) {
700 fmareg = (i % 2 == 0) ? reg09 : reg21;
701 fma(useFma, ymm1, ymm2, fmareg);
708 prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
711 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
713 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
715 fmareg = (i % 2 == 0) ? reg04 : reg16;
716 fma(useFma, ymm0, ymm2, fmareg);
717 if (unroll_m >= 16) {
718 fmareg = (i % 2 == 0) ? reg10 : reg22;
719 fma(useFma, ymm1, ymm2, fmareg);
727 ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
730 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
732 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
734 fmareg = (i % 2 == 0) ? reg05 : reg17;
735 fma(useFma, ymm0, ymm2, fmareg);
736 if (unroll_m >= 16) {
737 fmareg = (i % 2 == 0) ? reg11 : reg23;
738 fma(useFma, ymm1, ymm2, fmareg);
742 prefetcht0(ptr[BO1 + BO2]);
750 ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
752 prefetcht0(ptr[AO1 + LDA4]);
756 if (i == 1 || i == 2) {
760 + (PREFETCHSIZEA + (2 + 2 * i) * 8)
763 prefetcht0(ptr[AO1 + LDA4]);
767 if (i == 3 || i == 4 || i == 5 || i == 6) {
768 if (unroll_m >= 16) {
771 + (PREFETCHSIZEA + (2 + 2 * i) * 8)
774 prefetcht0(ptr[AO1 + LDA4]);
786 lea(AA, ptr[AA + LDA]);
791 if (isLoad1Unmasked) {
794 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
800 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
803 if (unroll_m >= 16) {
804 if (isLoad2Unmasked) {
805 vmovups(ymm1, ptr[AO1
806 + (unroll_m * (i + 1) + 1 * 8
810 vmaskmovps(ymm1, VMASK,
812 + (unroll_m * (i + 1) + 1 * 8
821 sub(AO1, -unroll_m * 8 * SIZE);
827 // Inner kernel with k=4
828 auto innerkernel4 = [&](int unroll_m, int unroll_n,
829 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
830 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
831 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
832 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
833 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
834 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
840 prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
842 prefetcht0(ptr[AO1 + LDA4]);
845 for (int i = 0; i < 4; i++) {
847 if (isLoad1Unmasked) {
848 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
850 vmaskmovps(ymm0, VMASK,
851 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
853 if (unroll_m >= 16) {
854 if (isLoad2Unmasked) {
855 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
857 vmaskmovps(ymm1, VMASK,
858 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
865 vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
867 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
869 fmareg = (i % 2 == 0) ? reg00 : reg12;
870 fma(useFma, ymm0, ymm2, fmareg);
871 if (unroll_m >= 16) {
872 fmareg = (i % 2 == 0) ? reg06 : reg18;
873 fma(useFma, ymm1, ymm2, fmareg);
877 prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
883 prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
886 ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
888 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
890 fmareg = (i % 2 == 0) ? reg01 : reg13;
891 fma(useFma, ymm0, ymm2, fmareg);
892 if (unroll_m >= 16) {
893 fmareg = (i % 2 == 0) ? reg07 : reg19;
894 fma(useFma, ymm1, ymm2, fmareg);
899 vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
901 if (unroll_m >= 16) {
903 + (unroll_m * i + 1 * 8 - OFFSET)
908 sub(LDA4, -unroll_m * 4 * SIZE);
916 ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
919 ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
921 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
923 fmareg = (i % 2 == 0) ? reg02 : reg14;
924 fma(useFma, ymm0, ymm2, fmareg);
925 if (unroll_m >= 16) {
926 fmareg = (i % 2 == 0) ? reg08 : reg20;
927 fma(useFma, ymm1, ymm2, fmareg);
940 prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
942 vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
944 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
946 fmareg = (i % 2 == 0) ? reg03 : reg15;
947 fma(useFma, ymm0, ymm2, fmareg);
948 if (unroll_m >= 16) {
949 fmareg = (i % 2 == 0) ? reg09 : reg21;
950 fma(useFma, ymm1, ymm2, fmareg);
957 prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
960 ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
962 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
964 fmareg = (i % 2 == 0) ? reg04 : reg16;
965 fma(useFma, ymm0, ymm2, fmareg);
966 if (unroll_m >= 16) {
967 fmareg = (i % 2 == 0) ? reg10 : reg22;
968 fma(useFma, ymm1, ymm2, fmareg);
976 ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
979 ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
981 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
983 fmareg = (i % 2 == 0) ? reg05 : reg17;
984 fma(useFma, ymm0, ymm2, fmareg);
985 if (unroll_m >= 16) {
986 fmareg = (i % 2 == 0) ? reg11 : reg23;
987 fma(useFma, ymm1, ymm2, fmareg);
991 prefetcht0(ptr[BO1 + BO2]);
999 ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
1001 prefetcht0(ptr[AO1 + LDA4]);
1005 if (i == 1 || i == 2) {
1006 if (unroll_m >= 8) {
1009 + (PREFETCHSIZEA + (2 + 2 * i) * 8)
1012 prefetcht0(ptr[AO1 + LDA4]);
1018 sub(BO1, -4 * SIZE);
1019 if (unroll_n >= 4) {
1020 sub(BO2, -4 * SIZE);
1026 if (isLoad1Unmasked) {
1029 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1035 + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
1038 if (unroll_m >= 16) {
1039 if (isLoad2Unmasked) {
1040 vmovups(ymm1, ptr[AO1
1041 + (unroll_m * (i + 1) + 1 * 8
1045 vmaskmovps(ymm1, VMASK,
1047 + (unroll_m * (i + 1) + 1 * 8
1056 sub(AO1, -unroll_m * 4 * SIZE);
1061 // Inner kernel with k=2
1062 auto innerkernel2 = [&](int unroll_m, int unroll_n,
1063 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1064 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1065 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1066 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
1067 Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
1068 Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
1073 for (int i = 0; i < 2; i++) {
1075 if (isLoad1Unmasked) {
1076 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1078 vmaskmovps(ymm0, VMASK,
1079 ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1081 if (unroll_m >= 16) {
1082 if (isLoad2Unmasked) {
1083 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1085 vmaskmovps(ymm1, VMASK,
1086 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1093 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1095 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1097 fmareg = (i % 2 == 0) ? reg00 : reg12;
1098 fma(useFma, ymm0, ymm2, fmareg);
1099 if (unroll_m >= 16) {
1100 fmareg = (i % 2 == 0) ? reg06 : reg18;
1101 fma(useFma, ymm1, ymm2, fmareg);
1103 if (unroll_n >= 2) {
1106 ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1108 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1110 fmareg = (i % 2 == 0) ? reg01 : reg13;
1111 fma(useFma, ymm0, ymm2, fmareg);
1112 if (unroll_m >= 16) {
1113 fmareg = (i % 2 == 0) ? reg07 : reg19;
1114 fma(useFma, ymm1, ymm2, fmareg);
1118 if (unroll_n >= 3) {
1122 ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
1125 ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1127 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1129 fmareg = (i % 2 == 0) ? reg02 : reg14;
1130 fma(useFma, ymm0, ymm2, fmareg);
1131 if (unroll_m >= 16) {
1132 fmareg = (i % 2 == 0) ? reg08 : reg20;
1133 fma(useFma, ymm1, ymm2, fmareg);
1137 if (unroll_n >= 4) {
1139 vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1141 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1143 fmareg = (i % 2 == 0) ? reg03 : reg15;
1144 fma(useFma, ymm0, ymm2, fmareg);
1145 if (unroll_m >= 16) {
1146 fmareg = (i % 2 == 0) ? reg09 : reg21;
1147 fma(useFma, ymm1, ymm2, fmareg);
1151 if (unroll_n >= 5) {
1154 ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1156 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1158 fmareg = (i % 2 == 0) ? reg04 : reg16;
1159 fma(useFma, ymm0, ymm2, fmareg);
1160 if (unroll_m >= 16) {
1161 fmareg = (i % 2 == 0) ? reg10 : reg22;
1162 fma(useFma, ymm1, ymm2, fmareg);
1166 if (unroll_n >= 6) {
1169 ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1171 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1173 fmareg = (i % 2 == 0) ? reg05 : reg17;
1174 fma(useFma, ymm0, ymm2, fmareg);
1175 if (unroll_m >= 16) {
1176 fmareg = (i % 2 == 0) ? reg11 : reg23;
1177 fma(useFma, ymm1, ymm2, fmareg);
1182 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1184 if (unroll_m >= 16) {
1186 + (unroll_m * 0 + 1 * 8 - OFFSET)
1190 sub(LDA4, -unroll_m * SIZE);
1194 if (isLoad1Unmasked) {
1195 vmovups(ymm0, ptr[AO1
1196 + (unroll_m * 1 + 0 * 8 - OFFSET)
1199 vmaskmovps(ymm0, VMASK,
1201 + (unroll_m * 1 + 0 * 8 - OFFSET)
1204 if (unroll_m >= 16) {
1205 if (isLoad2Unmasked) {
1208 + (unroll_m * 1 + 1 * 8 - OFFSET)
1211 vmaskmovps(ymm1, VMASK,
1213 + (unroll_m * 1 + 1 * 8 - OFFSET)
1217 sub(AO1, -unroll_m * SIZE);
1222 if (unroll_n >= 4) {
1232 // Inner kernel with k=1
1233 auto innerkernel1 = [&](int unroll_m, int unroll_n,
1234 bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
1235 bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
1236 Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
1237 Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) {
1240 if (isLoad1Unmasked) {
1241 vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1243 vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
1245 if (unroll_m >= 16) {
1246 if (isLoad2Unmasked) {
1247 vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1249 vmaskmovps(ymm1, VMASK,
1250 ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
1257 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1259 vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
1261 fma(useFma, ymm0, ymm2, reg00);
1262 if (unroll_m >= 16) {
1263 fma(useFma, ymm1, ymm2, reg06);
1266 if (unroll_n >= 2) {
1269 ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1271 vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
1273 fma(useFma, ymm0, ymm2, reg01);
1274 if (unroll_m >= 16) {
1275 fma(useFma, ymm1, ymm2, reg07);
1279 if (unroll_n >= 3) {
1282 ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1284 vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
1286 fma(useFma, ymm0, ymm2, reg02);
1287 if (unroll_m >= 16) {
1288 fma(useFma, ymm1, ymm2, reg08);
1292 if (unroll_n >= 4) {
1294 vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
1296 vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
1298 fma(useFma, ymm0, ymm2, reg03);
1299 if (unroll_m >= 16) {
1300 fma(useFma, ymm1, ymm2, reg09);
1304 if (unroll_n >= 5) {
1307 ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1309 vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
1311 fma(useFma, ymm0, ymm2, reg04);
1312 if (unroll_m >= 16) {
1313 fma(useFma, ymm1, ymm2, reg10);
1317 if (unroll_n >= 6) {
1320 ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1322 vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
1324 fma(useFma, ymm0, ymm2, reg05);
1325 if (unroll_m >= 16) {
1326 fma(useFma, ymm1, ymm2, reg11);
1331 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
1333 if (unroll_m >= 16) {
1334 vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
1337 sub(LDA4, -unroll_m * SIZE);
1341 if (isLoad1Unmasked) {
1343 ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1345 vmaskmovps(ymm0, VMASK,
1346 ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
1348 if (unroll_m >= 16) {
1349 if (isLoad2Unmasked) {
1350 vmovups(ymm1, ptr[AO1
1351 + (unroll_m * 1 + 1 * 8 - OFFSET)
1354 vmaskmovps(ymm1, VMASK,
1356 + (unroll_m * 1 + 1 * 8 - OFFSET)
1360 sub(AO1, -unroll_m * SIZE);
1365 if (unroll_n >= 4) {
1374 // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as
1376 // After calculating results in registers, writes back to C matrix
1377 auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1378 bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma,
1379 Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6),
1380 Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9),
1381 Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12),
1382 Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15),
1383 Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6),
1384 Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
1385 Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
1386 Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
1388 lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
1394 lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
1396 lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
1400 lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]);
1401 lea(BO2, ptr[BO2 + LDB * 2]);
1405 if (isLoad1Unmasked) {
1407 ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1409 vmaskmovps(ymm0, VMASK,
1410 ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
1412 if (unroll_m >= 16) {
1413 if (isLoad2Unmasked) {
1414 vmovups(ymm1, ptr[AO1
1415 + (unroll_m * 0 + 1 * 8 - OFFSET)
1418 vmaskmovps(ymm1, VMASK,
1420 + (unroll_m * 0 + 1 * 8 - OFFSET)
1426 for (int i = 4; i < 10; i++) {
1427 vxorps(Ymm(i), Ymm(i), Ymm(i));
1428 vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6));
1434 Label kernel12, kernel13, kernel14, kernel15;
1435 Label kernel16, kernel17, kernel18;
1437 sub(LL, SECOND_FETCH);
1438 jle(kernel13, T_NEAR);
1442 innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1443 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1444 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1445 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1446 reg21, reg22, reg23);
1447 jg(kernel12, T_NEAR);
1451 prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
1453 prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
1455 prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]);
1457 prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]);
1459 prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]);
1461 prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
1463 add(LL, SECOND_FETCH);
1464 jle(kernel15, T_NEAR);
1468 innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1469 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1470 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1471 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1472 reg21, reg22, reg23);
1473 jg(kernel14, T_NEAR);
1478 jle(kernel16, T_NEAR);
1479 innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1480 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1481 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1482 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1483 reg21, reg22, reg23);
1487 jle(kernel17, T_NEAR);
1488 innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1489 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1490 reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
1491 reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
1492 reg21, reg22, reg23);
1496 if (unroll_m == 16) {
1497 if (unroll_n <= 3) {
1498 vaddps(reg00, reg00, reg12);
1499 vaddps(reg01, reg01, reg13);
1500 vaddps(reg02, reg02, reg14);
1501 vaddps(reg06, reg06, reg18);
1502 vaddps(reg07, reg07, reg19);
1503 vaddps(reg08, reg08, reg20);
1507 if (unroll_m <= 8) {
1508 vaddps(reg00, reg00, reg12);
1509 vaddps(reg01, reg01, reg13);
1510 vaddps(reg02, reg02, reg14);
1511 vaddps(reg03, reg03, reg15);
1512 vaddps(reg04, reg04, reg16);
1513 vaddps(reg05, reg05, reg17);
1517 jle(kernel18, T_NEAR);
1518 innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1519 isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
1520 reg05, reg06, reg07, reg08, reg09, reg10, reg11);
1524 vbroadcastss(VALPHA, ALPHA);
1527 vbroadcastss(VBETA, BETA);
1530 // Write back the results; all beta and bias cases need to be
1533 case 1: mov(rax, LDC); break;
1534 case 2: lea(rax, ptr[LDC * 2]); break;
1535 case 3: lea(rax, ptr[LDC + LDC * 2]); break;
1536 case 4: lea(rax, ptr[LDC + LDC * 4]); break;
1538 lea(rax, ptr[LDC * 4]);
1542 lea(rax, ptr[LDC + LDC * 2]);
1549 if (isLoad1Unmasked) {
1550 vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1552 vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]);
1556 for (int i = 0; i < unroll_n; i++) {
1557 vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA);
1559 if (isLoad1Unmasked) {
1561 case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break;
1562 case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break;
1564 vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1566 case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break;
1567 case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break;
1569 vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1575 vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]);
1578 vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]);
1582 ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]);
1585 vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]);
1588 vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]);
1592 ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]);
1598 vaddps(Ymm(i + 4), ymm0, Ymm(i + 4));
1600 fma(useFma, VBETA, ymm0, Ymm(i + 4), true);
1604 vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4));
1606 if (isLoad1Unmasked) {
1608 case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break;
1610 vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4));
1613 vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1615 case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break;
1617 vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4));
1620 vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
1626 vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4));
1630 ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1633 vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK,
1637 vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4));
1641 ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
1644 vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK,
1650 if (unroll_m >= 16) {
1651 // Re-use ymm4 (VBIAS2)
1654 if (isLoad1Unmasked) {
1655 vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]);
1658 VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]);
1662 vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA);
1664 if (isLoad2Unmasked) {
1666 case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break;
1668 vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]);
1671 vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]);
1673 case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break;
1675 vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]);
1678 vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]);
1684 vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]);
1688 ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]);
1691 vmaskmovps(ymm0, VMASK,
1692 ptr[CO1 + LDC * 2 + 8 * SIZE]);
1695 vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]);
1699 ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]);
1702 vmaskmovps(ymm0, VMASK,
1703 ptr[CO2 + LDC * 2 + 8 * SIZE]);
1708 vaddps(Ymm(i + 10), ymm0, Ymm(i + 10));
1710 fma(useFma, VBETA, ymm0, Ymm(i + 10), true);
1714 vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10));
1716 if (isLoad2Unmasked) {
1719 vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10));
1722 vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10));
1725 vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1728 vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10));
1731 vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10));
1734 vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
1740 vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10));
1743 vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK,
1747 vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK,
1751 vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10));
1754 vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK,
1758 vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK,
1767 if (unroll_n >= 4) {
1771 // Compute next address of B
1773 lea(rax, ptr[K * SIZE]);
1780 lea(BO1, ptr[BO1 + LDB * 2]);
1781 lea(BO2, ptr[BO2 + LDB * 2]);
1784 lea(BO1, ptr[BO1 + LDB3]);
1785 lea(BO2, ptr[BO2 + LDB3]);
1788 lea(BO1, ptr[BO1 + LDB * 4]);
1789 lea(BO2, ptr[BO2 + LDB * 4]);
1792 lea(BO1, ptr[BO1 + LDB * 4]);
1794 lea(BO2, ptr[BO2 + LDB * 4]);
1798 lea(BO1, ptr[BO1 + LDB3 * 2]);
1799 lea(BO2, ptr[BO2 + LDB3 * 2]);
1808 add(BO1, unroll_n * SIZE);
1812 auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1813 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1814 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1815 isDirect, isCopy, true);
1818 auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1819 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1820 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1821 isDirect, isCopy, true);
1824 auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1825 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1826 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1827 isDirect, isCopy, true);
1830 auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1831 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1832 bool useFma = true) {
1833 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1834 isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1835 Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1836 Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1837 Ymm(13), Ymm(14), Ymm(15));
1840 auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1841 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1842 kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1843 isDirect, isCopy, false);
1846 auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1847 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1848 kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1849 isDirect, isCopy, false);
1852 auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1853 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1854 bool useFma = true) {
1855 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1856 isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1857 Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1858 Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1862 auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1863 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1864 kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1868 auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1869 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1870 kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1874 auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1875 bool isLoad2Unmasked, bool isDirect, bool isCopy,
1876 bool useFma = true) {
1877 kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1878 isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
1879 Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
1880 Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
1881 Ymm(13), Ymm(14), Ymm(15));
1884 auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1885 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1886 kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1887 isDirect, isCopy, false);
1890 auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
1891 bool isLoad2Unmasked, bool isDirect, bool isCopy) {
1892 kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
1893 isDirect, isCopy, false);
1896 // High-level subroutine; does packing if needed, then splits C matrix.
1897 // Operates on chunks of 16 rows, 6 columns at a time (handling tail
1898 // cases appropriately).
1899 // Masking is used for tail cases where M is not divisible by 8.
1901 int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
1903 do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
1906 Label subloop11, subloop11mask;
1907 Label subloop20, subloop21, subloop22, subloop23;
1908 Label subloop24, subloop25;
1909 Label subloop30, subloop31, subloop32, subloop33;
1910 Label subloop34, subloop35;
1911 Label subloop98, subloop98mask;
1912 Label subloop99, subloop99mask;
1915 lea(CO2, ptr[CO1 + LDC * 2]);
1917 add(C, unroll_m * SIZE);
1920 lea(BO2, qword[B + LDB3]);
1924 lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
1926 jg(subloop98, T_NEAR);
1929 lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
1936 // If N is too small, skip copy operation
1937 cmp(LL, UNROLL_N * 3);
1938 jle(subloop30, T_NEAR);
1940 // If A is not aligned to cache line
1942 je(subloop30, T_NEAR);
1945 jl(subloop20, T_NEAR);
1950 if (unroll_m == 16) {
1951 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1952 isLoad2Unmasked, true, true);
1954 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1955 isLoad2Unmasked, true, true);
1958 if (unroll_m == 16) {
1959 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1960 isLoad2Unmasked, false, false);
1962 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1963 isLoad2Unmasked, false, false);
1969 jl(subloop20, T_NEAR);
1973 if (unroll_m == 16) {
1974 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
1975 isLoad2Unmasked, false, false);
1977 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked,
1982 jge(subloop11, T_NEAR);
1987 jne(subloop21, T_NEAR);
1988 if (unroll_m == 16) {
1989 kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
1992 kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
1995 jmp(subloop99, T_NEAR);
2000 jne(subloop22, T_NEAR);
2001 if (unroll_m == 16) {
2002 kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2005 kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
2008 jmp(subloop99, T_NEAR);
2013 jne(subloop23, T_NEAR);
2014 if (unroll_m == 16) {
2015 kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2018 kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
2021 jmp(subloop99, T_NEAR);
2026 jne(subloop24, T_NEAR);
2027 if (unroll_m == 16) {
2028 kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2031 kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
2034 jmp(subloop99, T_NEAR);
2039 jne(subloop99, T_NEAR);
2040 if (unroll_m == 16) {
2041 kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2044 kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
2047 jmp(subloop99, T_NEAR);
2053 jl(subloop25, T_NEAR);
2057 if (unroll_m == 16) {
2058 kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2059 isLoad2Unmasked, true, false);
2061 kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
2062 isLoad2Unmasked, true, false);
2066 jge(subloop31, T_NEAR);
2071 jne(subloop32, T_NEAR);
2072 if (unroll_m == 16) {
2073 kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2076 kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
2079 jmp(subloop99, T_NEAR);
2084 jne(subloop33, T_NEAR);
2085 if (unroll_m == 16) {
2086 kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2089 kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
2092 jmp(subloop99, T_NEAR);
2097 jne(subloop34, T_NEAR);
2098 if (unroll_m == 16) {
2099 kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2102 kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
2105 jmp(subloop99, T_NEAR);
2110 jne(subloop35, T_NEAR);
2111 if (unroll_m == 16) {
2112 kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2115 kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
2118 jmp(subloop99, T_NEAR);
2123 jne(subloop99, T_NEAR);
2124 if (unroll_m == 16) {
2125 kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2128 kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
2135 // Compute address for A
2137 add(A, unroll_m * SIZE);
2140 imul(rax, rax, unroll_m);
2144 // Compute next address of BIAS
2146 add(BIAS, unroll_m * SIZE);
2152 Label buffer_in_ws, buffer_allocated;
2154 // Get the registers
2164 vmovss(xmm0, ptr[ARG_ALPHA]);
2165 vmovss(xmm1, ptr[r15]);
2172 cmp(K, STACK_K_CAPACITY);
2173 jg(buffer_in_ws, T_NEAR);
2175 // Create buffer and align to 4kB page
2176 lea(rax, ptr[K * SIZE]);
2180 and_(rsp, -PAGE_4K);
2181 jmp(buffer_allocated, T_NEAR);
2186 L(buffer_allocated);
2194 vmovss(ALPHA, xmm0);
2196 sub(A, -OFFSET * SIZE);
2197 sub(B, -OFFSET * SIZE);
2199 sal(LDA, BASE_SHIFT);
2200 sal(LDB, BASE_SHIFT);
2201 sal(LDC, BASE_SHIFT);
2202 lea(LDB3, ptr[LDB + LDB * 2]);
2204 for (int i = 0; i < 8; i++) {
2205 mov(dword[rsp + 88 + i * 4], i);
2208 if (isTransA && is_avx2) {
2210 vpbroadcastq(ymm1, xmm0);
2211 vinsertf128(ymm0, ymm0, xmm0, 1);
2212 vpermilpd(ymm0, ymm0, 5);
2213 vpaddq(ymm1, ymm1, ymm1);
2214 vperm2f128(ymm1, ymm1, ymm1, 8);
2215 vpaddq(ymm0, ymm0, ymm1);
2216 vmovups(STRIDE, ymm0);
2219 // Check A alignment and leading dimension; take copy-based path as
2226 Label main0, main1, main2, main3, main999;
2233 subloop(UNROLL_M, true, true);
2241 jle(main999, T_NEAR);
2248 vbroadcastss(VMASK, M);
2249 vpcmpgtd(VMASK, VMASK, MASK);
2251 subloop(16, true, false);
2252 jmp(main999, T_NEAR);
2258 subloop(8, true, true);
2259 jmp(main999, T_NEAR);
2265 vbroadcastss(VMASK, M);
2267 vpcmpgtd(VMASK, VMASK, MASK);
2269 auto xmask = Xmm(VMASK.getIdx());
2270 auto xmm_tmp = xmm4;
2272 vextractf128(xmm_tmp, VMASK, 1);
2273 vpcmpgtd(xmask, xmask, MASK);
2274 vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4
2275 vinsertf128(VMASK, VMASK, xmm_tmp, 1);
2277 subloop(8, false, false);
2281 // Restore original stack
2287 ker_ = this->getCode<ker_t>();
2290 typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
2291 const float *alpha, const float *a, dim_t lda,
2292 const float *b, dim_t ldb, const float *beta, float *c,
2293 dim_t ldc, const float *bias, float *ws);
2295 void operator()(dim_t m, dim_t n, dim_t k,
2296 const float *alpha, const float *a, dim_t lda,
2297 const float *b, dim_t ldb, const float *beta, float *c,
2298 dim_t ldc, const float *bias, float *ws) const
2300 ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
2307 const xbyak_gemm *get_xbyak_gemm(
2308 bool isTransA, bool isTransB, float beta, bool hasBias) {
2309 auto beta_idx = [](float beta) {
2310 return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
2313 // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
2314 static xbyak_gemm *kernel_table[2][2][2][3];
2315 static std::once_flag initialized;
2316 std::call_once(initialized, [=]{
2317 for (bool isTransA: {false, true})
2318 for (bool isTransB: {false, true})
2319 for (bool hasBias: {false, true})
2320 for (float beta: {0.0f, 1.0f, 2.0f}) {
2321 // nocopy sgemm with bias for beta != 0.0 is not supported
2322 if (hasBias && beta != 0.0)
2324 kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
2325 new xbyak_gemm(isTransA, isTransB, beta, hasBias);
2329 return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
2332 void sgemm_nocopy_driver(const char *transa,
2333 const char *transb, int m, int n, int k, const float *alpha,
2334 const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
2335 float *c, dim_t ldc, const float *bias, float *ws)
2337 bool isTransA = (*transa == 'T' || *transa == 't');
2338 bool isTransB = (*transb == 'T' || *transb == 't');
2340 int Bm, sizeM, Bn, sizeN, Bk, sizeK;
2344 if ((m <= 0) || (n <= 0))
2347 if ((k <= 0) || (alpha[0] == 0.)) {
2349 if (beta[0] == 0.) {
2350 for (j = 0; j < n; j++)
2351 for (i = 0; i < m; i++)
2352 c[i + j * ldc] = 0.0;
2353 } else if (beta[0] != 1.) {
2354 for (j = 0; j < n; j++)
2355 for (i = 0; i < m; i++)
2356 c[i + j * ldc] *= beta[0];
2362 assert(IMPLICATION(bias != nullptr, *beta == 0.0));
2364 // XXX: this happens on every thread...
2365 bool hasBias = (bias != nullptr);
2366 auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
2367 auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
2368 auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
2369 assert(ker_bn && ker_b1 && ker_b0);
2372 int BN = isTransA ? 96 : 48;
2373 int BK = isTransB ? 96 : 256;
2374 const float *curA, *curB, *curBias = nullptr;
2377 for (Bk = 0; Bk < k; Bk += sizeK) {
2379 if (sizeK >= BK * 2)
2383 sizeK = (sizeK + 1) / 2;
2386 for (Bm = 0; Bm < m; Bm += sizeM) {
2388 if (sizeM >= BM * 2)
2391 if (sizeM > BM + BM / 2)
2392 sizeM = (sizeM + 1) / 2;
2395 for (Bn = 0; Bn < n; Bn += sizeN) {
2397 if (sizeN >= BN * 2)
2400 if (sizeN > BN + BN / 2)
2401 sizeN = (sizeN + 1) / 2;
2405 curA = a + Bm + Bk * lda;
2407 curA = a + Bk + Bm * lda;
2410 curB = b + Bk + Bn * ldb;
2412 curB = b + Bn + Bk * ldb;
2414 curC = c + Bm + (size_t)Bn * ldc;
2415 if (bias != nullptr) {
2417 curBias = bias + Bm;
2423 if (*beta == 0.0 && bias == nullptr)
2424 (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
2425 alpha, curA, lda, curB, ldb, beta, curC, ldc,
2428 (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
2429 alpha, curA, lda, curB, ldb, beta, curC, ldc,
2432 (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
2433 alpha, curA, lda, curB, ldb, beta, curC, ldc,
2443 mkldnn_status_t jit_avx_gemm_f32(
2444 const char *transa, const char *transb,
2445 const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
2446 const float *A, const int *p_lda, const float *B, const int *p_ldb,
2447 const float *p_beta, float *C, const int *p_ldc, const float *bias)
2449 using namespace mkldnn::impl::utils;
2450 using namespace avx_gemm_f32;
2451 using namespace gemm_utils;
2453 if (*p_beta != 0 && bias)
2454 return ref_gemm(transa, transb, p_m, p_n, p_k,
2455 p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
2457 int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
2465 float beta = *p_beta;
2468 int nthr_m, nthr_n, nthr_k, nthr_mn;
2470 // Determine threading partitioning
2471 calc_nthr_nocopy_avx(
2472 m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
2473 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
2475 // May not happen, but just in case
2476 if (nthr < nthr_m * nthr_n * nthr_k)
2477 nthr = nthr_m * nthr_n * nthr_k;
2479 nthr_mn = nthr_m * nthr_n;
2481 unsigned char * ompstatus_ = nullptr;
2482 unsigned char volatile *ompstatus = nullptr;
2484 float *c_buffers = nullptr;
2485 float *ws_buffers = nullptr;
2488 ompstatus_ = (unsigned char *) malloc(
2489 nthr * CACHE_LINE_SIZE,
2491 ompstatus = (unsigned char volatile *) ompstatus_;
2494 for (int i = 0; i < nthr; i++)
2495 ompstatus[i * CACHE_LINE_SIZE] = 0;
2497 c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
2498 * sizeof(float), PAGE_4K);
2501 const size_t ws_elems_per_thr = (size_t)k * 16 + 64;
2502 const size_t ws_size_per_thr
2503 = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
2504 if (k > STACK_K_CAPACITY) {
2505 ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
2508 parallel_nd(nthr, [&](const int ithr) {
2509 int ithr_m, ithr_n, ithr_k, ithr_mn;
2510 int m_from, m_to, myM;
2511 int n_from, n_to, myN;
2512 int k_from, k_to, myK;
2514 const float *myA, *myB, *myBias = nullptr;
2515 float *myC = C, myBeta;
2516 float *ws = ws_buffers ?
2517 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
2520 int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
2522 if (ithr < nthr_m * nthr_n * nthr_k) {
2524 ithr_mn = ithr % nthr_mn;
2525 ithr_m = ithr_mn % nthr_m;
2526 ithr_n = ithr_mn / nthr_m;
2527 ithr_k = ithr / nthr_mn;
2529 /* swap ithr_k for performance improvement */
2531 ithr_k = nthr_k - 1;
2532 else if (ithr_k == nthr_k - 1)
2535 m_from = MB * (ithr_m);
2536 m_to = MB * (ithr_m + 1);
2539 myM = m_to - m_from;
2541 n_from = NB * (ithr_n);
2542 n_to = NB * (ithr_n + 1);
2545 myN = n_to - n_from;
2547 k_from = KB * (ithr_k);
2548 k_to = KB * (ithr_k + 1);
2551 myK = k_to - k_from;
2553 cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2554 ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
2556 if ((myM > 0) && (myN > 0)) {
2558 if (*transa == 'N' || *transa == 'n') {
2559 myA = &(A[m_from + k_from * lda]);
2561 myA = &(A[k_from + m_from * lda]);
2563 if (*transb == 'N' || *transb == 'n') {
2564 myB = &(B[k_from + n_from * ldb]);
2566 myB = &(B[n_from + k_from * ldb]);
2569 myC = &(C[m_from + n_from * ldc]);
2573 myBias = &(bias[m_from]);
2575 myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
2581 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
2582 lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
2584 if (nthr_k > 1 && !sum_later)
2585 ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
2588 if (nthr_k > 1 && !sum_later) {
2590 // sum matrices partitioned along K dimension
2593 partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2597 myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
2599 /* need to wait until main thread finishes */
2600 while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
2603 /* my cache is hot */
2604 sum_two_matrices(myM, n2, myC, MB,
2605 &C[m_from + (n_from + n1) * ldc], ldc);
2608 for (int ik = 1; ik < nthr_k; ++ik) {
2611 myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
2614 while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2617 sum_two_matrices(myM, n2, myC, MB,
2618 &C[m_from + (n_from + n1) * ldc], ldc);
2625 // handle C summation later
2626 if (nthr_k > 1 && ompstatus[0] == 0) {
2628 parallel_nd(nthr, [&](const int ithr) {
2629 int ithr_m, ithr_n, ithr_k, ithr_mn;
2630 int m_from, m_to, myM;
2631 int n_from, n_to, myN;
2635 if (ithr < nthr_m * nthr_n * nthr_k) {
2637 ithr_mn = ithr % nthr_mn;
2638 ithr_m = ithr_mn % nthr_m;
2639 ithr_n = ithr_mn / nthr_m;
2640 ithr_k = ithr / nthr_mn;
2642 /* swap ithr_k for performance improvement */
2644 ithr_k = nthr_k - 1;
2645 else if (ithr_k == nthr_k - 1)
2648 m_from = MB * (ithr_m);
2649 m_to = MB * (ithr_m + 1);
2652 myM = m_to - m_from;
2654 n_from = NB * (ithr_n);
2655 n_to = NB * (ithr_n + 1);
2658 myN = n_to - n_from;
2660 cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2663 // sum matrices partitioned along K dimension
2666 partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2670 myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
2673 /* my cache is hot */
2674 sum_two_matrices(myM, n2, myC, MB,
2675 &C[m_from + (n_from + n1) * ldc], ldc);
2678 for (int ik = 1; ik < nthr_k; ++ik) {
2681 myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
2684 sum_two_matrices(myM, n2, myC, MB,
2685 &C[m_from + (n_from + n1) * ldc], ldc);
2698 return mkldnn_success;
2705 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s