1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_AVX2_GENERATOR_HPP
18 #define CPU_JIT_AVX2_GENERATOR_HPP
21 #include "cpu_isa_traits.hpp"
24 #include "mkldnn_thread.hpp"
26 #ifdef JIT_PROFILING_VTUNE
27 #include "jitprofiling.h"
30 #if defined(_WIN32) && !defined(__GNUC__)
31 # define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__
33 # define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al)))
37 # define OFFSET_SHADOWSPACE 0x28
40 #define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \
41 const char *name() const override { return STRINGIFY(jit_name); } \
42 const char *source_file() const override { return __FILE__; \
49 // TODO: move this to jit_generator class?
57 // TODO: move this somewhere else? Although this is only used by jit kernels
59 static inline int float2int(float x) {
68 // TODO: A GPR class that hides ABI details from the JIT kernels and allows
69 // numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and
70 // stack register (sr).
72 // This will allow using syntax like this:
80 // mov(param, ptr[sr])
86 constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = {
87 Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
88 Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15,
90 Xbyak::Operand::RDI, Xbyak::Operand::RSI,
95 static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX),
96 abi_param2(Xbyak::Operand::RDX),
97 abi_param3(Xbyak::Operand::R8),
98 abi_param4(Xbyak::Operand::R9),
99 abi_not_param1(Xbyak::Operand::RDI);
101 static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI),
102 abi_param2(Xbyak::Operand::RSI),
103 abi_param3(Xbyak::Operand::RDX),
104 abi_param4(Xbyak::Operand::RCX),
105 abi_param5(Xbyak::Operand::R8),
106 abi_param6(Xbyak::Operand::R9),
107 abi_not_param1(Xbyak::Operand::RCX);
111 inline unsigned int get_cache_size(int level, bool per_core = true){
112 unsigned int l = level - 1;
113 // Currently, if XByak is not able to fetch the cache topology
114 // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core.
115 if (cpu.getDataCacheLevels() == 0){
116 const int L1_cache_per_core = 32000;
117 const int L2_cache_per_core = 512000;
118 const int L3_cache_per_core = 1024000;
119 int num_cores = per_core ? 1 : mkldnn_get_max_threads();
121 case(0): return L1_cache_per_core * num_cores;
122 case(1): return L2_cache_per_core * num_cores;
123 case(2): return L3_cache_per_core * num_cores;
127 if (l < cpu.getDataCacheLevels()) {
128 return cpu.getDataCacheSize(l)
129 / (per_core ? cpu.getCoresSharingDataCache(l) : 1);
136 class jit_generator : public Xbyak::CodeGenerator
139 const size_t xmm_len = 16;
141 const size_t xmm_to_preserve_start = 6;
142 const size_t xmm_to_preserve = 10;
144 const size_t xmm_to_preserve_start = 0;
145 const size_t xmm_to_preserve = 0;
148 const size_t num_abi_save_gpr_regs
149 = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]);
151 const size_t size_of_abi_save_regs
152 = num_abi_save_gpr_regs * rax.getBit() / 8
153 + xmm_to_preserve * xmm_len;
167 Xbyak::Reg64 param1 = abi_param1;
168 const int EVEX_max_8b_offt = 0x200;
169 const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
171 inline size_t get_size_of_abi_save_regs() {
172 return size_of_abi_save_regs;
176 if (xmm_to_preserve) {
177 sub(rsp, xmm_to_preserve * xmm_len);
178 for (size_t i = 0; i < xmm_to_preserve; ++i)
179 movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i));
181 for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
182 push(Xbyak::Reg64(abi_save_gpr_regs[i]));
183 if (mayiuse(avx512_common)) {
184 mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
188 void mic_prefetcht0(Xbyak::Address a) {
189 if (mayiuse(avx512_mic))
193 void mic_prefetcht1(Xbyak::Address a) {
194 if (mayiuse(avx512_mic))
198 void mic_prefetcht2(Xbyak::Address a) {
199 if (mayiuse(avx512_mic))
203 void uni_vzeroupper() {
204 if (mayiuse(avx) && !mayiuse(avx512_mic))
209 for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
210 pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i]));
211 if (xmm_to_preserve) {
212 for (size_t i = 0; i < xmm_to_preserve; ++i)
213 movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]);
214 add(rsp, xmm_to_preserve * xmm_len);
221 Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base,
222 T raw_offt, bool bcast = false)
226 using Xbyak::Address;
229 assert(raw_offt <= INT_MAX);
230 auto offt = static_cast<int>(raw_offt);
234 if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
235 offt = offt - 2 * EVEX_max_8b_offt;
237 } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
238 offt = offt - 4 * EVEX_max_8b_offt;
242 auto re = RegExp() + base + offt;
244 re = re + reg_EVEX_max_8b_offt * scale;
252 Xbyak::Address make_safe_addr(const Xbyak::Reg64 ®_out, size_t offt,
253 const Xbyak::Reg64 &tmp_reg, bool bcast = false) {
254 if (offt > INT_MAX) {
256 return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg];
258 return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt];
262 Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base,
263 size_t raw_offt, const Xbyak::Reg64 ®_offt, bool bcast = false) {
264 if (raw_offt > INT_MAX) {
265 return make_safe_addr(base, raw_offt, reg_offt, bcast);
267 return EVEX_compress_addr(base, raw_offt, bcast);
271 void safe_add(const Xbyak::Reg64 &base, size_t raw_offt,
272 const Xbyak::Reg64 ®_offt) {
273 if (raw_offt > INT_MAX) {
274 mov(reg_offt, raw_offt);
281 void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt,
282 const Xbyak::Reg64 ®_offt) {
283 if (raw_offt > INT_MAX) {
284 mov(reg_offt, raw_offt);
291 // Disallow char-based labels completely
292 void L(const char *label) = delete;
293 void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
295 void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
296 const Xbyak::Operand &op) {
297 assert(x1.getIdx() == x2.getIdx());
300 void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
301 const Xbyak::Operand &op) {
308 void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2,
309 const Xbyak::Operand &op) {
313 void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
316 void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
319 void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
322 void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
326 void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
329 void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
332 void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
335 void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
339 void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
342 void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
345 void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) {
349 void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
352 void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
355 void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) {
359 void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
362 void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
366 void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
369 void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
373 void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
376 void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
380 void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
384 void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
385 if (op.isMEM() || mayiuse(avx2)) {
388 Xbyak::Xmm t(x.getIdx());
389 if (t.getIdx() != op.getIdx()) movss(t, op);
390 vinsertf128(x, x, t, 1);
395 void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
399 void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
403 Xbyak::Xmm t(x.getIdx());
404 if (t.getIdx() != op.getIdx()) movsd(t, op);
405 vinsertf128(x, x, t, 1);
410 void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
413 void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) {
414 Xbyak::Xmm x1_(x1.getIdx());
415 Xbyak::Xmm x2_(x2.getIdx());
416 vrcpss(x1_, x1_, x2_);
418 void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) {
419 Xbyak::Xmm x_(x.getIdx());
423 void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
426 void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
429 void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) {
433 void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
434 const Xbyak::Operand &op2 = Xbyak::Operand()) {
435 assert(x.getIdx() == op1.getIdx());
438 void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
439 const Xbyak::Operand &op2 = Xbyak::Operand()) {
443 void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
444 const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
447 if (x.getIdx() != buf.getIdx()) {
452 void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
453 const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
457 void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
458 const Xbyak::Operand &op2 = Xbyak::Operand()) {
459 assert(x.getIdx() == op1.getIdx());
462 void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
463 const Xbyak::Operand &op2 = Xbyak::Operand()) {
467 void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2,
468 const Xbyak::Operand& op) {
469 assert(x1.getIdx() == x2.getIdx());
472 void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2,
473 const Xbyak::Operand& op) {
477 void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
478 const Xbyak::Operand &op2 = Xbyak::Operand()) {
479 assert(x.getIdx() == op1.getIdx());
482 void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
483 const Xbyak::Operand &op2 = Xbyak::Operand()) {
487 void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
488 const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
491 if (x.getIdx() != buf.getIdx()) {
496 void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
497 const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
501 void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
502 const Xbyak::Operand &op2 = Xbyak::Operand()) {
503 assert(x.getIdx() == op1.getIdx());
506 void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
507 const Xbyak::Operand &op2 = Xbyak::Operand()) {
511 void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
512 const Xbyak::Operand &op) {
516 void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
517 const Xbyak::Operand &op) {
518 vfmadd213ps(x1, x2, op);
521 void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
522 const Xbyak::Operand &op) {
526 void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
527 const Xbyak::Operand &op) {
528 vfmadd231ps(x1, x2, op);
531 void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
532 const Xbyak::Operand &op) {
537 void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
538 const Xbyak::Operand &op) {
539 vfnmadd231ps(x1, x2, op);
542 void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
545 void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
549 void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
550 const Xbyak::Operand &op) {
551 assert(x1.getIdx() == x2.getIdx());
554 void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
555 const Xbyak::Operand &op) {
559 void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
560 const Xbyak::Operand &op = Xbyak::Operand()) {
561 assert(x1.getIdx() == x2.getIdx());
564 void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
565 const Xbyak::Operand &op = Xbyak::Operand()) {
566 if (!mayiuse(avx512_common) || x1.getBit() < 512)
572 void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
573 const Xbyak::Operand &op = Xbyak::Operand()) {
574 assert(x1.getIdx() == x2.getIdx());
577 void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
578 const Xbyak::Operand &op = Xbyak::Operand()) {
579 if (!mayiuse(avx512_common) || x1.getBit() < 512)
585 void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
587 assert(x.getIdx() == op.getIdx());
590 void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
595 void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
597 assert(x.getIdx() == op.getIdx());
600 void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
605 void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
606 const Xbyak::Operand &op2 = Xbyak::Operand()) {
607 assert(x.getIdx() == op1.getIdx());
610 void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
611 const Xbyak::Operand &op2 = Xbyak::Operand()) {
615 void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
616 const Xbyak::Operand &op2 = Xbyak::Operand()) {
617 assert(x.getIdx() == op1.getIdx());
620 void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
621 const Xbyak::Operand &op2 = Xbyak::Operand()) {
625 void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
626 const Xbyak::Operand &op) {
627 assert(x1.getIdx() == x2.getIdx());
628 cmpps(x1, op, _cmp_nle_us);
631 void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
632 const Xbyak::Operand &op) {
633 vcmpgtps(x1, x2, op);
636 void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
637 const Xbyak::Operand &op) {
638 assert(x1.getIdx() == x2.getIdx());
639 cmpps(x1, op, _cmp_nlt_us);
642 void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
643 const Xbyak::Operand &op) {
644 vcmpps(x1, x2, op, _cmp_nlt_us);
647 void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) {
651 void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) {
652 assert(!(x1.isZMM() || op.isZMM()));
656 void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
657 const Xbyak::Operand &op, const Xbyak::Xmm &msk) {
658 assert(x1.getIdx() == x2.getIdx());
659 assert(msk.getIdx() == 0);
662 void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
663 const Xbyak::Operand &op, const Xbyak::Ymm &msk) {
664 vblendvps(x1, x2, op, msk);
667 void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op,
671 void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op,
673 vroundps(x, op, imm);
676 void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
679 void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
683 void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
686 void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
690 void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) {
691 movmskps(x1.cvt64(), x2);
693 void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) {
697 void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
698 assert(x1.getIdx() == x1.getIdx());
701 void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
702 vpackssdw(x1, x2, op);
705 void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
706 assert(x1.getIdx() == x1.getIdx());
709 void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
710 vpackuswb(x1, x2, op);
713 void uni_vpmovsxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
716 void uni_vpmovsxbd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
720 void uni_vpmovzxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
723 void uni_vpmovzxbd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
727 void uni_vpackusdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
728 assert(x1.getIdx() == x2.getIdx());
731 void uni_vpackusdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
732 vpackusdw(x1, x2, op);
735 void uni_vpacksswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
736 assert(x1.getIdx() == x2.getIdx());
739 void uni_vpacksswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
740 vpacksswb(x1, x2, op);
743 void uni_vpmaxsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
744 assert(x1.getIdx() == x2.getIdx());
747 void uni_vpmaxsd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
751 void uni_vpmaxsb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
752 assert(x1.getIdx() == x2.getIdx());
755 void uni_vpmaxsb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
759 void uni_vpmaxub(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
760 assert(x1.getIdx() == x2.getIdx());
763 void uni_vpmaxub(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
767 void uni_vpmaddubsw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
768 assert(x1.getIdx() == x2.getIdx());
771 void uni_vpmaddubsw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
772 vpmaddubsw(x1, x2, op);
775 void uni_vpmaddwd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
776 assert(x1.getIdx() == x2.getIdx());
779 void uni_vpmaddwd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
780 vpmaddwd(x1, x2, op);
783 void uni_vpmulld(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
784 assert(x1.getIdx() == x2.getIdx());
787 void uni_vpmulld(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
791 void uni_vpsubb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
792 assert(x1.getIdx() == x2.getIdx());
795 void uni_vpsubb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
799 void uni_vpslldq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::uint8 &op) {
800 assert(x1.getIdx() == x2.getIdx());
803 void uni_vpslldq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::uint8 &op) {
807 void uni_vpand(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
808 const Xbyak::Operand &op = Xbyak::Operand()) {
809 assert(x1.getIdx() == x2.getIdx());
812 void uni_vpand(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
813 const Xbyak::Operand &op = Xbyak::Operand()) {
817 void uni_vpaddb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
818 const Xbyak::Operand &op) {
819 assert(x1.getIdx() == x2.getIdx());
822 void uni_vpaddb(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
823 const Xbyak::Operand &op) {
827 void uni_vpshufb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
828 const Xbyak::Operand &op) {
829 assert(x1.getIdx() == x2.getIdx());
833 void uni_vpshufb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
834 const Xbyak::Operand &op) {
838 void mul_by_const(const Xbyak::Reg &out,
839 const Xbyak::Reg64 &tmp, int value) {
840 // Generates a shift + add sequence for multiplicating contents of the
841 // out register by a known JIT-time value. Clobbers the tmp register.
843 // Pros compared to mul/imul:
844 // - does not require using known registers
845 // - not microcoded on Intel(R) Xeon Phi(TM) processors
846 // Still, there are probably a lot of cases when mul/imul is faster on
847 // Intel(R) Core(TM) processors. Not intended for critical path.
849 // TODO: detect when overflow is emminent (Roma)
850 // TODO: detect when using mul/imul is a better option (Roma)
852 int p = 0; // the current power of 2
853 int old_p = 0; // the last seen power of 2 such that value[old_p] != 0
858 int shift = p - old_p;
871 void dump_code(const Xbyak::uint8 *code) const {
873 static int counter = 0;
874 #define MAX_FNAME_LEN 256
875 char fname[MAX_FNAME_LEN + 1];
876 snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", name(),
880 FILE *fp = mkldnn_fopen(fname, "w+");
881 // Failure to dump code is not fatal
883 size_t unused = fwrite(code, getSize(), 1, fp);
891 void register_code(const Xbyak::uint8 *code) const {
892 #ifdef JIT_PROFILING_VTUNE
893 if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) {
894 auto jmethod = iJIT_Method_Load();
895 jmethod.method_id = iJIT_GetNewMethodID();
896 jmethod.method_name = (char *)name();
897 jmethod.class_file_name = NULL;
898 jmethod.source_file_name = (char *)source_file();
899 jmethod.method_load_address = (void *)code;
900 jmethod.method_size = getSize();
902 iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED,
910 void *code_ptr = nullptr,
911 size_t code_size = 256 * 1024
912 ) : Xbyak::CodeGenerator(code_size, code_ptr)
915 virtual ~jit_generator() {}
917 virtual const char *name() const = 0;
918 virtual const char *source_file() const = 0;
920 // XXX: use normal_case name and update all callees (?)
921 const Xbyak::uint8 *getCode() {
922 const Xbyak::uint8 *code = CodeGenerator::getCode();
925 if (mkldnn_jit_dump())
931 template<typename F> const F getCode() {
932 // XXX (Roma): Xbyak code probably has a bug here
933 return (const F)getCode();