1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "mkldnn_thread.hpp"
22 #include "gemm_utils.hpp"
23 #include "jit_avx512_common_gemm_f32.hpp"
25 #define CACHE_LINE_SIZE 64
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::utils;
34 using namespace Xbyak;
35 #define STACKSIZE get_size_of_abi_save_regs()
37 #define STACK_K_CAPACITY 32
39 #define STACK_K_CAPACITY 2048
44 #define SECOND_FETCH unroll_n
48 struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
49 xbyak_gemm(char transa, char transb, float beta, bool hasBias = false,
50 void *code_ptr = nullptr,
51 size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
52 : jit_generator(code_ptr, code_size)
54 enum { ver_avx512_core, ver_avx512_mic } ver =
55 mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic;
57 bool isTransA = (transa == 'T' || transa == 't');
58 bool isTransB = (transb == 'T' || transb == 't');
59 bool isBeta0 = (beta == 0.0);
60 bool isBetaN = (!isBeta0 && beta != 1.0);
62 // various definitions for convenience
63 auto ARG_M = abi_param1;
64 auto ARG_N = abi_param2;
66 auto ARG_ALPHA = abi_param4;
68 auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
69 auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
70 sizeof(float *) + STACKSIZE];
71 const auto stackOffset = OFFSET_SHADOWSPACE +
72 sizeof(float *) + STACKSIZE;
78 const auto stackOffset = STACKSIZE;
82 auto ARG_B = ptr[rsp + 8 + stackOffset];
83 auto ARG_LDB = ptr[rsp + 16 + stackOffset];
84 auto ARG_BETA = ptr[rsp + 24 + stackOffset];
85 auto ARG_C = ptr[rsp + 32 + stackOffset];
86 auto ARG_LDC = ptr[rsp + 40 + stackOffset];
87 auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
88 auto ARG_WS = ptr[rsp + 56 + stackOffset];
94 auto AO1 = abi_param2;
95 auto BO1 = abi_param4;
100 auto LDA4 = abi_param1;
102 auto BIAS1 = abi_param1;
104 auto M = qword[rsp + 0];
105 auto N = qword[rsp + 8];
106 auto FLAG = qword[rsp + 16];
107 auto I = qword[rsp + 24];
108 auto C = qword[rsp + 32];
109 auto BIAS = qword[rsp + 40];
110 auto ALPHA = qword[rsp + 48];
111 auto BETA = qword[rsp + 64];
112 auto ORIG_A = qword[rsp + 80];
113 auto ORIG_SP = qword[rsp + 120];
122 auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80;
123 auto PREFETCHSIZEB = 16;
125 Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15,
126 zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24,
127 zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 };
129 // Function for packing if needed
130 auto do_pack = [&](int unroll_m) {
131 Label pack2, pack3, pack4, pack10;
134 lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
142 for (int i = 0; i < 4; i++) {
143 vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
145 vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
147 vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
150 vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE]
155 + (unroll_m * i + 1 * 16 - OFFSET)
161 + (unroll_m * i + 2 * 16 - OFFSET)
167 for (int i = 0; i < 4; i++) {
169 vgatherqps(ymm5 | k4,
170 ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]);
171 lea(BO2, ptr[BO1 + LDA * 8]);
173 vgatherqps(ymm6 | k4,
174 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
175 vshuff64x2(zmm0, zmm5, zmm6, 0x44);
178 lea(BO2, ptr[BO2 + LDA * 8]);
180 vgatherqps(ymm5 | k4,
181 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
182 lea(BO2, ptr[BO2 + LDA * 8]);
184 vgatherqps(ymm6 | k4,
185 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
186 vshuff64x2(zmm1, zmm5, zmm6, 0x44);
190 lea(BO2, ptr[BO2 + LDA * 8]);
192 vgatherqps(ymm5 | k4,
193 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
194 lea(BO2, ptr[BO2 + LDA * 8]);
196 vgatherqps(ymm6 | k4,
197 ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
198 lea(BO2, ptr[BO2 + LDA * 8]);
199 vshuff64x2(zmm2, zmm5, zmm6, 0x44);
202 vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE],
206 + (unroll_m * i + 1 * 16 - OFFSET)
211 + (unroll_m * i + 2 * 16 - OFFSET)
217 add(AO1, unroll_m * 4 * SIZE);
231 vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
233 vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
235 vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
239 vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]);
240 lea(BO2, ptr[BO1 + LDA * 8]);
242 vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
243 vshuff64x2(zmm0, zmm5, zmm6, 0x44);
246 lea(BO2, ptr[BO2 + LDA * 8]);
248 vgatherqps(ymm5 | k4,
249 ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
250 lea(BO2, ptr[BO2 + LDA * 8]);
252 vgatherqps(ymm6 | k4,
253 ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
254 vshuff64x2(zmm1, zmm5, zmm6, 0x44);
258 lea(BO2, ptr[BO2 + LDA * 8]);
260 vgatherqps(ymm5 | k4,
261 ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
262 lea(BO2, ptr[BO2 + LDA * 8]);
264 vgatherqps(ymm6 | k4,
265 ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
266 lea(BO2, ptr[BO2 + LDA * 8]);
267 vshuff64x2(zmm2, zmm5, zmm6, 0x44);
272 vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
275 vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE],
278 vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE],
281 add(AO1, unroll_m * SIZE);
289 // Function to update C, covering masking and other considerations
290 auto update = [&](Zmm reg, bool useCO1, int offset, int mask,
291 bool useScale = false) {
292 vmulps(reg, reg, VALPHA);
298 vmovups(zmm0, ptr[CO1 + offset * SIZE]);
300 vmovups(zmm0, ptr[CO2 + offset * SIZE]);
304 vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]);
306 vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]);
310 vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]);
312 vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]);
316 vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]);
318 vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]);
325 vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]);
327 vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]);
331 vmovups(zmm0 | k1 | T_z,
332 ptr[CO1 + LDC + offset * SIZE]);
334 vmovups(zmm0 | k1 | T_z,
335 ptr[CO2 + LDC + offset * SIZE]);
339 vmovups(zmm0 | k2 | T_z,
340 ptr[CO1 + LDC + offset * SIZE]);
342 vmovups(zmm0 | k2 | T_z,
343 ptr[CO2 + LDC + offset * SIZE]);
347 vmovups(zmm0 | k3 | T_z,
348 ptr[CO1 + LDC + offset * SIZE]);
350 vmovups(zmm0 | k3 | T_z,
351 ptr[CO2 + LDC + offset * SIZE]);
356 vaddps(zmm0, reg, zmm0);
358 vfmadd132ps(zmm0, reg, VBETA);
364 vmovups(ptr[CO1 + offset * SIZE], zmm0);
366 vmovups(ptr[CO2 + offset * SIZE], zmm0);
370 vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1);
372 vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1);
376 vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2);
378 vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2);
382 vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3);
384 vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3);
391 vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0);
393 vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0);
397 vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1);
399 vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1);
403 vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2);
405 vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2);
409 vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3);
411 vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3);
420 vmovups(ptr[CO1 + offset * SIZE], reg);
422 vmovups(ptr[CO2 + offset * SIZE], reg);
426 vmovups(ptr[CO1 + offset * SIZE], reg | k1);
428 vmovups(ptr[CO2 + offset * SIZE], reg | k1);
432 vmovups(ptr[CO1 + offset * SIZE], reg | k2);
434 vmovups(ptr[CO2 + offset * SIZE], reg | k2);
438 vmovups(ptr[CO1 + offset * SIZE], reg | k3);
440 vmovups(ptr[CO2 + offset * SIZE], reg | k3);
447 vmovups(ptr[CO1 + LDC + offset * SIZE], reg);
449 vmovups(ptr[CO2 + LDC + offset * SIZE], reg);
453 vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1);
455 vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1);
459 vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2);
461 vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2);
465 vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3);
467 vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3);
472 vpxorq(reg, reg, reg);
475 // Loop with unroll_n - 2 FMAs; called by innerkernel
476 auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) {
477 for (int i = 2; i < unroll_n; i++) {
478 if (ver == ver_avx512_core) {
485 + (iteration - OFFSET) * SIZE]);
491 + (iteration - OFFSET) * SIZE]);
495 ptr[BO2 + (iteration - OFFSET) * SIZE]);
501 + (iteration - OFFSET) * SIZE]);
507 + (iteration - OFFSET) * SIZE]);
513 + (iteration - OFFSET) * SIZE]);
517 vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
519 vfmadd231ps(regs[i], zmm3, zmm0);
521 vfmadd231ps(regs[i + 8], zmm3, zmm1);
523 vfmadd231ps(regs[i + 16], zmm3, zmm2);
528 vfmadd231ps(regs[i], zmm0,
529 zword_b[BO1 + LDB * 2
530 + (iteration - OFFSET) * SIZE]);
532 vfmadd231ps(regs[i + 8], zmm1,
533 zword_b[BO1 + LDB * 2
534 + (iteration - OFFSET) * SIZE]);
536 vfmadd231ps(regs[i + 16], zmm2,
537 zword_b[BO1 + LDB * 2
538 + (iteration - OFFSET) * SIZE]);
541 vfmadd231ps(regs[i], zmm0,
543 + (iteration - OFFSET) * SIZE]);
545 vfmadd231ps(regs[i + 8], zmm1,
547 + (iteration - OFFSET) * SIZE]);
549 vfmadd231ps(regs[i + 16], zmm2,
551 + (iteration - OFFSET) * SIZE]);
554 vfmadd231ps(regs[i], zmm0,
555 zword_b[BO2 + (iteration - OFFSET) * SIZE]);
557 vfmadd231ps(regs[i + 8], zmm1,
558 zword_b[BO2 + (iteration - OFFSET) * SIZE]);
560 vfmadd231ps(regs[i + 16], zmm2,
561 zword_b[BO2 + (iteration - OFFSET) * SIZE]);
564 vfmadd231ps(regs[i], zmm0,
565 zword_b[BO2 + LDB * 1
566 + (iteration - OFFSET) * SIZE]);
568 vfmadd231ps(regs[i + 8], zmm1,
569 zword_b[BO2 + LDB * 1
570 + (iteration - OFFSET) * SIZE]);
572 vfmadd231ps(regs[i + 16], zmm2,
573 zword_b[BO2 + LDB * 1
574 + (iteration - OFFSET) * SIZE]);
577 vfmadd231ps(regs[i], zmm0,
578 zword_b[BO2 + LDB * 2
579 + (iteration - OFFSET) * SIZE]);
581 vfmadd231ps(regs[i + 8], zmm1,
582 zword_b[BO2 + LDB * 2
583 + (iteration - OFFSET) * SIZE]);
585 vfmadd231ps(regs[i + 16], zmm2,
586 zword_b[BO2 + LDB * 2
587 + (iteration - OFFSET) * SIZE]);
590 vfmadd231ps(regs[i], zmm0,
592 + (iteration - OFFSET) * SIZE]);
594 vfmadd231ps(regs[i + 8], zmm1,
596 + (iteration - OFFSET) * SIZE]);
598 vfmadd231ps(regs[i + 16], zmm2,
600 + (iteration - OFFSET) * SIZE]);
605 regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]);
607 vfmadd231ps(regs[i + 8], zmm1,
608 zword_b[BO1 + (i - OFFSET) * SIZE]);
610 vfmadd231ps(regs[i + 16], zmm2,
611 zword_b[BO1 + (i - OFFSET) * SIZE]);
617 // Innerkernel; called by kernel
618 auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect,
619 bool isCopy, bool doCPrefetch, bool isUnmasked = true) {
620 for (int i = 0; i < 8; i++) {
623 + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET)
627 + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET)
631 + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET)
634 prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]);
636 prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]);
638 prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]);
643 if (isUnmasked || unroll_m > 16) {
646 + (unroll_m * i + 0 * 16 - OFFSET)
649 vmovups(zmm0 | k1 | T_z,
651 + (unroll_m * i + 0 * 16 - OFFSET)
654 if (unroll_m >= 32) {
655 if (isUnmasked || unroll_m > 32) {
656 vmovups(zmm1, ptr[AO1
657 + (unroll_m * i + 1 * 16
661 vmovups(zmm1 | k2 | T_z,
663 + (unroll_m * i + 1 * 16
668 if (unroll_m >= 48) {
670 vmovups(zmm2, ptr[AO1
671 + (unroll_m * i + 2 * 16
675 vmovups(zmm2 | k3 | T_z,
677 + (unroll_m * i + 2 * 16
684 if (isUnmasked || unroll_m > 16) {
685 vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
687 vmovups(zmm0 | k1 | T_z,
688 ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
690 if (unroll_m >= 32) {
691 if (isUnmasked || unroll_m > 32) {
692 vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
694 vmovups(zmm1 | k2 | T_z,
695 ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
698 if (unroll_m >= 48) {
700 vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
702 vmovups(zmm2 | k3 | T_z,
703 ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
709 if (ver == ver_avx512_core) {
711 vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
713 vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
715 vfmadd231ps(regs[0], zmm3, zmm0);
717 vfmadd231ps(regs[0 + 8], zmm3, zmm1);
719 vfmadd231ps(regs[0 + 16], zmm3, zmm2);
722 vfmadd231ps(regs[0], zmm0,
723 zword_b[BO1 + (i - OFFSET) * SIZE]);
725 vfmadd231ps(regs[0 + 8], zmm1,
726 zword_b[BO1 + (i - OFFSET) * SIZE]);
728 vfmadd231ps(regs[0 + 16], zmm2,
729 zword_b[BO1 + (i - OFFSET) * SIZE]);
731 vfmadd231ps(regs[0], zmm0,
732 zword_b[BO1 + (0 - OFFSET) * SIZE]);
734 vfmadd231ps(regs[0 + 8], zmm1,
735 zword_b[BO1 + (0 - OFFSET) * SIZE]);
737 vfmadd231ps(regs[0 + 16], zmm2,
738 zword_b[BO1 + (0 - OFFSET) * SIZE]);
742 if (unroll_n >= i + 1) {
747 ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]);
750 prefetcht0(ptr[BO1 + LDB
751 + (PREFETCHSIZEB - OFFSET) * SIZE]);
754 prefetcht0(ptr[BO1 + LDB * 2
755 + (PREFETCHSIZEB - OFFSET) * SIZE]);
758 prefetcht0(ptr[BO1 + LDB3
759 + (PREFETCHSIZEB - OFFSET) * SIZE]);
763 ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]);
766 prefetcht0(ptr[BO2 + LDB
767 + (PREFETCHSIZEB - OFFSET) * SIZE]);
770 prefetcht0(ptr[BO2 + LDB * 2
771 + (PREFETCHSIZEB - OFFSET) * SIZE]);
774 prefetcht0(ptr[BO2 + LDB3
775 + (PREFETCHSIZEB - OFFSET) * SIZE]);
782 if (ver == ver_avx512_core) {
785 ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
787 vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]);
789 vfmadd231ps(regs[1], zmm3, zmm0);
791 vfmadd231ps(regs[1 + 8], zmm3, zmm1);
793 vfmadd231ps(regs[1 + 16], zmm3, zmm2);
796 vfmadd231ps(regs[1], zmm0,
797 zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
799 vfmadd231ps(regs[1 + 8], zmm1,
800 zword_b[BO1 + LDB * 1
801 + (i - OFFSET) * SIZE]);
803 vfmadd231ps(regs[1 + 16], zmm2,
804 zword_b[BO1 + LDB * 1
805 + (i - OFFSET) * SIZE]);
807 vfmadd231ps(regs[1], zmm0,
808 zword_b[BO1 + (1 - OFFSET) * SIZE]);
810 vfmadd231ps(regs[1 + 8], zmm1,
811 zword_b[BO1 + (1 - OFFSET) * SIZE]);
813 vfmadd231ps(regs[1 + 16], zmm2,
814 zword_b[BO1 + (1 - OFFSET) * SIZE]);
820 if (isUnmasked || unroll_m > 16) {
822 + (unroll_m * i + 0 * 16 - OFFSET)
827 + (unroll_m * i + 0 * 16 - OFFSET)
831 if (unroll_m >= 32) {
832 if (isUnmasked || unroll_m > 32) {
834 + (unroll_m * i + 1 * 16 - OFFSET)
839 + (unroll_m * i + 1 * 16 - OFFSET)
844 if (unroll_m >= 48) {
847 + (unroll_m * i + 2 * 16 - OFFSET)
852 + (unroll_m * i + 2 * 16 - OFFSET)
858 sub(LDA4, -unroll_m * 8 * SIZE);
860 fmaloop(unroll_m, unroll_n, i);
864 if (ver == ver_avx512_core)
865 prefetchw(ptr[CO2 + 0 * 16 * SIZE]);
867 prefetcht0(ptr[CO2 + 0 * 16 * SIZE]);
871 if (doCPrefetch && unroll_m >= 32) {
872 if (ver == ver_avx512_core)
873 prefetchw(ptr[CO2 + 1 * 16 * SIZE]);
875 prefetcht0(ptr[CO2 + 1 * 16 * SIZE]);
878 if (ver == ver_avx512_core)
879 prefetcht0(ptr[AA + 16 * 0 * SIZE]);
881 prefetcht2(ptr[AA + 16 * 0 * SIZE]);
886 if (unroll_m >= 48) {
887 if (ver == ver_avx512_core)
888 prefetchw(ptr[CO2 + 2 * 16 * SIZE]);
890 prefetcht0(ptr[CO2 + 2 * 16 * SIZE]);
895 if (unroll_m >= 32) {
896 if (ver == ver_avx512_core)
897 prefetcht0(ptr[AA + 16 * 1 * SIZE]);
899 prefetcht2(ptr[AA + 16 * 1 * SIZE]);
905 prefetcht0(ptr[BO1 + BO2]);
916 if (unroll_m >= 48) {
917 if (ver == ver_avx512_core)
918 prefetcht0(ptr[AA + 16 * 2 * SIZE]);
920 prefetcht2(ptr[AA + 16 * 2 * SIZE]);
922 lea(AA, ptr[AA + LDA]);
926 if (isUnmasked || unroll_m > 16) {
928 ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
930 vmovups(zmm0 | k1 | T_z,
931 ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
933 if (unroll_m >= 32) {
934 if (isUnmasked || unroll_m > 32) {
935 vmovups(zmm1, ptr[AO1
936 + (unroll_m * 8 + 1 * 16 - OFFSET)
939 vmovups(zmm1 | k2 | T_z,
941 + (unroll_m * 8 + 1 * 16 - OFFSET)
945 if (unroll_m >= 48) {
947 vmovups(zmm2, ptr[AO1
948 + (unroll_m * 8 + 2 * 16 - OFFSET)
951 vmovups(zmm2 | k3 | T_z,
953 + (unroll_m * 8 + 2 * 16 - OFFSET)
957 sub(AO1, -unroll_m * 8 * SIZE);
963 // Main kernel; does prefetching and calls innerkernel
964 // After calculating results in registers, writes back to C matrix by
966 auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
967 bool isCopy, bool isUnmasked = true) {
969 lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
975 lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]);
977 auto step = ver == ver_avx512_core ? 2 : 4;
978 lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]);
982 lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]);
986 if (isUnmasked || unroll_m > 16) {
988 ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
990 vmovups(zmm0 | k1 | T_z,
991 ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
993 if (unroll_m >= 32) {
994 if (isUnmasked || unroll_m > 32) {
995 vmovups(zmm1, ptr[AO1
996 + (unroll_m * 0 + 1 * 16 - OFFSET)
999 vmovups(zmm1 | k2 | T_z,
1001 + (unroll_m * 0 + 1 * 16 - OFFSET)
1005 if (unroll_m >= 48) {
1007 vmovups(zmm2, ptr[AO1
1008 + (unroll_m * 0 + 2 * 16 - OFFSET)
1011 vmovups(zmm2 | k3 | T_z,
1013 + (unroll_m * 0 + 2 * 16 - OFFSET)
1019 Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18;
1023 sub(LL, SECOND_FETCH);
1024 jle(kernel13, T_NEAR);
1029 unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked);
1030 jg(kernel12, T_NEAR);
1034 lea(CO2, ptr[CO1 + (16 - 1) * SIZE]);
1036 jle(kernel15, T_NEAR);
1040 innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked);
1041 jg(kernel14, T_NEAR);
1047 jle(kernel18, T_NEAR);
1052 if (isUnmasked || unroll_m > 16) {
1053 vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
1055 vmovups(zmm0 | k1 | T_z,
1056 ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
1058 if (unroll_m >= 32) {
1059 if (isUnmasked || unroll_m > 32) {
1060 vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
1062 vmovups(zmm1 | k2 | T_z,
1063 ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
1066 if (unroll_m >= 48) {
1068 vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
1070 vmovups(zmm2 | k3 | T_z,
1071 ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
1077 for (int i = 0; i < unroll_n; i++) {
1081 vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
1085 zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
1089 zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
1093 zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]);
1096 vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]);
1100 zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
1104 zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
1108 zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]);
1112 vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
1114 vfmadd231ps(regs[i], zmm3, zmm0);
1115 if (unroll_m >= 32) {
1116 vfmadd231ps(regs[i + 8], zmm3, zmm1);
1118 if (unroll_m >= 48) {
1119 vfmadd231ps(regs[i + 16], zmm3, zmm2);
1124 if (isUnmasked || unroll_m > 16) {
1125 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
1128 vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
1131 if (unroll_m >= 32) {
1132 if (isUnmasked || unroll_m > 32) {
1134 + (unroll_m * 0 + 1 * 16 - OFFSET)
1139 + (unroll_m * 0 + 1 * 16 - OFFSET)
1144 if (unroll_m >= 48) {
1147 + (unroll_m * 0 + 2 * 16 - OFFSET)
1152 + (unroll_m * 0 + 2 * 16 - OFFSET)
1157 sub(LDA4, -unroll_m * SIZE);
1161 if (isUnmasked || unroll_m > 16) {
1163 ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
1165 vmovups(zmm0 | k1 | T_z,
1166 ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
1168 if (unroll_m >= 32) {
1169 if (isUnmasked || unroll_m > 32) {
1170 vmovups(zmm1, ptr[AO1
1171 + (unroll_m * 1 + 1 * 16 - OFFSET)
1174 vmovups(zmm1 | k2 | T_z,
1176 + (unroll_m * 1 + 1 * 16 - OFFSET)
1180 if (unroll_m >= 48) {
1182 vmovups(zmm2, ptr[AO1
1183 + (unroll_m * 1 + 2 * 16 - OFFSET)
1186 vmovups(zmm2 | k3 | T_z,
1188 + (unroll_m * 1 + 2 * 16 - OFFSET)
1192 sub(AO1, -unroll_m * SIZE);
1197 if (unroll_n >= 4) {
1205 jg(kernel16, T_NEAR);
1209 vbroadcastss(VALPHA, ALPHA);
1212 vbroadcastss(VBETA, BETA);
1215 // Write back the results; all beta cases need to be handled
1218 if (isUnmasked || unroll_m > 16)
1219 vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
1221 vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]);
1222 if (unroll_m >= 32) {
1223 if (isUnmasked || unroll_m > 32)
1224 vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]);
1226 vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]);
1228 if (unroll_m >= 48) {
1230 vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]);
1232 vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]);
1236 for (int i = 0; i < unroll_n; i++) {
1237 bool useScale = i % 2 != 0;
1238 bool useCO1 = i < 2;
1240 lea(CO2, ptr[CO1 + LDC * 2]);
1241 if (i == 4 || i == 6)
1242 lea(CO2, ptr[CO2 + LDC * 2]);
1244 vaddps(regs[i], VBIAS1, regs[i]);
1245 if (isUnmasked || unroll_m > 16) {
1246 update(regs[i], useCO1, 0, 0, useScale);
1248 update(regs[i], useCO1, 0, 1, useScale);
1250 if (unroll_m >= 32) {
1252 vaddps(regs[i + 8], VBIAS2, regs[i + 8]);
1253 if (isUnmasked || unroll_m > 32) {
1254 update(regs[i + 8], useCO1, 16, 0, useScale);
1256 update(regs[i + 8], useCO1, 16, 2, useScale);
1259 if (unroll_m >= 48) {
1261 vaddps(regs[i + 16], VBIAS3, regs[i + 16]);
1263 update(regs[i + 16], useCO1, 32, 0, useScale);
1265 update(regs[i + 16], useCO1, 32, 3, useScale);
1271 case 1: add(CO1, LDC); break;
1272 case 2: lea(CO1, ptr[CO1 + LDC * 2]); break;
1273 case 3: lea(CO1, ptr[CO2 + LDC * 1]); break;
1274 case 4: lea(CO1, ptr[CO2 + LDC * 2]); break;
1275 case 5: lea(CO1, ptr[CO2 + LDC * 1]); break;
1276 case 6: lea(CO1, ptr[CO2 + LDC * 2]); break;
1277 case 7: lea(CO1, ptr[CO2 + LDC * 1]); break;
1278 case 8: lea(CO1, ptr[CO2 + LDC * 2]); break;
1281 // Compute next address of B
1283 lea(rax, ptr[K * SIZE]);
1290 lea(BO1, ptr[BO1 + LDB * 2]);
1291 lea(BO2, ptr[BO2 + LDB * 2]);
1294 lea(BO1, ptr[BO1 + LDB3]);
1295 lea(BO2, ptr[BO2 + LDB3]);
1298 lea(BO1, ptr[BO1 + LDB * 4]);
1299 lea(BO2, ptr[BO2 + LDB * 4]);
1302 lea(BO1, ptr[BO1 + LDB * 4]);
1304 lea(BO2, ptr[BO2 + LDB * 4]);
1308 lea(BO1, ptr[BO1 + LDB3 * 2]);
1309 lea(BO2, ptr[BO2 + LDB3 * 2]);
1312 lea(BO1, ptr[BO1 + LDB * 8]);
1314 lea(BO2, ptr[BO2 + LDB * 8]);
1318 lea(BO1, ptr[BO1 + LDB * 8]);
1319 lea(BO2, ptr[BO2 + LDB * 8]);
1328 add(BO1, unroll_n * SIZE);
1332 // High-level subroutine; does packing if needed, then splits C matrix.
1333 // Operates on chunks of 48 rows, 8 columns at a time (handling tail
1334 // cases appropriately by doing 32 or 16 rows, and/or with masking,
1335 // and/or fewer columns).
1336 auto subloop = [&](int unroll_m) {
1337 Label l_subloop_20x[8], l_subloop_mask_20x[8];
1338 Label l_subloop_30x[8], l_subloop_mask_30x[8];
1340 Label subloop11, subloop11mask;
1341 Label subloop30, subloop30mask;
1342 Label subloop31, subloop31mask;
1344 Label subloop98, subloop98mask;
1350 sub(rcx, unroll_m - 16);
1360 if (unroll_m == 16) {
1362 } else if (unroll_m == 32) {
1374 jne(subloop96, T_NEAR);
1381 add(C, unroll_m * SIZE);
1385 lea(BO2, ptr[B + LDB * 4]);
1389 lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
1391 jg(subloop98, T_NEAR);
1394 lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
1401 // If N is too small, skip copy operation
1402 cmp(LL, UNROLL_N * 3);
1403 jle(subloop30, T_NEAR);
1405 // If A is not aligned to cache line
1407 je(subloop30, T_NEAR);
1410 jl(l_subloop_20x[1], T_NEAR);
1415 kernel(unroll_m, UNROLL_N, true, true);
1417 kernel(unroll_m, UNROLL_N, false, false);
1422 jl(l_subloop_20x[1], T_NEAR);
1426 kernel(unroll_m, UNROLL_N, false, false);
1429 jge(subloop11, T_NEAR);
1432 for (int i = 1; i <= 7; i++) {
1433 L(l_subloop_20x[i]);
1436 jne(l_subloop_20x[i + 1], T_NEAR);
1438 jne(subloop99, T_NEAR);
1440 kernel(unroll_m, i, false, false);
1441 jmp(subloop99, T_NEAR);
1448 jl(l_subloop_30x[1], T_NEAR);
1452 kernel(unroll_m, UNROLL_N, true, false);
1455 jge(subloop31, T_NEAR);
1458 for (int i = 1; i <= 7; i++) {
1459 L(l_subloop_30x[i]);
1462 jne(l_subloop_30x[i + 1], T_NEAR);
1464 jne(subloop99, T_NEAR);
1466 kernel(unroll_m, i, true, false);
1468 jmp(subloop99, T_NEAR);
1472 jmp(subloop99, T_NEAR);
1481 add(C, unroll_m * SIZE);
1484 lea(BO2, ptr[B + LDB * 4]);
1488 lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
1490 jg(subloop98mask, T_NEAR);
1492 lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
1499 // If N is too small, skip copy operation
1500 cmp(LL, UNROLL_N * 3);
1501 jle(subloop30mask, T_NEAR);
1503 // If A is not aligned to cache line
1505 je(subloop30mask, T_NEAR);
1508 jl(l_subloop_mask_20x[1], T_NEAR);
1513 kernel(unroll_m, UNROLL_N, true, true, false);
1515 kernel(unroll_m, UNROLL_N, false, false, false);
1520 jl(l_subloop_mask_20x[1], T_NEAR);
1524 kernel(unroll_m, UNROLL_N, false, false, false);
1527 jge(subloop11mask, T_NEAR);
1530 for (int i = 1; i <= 7; i++) {
1531 L(l_subloop_mask_20x[i]);
1534 jne(l_subloop_mask_20x[i + 1], T_NEAR);
1536 jne(subloop99, T_NEAR);
1538 kernel(unroll_m, i, false, false, false);
1539 jmp(subloop99, T_NEAR);
1546 jl(l_subloop_mask_30x[1], T_NEAR);
1550 kernel(unroll_m, UNROLL_N, true, false, false);
1553 jge(subloop31mask, T_NEAR);
1556 for (int i = 1; i <= 7; i++) {
1557 L(l_subloop_mask_30x[i]);
1560 jne(l_subloop_mask_30x[i + 1], T_NEAR);
1562 jne(subloop99, T_NEAR);
1564 kernel(unroll_m, i, true, false, false);
1566 jmp(subloop99, T_NEAR);
1572 // Compute address for A
1574 add(A, unroll_m * SIZE);
1577 imul(rax, rax, unroll_m);
1581 // Compute next address of BIAS
1583 add(BIAS, unroll_m * SIZE);
1589 Label buffer_in_ws, buffer_allocated;
1591 // Get the registers
1601 vmovss(xmm0, ptr[ARG_ALPHA]);
1602 vmovss(xmm1, ptr[r15]);
1609 cmp(K, STACK_K_CAPACITY);
1610 jg(buffer_in_ws, T_NEAR);
1612 // Create buffer and align to 4kB page
1613 lea(rax, ptr[K * SIZE]);
1614 imul(rax, rax, 0x30);
1617 and_(rsp, -PAGE_4K);
1618 jmp(buffer_allocated, T_NEAR);
1623 L(buffer_allocated);
1631 vmovss(ALPHA, xmm0);
1633 sub(A, -OFFSET * SIZE);
1634 sub(B, -OFFSET * SIZE);
1636 sal(LDA, BASE_SHIFT);
1637 sal(LDB, BASE_SHIFT);
1638 sal(LDC, BASE_SHIFT);
1639 lea(LDB3, ptr[LDB + LDB * 2]);
1642 vpbroadcastq(zmm2, LDA);
1643 vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE);
1647 for (int i = 0; i < 6; i++) {
1648 vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
1649 kshiftlw(k4, k4, 1);
1651 vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
1654 // Check A alignment and leading dimension; take copy-based path as
1658 and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f);
1661 for (int i = 8; i < 16; i++) {
1662 for (int j = 0; j < 3; j++) {
1663 vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j));
1667 Label main0, main1, main2, main999;
1685 jmp(main999, T_NEAR);
1690 jle(main999, T_NEAR);
1695 // Restore original stack
1701 ker_ = reinterpret_cast<decltype(ker_)>(
1702 const_cast<uint8_t *>(this->getCode()));
1705 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm)
1707 void operator()(long long int m, long long int n, long long int k,
1708 const float *alpha, const float *a, long long int lda,
1709 const float *b, long long int ldb, const float *beta, float *c,
1710 long long int ldc, const float *bias, float *ws)
1712 (*ker_)(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
1716 void (*ker_)(long long int m, long long int n, long long int k,
1717 const float *alpha, const float *a, long long int lda,
1718 const float *b, long long int ldb, const float *beta, float *c,
1719 long long int ldc, const float *bias, float *ws);
1722 typedef void (*ker)(long long int, long long int, long long int, float *,
1723 float *, long long int, float *, long long int, float *, float *,
1724 long long int, float *, float *);
1725 void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa,
1726 const char *transb, int m, int n, int k, const float *alpha,
1727 const float *a, int lda, const float *b, int ldb, const float *beta,
1728 float *c, int ldc, const float *bias, float *ws)
1730 bool isTransA = (*transa == 'T' || *transa == 't');
1731 bool isTransB = (*transb == 'T' || *transb == 't');
1733 int Bm, sizeM, Bn, sizeN, Bk, sizeK;
1737 if ((m <= 0) || (n <= 0))
1740 if ((k <= 0) || (alpha[0] == 0.)) {
1742 if (beta[0] == 0.) {
1743 for (j = 0; j < n; j++)
1744 for (i = 0; i < m; i++)
1745 c[i + j * ldc] = 0.0;
1746 } else if (beta[0] != 1.) {
1747 for (j = 0; j < n; j++)
1748 for (i = 0; i < m; i++)
1749 c[i + j * ldc] *= beta[0];
1755 int BM = 4032, BN, BK;
1756 if (mayiuse(avx512_core)) {
1757 BN = isTransA ? 384 : 64;
1760 BN = isTransA ? 96 : 64;
1761 BK = isTransB ? 96 : 192;
1762 if (!isTransA && !isTransB)
1765 const float *curA, *curB, *curBias = nullptr;
1768 for (Bk = 0; Bk < k; Bk += sizeK) {
1770 if (sizeK >= BK * 2)
1774 sizeK = (sizeK + 1) / 2;
1777 for (Bm = 0; Bm < m; Bm += sizeM) {
1779 if (sizeM >= BM * 2)
1782 if (sizeM > BM + BM / 2)
1783 sizeM = (sizeM + 1) / 2;
1786 for (Bn = 0; Bn < n; Bn += sizeN) {
1788 if (sizeN >= BN * 2)
1791 if (sizeN > BN + BN / 2)
1792 sizeN = (sizeN + 1) / 2;
1796 curA = a + Bm + (size_t)Bk * lda;
1798 curA = a + Bk + (size_t)Bm * lda;
1801 curB = b + Bk + (size_t)Bn * ldb;
1803 curB = b + Bn + (size_t)Bk * ldb;
1805 curC = c + Bm + (size_t)Bn * ldc;
1806 if (bias != nullptr) {
1808 curBias = bias + Bm;
1814 if (*beta == 0.0 && bias == nullptr)
1815 (*ker_b0_)((long long int)sizeM, (long long int)sizeN,
1816 (long long int)sizeK, alpha, curA,
1817 (long long int)lda, curB, (long long int)ldb,
1818 beta, curC, (long long int)ldc, curBias, ws);
1820 (*ker_bn_)((long long int)sizeM, (long long int)sizeN,
1821 (long long int)sizeK, alpha, curA,
1822 (long long int)lda, curB, (long long int)ldb,
1823 beta, curC, (long long int)ldc, curBias, ws);
1825 (*ker_b1_)((long long int)sizeM, (long long int)sizeN,
1826 (long long int)sizeK, alpha, curA,
1827 (long long int)lda, curB, (long long int)ldb, beta,
1828 curC, (long long int)ldc, curBias, ws);
1836 void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
1837 const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
1838 const float *A, const int *p_lda, const float *B, const int *p_ldb,
1839 const float *p_beta, float *C, const int *p_ldc, const float *bias)
1841 if (beta_ == 0. || beta_ == 1.)
1842 assert(*p_beta == beta_);
1843 assert((one_of(*transa, 'T', 't') == one_of(transa_, 'T', 't')));
1845 int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
1852 float beta = *p_beta;
1855 int nthr_m, nthr_n, nthr_k, nthr_mn;
1857 assert(nthr <= nthrs_);
1859 // Determine threading partitioning
1860 gemm_utils::calc_nthr_nocopy_avx512_common(
1861 m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
1862 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
1864 // May not happen, but just in case
1865 if (nthr < nthr_m * nthr_n * nthr_k)
1866 nthr = nthr_m * nthr_n * nthr_k;
1868 nthr_mn = nthr_m * nthr_n;
1870 unsigned char * ompstatus_ = nullptr;
1871 unsigned char volatile *ompstatus = nullptr;
1873 float *c_buffers = nullptr;
1874 float *ws_buffers = nullptr;
1877 ompstatus_ = (unsigned char *) malloc(
1878 nthr * CACHE_LINE_SIZE,
1880 ompstatus = (unsigned char volatile *) ompstatus_;
1882 for (int i = 0; i < nthr; i++)
1883 ompstatus[i * CACHE_LINE_SIZE] = 0;
1885 c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
1886 * sizeof(float), PAGE_4K);
1889 const size_t ws_elems_per_thr = k * 48 + 64;
1890 const size_t ws_size_per_thr
1891 = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
1892 if (k > STACK_K_CAPACITY) {
1893 ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
1896 parallel(nthr, [&](const int ithr, const int nthr) {
1897 int ithr_m, ithr_n, ithr_k, ithr_mn;
1898 int m_from, m_to, myM;
1899 int n_from, n_to, myN;
1900 int k_from, k_to, myK;
1902 const float *myA, *myB, *myBias = nullptr;
1903 float *myC = C, myBeta;
1904 float *ws = ws_buffers ?
1905 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
1908 if (ithr < nthr_m * nthr_n * nthr_k) {
1910 ithr_mn = ithr % nthr_mn;
1911 ithr_m = ithr_mn % nthr_m;
1912 ithr_n = ithr_mn / nthr_m;
1913 ithr_k = ithr / nthr_mn;
1915 /* swap ithr_k for performance improvement */
1917 ithr_k = nthr_k - 1;
1918 else if (ithr_k == nthr_k - 1)
1921 m_from = MB * (ithr_m);
1922 m_to = MB * (ithr_m + 1);
1925 myM = m_to - m_from;
1927 n_from = NB * (ithr_n);
1928 n_to = NB * (ithr_n + 1);
1931 myN = n_to - n_from;
1933 k_from = KB * (ithr_k);
1934 k_to = KB * (ithr_k + 1);
1937 myK = k_to - k_from;
1939 cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
1940 ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
1942 if ((myM > 0) && (myN > 0)) {
1944 if (*transa == 'N' || *transa == 'n') {
1945 myA = &(A[m_from + k_from * lda]);
1947 myA = &(A[k_from + m_from * lda]);
1949 if (*transb == 'N' || *transb == 'n') {
1950 myB = &(B[k_from + n_from * ldb]);
1952 myB = &(B[n_from + k_from * ldb]);
1955 myC = &(C[m_from + n_from * ldc]);
1959 myBias = &(bias[m_from]);
1961 myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
1967 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
1968 lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
1971 ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
1976 // sum matrices partitioned along K dimension
1979 gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
1983 myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
1984 myC = myC + n1 * MB;
1985 /* need to wait until main thread finishes */
1986 while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
1989 /* my cache is hot */
1990 gemm_utils::sum_two_matrices(myM, n2, myC, MB,
1991 &C[m_from + (n_from + n1) * ldc], ldc);
1994 for (int ik = 1; ik < nthr_k; ++ik) {
1997 myC = c_buffers + MB * NB * (cbase + ik - 1);
1998 myC = myC + n1 * MB;
2000 while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
2003 gemm_utils::sum_two_matrices(myM, n2, myC, MB,
2004 &C[m_from + (n_from + n1) * ldc], ldc);
2016 jit_avx512_common_gemm_f32::jit_avx512_common_gemm_f32(
2017 char transa, char transb, float beta, bool hasBias)
2024 assert(beta == 0.0);
2026 ker_bn_ = new xbyak_gemm(transa, transb, beta, hasBias);
2028 ker_b1_ = new xbyak_gemm(transa, transb, 1.0);
2032 if (beta != 0.0 || (beta == 0.0 && hasBias)) {
2033 ker_b0_ = new xbyak_gemm(transa, transb, 0.0);
2038 nthrs_ = mkldnn_get_max_threads();
2041 jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32()
2046 if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
2053 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s