1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "mkldnn_thread.hpp"
23 #include "ref_gemm_f32.hpp"
24 #include "gemm_utils_f32.hpp"
25 #include "jit_avx512_common_gemm_f32.hpp"
27 #include "jit_generator.hpp"
33 #define CACHE_LINE_SIZE 64
35 #define STACKSIZE get_size_of_abi_save_regs()
37 #define STACK_K_CAPACITY 32
39 #define STACK_K_CAPACITY 2048
44 #define SECOND_FETCH unroll_n
48 namespace avx512_common_gemm_f32 {
49 using namespace gemm_utils;
51 struct xbyak_gemm : public jit_generator {
52 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm)
54 xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
55 void *code_ptr = nullptr,
56 size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
57 : jit_generator(code_ptr, code_size)
59 using namespace Xbyak;
61 enum { ver_avx512_core, ver_avx512_mic } ver =
62 mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic;
64 bool isBeta0 = (beta == 0.0);
65 bool isBetaN = (!isBeta0 && beta != 1.0);
67 // various definitions for convenience
68 auto ARG_M = abi_param1;
69 auto ARG_N = abi_param2;
71 auto ARG_ALPHA = abi_param4;
73 auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
74 auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
75 sizeof(float *) + STACKSIZE];
76 const auto stackOffset = OFFSET_SHADOWSPACE +
77 sizeof(float *) + STACKSIZE;
83 const auto stackOffset = STACKSIZE;
87 auto ARG_B = ptr[rsp + 8 + stackOffset];
88 auto ARG_LDB = ptr[rsp + 16 + stackOffset];
89 auto ARG_BETA = ptr[rsp + 24 + stackOffset];
90 auto ARG_C = ptr[rsp + 32 + stackOffset];
91 auto ARG_LDC = ptr[rsp + 40 + stackOffset];
92 auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
93 auto ARG_WS = ptr[rsp + 56 + stackOffset];
99 auto AO1 = abi_param2;
100 auto BO1 = abi_param4;
105 auto LDA4 = abi_param1;
107 auto BIAS1 = abi_param1;
109 auto M = qword[rsp + 0];
110 auto N = qword[rsp + 8];
111 auto FLAG = qword[rsp + 16];
112 auto I = qword[rsp + 24];
113 auto C = qword[rsp + 32];
114 auto BIAS = qword[rsp + 40];
115 auto ALPHA = qword[rsp + 48];
116 auto BETA = qword[rsp + 64];
117 auto ORIG_A = qword[rsp + 80];
118 auto ORIG_SP = qword[rsp + 120];
127 auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80;
128 auto PREFETCHSIZEB = 16;
130 Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15,
131 zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24,
132 zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 };
134 // Function for packing if needed
135 auto do_pack = [&](int unroll_m) {
136 Label pack2, pack3, pack4, pack10;
139 lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
147 for (int i = 0; i < 4; i++) {
148 vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
150 vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
152 vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
155 vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE]
160 + (unroll_m * i + 1 * 16 - OFFSET)
166 + (unroll_m * i + 2 * 16 - OFFSET)
172 for (int i = 0; i < 4; i++) {
174 vgatherqps(ymm5 | k4,
175 ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]);
176 lea(BO2, ptr[BO1 + LDA * 8]);
178 vgatherqps(ymm6 | k4,
179 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
180 vshuff64x2(zmm0, zmm5, zmm6, 0x44);
183 lea(BO2, ptr[BO2 + LDA * 8]);
185 vgatherqps(ymm5 | k4,
186 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
187 lea(BO2, ptr[BO2 + LDA * 8]);
189 vgatherqps(ymm6 | k4,
190 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
191 vshuff64x2(zmm1, zmm5, zmm6, 0x44);
195 lea(BO2, ptr[BO2 + LDA * 8]);
197 vgatherqps(ymm5 | k4,
198 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
199 lea(BO2, ptr[BO2 + LDA * 8]);
201 vgatherqps(ymm6 | k4,
202 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
203 lea(BO2, ptr[BO2 + LDA * 8]);
204 vshuff64x2(zmm2, zmm5, zmm6, 0x44);
207 vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE],
211 + (unroll_m * i + 1 * 16 - OFFSET)
216 + (unroll_m * i + 2 * 16 - OFFSET)
222 add(AO1, unroll_m * 4 * SIZE);
236 vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
238 vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
240 vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
244 vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]);
245 lea(BO2, ptr[BO1 + LDA * 8]);
247 vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
248 vshuff64x2(zmm0, zmm5, zmm6, 0x44);
251 lea(BO2, ptr[BO2 + LDA * 8]);
253 vgatherqps(ymm5 | k4,
254 ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
255 lea(BO2, ptr[BO2 + LDA * 8]);
257 vgatherqps(ymm6 | k4,
258 ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
259 vshuff64x2(zmm1, zmm5, zmm6, 0x44);
263 lea(BO2, ptr[BO2 + LDA * 8]);
265 vgatherqps(ymm5 | k4,
266 ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
267 lea(BO2, ptr[BO2 + LDA * 8]);
269 vgatherqps(ymm6 | k4,
270 ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
271 lea(BO2, ptr[BO2 + LDA * 8]);
272 vshuff64x2(zmm2, zmm5, zmm6, 0x44);
277 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
280 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE],
283 vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE],
286 add(AO1, unroll_m * SIZE);
294 // Function to update C, covering masking and other considerations
295 auto update = [&](Zmm reg, bool useCO1, int offset, int mask,
296 bool useScale = false) {
297 vmulps(reg, reg, VALPHA);
303 vmovups(zmm0, ptr[CO1 + offset * SIZE]);
305 vmovups(zmm0, ptr[CO2 + offset * SIZE]);
309 vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]);
311 vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]);
315 vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]);
317 vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]);
321 vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]);
323 vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]);
330 vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]);
332 vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]);
336 vmovups(zmm0 | k1 | T_z,
337 ptr[CO1 + LDC + offset * SIZE]);
339 vmovups(zmm0 | k1 | T_z,
340 ptr[CO2 + LDC + offset * SIZE]);
344 vmovups(zmm0 | k2 | T_z,
345 ptr[CO1 + LDC + offset * SIZE]);
347 vmovups(zmm0 | k2 | T_z,
348 ptr[CO2 + LDC + offset * SIZE]);
352 vmovups(zmm0 | k3 | T_z,
353 ptr[CO1 + LDC + offset * SIZE]);
355 vmovups(zmm0 | k3 | T_z,
356 ptr[CO2 + LDC + offset * SIZE]);
361 vaddps(zmm0, reg, zmm0);
363 vfmadd132ps(zmm0, reg, VBETA);
369 vmovups(ptr[CO1 + offset * SIZE], zmm0);
371 vmovups(ptr[CO2 + offset * SIZE], zmm0);
375 vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1);
377 vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1);
381 vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2);
383 vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2);
387 vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3);
389 vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3);
396 vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0);
398 vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0);
402 vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1);
404 vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1);
408 vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2);
410 vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2);
414 vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3);
416 vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3);
425 vmovups(ptr[CO1 + offset * SIZE], reg);
427 vmovups(ptr[CO2 + offset * SIZE], reg);
431 vmovups(ptr[CO1 + offset * SIZE], reg | k1);
433 vmovups(ptr[CO2 + offset * SIZE], reg | k1);
437 vmovups(ptr[CO1 + offset * SIZE], reg | k2);
439 vmovups(ptr[CO2 + offset * SIZE], reg | k2);
443 vmovups(ptr[CO1 + offset * SIZE], reg | k3);
445 vmovups(ptr[CO2 + offset * SIZE], reg | k3);
452 vmovups(ptr[CO1 + LDC + offset * SIZE], reg);
454 vmovups(ptr[CO2 + LDC + offset * SIZE], reg);
458 vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1);
460 vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1);
464 vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2);
466 vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2);
470 vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3);
472 vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3);
477 vpxorq(reg, reg, reg);
480 // Loop with unroll_n - 2 FMAs; called by innerkernel
481 auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) {
482 for (int i = 2; i < unroll_n; i++) {
483 if (ver == ver_avx512_core) {
490 + (iteration - OFFSET) * SIZE]);
496 + (iteration - OFFSET) * SIZE]);
500 ptr[BO2 + (iteration - OFFSET) * SIZE]);
506 + (iteration - OFFSET) * SIZE]);
512 + (iteration - OFFSET) * SIZE]);
518 + (iteration - OFFSET) * SIZE]);
522 vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
524 vfmadd231ps(regs[i], zmm3, zmm0);
526 vfmadd231ps(regs[i + 8], zmm3, zmm1);
528 vfmadd231ps(regs[i + 16], zmm3, zmm2);
533 vfmadd231ps(regs[i], zmm0,
534 zword_b[BO1 + LDB * 2
535 + (iteration - OFFSET) * SIZE]);
537 vfmadd231ps(regs[i + 8], zmm1,
538 zword_b[BO1 + LDB * 2
539 + (iteration - OFFSET) * SIZE]);
541 vfmadd231ps(regs[i + 16], zmm2,
542 zword_b[BO1 + LDB * 2
543 + (iteration - OFFSET) * SIZE]);
546 vfmadd231ps(regs[i], zmm0,
548 + (iteration - OFFSET) * SIZE]);
550 vfmadd231ps(regs[i + 8], zmm1,
552 + (iteration - OFFSET) * SIZE]);
554 vfmadd231ps(regs[i + 16], zmm2,
556 + (iteration - OFFSET) * SIZE]);
559 vfmadd231ps(regs[i], zmm0,
560 zword_b[BO2 + (iteration - OFFSET) * SIZE]);
562 vfmadd231ps(regs[i + 8], zmm1,
563 zword_b[BO2 + (iteration - OFFSET) * SIZE]);
565 vfmadd231ps(regs[i + 16], zmm2,
566 zword_b[BO2 + (iteration - OFFSET) * SIZE]);
569 vfmadd231ps(regs[i], zmm0,
570 zword_b[BO2 + LDB * 1
571 + (iteration - OFFSET) * SIZE]);
573 vfmadd231ps(regs[i + 8], zmm1,
574 zword_b[BO2 + LDB * 1
575 + (iteration - OFFSET) * SIZE]);
577 vfmadd231ps(regs[i + 16], zmm2,
578 zword_b[BO2 + LDB * 1
579 + (iteration - OFFSET) * SIZE]);
582 vfmadd231ps(regs[i], zmm0,
583 zword_b[BO2 + LDB * 2
584 + (iteration - OFFSET) * SIZE]);
586 vfmadd231ps(regs[i + 8], zmm1,
587 zword_b[BO2 + LDB * 2
588 + (iteration - OFFSET) * SIZE]);
590 vfmadd231ps(regs[i + 16], zmm2,
591 zword_b[BO2 + LDB * 2
592 + (iteration - OFFSET) * SIZE]);
595 vfmadd231ps(regs[i], zmm0,
597 + (iteration - OFFSET) * SIZE]);
599 vfmadd231ps(regs[i + 8], zmm1,
601 + (iteration - OFFSET) * SIZE]);
603 vfmadd231ps(regs[i + 16], zmm2,
605 + (iteration - OFFSET) * SIZE]);
610 regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]);
612 vfmadd231ps(regs[i + 8], zmm1,
613 zword_b[BO1 + (i - OFFSET) * SIZE]);
615 vfmadd231ps(regs[i + 16], zmm2,
616 zword_b[BO1 + (i - OFFSET) * SIZE]);
622 // Innerkernel; called by kernel
623 auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect,
624 bool isCopy, bool doCPrefetch, bool isUnmasked = true) {
625 for (int i = 0; i < 8; i++) {
628 + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET)
632 + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET)
636 + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET)
639 prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]);
641 prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]);
643 prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]);
648 if (isUnmasked || unroll_m > 16) {
651 + (unroll_m * i + 0 * 16 - OFFSET)
654 vmovups(zmm0 | k1 | T_z,
656 + (unroll_m * i + 0 * 16 - OFFSET)
659 if (unroll_m >= 32) {
660 if (isUnmasked || unroll_m > 32) {
661 vmovups(zmm1, ptr[AO1
662 + (unroll_m * i + 1 * 16
666 vmovups(zmm1 | k2 | T_z,
668 + (unroll_m * i + 1 * 16
673 if (unroll_m >= 48) {
675 vmovups(zmm2, ptr[AO1
676 + (unroll_m * i + 2 * 16
680 vmovups(zmm2 | k3 | T_z,
682 + (unroll_m * i + 2 * 16
689 if (isUnmasked || unroll_m > 16) {
690 vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
692 vmovups(zmm0 | k1 | T_z,
693 ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
695 if (unroll_m >= 32) {
696 if (isUnmasked || unroll_m > 32) {
697 vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
699 vmovups(zmm1 | k2 | T_z,
700 ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
703 if (unroll_m >= 48) {
705 vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
707 vmovups(zmm2 | k3 | T_z,
708 ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
714 if (ver == ver_avx512_core) {
716 vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
718 vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
720 vfmadd231ps(regs[0], zmm3, zmm0);
722 vfmadd231ps(regs[0 + 8], zmm3, zmm1);
724 vfmadd231ps(regs[0 + 16], zmm3, zmm2);
727 vfmadd231ps(regs[0], zmm0,
728 zword_b[BO1 + (i - OFFSET) * SIZE]);
730 vfmadd231ps(regs[0 + 8], zmm1,
731 zword_b[BO1 + (i - OFFSET) * SIZE]);
733 vfmadd231ps(regs[0 + 16], zmm2,
734 zword_b[BO1 + (i - OFFSET) * SIZE]);
736 vfmadd231ps(regs[0], zmm0,
737 zword_b[BO1 + (0 - OFFSET) * SIZE]);
739 vfmadd231ps(regs[0 + 8], zmm1,
740 zword_b[BO1 + (0 - OFFSET) * SIZE]);
742 vfmadd231ps(regs[0 + 16], zmm2,
743 zword_b[BO1 + (0 - OFFSET) * SIZE]);
747 if (unroll_n >= i + 1) {
752 ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]);
755 prefetcht0(ptr[BO1 + LDB
756 + (PREFETCHSIZEB - OFFSET) * SIZE]);
759 prefetcht0(ptr[BO1 + LDB * 2
760 + (PREFETCHSIZEB - OFFSET) * SIZE]);
763 prefetcht0(ptr[BO1 + LDB3
764 + (PREFETCHSIZEB - OFFSET) * SIZE]);
768 ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]);
771 prefetcht0(ptr[BO2 + LDB
772 + (PREFETCHSIZEB - OFFSET) * SIZE]);
775 prefetcht0(ptr[BO2 + LDB * 2
776 + (PREFETCHSIZEB - OFFSET) * SIZE]);
779 prefetcht0(ptr[BO2 + LDB3
780 + (PREFETCHSIZEB - OFFSET) * SIZE]);
787 if (ver == ver_avx512_core) {
790 ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
792 vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]);
794 vfmadd231ps(regs[1], zmm3, zmm0);
796 vfmadd231ps(regs[1 + 8], zmm3, zmm1);
798 vfmadd231ps(regs[1 + 16], zmm3, zmm2);
801 vfmadd231ps(regs[1], zmm0,
802 zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
804 vfmadd231ps(regs[1 + 8], zmm1,
805 zword_b[BO1 + LDB * 1
806 + (i - OFFSET) * SIZE]);
808 vfmadd231ps(regs[1 + 16], zmm2,
809 zword_b[BO1 + LDB * 1
810 + (i - OFFSET) * SIZE]);
812 vfmadd231ps(regs[1], zmm0,
813 zword_b[BO1 + (1 - OFFSET) * SIZE]);
815 vfmadd231ps(regs[1 + 8], zmm1,
816 zword_b[BO1 + (1 - OFFSET) * SIZE]);
818 vfmadd231ps(regs[1 + 16], zmm2,
819 zword_b[BO1 + (1 - OFFSET) * SIZE]);
825 if (isUnmasked || unroll_m > 16) {
827 + (unroll_m * i + 0 * 16 - OFFSET)
832 + (unroll_m * i + 0 * 16 - OFFSET)
836 if (unroll_m >= 32) {
837 if (isUnmasked || unroll_m > 32) {
839 + (unroll_m * i + 1 * 16 - OFFSET)
844 + (unroll_m * i + 1 * 16 - OFFSET)
849 if (unroll_m >= 48) {
852 + (unroll_m * i + 2 * 16 - OFFSET)
857 + (unroll_m * i + 2 * 16 - OFFSET)
863 sub(LDA4, -unroll_m * 8 * SIZE);
865 fmaloop(unroll_m, unroll_n, i);
869 if (ver == ver_avx512_core)
870 prefetchw(ptr[CO2 + 0 * 16 * SIZE]);
872 prefetcht0(ptr[CO2 + 0 * 16 * SIZE]);
876 if (doCPrefetch && unroll_m >= 32) {
877 if (ver == ver_avx512_core)
878 prefetchw(ptr[CO2 + 1 * 16 * SIZE]);
880 prefetcht0(ptr[CO2 + 1 * 16 * SIZE]);
883 if (ver == ver_avx512_core)
884 prefetcht0(ptr[AA + 16 * 0 * SIZE]);
886 prefetcht2(ptr[AA + 16 * 0 * SIZE]);
891 if (unroll_m >= 48) {
892 if (ver == ver_avx512_core)
893 prefetchw(ptr[CO2 + 2 * 16 * SIZE]);
895 prefetcht0(ptr[CO2 + 2 * 16 * SIZE]);
900 if (unroll_m >= 32) {
901 if (ver == ver_avx512_core)
902 prefetcht0(ptr[AA + 16 * 1 * SIZE]);
904 prefetcht2(ptr[AA + 16 * 1 * SIZE]);
910 prefetcht0(ptr[BO1 + BO2]);
921 if (unroll_m >= 48) {
922 if (ver == ver_avx512_core)
923 prefetcht0(ptr[AA + 16 * 2 * SIZE]);
925 prefetcht2(ptr[AA + 16 * 2 * SIZE]);
927 lea(AA, ptr[AA + LDA]);
931 if (isUnmasked || unroll_m > 16) {
933 ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
935 vmovups(zmm0 | k1 | T_z,
936 ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
938 if (unroll_m >= 32) {
939 if (isUnmasked || unroll_m > 32) {
940 vmovups(zmm1, ptr[AO1
941 + (unroll_m * 8 + 1 * 16 - OFFSET)
944 vmovups(zmm1 | k2 | T_z,
946 + (unroll_m * 8 + 1 * 16 - OFFSET)
950 if (unroll_m >= 48) {
952 vmovups(zmm2, ptr[AO1
953 + (unroll_m * 8 + 2 * 16 - OFFSET)
956 vmovups(zmm2 | k3 | T_z,
958 + (unroll_m * 8 + 2 * 16 - OFFSET)
962 sub(AO1, -unroll_m * 8 * SIZE);
968 // Main kernel; does prefetching and calls innerkernel
969 // After calculating results in registers, writes back to C matrix by
971 auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
972 bool isCopy, bool isUnmasked = true) {
974 lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
980 lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]);
982 auto step = ver == ver_avx512_core ? 2 : 4;
983 lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]);
987 lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]);
991 if (isUnmasked || unroll_m > 16) {
993 ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
995 vmovups(zmm0 | k1 | T_z,
996 ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
998 if (unroll_m >= 32) {
999 if (isUnmasked || unroll_m > 32) {
1000 vmovups(zmm1, ptr[AO1
1001 + (unroll_m * 0 + 1 * 16 - OFFSET)
1004 vmovups(zmm1 | k2 | T_z,
1006 + (unroll_m * 0 + 1 * 16 - OFFSET)
1010 if (unroll_m >= 48) {
1012 vmovups(zmm2, ptr[AO1
1013 + (unroll_m * 0 + 2 * 16 - OFFSET)
1016 vmovups(zmm2 | k3 | T_z,
1018 + (unroll_m * 0 + 2 * 16 - OFFSET)
1024 Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18;
1028 sub(LL, SECOND_FETCH);
1029 jle(kernel13, T_NEAR);
1034 unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked);
1035 jg(kernel12, T_NEAR);
1039 lea(CO2, ptr[CO1 + (16 - 1) * SIZE]);
1041 jle(kernel15, T_NEAR);
1045 innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked);
1046 jg(kernel14, T_NEAR);
1052 jle(kernel18, T_NEAR);
1057 if (isUnmasked || unroll_m > 16) {
1058 vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
1060 vmovups(zmm0 | k1 | T_z,
1061 ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
1063 if (unroll_m >= 32) {
1064 if (isUnmasked || unroll_m > 32) {
1065 vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
1067 vmovups(zmm1 | k2 | T_z,
1068 ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
1071 if (unroll_m >= 48) {
1073 vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
1075 vmovups(zmm2 | k3 | T_z,
1076 ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
1082 for (int i = 0; i < unroll_n; i++) {
1086 vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
1090 zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1094 zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1098 zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]);
1101 vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]);
1105 zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1109 zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1113 zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]);
1117 vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
1119 vfmadd231ps(regs[i], zmm3, zmm0);
1120 if (unroll_m >= 32) {
1121 vfmadd231ps(regs[i + 8], zmm3, zmm1);
1123 if (unroll_m >= 48) {
1124 vfmadd231ps(regs[i + 16], zmm3, zmm2);
1129 if (isUnmasked || unroll_m > 16) {
1130 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
1133 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
1136 if (unroll_m >= 32) {
1137 if (isUnmasked || unroll_m > 32) {
1139 + (unroll_m * 0 + 1 * 16 - OFFSET)
1144 + (unroll_m * 0 + 1 * 16 - OFFSET)
1149 if (unroll_m >= 48) {
1152 + (unroll_m * 0 + 2 * 16 - OFFSET)
1157 + (unroll_m * 0 + 2 * 16 - OFFSET)
1162 sub(LDA4, -unroll_m * SIZE);
1166 if (isUnmasked || unroll_m > 16) {
1168 ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
1170 vmovups(zmm0 | k1 | T_z,
1171 ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
1173 if (unroll_m >= 32) {
1174 if (isUnmasked || unroll_m > 32) {
1175 vmovups(zmm1, ptr[AO1
1176 + (unroll_m * 1 + 1 * 16 - OFFSET)
1179 vmovups(zmm1 | k2 | T_z,
1181 + (unroll_m * 1 + 1 * 16 - OFFSET)
1185 if (unroll_m >= 48) {
1187 vmovups(zmm2, ptr[AO1
1188 + (unroll_m * 1 + 2 * 16 - OFFSET)
1191 vmovups(zmm2 | k3 | T_z,
1193 + (unroll_m * 1 + 2 * 16 - OFFSET)
1197 sub(AO1, -unroll_m * SIZE);
1202 if (unroll_n >= 4) {
1210 jg(kernel16, T_NEAR);
1214 vbroadcastss(VALPHA, ALPHA);
1217 vbroadcastss(VBETA, BETA);
1220 // Write back the results; all beta cases need to be handled
1223 if (isUnmasked || unroll_m > 16)
1224 vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1226 vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]);
1227 if (unroll_m >= 32) {
1228 if (isUnmasked || unroll_m > 32)
1229 vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]);
1231 vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]);
1233 if (unroll_m >= 48) {
1235 vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]);
1237 vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]);
1241 for (int i = 0; i < unroll_n; i++) {
1242 bool useScale = i % 2 != 0;
1243 bool useCO1 = i < 2;
1245 lea(CO2, ptr[CO1 + LDC * 2]);
1246 if (i == 4 || i == 6)
1247 lea(CO2, ptr[CO2 + LDC * 2]);
1249 vaddps(regs[i], VBIAS1, regs[i]);
1250 if (isUnmasked || unroll_m > 16) {
1251 update(regs[i], useCO1, 0, 0, useScale);
1253 update(regs[i], useCO1, 0, 1, useScale);
1255 if (unroll_m >= 32) {
1257 vaddps(regs[i + 8], VBIAS2, regs[i + 8]);
1258 if (isUnmasked || unroll_m > 32) {
1259 update(regs[i + 8], useCO1, 16, 0, useScale);
1261 update(regs[i + 8], useCO1, 16, 2, useScale);
1264 if (unroll_m >= 48) {
1266 vaddps(regs[i + 16], VBIAS3, regs[i + 16]);
1268 update(regs[i + 16], useCO1, 32, 0, useScale);
1270 update(regs[i + 16], useCO1, 32, 3, useScale);
1276 case 1: add(CO1, LDC); break;
1277 case 2: lea(CO1, ptr[CO1 + LDC * 2]); break;
1278 case 3: lea(CO1, ptr[CO2 + LDC * 1]); break;
1279 case 4: lea(CO1, ptr[CO2 + LDC * 2]); break;
1280 case 5: lea(CO1, ptr[CO2 + LDC * 1]); break;
1281 case 6: lea(CO1, ptr[CO2 + LDC * 2]); break;
1282 case 7: lea(CO1, ptr[CO2 + LDC * 1]); break;
1283 case 8: lea(CO1, ptr[CO2 + LDC * 2]); break;
1286 // Compute next address of B
1288 lea(rax, ptr[K * SIZE]);
1295 lea(BO1, ptr[BO1 + LDB * 2]);
1296 lea(BO2, ptr[BO2 + LDB * 2]);
1299 lea(BO1, ptr[BO1 + LDB3]);
1300 lea(BO2, ptr[BO2 + LDB3]);
1303 lea(BO1, ptr[BO1 + LDB * 4]);
1304 lea(BO2, ptr[BO2 + LDB * 4]);
1307 lea(BO1, ptr[BO1 + LDB * 4]);
1309 lea(BO2, ptr[BO2 + LDB * 4]);
1313 lea(BO1, ptr[BO1 + LDB3 * 2]);
1314 lea(BO2, ptr[BO2 + LDB3 * 2]);
1317 lea(BO1, ptr[BO1 + LDB * 8]);
1319 lea(BO2, ptr[BO2 + LDB * 8]);
1323 lea(BO1, ptr[BO1 + LDB * 8]);
1324 lea(BO2, ptr[BO2 + LDB * 8]);
1333 add(BO1, unroll_n * SIZE);
1337 // High-level subroutine; does packing if needed, then splits C matrix.
1338 // Operates on chunks of 48 rows, 8 columns at a time (handling tail
1339 // cases appropriately by doing 32 or 16 rows, and/or with masking,
1340 // and/or fewer columns).
1341 auto subloop = [&](int unroll_m) {
1342 Label l_subloop_20x[8], l_subloop_mask_20x[8];
1343 Label l_subloop_30x[8], l_subloop_mask_30x[8];
1345 Label subloop11, subloop11mask;
1346 Label subloop30, subloop30mask;
1347 Label subloop31, subloop31mask;
1349 Label subloop98, subloop98mask;
1355 sub(rcx, unroll_m - 16);
1365 if (unroll_m == 16) {
1367 } else if (unroll_m == 32) {
1379 jne(subloop96, T_NEAR);
1386 add(C, unroll_m * SIZE);
1390 lea(BO2, ptr[B + LDB * 4]);
1394 lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
1396 jg(subloop98, T_NEAR);
1399 lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
1406 // If N is too small, skip copy operation
1407 cmp(LL, UNROLL_N * 3);
1408 jle(subloop30, T_NEAR);
1410 // If A is not aligned to cache line
1412 je(subloop30, T_NEAR);
1415 jl(l_subloop_20x[1], T_NEAR);
1420 kernel(unroll_m, UNROLL_N, true, true);
1422 kernel(unroll_m, UNROLL_N, false, false);
1427 jl(l_subloop_20x[1], T_NEAR);
1431 kernel(unroll_m, UNROLL_N, false, false);
1434 jge(subloop11, T_NEAR);
1437 for (int i = 1; i <= 7; i++) {
1438 L(l_subloop_20x[i]);
1441 jne(l_subloop_20x[i + 1], T_NEAR);
1443 jne(subloop99, T_NEAR);
1445 kernel(unroll_m, i, false, false);
1446 jmp(subloop99, T_NEAR);
1453 jl(l_subloop_30x[1], T_NEAR);
1457 kernel(unroll_m, UNROLL_N, true, false);
1460 jge(subloop31, T_NEAR);
1463 for (int i = 1; i <= 7; i++) {
1464 L(l_subloop_30x[i]);
1467 jne(l_subloop_30x[i + 1], T_NEAR);
1469 jne(subloop99, T_NEAR);
1471 kernel(unroll_m, i, true, false);
1473 jmp(subloop99, T_NEAR);
1477 jmp(subloop99, T_NEAR);
1486 add(C, unroll_m * SIZE);
1489 lea(BO2, ptr[B + LDB * 4]);
1493 lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
1495 jg(subloop98mask, T_NEAR);
1497 lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
1504 // If N is too small, skip copy operation
1505 cmp(LL, UNROLL_N * 3);
1506 jle(subloop30mask, T_NEAR);
1508 // If A is not aligned to cache line
1510 je(subloop30mask, T_NEAR);
1513 jl(l_subloop_mask_20x[1], T_NEAR);
1518 kernel(unroll_m, UNROLL_N, true, true, false);
1520 kernel(unroll_m, UNROLL_N, false, false, false);
1525 jl(l_subloop_mask_20x[1], T_NEAR);
1529 kernel(unroll_m, UNROLL_N, false, false, false);
1532 jge(subloop11mask, T_NEAR);
1535 for (int i = 1; i <= 7; i++) {
1536 L(l_subloop_mask_20x[i]);
1539 jne(l_subloop_mask_20x[i + 1], T_NEAR);
1541 jne(subloop99, T_NEAR);
1543 kernel(unroll_m, i, false, false, false);
1544 jmp(subloop99, T_NEAR);
1551 jl(l_subloop_mask_30x[1], T_NEAR);
1555 kernel(unroll_m, UNROLL_N, true, false, false);
1558 jge(subloop31mask, T_NEAR);
1561 for (int i = 1; i <= 7; i++) {
1562 L(l_subloop_mask_30x[i]);
1565 jne(l_subloop_mask_30x[i + 1], T_NEAR);
1567 jne(subloop99, T_NEAR);
1569 kernel(unroll_m, i, true, false, false);
1571 jmp(subloop99, T_NEAR);
1577 // Compute address for A
1579 add(A, unroll_m * SIZE);
1582 imul(rax, rax, unroll_m);
1586 // Compute next address of BIAS
1588 add(BIAS, unroll_m * SIZE);
1594 Label buffer_in_ws, buffer_allocated;
1596 // Get the registers
1606 vmovss(xmm0, ptr[ARG_ALPHA]);
1607 vmovss(xmm1, ptr[r15]);
1614 cmp(K, STACK_K_CAPACITY);
1615 jg(buffer_in_ws, T_NEAR);
1617 // Create buffer and align to 4kB page
1618 lea(rax, ptr[K * SIZE]);
1619 imul(rax, rax, 0x30);
1622 and_(rsp, -PAGE_4K);
1623 jmp(buffer_allocated, T_NEAR);
1628 L(buffer_allocated);
1636 vmovss(ALPHA, xmm0);
1638 sub(A, -OFFSET * SIZE);
1639 sub(B, -OFFSET * SIZE);
1641 sal(LDA, BASE_SHIFT);
1642 sal(LDB, BASE_SHIFT);
1643 sal(LDC, BASE_SHIFT);
1644 lea(LDB3, ptr[LDB + LDB * 2]);
1647 vpbroadcastq(zmm2, LDA);
1648 vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE);
1652 for (int i = 0; i < 6; i++) {
1653 vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
1654 kshiftlw(k4, k4, 1);
1656 vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
1659 // Check A alignment and leading dimension; take copy-based path as
1663 and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f);
1666 for (int i = 8; i < 16; i++) {
1667 for (int j = 0; j < 3; j++) {
1668 vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j));
1672 Label main0, main1, main2, main999;
1690 jmp(main999, T_NEAR);
1695 jle(main999, T_NEAR);
1700 // Restore original stack
1706 ker_ = this->getCode<ker_t>();
1709 typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
1710 const float *alpha, const float *a, dim_t lda,
1711 const float *b, dim_t ldb, const float *beta, float *c,
1712 dim_t ldc, const float *bias, float *ws);
1714 void operator()(dim_t m, dim_t n, dim_t k,
1715 const float *alpha, const float *a, dim_t lda,
1716 const float *b, dim_t ldb, const float *beta, float *c,
1717 dim_t ldc, const float *bias, float *ws) const
1719 ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
1726 const xbyak_gemm *get_xbyak_gemm(
1727 bool isTransA, bool isTransB, float beta, bool hasBias) {
1728 auto beta_idx = [](float beta) {
1729 return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
1732 // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
1733 static xbyak_gemm *kernel_table[2][2][2][3];
1734 static std::once_flag initialized;
1735 std::call_once(initialized, [=]{
1736 for (bool isTransA: {false, true})
1737 for (bool isTransB: {false, true})
1738 for (bool hasBias: {false, true})
1739 for (float beta: {0.0f, 1.0f, 2.0f}) {
1740 // nocopy sgemm with bias for beta != 0.0 is not supported
1741 if (hasBias && beta != 0.0)
1743 kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
1744 new xbyak_gemm(isTransA, isTransB, beta, hasBias);
1748 return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
1751 void sgemm_nocopy_driver(const char *transa,
1752 const char *transb, int m, int n, int k, const float *alpha,
1753 const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
1754 float *c, dim_t ldc, const float *bias, float *ws)
1756 bool isTransA = (*transa == 'T' || *transa == 't');
1757 bool isTransB = (*transb == 'T' || *transb == 't');
1759 int Bm, sizeM, Bn, sizeN, Bk, sizeK;
1763 if ((m <= 0) || (n <= 0))
1766 if ((k <= 0) || (alpha[0] == 0.)) {
1768 if (beta[0] == 0.) {
1769 for (j = 0; j < n; j++)
1770 for (i = 0; i < m; i++)
1771 c[i + j * ldc] = 0.0;
1772 } else if (beta[0] != 1.) {
1773 for (j = 0; j < n; j++)
1774 for (i = 0; i < m; i++)
1775 c[i + j * ldc] *= beta[0];
1781 assert(IMPLICATION(bias != nullptr, *beta == 0.0));
1783 // XXX: this happens on every thread...
1784 bool hasBias = (bias != nullptr);
1785 auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
1786 auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
1787 auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
1788 assert(ker_bn && ker_b1 && ker_b0);
1790 int BM = 4032, BN, BK;
1791 if (mayiuse(avx512_core)) {
1792 BN = isTransA ? 384 : 64;
1795 BN = isTransA ? 96 : 64;
1796 BK = isTransB ? 96 : 192;
1797 if (!isTransA && !isTransB)
1800 const float *curA, *curB, *curBias = nullptr;
1803 for (Bk = 0; Bk < k; Bk += sizeK) {
1805 if (sizeK >= BK * 2)
1809 sizeK = (sizeK + 1) / 2;
1812 for (Bm = 0; Bm < m; Bm += sizeM) {
1814 if (sizeM >= BM * 2)
1817 if (sizeM > BM + BM / 2)
1818 sizeM = (sizeM + 1) / 2;
1821 for (Bn = 0; Bn < n; Bn += sizeN) {
1823 if (sizeN >= BN * 2)
1826 if (sizeN > BN + BN / 2)
1827 sizeN = (sizeN + 1) / 2;
1831 curA = a + Bm + Bk * lda;
1833 curA = a + Bk + Bm * lda;
1836 curB = b + Bk + Bn * ldb;
1838 curB = b + Bn + Bk * ldb;
1840 curC = c + Bm + (size_t)Bn * ldc;
1841 if (bias != nullptr) {
1843 curBias = bias + Bm;
1849 if (*beta == 0.0 && bias == nullptr)
1850 (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
1851 alpha, curA, lda, curB, ldb, beta, curC, ldc,
1854 (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
1855 alpha, curA, lda, curB, ldb, beta, curC, ldc,
1858 (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
1859 alpha, curA, lda, curB, ldb, beta, curC, ldc,
1869 mkldnn_status_t jit_avx512_common_gemm_f32(
1870 const char *transa, const char *transb,
1871 const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
1872 const float *A, const int *p_lda, const float *B, const int *p_ldb,
1873 const float *p_beta, float *C, const int *p_ldc, const float *bias)
1875 using namespace mkldnn::impl::utils;
1876 using namespace avx512_common_gemm_f32;
1877 using namespace gemm_utils;
1879 if (*p_beta != 0 && bias)
1880 return ref_gemm(transa, transb, p_m, p_n, p_k,
1881 p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
1883 int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
1891 float beta = *p_beta;
1894 int nthr_m, nthr_n, nthr_k, nthr_mn;
1896 // Determine threading partitioning
1897 calc_nthr_nocopy_avx512_common(
1898 m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
1899 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
1901 // May not happen, but just in case
1902 if (nthr < nthr_m * nthr_n * nthr_k)
1903 nthr = nthr_m * nthr_n * nthr_k;
1905 nthr_mn = nthr_m * nthr_n;
1907 unsigned char * ompstatus_ = nullptr;
1908 unsigned char volatile *ompstatus = nullptr;
1910 float *c_buffers = nullptr;
1911 float *ws_buffers = nullptr;
1914 ompstatus_ = (unsigned char *) malloc(
1915 nthr * CACHE_LINE_SIZE,
1917 ompstatus = (unsigned char volatile *) ompstatus_;
1920 for (int i = 0; i < nthr; i++)
1921 ompstatus[i * CACHE_LINE_SIZE] = 0;
1923 c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
1924 * sizeof(float), PAGE_4K);
1927 const size_t ws_elems_per_thr = (size_t)k * 48 + 64;
1928 const size_t ws_size_per_thr
1929 = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
1930 if (k > STACK_K_CAPACITY) {
1931 ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
1934 parallel_nd(nthr, [&](const int ithr) {
1935 int ithr_m, ithr_n, ithr_k, ithr_mn;
1936 int m_from, m_to, myM;
1937 int n_from, n_to, myN;
1938 int k_from, k_to, myK;
1940 const float *myA, *myB, *myBias = nullptr;
1941 float *myC = C, myBeta;
1942 float *ws = ws_buffers ?
1943 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
1946 int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
1948 if (ithr < nthr_m * nthr_n * nthr_k) {
1950 ithr_mn = ithr % nthr_mn;
1951 ithr_m = ithr_mn % nthr_m;
1952 ithr_n = ithr_mn / nthr_m;
1953 ithr_k = ithr / nthr_mn;
1955 /* swap ithr_k for performance improvement */
1957 ithr_k = nthr_k - 1;
1958 else if (ithr_k == nthr_k - 1)
1961 m_from = MB * (ithr_m);
1962 m_to = MB * (ithr_m + 1);
1965 myM = m_to - m_from;
1967 n_from = NB * (ithr_n);
1968 n_to = NB * (ithr_n + 1);
1971 myN = n_to - n_from;
1973 k_from = KB * (ithr_k);
1974 k_to = KB * (ithr_k + 1);
1977 myK = k_to - k_from;
1979 cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
1980 ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
1982 if ((myM > 0) && (myN > 0)) {
1984 if (*transa == 'N' || *transa == 'n') {
1985 myA = &(A[m_from + k_from * lda]);
1987 myA = &(A[k_from + m_from * lda]);
1989 if (*transb == 'N' || *transb == 'n') {
1990 myB = &(B[k_from + n_from * ldb]);
1992 myB = &(B[n_from + k_from * ldb]);
1995 myC = &(C[m_from + n_from * ldc]);
1999 myBias = &(bias[m_from]);
2001 myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
2007 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
2008 lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
2010 if (nthr_k > 1 && !sum_later)
2011 ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
2014 if (nthr_k > 1 && !sum_later) {
2016 // sum matrices partitioned along K dimension
2019 partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2023 myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
2025 /* need to wait until main thread finishes */
2026 while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
2029 /* my cache is hot */
2030 sum_two_matrices(myM, n2, myC, MB,
2031 &C[m_from + (n_from + n1) * ldc], ldc);
2034 for (int ik = 1; ik < nthr_k; ++ik) {
2037 myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
2040 while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2043 sum_two_matrices(myM, n2, myC, MB,
2044 &C[m_from + (n_from + n1) * ldc], ldc);
2052 // handle C summation later
2053 if (nthr_k > 1 && ompstatus[0] == 0) {
2055 parallel_nd(nthr, [&](const int ithr) {
2056 int ithr_m, ithr_n, ithr_k, ithr_mn;
2057 int m_from, m_to, myM;
2058 int n_from, n_to, myN;
2062 if (ithr < nthr_m * nthr_n * nthr_k) {
2064 ithr_mn = ithr % nthr_mn;
2065 ithr_m = ithr_mn % nthr_m;
2066 ithr_n = ithr_mn / nthr_m;
2067 ithr_k = ithr / nthr_mn;
2069 /* swap ithr_k for performance improvement */
2071 ithr_k = nthr_k - 1;
2072 else if (ithr_k == nthr_k - 1)
2075 m_from = MB * (ithr_m);
2076 m_to = MB * (ithr_m + 1);
2079 myM = m_to - m_from;
2081 n_from = NB * (ithr_n);
2082 n_to = NB * (ithr_n + 1);
2085 myN = n_to - n_from;
2087 cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
2090 // sum matrices partitioned along K dimension
2093 partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
2097 myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
2100 /* my cache is hot */
2101 sum_two_matrices(myM, n2, myC, MB,
2102 &C[m_from + (n_from + n1) * ldc], ldc);
2105 for (int ik = 1; ik < nthr_k; ++ik) {
2108 myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
2111 sum_two_matrices(myM, n2, myC, MB,
2112 &C[m_from + (n_from + n1) * ldc], ldc);
2124 return mkldnn_success;
2131 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s