Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_generator.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_AVX2_GENERATOR_HPP
18 #define CPU_JIT_AVX2_GENERATOR_HPP
19
20 #include <limits.h>
21 #include "cpu_isa_traits.hpp"
22
23 #include "utils.hpp"
24 #include "mkldnn_thread.hpp"
25
26 #ifdef JIT_PROFILING_VTUNE
27 #include "jitprofiling.h"
28 #endif
29
30 #if defined(_WIN32) && !defined(__GNUC__)
31 #   define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__
32 #else
33 #   define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al)))
34 #endif
35
36 #if defined(_WIN32)
37 #   define OFFSET_SHADOWSPACE 0x28
38 #endif
39
40 #define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \
41     const char *name() const override { return STRINGIFY(jit_name); } \
42     const char *source_file() const override { return __FILE__; \
43     }
44
45 namespace mkldnn {
46 namespace impl {
47 namespace cpu {
48
49 // TODO: move this to jit_generator class?
50 namespace {
51
52 typedef enum {
53     PAGE_4K = 4096,
54     PAGE_2M = 2097152,
55 } cpu_page_size_t;
56
57 // TODO: move this somewhere else? Although this is only used by jit kernels
58 // (Roma)
59 static inline int float2int(float x) {
60     union {
61         float vfloat;
62         int vint;
63     } cvt;
64     cvt.vfloat = x;
65     return cvt.vint;
66 }
67
68 // TODO: A GPR class that hides ABI details from the JIT kernels and allows
69 // numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and
70 // stack register (sr).
71 //
72 // This will allow using syntax like this:
73 //
74 // param = gpr0;
75 // reg_input = gpr0;
76 // reg_output =  gpr1;
77 // ...
78 //
79 // #ifndef XBYAK64
80 // mov(param, ptr[sr])
81 // #endif
82 //
83 // (Roma)
84
85 #ifdef XBYAK64
86 constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = {
87     Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
88     Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15,
89 #ifdef _WIN32
90     Xbyak::Operand::RDI, Xbyak::Operand::RSI,
91 #endif
92 };
93
94 #ifdef _WIN32
95 static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX),
96              abi_param2(Xbyak::Operand::RDX),
97              abi_param3(Xbyak::Operand::R8),
98              abi_param4(Xbyak::Operand::R9),
99              abi_not_param1(Xbyak::Operand::RDI);
100 #else
101 static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI),
102              abi_param2(Xbyak::Operand::RSI),
103              abi_param3(Xbyak::Operand::RDX),
104              abi_param4(Xbyak::Operand::RCX),
105              abi_param5(Xbyak::Operand::R8),
106              abi_param6(Xbyak::Operand::R9),
107              abi_not_param1(Xbyak::Operand::RCX);
108 #endif
109 #endif
110
111 inline unsigned int get_cache_size(int level, bool per_core = true){
112     unsigned int l = level - 1;
113     // Currently, if XByak is not able to fetch the cache topology
114     // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core.
115     if (cpu.getDataCacheLevels() == 0){
116         const int L1_cache_per_core = 32000;
117         const int L2_cache_per_core = 512000;
118         const int L3_cache_per_core = 1024000;
119         int num_cores = per_core ? 1 : mkldnn_get_max_threads();
120         switch(l){
121         case(0): return L1_cache_per_core * num_cores;
122         case(1): return L2_cache_per_core * num_cores;
123         case(2): return L3_cache_per_core * num_cores;
124         default: return 0;
125         }
126     }
127     if (l < cpu.getDataCacheLevels()) {
128         return cpu.getDataCacheSize(l)
129             / (per_core ? cpu.getCoresSharingDataCache(l) : 1);
130     } else
131         return 0;
132 }
133
134 }
135
136 class jit_generator : public Xbyak::CodeGenerator
137 {
138 private:
139     const size_t xmm_len = 16;
140 #ifdef _WIN32
141     const size_t xmm_to_preserve_start = 6;
142     const size_t xmm_to_preserve = 10;
143 #else
144     const size_t xmm_to_preserve_start = 0;
145     const size_t xmm_to_preserve = 0;
146 #endif
147
148     const size_t num_abi_save_gpr_regs
149         = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]);
150
151     const size_t size_of_abi_save_regs
152         = num_abi_save_gpr_regs * rax.getBit() / 8
153         + xmm_to_preserve * xmm_len;
154
155 public:
156     enum {
157         _cmp_eq_oq = 0u,
158         _cmp_lt_os = 1u,
159         _cmp_le_os = 2u,
160         _cmp_neq_uq = 4u,
161         _cmp_nlt_us = 5u,
162         _cmp_nle_us = 6u,
163
164         _op_floor = 1u,
165     };
166
167     Xbyak::Reg64 param1 = abi_param1;
168     const int EVEX_max_8b_offt = 0x200;
169     const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
170
171     inline size_t get_size_of_abi_save_regs() {
172         return size_of_abi_save_regs;
173     }
174
175     void preamble() {
176         if (xmm_to_preserve) {
177             sub(rsp, xmm_to_preserve * xmm_len);
178             for (size_t i = 0; i < xmm_to_preserve; ++i)
179                 movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i));
180         }
181         for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
182             push(Xbyak::Reg64(abi_save_gpr_regs[i]));
183         if (mayiuse(avx512_common)) {
184             mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
185         }
186     }
187
188     void mic_prefetcht0(Xbyak::Address a) {
189         if (mayiuse(avx512_mic))
190             prefetcht0(a);
191     }
192
193     void mic_prefetcht1(Xbyak::Address a) {
194         if (mayiuse(avx512_mic))
195             prefetcht1(a);
196     }
197
198     void mic_prefetcht2(Xbyak::Address a) {
199         if (mayiuse(avx512_mic))
200             prefetcht2(a);
201     }
202
203     void uni_vzeroupper() {
204         if (mayiuse(avx) && !mayiuse(avx512_mic))
205             vzeroupper();
206     }
207
208     void postamble() {
209         for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
210             pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i]));
211         if (xmm_to_preserve) {
212             for (size_t i = 0; i < xmm_to_preserve; ++i)
213                 movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]);
214             add(rsp, xmm_to_preserve * xmm_len);
215         }
216         uni_vzeroupper();
217         ret();
218     }
219
220     template<typename T>
221     Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base,
222             T raw_offt, bool bcast = false)
223     {
224         using Xbyak::Zmm;
225         using Xbyak::Reg64;
226         using Xbyak::Address;
227         using Xbyak::RegExp;
228
229         assert(raw_offt <= INT_MAX);
230         auto offt = static_cast<int>(raw_offt);
231
232         int scale = 0;
233
234         if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
235             offt = offt - 2 * EVEX_max_8b_offt;
236             scale = 1;
237         } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
238             offt = offt - 4 * EVEX_max_8b_offt;
239             scale = 2;
240         }
241
242         auto re = RegExp() + base + offt;
243         if (scale)
244             re = re + reg_EVEX_max_8b_offt * scale;
245
246         if (bcast)
247             return zword_b [re];
248         else
249             return zword [re];
250     }
251
252     Xbyak::Address make_safe_addr(const Xbyak::Reg64 &reg_out, size_t offt,
253         const Xbyak::Reg64 &tmp_reg, bool bcast = false) {
254         if (offt > INT_MAX) {
255             mov(tmp_reg, offt);
256             return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg];
257         } else {
258             return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt];
259         }
260     }
261
262     Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base,
263         size_t raw_offt, const Xbyak::Reg64 &reg_offt, bool bcast = false) {
264         if (raw_offt > INT_MAX) {
265             return make_safe_addr(base, raw_offt, reg_offt, bcast);
266         } else {
267             return EVEX_compress_addr(base, raw_offt, bcast);
268         }
269     }
270
271     void safe_add(const Xbyak::Reg64 &base, size_t raw_offt,
272         const Xbyak::Reg64 &reg_offt) {
273         if (raw_offt > INT_MAX) {
274             mov(reg_offt, raw_offt);
275             add(base, reg_offt);
276         } else {
277             add(base, raw_offt);
278         }
279     }
280
281     void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt,
282         const Xbyak::Reg64 &reg_offt) {
283         if (raw_offt > INT_MAX) {
284             mov(reg_offt, raw_offt);
285             sub(base, reg_offt);
286         } else {
287             sub(base, raw_offt);
288         }
289     }
290
291     // Disallow char-based labels completely
292     void L(const char *label) = delete;
293     void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
294
295     void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
296                    const Xbyak::Operand &op) {
297         assert(x1.getIdx() == x2.getIdx());
298         pxor(x2, op);
299     }
300     void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
301                    const Xbyak::Operand &op) {
302         if (mayiuse(avx2)) {
303             vpxor(x1, x2, op);
304         } else {
305             vxorps(x1, x2, op);
306         }
307     }
308     void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2,
309                    const Xbyak::Operand &op) {
310         vpxord(x1, x2, op);
311     }
312
313     void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
314         movss(addr, x);
315     }
316     void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
317         vmovss(addr, x);
318     }
319     void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
320         movss(x, addr);
321     }
322     void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
323         vmovss(x, addr);
324     }
325
326     void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
327         movsd(addr, x);
328     }
329     void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
330         vmovsd(addr, x);
331     }
332     void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
333         movsd(x, addr);
334     }
335     void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
336         vmovsd(x, addr);
337     }
338
339     void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
340         movdqu(addr, x);
341     }
342     void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
343         vmovdqu(addr, x);
344     }
345     void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) {
346         vmovdqu32(addr, x);
347     }
348
349     void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
350         movdqu(x, addr);
351     }
352     void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
353         vmovdqu(x, addr);
354     }
355     void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) {
356         vmovdqu32(x, addr);
357     }
358
359     void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
360         movups(addr, x);
361     }
362     void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
363         vmovups(addr, x);
364     }
365
366     void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
367         movups(x, op);
368     }
369     void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
370         vmovups(x, op);
371     }
372
373     void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
374         movntps(addr, x);
375     }
376     void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
377         vmovntps(addr, x);
378     }
379
380     void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
381         movss(x, op);
382         shufps(x, x, 0x0);
383     }
384     void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
385         if (op.isMEM() || mayiuse(avx2)) {
386             vbroadcastss(x, op);
387         } else {
388             Xbyak::Xmm t(x.getIdx());
389             if (t.getIdx() != op.getIdx()) movss(t, op);
390             vinsertf128(x, x, t, 1);
391             vshufps(x, x, x, 0);
392         }
393     }
394
395     void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
396         movsd(x, op);
397         pshufd(x, x, 0x0);
398     }
399     void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
400         if (mayiuse(avx2)) {
401             vpbroadcastd(x, op);
402         } else {
403             Xbyak::Xmm t(x.getIdx());
404             if (t.getIdx() != op.getIdx()) movsd(t, op);
405             vinsertf128(x, x, t, 1);
406             vshufps(x, x, x, 0);
407         }
408     }
409
410     void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
411         rcpss(x, op);
412     }
413     void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) {
414         Xbyak::Xmm x1_(x1.getIdx());
415         Xbyak::Xmm x2_(x2.getIdx());
416         vrcpss(x1_, x1_, x2_);
417     }
418     void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) {
419         Xbyak::Xmm x_(x.getIdx());
420         vrcpss(x_, x_, op);
421     }
422
423     void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
424         rcpps(x, op);
425     }
426     void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
427         vrcpps(x, op);
428     }
429     void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) {
430         vrcp14ps(x, op);
431     }
432
433     void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
434                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
435         assert(x.getIdx() == op1.getIdx());
436         divps(x, op2);
437     }
438     void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
439                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
440         vdivps(x, op1, op2);
441     }
442
443     void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
444                     const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
445         movups(buf, op1);
446         divps(buf, op2);
447         if (x.getIdx() != buf.getIdx()) {
448             movups(x, buf);
449         }
450     }
451
452     void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
453                     const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
454         vdivps(x, op1, op2);
455     }
456
457     void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
458                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
459         assert(x.getIdx() == op1.getIdx());
460         addps(x, op2);
461     }
462     void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
463                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
464         vaddps(x, op1, op2);
465     }
466
467     void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2,
468                      const Xbyak::Operand& op) {
469         assert(x1.getIdx() == x2.getIdx());
470         psignd(x1, op);
471     }
472     void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2,
473                      const Xbyak::Operand& op) {
474         vpsignd(x1, x2, op);
475     }
476
477     void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
478                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
479         assert(x.getIdx() == op1.getIdx());
480         subps(x, op2);
481     }
482     void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
483                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
484         vsubps(x, op1, op2);
485     }
486
487     void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
488                     const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
489         movups(buf, op1);
490         subps(buf, op2);
491         if (x.getIdx() != buf.getIdx()) {
492             movups(x, buf);
493         }
494     }
495
496     void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
497                     const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
498         vsubps(x, op1, op2);
499     }
500
501     void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
502                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
503         assert(x.getIdx() == op1.getIdx());
504         mulps(x, op2);
505     }
506     void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
507                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
508         vmulps(x, op1, op2);
509     }
510
511     void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
512                          const Xbyak::Operand &op) {
513         mulps(x1, x2);
514         addps(x1, op);
515     }
516     void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
517                          const Xbyak::Operand &op) {
518         vfmadd213ps(x1, x2, op);
519     }
520
521     void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
522                          const Xbyak::Operand &op) {
523         mulps(x2, op);
524         addps(x1, x2);
525     }
526     void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
527                          const Xbyak::Operand &op) {
528         vfmadd231ps(x1, x2, op);
529     }
530
531     void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
532                            const Xbyak::Operand &op) {
533         mulps(x2, op);
534         subps(x1, x2);
535     }
536
537     void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
538                            const Xbyak::Operand &op) {
539         vfnmadd231ps(x1, x2, op);
540     }
541
542     void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
543         sqrtps(x, op);
544     }
545     void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
546         vsqrtps(x, op);
547     }
548
549     void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
550                     const Xbyak::Operand &op) {
551         assert(x1.getIdx() == x2.getIdx());
552         paddd(x2, op);
553     }
554     void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
555                     const Xbyak::Operand &op) {
556         vpaddd(x1, x2, op);
557     }
558
559     void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
560                     const Xbyak::Operand &op = Xbyak::Operand()) {
561         assert(x1.getIdx() == x2.getIdx());
562         andps(x1, op);
563     }
564     void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
565                     const Xbyak::Operand &op = Xbyak::Operand()) {
566         if (!mayiuse(avx512_common) || x1.getBit() < 512)
567             vandps(x1, x2, op);
568         else
569             vpandd(x1, x2, op);
570     }
571
572     void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
573                     const Xbyak::Operand &op = Xbyak::Operand()) {
574         assert(x1.getIdx() == x2.getIdx());
575         orps(x1, op);
576     }
577     void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
578                     const Xbyak::Operand &op = Xbyak::Operand()) {
579         if (!mayiuse(avx512_common) || x1.getBit() < 512)
580             vorps(x1, x2, op);
581         else
582             vpord(x1, x2, op);
583     }
584
585     void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
586                     const int imm) {
587         assert(x.getIdx() == op.getIdx());
588         pslld(x, imm);
589     }
590     void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
591                     const int imm) {
592         vpslld(x, op, imm);
593     }
594
595     void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
596                     const int imm) {
597         assert(x.getIdx() == op.getIdx());
598         psrld(x, imm);
599     }
600     void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
601                     const int imm) {
602         vpsrld(x, op, imm);
603     }
604
605     void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
606                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
607         assert(x.getIdx() == op1.getIdx());
608         maxps(x, op2);
609     }
610     void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
611                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
612         vmaxps(x, op1, op2);
613     }
614
615     void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
616                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
617         assert(x.getIdx() == op1.getIdx());
618         minps(x, op2);
619     }
620     void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
621                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
622         vminps(x, op1, op2);
623     }
624
625     void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
626                       const Xbyak::Operand &op) {
627         assert(x1.getIdx() == x2.getIdx());
628         cmpps(x1, op, _cmp_nle_us);
629     }
630
631     void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
632                       const Xbyak::Operand &op) {
633         vcmpgtps(x1, x2, op);
634     }
635
636     void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
637                       const Xbyak::Operand &op) {
638         assert(x1.getIdx() == x2.getIdx());
639         cmpps(x1, op, _cmp_nlt_us);
640     }
641
642     void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
643                       const Xbyak::Operand &op) {
644         vcmpps(x1, x2, op, _cmp_nlt_us);
645     }
646
647     void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) {
648         ptest(x1, op);
649     }
650
651     void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) {
652         assert(!(x1.isZMM() || op.isZMM()));
653         vtestps(x1, op);
654     }
655
656     void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
657                        const Xbyak::Operand &op, const Xbyak::Xmm &msk) {
658         assert(x1.getIdx() == x2.getIdx());
659         assert(msk.getIdx() == 0);
660         blendvps(x1, op);
661     }
662     void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
663                        const Xbyak::Operand &op, const Xbyak::Ymm &msk) {
664         vblendvps(x1, x2, op, msk);
665     }
666
667     void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op,
668                       const int imm) {
669         roundps(x, op, imm);
670     }
671     void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op,
672                       const int imm) {
673         vroundps(x, op, imm);
674     }
675
676     void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
677         cvtps2dq(x, op);
678     }
679     void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
680         vcvtps2dq(x, op);
681     }
682
683     void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
684         cvtdq2ps(x, op);
685     }
686     void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
687         vcvtdq2ps(x, op);
688     }
689
690     void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) {
691         movmskps(x1.cvt64(), x2);
692     }
693     void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) {
694         vmovmskps(x1, x2);
695     }
696
697     void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
698         assert(x1.getIdx() == x1.getIdx());
699         packssdw(x1, op);
700     }
701     void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
702         vpackssdw(x1, x2, op);
703     }
704
705     void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
706         assert(x1.getIdx() == x1.getIdx());
707         packuswb(x1, op);
708     }
709     void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
710         vpackuswb(x1, x2, op);
711     }
712
713     void uni_vpmovsxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
714         pmovsxbd(x, op);
715     }
716     void uni_vpmovsxbd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
717         vpmovsxbd(x, op);
718     }
719
720     void uni_vpmovzxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
721         pmovzxbd(x, op);
722     }
723     void uni_vpmovzxbd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
724         vpmovzxbd(x, op);
725     }
726
727     void uni_vpackusdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
728         assert(x1.getIdx() == x2.getIdx());
729         packusdw(x1, op);
730     }
731     void uni_vpackusdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
732         vpackusdw(x1, x2, op);
733     }
734
735     void uni_vpacksswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
736         assert(x1.getIdx() == x2.getIdx());
737         packsswb(x1, op);
738     }
739     void uni_vpacksswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
740         vpacksswb(x1, x2, op);
741     }
742
743     void uni_vpmaxsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
744         assert(x1.getIdx() == x2.getIdx());
745         pmaxsd(x1, op);
746     }
747     void uni_vpmaxsd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
748         vpmaxsd(x1, x2, op);
749     }
750
751     void uni_vpmaxsb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
752         assert(x1.getIdx() == x2.getIdx());
753         pmaxsb(x1, op);
754     }
755     void uni_vpmaxsb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
756         vpmaxsb(x1, x2, op);
757     }
758
759     void uni_vpmaxub(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
760         assert(x1.getIdx() == x2.getIdx());
761         pmaxub(x1, op);
762     }
763     void uni_vpmaxub(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
764         vpmaxub(x1, x2, op);
765     }
766
767     void uni_vpmaddubsw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
768         assert(x1.getIdx() == x2.getIdx());
769         pmaddubsw(x1, op);
770     }
771     void uni_vpmaddubsw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
772         vpmaddubsw(x1, x2, op);
773     }
774
775     void uni_vpmaddwd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
776         assert(x1.getIdx() == x2.getIdx());
777         pmaddwd(x1, op);
778     }
779     void uni_vpmaddwd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
780         vpmaddwd(x1, x2, op);
781     }
782
783     void uni_vpmulld(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
784         assert(x1.getIdx() == x2.getIdx());
785         pmulld(x1, op);
786     }
787     void uni_vpmulld(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
788         vpmulld(x1, x2, op);
789     }
790
791     void uni_vpsubb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
792         assert(x1.getIdx() == x2.getIdx());
793         psubb(x1, op);
794     }
795     void uni_vpsubb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
796         vpsubb(x1, x2, op);
797     }
798
799     void uni_vpslldq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::uint8 &op) {
800         assert(x1.getIdx() == x2.getIdx());
801         pslldq(x1, op);
802     }
803     void uni_vpslldq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::uint8 &op) {
804         vpslldq(x1, x2, op);
805     }
806
807     void uni_vpand(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
808                    const Xbyak::Operand &op = Xbyak::Operand()) {
809         assert(x1.getIdx() == x2.getIdx());
810         pand(x1, op);
811     }
812     void uni_vpand(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
813                     const Xbyak::Operand &op = Xbyak::Operand()) {
814         vpand(x1, x2, op);
815     }
816
817     void uni_vpaddb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
818                     const Xbyak::Operand &op) {
819         assert(x1.getIdx() == x2.getIdx());
820         paddb(x2, op);
821     }
822     void uni_vpaddb(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
823                     const Xbyak::Operand &op) {
824         vpaddb(x1, x2, op);
825     }
826
827     void uni_vpshufb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
828                      const Xbyak::Operand &op) {
829         assert(x1.getIdx() == x2.getIdx());
830         pshufb(x1, op);
831     }
832
833     void uni_vpshufb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
834                      const Xbyak::Operand &op) {
835         vpshufb(x1, x2, op);
836     }
837
838     void mul_by_const(const Xbyak::Reg &out,
839             const Xbyak::Reg64 &tmp, int value) {
840         // Generates a shift + add sequence for multiplicating contents of the
841         // out register by a known JIT-time value. Clobbers the tmp register.
842         //
843         // Pros compared to mul/imul:
844         // - does not require using known registers
845         // - not microcoded on Intel(R) Xeon Phi(TM) processors
846         // Still, there are probably a lot of cases when mul/imul is faster on
847         // Intel(R) Core(TM) processors. Not intended for critical path.
848
849         // TODO: detect when overflow is emminent (Roma)
850         // TODO: detect when using mul/imul is a better option (Roma)
851
852         int p = 0; // the current power of 2
853         int old_p = 0; // the last seen power of 2 such that value[old_p] != 0
854
855         xor_(tmp, tmp);
856         while (value) {
857             if (value & 1) {
858                 int shift = p - old_p;
859                 if (shift) {
860                     shl(out, shift);
861                     old_p = p;
862                 }
863                 add(tmp, out);
864             }
865             value >>= 1;
866             p++;
867         }
868         mov(out, tmp);
869     }
870
871     void dump_code(const Xbyak::uint8 *code) const {
872         if (code) {
873             static int counter = 0;
874 #define MAX_FNAME_LEN 256
875             char fname[MAX_FNAME_LEN + 1];
876             snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", name(),
877                     counter);
878             counter++;
879
880             FILE *fp = mkldnn_fopen(fname, "w+");
881             // Failure to dump code is not fatal
882             if (fp) {
883                 size_t unused = fwrite(code, getSize(), 1, fp);
884                 UNUSED(unused);
885                 fclose(fp);
886             }
887         }
888 #undef MAX_FNAME_LEN
889     }
890
891     void register_code(const Xbyak::uint8 *code) const {
892 #ifdef JIT_PROFILING_VTUNE
893         if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) {
894             auto jmethod = iJIT_Method_Load();
895             jmethod.method_id = iJIT_GetNewMethodID();
896             jmethod.method_name = (char *)name();
897             jmethod.class_file_name = NULL;
898             jmethod.source_file_name = (char *)source_file();
899             jmethod.method_load_address = (void *)code;
900             jmethod.method_size = getSize();
901
902             iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED,
903                     (void*)&jmethod);
904         }
905 #endif
906     }
907
908 public:
909     jit_generator(
910         void *code_ptr = nullptr,
911         size_t code_size = 256 * 1024
912         ) : Xbyak::CodeGenerator(code_size, code_ptr)
913     {
914     }
915     virtual ~jit_generator() {}
916
917     virtual const char *name() const = 0;
918     virtual const char *source_file() const = 0;
919
920     // XXX: use normal_case name and update all callees (?)
921     const Xbyak::uint8 *getCode() {
922         const Xbyak::uint8 *code = CodeGenerator::getCode();
923         register_code(code);
924
925         if (mkldnn_jit_dump())
926             dump_code(code);
927
928         return code;
929     }
930
931     template<typename F> const F getCode() {
932         // XXX (Roma): Xbyak code probably has a bug here
933         return (const F)getCode();
934     }
935 };
936
937 }
938 }
939 }
940
941 #endif