1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_AVX2_GENERATOR_HPP
18 #define CPU_JIT_AVX2_GENERATOR_HPP
21 #include "cpu_isa_traits.hpp"
24 #include "mkldnn_thread.hpp"
26 #ifdef JIT_PROFILING_VTUNE
27 #include "jitprofiling.h"
30 #if defined(_WIN32) && !defined(__GNUC__)
31 # define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__
33 # define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al)))
37 # define OFFSET_SHADOWSPACE 0x28
40 #define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \
41 const char *name() const override { return STRINGIFY(jit_name); } \
42 const char *source_file() const override { return __FILE__; \
49 // TODO: move this to jit_generator class?
57 // TODO: move this somewhere else? Although this is only used by jit kernels
59 static inline int float2int(float x) {
68 // TODO: A GPR class that hides ABI details from the JIT kernels and allows
69 // numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and
70 // stack register (sr).
72 // This will allow using syntax like this:
80 // mov(param, ptr[sr])
86 constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = {
87 Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
88 Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15,
90 Xbyak::Operand::RDI, Xbyak::Operand::RSI,
95 static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX),
96 abi_param2(Xbyak::Operand::RDX),
97 abi_param3(Xbyak::Operand::R8),
98 abi_param4(Xbyak::Operand::R9),
99 abi_not_param1(Xbyak::Operand::RDI);
101 static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI),
102 abi_param2(Xbyak::Operand::RSI),
103 abi_param3(Xbyak::Operand::RDX),
104 abi_param4(Xbyak::Operand::RCX),
105 abi_param5(Xbyak::Operand::R8),
106 abi_param6(Xbyak::Operand::R9),
107 abi_not_param1(Xbyak::Operand::RCX);
111 inline unsigned int get_cache_size(int level, bool per_core = true){
112 unsigned int l = level - 1;
113 // Currently, if XByak is not able to fetch the cache topology
114 // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core.
115 if (cpu.getDataCacheLevels() == 0){
116 const int L1_cache_per_core = 32000;
117 const int L2_cache_per_core = 512000;
118 const int L3_cache_per_core = 1024000;
119 int num_cores = per_core ? 1 : mkldnn_get_max_threads();
121 case(0): return L1_cache_per_core * num_cores;
122 case(1): return L2_cache_per_core * num_cores;
123 case(2): return L3_cache_per_core * num_cores;
127 if (l < cpu.getDataCacheLevels()) {
128 return cpu.getDataCacheSize(l)
129 / (per_core ? cpu.getCoresSharingDataCache(l) : 1);
136 class jit_generator : public Xbyak::CodeGenerator
139 const size_t xmm_len = 16;
141 const size_t xmm_to_preserve_start = 6;
142 const size_t xmm_to_preserve = 10;
144 const size_t xmm_to_preserve_start = 0;
145 const size_t xmm_to_preserve = 0;
148 const size_t num_abi_save_gpr_regs
149 = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]);
151 const size_t size_of_abi_save_regs
152 = num_abi_save_gpr_regs * rax.getBit() / 8
153 + xmm_to_preserve * xmm_len;
167 Xbyak::Reg64 param1 = abi_param1;
168 const int EVEX_max_8b_offt = 0x200;
169 const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
171 inline size_t get_size_of_abi_save_regs() {
172 return size_of_abi_save_regs;
176 if (xmm_to_preserve) {
177 sub(rsp, xmm_to_preserve * xmm_len);
178 for (size_t i = 0; i < xmm_to_preserve; ++i)
179 movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i));
181 for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
182 push(Xbyak::Reg64(abi_save_gpr_regs[i]));
183 if (mayiuse(avx512_common)) {
184 mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
188 void mic_prefetcht0(Xbyak::Address a) {
189 if (mayiuse(avx512_mic))
193 void mic_prefetcht1(Xbyak::Address a) {
194 if (mayiuse(avx512_mic))
198 void mic_prefetcht2(Xbyak::Address a) {
199 if (mayiuse(avx512_mic))
203 void uni_vzeroupper() {
204 if (mayiuse(avx) && !mayiuse(avx512_mic))
209 for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
210 pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i]));
211 if (xmm_to_preserve) {
212 for (size_t i = 0; i < xmm_to_preserve; ++i)
213 movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]);
214 add(rsp, xmm_to_preserve * xmm_len);
221 Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base,
222 T raw_offt, bool bcast = false)
226 using Xbyak::Address;
229 assert(raw_offt <= INT_MAX);
230 auto offt = static_cast<int>(raw_offt);
234 if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
235 offt = offt - 2 * EVEX_max_8b_offt;
237 } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
238 offt = offt - 4 * EVEX_max_8b_offt;
242 auto re = RegExp() + base + offt;
244 re = re + reg_EVEX_max_8b_offt * scale;
252 Xbyak::Address make_safe_addr(const Xbyak::Reg64 ®_out, size_t offt,
253 const Xbyak::Reg64 &tmp_reg, bool bcast = false) {
254 if (offt > INT_MAX) {
256 return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg];
258 return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt];
262 Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base,
263 size_t raw_offt, const Xbyak::Reg64 ®_offt, bool bcast = false) {
264 if (raw_offt > INT_MAX) {
265 return make_safe_addr(base, raw_offt, reg_offt, bcast);
267 return EVEX_compress_addr(base, raw_offt, bcast);
271 void safe_add(const Xbyak::Reg64 &base, size_t raw_offt,
272 const Xbyak::Reg64 ®_offt) {
273 if (raw_offt > INT_MAX) {
274 mov(reg_offt, raw_offt);
281 void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt,
282 const Xbyak::Reg64 ®_offt) {
283 if (raw_offt > INT_MAX) {
284 mov(reg_offt, raw_offt);
291 // Disallow char-based labels completely
292 void L(const char *label) = delete;
293 void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
295 void L_aligned(Xbyak::Label &label, int alignment = 16) {
300 void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
301 const Xbyak::Operand &op) {
302 assert(x1.getIdx() == x2.getIdx());
305 void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
306 const Xbyak::Operand &op) {
313 void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2,
314 const Xbyak::Operand &op) {
318 void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
321 void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
324 void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
327 void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
331 void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
334 void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
337 void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
340 void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
344 void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
347 void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
350 void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) {
354 void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
357 void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
360 void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) {
364 void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
367 void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
371 void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
374 void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
378 void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
381 void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
385 void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
389 void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
390 if (op.isMEM() || mayiuse(avx2)) {
393 Xbyak::Xmm t(x.getIdx());
394 if (t.getIdx() != op.getIdx()) movss(t, op);
395 vinsertf128(x, x, t, 1);
400 void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
404 void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
408 Xbyak::Xmm t(x.getIdx());
409 if (t.getIdx() != op.getIdx()) movsd(t, op);
410 vinsertf128(x, x, t, 1);
415 void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
418 void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) {
419 Xbyak::Xmm x1_(x1.getIdx());
420 Xbyak::Xmm x2_(x2.getIdx());
421 vrcpss(x1_, x1_, x2_);
423 void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) {
424 Xbyak::Xmm x_(x.getIdx());
428 void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
431 void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
434 void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) {
438 void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
439 const Xbyak::Operand &op2 = Xbyak::Operand()) {
440 assert(x.getIdx() == op1.getIdx());
443 void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
444 const Xbyak::Operand &op2 = Xbyak::Operand()) {
448 void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
449 const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
452 if (x.getIdx() != buf.getIdx()) {
457 void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
458 const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
462 void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
463 const Xbyak::Operand &op2 = Xbyak::Operand()) {
464 assert(x.getIdx() == op1.getIdx());
467 void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
468 const Xbyak::Operand &op2 = Xbyak::Operand()) {
471 void uni_vaddss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
472 const Xbyak::Operand &op2 = Xbyak::Operand()) {
473 assert(x.getIdx() == op1.getIdx());
476 void uni_vaddss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
477 const Xbyak::Operand &op2 = Xbyak::Operand()) {
481 void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2,
482 const Xbyak::Operand& op) {
483 assert(x1.getIdx() == x2.getIdx());
486 void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2,
487 const Xbyak::Operand& op) {
491 void uni_vsubss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
492 const Xbyak::Operand &op2 = Xbyak::Operand()) {
493 assert(x.getIdx() == op1.getIdx());
496 void uni_vsubss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
497 const Xbyak::Operand &op2 = Xbyak::Operand()) {
498 vsubss(x, Xbyak::Xmm(op1.getIdx()), Xbyak::Xmm(op2.getIdx()));
501 void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
502 const Xbyak::Operand &op2 = Xbyak::Operand()) {
503 assert(x.getIdx() == op1.getIdx());
506 void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
507 const Xbyak::Operand &op2 = Xbyak::Operand()) {
511 void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
512 const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
515 if (x.getIdx() != buf.getIdx()) {
520 void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
521 const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
525 void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
526 const Xbyak::Operand &op2 = Xbyak::Operand()) {
527 assert(x.getIdx() == op1.getIdx());
530 void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
531 const Xbyak::Operand &op2 = Xbyak::Operand()) {
535 void uni_vmulss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
536 const Xbyak::Operand &op2 = Xbyak::Operand()) {
537 assert(x.getIdx() == op1.getIdx());
540 void uni_vmulss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
541 const Xbyak::Address &op2) {
542 vmulss(x, Xbyak::Xmm(op1.getIdx()), op2);
544 void uni_vmulss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
545 const Xbyak::Ymm &op2) {
546 vmulss(x, Xbyak::Xmm(op1.getIdx()), Xbyak::Xmm(op2.getIdx()));
549 void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
550 const Xbyak::Operand &op) {
554 void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
555 const Xbyak::Operand &op) {
556 vfmadd213ps(x1, x2, op);
559 void uni_vfmadd213ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
560 const Xbyak::Operand &op) {
564 void uni_vfmadd213ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
565 const Xbyak::Operand &op) {
566 vfmadd213ss(x1, x2, op);
569 void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
570 const Xbyak::Operand &op) {
574 void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
575 const Xbyak::Operand &op) {
576 vfmadd231ps(x1, x2, op);
578 void uni_vfmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
579 const Xbyak::Operand &op) {
583 void uni_vfmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
584 const Xbyak::Operand &op) {
585 vfmadd231ss(Xbyak::Xmm(x1.getIdx()), Xbyak::Xmm(x2.getIdx()), op);
588 void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
589 const Xbyak::Operand &op) {
594 void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
595 const Xbyak::Operand &op) {
596 vfnmadd231ps(x1, x2, op);
599 void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
602 void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
606 void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
607 const Xbyak::Operand &op) {
608 assert(x1.getIdx() == x2.getIdx());
611 void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
612 const Xbyak::Operand &op) {
616 void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
617 const Xbyak::Operand &op = Xbyak::Operand()) {
618 assert(x1.getIdx() == x2.getIdx());
621 void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
622 const Xbyak::Operand &op = Xbyak::Operand()) {
623 if (!mayiuse(avx512_common) || x1.getBit() < 512)
629 void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
630 const Xbyak::Operand &op = Xbyak::Operand()) {
631 assert(x1.getIdx() == x2.getIdx());
634 void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
635 const Xbyak::Operand &op = Xbyak::Operand()) {
636 if (!mayiuse(avx512_common) || x1.getBit() < 512)
642 void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
644 assert(x.getIdx() == op.getIdx());
647 void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
652 void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
654 assert(x.getIdx() == op.getIdx());
657 void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
662 void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
663 const Xbyak::Operand &op2 = Xbyak::Operand()) {
664 assert(x.getIdx() == op1.getIdx());
667 void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
668 const Xbyak::Operand &op2 = Xbyak::Operand()) {
672 void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
673 const Xbyak::Operand &op2 = Xbyak::Operand()) {
674 assert(x.getIdx() == op1.getIdx());
677 void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
678 const Xbyak::Operand &op2 = Xbyak::Operand()) {
682 void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
683 const Xbyak::Operand &op) {
684 assert(x1.getIdx() == x2.getIdx());
685 cmpps(x1, op, _cmp_nle_us);
688 void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
689 const Xbyak::Operand &op) {
690 vcmpgtps(x1, x2, op);
693 void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
694 const Xbyak::Operand &op) {
695 assert(x1.getIdx() == x2.getIdx());
696 cmpps(x1, op, _cmp_nlt_us);
699 void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
700 const Xbyak::Operand &op) {
701 vcmpps(x1, x2, op, _cmp_nlt_us);
704 void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) {
708 void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) {
709 assert(!(x1.isZMM() || op.isZMM()));
713 void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
714 const Xbyak::Operand &op, const Xbyak::Xmm &msk) {
715 assert(x1.getIdx() == x2.getIdx());
716 assert(msk.getIdx() == 0);
719 void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
720 const Xbyak::Operand &op, const Xbyak::Ymm &msk) {
721 vblendvps(x1, x2, op, msk);
724 void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op,
728 void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op,
730 vroundps(x, op, imm);
732 void uni_vroundps(const Xbyak::Zmm &x, const Xbyak::Operand &op,
734 vrndscaleps(x, op, imm & 0x3);
737 void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
740 void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
744 void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
747 void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
751 void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) {
752 movmskps(x1.cvt64(), x2);
754 void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) {
758 void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
759 assert(x1.getIdx() == x1.getIdx());
762 void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
763 vpackssdw(x1, x2, op);
766 void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
767 assert(x1.getIdx() == x1.getIdx());
770 void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
771 vpackuswb(x1, x2, op);
774 void uni_vpmovsxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
777 void uni_vpmovsxbd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
781 void uni_vpmovzxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
784 void uni_vpmovzxbd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
788 void uni_vpackusdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
789 assert(x1.getIdx() == x2.getIdx());
792 void uni_vpackusdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
793 vpackusdw(x1, x2, op);
796 void uni_vpacksswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
797 assert(x1.getIdx() == x2.getIdx());
800 void uni_vpacksswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
801 vpacksswb(x1, x2, op);
804 void uni_vpmaxsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
805 assert(x1.getIdx() == x2.getIdx());
808 void uni_vpmaxsd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
812 void uni_vpmaxsb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
813 assert(x1.getIdx() == x2.getIdx());
816 void uni_vpmaxsb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
820 void uni_vpmaxub(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
821 assert(x1.getIdx() == x2.getIdx());
824 void uni_vpmaxub(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
828 void uni_vpmaddubsw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
829 assert(x1.getIdx() == x2.getIdx());
832 void uni_vpmaddubsw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
833 vpmaddubsw(x1, x2, op);
836 void uni_vpmaddwd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
837 assert(x1.getIdx() == x2.getIdx());
840 void uni_vpmaddwd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
841 vpmaddwd(x1, x2, op);
844 void uni_vpmulld(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
845 assert(x1.getIdx() == x2.getIdx());
848 void uni_vpmulld(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
852 void uni_vpsubb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
853 assert(x1.getIdx() == x2.getIdx());
856 void uni_vpsubb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
860 void uni_vpslldq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::uint8 &op) {
861 assert(x1.getIdx() == x2.getIdx());
864 void uni_vpslldq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::uint8 &op) {
868 void uni_vpand(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
869 const Xbyak::Operand &op = Xbyak::Operand()) {
870 assert(x1.getIdx() == x2.getIdx());
873 void uni_vpand(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
874 const Xbyak::Operand &op = Xbyak::Operand()) {
878 void uni_vpaddb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
879 const Xbyak::Operand &op) {
880 assert(x1.getIdx() == x2.getIdx());
883 void uni_vpaddb(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
884 const Xbyak::Operand &op) {
888 void uni_vpshufb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
889 const Xbyak::Operand &op) {
890 assert(x1.getIdx() == x2.getIdx());
894 void uni_vpshufb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
895 const Xbyak::Operand &op) {
899 void uni_vpcmpeqd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
900 const Xbyak::Operand &op) {
901 assert(x1.getIdx() == x2.getIdx());
905 void uni_vpcmpeqd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
906 const Xbyak::Operand &op) {
907 vpcmpeqd(x1, x2, op);
910 void mul_by_const(const Xbyak::Reg &out,
911 const Xbyak::Reg64 &tmp, int value) {
912 // Generates a shift + add sequence for multiplicating contents of the
913 // out register by a known JIT-time value. Clobbers the tmp register.
915 // Pros compared to mul/imul:
916 // - does not require using known registers
917 // - not microcoded on Intel(R) Xeon Phi(TM) processors
918 // Still, there are probably a lot of cases when mul/imul is faster on
919 // Intel(R) Core(TM) processors. Not intended for critical path.
921 // TODO: detect when overflow is emminent (Roma)
922 // TODO: detect when using mul/imul is a better option (Roma)
924 int p = 0; // the current power of 2
925 int old_p = 0; // the last seen power of 2 such that value[old_p] != 0
930 int shift = p - old_p;
943 void dump_code(const Xbyak::uint8 *code) const {
945 static int counter = 0;
946 #define MAX_FNAME_LEN 256
947 char fname[MAX_FNAME_LEN + 1];
948 snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", name(),
952 FILE *fp = mkldnn_fopen(fname, "w+");
953 // Failure to dump code is not fatal
955 size_t unused = fwrite(code, getSize(), 1, fp);
963 void register_code(const Xbyak::uint8 *code) const {
964 #ifdef JIT_PROFILING_VTUNE
965 if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) {
966 auto jmethod = iJIT_Method_Load();
967 jmethod.method_id = iJIT_GetNewMethodID();
968 jmethod.method_name = (char *)name();
969 jmethod.class_file_name = NULL;
970 jmethod.source_file_name = (char *)source_file();
971 jmethod.method_load_address = (void *)code;
972 jmethod.method_size = getSize();
974 iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED,
982 void *code_ptr = nullptr,
983 size_t code_size = 256 * 1024
984 ) : Xbyak::CodeGenerator(code_size, code_ptr)
987 virtual ~jit_generator() {}
989 virtual const char *name() const = 0;
990 virtual const char *source_file() const = 0;
992 // XXX: use normal_case name and update all callees (?)
993 const Xbyak::uint8 *getCode() {
994 const Xbyak::uint8 *code = CodeGenerator::getCode();
997 if (mkldnn_jit_dump())
1003 template<typename F> const F getCode() {
1004 // XXX (Roma): Xbyak code probably has a bug here
1005 return (const F)getCode();