Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_1x1_convolution.cpp
index 7a6e17c..5f888a2 100644 (file)
 * limitations under the License.
 *******************************************************************************/
 
-#include <cstring>
-#include <mkldnn_types.h>
-#include <iostream>
-#include "mkldnn_types.h"
-
 #include "c_types_map.hpp"
-#include "jit_avx2_1x1_convolution.hpp"
-#include "utils.hpp"
 #include "mkldnn_thread.hpp"
 #include "type_helpers.hpp"
-
+#include "utils.hpp"
+#include <cstring>
 #include "jit_generator.hpp"
 
+#include "jit_avx2_1x1_convolution.hpp"
+
 namespace mkldnn {
 namespace impl {
 namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 #define data_blk_off(f, n, c, h, w) \
@@ -42,27 +39,28 @@ using namespace mkldnn::impl::utils;
 
 /* convolution forward */
 
-template <bool with_relu>
-void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward() {
+void jit_avx2_1x1_convolution_fwd_t::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+
+    auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
 
     const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const int MB = pd()->MB();
 
     const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
     const int ndims = dst_d.ndims();
 
-    const int stride_h = (ndims == 3) ? 1 : conf_.cdesc()->strides[0];
-    const int stride_w = conf_.cdesc()->strides[ndims - 3];
-    const int pad_t = (ndims == 3) ? 0 : conf_.cdesc()->padding[0][0];
-    const int pad_l = conf_.cdesc()->padding[0][ndims - 3];
+    const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+    const int stride_w = pd()->desc()->strides[ndims - 3];
+    const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+    const int pad_l = pd()->desc()->padding[0][ndims - 3];
 
     auto step = [](int default_step, int remaining, int tail_step) {
         assert(default_step <= tail_step);
@@ -73,8 +71,8 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward() {
         // TODO (Roma): remove this restriction
         assert(jcp.stride_w == 1 && jcp.stride_h == 1);
 
-        jit_1x1_conv_call_s p = {};
-        rtus_driver_t<avx2>::call_params_t rp = {};
+           auto p = jit_1x1_conv_call_s();
+           auto rp = rtus_driver_t<avx2>::call_params_t();
 
         const int nb_oc = jcp.nb_load;
         const int nb_ic = jcp.nb_reduce;
@@ -129,13 +127,14 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward() {
                             nb_ic_blocking * jcp.ic_block);
                     rp.icb = p.reduce_dim / jcp.reduce_block;
 
-                    p.load_data = &weights[conf_.with_groups()
+                    p.load_data = &weights[pd()->with_groups()
                         ? weights_d.blk_off(g, ocb, icb)
                         : weights_d.blk_off(ocb, icb)];
 
                     const int _icb = g * nb_ic + icb;
-                    if (conf_.rtus_.reduce_src_) {
-                        rp.ws = scratch_ + ithr * ws_per_thread_
+                    if (pd()->rtus_.reduce_src_) {
+                        rp.ws = rtus_space
+                            + ithr * pd()->rtus_.space_per_thread_
                             + _icb * jcp.is * jcp.ic_block;
 
                         if (ocb == 0) {
@@ -159,29 +158,37 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward() {
         }
     };
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
+        utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bias = padded_bias;
     }
 
     parallel(0, ker);
+
+    if (pd()->wants_zero_pad_dst())
+        output_memory_primitive(0)->zero_pad();
 }
 
-template <bool with_relu>
-void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
+void jit_avx2_1x1_convolution_fwd_t::execute_forward_with_dw_conv() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+
+    auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
 
     const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const auto &jcp_dw = kernel_dw_->jcp;
+    const int MB = pd()->MB();
 
-    auto dw_bias = jcp.dw_conv_biases;
+    auto dw_bias = jcp_dw.conv_biases;
 
     int ocb_work = jcp.with_dw_conv ? utils::div_up(jcp.nb_load, jcp.nb_load_blocking) : 1;
     const int work_amount = MB * jcp.ngroups * ocb_work * jcp.nb_bcast;
@@ -205,8 +212,8 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
 
                 if ((oh + h) < 0 || (oh + h) >= jcp.ih) {
                     for (int chb = ocb; chb < ocb + load_step; chb++) {
-                        memset(ws_p + (((oh + h) + 1) % jcp.dw_conv_ker_h) * jcp.ow * jcp.oc_block +
-                               (chb - ocb) * jcp.dw_conv_ker_h * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
+                        memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
+                               (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
                     }
                 } else {
                     const int _ocb = g * jcp.nb_load + ocb;
@@ -217,7 +224,7 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
                     rp.os = p.bcast_dim;
                     p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, load_step * jcp.oc_block);
 
-                    p.output_data = &ws_p[(((oh + h) + 1) % jcp.dw_conv_ker_h) * jcp.ow * jcp.oc_block];
+                    p.output_data = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block];
 
                     p.bias_data = &bias[_ocb * jcp.oc_block];
 
@@ -231,13 +238,14 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
                                                        jcp.nb_reduce_blocking * jcp.ic_block);
                         rp.icb = p.reduce_dim / jcp.reduce_block;
 
-                        p.load_data = &weights[conf_.with_groups()
+                        p.load_data = &weights[pd()->with_groups()
                                                ? weights_d.blk_off(g, ocb, icb)
                                                : weights_d.blk_off(ocb, icb)];
 
                         const int _icb = g * jcp.nb_reduce + icb;
-                        if (conf_.rtus_.reduce_src_) {
-                            rp.ws = scratch_ + ithr * ws_per_thread_
+                        if (pd()->rtus_.reduce_src_) {
+                            rp.ws = rtus_space
+                                    + ithr * pd()->rtus_.space_per_thread_
                                     + _icb * jcp.is * jcp.ic_block;
 
                             if (ocb == 0) {
@@ -259,7 +267,6 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
         };
 
         auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int load_step, int dst_idx) {
-            const auto &jcp_dw = kernel_dw_->jcp;
 
             for (int chb = ocb; chb < ocb + load_step; chb++) {
                 auto par_conv_dw = jit_conv_call_s();
@@ -275,9 +282,11 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
                                        dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
 
                 par_conv_dw.kh_padding = jcp_dw.kh;
-                par_conv_dw.filt = &jcp.dw_conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
+                par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
                 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
                 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
+                par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
+                par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
 
                 kernel_dw_->jit_ker(&par_conv_dw);
             }
@@ -288,7 +297,9 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
         int start{0}, end{0};
         balance211(work_amount, nthr, ithr, start, end);
 
-        auto pbuf = dw_conv_buffer_ + ithr * dw_conv_buffer_size_;
+        auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
+        size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
+        auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
 
         const int os_block = jcp.iw;
 
@@ -319,7 +330,7 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
                 compute_block_1x1(pbuf, n, g, oh + 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step);
             }
 
-            if ((oh % jcp.dw_conv_str_h == 0)) {
+            if ((oh % jcp_dw.stride_h == 0)) {
                 compute_row_dw(pbuf, n, ocb, load_step, oh);
             }
 
@@ -327,44 +338,50 @@ void _jit_avx2_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
         }
     };
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
-
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            dw_padded_bias_[oc] = dw_bias[oc];
-        dw_bias = dw_padded_bias_;
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
+        utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bias = padded_bias;
+
+        auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
+        utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
+        utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
+                         jcp.oc - jcp.oc_without_padding);
+        dw_bias = dw_padded_bias;
     }
 
     parallel(0, ker);
-}
 
-template struct _jit_avx2_1x1_convolution_fwd_t<true>;
-template struct _jit_avx2_1x1_convolution_fwd_t<false>;
+    if (pd()->wants_zero_pad_dst())
+        output_memory_primitive(0)->zero_pad();
+}
 
 /* convolution backward wtr data */
 
-void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data() {
+void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data() const {
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+
+    auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
 
     const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const int MB = pd()->MB();
 
     // TODO (Roma): remove this restriction
     assert(jcp.stride_w == 1 && jcp.stride_h == 1);
     const int ndims = diff_dst_d.ndims();
 
-    const int stride_h = (ndims == 3) ? 1 : conf_.desc()->strides[0];
-    const int stride_w = conf_.desc()->strides[ndims - 3];
-    const int pad_t = (ndims == 3) ? 0 : conf_.desc()->padding[0][0];
-    const int pad_l = conf_.desc()->padding[0][ndims - 3];
+    const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+    const int stride_w = pd()->desc()->strides[ndims - 3];
+    const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+    const int pad_l = pd()->desc()->padding[0][ndims - 3];
 
     const int nb_ic = jcp.nb_load;
     const int nb_oc = jcp.nb_reduce;
@@ -417,8 +434,9 @@ void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data() {
 
                 const int _icb = g * nb_ic + icb;
                 rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
-                if (conf_.rtus_.reduce_src_) {
-                    rp.ws = scratch_ + ithr * ws_per_thread_;
+                if (pd()->rtus_.reduce_src_) {
+                    rp.ws = rtus_space
+                        + ithr * pd()->rtus_.space_per_thread_;
                     p.output_data = rp.ws;
                 } else
                     p.output_data = rp.src;
@@ -430,7 +448,7 @@ void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data() {
                         ow);
                     p.bcast_data = &diff_dst[diff_dst_off];
 
-                    p.load_data = &weights[conf_.with_groups()
+                    p.load_data = &weights[pd()->with_groups()
                         ? weights_d.blk_off(g, ocb, icb)
                         : weights_d.blk_off(ocb, icb)];
 
@@ -442,7 +460,7 @@ void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data() {
                     kernel_->jit_ker(&p);
                 }
 
-                if (conf_.rtus_.reduce_src_)
+                if (pd()->rtus_.reduce_src_)
                     rtus_driver_->ker_(&rp);
             }
         }
@@ -454,64 +472,46 @@ void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data() {
 /* convolution backward wtr weights */
 
 jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t(
-        const pd_t *pd, const input_vector &inputs,
+        const pd_t *apd, const input_vector &inputs,
         const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr)
-    , rtus_driver_(nullptr), ws_per_thread_(0), scratch_(nullptr)
-    , padded_bias_(nullptr)
+    : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr)
+    , rtus_driver_(nullptr)
 {
-    kernel_ = new jit_avx2_1x1_conv_kernel_f32(conf_.jcp_, *conf_.attr());
-
-    const auto &jcp = kernel_->jcp;
-
-    const int ic_block = jcp.bcast_block;
-    const int nb_ic = jcp.nb_bcast;
-    const int nb_ic_blocking = jcp.nb_bcast_blocking;
-    const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking);
-
-    const int oc_block = jcp.load_block;
-    const int nb_oc = jcp.nb_load;
-    const int nb_oc_blocking = jcp.nb_load_blocking;
-    const int load_work = utils::div_up(nb_oc, nb_oc_blocking);
-
-    const int job_size
-        = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block;
-    const int njobs_x = bcast_work;
-    const int njobs_y = jcp.ngroups * load_work;
-
-    const int max_threads = mkldnn_get_max_threads();
-    const size_t max_buffer_size = max_threads * job_size * 8;
-
-    reducer_weights_ = new cpu_reducer_2d_t<data_type::f32>(
-            reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x,
-                jcp.mb * jcp.nb_reduce, max_buffer_size),
-            job_size / nb_oc_blocking, nb_oc_blocking, ic_block,
-            nb_ic * ic_block * oc_block, nb_oc, false);
-
-    reducer_bias_ = !conf_.with_bias() ? nullptr
-        : new cpu_reducer_t<data_type::f32>(reduce_balancer_t(max_threads,
-                    oc_block, jcp.ngroups * jcp.oc / oc_block,
-                    jcp.mb, max_buffer_size));
-
-    if (conf_.want_padded_bias())
-        padded_bias_ = (data_t *)malloc(sizeof(data_t) * jcp.oc, 64);
-
+    kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, jit_conv_conf_t(), *pd()->attr());
+    reducer_weights_ =
+        new cpu_reducer_2d_t<data_type::f32>(pd()->reducer_wei_conf_);
+    reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
     init_rtus_driver<avx2>(this);
 }
 
-void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() {
+void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
     auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
-    data_t *diff_bias = conf_.want_padded_bias() ? padded_bias_ : diff_bias_in;
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
-    const memory_desc_wrapper diff_bias_d(conf_.diff_weights_pd(1));
+    auto scratchpad = this->scratchpad();
+
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
+    const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
 
     const auto &jcp = kernel_->jcp;
+    auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
+
+    data_t *diff_bias = pd()->wants_padded_bias()
+        ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
+
+    auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+            prefix_reducer_bia);
+    auto rb = this->reducer_bias_;
+    rb->init(reducer_bia_scratchpad);
+
+    auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
+            prefix_reducer_wei);
+    auto rw = this->reducer_weights_;
+    rw->init(reducer_wei_scratchpad);
 
     const int ndims = diff_dst_d.ndims();
     // TODO (Roma): remove this restriction
@@ -528,10 +528,10 @@ void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() {
     const int sp_dim = jcp.reduce_dim;
     const int mb_sp_work = jcp.mb * sp_dim;
 
-    const int stride_h = (ndims == 3) ? 1 : conf_.desc()->strides[0];
-    const int stride_w = conf_.desc()->strides[ndims - 3];
-    const int pad_t = (ndims == 3) ? 0 : conf_.desc()->padding[0][0];
-    const int pad_l = conf_.desc()->padding[0][ndims - 3];
+    const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+    const int stride_w = pd()->desc()->strides[ndims - 3];
+    const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+    const int pad_l = pd()->desc()->padding[0][ndims - 3];
 
     auto step = [](int default_step, int remaining, int tail_step) {
         assert(default_step <= tail_step);
@@ -574,7 +574,7 @@ void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() {
                     p.load_data = diff_dst
                         + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block;
 
-                    if (conf_.rtus_.reduce_src_) {
+                    if (pd()->rtus_.reduce_src_) {
                         const int oh = sp / jcp.ow;
                         const int ow = sp % jcp.ow;
 
@@ -582,7 +582,8 @@ void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() {
                         const int iw = nstl::max(ow * stride_w - pad_l, 0);
                         rp.iw_start = iw;
 
-                        rp.ws = scratch_ + ithr * ws_per_thread_
+                        rp.ws = rtus_space
+                            + ithr * pd()->rtus_.space_per_thread_
                             + (ic_b * jcp.is + sp) * jcp.ic_block;
                         if (ndims == 3)
                             rp.src = src
@@ -607,22 +608,21 @@ void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() {
     };
 
     auto ker = [&](const int ithr, const int nthr) {
-        auto rw = this->reducer_weights_;
-        assert(nthr == rw->balancer_.nthr_);
+        assert(nthr == rw->balancer().nthr_);
 
-        const int w_njobs = rw->balancer_.ithr_njobs(ithr);
+        const int w_njobs = rw->balancer().ithr_njobs(ithr);
         if (w_njobs == 0) return;
 
         /* setup: independent work (oc, ic) */
-        const int w_job_start = rw->balancer_.ithr_job_off(ithr);
+        const int w_job_start = rw->balancer().ithr_job_off(ithr);
         int g{0}, load_i{0}, bcast_i{0};
         nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work,
                 bcast_i, bcast_work);
 
         /* setup: reduction work (mb, sp) */
         int mb_sp_start{0}, mb_sp_end{0};
-        balance211(mb_sp_work, rw->balancer_.nthr_per_group_,
-                rw->balancer_.id_in_group(ithr), mb_sp_start, mb_sp_end);
+        balance211(mb_sp_work, rw->balancer().nthr_per_group_,
+                rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end);
         int img_start{0}, sp_start{0};
         nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim);
 
@@ -637,16 +637,16 @@ void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() {
             data_t *store_to;
             size_t store_to_ld;
 
-            if (rw->balancer_.nthr_per_group_ == 1 ||
-                    (rw->balancer_.master(ithr) && rw->master_uses_dst_)) {
-                const size_t off = conf_.with_groups()
+            if (rw->balancer().nthr_per_group_ == 1) {
+                const size_t off = pd()->with_groups()
                     ? diff_weights_d.blk_off(g, oc_b, ic_b)
                     : diff_weights_d.blk_off(oc_b, ic_b);
                 store_to = &diff_weights[off];
                 store_to_ld = jcp.ic * jcp.oc_block;
             } else {
-                const size_t off = iwork * rw->balancer_.job_size_;
-                store_to = &rw->get_local_ptr(ithr, nullptr)[off];
+                const size_t off = iwork * rw->balancer().job_size_;
+                store_to =
+                    rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off;
                 store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block;
             }
 
@@ -670,22 +670,21 @@ void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() {
             nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i,
                              bcast_work);
         }
-        rw->reduce(ithr, diff_weights);
+        rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
     };
 
     auto ker_bias = [&](int ithr, int nthr) {
-        auto rb = this->reducer_bias_;
-        assert(nthr == rb->balancer_.nthr_);
+        assert(nthr == rb->balancer().nthr_);
 
-        const int b_job_start = rb->balancer_.ithr_job_off(ithr);
-        const int b_njobs = rb->balancer_.ithr_njobs(ithr);
+        const int b_job_start = rb->balancer().ithr_job_off(ithr);
+        const int b_njobs = rb->balancer().ithr_njobs(ithr);
 
         if (b_njobs == 0) return;
 
         /* reduction dimension */
         int img_start{0}, img_end{0};
-        balance211(jcp.mb, rb->balancer_.nthr_per_group_,
-                rb->balancer_.id_in_group(ithr), img_start, img_end);
+        balance211(jcp.mb, rb->balancer().nthr_per_group_,
+                rb->balancer().id_in_group(ithr), img_start, img_end);
 
         /* jobs */
         int g_start{0}, ocb_start{0};
@@ -697,8 +696,9 @@ void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() {
                 const size_t _oc = g * nb_oc + ocb;
 
                 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
-                data_t *d_bias = &rb->get_local_ptr(ithr, diff_bias)[
-                    b_job_loc * rb->balancer_.job_size_];
+                data_t *d_bias =
+                    rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad)
+                    + b_job_loc * rb->balancer().job_size_;
 
                 if (img == img_start)
                     for (int o = 0; o < 8; ++o) d_bias[o] = 0.;
@@ -713,17 +713,17 @@ void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() {
                 nd_iterator_step(g, jcp.ngroups, ocb, nb_oc);
             }
         }
-        rb->reduce(ithr, diff_bias);
+        rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
     };
 
     parallel(0, [&](const int ithr, const int nthr) {
         ker(ithr, nthr);
-        if (conf_.with_bias())
+        if (pd()->with_bias())
             ker_bias(ithr, nthr);
     });
 
     /* TODO: put this in ker_bias */
-    if (conf_.want_padded_bias()) {
+    if (pd()->wants_padded_bias()) {
         assert(jcp.ngroups == 1);
         for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
             diff_bias_in[oc] = diff_bias[oc];