Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_convolution_winograd.cpp
index 93db55e..eb45ba9 100644 (file)
@@ -37,6 +37,8 @@ namespace mkldnn {
 namespace impl {
 namespace cpu {
 
+using namespace memory_tracking::names;
+
 namespace {
 
 unsigned int LLC_cache_size = get_cache_size(3, false);
@@ -511,80 +513,6 @@ void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
 }
 
 template <bool is_fwd>
-void input_transform_tileblock_data(int tile_block,
-        const jit_conv_winograd_conf_t &jcp,
-        float *inp, float *tinp)
-{
-    const int inph = is_fwd ? jcp.ih : jcp.oh;
-    const int inpw = is_fwd ? jcp.iw : jcp.ow;
-    const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
-    const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
-    const int wp_max = inpw + l_pad;
-    const int hp_max = inph + t_pad;
-    float Iw[alpha][alpha][simd_w];
-    float I[alpha][alpha][simd_w];
-
-    array_offset_calculator<float, 5> input(inp,
-            jcp.mb, jcp.dimK/simd_w, inph, inpw, simd_w);
-    array_offset_calculator<float, 7> output(tinp,
-            alpha, alpha,
-            jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
-            jcp.dimN_reg_block, jcp.dimK_reg_block);
-
-    int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
-
-    for (int nb_tile_block_ur = 0;
-            nb_tile_block_ur < jcp.nb_tile_block_ur;
-            nb_tile_block_ur++) {
-        for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
-                tile_block_ur++) {
-
-            int img = tile_index / (jcp.jtiles * jcp.itiles);
-            int ti = tile_index % jcp.itiles;
-            int tj = (tile_index / jcp.itiles) % jcp.jtiles;
-            float *pinp_b = &(input(img, 0, 0, 0, 0));
-
-            for (int j = 0; j < alpha; j++) {
-                int ydim = tj * tile_size + j;
-                if ((t_pad <= ydim) && (ydim < hp_max)) {
-                    float *pinp_j = pinp_b + (ydim - t_pad) * inpw * simd_w;
-                    for (int i = 0; i < alpha; i++) {
-                        int xdim = ti * tile_size + i;
-                        if ((l_pad <= xdim) && (xdim < wp_max)) {
-                            float *pinp_i = pinp_j + (xdim - l_pad) * simd_w;
-                            load_ps(I[j][i], pinp_i);
-                        } else {
-                            PRAGMA_OMP_SIMD()
-                            for (int v = 0; v < simd_w; v++) {
-                                I[j][i][v] = 0.0f;
-                            }
-                        }
-                    }
-                } else {
-                    for (int i = 0; i < alpha; i++) {
-                        PRAGMA_OMP_SIMD()
-                        for (int v = 0; v < simd_w; v++) {
-                            I[j][i][v] = 0.0f;
-                        }
-                    }
-                }
-            }
-
-            trans_I_4x4_3x3(Iw, I);
-            for (int j = 0; j < alpha; j++) {
-                for (int i = 0; i < alpha; i++) {
-                    store_output(&(output(j, i,
-                                    nb_tile_block_ur, 0, 0,
-                                    tile_block_ur, 0)),
-                                 Iw[j][i], false);
-                }
-            }
-            tile_index++;
-        }
-    }
-}
-
-template <bool is_fwd>
 void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
         float *wp, float *twp)
 {
@@ -691,7 +619,7 @@ void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
                                     O[j][i][v] = true
                                         && with_relu_presum && O[j][i][v] < 0.f
                                                 ? O[j][i][v]
-                                                * jcp.eltwise_alpha
+                                                * jcp.eltwise.alpha
                                                 : O[j][i][v];
                                 }
                             }
@@ -717,83 +645,6 @@ void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
     }
 }
 
-template <bool is_fwd, bool with_bias, bool with_relu_presum, bool with_sum>
-void output_transform_tileblock_data(int tile_block,
-        const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
-        float *toutp, float *outp, float *bias, bool streamout) {
-    float Ow[alpha][alpha][simd_w];
-    float O[tile_size][tile_size][simd_w];
-    int outw = is_fwd ? jcp.ow : jcp.iw;
-    int outh = is_fwd ? jcp.oh : jcp.ih;
-
-    /* Prepare for PostOps */
-    bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1;
-
-    array_offset_calculator<float, 6> input(toutp,
-            alpha, alpha,
-            jcp.dimN_block, jcp.dimM_block,
-            jcp.dimN_reg_block, jcp.dimM_simd_block);
-    array_offset_calculator<float, 5> output(outp,
-            jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw,
-            jcp.dimM_simd_block);
-
-    int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
-
-    for (int nb_tile_block_ur = 0;
-            nb_tile_block_ur < jcp.nb_tile_block_ur;
-            nb_tile_block_ur++) {
-
-        for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
-                tile_block_ur++) {
-            int img = tile_index / (jcp.jtiles * jcp.itiles);
-            int ti = tile_index % jcp.itiles;
-            int tj = (tile_index / jcp.itiles) % jcp.jtiles;
-
-            for (int j = 0; j < alpha; j++) {
-                for (int i = 0; i < alpha; i++) {
-                    float *pinp_tile = &(input(j, i, nb_tile_block_ur, 0,
-                                tile_block_ur, 0));
-                    load_ps(Ow[j][i], pinp_tile);
-                }
-            }
-
-            trans_O_4x4_3x3(Ow, O);
-
-            float *pout_b = &(output(img, 0, 0, 0, 0));
-            for (int j = 0; j < tile_size; j++) {
-                int ydim = tj * tile_size + j;
-                if (ydim < outh) {
-                    float *pout_j = pout_b + ydim * outw * simd_w;
-                    for (int i = 0; i < tile_size; i++) {
-                        int xdim = ti * tile_size + i;
-                        if (xdim < outw) {
-                            float *pout_i = pout_j + xdim * simd_w;
-                            if (is_fwd) {
-                                PRAGMA_OMP_SIMD()
-                                for (int v = 0; v < simd_w; v++) {
-                                    O[j][i][v] += with_bias ? bias[v] : 0.f;
-                                    O[j][i][v] = true
-                                        && with_relu_presum && O[j][i][v] < 0.f
-                                                ? O[j][i][v]
-                                                * jcp.eltwise_alpha
-                                                : O[j][i][v];
-
-                                }
-                            }
-                            if (with_sum)
-                                accum_output(pout_i, O[j][i], streamout,
-                                        with_relu_postsum);
-                            else
-                                store_output(pout_i, O[j][i], streamout);
-                        }
-                    }
-                }
-            }
-            tile_index++;
-        }
-    }
-}
-
 template <bool ver_4fma>
 void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
         float *inp, float *tinp, float *Iw_temp,
@@ -1049,7 +900,8 @@ void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv,
 
 template <bool is_fwd>
 void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
-        const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
+        const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
+        const memory_tracking::grantor_t &scratchpad) const{
     const auto &jcp = kernel_->jcp;
     const auto &p_ops = attr_->post_ops_;
 
@@ -1058,7 +910,7 @@ void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
     const int outh = is_fwd ? jcp.oh : jcp.ih;
     const int outw = is_fwd ? jcp.ow : jcp.iw;
 
-    /* Note that jcp.with_relu is true for both fused conv+relu primitive
+    /* Note that jcp.with_eltwise is true for both fused conv+relu primitive
      * and conv primitive with PostOps with relu before sum
      * (PostOps relu after sum is handled later) */
     auto output_transform = jcp.with_bias
@@ -1094,24 +946,23 @@ void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
     array_offset_calculator<float, 2> bias(bias_ptr,
             jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
 
-    array_offset_calculator<float, 8> M(
-            (float *)((is_fwd
-                    ? (this->scratchpad_)->M_ptr()
-                    : (this->scratchpad_)->V_ptr())),
+    array_offset_calculator<float, 8> M(is_fwd
+            ? scratchpad.template get<float>(key_wino_M)
+            : scratchpad.template get<float>(key_wino_V),
             jcp.dimN_nb_block, jcp.dimM_nb_block,
             alpha, alpha,
             jcp.dimN_block, jcp.dimM_block,
             jcp.dimN_reg_block, jcp.dimM_simd_block);
-    array_offset_calculator<float, 8> U((float *)((this->scratchpad_)->U_ptr()),
+    array_offset_calculator<float, 8> U(
+            scratchpad.template get<float>(key_wino_U),
             jcp.dimM_nb_block,
             alpha, alpha,
             jcp.dimK_nb_block,
             jcp.dimM_block, jcp.dimK_block,
             jcp.dimK_reg_block, jcp.dimM_simd_block);
-    array_offset_calculator<float, 8> V(
-            (float *)((is_fwd
-                    ? (this->scratchpad_)->V_ptr()
-                    : (this->scratchpad_)->M_ptr())),
+    array_offset_calculator<float, 8> V(is_fwd
+            ? scratchpad.template get<float>(key_wino_V)
+            : scratchpad.template get<float>(key_wino_M),
             jcp.dimN_nb_block, alpha, alpha,
             jcp.dimN_block, jcp.dimK_nb_block,
             jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
@@ -1121,15 +972,15 @@ void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
 
     const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0;
 
-    const bool want_padded_bias = jcp.with_bias
+    const bool wants_padded_bias = jcp.with_bias
         && jcp.oc_without_padding != jcp.oc;
     float last_slice_bias[simd_w] = {0};
-    if (want_padded_bias) {
+    if (wants_padded_bias) {
         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
     }
 
-#pragma omp parallel
+PRAGMA_OMP(parallel)
     {
         parallel_nd_in_omp(MB, jcp.dimK_nb_block, jcp.dimK_block,
             [&](int img, int K_blk1, int K_blk2) {
@@ -1148,7 +999,7 @@ void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
                 ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr);
         });
 
-#pragma omp barrier
+PRAGMA_OMP(barrier)
 
         parallel_nd_in_omp(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block,
             [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) {
@@ -1174,14 +1025,14 @@ void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
         });
 
 
-#pragma omp barrier
+PRAGMA_OMP(barrier)
 
         parallel_nd_in_omp(MB, jcp.dimM_nb_block, jcp.dimM_block,
                     [&](int img, int M_blk1, int M_blk2) {
 
             const int M_blk = M_blk1 * jcp.dimM_block + M_blk2;
 
-            float *bias_ptr = want_padded_bias
+            float *bias_ptr = wants_padded_bias
                 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
                 ? last_slice_bias : &bias(M_blk, 0);
 
@@ -1194,180 +1045,25 @@ void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
     }
 }
 
-template void
-_jit_avx512_common_convolution_winograd_t<true>::_execute_data_W_S_G_D(
-        const int, float *, float *, float *, float *);
-template void
-_jit_avx512_common_convolution_winograd_t<false>::_execute_data_W_S_G_D(
-        const int, float *, float *, float *, float *);
-
-template <bool is_fwd>
-void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_SGD(
-        const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
-    const auto &jcp = kernel_->jcp;
-    const auto &p_ops = attr_->post_ops_;
-
-    const int inph = is_fwd ? jcp.ih : jcp.oh;
-    const int inpw = is_fwd ? jcp.iw : jcp.ow;
-    const int outh = is_fwd ? jcp.oh : jcp.ih;
-    const int outw = is_fwd ? jcp.ow : jcp.iw;
-
-    /* Note that jcp.with_relu is true for both fused conv+relu primitive
-     * and conv primitive with PostOps with relu before sum
-     * (PostOps relu after sum is handled later) */
-    auto output_transform_tileblock = jcp.with_bias
-            ? (jcp.with_eltwise
-                ? (jcp.with_sum
-                    ? output_transform_tileblock_data<is_fwd, true, true, true>
-                    : output_transform_tileblock_data<is_fwd, true, true, false>)
-                : (jcp.with_sum
-                    ? output_transform_tileblock_data<is_fwd, true, false, true>
-                    : output_transform_tileblock_data<is_fwd, true, false, false>))
-            : (jcp.with_eltwise
-                ? (jcp.with_sum
-                    ? output_transform_tileblock_data<is_fwd, false, true, true>
-                    : output_transform_tileblock_data<is_fwd, false, true, false>)
-                : (jcp.with_sum
-                    ? output_transform_tileblock_data<is_fwd, false, false, true>
-                    : output_transform_tileblock_data<is_fwd, false, false, false>));
-
-    array_offset_calculator<float, 5> input(inp_ptr,
-            MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
-    array_offset_calculator<float, 5> output(out_ptr,
-            MB, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
-    array_offset_calculator<float, 6> weights(wei_ptr,
-            jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
-            jcp.ic_simd_block, jcp.oc_simd_block);
-    array_offset_calculator<float, 2> bias(bias_ptr,
-            jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
-
-    array_offset_calculator<float, 8> U((float *)((this->scratchpad_)->U_ptr()),
-            jcp.dimM_nb_block,
-            alpha, alpha,
-            jcp.dimK_nb_block,
-            jcp.dimM_block, jcp.dimK_block,
-            jcp.dimK_reg_block, jcp.dimM_simd_block);
-
-    array_offset_calculator<float, 8> M(
-            (float *)((is_fwd
-                    ? (this->scratchpad_)->M_ptr()
-                    : (this->scratchpad_)->V_ptr())),
-            0, jcp.dimM_nb_block, alpha, alpha,
-            jcp.dimN_block, jcp.dimM_block,
-            jcp.dimN_reg_block, jcp.dimM_simd_block);
-
-    array_offset_calculator<float, 8> V(
-            (float *)((is_fwd
-                    ? (this->scratchpad_)->V_ptr()
-                    : (this->scratchpad_)->M_ptr())),
-            0, alpha, alpha, jcp.dimN_block,
-            jcp.dimK_nb_block, jcp.dimK_block,
-            jcp.dimN_reg_block, jcp.dimK_reg_block);
-
-    const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0;
-
-    const bool want_padded_bias = jcp.with_bias
-        && jcp.oc_without_padding != jcp.oc;
-    float last_slice_bias[simd_w] = {0};
-    if (want_padded_bias) {
-        for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
-            last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
-    }
-
-#pragma omp parallel
-    {
-        parallel_nd_in_omp(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block,
-            [&](int ofm1, int ifm1, int ofm2, int ifm2) {
-
-            float *U_base_ptr = is_fwd
-                              ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
-                              : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
-            weight_transform_data<is_fwd>(jcp,
-                    &(weights(ofm1 * jcp.oc_block + ofm2,
-                            ifm1 * jcp.ic_block + ifm2,
-                            0, 0, 0, 0)),
-                    U_base_ptr);
-        });
-
-#pragma omp barrier
-
-    int ithr = mkldnn_get_thread_num();
-
-#pragma omp for schedule(static)
-    for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
-        for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
-            for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
-                input_transform_tileblock_data<is_fwd>(
-                        tile_block, jcp,
-                        &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
-                        &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
-            }
-        }
-
-        for (int oj = 0; oj < alpha; oj++) {
-            for (int oi = 0; oi < alpha; oi++) {
-                for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
-                    for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++) {
-                        kernel_->gemm_loop_ker_first_iter(
-                                (float *)&(M(ithr, M_blk1, oj, oi,
-                                        N_blk, 0, 0, 0)),
-                                (const float *)&(U(M_blk1, oj, oi, 0,
-                                        0, 0, 0, 0)),
-                                (const float *)&(V(ithr, oj, oi,
-                                        N_blk, 0, 0, 0, 0)));
-                        for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
-                            kernel_->gemm_loop_ker(
-                                    (float *)&(M(ithr, M_blk1, oj, oi,
-                                            N_blk, 0, 0, 0)),
-                                    (const float *)&(U(M_blk1, oj, oi, K_blk1,
-                                            0, 0, 0, 0)),
-                                    (const float *)&(V(ithr, oj, oi,
-                                            N_blk, K_blk1, 0, 0, 0)));
-                        }
-                    }
-                }
-            }
-        }
-
-        for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
-            for (int M_blk2 = 0; M_blk2 < jcp.dimM_block; M_blk2++) {
-                const int M_blk = M_blk1 * jcp.dimM_block + M_blk2;
-
-                float *bias_ptr = want_padded_bias
-                    && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
-                    ? last_slice_bias : &bias(M_blk, 0);
-
-                output_transform_tileblock(tile_block, jcp, p_ops,
-                        &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
-                        &(output(0, M_blk, 0, 0, 0)),
-                        bias_ptr, output_is_aligned);
-            }
-        }
-    }
-    }
-}
-
-template void
-_jit_avx512_common_convolution_winograd_t<true>::_execute_data_W_SGD(
-        const int, float *, float *, float *, float *);
-template void
-_jit_avx512_common_convolution_winograd_t<false>::_execute_data_W_SGD(
-        const int, float *, float *, float *, float *);
+template struct _jit_avx512_common_convolution_winograd_t<true>;
+template struct _jit_avx512_common_convolution_winograd_t<false>;
 
 void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_maybe_execute_diff_bias_copy() {
-    if (conf_.want_padded_bias()) {
+_maybe_execute_diff_bias_copy(
+        const memory_tracking::grantor_t &scratchpad) const {
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
         float *diff_bias = (float *)this->memory(1);
-        for (int oc = 0; oc < conf_.jcp_.oc_without_padding; ++oc)
-            diff_bias[oc] = this->padded_bias_[oc];
+        for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
+            diff_bias[oc] = padded_bias[oc];
     }
 }
 
 void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_execute_backward_weights_S_D_G_W()
-{
+_execute_backward_weights_S_D_G_W(
+        const memory_tracking::grantor_t &scratchpad) const {
     const auto &jcp = kernel_->jcp;
-    const int nthreads = scratchpad_->num_threads();
+    const int nthreads = jcp.nthr;
 
     auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ?
             diff_src_transform_bwd_weights<true> :
@@ -1382,25 +1078,25 @@ _execute_backward_weights_S_D_G_W()
             jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w);
     array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
             jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
-    array_offset_calculator<float, 2> diff_bias(
-            conf_.want_padded_bias() ? padded_bias_ : (float *)this->memory(1),
-            jcp.oc/simd_w, simd_w);
+    array_offset_calculator<float, 2> diff_bias(pd()->wants_padded_bias()
+            ? scratchpad.get<float>(key_conv_padded_bias)
+            : (float *)this->memory(1), jcp.oc/simd_w, simd_w);
 
     array_offset_calculator<float, 8> U(
-            (float *)(scratchpad_->U_ptr()),
+            scratchpad.get<float>(key_wino_U),
             jcp.nb_ic, jcp.nb_oc,
             alpha, alpha,
             jcp.oc_block, jcp.ic_block,
             jcp.ic_simd_block, jcp.oc_simd_block);
 
     array_offset_calculator<float, 8> M(
-            (float *)(scratchpad_->M_ptr()),
+            scratchpad.get<float>(key_wino_M),
             jcp.nb_oc, alpha, alpha,
             jcp.tile_block, jcp.oc_block,
             jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
             jcp.oc_simd_block);
     array_offset_calculator<float, 8> V(
-            (float *)(scratchpad_->V_ptr()),
+            scratchpad.get<float>(key_wino_V),
             jcp.nb_ic, alpha, alpha,
             jcp.tile_block, jcp.ic_block,
             jcp.nb_tile_block_ur, jcp.tile_block_ur,
@@ -1409,23 +1105,23 @@ _execute_backward_weights_S_D_G_W()
     const int trans_buffer_size = alpha * alpha * jcp.tile_4fma
                                 * jcp.ic_simd_block;
     array_offset_calculator<float, 2> trans_buffer(
-            (float *)(scratchpad_->src_transpose_ptr()),
+            scratchpad.get<float>(key_conv_tr_src),
             nthreads,
             trans_buffer_size);
 
     array_offset_calculator<float, 2> diff_bias_prv(
-            (float *)(scratchpad_->bias_ptr()),
-            mkldnn_get_max_threads(),
+            scratchpad.get<float>(key_conv_bia_reduction),
+            nthreads,
             jcp.oc);
 
-#pragma omp parallel num_threads(nthreads)
+PRAGMA_OMP(parallel num_threads(nthreads))
     {
         if (jcp.with_bias) {
             parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
                 diff_bias_prv(ithr, ofm) = 0.0f;
             });
 
-#pragma omp for nowait
+PRAGMA_OMP(for nowait)
             for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
                 PRAGMA_OMP_SIMD()
                 for (int v = 0; v < simd_w; v++)
@@ -1461,7 +1157,7 @@ _execute_backward_weights_S_D_G_W()
                     dbias);
         });
 
-#pragma omp barrier
+PRAGMA_OMP(barrier)
 
         for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
             parallel_nd_in_omp(alpha, alpha, jcp.nb_oc,
@@ -1486,7 +1182,7 @@ _execute_backward_weights_S_D_G_W()
             });
         }
 
-#pragma omp barrier
+PRAGMA_OMP(barrier)
 
         parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
             [&](int ifm1, int ofm1, int ofm2, int ifm2) {
@@ -1497,7 +1193,7 @@ _execute_backward_weights_S_D_G_W()
         });
 
         if (jcp.with_bias) {
-#pragma omp for
+PRAGMA_OMP(for)
             for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
                 for (int ithr = 0; ithr < nthreads; ithr++) {
                     float* base_bias_ptr = &(diff_bias(ofm1, 0));
@@ -1512,806 +1208,9 @@ _execute_backward_weights_S_D_G_W()
         }
     }
 
-    _maybe_execute_diff_bias_copy();
+    _maybe_execute_diff_bias_copy(scratchpad);
 }
 
-namespace {
-
-const int max_threads_number = 1024;
-
-template <bool ver_4fma>
-void diff_src_transform_bwd_weights_tile(int tile_block,
-    jit_conv_winograd_conf_t conv, float *inp, float *tinp,
-    void(*transpose_4fma_ker)(float *, float *))
-{
-    const int ifwp = conv.iw + conv.l_pad;
-    const int ifhp = conv.ih + conv.t_pad;
-    float I[alpha][alpha][simd_w];
-    float Iw[alpha][alpha][simd_w];
-
-    float *Iw_buffer = nullptr;
-    if (ver_4fma) {
-        Iw_buffer = (float *)malloc(alpha * alpha * conv.tile_4fma
-            * simd_w * sizeof(float), 64);
-    }
-    array_offset_calculator<float, 4> Iw_scratchpad(Iw_buffer,
-        alpha, alpha, conv.tile_4fma, simd_w);
-    array_offset_calculator<float, 5> input(inp,
-        conv.mb, conv.ic / simd_w, conv.ih, conv.iw, simd_w);
-    array_offset_calculator<float, 7> output(tinp,
-        0, alpha, alpha,
-        conv.ic_block,
-        conv.nb_tile_block_ur, conv.tile_block_ur,
-        conv.ic_simd_block * conv.tile_4fma);
-
-    int tile_4fma = 0;
-
-    int n_tiles = tile_block * conv.nb_tile_block_ur * conv.tile_block_ur;
-    for (int nb_tile_block_ur = 0; nb_tile_block_ur < conv.nb_tile_block_ur;
-        nb_tile_block_ur++) {
-        for (int tile_block_ur = 0; tile_block_ur < conv.tile_block_ur;
-            tile_block_ur++) {
-
-            int img = n_tiles / (conv.jtiles * conv.itiles);
-            int no_tile = n_tiles % (conv.jtiles * conv.itiles);
-            int ti = no_tile % conv.itiles;
-            int tj = no_tile / conv.itiles;
-
-            for (int j = 0; j < alpha; j++) {
-                int ydim = tj * tile_size + j;
-                if ((conv.t_pad <= ydim) && ydim < ifhp) {
-                    for (int i = 0; i < alpha; i++) {
-                        int xdim = ti * tile_size + i;
-                        if ((conv.l_pad <= xdim) && xdim < ifwp) {
-                            PRAGMA_OMP_SIMD()
-                            for (int v = 0; v < simd_w; v++) {
-                                I[j][i][v] = input(img, 0,
-                                    ydim - conv.t_pad,
-                                    xdim - conv.l_pad, v);
-                            }
-                        }
-                        else {
-                            PRAGMA_OMP_SIMD()
-                            for (int v = 0; v < simd_w; v++) {
-                                I[j][i][v] = 0.0f;
-                            }
-                        }
-                    }
-                }
-                else {
-                    for (int i = 0; i < alpha; i++) {
-                        PRAGMA_OMP_SIMD()
-                        for (int v = 0; v < simd_w; v++) {
-                            I[j][i][v] = 0.0f;
-                        }
-                    }
-                }
-            }
-
-            trans_I_4x4_3x3(Iw, I);
-
-            if (ver_4fma) {
-                for (int j = 0; j < alpha; j++) {
-                    for (int i = 0; i < alpha; i++) {
-                        PRAGMA_OMP_SIMD()
-                        for (int v = 0; v < simd_w; v++) {
-                            Iw_scratchpad(j, i, tile_4fma, v) = Iw[j][i][v];
-                        }
-                    }
-                }
-                tile_4fma++;
-                if (tile_4fma == conv.tile_4fma) {
-                    float *outp = &(output(0, 0, 0, 0,
-                        nb_tile_block_ur, tile_block_ur, 0));
-                    transpose_4fma_ker(outp, (float *)Iw_buffer);
-                    tile_4fma = 0;
-                }
-            }
-            else {
-                for (int j = 0; j < alpha; j++) {
-                    for (int i = 0; i < alpha; i++) {
-                        store_output(
-                            &(output(0, j, i, 0,
-                                nb_tile_block_ur, tile_block_ur, 0)),
-                            Iw[j][i], false);
-
-                    }
-                }
-            }
-            n_tiles++;
-        }
-    }
-}
-
-template <bool with_bias>
-void diff_dst_transform_bwd_weights_tile(int tile_block,
-    jit_conv_winograd_conf_t conv, float *inp, float *tinp, float *dbias)
-{
-    float I[alpha][alpha][simd_w];
-    float Iw[alpha][alpha][simd_w];
-
-    array_offset_calculator<float, 5> input(inp,
-        conv.mb, conv.oc / simd_w, conv.oh, conv.ow, conv.oc_simd_block);
-    array_offset_calculator<float, 7> output(tinp,
-        conv.nb_oc, alpha, alpha,
-        conv.oc_block,
-        conv.nb_tile_block_ur,
-        conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block);
-
-    int n_tiles = tile_block * conv.nb_tile_block_ur * conv.tile_block_ur;
-    for (int nb_tile_block_ur = 0; nb_tile_block_ur < conv.nb_tile_block_ur;
-        nb_tile_block_ur++) {
-        for (int tile_block_ur = 0; tile_block_ur < conv.tile_block_ur;
-            tile_block_ur++) {
-
-            int img = n_tiles / (conv.jtiles * conv.itiles);
-            int no_tile = n_tiles % (conv.jtiles * conv.itiles);
-            int ti = no_tile % conv.itiles;
-            int tj = no_tile / conv.itiles;
-
-            for (int j = 0; j < alpha; j++) {
-                int ydim = tj * tile_size + j;
-                if (ydim < conv.oh) {
-                    for (int i = 0; i < alpha; i++) {
-                        int xdim = ti * tile_size + i;
-                        if (xdim < conv.ow) {
-                            float *input_base = &input(img, 0, ydim, xdim, 0);
-
-                            PRAGMA_OMP_SIMD()
-                            for (int v = 0; v < simd_w; v++) {
-                                I[j][i][v] = input_base[v];
-                            }
-                            if (with_bias && j < tile_size && i < tile_size) {
-                                PRAGMA_OMP_SIMD()
-                                for (int v = 0; v < simd_w; v++) {
-                                    dbias[v] += input_base[v];
-                                }
-                            }
-                        }
-                        else {
-                            PRAGMA_OMP_SIMD()
-                            for (int v = 0; v < simd_w; v++) {
-                                I[j][i][v] = 0.0f;
-                            }
-                        }
-                    }
-                }
-                else {
-                    for (int i = 0; i < alpha; i++) {
-                        PRAGMA_OMP_SIMD()
-                        for (int v = 0; v < simd_w; v++) {
-                            I[j][i][v] = 0.0f;
-                        }
-                    }
-                }
-            }
-
-            trans_W_3x3_4x4_wu(Iw, I);
-
-            for (int j = 0; j < alpha; j++) {
-                for (int i = 0; i < alpha; i++) {
-                    /*TODO: Try instrinsic for casting into __m512*/
-                    store_output(&(output(0, j, i, 0,
-                        nb_tile_block_ur, tile_block_ur, 0)),
-                        Iw[j][i], false);
-                }
-            }
-            n_tiles++;
-        }
-    }
-}
-
-// Sum to the first buffer array
-void array_sum(int num_arrs, float *output,
-    size_t nelems, float *input_ptrs[], bool reduce_to_first = true)
-{
-    const size_t block_size = 16 * 1024 / sizeof(float);
-    const size_t blocks_number = nelems / block_size;
-    const size_t tail = nelems % block_size;
-
-#pragma omp parallel
-    {
-        const int ithr = mkldnn_get_thread_num();
-        const int nthr = mkldnn_get_num_threads();
-        size_t start{ 0 }, end{ 0 };
-        balance211(blocks_number, nthr, ithr, start, end);
-
-        for (size_t nb = start; nb < end; ++nb) {
-            size_t start_e = nb * block_size;
-            size_t end_e = start_e + block_size;
-            if (!reduce_to_first) {
-                PRAGMA_OMP_SIMD()
-                for (size_t e = start_e; e < end_e; e++) {
-                    output[e] = input_ptrs[0][e];
-                }
-            }
-            for (int a = 1; a < num_arrs; a++) {
-                PRAGMA_OMP_SIMD()
-                for (size_t e = start_e; e < end_e; e++) {
-                    output[e] += input_ptrs[a][e];
-                }
-            }
-        }
-
-        if (tail != 0 && ithr == nthr - 1) {
-            size_t start_e = nelems - tail;
-            size_t end_e = nelems;
-            if (!reduce_to_first) {
-                PRAGMA_OMP_SIMD()
-                for (size_t e = start_e; e < end_e; e++) {
-                    output[e] = input_ptrs[0][e];
-                }
-            }
-            for (int a = 1; a < num_arrs; a++) {
-                PRAGMA_OMP_SIMD()
-                for (size_t e = start_e; e < end_e; e++) {
-                    output[e] += input_ptrs[a][e];
-                }
-            }
-        }
-    }
-}
-
-void subarray_sum(int num_arrs, float *output, size_t nelems,
-        float *input_ptrs[], size_t input_starts[], size_t input_ends[])
-{
-    using namespace nstl;
-    const size_t block_size = 16 * 1024 / sizeof(float);
-    const size_t blocks_number = nelems / block_size;
-    const size_t tail = nelems % block_size;
-
-#pragma omp parallel
-    {
-        const int ithr = mkldnn_get_thread_num();
-        const int nthr = mkldnn_get_num_threads();
-        size_t start{ 0 }, end{ 0 };
-        balance211(blocks_number, nthr, ithr, start, end);
-
-        for (size_t nb = start; nb < end; ++nb) {
-            size_t start_e = nb * block_size;
-            size_t end_e = start_e + block_size;
-            size_t input_start = max(start_e, min(input_starts[0], end_e));
-            size_t input_end = max(start_e, min(input_ends[0], end_e));
-
-            PRAGMA_OMP_SIMD()
-            for (size_t e = start_e; e < input_start; e++) {
-                output[e] = 0.f;
-            }
-
-            PRAGMA_OMP_SIMD()
-            for (size_t e = input_start; e < input_end; e++) {
-                output[e] = input_ptrs[0][e];
-            }
-
-            PRAGMA_OMP_SIMD()
-            for (size_t e = input_end; e < end_e; e++) {
-                output[e] = 0.f;
-            }
-            for (int a = 1; a < num_arrs; a++) {
-                input_start = max(start_e, input_starts[a]);
-                input_end = min(input_ends[a], end_e);
-
-                PRAGMA_OMP_SIMD()
-                for (size_t e = input_start; e < input_end; e++) {
-                    output[e] += input_ptrs[a][e];
-                }
-            }
-        }
-
-        if (tail != 0 && ithr == nthr - 1) {
-            size_t start_e = nelems - tail;
-            size_t end_e = nelems;
-            size_t input_start = max(start_e, min(input_starts[0], end_e));
-            size_t input_end = max(start_e, min(input_ends[0], end_e));
-
-            PRAGMA_OMP_SIMD()
-            for (size_t e = start_e; e < input_start; e++) {
-                output[e] = 0.f;
-            }
-
-            PRAGMA_OMP_SIMD()
-            for (size_t e = input_start; e < input_end; e++) {
-                output[e] = input_ptrs[0][e];
-            }
-
-            PRAGMA_OMP_SIMD()
-            for (size_t e = input_end; e < end_e; e++) {
-                output[e] = 0.f;
-            }
-            for (int a = 1; a < num_arrs; a++) {
-                input_start = max(start_e, input_starts[a]);
-                input_end = min(input_ends[a], end_e);
-
-                PRAGMA_OMP_SIMD()
-                for (size_t e = start_e; e < end_e; e++) {
-                    output[e] += input_ptrs[a][e];
-                }
-            }
-        }
-    }
-}
-} // namespace
-
-void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_execute_backward_weights_S_D_Giot_W()
-{
-    const auto &jcp = kernel_->jcp;
-    const int nthreads = scratchpad_->num_threads();
-    int U_size = jcp.oc * jcp.ic * alpha * alpha * sizeof(float);
-
-    auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ?
-            diff_src_transform_bwd_weights<true> :
-            diff_src_transform_bwd_weights<false>;
-    auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
-                                        ? diff_dst_transform_bwd_weights<true>
-                                        : diff_dst_transform_bwd_weights<false>;
-
-    array_offset_calculator<float, 5> diff_src((float *)this->input_memory(0),
-            jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
-    array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
-            jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
-    array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
-            jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
-    array_offset_calculator<float, 2> diff_bias(
-            conf_.want_padded_bias() ? padded_bias_ : (float *)this->memory(1),
-            jcp.oc / simd_w, simd_w);
-
-    array_offset_calculator<float, 8> U((float *)(scratchpad_->U_ptr()),
-            jcp.nb_ic, jcp.nb_oc,
-            alpha, alpha,
-            jcp.oc_block, jcp.ic_block,
-            jcp.ic_simd_block, jcp.oc_simd_block);
-
-    array_offset_calculator<float, 9> Us(
-            (float *)(scratchpad_->U_ptr() + U_size),
-            0, jcp.nb_ic, jcp.nb_oc,
-            alpha, alpha,
-            jcp.oc_block, jcp.ic_block,
-            jcp.ic_simd_block, jcp.oc_simd_block);
-
-    array_offset_calculator<float, 8> M((float *)(scratchpad_->M_ptr()),
-            jcp.nb_oc, alpha, alpha,
-            jcp.tile_block, jcp.oc_block,
-            jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
-            jcp.oc_simd_block);
-
-    array_offset_calculator<float, 8> V((float *)(scratchpad_->V_ptr()),
-            jcp.nb_ic, alpha, alpha,
-            jcp.tile_block, jcp.ic_block,
-            jcp.nb_tile_block_ur, jcp.tile_block_ur,
-            jcp.ic_simd_block * jcp.tile_4fma);
-
-    const int trans_buffer_size = alpha * alpha * jcp.tile_4fma
-        * jcp.ic_simd_block;
-    array_offset_calculator<float, 2> trans_buffer(
-        (float *)(scratchpad_->src_transpose_ptr()),
-        nthreads,
-        trans_buffer_size);
-
-    array_offset_calculator<float, 2> diff_bias_prv(
-            (float *)(scratchpad_->bias_ptr()), nthreads, jcp.oc);
-
-#pragma omp parallel
-    {
-        if (jcp.with_bias) {
-            parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
-                diff_bias_prv(ithr, ofm) = 0.0f;
-            });
-#pragma omp for nowait
-            for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
-                PRAGMA_OMP_SIMD()
-                for (int v = 0; v < simd_w; v++)
-                    diff_bias(bofm, v) = 0.0f;
-            }
-        }
-    }
-
-#pragma omp parallel
-    {
-        const int ithread = mkldnn_get_thread_num();
-        parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block,
-            [&](int img, int ifm1, int ifm2) {
-                float *transb = jcp.ver == ver_4fma
-                    ? &(trans_buffer(ithread, 0))
-                    : NULL;
-                diff_src_transform_bwd_weights_ver(img, jcp,
-                    &(diff_src(img, ifm1 * jcp.ic_block + ifm2,
-                            0, 0, 0)),
-                    &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)),
-                    transb,
-                    kernel_->transpose_4fma_ker);
-        });
-    }
-
-#pragma omp parallel num_threads(nthreads)
-    {
-        parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block,
-            [&](int img, int ofm1, int ofm2) {
-                const int ithread = mkldnn_get_thread_num();
-                float *dbias = jcp.with_bias
-                    ? &(diff_bias_prv(ithread,
-                        simd_w * (ofm1 * jcp.oc_block + ofm2)))
-                    : NULL;
-                diff_dst_transform_bwd_weights_ver(img, jcp,
-                    &(diff_dst(img, ofm1 * jcp.oc_block + ofm2, 0, 0, 0)),
-                    &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)), dbias);
-        });
-    }
-
-    size_t input_starts[max_threads_number];
-    size_t input_ends[max_threads_number];
-    int th_counter = 0;
-#pragma omp parallel firstprivate(th_counter) num_threads(nthreads)
-    {
-        parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
-            [&](int ifm1, int ofm1, int oj, int oi, int tile_block) {
-                int ithr = mkldnn_get_thread_num();
-                if (th_counter == 0) {
-                    input_starts[ithr] = (float *)&(Us(ithr, ifm1, ofm1,
-                            oj, oi, 0, 0, 0, 0)) - (float *)&(Us(ithr, 0, 0,
-                            0, 0, 0, 0, 0, 0));
-                    input_ends[ithr] = input_starts[ithr]
-                        + jcp.oc_block * jcp.ic_block
-                        * jcp.ic_simd_block * jcp.oc_simd_block;
-                }
-                else if (tile_block == 0) {
-                    input_ends[ithr] += jcp.oc_block * jcp.ic_block
-                        * jcp.ic_simd_block * jcp.oc_simd_block;
-                }
-
-                if (th_counter == 0 || tile_block == 0) {
-                    kernel_->gemm_loop_ker_first_iter(
-                        &(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, 0)),
-                        &(M(ofm1, oj, oi, tile_block, 0, 0, 0, 0)),
-                        &(V(ifm1, oj, oi, tile_block, 0, 0, 0, 0)));
-                } else {
-                    kernel_->gemm_loop_ker(
-                        &(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, 0)),
-                        &(M(ofm1, oj, oi, tile_block, 0, 0, 0, 0)),
-                        &(V(ifm1, oj, oi, tile_block, 0, 0, 0, 0)));
-                }
-                th_counter++;
-        });
-    }
-
-
-    // Reduce diff-weights
-    {
-        float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0));
-        size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
-        float *input_ptrs[max_threads_number];
-        for (int i = 0; i < nthreads; i++)
-            input_ptrs[i] = output + nelems * (i + 1);
-        subarray_sum(
-                nthreads, output, nelems, input_ptrs, input_starts, input_ends);
-    }
-
-    parallel_nd(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
-        [&](int ifm1, int ofm1, int ofm2, int ifm2) {
-            diff_weights_transform_bwd_weights(jcp,
-                    &(diff_weights(ofm1 * jcp.oc_block + ofm2,
-                            ifm1 * jcp.ic_block + ifm2,
-                            0, 0, 0, 0)),
-                    &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0)));
-    });
-
-#pragma omp parallel
-    if (jcp.with_bias) {
-#pragma omp for
-        for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
-            for (int ithr = 0; ithr < nthreads; ithr++) {
-                float* base_bias_ptr = &(diff_bias(ofm1, 0));
-                float* base_bias_prv_ptr = &(diff_bias_prv(
-                            ithr * jcp.oc + ofm1 * simd_w));
-                PRAGMA_OMP_SIMD()
-                for (int ofm2 = 0; ofm2 < simd_w; ofm2++) {
-                    base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2];
-                }
-            }
-        }
-    }
-
-    _maybe_execute_diff_bias_copy();
-}
-
-void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_execute_backward_weights_SDGtWo()
-{
-    const auto &jcp = kernel_->jcp;
-    const int nthreads = scratchpad_->num_threads();
-
-    auto diff_src_transform_bwd_weights_ver_tile = jcp.ver == ver_4fma ?
-            diff_src_transform_bwd_weights_tile<true> :
-            diff_src_transform_bwd_weights_tile<false>;
-    auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
-                                  ? diff_dst_transform_bwd_weights_tile<true>
-                                  : diff_dst_transform_bwd_weights_tile<false>;
-
-    array_offset_calculator<float, 5> diff_src((float *)this->input_memory(0),
-            jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
-    array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
-            jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
-    array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
-            jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
-    array_offset_calculator<float, 3> diff_bias(
-            conf_.want_padded_bias() ? padded_bias_ : (float *)this->memory(1),
-            jcp.nb_oc, jcp.oc_block, simd_w);
-
-    array_offset_calculator<float, 8> Us((float *)(scratchpad_->U_ptr()),
-            0, jcp.nb_ic, alpha, alpha,
-            jcp.oc_block, jcp.ic_block,
-            jcp.ic_simd_block, jcp.oc_simd_block);
-
-    array_offset_calculator<float, 7> M((float *)(scratchpad_->M_ptr()),
-            0, alpha, alpha,
-            jcp.oc_block,
-            jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
-            jcp.oc_simd_block);
-
-    array_offset_calculator<float, 8> V((float *)(scratchpad_->V_ptr()),
-            0, jcp.nb_ic, alpha, alpha,
-            jcp.ic_block,
-            jcp.nb_tile_block_ur, jcp.tile_block_ur,
-            jcp.ic_simd_block * jcp.tile_4fma);
-
-    array_offset_calculator<float, 2> diff_bias_prv(
-            (float *)(scratchpad_->bias_ptr()),
-            nthreads, jcp.oc / jcp.nb_oc);
-
-    for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
-        int th_counter = 0;
-
-#pragma omp parallel
-        {
-            if (jcp.with_bias) {
-                parallel_nd_in_omp(nthreads, jcp.oc / jcp.nb_oc,
-                    [&](int ithr, int ofm) {
-                        diff_bias_prv(ithr, ofm) = 0.0f;
-                });
-#pragma omp for nowait
-                for (int bofm = 0; bofm < jcp.oc_block; bofm++) {
-                    PRAGMA_OMP_SIMD()
-                    for (int v = 0; v < simd_w; v++)
-                        diff_bias(ofm1, bofm, v) = 0.0f;
-                }
-            }
-        }
-
-#pragma omp parallel firstprivate(th_counter) num_threads(nthreads)
-#pragma omp for nowait
-        for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
-            int ithr = mkldnn_get_thread_num();
-            for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
-                for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
-                    diff_src_transform_bwd_weights_ver_tile(tile_block, jcp,
-                            &(diff_src(0, ifm1 * jcp.ic_block + ifm2, 0, 0, 0)),
-                            &(V(ithr, ifm1, 0, 0, ifm2, 0, 0, 0)),
-                            kernel_->transpose_4fma_ker);
-                }
-            }
-
-            for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
-                float *dbias = jcp.with_bias
-                    ? &(diff_bias_prv(ithr, simd_w * ofm2))
-                    : NULL;
-                diff_dst_transform_bwd_weights_ver(tile_block, jcp,
-                        &(diff_dst(0, ofm1 * jcp.oc_block + ofm2, 0, 0, 0)),
-                        &(M(ithr, 0, 0, ofm2, 0, 0, 0)),
-                        dbias);
-            }
-
-            for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
-                for (int oj = 0; oj < alpha; oj++) {
-                    for (int oi = 0; oi < alpha; oi++) {
-                        if (th_counter == 0)
-                            kernel_->gemm_loop_ker_first_iter(
-                                    &(Us(ithr, ifm1, oj, oi, 0, 0, 0, 0)),
-                                    &(M(ithr, oj, oi, 0, 0, 0, 0)),
-                                    &(V(ithr, ifm1, oj, oi, 0, 0, 0, 0)));
-                        else
-                            kernel_->gemm_loop_ker(
-                                    &(Us(ithr, ifm1, oj, oi, 0, 0, 0, 0)),
-                                    &(M(ithr, oj, oi, 0, 0, 0, 0)),
-                                    &(V(ithr, ifm1, oj, oi, 0, 0, 0, 0)));
-                    }
-                }
-            }
-            th_counter++;
-        }
-        // Reduce diff-weights
-        {
-            float *output = (float *)(scratchpad_->U_ptr());
-            size_t nelems
-                    = jcp.ic * (jcp.oc / jcp.nb_oc) * alpha * alpha;
-            float *input_ptrs[max_threads_number];
-            for (int i = 0; i < nthreads; i++) {
-                input_ptrs[i] = output + nelems * i;
-            }
-            array_sum(nthreads, output, nelems, input_ptrs);
-        }
-
-        parallel_nd(jcp.nb_ic, jcp.oc_block, jcp.ic_block,
-            [&](int ifm1, int ofm2, int ifm2) {
-            diff_weights_transform_bwd_weights(jcp,
-                    &(diff_weights(ofm1 * jcp.oc_block + ofm2,
-                            ifm1 * jcp.ic_block + ifm2,
-                            0, 0, 0, 0)),
-                    &(Us(0, ifm1, 0, 0, ofm2, ifm2, 0, 0)));
-        });
-
-#pragma omp parallel
-        if (jcp.with_bias) {
-#pragma omp for
-            for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
-                for (int ithr = 0; ithr < nthreads; ithr++) {
-                    float* base_bias_ptr = &(diff_bias(ofm1, ofm2, 0));
-                    float* base_bias_prv_ptr = &(diff_bias_prv(
-                                ithr * jcp.oc_block * simd_w + ofm2 * simd_w));
-                    PRAGMA_OMP_SIMD()
-                    for (int ofm3 = 0; ofm3 < simd_w; ofm3++) {
-                        base_bias_ptr[ofm3] += base_bias_prv_ptr[ofm3];
-                    }
-                }
-            }
-        }
-    }
-
-    _maybe_execute_diff_bias_copy();
-}
-
-void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_execute_backward_weights_SDGt_W()
-{
-    const auto &jcp = kernel_->jcp;
-    const int nthreads = scratchpad_->num_threads();
-
-    auto diff_src_transform_bwd_weights_ver_tile = jcp.ver == ver_4fma ?
-            diff_src_transform_bwd_weights_tile<true> :
-            diff_src_transform_bwd_weights_tile<false>;
-    auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
-                                  ? diff_dst_transform_bwd_weights_tile<true>
-                                  : diff_dst_transform_bwd_weights_tile<false>;
-
-    array_offset_calculator<float, 5> diff_src((float *)this->input_memory(0),
-            jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
-    array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
-            jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
-    array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
-            jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
-    array_offset_calculator<float, 2> diff_bias(
-            conf_.want_padded_bias() ? padded_bias_ : (float *)this->memory(1),
-            jcp.oc / simd_w, simd_w);
-
-    array_offset_calculator<float, 8> U((float *)(scratchpad_->U_ptr()),
-            jcp.nb_oc, jcp.nb_ic,
-            alpha, alpha,
-            jcp.oc_block, jcp.ic_block,
-            jcp.ic_simd_block, jcp.oc_simd_block);
-
-    array_offset_calculator<float, 9> Us((float *)(scratchpad_->U_ptr()),
-            0, jcp.nb_oc, jcp.nb_ic,
-            alpha, alpha,
-            jcp.oc_block, jcp.ic_block,
-            jcp.ic_simd_block, jcp.oc_simd_block);
-
-    array_offset_calculator<float, 8> M((float *)(scratchpad_->M_ptr()),
-            0, jcp.nb_oc, alpha, alpha, jcp.oc_block,
-            jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
-            jcp.oc_simd_block);
-
-    array_offset_calculator<float, 8> V((float *)(scratchpad_->V_ptr()),
-            0, jcp.nb_ic, alpha, alpha, jcp.ic_block,
-            jcp.nb_tile_block_ur, jcp.tile_block_ur,
-            jcp.ic_simd_block * jcp.tile_4fma);
-
-    array_offset_calculator<float, 2> diff_bias_prv(
-            (float *)(scratchpad_->bias_ptr()),
-            nthreads, jcp.oc);
-
-#pragma omp parallel
-    {
-        if (jcp.with_bias) {
-            parallel_nd_in_omp(nthreads, jcp.oc,
-                [&](int ithr, int ofm) {
-                    diff_bias_prv(ithr, ofm) = 0.0f;
-            });
-#pragma omp for nowait
-            for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
-                PRAGMA_OMP_SIMD()
-                for (int v = 0; v < simd_w; v++)
-                    diff_bias(bofm, v) = 0.0f;
-            }
-        }
-    }
-
-    int th_counter = 0;
-#pragma omp parallel firstprivate(th_counter) num_threads(nthreads)
-#pragma omp for nowait
-    for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
-        int ithr = mkldnn_get_thread_num();
-
-        for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
-            for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
-                diff_src_transform_bwd_weights_ver_tile(tile_block, jcp,
-                        &(diff_src(0, ifm1 * jcp.ic_block + ifm2,
-                                0, 0, 0)),
-                        &(V(ithr, ifm1, 0, 0, ifm2, 0, 0, 0)),
-                        kernel_->transpose_4fma_ker);
-            }
-        }
-
-        for (int ofm1 = 0; ofm1 < jcp.nb_oc; ofm1++) {
-            for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
-                float *dbias = jcp.with_bias
-                    ? &(diff_bias_prv(ithr,
-                                simd_w * (ofm1 * jcp.oc_block + ofm2)))
-                    : NULL;
-                diff_dst_transform_bwd_weights_ver(tile_block, jcp,
-                        &(diff_dst(0, ofm1 * jcp.oc_block + ofm2,
-                                0, 0, 0)),
-                        &(M(ithr, ofm1, 0, 0, ofm2, 0, 0, 0)),
-                        dbias);
-            }
-        }
-
-        for (int ofm1 = 0; ofm1 < jcp.nb_oc; ofm1++) {
-            for (int oj = 0; oj < alpha; oj++) {
-                for (int oi = 0; oi < alpha; oi++) {
-                    for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
-                        if (th_counter == 0)
-                            kernel_->gemm_loop_ker_first_iter(
-                                    &(Us(ithr, ofm1, ifm1, oj, oi, 0, 0, 0, 0)),
-                                    &(M(ithr, ofm1, oj, oi, 0, 0, 0, 0)),
-                                    &(V(ithr, ifm1, oj, oi, 0, 0, 0, 0)));
-                        else
-                            kernel_->gemm_loop_ker(
-                                    &(Us(ithr, ofm1, ifm1, oj, oi, 0, 0, 0, 0)),
-                                    &(M(ithr, ofm1, oj, oi, 0, 0, 0, 0)),
-                                    &(V(ithr, ifm1, oj, oi, 0, 0, 0, 0)));
-                    }
-                }
-            }
-        }
-        th_counter++;
-    }
-
-    // Reduce diff-weights
-    {
-        float *output = (float *)(scratchpad_->U_ptr());
-        size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
-        float *input_ptrs[max_threads_number];
-        for (int i = 0; i < nthreads; i++) {
-            input_ptrs[i] = output + nelems * i;
-        }
-        array_sum(nthreads, output, nelems, input_ptrs);
-    }
-
-    parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block,
-        [&](int ofm1, int ifm1, int ofm2, int ifm2) {
-        diff_weights_transform_bwd_weights(jcp,
-                &(diff_weights(ofm1 * jcp.oc_block + ofm2,
-                        ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)),
-                &(U(ofm1, ifm1, 0, 0, ofm2, ifm2, 0, 0)));
-    });
-
-#pragma omp parallel
-    if (jcp.with_bias) {
-#pragma omp for
-        for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
-            for (int ithr = 0; ithr < nthreads; ithr++) {
-                float* base_bias_ptr = &(diff_bias(ofm1, 0));
-                float* base_bias_prv_ptr = &(diff_bias_prv(
-                            ithr * jcp.oc + ofm1 * simd_w));
-                PRAGMA_OMP_SIMD()
-                for (int ofm2 = 0; ofm2 < simd_w; ofm2++) {
-                    base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2];
-                }
-            }
-        }
-    }
-
-    _maybe_execute_diff_bias_copy();
-}
 }
 }
 }