updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_generator.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_AVX2_GENERATOR_HPP
18 #define CPU_JIT_AVX2_GENERATOR_HPP
19
20 #include <limits.h>
21 #include "cpu_isa_traits.hpp"
22
23 #include "utils.hpp"
24 #include "mkldnn_thread.hpp"
25
26 #ifdef JIT_PROFILING_VTUNE
27 #include "jitprofiling.h"
28 #endif
29
30 #if defined(_WIN32) && !defined(__GNUC__)
31 #   define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__
32 #else
33 #   define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al)))
34 #endif
35
36 #if defined(_WIN32)
37 #   define OFFSET_SHADOWSPACE 0x28
38 #endif
39
40 #define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \
41     const char *name() const override { return STRINGIFY(jit_name); } \
42     const char *source_file() const override { return __FILE__; \
43     }
44
45 namespace mkldnn {
46 namespace impl {
47 namespace cpu {
48
49 // TODO: move this to jit_generator class?
50 namespace {
51
52 typedef enum {
53     PAGE_4K = 4096,
54     PAGE_2M = 2097152,
55 } cpu_page_size_t;
56
57 // TODO: move this somewhere else? Although this is only used by jit kernels
58 // (Roma)
59 static inline int float2int(float x) {
60     union {
61         float vfloat;
62         int vint;
63     } cvt;
64     cvt.vfloat = x;
65     return cvt.vint;
66 }
67
68 // TODO: A GPR class that hides ABI details from the JIT kernels and allows
69 // numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and
70 // stack register (sr).
71 //
72 // This will allow using syntax like this:
73 //
74 // param = gpr0;
75 // reg_input = gpr0;
76 // reg_output =  gpr1;
77 // ...
78 //
79 // #ifndef XBYAK64
80 // mov(param, ptr[sr])
81 // #endif
82 //
83 // (Roma)
84
85 #ifdef XBYAK64
86 constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = {
87     Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
88     Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15,
89 #ifdef _WIN32
90     Xbyak::Operand::RDI, Xbyak::Operand::RSI,
91 #endif
92 };
93
94 #ifdef _WIN32
95 static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX),
96              abi_param2(Xbyak::Operand::RDX),
97              abi_param3(Xbyak::Operand::R8),
98              abi_param4(Xbyak::Operand::R9),
99              abi_not_param1(Xbyak::Operand::RDI);
100 #else
101 static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI),
102              abi_param2(Xbyak::Operand::RSI),
103              abi_param3(Xbyak::Operand::RDX),
104              abi_param4(Xbyak::Operand::RCX),
105              abi_param5(Xbyak::Operand::R8),
106              abi_param6(Xbyak::Operand::R9),
107              abi_not_param1(Xbyak::Operand::RCX);
108 #endif
109 #endif
110
111 inline unsigned int get_cache_size(int level, bool per_core = true){
112     unsigned int l = level - 1;
113     // Currently, if XByak is not able to fetch the cache topology
114     // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core.
115     if (cpu.getDataCacheLevels() == 0){
116         const int L1_cache_per_core = 32000;
117         const int L2_cache_per_core = 512000;
118         const int L3_cache_per_core = 1024000;
119         int num_cores = per_core ? 1 : mkldnn_get_max_threads();
120         switch(l){
121         case(0): return L1_cache_per_core * num_cores;
122         case(1): return L2_cache_per_core * num_cores;
123         case(2): return L3_cache_per_core * num_cores;
124         default: return 0;
125         }
126     }
127     if (l < cpu.getDataCacheLevels()) {
128         return cpu.getDataCacheSize(l)
129             / (per_core ? cpu.getCoresSharingDataCache(l) : 1);
130     } else
131         return 0;
132 }
133
134 }
135
136 class jit_generator : public Xbyak::CodeGenerator
137 {
138 private:
139     const size_t xmm_len = 16;
140 #ifdef _WIN32
141     const size_t xmm_to_preserve_start = 6;
142     const size_t xmm_to_preserve = 10;
143 #else
144     const size_t xmm_to_preserve_start = 0;
145     const size_t xmm_to_preserve = 0;
146 #endif
147
148     const size_t num_abi_save_gpr_regs
149         = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]);
150
151     const size_t size_of_abi_save_regs
152         = num_abi_save_gpr_regs * rax.getBit() / 8
153         + xmm_to_preserve * xmm_len;
154
155 public:
156     enum {
157         _cmp_eq_oq = 0u,
158         _cmp_lt_os = 1u,
159         _cmp_le_os = 2u,
160         _cmp_neq_uq = 4u,
161         _cmp_nlt_us = 5u,
162         _cmp_nle_us = 6u,
163
164         _op_floor = 1u,
165     };
166
167     Xbyak::Reg64 param1 = abi_param1;
168     const int EVEX_max_8b_offt = 0x200;
169     const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
170
171     inline size_t get_size_of_abi_save_regs() {
172         return size_of_abi_save_regs;
173     }
174
175     void preamble() {
176         if (xmm_to_preserve) {
177             sub(rsp, xmm_to_preserve * xmm_len);
178             for (size_t i = 0; i < xmm_to_preserve; ++i)
179                 movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i));
180         }
181         for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
182             push(Xbyak::Reg64(abi_save_gpr_regs[i]));
183         if (mayiuse(avx512_common)) {
184             mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
185         }
186     }
187
188     void mic_prefetcht0(Xbyak::Address a) {
189         if (mayiuse(avx512_mic))
190             prefetcht0(a);
191     }
192
193     void mic_prefetcht1(Xbyak::Address a) {
194         if (mayiuse(avx512_mic))
195             prefetcht1(a);
196     }
197
198     void mic_prefetcht2(Xbyak::Address a) {
199         if (mayiuse(avx512_mic))
200             prefetcht2(a);
201     }
202
203     void uni_vzeroupper() {
204         if (mayiuse(avx) && !mayiuse(avx512_mic))
205             vzeroupper();
206     }
207
208     void postamble() {
209         for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
210             pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i]));
211         if (xmm_to_preserve) {
212             for (size_t i = 0; i < xmm_to_preserve; ++i)
213                 movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]);
214             add(rsp, xmm_to_preserve * xmm_len);
215         }
216         uni_vzeroupper();
217         ret();
218     }
219
220     template<typename T>
221     Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base,
222             T raw_offt, bool bcast = false)
223     {
224         using Xbyak::Zmm;
225         using Xbyak::Reg64;
226         using Xbyak::Address;
227         using Xbyak::RegExp;
228
229         assert(raw_offt <= INT_MAX);
230         auto offt = static_cast<int>(raw_offt);
231
232         int scale = 0;
233
234         if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
235             offt = offt - 2 * EVEX_max_8b_offt;
236             scale = 1;
237         } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
238             offt = offt - 4 * EVEX_max_8b_offt;
239             scale = 2;
240         }
241
242         auto re = RegExp() + base + offt;
243         if (scale)
244             re = re + reg_EVEX_max_8b_offt * scale;
245
246         if (bcast)
247             return zword_b [re];
248         else
249             return zword [re];
250     }
251
252     Xbyak::Address make_safe_addr(const Xbyak::Reg64 &reg_out, size_t offt,
253         const Xbyak::Reg64 &tmp_reg, bool bcast = false) {
254         if (offt > INT_MAX) {
255             mov(tmp_reg, offt);
256             return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg];
257         } else {
258             return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt];
259         }
260     }
261
262     Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base,
263         size_t raw_offt, const Xbyak::Reg64 &reg_offt, bool bcast = false) {
264         if (raw_offt > INT_MAX) {
265             return make_safe_addr(base, raw_offt, reg_offt, bcast);
266         } else {
267             return EVEX_compress_addr(base, raw_offt, bcast);
268         }
269     }
270
271     void safe_add(const Xbyak::Reg64 &base, size_t raw_offt,
272         const Xbyak::Reg64 &reg_offt) {
273         if (raw_offt > INT_MAX) {
274             mov(reg_offt, raw_offt);
275             add(base, reg_offt);
276         } else {
277             add(base, raw_offt);
278         }
279     }
280
281     void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt,
282         const Xbyak::Reg64 &reg_offt) {
283         if (raw_offt > INT_MAX) {
284             mov(reg_offt, raw_offt);
285             sub(base, reg_offt);
286         } else {
287             sub(base, raw_offt);
288         }
289     }
290
291     // Disallow char-based labels completely
292     void L(const char *label) = delete;
293     void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
294
295     void L_aligned(Xbyak::Label &label, int alignment = 16) {
296         align(alignment);
297         L(label);
298     }
299
300     void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
301                    const Xbyak::Operand &op) {
302         assert(x1.getIdx() == x2.getIdx());
303         pxor(x2, op);
304     }
305     void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
306                    const Xbyak::Operand &op) {
307         if (mayiuse(avx2)) {
308             vpxor(x1, x2, op);
309         } else {
310             vxorps(x1, x2, op);
311         }
312     }
313     void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2,
314                    const Xbyak::Operand &op) {
315         vpxord(x1, x2, op);
316     }
317
318     void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
319         movss(addr, x);
320     }
321     void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
322         vmovss(addr, x);
323     }
324     void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
325         movss(x, addr);
326     }
327     void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
328         vmovss(x, addr);
329     }
330
331     void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
332         movsd(addr, x);
333     }
334     void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
335         vmovsd(addr, x);
336     }
337     void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
338         movsd(x, addr);
339     }
340     void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
341         vmovsd(x, addr);
342     }
343
344     void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
345         movdqu(addr, x);
346     }
347     void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
348         vmovdqu(addr, x);
349     }
350     void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) {
351         vmovdqu32(addr, x);
352     }
353
354     void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
355         movdqu(x, addr);
356     }
357     void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
358         vmovdqu(x, addr);
359     }
360     void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) {
361         vmovdqu32(x, addr);
362     }
363
364     void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
365         movups(addr, x);
366     }
367     void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
368         vmovups(addr, x);
369     }
370
371     void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
372         movups(x, op);
373     }
374     void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
375         vmovups(x, op);
376     }
377
378     void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
379         movntps(addr, x);
380     }
381     void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
382         vmovntps(addr, x);
383     }
384
385     void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
386         movss(x, op);
387         shufps(x, x, 0x0);
388     }
389     void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
390         if (op.isMEM() || mayiuse(avx2)) {
391             vbroadcastss(x, op);
392         } else {
393             Xbyak::Xmm t(x.getIdx());
394             if (t.getIdx() != op.getIdx()) movss(t, op);
395             vinsertf128(x, x, t, 1);
396             vshufps(x, x, x, 0);
397         }
398     }
399
400     void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
401         movsd(x, op);
402         pshufd(x, x, 0x0);
403     }
404     void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
405         if (mayiuse(avx2)) {
406             vpbroadcastd(x, op);
407         } else {
408             Xbyak::Xmm t(x.getIdx());
409             if (t.getIdx() != op.getIdx()) movsd(t, op);
410             vinsertf128(x, x, t, 1);
411             vshufps(x, x, x, 0);
412         }
413     }
414
415     void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
416         rcpss(x, op);
417     }
418     void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) {
419         Xbyak::Xmm x1_(x1.getIdx());
420         Xbyak::Xmm x2_(x2.getIdx());
421         vrcpss(x1_, x1_, x2_);
422     }
423     void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) {
424         Xbyak::Xmm x_(x.getIdx());
425         vrcpss(x_, x_, op);
426     }
427
428     void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
429         rcpps(x, op);
430     }
431     void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
432         vrcpps(x, op);
433     }
434     void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) {
435         vrcp14ps(x, op);
436     }
437
438     void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
439                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
440         assert(x.getIdx() == op1.getIdx());
441         divps(x, op2);
442     }
443     void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
444                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
445         vdivps(x, op1, op2);
446     }
447
448     void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
449                     const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
450         movups(buf, op1);
451         divps(buf, op2);
452         if (x.getIdx() != buf.getIdx()) {
453             movups(x, buf);
454         }
455     }
456
457     void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
458                     const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
459         vdivps(x, op1, op2);
460     }
461
462     void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
463                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
464         assert(x.getIdx() == op1.getIdx());
465         addps(x, op2);
466     }
467     void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
468                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
469         vaddps(x, op1, op2);
470     }
471     void uni_vaddss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
472                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
473         assert(x.getIdx() == op1.getIdx());
474         addss(x, op2);
475     }
476     void uni_vaddss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
477                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
478         vaddss(x, op1, op2);
479     }
480
481     void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2,
482                      const Xbyak::Operand& op) {
483         assert(x1.getIdx() == x2.getIdx());
484         psignd(x1, op);
485     }
486     void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2,
487                      const Xbyak::Operand& op) {
488         vpsignd(x1, x2, op);
489     }
490
491     void uni_vsubss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
492                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
493         assert(x.getIdx() == op1.getIdx());
494         subps(x, op2);
495     }
496     void uni_vsubss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
497                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
498         vsubss(x, Xbyak::Xmm(op1.getIdx()), Xbyak::Xmm(op2.getIdx()));
499     }
500
501     void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
502                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
503         assert(x.getIdx() == op1.getIdx());
504         subps(x, op2);
505     }
506     void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
507                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
508         vsubps(x, op1, op2);
509     }
510
511     void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
512                     const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
513         movups(buf, op1);
514         subps(buf, op2);
515         if (x.getIdx() != buf.getIdx()) {
516             movups(x, buf);
517         }
518     }
519
520     void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
521                     const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
522         vsubps(x, op1, op2);
523     }
524
525     void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
526                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
527         assert(x.getIdx() == op1.getIdx());
528         mulps(x, op2);
529     }
530     void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
531                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
532         vmulps(x, op1, op2);
533     }
534
535     void uni_vmulss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
536                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
537         assert(x.getIdx() == op1.getIdx());
538         mulss(x, op2);
539     }
540     void uni_vmulss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
541                     const Xbyak::Address &op2) {
542         vmulss(x, Xbyak::Xmm(op1.getIdx()), op2);
543     }
544     void uni_vmulss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
545                     const Xbyak::Ymm &op2) {
546         vmulss(x, Xbyak::Xmm(op1.getIdx()), Xbyak::Xmm(op2.getIdx()));
547     }
548
549     void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
550                          const Xbyak::Operand &op) {
551         mulps(x1, x2);
552         addps(x1, op);
553     }
554     void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
555                          const Xbyak::Operand &op) {
556         vfmadd213ps(x1, x2, op);
557     }
558
559     void uni_vfmadd213ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
560                          const Xbyak::Operand &op) {
561         mulss(x1, x2);
562         addss(x1, op);
563     }
564     void uni_vfmadd213ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
565                          const Xbyak::Operand &op) {
566         vfmadd213ss(x1, x2, op);
567     }
568
569     void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
570                          const Xbyak::Operand &op) {
571         mulps(x2, op);
572         addps(x1, x2);
573     }
574     void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
575                          const Xbyak::Operand &op) {
576         vfmadd231ps(x1, x2, op);
577     }
578     void uni_vfmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
579                          const Xbyak::Operand &op) {
580         mulss(x2, op);
581         addss(x1, x2);
582     }
583     void uni_vfmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
584                          const Xbyak::Operand &op) {
585         vfmadd231ss(Xbyak::Xmm(x1.getIdx()), Xbyak::Xmm(x2.getIdx()), op);
586     }
587
588     void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
589                            const Xbyak::Operand &op) {
590         mulps(x2, op);
591         subps(x1, x2);
592     }
593
594     void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
595                            const Xbyak::Operand &op) {
596         vfnmadd231ps(x1, x2, op);
597     }
598
599     void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
600         sqrtps(x, op);
601     }
602     void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
603         vsqrtps(x, op);
604     }
605
606     void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
607                     const Xbyak::Operand &op) {
608         assert(x1.getIdx() == x2.getIdx());
609         paddd(x2, op);
610     }
611     void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
612                     const Xbyak::Operand &op) {
613         vpaddd(x1, x2, op);
614     }
615
616     void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
617                     const Xbyak::Operand &op = Xbyak::Operand()) {
618         assert(x1.getIdx() == x2.getIdx());
619         andps(x1, op);
620     }
621     void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
622                     const Xbyak::Operand &op = Xbyak::Operand()) {
623         if (!mayiuse(avx512_common) || x1.getBit() < 512)
624             vandps(x1, x2, op);
625         else
626             vpandd(x1, x2, op);
627     }
628
629     void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
630                     const Xbyak::Operand &op = Xbyak::Operand()) {
631         assert(x1.getIdx() == x2.getIdx());
632         orps(x1, op);
633     }
634     void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
635                     const Xbyak::Operand &op = Xbyak::Operand()) {
636         if (!mayiuse(avx512_common) || x1.getBit() < 512)
637             vorps(x1, x2, op);
638         else
639             vpord(x1, x2, op);
640     }
641
642     void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
643                     const int imm) {
644         assert(x.getIdx() == op.getIdx());
645         pslld(x, imm);
646     }
647     void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
648                     const int imm) {
649         vpslld(x, op, imm);
650     }
651
652     void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
653                     const int imm) {
654         assert(x.getIdx() == op.getIdx());
655         psrld(x, imm);
656     }
657     void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
658                     const int imm) {
659         vpsrld(x, op, imm);
660     }
661
662     void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
663                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
664         assert(x.getIdx() == op1.getIdx());
665         maxps(x, op2);
666     }
667     void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
668                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
669         vmaxps(x, op1, op2);
670     }
671
672     void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
673                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
674         assert(x.getIdx() == op1.getIdx());
675         minps(x, op2);
676     }
677     void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
678                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
679         vminps(x, op1, op2);
680     }
681
682     void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
683                       const Xbyak::Operand &op) {
684         assert(x1.getIdx() == x2.getIdx());
685         cmpps(x1, op, _cmp_nle_us);
686     }
687
688     void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
689                       const Xbyak::Operand &op) {
690         vcmpgtps(x1, x2, op);
691     }
692
693     void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
694                       const Xbyak::Operand &op) {
695         assert(x1.getIdx() == x2.getIdx());
696         cmpps(x1, op, _cmp_nlt_us);
697     }
698
699     void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
700                       const Xbyak::Operand &op) {
701         vcmpps(x1, x2, op, _cmp_nlt_us);
702     }
703
704     void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) {
705         ptest(x1, op);
706     }
707
708     void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) {
709         assert(!(x1.isZMM() || op.isZMM()));
710         vtestps(x1, op);
711     }
712
713     void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
714                        const Xbyak::Operand &op, const Xbyak::Xmm &msk) {
715         assert(x1.getIdx() == x2.getIdx());
716         assert(msk.getIdx() == 0);
717         blendvps(x1, op);
718     }
719     void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
720                        const Xbyak::Operand &op, const Xbyak::Ymm &msk) {
721         vblendvps(x1, x2, op, msk);
722     }
723
724     void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op,
725                       const int imm) {
726         roundps(x, op, imm);
727     }
728     void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op,
729                       const int imm) {
730         vroundps(x, op, imm);
731     }
732     void uni_vroundps(const Xbyak::Zmm &x, const Xbyak::Operand &op,
733                       const int imm) {
734         vrndscaleps(x, op, imm & 0x3);
735     }
736
737     void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
738         cvtps2dq(x, op);
739     }
740     void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
741         vcvtps2dq(x, op);
742     }
743
744     void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
745         cvtdq2ps(x, op);
746     }
747     void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
748         vcvtdq2ps(x, op);
749     }
750
751     void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) {
752         movmskps(x1.cvt64(), x2);
753     }
754     void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) {
755         vmovmskps(x1, x2);
756     }
757
758     void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
759         assert(x1.getIdx() == x1.getIdx());
760         packssdw(x1, op);
761     }
762     void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
763         vpackssdw(x1, x2, op);
764     }
765
766     void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
767         assert(x1.getIdx() == x1.getIdx());
768         packuswb(x1, op);
769     }
770     void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
771         vpackuswb(x1, x2, op);
772     }
773
774     void uni_vpmovsxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
775         pmovsxbd(x, op);
776     }
777     void uni_vpmovsxbd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
778         vpmovsxbd(x, op);
779     }
780
781     void uni_vpmovzxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
782         pmovzxbd(x, op);
783     }
784     void uni_vpmovzxbd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
785         vpmovzxbd(x, op);
786     }
787
788     void uni_vpackusdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
789         assert(x1.getIdx() == x2.getIdx());
790         packusdw(x1, op);
791     }
792     void uni_vpackusdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
793         vpackusdw(x1, x2, op);
794     }
795
796     void uni_vpacksswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
797         assert(x1.getIdx() == x2.getIdx());
798         packsswb(x1, op);
799     }
800     void uni_vpacksswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
801         vpacksswb(x1, x2, op);
802     }
803
804     void uni_vpmaxsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
805         assert(x1.getIdx() == x2.getIdx());
806         pmaxsd(x1, op);
807     }
808     void uni_vpmaxsd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
809         vpmaxsd(x1, x2, op);
810     }
811
812     void uni_vpmaxsb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
813         assert(x1.getIdx() == x2.getIdx());
814         pmaxsb(x1, op);
815     }
816     void uni_vpmaxsb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
817         vpmaxsb(x1, x2, op);
818     }
819
820     void uni_vpmaxub(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
821         assert(x1.getIdx() == x2.getIdx());
822         pmaxub(x1, op);
823     }
824     void uni_vpmaxub(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
825         vpmaxub(x1, x2, op);
826     }
827
828     void uni_vpmaddubsw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
829         assert(x1.getIdx() == x2.getIdx());
830         pmaddubsw(x1, op);
831     }
832     void uni_vpmaddubsw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
833         vpmaddubsw(x1, x2, op);
834     }
835
836     void uni_vpmaddwd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
837         assert(x1.getIdx() == x2.getIdx());
838         pmaddwd(x1, op);
839     }
840     void uni_vpmaddwd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
841         vpmaddwd(x1, x2, op);
842     }
843
844     void uni_vpmulld(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
845         assert(x1.getIdx() == x2.getIdx());
846         pmulld(x1, op);
847     }
848     void uni_vpmulld(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
849         vpmulld(x1, x2, op);
850     }
851
852     void uni_vpsubb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
853         assert(x1.getIdx() == x2.getIdx());
854         psubb(x1, op);
855     }
856     void uni_vpsubb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
857         vpsubb(x1, x2, op);
858     }
859
860     void uni_vpslldq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::uint8 &op) {
861         assert(x1.getIdx() == x2.getIdx());
862         pslldq(x1, op);
863     }
864     void uni_vpslldq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::uint8 &op) {
865         vpslldq(x1, x2, op);
866     }
867
868     void uni_vpand(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
869                    const Xbyak::Operand &op = Xbyak::Operand()) {
870         assert(x1.getIdx() == x2.getIdx());
871         pand(x1, op);
872     }
873     void uni_vpand(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
874                     const Xbyak::Operand &op = Xbyak::Operand()) {
875         vpand(x1, x2, op);
876     }
877
878     void uni_vpaddb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
879                     const Xbyak::Operand &op) {
880         assert(x1.getIdx() == x2.getIdx());
881         paddb(x2, op);
882     }
883     void uni_vpaddb(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
884                     const Xbyak::Operand &op) {
885         vpaddb(x1, x2, op);
886     }
887
888     void uni_vpshufb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
889                      const Xbyak::Operand &op) {
890         assert(x1.getIdx() == x2.getIdx());
891         pshufb(x1, op);
892     }
893
894     void uni_vpshufb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
895                      const Xbyak::Operand &op) {
896         vpshufb(x1, x2, op);
897     }
898
899     void uni_vpcmpeqd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
900                       const Xbyak::Operand &op) {
901         assert(x1.getIdx() == x2.getIdx());
902         pcmpeqd(x1, op);
903     }
904
905     void uni_vpcmpeqd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
906                       const Xbyak::Operand &op) {
907         vpcmpeqd(x1, x2, op);
908     }
909
910     void mul_by_const(const Xbyak::Reg &out,
911             const Xbyak::Reg64 &tmp, int value) {
912         // Generates a shift + add sequence for multiplicating contents of the
913         // out register by a known JIT-time value. Clobbers the tmp register.
914         //
915         // Pros compared to mul/imul:
916         // - does not require using known registers
917         // - not microcoded on Intel(R) Xeon Phi(TM) processors
918         // Still, there are probably a lot of cases when mul/imul is faster on
919         // Intel(R) Core(TM) processors. Not intended for critical path.
920
921         // TODO: detect when overflow is emminent (Roma)
922         // TODO: detect when using mul/imul is a better option (Roma)
923
924         int p = 0; // the current power of 2
925         int old_p = 0; // the last seen power of 2 such that value[old_p] != 0
926
927         xor_(tmp, tmp);
928         while (value) {
929             if (value & 1) {
930                 int shift = p - old_p;
931                 if (shift) {
932                     shl(out, shift);
933                     old_p = p;
934                 }
935                 add(tmp, out);
936             }
937             value >>= 1;
938             p++;
939         }
940         mov(out, tmp);
941     }
942
943     void dump_code(const Xbyak::uint8 *code) const {
944         if (code) {
945             static int counter = 0;
946 #define MAX_FNAME_LEN 256
947             char fname[MAX_FNAME_LEN + 1];
948             snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", name(),
949                     counter);
950             counter++;
951
952             FILE *fp = mkldnn_fopen(fname, "w+");
953             // Failure to dump code is not fatal
954             if (fp) {
955                 size_t unused = fwrite(code, getSize(), 1, fp);
956                 UNUSED(unused);
957                 fclose(fp);
958             }
959         }
960 #undef MAX_FNAME_LEN
961     }
962
963     void register_code(const Xbyak::uint8 *code) const {
964 #ifdef JIT_PROFILING_VTUNE
965         if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) {
966             auto jmethod = iJIT_Method_Load();
967             jmethod.method_id = iJIT_GetNewMethodID();
968             jmethod.method_name = (char *)name();
969             jmethod.class_file_name = NULL;
970             jmethod.source_file_name = (char *)source_file();
971             jmethod.method_load_address = (void *)code;
972             jmethod.method_size = getSize();
973
974             iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED,
975                     (void*)&jmethod);
976         }
977 #endif
978     }
979
980 public:
981     jit_generator(
982         void *code_ptr = nullptr,
983         size_t code_size = 256 * 1024
984         ) : Xbyak::CodeGenerator(code_size, code_ptr)
985     {
986     }
987     virtual ~jit_generator() {}
988
989     virtual const char *name() const = 0;
990     virtual const char *source_file() const = 0;
991
992     // XXX: use normal_case name and update all callees (?)
993     const Xbyak::uint8 *getCode() {
994         const Xbyak::uint8 *code = CodeGenerator::getCode();
995         register_code(code);
996
997         if (mkldnn_jit_dump())
998             dump_code(code);
999
1000         return code;
1001     }
1002
1003     template<typename F> const F getCode() {
1004         // XXX (Roma): Xbyak code probably has a bug here
1005         return (const F)getCode();
1006     }
1007 };
1008
1009 }
1010 }
1011 }
1012
1013 #endif