Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_fp32_wino_conv_4x3.cpp
index 4b9fbd6..60e2a69 100644 (file)
@@ -25,7 +25,6 @@
 #include "type_helpers.hpp"
 #include "utils.hpp"
 
-#include "jit_avx512_common_convolution_winograd.hpp"
 #include "jit_avx512_core_fp32_wino_conv_4x3.hpp"
 
 #ifndef _MSC_VER
@@ -41,12 +40,13 @@ namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 template <bool is_fwd>
 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
 ::weight_transform_data(const jit_conv_winograd_conf_t &jcp,
-        float *wp, float *twp)
+        float *wp, float *twp) const
 {
     float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
                  1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
@@ -70,7 +70,7 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
 template<bool is_fwd>
 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::output_transform_data
 (int image, const jit_conv_winograd_conf_t &jcp,
-    const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) {
+    const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const {
 
     float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
     float Ow[alpha][alpha][simd_w];
@@ -121,7 +121,7 @@ template<bool is_fwd>
 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
 ::output_transform_tileblock_data(int tile_block,
     const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
-    float *toutp, float *outp, float *bias) {
+    float *toutp, float *outp, float *bias) const {
 
     float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
     float Ow[alpha][alpha][simd_w];
@@ -171,7 +171,7 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
 template<bool is_fwd>
 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
     ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
-        float *inp, float *tinp)
+        float *inp, float *tinp) const
 {
     float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
                  0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
@@ -224,7 +224,7 @@ template <bool is_fwd>
 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
     ::input_transform_tileblock_data(int tile_block,
         const jit_conv_winograd_conf_t &jcp,
-        float *inp, float *tinp)
+        float *inp, float *tinp) const
 {
     float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
@@ -280,7 +280,8 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
 
 template <bool is_fwd>
 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
-        const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
+        const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
+        const memory_tracking::grantor_t &scratchpad) const {
     const auto &jcp = kernel_->jcp;
     const auto &p_ops = attr_->post_ops_;
 
@@ -306,10 +307,9 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
     array_offset_calculator<float, 2> bias(bias_ptr,
             jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
 
-    array_offset_calculator<float, 8> M(
-            (float *)((is_fwd
-                    ? (this->scratchpad_)->M_ptr()
-                    : (this->scratchpad_)->V_ptr())),
+    array_offset_calculator<float, 8> M(is_fwd
+            ? scratchpad.template get<float>(key_wino_M)
+            : scratchpad.template get<float>(key_wino_V),
             jcp.dimN_nb_block, jcp.dimM_nb_block,
             alpha, alpha,
             jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
@@ -317,7 +317,7 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
 
     auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
             ? wei_ptr
-            : (float *)(this->scratchpad_)->U_ptr();
+            : scratchpad.template get<float>(key_wino_U);
 
     array_offset_calculator<float, 8> U(wino_wei,
             jcp.dimM_nb_block,
@@ -325,23 +325,22 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
             jcp.dimK_nb_block,
             jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
             jcp.dimK_reg_block, jcp.dimM_simd_block);
-    array_offset_calculator<float, 8> V(
-            (float *)((is_fwd
-                    ? (this->scratchpad_)->V_ptr()
-                    : (this->scratchpad_)->M_ptr())),
+    array_offset_calculator<float, 8> V(is_fwd
+            ? scratchpad.template get<float>(key_wino_V)
+            : scratchpad.template get<float>(key_wino_M),
             jcp.dimN_nb_block, alpha, alpha,
             jcp.dimN_block, jcp.dimK_nb_block,
             jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
 
-    const bool want_padded_bias = jcp.with_bias
+    const bool wants_padded_bias = jcp.with_bias
         && jcp.oc_without_padding != jcp.oc;
     float last_slice_bias[simd_w] = {0};
-    if (want_padded_bias) {
+    if (wants_padded_bias) {
         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
     }
 
-#pragma omp parallel
+PRAGMA_OMP(parallel)
     {
         parallel_nd_in_omp(MB, jcp.dimK_nb_block, jcp.dimK_block,
                 [&](int img, int K_blk1, int K_blk2) {
@@ -367,7 +366,7 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
             });
         }
 
-#pragma omp barrier
+PRAGMA_OMP(barrier)
 
         parallel_nd_in_omp(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block,
             [&](int N_blk1, int oj, int oi, int M_blk1) {
@@ -383,14 +382,14 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
                             N_blk2, K_blk1, 0, 0, 0)), K_blk1);
         });
 
-#pragma omp barrier
+PRAGMA_OMP(barrier)
 
         parallel_nd_in_omp(MB, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block),
                     [&](int img, int M_blk1, int M_blk2) {
             const int M_blk =
                 M_blk1 * jcp.dimM_block  * jcp.dimM_reg_block + M_blk2;
 
-            float *bias_ptr = want_padded_bias
+            float *bias_ptr = wants_padded_bias
                 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
                 ? last_slice_bias : &bias(M_blk, 0);
             output_transform_data(img, jcp, p_ops,
@@ -400,16 +399,11 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
     }
 }
 
-template void
-_jit_avx512_core_fp32_wino_conv_4x3_t<true>::_execute_data_W_S_G_D(
-        const int, float *, float *, float *, float *);
-template void
-_jit_avx512_core_fp32_wino_conv_4x3_t<false>::_execute_data_W_S_G_D(
-        const int, float *, float *, float *, float *);
-
 template <bool is_fwd>
-void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(
-        const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
+void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(const int MB,
+        float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
+        const memory_tracking::grantor_t &scratchpad) const {
+
     const auto &jcp = kernel_->jcp;
     const auto &p_ops = attr_->post_ops_;
 
@@ -430,7 +424,7 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(
 
     auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
                 ? wei_ptr
-                : (float *)(this->scratchpad_)->U_ptr();
+                : scratchpad.template get<float>(key_wino_U);
 
     array_offset_calculator<float, 8> U(wino_wei,
             jcp.dimM_nb_block,
@@ -439,25 +433,23 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(
             jcp.dimM_block  * jcp.dimM_reg_block, jcp.dimK_block,
             jcp.dimK_reg_block, jcp.dimM_simd_block);
 
-    array_offset_calculator<float, 8> M(
-            (float *)((is_fwd
-                    ? (this->scratchpad_)->M_ptr()
-                    : (this->scratchpad_)->V_ptr())),
+    array_offset_calculator<float, 8> M(is_fwd
+            ? scratchpad.template get<float>(key_wino_M)
+            : scratchpad.template get<float>(key_wino_V),
             0, jcp.dimM_nb_block, alpha, alpha,
             jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
             jcp.dimN_reg_block, jcp.dimM_simd_block);
-    array_offset_calculator<float, 8> V(
-            (float *)((is_fwd
-                    ? (this->scratchpad_)->V_ptr()
-                    : (this->scratchpad_)->M_ptr())),
+    array_offset_calculator<float, 8> V(is_fwd
+            ? scratchpad.template get<float>(key_wino_V)
+            : scratchpad.template get<float>(key_wino_M),
             0, alpha, alpha, jcp.dimN_block,
             jcp.dimK_nb_block, jcp.dimK_block,
             jcp.dimN_reg_block, jcp.dimK_reg_block);
 
-    const bool want_padded_bias = jcp.with_bias
+    const bool wants_padded_bias = jcp.with_bias
         && jcp.oc_without_padding != jcp.oc;
     float last_slice_bias[simd_w] = {0};
-    if (want_padded_bias) {
+    if (wants_padded_bias) {
         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
     }
@@ -478,12 +470,12 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(
         });
     }
 
-#pragma omp parallel
+PRAGMA_OMP(parallel)
     {
 
     int ithr = mkldnn_get_thread_num();
 
-#pragma omp for schedule(static)
+PRAGMA_OMP(for schedule(static))
     for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
         for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
             for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
@@ -516,7 +508,7 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(
                 const int M_blk =
                     M_blk1 * jcp.dimM_block  * jcp.dimM_reg_block + M_blk2;
 
-                float *bias_ptr = want_padded_bias
+                float *bias_ptr = wants_padded_bias
                     && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
                     ? last_slice_bias : &bias(M_blk, 0);
 
@@ -529,12 +521,8 @@ void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(
     }
 }
 
-template void
-_jit_avx512_core_fp32_wino_conv_4x3_t<true>::_execute_data_W_SGD(
-        const int, float *, float *, float *, float *);
-template void
-_jit_avx512_core_fp32_wino_conv_4x3_t<false>::_execute_data_W_SGD(
-        const int, float *, float *, float *, float *);
+template struct _jit_avx512_core_fp32_wino_conv_4x3_t<true>;
+template struct _jit_avx512_core_fp32_wino_conv_4x3_t<false>;
 
 namespace {
 
@@ -545,7 +533,7 @@ void subarray_sum(size_t num_arrs, float *output, size_t nelems,
     const size_t blocks_number = nelems / block_size;
     const size_t tail = nelems % block_size;
 
-#pragma omp parallel
+PRAGMA_OMP(parallel)
     {
         const int ithr = mkldnn_get_thread_num();
         const int nthr = mkldnn_get_num_threads();
@@ -627,7 +615,7 @@ void array_sum(size_t num_arrs, float *output,
     const size_t blocks_number = nelems / block_size;
     const size_t tail = nelems % block_size;
 
-#pragma omp parallel
+PRAGMA_OMP(parallel)
     {
         const size_t ithr = mkldnn_get_thread_num();
         const size_t nthr = mkldnn_get_num_threads();
@@ -672,9 +660,10 @@ void array_sum(size_t num_arrs, float *output,
 } //bwdw namespace
 
 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
-_execute_backward_weights_SDGtWo() {
+_execute_backward_weights_SDGtWo(
+        const memory_tracking::grantor_t &scratchpad) const {
     const auto &jcp = kernel_->jcp;
-    const int nthreads = scratchpad_->num_threads();
+    const int nthreads = jcp.nthr;
 
     array_offset_calculator<float, 5> src((float *)this->input_memory(0),
             jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
@@ -683,20 +672,20 @@ _execute_backward_weights_SDGtWo() {
     array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
             jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
 
-    array_offset_calculator<float, 8> Us((float *)(scratchpad_->U_ptr()),
+    array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U),
             0, alpha, alpha,
             jcp.oc_block, jcp.ic_block,
             jcp.ic_simd_block,
             jcp.oc_reg_block,
             jcp.oc_simd_block);
 
-    int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
-        * jcp.ic / jcp.nb_ic * sizeof(float);
+    const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
+        * jcp.ic / jcp.nb_ic;
     array_offset_calculator<float, 7>diff_weights_prv(
-            (float *)(scratchpad_->U_ptr() + U_sz),
+            scratchpad.get<float>(key_wino_U) + U_sz,
             0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
 
-    array_offset_calculator<float, 8> M((float *)(scratchpad_->M_ptr()),
+    array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M),
             0, alpha, alpha,
             jcp.oc_block,
             jcp.nb_tile_block_ur,
@@ -704,7 +693,7 @@ _execute_backward_weights_SDGtWo() {
             jcp.oc_reg_block,
             jcp.oc_simd_block);
 
-    array_offset_calculator<float, 7> V((float *)(scratchpad_->V_ptr()),
+    array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V),
             0, alpha, alpha,
             jcp.ic_block,
             jcp.nb_tile_block_ur,
@@ -712,7 +701,7 @@ _execute_backward_weights_SDGtWo() {
             jcp.ic_simd_block);
 
     array_offset_calculator<float, 2> diff_bias_prv(
-            (float *)(scratchpad_->bias_ptr()), nthreads, jcp.oc);
+            scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
 
     auto trans_ker_p = jit_wino_transform_call_s();
     float I[alpha][alpha][simd_w];
@@ -724,7 +713,7 @@ _execute_backward_weights_SDGtWo() {
        1.13777777777778f};
     float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
 
-#pragma omp parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T)
+PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T))
 {
     if (jcp.with_bias) {
         parallel_nd_in_omp(nthreads, jcp.oc / simd_w,
@@ -740,7 +729,7 @@ _execute_backward_weights_SDGtWo() {
     int ithr = mkldnn_get_thread_num();
     for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
         int first_tblk = 0;
-#pragma omp for
+PRAGMA_OMP(for)
         for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
             int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
             int img = tile_index / (jcp.itiles * jcp.jtiles);
@@ -806,7 +795,7 @@ _execute_backward_weights_SDGtWo() {
     // Reduce diff-weights
     {
         float *output = (float *)(this->memory(0));
-        float *input_base = (float *)(scratchpad_->U_ptr() + U_sz);
+        float *input_base = scratchpad.get<float>(key_wino_U) + U_sz;
         int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
         float *input_ptrs[max_threads_number];
         for (int i = 0; i < nthreads; ++i) {
@@ -816,7 +805,7 @@ _execute_backward_weights_SDGtWo() {
 
         if (jcp.with_bias) {
             output = (float *)(this->memory(1));
-            input_base = (float *)(scratchpad_->bias_ptr());
+            input_base = scratchpad.get<float>(key_conv_bia_reduction);
             for (int i = 0; i < nthreads; ++i) {
                 input_ptrs[i] = input_base + jcp.oc * i;
             }
@@ -827,9 +816,10 @@ _execute_backward_weights_SDGtWo() {
 }
 
 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
-_execute_backward_weights_S_D_Giot_W() {
+_execute_backward_weights_S_D_Giot_W(
+        const memory_tracking::grantor_t &scratchpad) const {
     const auto &jcp = kernel_->jcp;
-    const int nthreads = scratchpad_->num_threads();
+    const int nthreads = jcp.nthr;
 
     array_offset_calculator<float, 5> src((float *)this->input_memory(0),
             jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
@@ -839,7 +829,7 @@ _execute_backward_weights_S_D_Giot_W() {
             jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
     array_offset_calculator<float, 1> diff_bias((float *)this->memory(1), jcp.oc);
 
-    array_offset_calculator<float, 9> U((float *)(scratchpad_->U_ptr()),
+    array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U),
             jcp.nb_ic, jcp.nb_oc,
             alpha, alpha,
             jcp.oc_block, jcp.ic_block,
@@ -847,9 +837,9 @@ _execute_backward_weights_S_D_Giot_W() {
             jcp.oc_reg_block,
             jcp.oc_simd_block);
 
-    int U_size = jcp.oc * jcp.ic * alpha * alpha * sizeof(float);
+    const int U_size = jcp.oc * jcp.ic * alpha * alpha;
     array_offset_calculator<float, 10> Us(
-            (float *)(scratchpad_->U_ptr() + U_size),
+            scratchpad.get<float>(key_wino_U) + U_size,
             0, jcp.nb_ic, jcp.nb_oc,
             alpha, alpha,
             jcp.oc_block, jcp.ic_block,
@@ -857,7 +847,7 @@ _execute_backward_weights_S_D_Giot_W() {
             jcp.oc_reg_block,
             jcp.oc_simd_block);
 
-    array_offset_calculator<float, 9> M((float *)(scratchpad_->M_ptr()),
+    array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M),
             jcp.nb_oc,
             jcp.tile_block,
             alpha, alpha,
@@ -867,7 +857,7 @@ _execute_backward_weights_S_D_Giot_W() {
             jcp.oc_reg_block,
             jcp.oc_simd_block);
 
-    array_offset_calculator<float, 8> V((float *)(scratchpad_->V_ptr()),
+    array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V),
             jcp.nb_ic,
             jcp.tile_block,
             alpha, alpha,
@@ -876,7 +866,7 @@ _execute_backward_weights_S_D_Giot_W() {
             jcp.ic_simd_block);
 
     array_offset_calculator<float, 2> diff_bias_prv(
-            (float *)(scratchpad_->bias_ptr()), nthreads, jcp.oc);
+            scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
 
     size_t input_starts[max_threads_number] = {0};
     size_t input_ends[max_threads_number] = {0};
@@ -892,7 +882,7 @@ _execute_backward_weights_S_D_Giot_W() {
     float I[alpha][alpha][simd_w];
     float T[alpha][alpha][simd_w];
 
-#pragma omp parallel firstprivate(first_tblk, trans_ker_p, I, T)
+PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T))
 {
     if (jcp.with_bias) {
         parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
@@ -941,7 +931,7 @@ _execute_backward_weights_S_D_Giot_W() {
         }
     });
 
-    #pragma omp barrier
+    PRAGMA_OMP(barrier)
 
     parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
         [&](int ifm1, int ofm1, int oj, int oi, int tblk1){
@@ -991,7 +981,7 @@ _execute_backward_weights_S_D_Giot_W() {
     }
 
     trans_ker_p.G = G_O_3x3_4x4;
-#pragma omp parallel firstprivate(trans_ker_p)
+PRAGMA_OMP(parallel firstprivate(trans_ker_p))
     {
         parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block,
             [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){