Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_conv_kernel.hpp
index d243004..0e8e7ca 100644 (file)
 #define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
 
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
 #include "cpu_memory.hpp"
 
 #include "jit_generator.hpp"
 #include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+#include "jit_uni_depthwise.hpp"
 
 namespace mkldnn {
 namespace impl {
 namespace cpu {
 
-struct jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
-    DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_conv_fwd_ker_t)
+template<typename Vmm>
+struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
+    DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t)
 
     enum { STATE_FIRST_DST_LOAD = 0x1U };
 
-    jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
+    _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
             const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
     {
         generate();
-        jit_ker = (void (*)(jit_conv_call_s *))getCode();
+        jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
+    }
+
+    ~_jit_avx512_core_x8s8s32x_fwd_kernel() {
+        for (auto inj : eltwise_injectors)
+            delete inj;
+        eltwise_injectors.clear();
+
+        for (auto inj : depthwise_injectors)
+            delete inj;
+        depthwise_injectors.clear();
     }
-    static bool post_ops_ok(jit_conv_conf_t &jcp,
-            const primitive_attr_t &attr);
-    static status_t init_conf(jit_conv_conf_t &jcp,
-            const convolution_desc_t &cd,
-            cpu_memory_t::pd_t &src_pd,
-            cpu_memory_t::pd_t &weights_pd,
-            cpu_memory_t::pd_t &dst_pd,
-            cpu_memory_t::pd_t &bias_pd,
-            const primitive_attr_t &attr,
-            int nthreads,
-            bool with_relu = false,
-            float relu_negative_slope = 0.);
 
     jit_conv_conf_t jcp;
     const primitive_attr_t &attr_;
-    void (*jit_ker)(jit_conv_call_s *);
+    void (*jit_ker_)(jit_conv_call_s *);
 
 private:
-    using reg64_t = const Xbyak::Reg64;
-    using zmm_t = const Xbyak::Zmm;
-    using xmm_t = const Xbyak::Xmm;
+    nstl::vector<jit_uni_eltwise_injector_f32<avx512_common>*> eltwise_injectors;
+    nstl::vector<jit_uni_depthwise_injector_f32<avx512_common>*> depthwise_injectors;
+
     enum {
         typesize = sizeof(float),
         ker_reg_base_idx = 28,
+        ker_dw_reg_base_idx = 30,
     };
-    enum {
+    typedef enum {
         no_last_block,
         last_ic_block,
         last_sp_block,
-    };
-
-    reg64_t reg_inp = r8;
-    reg64_t reg_ker = r9;
-    reg64_t reg_out = r10;
-    reg64_t aux_reg_inp = r11;
-    reg64_t reg_ptr_sum_scale = r11;
-    reg64_t aux_reg_ker = r12;
-    reg64_t reg_owb = r12;
-
-    reg64_t reg_scratch = r14;
-    reg64_t reg_kj = rax;
-    reg64_t reg_overflow = rax;
-    reg64_t reg_ptr_scales = rax;
-    reg64_t reg_oi = rbx;
-    reg64_t reg_bias = rdx;
-    reg64_t reg_compensation = reg_scratch;
-    reg64_t reg_kh = abi_not_param1;
-    reg64_t param = abi_param1;
-    reg64_t reg_tmp = rbp;
-    reg64_t imm_addr64 = r15;
-    reg64_t reg_oc_blocks = rsi;
-    reg64_t reg_icb = reg_bias;
-    reg64_t reg_bias_alpha = reg_kh;
-
-    Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
-
-    zmm_t zmm_tmp = zmm_t(28);
-    zmm_t zmm_one = zmm_t(29);
-    zmm_t zmm_scales = zmm_t(30);
-    zmm_t zmm_shift = zmm_t(30);
-    zmm_t zmm_zero = zmm_t(31);
-    zmm_t zmm_wei = zmm_t(31);
-
-    zmm_t zmm_out(int i_ur, int i_oc) {
+    } ic_block_t;
+
+    /* data regs */
+    const Xbyak::Reg64 reg_ptr_scales = rax;
+    const Xbyak::Reg64 reg_inp = r8;
+    const Xbyak::Reg64 reg_ker = r9;
+    const Xbyak::Reg64 reg_out = r10;
+    const Xbyak::Reg64 aux_reg_inp = r11;
+    const Xbyak::Reg64 reg_ptr_sum_scale = r11;
+    const Xbyak::Reg64 aux_reg_ker = r12;
+    const Xbyak::Reg64 reg_compensation = r14;
+    /* counter regs */
+    const Xbyak::Reg64 reg_bias_alpha = abi_not_param1;
+    const Xbyak::Reg64 reg_oi = rbx;
+    const Xbyak::Reg64 reg_bias = rdx;
+    const Xbyak::Reg64 reg_oc_blocks = rsi;
+    const Xbyak::Reg64 reg_owb = aux_reg_ker;
+    const Xbyak::Reg64 reg_scratch = reg_compensation;
+    const Xbyak::Reg64 reg_kj = reg_ptr_scales;
+    const Xbyak::Reg64 reg_overflow = reg_ptr_scales;
+    const Xbyak::Reg64 reg_icb = reg_bias;
+
+    const Xbyak::Reg64 reg_d_weights = r15;
+    const Xbyak::Reg64 reg_d_bias = r13;
+
+    const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
+    const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3);
+
+    const Vmm vmm_wei = Vmm(31);
+    /* used during bias section of store_output */
+    const Vmm vmm_comp = Vmm(30); // only for signed input
+    const Vmm vmm_bias = Vmm(31);
+    /* used during post_op sum section of store_output */
+    const Vmm vmm_prev_dst = Vmm(31);
+    /* used during write-out section of store_output */
+    const Vmm vmm_zero = Vmm(31);
+
+    /* used in compute_ker (but set during prepare_output) */
+    const Vmm vmm_shift = vmm_comp; // only for signed input
+    /* used in compute_ker (but only for pre-VNNI machines) */
+    const Vmm vmm_tmp = Vmm(28); // not used for depthwise
+    const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise.
+
+    /* registers use only for depthwise
+       groups are always blocked by 16(padded if needed),
+       hence use only Zmm registers */
+    const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
+    Xbyak::Zmm zmm_src;
+    Xbyak::Zmm zmm_permute;
+    Xbyak::Zmm zmm_zero_blend; // used only for fast depthwise
+
+    Vmm vmm_out(int i_ur, int i_oc) {
         int idx = i_ur + i_oc * jcp.ur_w;
-        assert(idx < ker_reg_base_idx);
-        return zmm_t(idx);
+        assert(idx < (jcp.is_depthwise
+                    ? ker_dw_reg_base_idx : ker_reg_base_idx));
+        return Vmm(idx);
     }
-    xmm_t xmm_out(int i_ur, int i_oc) {
+    Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
         int idx = i_ur + i_oc * jcp.ur_w;
-        assert(idx < ker_reg_base_idx);
-        return xmm_t(idx);
+        assert(idx < (jcp.is_depthwise
+                    ? ker_dw_reg_base_idx : ker_reg_base_idx));
+        return Xbyak::Zmm(idx);
     }
-    zmm_t zmm_inp(int i_ic, int nb_x_blocking) {
+    Vmm vmm_inp(int i_ic, int nb_x_blocking) {
         int idx = i_ic + nb_x_blocking * jcp.ur_w;
         assert(idx < 31);
-        return zmm_t(idx);
+        return Vmm(idx);
     }
-    zmm_t zmm_bias_alpha() {
-        return zmm_t(jcp.nb_oc_blocking * jcp.ur_w);
+    Vmm vmm_bias_alpha() {
+        int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+        return Vmm(nb_c_block * jcp.ur_w);
     }
-    xmm_t xmm_bias_alpha() {
-        return xmm_t(jcp.nb_oc_blocking * jcp.ur_w);
+    Xbyak::Xmm xmm_bias_alpha() {
+        int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+        return Xbyak::Xmm(nb_c_block * jcp.ur_w);
     }
     int get_ow_start(int ki, int pad_l) {
         return nstl::max(0,
@@ -132,17 +157,79 @@ private:
                                                            * (jcp.dilate_w + 1),
                                            jcp.stride_w));
     }
-    bool maybe_relu(int position);
+
     void prepare_output(int ur_w);
-    void store_output(int ur_w, int last_oc_block_flag);
-    void compute_ker(int ur_w, int pad_l, int pad_r, int last_ic_block_flag,
-                                                        bool h_padded = false);
-    void kh_loop(int ur_w, int pad_l, int pad_r, int last_ic_block_flag);
+    void store_output(int ur_w, bool last_oc_block_flag);
+    void compute_ker_dw(
+            int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded);
+    void compute_ker(int ur_w, int pad_l, int pad_r,
+            ic_block_t last_ic_block_flag, bool h_padded = false);
+    void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
     void icb_loop(
             int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
     void generate();
-    void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
+    void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op,
         bool mask_flag);
+    const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false);
+};
+
+struct jit_avx512_core_x8s8s32x_fwd_kernel {
+
+    jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
+            const primitive_attr_t &attr) :
+        jit_ker(nullptr),
+        zmm_kernel_(nullptr),
+        ymm_kernel_(nullptr),
+        xmm_kernel_(nullptr) {
+            int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
+            switch (ch_block) {
+                case 16:
+                    zmm_kernel_ =
+                        new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm>(
+                                ajcp, attr);
+                    jit_ker = zmm_kernel_->jit_ker_;
+                    return;
+                case 8:
+                    ymm_kernel_ =
+                        new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm>(
+                                ajcp, attr);
+                    jit_ker = ymm_kernel_->jit_ker_;
+                    return;
+                case 4:
+                    xmm_kernel_ =
+                        new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm>(
+                                ajcp, attr);
+                    jit_ker = xmm_kernel_->jit_ker_;
+                    return;
+                default:
+                    assert(!"invalid channel blocking");
+            }
+    }
+
+    ~jit_avx512_core_x8s8s32x_fwd_kernel() {
+        delete xmm_kernel_;
+        delete ymm_kernel_;
+        delete zmm_kernel_;
+    }
+
+    static bool post_ops_ok(jit_conv_conf_t &jcp,
+            const primitive_attr_t &attr);
+
+    static status_t init_conf(jit_conv_conf_t &jcp,
+            const convolution_desc_t &cd,
+            cpu_memory_t::pd_t &src_pd,
+            cpu_memory_t::pd_t &weights_pd,
+            cpu_memory_t::pd_t &dst_pd,
+            cpu_memory_t::pd_t &bias_pd,
+            const primitive_attr_t &attr,
+            int nthreads);
+    static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+            const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
+
+    void (*jit_ker)(jit_conv_call_s *);
+    _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
+    _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm> *ymm_kernel_;
+    _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
 };
 
 }