Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_conv_kernel.hpp
index ec6e185..4641292 100644 (file)
@@ -18,8 +18,9 @@
 #define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP
 
 #include "c_types_map.hpp"
-#include "cpu_memory.hpp"
+#include "memory_tracking.hpp"
 
+#include "cpu_memory.hpp"
 #include "jit_generator.hpp"
 #include "jit_primitive_conf.hpp"
 #include "jit_uni_eltwise.hpp"
@@ -29,16 +30,18 @@ namespace mkldnn {
 namespace impl {
 namespace cpu {
 
-struct jit_avx512_common_conv_fwd_kernel : public jit_generator {
+template<typename Vmm>
+struct _jit_avx512_common_conv_fwd_kernel : public jit_generator {
 
-    jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
-            const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
+    _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
+            const primitive_attr_t &attr)
+        : jcp(ajcp), attr_(attr)
     {
         generate();
-        jit_ker = (void (*)(jit_conv_call_s *))getCode();
+        jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
     }
 
-    ~jit_avx512_common_conv_fwd_kernel() {
+    ~_jit_avx512_common_conv_fwd_kernel() {
         for (auto inj : eltwise_injectors)
             delete inj;
         eltwise_injectors.clear();
@@ -48,24 +51,11 @@ struct jit_avx512_common_conv_fwd_kernel : public jit_generator {
         depthwise_injectors.clear();
     }
 
-    DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_fwd_kernel)
-
-    static bool post_ops_ok(jit_conv_conf_t &jcp,
-            const primitive_attr_t &attr);
-    static status_t init_conf(jit_conv_conf_t &jcp,
-            const convolution_desc_t &cd,
-            cpu_memory_t::pd_t &src_pd,
-            cpu_memory_t::pd_t &weights_pd,
-            cpu_memory_t::pd_t &dst_pd,
-            cpu_memory_t::pd_t &bias_pd,
-            const primitive_attr_t &attr,
-            int nthreads,
-            bool with_relu,
-            float relu_negative_slope);
+    DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel)
 
     jit_conv_conf_t jcp;
     const primitive_attr_t &attr_;
-    void (*jit_ker)(jit_conv_call_s *);
+    void (*jit_ker_)(jit_conv_call_s *);
 
 private:
     using reg64_t = const Xbyak::Reg64;
@@ -121,25 +111,25 @@ private:
     reg64_t reg_long_offt = r11;
     reg64_t reg_out_long_offt = r14;
 
-    inline Xbyak::Zmm zmm_ker(int i_ic) {
+    inline Vmm vmm_ker(int i_ic) {
         assert(i_ic < 4);
-        return Xbyak::Zmm(ker_reg_base_idx + i_ic);
+        return Vmm(ker_reg_base_idx + i_ic);
     }
 
-    inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
+    inline Vmm vmm_out(int i_ur, int i_oc) {
         int idx = i_ur + i_oc * jcp.ur_w;
         assert(idx < ker_reg_base_idx);
-        return Xbyak::Zmm(idx);
+        return Vmm(idx);
     }
 
-    inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) {
+    inline Vmm vmm_inp(int i_ic, int nb_x_blocking) {
         int idx = i_ic + nb_x_blocking * jcp.ur_w;
         assert(idx < 31);
-        return Xbyak::Zmm(idx);
+        return Vmm(idx);
     }
 
     Xbyak::Reg64 imm_addr64 = r15;
-    Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
+    Vmm vmm_wei = Vmm(31);
 
     reg64_t reg_d_weights = imm_addr64;
     reg64_t reg_d_bias = reg_kj;
@@ -158,35 +148,11 @@ private:
 
     void generate();
 
-    inline void vpXdpwssd(Xbyak::Zmm zmm1, Xbyak::Zmm zmm2,
-        const Xbyak::Address& op) {
-        if (jcp.ver == ver_4vnni)
-            vp4dpwssd(zmm1, zmm2, op);
-        else
-            vpdpwssd(zmm1, zmm2, op);
-    }
-
-    inline void vadd(Xbyak::Zmm zmm, const Xbyak::Operand& op) {
+    inline void vadd(Vmm vmm, const Xbyak::Operand& op) {
         if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
-            vpaddd(zmm, zmm, op);
+            vpaddd(vmm, vmm, op);
         else
-            vaddps(zmm, zmm, op);
-    }
-
-    inline void vcmp(Xbyak::Opmask kmask,
-        Xbyak::Zmm zmm_src1, Xbyak::Zmm zmm_src2, const unsigned char cmp) {
-        if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
-            vpcmpd(kmask, zmm_src1, zmm_src2, cmp);
-        else
-            vcmpps(kmask, zmm_src1, zmm_src2, cmp);
-    }
-
-    inline void vmul(Xbyak::Zmm zmm_dst, Xbyak::Opmask kmask,
-                     Xbyak::Zmm zmm_src1, Xbyak::Zmm zmm_src2) {
-        if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
-            vpmulld(zmm_dst | kmask, zmm_src1, zmm_src2);
-        else
-            vmulps(zmm_dst | kmask, zmm_src1, zmm_src2);
+            vaddps(vmm, vmm, op);
     }
 
     inline size_t get_output_offset(int oi, int n_oc_block) {
@@ -224,6 +190,59 @@ private:
     }
 };
 
+struct jit_avx512_common_conv_fwd_kernel {
+
+    jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
+        const primitive_attr_t &attr) :
+        jit_ker(nullptr),
+        zmm_kernel_(nullptr),
+        xmm_kernel_(nullptr) {
+        int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block;
+        switch (ch_block) {
+        case 16:
+            zmm_kernel_ =
+                new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>(
+                    ajcp, attr);
+            jit_ker = zmm_kernel_->jit_ker_;
+            return;
+        case 4:
+            xmm_kernel_ =
+                new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>(
+                    ajcp, attr);
+            jit_ker = xmm_kernel_->jit_ker_;
+            return;
+        default:
+            assert(!"invalid channel blocking");
+        }
+    }
+
+    ~jit_avx512_common_conv_fwd_kernel() {
+        delete xmm_kernel_;
+        delete zmm_kernel_;
+    }
+
+    enum {
+        typesize = sizeof(float)
+    };
+
+    static bool post_ops_ok(jit_conv_conf_t &jcp,
+        const primitive_attr_t &attr);
+    static status_t init_conf(jit_conv_conf_t &jcp,
+        const convolution_desc_t &cd,
+        cpu_memory_t::pd_t &src_pd,
+        cpu_memory_t::pd_t &weights_pd,
+        cpu_memory_t::pd_t &dst_pd,
+        cpu_memory_t::pd_t &bias_pd,
+        const primitive_attr_t &attr,
+        int nthreads);
+    static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+        const jit_conv_conf_t &jcp);
+
+    void(*jit_ker)(jit_conv_call_s *);
+    _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
+    _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
+};
+
 struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator {
 
     jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
@@ -239,6 +258,8 @@ struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator {
             const memory_desc_wrapper &diff_src_d,
             const memory_desc_wrapper &weights_d,
             const memory_desc_wrapper &diff_dst_d);
+    static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+            const jit_conv_conf_t &jcp);
 
     jit_conv_conf_t jcp;
     void (*jit_ker)(jit_conv_call_s *);
@@ -358,6 +379,8 @@ struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator {
             const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
             cpu_memory_t::pd_t &diff_weights_pd,
             cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd);
+    static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+            const jit_conv_conf_t &jcp);
 
     jit_conv_conf_t jcp;
     void (*jit_ker)(jit_conv_call_s *);
@@ -423,6 +446,9 @@ private:
     inline void compute_loop();
 
     void generate();
+
+    static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb,
+            int &nthr_g, int &nthr_oc_b, int &nthr_ic_b);
 };
 
 }