Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_1x1_conv_kernel.hpp
index 31d5b62..af7ca95 100644 (file)
@@ -18,6 +18,8 @@
 #define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP
 
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
 #include "jit_generator.hpp"
 #include "jit_primitive_conf.hpp"
 #include "jit_uni_eltwise.hpp"
@@ -29,7 +31,8 @@ namespace cpu {
 
 struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
     jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
-            const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
+            const primitive_attr_t &attr)
+        : jcp(ajcp), attr_(attr)
     {
         this->generate();
         jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
@@ -51,25 +54,15 @@ struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
                                 const primitive_attr_t &attr);
 
     static status_t init_conf(jit_1x1_conv_conf_t &jcp,
-                                const convolution_desc_t &cd,
-                                const memory_desc_wrapper &src_d,
-                                const memory_desc_wrapper &weights_d,
-                                const memory_desc_wrapper &dst_d,
-                                const primitive_attr_t &attr,
-                                bool with_relu, float relu_negative_slope,
-                                int nthreads, bool reduce_src);
+            const convolution_desc_t &cd,
+            const memory_desc_wrapper &src_d,
+            const memory_desc_wrapper &weights_d,
+            const memory_desc_wrapper &dst_d,
+            const primitive_attr_t &attr,
+            int nthreads, bool reduce_src);
 
-    static status_t init_conf(jit_1x1_conv_conf_t &jcp,
-                              const convolution_desc_t &cd,
-                              const memory_desc_wrapper &src_d,
-                              const memory_desc_wrapper &weights_d,
-                              const memory_desc_wrapper &dst_d,
-                              const primitive_attr_t &attr,
-                              int nthreads, bool reduce_src)
-    {
-        return init_conf(jcp, cd, src_d, weights_d, dst_d, attr, false, 0.0,
-        nthreads, reduce_src);
-    }
+    static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+            const jit_1x1_conv_conf_t &jcp);
 
     jit_1x1_conv_conf_t jcp;
     const primitive_attr_t &attr_;
@@ -78,7 +71,6 @@ struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
   private:
     using reg64_t = const Xbyak::Reg64;
     using zmm_t = const Xbyak::Zmm;
-    using mask_t = const Xbyak::Opmask;
 
     reg64_t reg_bcast_data = r8;
     reg64_t reg_load_data = r10;
@@ -95,6 +87,7 @@ struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
     reg64_t reg_reduce_pos_flag = rax;
     reg64_t reg_output_stride = r13;
     reg64_t reg_bias_data = r12;
+    reg64_t reg_relu_ns = r13;
     reg64_t reg_bcast_loop_work = aux1_reg_bcast_data;
 
     Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31);
@@ -115,6 +108,7 @@ struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
     void generate();
     static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);
 };
+
 }
 }
 }