Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_generator.hpp
index b72ed2d..b247724 100644 (file)
@@ -102,6 +102,8 @@ static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI),
              abi_param2(Xbyak::Operand::RSI),
              abi_param3(Xbyak::Operand::RDX),
              abi_param4(Xbyak::Operand::RCX),
+             abi_param5(Xbyak::Operand::R8),
+             abi_param6(Xbyak::Operand::R9),
              abi_not_param1(Xbyak::Operand::RCX);
 #endif
 #endif
@@ -110,7 +112,7 @@ inline unsigned int get_cache_size(int level, bool per_core = true){
     unsigned int l = level - 1;
     // Currently, if XByak is not able to fetch the cache topology
     // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core.
-    if (cpu.data_cache_levels == 0){
+    if (cpu.getDataCacheLevels() == 0){
         const int L1_cache_per_core = 32000;
         const int L2_cache_per_core = 512000;
         const int L3_cache_per_core = 1024000;
@@ -122,31 +124,15 @@ inline unsigned int get_cache_size(int level, bool per_core = true){
         default: return 0;
         }
     }
-    if (l < cpu.data_cache_levels) {
-        return cpu.data_cache_size[l]
-            / (per_core ? cpu.cores_sharing_data_cache[l] : 1);
+    if (l < cpu.getDataCacheLevels()) {
+        return cpu.getDataCacheSize(l)
+            / (per_core ? cpu.getCoresSharingDataCache(l) : 1);
     } else
         return 0;
 }
 
 }
 
-// TODO (Roma): move all_same to a more appropriate location
-
-template <typename T, typename U, typename... Us>
-struct all_same : std::false_type {};
-
-template <typename T, typename... Us>
-struct all_same<T, T, Us...> : all_same<T, Us...> { };
-
-template <typename T>
-struct all_same<T, T> : std::true_type {};
-
-struct jit_code_injection {
-    const Xbyak::uint8* code;
-    size_t size;
-};
-
 class jit_generator : public Xbyak::CodeGenerator
 {
 private:
@@ -174,6 +160,8 @@ public:
         _cmp_neq_uq = 4u,
         _cmp_nlt_us = 5u,
         _cmp_nle_us = 6u,
+
+        _op_floor = 1u,
     };
 
     Xbyak::Reg64 param1 = abi_param1;
@@ -302,7 +290,7 @@ public:
 
     // Disallow char-based labels completely
     void L(const char *label) = delete;
-    void L(const Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
+    void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
 
     void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
                    const Xbyak::Operand &op) {
@@ -322,6 +310,32 @@ public:
         vpxord(x1, x2, op);
     }
 
+    void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
+        movss(addr, x);
+    }
+    void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
+        vmovss(addr, x);
+    }
+    void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
+        movss(x, addr);
+    }
+    void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
+        vmovss(x, addr);
+    }
+
+    void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
+        movsd(addr, x);
+    }
+    void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
+        vmovsd(addr, x);
+    }
+    void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
+        movsd(x, addr);
+    }
+    void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
+        vmovsd(x, addr);
+    }
+
     void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
         movdqu(addr, x);
     }
@@ -393,6 +407,29 @@ public:
         }
     }
 
+    void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+        rcpss(x, op);
+    }
+    void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) {
+        Xbyak::Xmm x1_(x1.getIdx());
+        Xbyak::Xmm x2_(x2.getIdx());
+        vrcpss(x1_, x1_, x2_);
+    }
+    void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) {
+        Xbyak::Xmm x_(x.getIdx());
+        vrcpss(x_, x_, op);
+    }
+
+    void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+        rcpps(x, op);
+    }
+    void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
+        vrcpps(x, op);
+    }
+    void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) {
+        vrcp14ps(x, op);
+    }
+
     void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
                     const Xbyak::Operand &op2 = Xbyak::Operand()) {
         assert(x.getIdx() == op1.getIdx());
@@ -519,24 +556,30 @@ public:
         vpaddd(x1, x2, op);
     }
 
-    void uni_vandps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
-                    const Xbyak::Operand &op2 = Xbyak::Operand()) {
-        assert(x.getIdx() == op1.getIdx());
-        andps(x, op2);
+    void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+                    const Xbyak::Operand &op = Xbyak::Operand()) {
+        assert(x1.getIdx() == x2.getIdx());
+        andps(x1, op);
     }
-    void uni_vandps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
-                    const Xbyak::Operand &op2 = Xbyak::Operand()) {
-        vandps(x, op1, op2);
+    void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+                    const Xbyak::Operand &op = Xbyak::Operand()) {
+        if (!mayiuse(avx512_common) || x1.getBit() < 512)
+            vandps(x1, x2, op);
+        else
+            vpandd(x1, x2, op);
     }
 
-    void uni_vorps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
-                    const Xbyak::Operand &op2 = Xbyak::Operand()) {
-        assert(x.getIdx() == op1.getIdx());
-        orps(x, op2);
+    void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+                    const Xbyak::Operand &op = Xbyak::Operand()) {
+        assert(x1.getIdx() == x2.getIdx());
+        orps(x1, op);
     }
-    void uni_vorps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
-                    const Xbyak::Operand &op2 = Xbyak::Operand()) {
-        vorps(x, op1, op2);
+    void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+                    const Xbyak::Operand &op = Xbyak::Operand()) {
+        if (!mayiuse(avx512_common) || x1.getBit() < 512)
+            vorps(x1, x2, op);
+        else
+            vpord(x1, x2, op);
     }
 
     void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
@@ -582,16 +625,38 @@ public:
     void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
                       const Xbyak::Operand &op) {
         assert(x1.getIdx() == x2.getIdx());
-        cmpps(x1, op, 0x6);
+        cmpps(x1, op, _cmp_nle_us);
     }
+
     void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
                       const Xbyak::Operand &op) {
         vcmpgtps(x1, x2, op);
     }
 
+    void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+                      const Xbyak::Operand &op) {
+        assert(x1.getIdx() == x2.getIdx());
+        cmpps(x1, op, _cmp_nlt_us);
+    }
+
+    void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+                      const Xbyak::Operand &op) {
+        vcmpps(x1, x2, op, _cmp_nlt_us);
+    }
+
+    void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) {
+        ptest(x1, op);
+    }
+
+    void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) {
+        assert(!(x1.isZMM() || op.isZMM()));
+        vtestps(x1, op);
+    }
+
     void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
                        const Xbyak::Operand &op, const Xbyak::Xmm &msk) {
         assert(x1.getIdx() == x2.getIdx());
+        assert(msk.getIdx() == 0);
         blendvps(x1, op);
     }
     void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
@@ -629,6 +694,22 @@ public:
         vmovmskps(x1, x2);
     }
 
+    void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
+        assert(x1.getIdx() == x1.getIdx());
+        packssdw(x1, op);
+    }
+    void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
+        vpackssdw(x1, x2, op);
+    }
+
+    void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
+        assert(x1.getIdx() == x1.getIdx());
+        packuswb(x1, op);
+    }
+    void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
+        vpackuswb(x1, x2, op);
+    }
+
     void uni_vpmovsxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
         pmovsxbd(x, op);
     }
@@ -643,14 +724,6 @@ public:
         vpmovzxbd(x, op);
     }
 
-    void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
-        assert(x1.getIdx() == x2.getIdx());
-        packssdw(x1, op);
-    }
-    void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
-        vpackssdw(x1, x2, op);
-    }
-
     void uni_vpackusdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
         assert(x1.getIdx() == x2.getIdx());
         packusdw(x1, op);
@@ -667,14 +740,6 @@ public:
         vpacksswb(x1, x2, op);
     }
 
-    void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
-        assert(x1.getIdx() == x2.getIdx());
-        packuswb(x1, op);
-    }
-    void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) {
-        vpackuswb(x1, x2, op);
-    }
-
     void uni_vpmaxsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) {
         assert(x1.getIdx() == x2.getIdx());
         pmaxsd(x1, op);
@@ -731,6 +796,45 @@ public:
         vpsubb(x1, x2, op);
     }
 
+    void uni_vpslldq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::uint8 &op) {
+        assert(x1.getIdx() == x2.getIdx());
+        pslldq(x1, op);
+    }
+    void uni_vpslldq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::uint8 &op) {
+        vpslldq(x1, x2, op);
+    }
+
+    void uni_vpand(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+                   const Xbyak::Operand &op = Xbyak::Operand()) {
+        assert(x1.getIdx() == x2.getIdx());
+        pand(x1, op);
+    }
+    void uni_vpand(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+                    const Xbyak::Operand &op = Xbyak::Operand()) {
+        vpand(x1, x2, op);
+    }
+
+    void uni_vpaddb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+                    const Xbyak::Operand &op) {
+        assert(x1.getIdx() == x2.getIdx());
+        paddb(x2, op);
+    }
+    void uni_vpaddb(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
+                    const Xbyak::Operand &op) {
+        vpaddb(x1, x2, op);
+    }
+
+    void uni_vpshufb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+                     const Xbyak::Operand &op) {
+        assert(x1.getIdx() == x2.getIdx());
+        pshufb(x1, op);
+    }
+
+    void uni_vpshufb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+                     const Xbyak::Operand &op) {
+        vpshufb(x1, x2, op);
+    }
+
     void mul_by_const(const Xbyak::Reg &out,
             const Xbyak::Reg64 &tmp, int value) {
         // Generates a shift + add sequence for multiplicating contents of the
@@ -764,10 +868,6 @@ public:
         mov(out, tmp);
     }
 
-    void inject(jit_code_injection&& in) {
-        db(in.code, in.size);
-    }
-
     void dump_code(const Xbyak::uint8 *code) const {
         if (code) {
             static int counter = 0;