Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_convolution_utils.hpp
index c2ebc45..1bcfcc3 100644 (file)
 #define CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP
 
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+
 #include "cpu_convolution_pd.hpp"
 #include "cpu_engine.hpp"
 #include "jit_primitive_conf.hpp"
-#include "mkldnn_thread.hpp"
-#include "scratchpad.hpp"
 
 namespace mkldnn {
 namespace impl {
@@ -30,32 +31,32 @@ namespace cpu {
 
 namespace jit_gemm_convolution_utils {
 
-    void im2col_3d(jit_gemm_conv_conf_t &jcp, const float *im, float *col,
+void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
         int od);
-    void im2col(jit_gemm_conv_conf_t &jcp, const float *im, float *col);
-    template <typename T>
-    void im2col_u8(jit_gemm_conv_conf_t &jcp, const T *im, uint8_t *col);
+void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
+       float *__restrict col, int hs, int hb, int ws, int wb);
+template <typename T>
+void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
+        uint8_t *__restrict col);
 
-    void col2im_s32(jit_gemm_conv_conf_t &jcp, const int32_t *col, int32_t *im);
-    void col2im_3d(jit_gemm_conv_conf_t &jcp, const float *col, float *im,
+void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col,
+        int32_t *__restrict im);
+void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im,
         int od);
-    void col2im(jit_gemm_conv_conf_t &jcp, const float *col, float *im);
+void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im);
 
-    void init_conf(jit_gemm_conv_conf_t &jcp,
-        const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
-        const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
-        int max_threads, bool with_relu = false, float relu_negative_slope = -1.0);
+status_t init_conf(jit_gemm_conv_conf_t &jcp,
+        memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
+        const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
+        const memory_desc_wrapper &dst_d, int max_threads);
 
-    status_t prepare_scratchpad(jit_gemm_conv_conf_t &jcp,
-                scratchpad_t **col_scratchpad_, size_t size, const int nthr);
-
-    void bwd_weights_balance(int ithr, int nthr,
-        int ngroups, int mb, int &ithr_g, int &nthr_g, int &ithr_mb,
-            int &nthr_mb);
-    void bwd_weights_reduction_par(int ithr, int nthr,
+void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb,
+        int &ithr_g, int &nthr_g, int &ithr_mb, int &nthr_mb);
+void bwd_weights_reduction_par(int ithr, int nthr,
         const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws,
-            float *weights);
-};
+        float *weights);
+
+}
 
 }
 }