Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_convolution.cpp
index e9ccf6f..d7ea64b 100644 (file)
 * limitations under the License.
 *******************************************************************************/
 
-#include <cstring>
-#include "mkldnn_types.h"
-
 #include "c_types_map.hpp"
-#include "jit_avx2_convolution.hpp"
-#include "utils.hpp"
 #include "mkldnn_thread.hpp"
 #include "type_helpers.hpp"
+#include "utils.hpp"
+#include <cstring>
+
+#include "jit_avx2_convolution.hpp"
 
 namespace mkldnn {
 namespace impl {
@@ -29,39 +28,38 @@ namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
-
 #define src_blk_off(f, n, c, d, h, w) \
-    (conf_.ndims() == 3) \
+    (pd()->ndims() == 3) \
     ? (f).blk_off(n, c, w) \
-    : (conf_.ndims() == 4) \
+    : (pd()->ndims() == 4) \
     ? (f).blk_off(n, c, h, w) \
     : (f).blk_off(n, c, d, h, w)
 
 #define wht_blk_off_(f, g, ...) \
-    conf_.with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
+    pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
 #define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
-    (conf_.ndims() == 3) \
+    (pd()->ndims() == 3) \
     ? wht_blk_off_(f, g, oc, ic, kw) \
-    : (conf_.ndims() == 4) \
+    : (pd()->ndims() == 4) \
     ? wht_blk_off_(f, g, oc, ic, kh, kw) \
     : wht_blk_off_(f, g, oc, ic, kd, kh, kw)
 
-template <bool with_relu>
-void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
+void jit_avx2_convolution_fwd_t::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper bias_d(pd()->weights_pd(1));
 
     const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const int MB = pd()->MB();
 
     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
     const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.od
@@ -86,7 +84,7 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
                 int ocb_num = jcp.nb_oc_blocking;
 
                 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
-                    jit_conv_call_s par_conv = {};
+                    auto par_conv = jit_conv_call_s();
 
                     const int ij = oh * jcp.stride_h;
                     const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
@@ -99,7 +97,7 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
                         + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id;
 
                     const size_t _oc = g * jcp.nb_oc + ocb;
-                    const size_t _ic = g * jcp.nb_ic + icb;
+                    const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb;
 
                     const int ih = nstl::max(ij - jcp.t_pad
                         + div_up(i_t_overflow,
@@ -155,31 +153,35 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
         }
     };
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
+        utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bias = padded_bias;
     }
 
     parallel(0, ker);
+
+    if (pd()->wants_zero_pad_dst())
+        output_memory_primitive(0)->zero_pad();
 }
 
-template <bool with_relu>
-void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
+void jit_avx2_convolution_fwd_t::execute_forward_with_dw_conv() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper bias_d(pd()->weights_pd(1));
 
     const auto &jcp = kernel_->jcp;
     const auto &jcp_dw = kernel_dw_->jcp;
-    const int MB = conf_.MB();
+    const int MB = pd()->MB();
 
-    auto dw_bias = jcp.dw_conv_biases;
+    auto dw_bias = jcp_dw.conv_biases;
 
     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
     const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
@@ -189,8 +191,8 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
             for (int h = 0; h < num_rows; h++) {
                 if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
                     for (int chb = ocb; chb < ocb + ocb_num; chb++) {
-                        memset(ws_p + (((oh + h) + 1) % jcp.dw_conv_ker_h) * jcp.ow * jcp.oc_block +
-                               (chb - ocb) * jcp.dw_conv_ker_h * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
+                        memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
+                               (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
                     }
                 } else {
                     for (int icb = 0; icb < jcp.nb_ic; ++icb) {
@@ -211,11 +213,11 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
                         par_conv.src = &src[src_d.blk_off(n,
                                                           jcp.ic == 3 ? 0 : _ic, ih, 0)];
 
-                        par_conv.dst = &ws_p[(((oh + h) + 1) % jcp.dw_conv_ker_h) * jcp.ow *
+                        par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow *
                                              jcp.oc_block];
 
                         const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
-                        par_conv.filt = &weights[conf_.with_groups()
+                        par_conv.filt = &weights[pd()->with_groups()
                                                  ? weights_d.blk_off(g, ocb,
                                                                      jcp.ic == 3 ? 0 : icb, wh, 0)
                                                  : weights_d.blk_off(ocb,
@@ -264,9 +266,11 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
                                        dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
 
                 par_conv_dw.kh_padding = jcp_dw.kh;
-                par_conv_dw.filt = &jcp.dw_conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
+                par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
                 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
                 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
+                par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
+                par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
 
                 kernel_dw_->jit_ker(&par_conv_dw);
             }
@@ -275,7 +279,9 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
         size_t start{0}, end{0};
         balance211(work_amount, nthr, ithr, start, end);
 
-        auto pbuf = dw_conv_buffer_ + ithr * dw_conv_buffer_size_;
+        auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
+        size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
+        auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
 
         size_t n{0}, g{0}, ocbb{0}, oh{0};
         nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work,
@@ -304,138 +310,156 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
         }
     };
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
-
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            dw_padded_bias_[oc] = dw_bias[oc];
-        dw_bias = dw_padded_bias_;
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
+        utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bias = padded_bias;
+
+        auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
+        utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
+        utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
+                         jcp.oc - jcp.oc_without_padding);
+        dw_bias = dw_padded_bias;
     }
 
     parallel(0, ker);
-}
 
-template void _jit_avx2_convolution_fwd_t<true>::execute_forward();
-template void _jit_avx2_convolution_fwd_t<false>::execute_forward();
-template void _jit_avx2_convolution_fwd_t<true>::execute_forward_fusing();
-template void _jit_avx2_convolution_fwd_t<false>::execute_forward_fusing();
+    if (pd()->wants_zero_pad_dst())
+        output_memory_primitive(0)->zero_pad();
+}
 
-void jit_avx2_convolution_bwd_data_t::execute_backward_data() {
+void jit_avx2_convolution_bwd_data_t::execute_backward_data() const {
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const int MB = pd()->MB();
 
     int icb_work = jcp.nb_ic / jcp.nb_ic_blocking;
-    const size_t work_amount = MB * jcp.ngroups * icb_work * jcp.ih;
+    int ih_block_size = jcp.ih;
+    int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
+    size_t work_amount = MB * jcp.ngroups * icb_work * num_ih_blocks;
+    if (work_amount < (size_t)2 * mkldnn_get_max_threads()) {
+        ih_block_size = 1;
+        num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
+        work_amount *= num_ih_blocks;
+    }
 
     auto ker = [&](const int ithr, const int nthr) {
         size_t start{0}, end{0};
         balance211(work_amount, nthr, ithr, start, end);
 
-        size_t n{0}, g{0}, icbb{0}, ih{0};
-        nd_iterator_init(start, n, MB, g, jcp.ngroups, icbb, icb_work, ih, jcp.ih);
+        size_t n{0}, g{0}, icbb{0}, ihb{0};
+        nd_iterator_init(start, n, MB, g, jcp.ngroups, icbb, icb_work,
+                         ihb, num_ih_blocks);
+
         for (size_t iwork = start; iwork < end; ++iwork) {
-            for (int oc = 0; oc < jcp.nb_oc; ++oc)
+            for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking)
             for (int id = 0; id < jcp.id; ++id) {
                 auto par_conv = jit_conv_call_s();
 
                 const int idp = jcp.id + 2 * jcp.f_pad;
                 const int d_t_overflow = nstl::max(0,
-                                                   jcp.kd - 1 - id - jcp.f_pad);
+                        jcp.kd - 1 - id - jcp.f_pad);
                 const int back_pad = idp - jcp.id - jcp.f_pad;
                 const int d_b_overflow = nstl::max(0,
-                                                   jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
+                        jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
                 const int od = id + jcp.f_pad - d_b_overflow;
 
-                const int simd_w = 8;
-
-                const int i_t_overflow = nstl::max(0,
-                                                   jcp.kh - 1 - (int)ih - jcp.t_pad);
-                const int b_pad = jcp.ihp - jcp.ih - jcp.t_pad;
-                const int i_b_overflow = nstl::max(0,
-                                                   jcp.kh - 1 - (jcp.ih - 1 - (int)ih) - b_pad);
-                int oh = ih + jcp.t_pad - i_b_overflow;
-
-                int stride_off_h = oh % jcp.stride_h;
-                oh /= jcp.stride_h;
-
-                par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
-                                         /*jcp.ic == 3 ? 0 :*/
-                                                     g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
-                par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
-                                                     n, g * jcp.nb_oc + oc, od, oh, 0)];
-                par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
-                                                     jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
-                                                     d_b_overflow, i_b_overflow + stride_off_h, 0)];
-
-                par_conv.src_prf = nullptr;
-                par_conv.dst_prf = nullptr;
-                par_conv.filt_prf = nullptr;
-                // TODO: move initialization into the kernel
-                if (oc == 0) {
-                    for (int iw = 0; iw < jcp.iw; iw++) {
-                        for (int b = 0; b < jcp.nb_ic_blocking; b++) {
-                            int current_ic =
-                                    (jcp.ic == 3 ? 0 : g * jcp.nb_ic)
-                                    + jcp.nb_ic_blocking * icbb + b;
-                            int current_idx =
-                                    src_blk_off(diff_src_d, n, current_ic,
-                                                id, ih, iw);
-                            for (int v = 0; v < simd_w; v++)
-                                diff_src[current_idx + v] = 0.0;
-                        }
-                    }
-                }
+                int ih_start = ihb * ih_block_size;
+                int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size);
+                for (int ih = ih_start; ih < ih_end; ++ih) {
+
+                    const int i_t_overflow = nstl::max(0, (jcp.kh - 1
+                                        - ih - jcp.t_pad) / jcp.stride_h);
+                    const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih
+                                        + ih - jcp.b_pad) / jcp.stride_h);
+                    int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
+                                + jcp.b_pad - ih) % jcp.stride_h);
+                    int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h;
+
+                    par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
+                    par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo)
+                              / jcp.stride_h + 1 - i_t_overflow - i_b_overflow;
+                    par_conv.kw_padding = 0;
 
-                par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
-                par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow - stride_off_h);
-                par_conv.kw_padding = 0;
+                    const int k_lo = overflow_kh_lo
+                                   + i_b_overflow * jcp.stride_h;
+                    const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h;
+
+                    par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
+                        /*jcp.ic == 3 ? 0 :*/
+                        g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
+                    par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
+                            n, g * jcp.nb_oc + oc, od, oh, 0)];
+                    par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
+                                jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
+                                d_b_overflow, k_lo, 0)];
+
+                    par_conv.src_prf = nullptr;
+                    par_conv.dst_prf = nullptr;
+                    par_conv.filt_prf = nullptr;
+                    par_conv.channel = oc;
+                    par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc,
+                                       jcp.nb_oc_blocking);
 
-                if (par_conv.kh_padding > 0)
                     kernel_->jit_ker(&par_conv);
+                }
             }
-            nd_iterator_step(n, MB, g, jcp.ngroups, icbb, icb_work, ih, jcp.ih);
+            nd_iterator_step(n, MB, g, jcp.ngroups, icbb, icb_work, ihb,
+                             num_ih_blocks);
         }
     };
 
     parallel(0, ker);
 }
 
-void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
+void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
     auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
-    data_t *diff_bias = conf_.want_padded_bias() ? padded_bias_ : diff_bias_in;
 
-    const memory_desc_wrapper src_d(conf_.src_pd(0));
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+    auto scratchpad = this->scratchpad();
+
+    data_t *diff_bias = pd()->wants_padded_bias()
+        ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
+
+    const memory_desc_wrapper src_d(pd()->src_pd(0));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
 
+    auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+            prefix_reducer_bia);
+    auto rb = this->reducer_bias_;
+    rb->init(reducer_bia_scratchpad);
+
+    auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
+            prefix_reducer_wei);
+    auto rw = this->reducer_weights_;
+    rw->init(reducer_wei_scratchpad);
+
     auto ker = [&](int ithr, int nthr) {
-        auto rw = this->reducer_weights_;
-        assert(nthr == rw->balancer_.nthr_);
+        assert(nthr == rw->balancer().nthr_);
 
-        const int w_job_start = rw->balancer_.ithr_job_off(ithr);
-        const int w_njobs = rw->balancer_.ithr_njobs(ithr);
+        const int w_job_start = rw->balancer().ithr_job_off(ithr);
+        const int w_njobs = rw->balancer().ithr_njobs(ithr);
 
         if (w_njobs == 0) return;
 
         /* reduction dimension */
         int img_od_start{0}, img_od_end{0}, img{0}, od_s{0};
-        balance211(jcp.mb * jcp.od, rw->balancer_.nthr_per_group_,
-                rw->balancer_.id_in_group(ithr), img_od_start, img_od_end);
+        balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_,
+                rw->balancer().id_in_group(ithr), img_od_start, img_od_end);
 
         int img_start = img_od_start, img_end = img_od_end;
         nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
@@ -461,9 +485,10 @@ void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
 
                 /* TODO: put dw <-- 0 in kernel */
                 if (img == img_first)
-                    array_set((data_t *)&rw->get_local_ptr(ithr, diff_weights)[
-                        w_job_loc * rw->balancer_.job_size_], 0,
-                            rw->balancer_.job_size_);
+                    array_set(rw->get_local_ptr(ithr, diff_weights,
+                                reducer_wei_scratchpad) +
+                            w_job_loc * rw->balancer().job_size_, 0,
+                            rw->balancer().job_size_);
 
                 for (int od = od_s; od < od_e; ++od) {
                     const int id = od * jcp.stride_d;
@@ -473,8 +498,9 @@ void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
                     par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
                     par_conv.dst =
                         &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)];
-                    par_conv.filt = &rw->get_local_ptr(ithr, diff_weights)[
-                        w_job_loc * rw->balancer_.job_size_];
+                    par_conv.filt = rw->get_local_ptr(ithr, diff_weights,
+                            reducer_wei_scratchpad) +
+                        w_job_loc * rw->balancer().job_size_;
 
                     kernel_->jit_ker(&par_conv);
                 }
@@ -483,22 +509,21 @@ void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
             }
             nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
         }
-        rw->reduce(ithr, diff_weights);
+        rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
     };
 
     auto ker_bias = [&](int ithr, int nthr) {
-        auto rb = this->reducer_bias_;
-        assert(nthr == rb->balancer_.nthr_);
+        assert(nthr == rb->balancer().nthr_);
 
-        const int b_job_start = rb->balancer_.ithr_job_off(ithr);
-        const int b_njobs = rb->balancer_.ithr_njobs(ithr);
+        const int b_job_start = rb->balancer().ithr_job_off(ithr);
+        const int b_njobs = rb->balancer().ithr_njobs(ithr);
 
         if (b_njobs == 0) return;
 
         /* reduction dimension */
         int img_start{0}, img_end{0};
-        balance211(jcp.mb, rb->balancer_.nthr_per_group_,
-                rb->balancer_.id_in_group(ithr), img_start, img_end);
+        balance211(jcp.mb, rb->balancer().nthr_per_group_,
+                rb->balancer().id_in_group(ithr), img_start, img_end);
 
         /* jobs */
         int g_start{0}, ocb_start{0};
@@ -511,8 +536,9 @@ void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
                 const size_t _oc = g * jcp.nb_oc + ocb;
 
                 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
-                data_t *d_bias = &rb->get_local_ptr(ithr, diff_bias)[
-                    b_job_loc * rb->balancer_.job_size_];
+                data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
+                        reducer_bia_scratchpad) +
+                    b_job_loc * rb->balancer().job_size_;
 
                 if (img == img_start)
                     for (int o = 0; o < 8; ++o)
@@ -528,18 +554,17 @@ void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
                 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
             }
         }
-        rb->reduce(ithr, diff_bias);
+        rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
     };
 
-
     parallel(0, [&](const int ithr, const int nthr) {
         ker(ithr, nthr);
-        if (conf_.with_bias())
+        if (pd()->with_bias())
             ker_bias(ithr, nthr);
     });
 
     /* TODO: put this in ker_bias */
-    if (conf_.want_padded_bias()) {
+    if (pd()->wants_padded_bias()) {
         assert(jcp.ngroups == 1);
         for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
             diff_bias_in[oc] = diff_bias[oc];