Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp
index 9765de9..4e3ff51 100644 (file)
 #define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
 
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
 #include "jit_generator.hpp"
 #include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+#include "jit_uni_depthwise.hpp"
 
 namespace mkldnn {
 namespace impl {
@@ -34,38 +38,39 @@ struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator {
         jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
     }
 
+    ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() {
+        for (auto inj : eltwise_injectors)
+            delete inj;
+        eltwise_injectors.clear();
+
+        for (auto inj : depthwise_injectors)
+            delete inj;
+        depthwise_injectors.clear();
+    }
+
     static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
                                 const primitive_attr_t &attr);
 
     static status_t init_conf(jit_1x1_conv_conf_t &jcp,
-                                const convolution_desc_t &cd,
-                                const memory_desc_wrapper &src_d,
-                                const memory_desc_wrapper &weights_d,
-                                const memory_desc_wrapper &dst_d,
-                                const memory_desc_wrapper &bias_d,
-                                const primitive_attr_t &attr,
-                                bool with_relu, float relu_negative_slope,
-                                int nthreads, bool reduce_src);
+            const convolution_desc_t &cd,
+            const memory_desc_wrapper &src_d,
+            const memory_desc_wrapper &weights_d,
+            const memory_desc_wrapper &dst_d,
+            const memory_desc_wrapper &bias_d,
+            const primitive_attr_t &attr,
+            int nthreads, bool reduce_src);
 
-    static status_t init_conf(jit_1x1_conv_conf_t &jcp,
-                              const convolution_desc_t &cd,
-                              const memory_desc_wrapper &src_d,
-                              const memory_desc_wrapper &weights_d,
-                              const memory_desc_wrapper &dst_d,
-                              const memory_desc_wrapper &bias_d,
-                              const primitive_attr_t &attr,
-                              int nthreads, bool reduce_src)
-    {
-        return init_conf(jcp, cd, src_d, weights_d, dst_d, bias_d, attr, false,
-            0.0, nthreads, reduce_src);
-    }
-    bool maybe_relu(int position);
+    static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+            const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr);
 
     jit_1x1_conv_conf_t jcp;
     const primitive_attr_t &attr_;
     void (*jit_ker)(jit_1x1_conv_call_s *);
 
   private:
+    nstl::vector<jit_uni_eltwise_injector_f32<avx512_common>*> eltwise_injectors;
+    nstl::vector<jit_uni_depthwise_injector_f32<avx512_common>*> depthwise_injectors;
+
     using reg64_t = const Xbyak::Reg64;
     using zmm_t = const Xbyak::Zmm;
     using mask_t = const Xbyak::Opmask;
@@ -90,6 +95,10 @@ struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator {
     reg64_t aux_reg_output_data = abi_not_param1;
     reg64_t reduce_loop_iter = abi_param1;
 
+    const Xbyak::Reg64 reg_d_weights = aux_reg_bcast_data;
+    const Xbyak::Reg64 reg_d_bias = reduce_loop_iter;
+    const Xbyak::Reg64 reg_oc_off = aux_reg_load_data;
+
     reg64_t reg_last_load = r8;
     mask_t ktail_mask = k6;
 
@@ -109,18 +118,17 @@ struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator {
     int reg_bcast_data_off = 16;
     int reg_load_data_off = 24;
     int reg_ptr_sum_scale_off = 32;
-    int reg_last_load_off = 40;
-    int reg_comp_data_off = 48;
-    int stack_space_needed = 56;
+    int reg_comp_data_off = 40;
+    int stack_space_needed = 48;
 
     void bcast_loop(int load_loop_blk);
     void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
 
     void generate();
-    static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);
     void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
         bool mask_flag);
 };
+
 }
 }
 }