Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_1x1_convolution.cpp
index da38121..099f1bd 100644 (file)
 * limitations under the License.
 *******************************************************************************/
 
-#include "mkldnn_types.h"
-
 #include "c_types_map.hpp"
-#include "jit_avx512_common_1x1_convolution.hpp"
-#include "utils.hpp"
 #include "mkldnn_thread.hpp"
 #include "type_helpers.hpp"
+#include "utils.hpp"
 
 #include "jit_generator.hpp"
 
+#include "jit_avx512_common_1x1_convolution.hpp"
+
 namespace mkldnn {
 namespace impl {
 namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 #define data_blk_off(f, n, c, h, w) \
@@ -37,74 +37,84 @@ using namespace mkldnn::impl::utils;
     ? (f).blk_off(n, c, w) \
     : (f).blk_off(n, c, h, w))
 
+
 namespace {
 template <typename T, typename U>
 void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
     T nx, T &nx_start, T &nx_end, T nx_divider)
 {
-    const T grp_size = utils::div_up(nthr, nx_divider);
-    const T grp_count = utils::div_up(nthr, grp_size);
-
-    T grp = ithr / grp_size;
-    T grp_ithr = ithr % grp_size;
-    T grp_nthr = grp_size;
-    T first_grps = nthr % grp_count;
-    if (first_grps > 0 && grp >= first_grps) {
-        ithr -= first_grps * grp_size;
-        grp_nthr--;
-        grp = ithr / grp_nthr + first_grps;
-        grp_ithr = ithr % grp_nthr;
+    const int grp_count = nstl::min(nx_divider, nthr);
+    const int grp_size_big = nthr / grp_count + 1;
+    const int grp_size_small = nthr / grp_count;
+    const int n_grp_big = nthr % grp_count;
+    const int threads_in_big_groups = n_grp_big * grp_size_big;
+
+    const int ithr_bound_distance = ithr - threads_in_big_groups;
+    T grp, grp_ithr, grp_nthr;
+    if (ithr_bound_distance < 0) { // ithr in first groups
+        grp = ithr / grp_size_big;
+        grp_ithr = ithr % grp_size_big;
+        grp_nthr = grp_size_big;
+    } else { // ithr in last groups
+        grp = n_grp_big + ithr_bound_distance / grp_size_small;
+        grp_ithr = ithr_bound_distance % grp_size_small;
+        grp_nthr = grp_size_small;
     }
+
     balance211(nx, grp_count, grp, nx_start, nx_end);
     balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
 }
 }
 /* convolution forward */
 
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
-        data_type_t dst_type>
-void _jit_avx512_common_1x1_convolution_fwd_t
-    <with_relu, src_type, wei_type, dst_type>::execute_forward()
-{
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
+void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
+execute_forward() const {
     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
     auto weights =
         reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
 
+    auto scratchpad = this->scratchpad();
+
     auto &jcp = kernel_->jcp;
-    if (conf_.want_padded_bias()) {
-        assert(jcp.ngroups == 1);
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad.template get<dst_data_t>(
+                key_conv_padded_bias);
+        utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bias = padded_bias;
     }
 
     parallel(0, [&](const int ithr, const int nthr) {
-        execute_forward_thr(ithr, nthr, src, weights, bias, dst);
+        execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
     });
+
+    if (pd()->wants_zero_pad_dst())
+        output_memory_primitive(0)->zero_pad();
 }
 
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
-        data_type_t dst_type>
-void _jit_avx512_common_1x1_convolution_fwd_t
-    <with_relu, src_type, wei_type, dst_type>::execute_forward_thr(
-            const int ithr, const int nthr,
-            const src_data_t *src, const wei_data_t *weights,
-            const dst_data_t *bias, dst_data_t *dst)
-{
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
+void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
+execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
+        const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst,
+        const memory_tracking::grantor_t &scratchpad) const {
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+
+    auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
 
     const int ndims = src_d.ndims();
-    const int stride_h = (ndims == 3) ? 1 : conf_.cdesc()->strides[0];
-    const int stride_w = conf_.cdesc()->strides[ndims - 3];
-    const int pad_t = (ndims == 3) ? 0 : conf_.cdesc()->padding[0][0];
-    const int pad_l = conf_.cdesc()->padding[0][ndims - 3];
+    const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+    const int stride_w = pd()->desc()->strides[ndims - 3];
+    const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+    const int pad_l = pd()->desc()->padding[0][ndims - 3];
 
-    auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const auto &jcp = kernel_->jcp;
+    const int MB = pd()->MB();
     const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
 
     auto step = [](int default_step, int remaining, int tail_step) {
@@ -179,13 +189,13 @@ void _jit_avx512_common_1x1_convolution_fwd_t
 
         p.output_data = &dst[dst_off];
         p.bias_data = &bias[_ocb * jcp.oc_block];
-        p.load_data = &weights[conf_.with_groups()
+        p.load_data = &weights[pd()->with_groups()
             ? weights_d.blk_off(g, ocb, icb)
             : weights_d.blk_off(ocb, icb)];
 
         const int _icb = g * nb_ic + icb;
-        if (conf_.rtus_.reduce_src_) {
-            rp.ws = scratch_ + ithr * ws_per_thread_
+        if (pd()->rtus_.reduce_src_) {
+            rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
                 + _icb * jcp.is * jcp.ic_block;
             if (ocb == ocb_start) {
                 rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
@@ -274,40 +284,39 @@ void _jit_avx512_common_1x1_convolution_fwd_t
 }
 
 
-template struct _jit_avx512_common_1x1_convolution_fwd_t<true, data_type::f32>;
-template struct _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::f32>;
-template struct _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::s16,
-    data_type::s16, data_type::s32>;
-template struct _jit_avx512_common_1x1_convolution_fwd_t<true, data_type::s16,
+template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
+template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::s16,
     data_type::s16, data_type::s32>;
 /* convolution backward wtr data */
 
 template <data_type_t diff_dst_type, data_type_t wei_type,
-    data_type_t diff_src_type>
-void _jit_avx512_common_1x1_convolution_bwd_data_t
-    <diff_dst_type, wei_type, diff_src_type>::execute_backward_data()
-{
+         data_type_t diff_src_type>
+void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type,
+     diff_src_type>::execute_backward_data() const {
     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
         (this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t *>
         (this->input_memory(1));
     auto diff_src = reinterpret_cast<diff_src_data_t *>(this->memory());
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+
+    auto rtus_space = scratchpad().template get<diff_src_data_t>(
+            key_conv_rtus_space);
 
     const int ndims = diff_src_d.ndims();
     const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const int MB = pd()->MB();
 
     // TODO (Roma): remove this restriction
     assert(jcp.stride_w == 1 && jcp.stride_h == 1);
 
-    const int stride_h = (ndims == 3) ? 1 : conf_.desc()->strides[0];
-    const int stride_w = conf_.desc()->strides[ndims - 3];
-    const int pad_t = (ndims == 3) ? 0 : conf_.desc()->padding[0][0];
-    const int pad_l = conf_.desc()->padding[0][ndims - 3];
+    const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+    const int stride_w = pd()->desc()->strides[ndims - 3];
+    const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+    const int pad_l = pd()->desc()->padding[0][ndims - 3];
 
     const int nb_ic = jcp.nb_load;
     const int nb_oc = jcp.nb_reduce;
@@ -376,8 +385,9 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t
 
                     const int _icb = g * nb_ic + icb;
                     rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
-                    if (conf_.rtus_.reduce_src_) {
-                        rp.ws = scratch_ + ithr * ws_per_thread_;
+                    if (pd()->rtus_.reduce_src_) {
+                        rp.ws = rtus_space
+                            + ithr * pd()->rtus_.space_per_thread_;
                         p.output_data = rp.ws;
                     } else
                         p.output_data = rp.src;
@@ -395,7 +405,7 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t
                         size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow);
                         p.bcast_data = &diff_dst[diff_dst_off];
 
-                        p.load_data = &weights[conf_.with_groups()
+                        p.load_data = &weights[pd()->with_groups()
                             ? weights_d.blk_off(g, ocb, icb)
                             : weights_d.blk_off(ocb, icb)];
 
@@ -406,7 +416,7 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t
 
                         kernel_->jit_ker(&p);
                     }
-                    if (conf_.rtus_.reduce_src_)
+                    if (pd()->rtus_.reduce_src_)
                         rtus_driver_->ker_(&rp);
                 }
             }
@@ -414,87 +424,81 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t
     });
 }
 
-template struct _jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
-template struct _jit_avx512_common_1x1_convolution_bwd_data_t<data_type::s16,
+template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
+template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::s16,
     data_type::s16, data_type::s32>;
 
 /* convolution backward wtr weights */
 
 #define wht_blk_off(d, g, ...) \
-        (conf_.with_groups() \
+        (pd()->with_groups() \
          ? (d).blk_off((g), __VA_ARGS__) \
          : (d).blk_off(__VA_ARGS__))
 
 jit_avx512_common_1x1_convolution_bwd_weights_t ::
-        jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *pd,
+        jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd,
                 const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs)
-    , conf_(*pd), kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
-    , trans_kernel_(nullptr), rtus_driver_(nullptr), ws_per_thread_(0)
-    , scratch_(nullptr), padded_bias_(nullptr), bctx_(nullptr)
-    , tr_src_(nullptr), ws_reduction_(nullptr)
+    : cpu_primitive_t(apd, inputs, outputs)
+    , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
+    , trans_kernel_(nullptr), rtus_driver_(nullptr)
 {
-    kernel_ = new jit_avx512_common_1x1_conv_kernel(conf_.jcp_, *conf_.attr());
-
-    const auto &jcp = kernel_->jcp;
-
-    const int wei_size = jcp.ngroups * jcp.oc * jcp.ic;
-    ws_reduction_ =
-        (data_t *)malloc((jcp.nthr_mb - 1) * wei_size * sizeof(data_t), 64);
+    kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
     acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
+    reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
+    init_rtus_driver<avx512_common>(this);
 
-    if (conf_.with_bias()) {
-        const size_t max_buffer_size = jcp.nthr * 3 * 5 * 5 * 16 * 16;
-        reducer_bias_ = new cpu_reducer_t<data_type::f32>(
-                reduce_balancer_t(jcp.nthr, jcp.oc_block,
-                        jcp.ngroups * jcp.nb_load, jcp.mb, max_buffer_size));
-
-        if (conf_.want_padded_bias()) {
-            assert(jcp.ngroups == 1);
-            padded_bias_ = (data_t *)malloc(sizeof(data_t) * jcp.oc, 64);
-        }
-    }
+    const auto &jcp = kernel_->jcp;
 
     if (jcp.transpose_src) {
-        const ptrdiff_t tr_src_size = (ptrdiff_t)jcp.nthr_mb
-            * (ptrdiff_t)jcp.ngroups * (ptrdiff_t)jcp.ic * jcp.tr_is;
-        tr_src_ = (data_t *)malloc(tr_src_size * sizeof(data_t), 64);
-        parallel_nd(tr_src_size, [&](ptrdiff_t i) { tr_src_[i] = 0; });
         auto tp = jit_transpose4x16_src_t();
         tp.src_pf0_distance = 4;
         tp.tr_src_pf0_distance = 0;
         tp.src_pf1 = true;
         tp.tr_src_pf1 = false;
         trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp);
-
-        bctx_ = (simple_barrier::ctx_t *)malloc(
-                jcp.nthr * sizeof(simple_barrier::ctx_t), 64);
-        for (int i = 0; i < jcp.nthr; ++i)
-            simple_barrier::ctx_init(&bctx_[i]);
     }
-
-    init_rtus_driver<avx512_common>(this);
 }
 
-void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
+void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights() const
 {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
     auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
-    data_t *diff_bias = conf_.want_padded_bias() ? padded_bias_ : diff_bias_in;
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
+
+    const auto scratchpad = this->scratchpad();
+
+    auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
+    data_t *diff_bias = pd()->wants_padded_bias()
+        ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
+    auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction);
+
+    /* prepare src transposition barriers */
+    auto tr_src = scratchpad.get<data_t>(key_conv_tr_src);
+    auto tr_src_bctx = scratchpad.get<simple_barrier::ctx_t>(
+            key_conv_tr_src_bctx);
+    if (jcp.transpose_src) {
+        for (int i = 0; i < jcp.nthr; ++i)
+            simple_barrier::ctx_init(&tr_src_bctx[i]);
+    }
+
     const int ndims = src_d.ndims();
     const int wei_size = jcp.ngroups * jcp.oc * jcp.ic;
 
     simple_barrier::ctx_t reduction_barrier;
     simple_barrier::ctx_init(&reduction_barrier);
 
+    const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+            prefix_reducer_bia);
+    auto rb = this->reducer_bias_;
+    rb->init(reducer_bia_scratchpad);
+
     // TODO (Roma): remove this restriction
     assert(jcp.stride_w == 1 && jcp.stride_h == 1);
 
@@ -507,10 +511,10 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
     const int sp_nb = jcp.nb_reduce;
     const int mb_sp_work = jcp.mb * sp_nb;
 
-    const int stride_h = (ndims == 3) ? 1 : conf_.desc()->strides[0];
-    const int stride_w = conf_.desc()->strides[ndims - 3];
-    const int pad_t = (ndims == 3) ? 0 : conf_.desc()->padding[0][0];
-    const int pad_l = conf_.desc()->padding[0][ndims - 3];
+    const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+    const int stride_w = pd()->desc()->strides[ndims - 3];
+    const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+    const int pad_l = pd()->desc()->padding[0][ndims - 3];
 
     auto step = [](int default_step, int remaining, int tail_step) {
         assert(default_step <= tail_step);
@@ -548,7 +552,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
 
         const int src1_off = data_blk_off(src_d, img, _ic, ih, iw);
         data_t *src1 = (data_t *)&src[src1_off];
-        data_t *tr_src1 = &tr_src_[tr_src_off(ithr_mb, ic_b_tr, is)];
+        data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)];
 
         assert(jcp.ic_block == 16);
         const int src_stride = jcp.is * jcp.ic_block;
@@ -611,9 +615,8 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
         const int oc_b_work = oc_b_end - oc_b_start;
         const int ic_b_work = ic_b_end - ic_b_start;
 
-        data_t *diff_wei = ithr_mb == 0 ?
-                diff_weights :
-                ws_reduction_ + (ithr_mb - 1) * wei_size;
+        data_t *diff_wei = ithr_mb == 0
+            ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size;
 
         int sp_b_step = 0;
         for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end;
@@ -634,7 +637,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
                     if (jcp.transpose_src) {
                         if (jcp.nthr_oc_b > 1)
                             simple_barrier::barrier(
-                                    &bctx_[ithr_but_oc], jcp.nthr_oc_b);
+                                    &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
                         const int sp_size
                                 = nstl::min(sp_b_step * jcp.reduce_block,
                                         jcp.is - sp_b * jcp.reduce_block);
@@ -642,7 +645,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
                             bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start);
                         if (jcp.nthr_oc_b > 1)
                             simple_barrier::barrier(
-                                    &bctx_[ithr_but_oc], jcp.nthr_oc_b);
+                                    &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
                     }
 
                     for (int oc_b = oc_b_start; oc_b < oc_b_end;
@@ -660,7 +663,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
                         store_to = diff_wei + off;
 
                         const data_t *diff_src = jcp.transpose_src ?
-                                &tr_src_[tr_src_off(ithr_mb, _ic_b_tr, 0)] :
+                                &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] :
                                 &src[src_d.blk_off(img, _ic_b)];
 
                         int sp_b_end = sp_b + sp_b_step;
@@ -690,7 +693,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
                         int sp = sp_b * jcp.reduce_block;
                         p.load_data = pdiff_dst + sp * jcp.oc_block;
 
-                        if (conf_.rtus_.reduce_src_) {
+                        if (pd()->rtus_.reduce_src_) {
                             const int oh = sp / jcp.ow;
                             const int ow = sp % jcp.ow;
 
@@ -698,8 +701,9 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
                             const int iw = nstl::max(ow * stride_w - pad_l, 0);
                             rp.iw_start = iw;
 
-                            rp.ws = scratch_ + ithr * ws_per_thread_
-                                    + sp * jcp.ic_block;
+                            rp.ws = rtus_space
+                                + ithr * pd()->rtus_.space_per_thread_
+                                + sp * jcp.ic_block;
 
                             if (ndims == 3)
                                 rp.src = local_src + iw
@@ -720,7 +724,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
             }
         }
 
-        /* diff_weights[:] += sum(ws_reduction_[thr_mb][:]) */
+        /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */
         if (jcp.nthr_mb > 1) {
             simple_barrier::barrier(&reduction_barrier, jcp.nthr);
             const int work = g_work * oc_b_work * ic_b_work;
@@ -747,7 +751,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
                     const size_t off
                             = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
                     data_t *d = diff_weights + off;
-                    data_t *s = ws_reduction_ + (thr_mb - 1) * wei_size + off;
+                    data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off;
 
                     acc_ker_->accumulate(d, s, acc_size);
 
@@ -760,11 +764,10 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
     };
 
     auto ker_bias = [&](int ithr, int nthr) {
-        auto rb = this->reducer_bias_;
-        assert(nthr == rb->balancer_.nthr_);
+        assert(nthr == rb->balancer().nthr_);
 
-        const int b_job_start = rb->balancer_.ithr_job_off(ithr);
-        const int b_njobs = rb->balancer_.ithr_njobs(ithr);
+        const int b_job_start = rb->balancer().ithr_job_off(ithr);
+        const int b_njobs = rb->balancer().ithr_njobs(ithr);
 
         if (b_njobs == 0)
             return;
@@ -772,8 +775,8 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
         /* reduction dimension */
         int img_start{ 0 }, img_end{ 0 };
 
-        balance211(jcp.mb, rb->balancer_.nthr_per_group_,
-                rb->balancer_.id_in_group(ithr), img_start, img_end);
+        balance211(jcp.mb, rb->balancer().nthr_per_group_,
+                rb->balancer().id_in_group(ithr), img_start, img_end);
 
         /* jobs */
         int g_start{ 0 }, ocb_start{ 0 };
@@ -786,8 +789,9 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
                 const size_t _oc = g * jcp.nb_load + ocb;
 
                 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
-                data_t *d_bias = &rb->get_local_ptr(
-                        ithr, diff_bias)[b_job_loc * rb->balancer_.job_size_];
+                data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
+                        reducer_bia_scratchpad)
+                    + b_job_loc * rb->balancer().job_size_;
 
                 if (img == img_start)
                     for (int o = 0; o < 16; ++o)
@@ -803,20 +807,19 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
                 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load);
             }
         }
-        rb->reduce(ithr, diff_bias);
+        rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
     };
 
     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
         ker(ithr, jcp.nthr);
-        if (conf_.with_bias())
+        if (pd()->with_bias())
             ker_bias(ithr, jcp.nthr);
     });
 
     /* TODO: put this in ker_bias */
-    if (conf_.want_padded_bias()) {
+    if (pd()->wants_padded_bias()) {
         assert(jcp.ngroups == 1);
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            diff_bias_in[oc] = diff_bias[oc];
+        utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding);
     }
 }