Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_convolution.cpp
index 8767207..da07a52 100644 (file)
 * limitations under the License.
 *******************************************************************************/
 
-#include "mkldnn_types.h"
 #include "c_types_map.hpp"
-#include "jit_avx512_common_convolution.hpp"
 #include "mkldnn_thread.hpp"
 #include "type_helpers.hpp"
 #include "utils.hpp"
 
+#include "jit_avx512_common_convolution.hpp"
+
 namespace mkldnn {
 namespace impl {
 namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 using namespace nstl;
@@ -127,25 +128,40 @@ void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
         ker(&p);
 }
 #define wht_blk_off(d, g, ...) \
-        (conf_.with_groups() \
+        (pd()->with_groups() \
          ? (d).blk_off((g), __VA_ARGS__) \
          : (d).blk_off(__VA_ARGS__))
 
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
+void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
+prepare_padded_bias(const dst_data_t *&bias) const {
+    if (!pd()->wants_padded_bias()) return;
+
+    auto padded_bias = scratchpad().template get<dst_data_t>(
+            key_conv_padded_bias);
+    utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding);
+    utils::array_set(padded_bias + pd()->jcp_.oc_without_padding,
+            (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding);
+    bias = padded_bias;
+}
+
+template <data_type_t src_type, data_type_t wei_type,
           data_type_t dst_type>
-void _jit_avx512_common_convolution_fwd_t
-    <with_relu, src_type, wei_type, dst_type>::execute_forward_1d()
+void jit_avx512_common_convolution_fwd_t
+    <src_type, wei_type, dst_type>::execute_forward_1d() const
 {
     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    prepare_padded_bias(bias);
 
-    const auto &jcp = kernel_->jcp;
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+
+    const auto &jcp = pd()->jcp_;
     assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
 
     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
@@ -157,11 +173,6 @@ void _jit_avx512_common_convolution_fwd_t
     else
         nthr = mkldnn_get_max_threads();
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
-    }
     parallel(nthr, [&](const int ithr, const int nthr) {
         int start{0}, end{0}, start_copy;
         balance211(work_amount, nthr, ithr, start, end);
@@ -191,7 +202,7 @@ void _jit_avx512_common_convolution_fwd_t
                 int ocb = occ * jcp.nb_oc_blocking;
                 int g_ocb = g * jcp.nb_oc + ocb;
                 int g_oc = g_ocb * jcp.oc_block;
-                int g_icb = g * jcp.nb_ic;
+                int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
 
                 int ow_s =  owb * jcp.ow_block;
                 int iw_s =  ow_s * jcp.stride_w;
@@ -228,22 +239,24 @@ void _jit_avx512_common_convolution_fwd_t
     });
 }
 
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
+template <data_type_t src_type, data_type_t wei_type,
           data_type_t dst_type>
-void _jit_avx512_common_convolution_fwd_t
-    <with_relu, src_type, wei_type, dst_type>::execute_forward_2d()
+void jit_avx512_common_convolution_fwd_t
+    <src_type, wei_type, dst_type>::execute_forward_2d() const
 {
     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    prepare_padded_bias(bias);
 
-    const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+
+    const auto &jcp = pd()->jcp_;
+    const int MB = pd()->MB();
     assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
 
     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
@@ -255,12 +268,6 @@ void _jit_avx512_common_convolution_fwd_t
     else
         nthr = mkldnn_get_max_threads();
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
-    }
-
     parallel(nthr, [&](const int ithr, const int nthr) {
         int start{0}, end{0}, start_copy;
         balance211(work_amount, nthr, ithr, start, end);
@@ -290,7 +297,7 @@ void _jit_avx512_common_convolution_fwd_t
                 int ocb = occ * jcp.nb_oc_blocking;
                 int g_ocb = g * jcp.nb_oc + ocb;
                 int g_oc = g_ocb * jcp.oc_block;
-                int g_icb = g * jcp.nb_ic;
+                int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
 
                 int work_rem = end - start;
 
@@ -357,30 +364,26 @@ void _jit_avx512_common_convolution_fwd_t
     });
 }
 
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
+template <data_type_t src_type, data_type_t wei_type,
           data_type_t dst_type>
-void _jit_avx512_common_convolution_fwd_t
-    <with_relu, src_type, wei_type, dst_type>::execute_forward_3d()
+void jit_avx512_common_convolution_fwd_t
+    <src_type, wei_type, dst_type>::execute_forward_3d() const
 {
     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+    prepare_padded_bias(bias);
 
-    const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
-    assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper bias_d(pd()->weights_pd(1));
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
-    }
+    const auto &jcp = pd()->jcp_;
+    const int MB = pd()->MB();
+    assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
 
     parallel(0, [&](const int ithr, const int nthr) {
         int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
@@ -418,7 +421,7 @@ void _jit_avx512_common_convolution_fwd_t
                 int ocb = occ * jcp.nb_oc_blocking;
                 int g_ocb = g * jcp.nb_oc + ocb;
                 int g_oc = g_ocb * jcp.oc_block;
-                int g_icb = g * jcp.nb_ic;
+                int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
 
                 int work_rem = end - start;
                 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
@@ -491,25 +494,22 @@ void _jit_avx512_common_convolution_fwd_t
     });
 }
 
-template struct _jit_avx512_common_convolution_fwd_t<false, data_type::f32>;
-template struct _jit_avx512_common_convolution_fwd_t<true, data_type::f32>;
-template struct _jit_avx512_common_convolution_fwd_t<false, data_type::s16,
-        data_type::s16, data_type::s32>;
-template struct _jit_avx512_common_convolution_fwd_t<true, data_type::s16,
+template struct jit_avx512_common_convolution_fwd_t<data_type::f32>;
+template struct jit_avx512_common_convolution_fwd_t<data_type::s16,
         data_type::s16, data_type::s32>;
 
 template <data_type_t diff_dst_type, data_type_t wei_type,
           data_type_t diff_src_type>
 void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
-          diff_src_type>::execute_backward_data_1d() {
+          diff_src_type>::execute_backward_data_1d() const {
     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
                                                        (this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
 
@@ -579,18 +579,18 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
 template <data_type_t diff_dst_type, data_type_t wei_type,
           data_type_t diff_src_type>
 void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
-          diff_src_type>::execute_backward_data_2d() {
+          diff_src_type>::execute_backward_data_2d() const {
     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
                                                        (this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const int MB = pd()->MB();
 
     parallel(0, [&](const int ithr, const int nthr) {
         int start{0}, end{0}, start_copy;
@@ -704,18 +704,18 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
 template <data_type_t diff_dst_type, data_type_t wei_type,
           data_type_t diff_src_type>
 void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
-          diff_src_type>::execute_backward_data_3d() {
+          diff_src_type>::execute_backward_data_3d() const {
     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
                                                        (this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
-    const int MB = conf_.MB();
+    const int MB = pd()->MB();
 
     parallel(0, [&](const int ithr, const int nthr) {
         int start{0}, end{0}, start_copy;
@@ -881,89 +881,33 @@ template <data_type_t src_type, data_type_t diff_dst_type,
           data_type_t diff_weights_type>
 jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
           diff_weights_type>::
-jit_avx512_common_convolution_bwd_weights_t(const pd_t *pd,
+jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd,
         const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr)
+    : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr)
     , trans_kernel_(nullptr), trans_dst_kernel_(nullptr), acc_ker_(nullptr)
-    , reducer_bias_(nullptr), padded_bias_(nullptr), tr_src_(nullptr)
-    , tr_diff_dst_(nullptr), ws_reduction_(nullptr), tr_src_bctx_(nullptr)
-    , tr_diff_dst_bctx_(nullptr)
+    , reducer_bias_(nullptr)
 {
-    const auto &j = conf_.jcp_;
-    kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j);
+    const auto &j = pd()->jcp_;
 
-    balance();
+    nthr_ = j.nthr;
+    nthr_mb_ = j.nthr_mb;
+    nthr_g_ = j.nthr_g;
+    nthr_oc_b_ = j.nthr_oc_b;
+    nthr_ic_b_ = j.nthr_ic_b;
+
+    kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j);
 
     if (utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
         trans_kernel_ = create_trans_src(&j);
         if (utils::one_of(j.ver, ver_4vnni, ver_vnni))
             trans_dst_kernel_ = create_trans_dst(&j);
-        if (j.is_1stconv) {
-            const int tr_src_size =
-                nthr_ / nthr_oc_b_ * j.ih * j.stride_w * j.tr_ld;
-            tr_src_ = (src_data_t *)malloc(tr_src_size * sizeof(src_data_t), 64);
-        } else {
-            // XXX: See the comment about tr_iw and guarding elements in
-            // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
-            const int max_nthr = nthr_mb_ * j.ngroups * j.nb_ic;
-            const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw;
-            const int tr_src_size = max_nthr * min_tr_src_size_per_thr
-                + j.tr_src_num_guard_elems;
-            tr_src_ = (src_data_t *)malloc(tr_src_size * sizeof(src_data_t), 64);
-            /* to avoid NaNs in computations we zero tail num_guard_elems for
-             * each possible thread group */
-            for (int ithr = 1; ithr <= max_nthr; ++ithr) {
-                src_data_t *ts = &tr_src_[ithr * min_tr_src_size_per_thr];
-                for (int i = 0; i < j.tr_src_num_guard_elems; ++i)
-                    ts[i] = 0;
-            }
-        }
-
-        /* prepare synchronization contexts */
-        if (nthr_oc_b_ > 1) {
-            const int tr_src_bctx_size = nthr_ / nthr_oc_b_;
-            tr_src_bctx_ = (simple_barrier::ctx_t *)malloc(
-                    tr_src_bctx_size * sizeof(simple_barrier::ctx_t), 64);
-            for (int i = 0; i < tr_src_bctx_size; ++i)
-                simple_barrier::ctx_init(&tr_src_bctx_[i]);
-        }
-
-        if (utils::one_of(j.ver, ver_4vnni, ver_vnni)) {
-            const size_t tr_diff_dst_size =
-                nthr_mb_ * j.ngroups * j.nb_oc * j.oc_block * j.tr_ow * j.oh;
-            tr_diff_dst_ = (diff_dst_data_t *)malloc(
-                    tr_diff_dst_size * sizeof(diff_dst_data_t), 64);
-
-            /* prepare synchronization contexts */
-            if (nthr_ic_b_ > 1) {
-                const size_t tr_diff_dst_bctx_size = nthr_ / nthr_ic_b_;
-                tr_diff_dst_bctx_ = (simple_barrier::ctx_t *)malloc(
-                        tr_diff_dst_bctx_size * sizeof(simple_barrier::ctx_t),
-                        64);
-                for (size_t i = 0; i < tr_diff_dst_bctx_size; ++i)
-                    simple_barrier::ctx_init(&tr_diff_dst_bctx_[i]);
-            }
-        }
     }
 
-    if (nthr_mb_ > 1) {
-        const int wei_size = j.ngroups * j.oc * j.ic * j.kh * j.kw * j.kd;
-        const int bia_size = j.ngroups * j.oc;
-        ws_reduction_ = (diff_weights_data_t *)malloc((nthr_mb_ - 1)
-            * (wei_size + bia_size) * sizeof(diff_weights_data_t), 64);
+    if (nthr_mb_ > 1)
         acc_ker_ = new cpu_accumulator_1d_t<diff_weights_type>();
-        simple_barrier::ctx_init(&reduction_bctx_);
-    }
 
-    if (conf_.with_bias()) {
-        const size_t max_buffer_size = nthr_ * 3 * 5 * 5 * 16 * 16;
-        reducer_bias_ = new cpu_reducer_t<diff_weights_type>(reduce_balancer_t(
-                    nthr_, j.oc_block, j.ngroups * j.nb_oc, j.mb,
-                    max_buffer_size));
-        if (conf_.want_padded_bias())
-            padded_bias_ = (diff_weights_data_t *)
-                malloc(sizeof(diff_weights_data_t) * j.oc, 64);
-    }
+    reducer_bias_ =
+        new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_);
 }
 
 template <data_type_t src_type, data_type_t diff_dst_type,
@@ -975,6 +919,17 @@ struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
     const diff_weights_data_t *diff_weights;
     diff_weights_data_t *diff_bias;
 
+    const memory_tracking::grantor_t scratchpad;
+
+    src_data_t *tr_src;
+    simple_barrier::ctx_t *tr_src_bctx;
+
+    diff_dst_data_t *tr_diff_dst;
+    simple_barrier::ctx_t *tr_diff_dst_bctx;
+
+    diff_weights_data_t *wei_bia_reduction;
+    simple_barrier::ctx_t *wei_bia_reduction_bctx;
+
     int ithr;
     int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb;
     int ithr_but_oc;
@@ -986,16 +941,30 @@ struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
     int ic_b_start = 0, ic_b_end = 0, ic_b_work;
 
     thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self,
-            int ithr): ithr(ithr) {
-
+            int ithr): scratchpad(self->scratchpad()), ithr(ithr) {
         src = reinterpret_cast<const src_data_t *>(self->input_memory(0));
         diff_dst = reinterpret_cast<const diff_dst_data_t *>(
             self->input_memory(1));
         diff_weights = reinterpret_cast<diff_weights_data_t *>(self->memory(0));
-        diff_bias = self->conf_.want_padded_bias()
-            ? self->padded_bias_
+        diff_bias = self->pd()->wants_padded_bias()
+            ? scratchpad.template get<diff_weights_data_t>(
+                    key_conv_padded_bias)
             : reinterpret_cast<diff_weights_data_t *>(self->memory(1));
 
+        tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
+        tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+                key_conv_tr_src_bctx);
+
+        tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
+                key_conv_tr_diff_dst);
+        tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+                key_conv_tr_diff_dst_bctx);
+
+        wei_bia_reduction = scratchpad.template get<diff_weights_data_t>(
+                key_conv_wei_bia_reduction);
+        wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+                key_conv_wei_bia_reduction_bctx);
+
         ithr_ic_b = ithr % self->nthr_ic_b_;
         ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
         ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
@@ -1030,20 +999,20 @@ struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 template <data_type_t src_type, data_type_t diff_dst_type,
           data_type_t diff_weights_type>
 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
-    diff_weights_type>::compute_diff_weights(const thread_info_t *ti) {
-    const memory_desc_wrapper src_d(conf_.src_pd(0));
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+    diff_weights_type>::compute_diff_weights(const thread_info_t *ti) const {
+    const memory_desc_wrapper src_d(pd()->src_pd(0));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
     const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd;
 
     diff_weights_data_t *diff_wei = ti->ithr_mb == 0
         ? (diff_weights_data_t*)ti->diff_weights
-        : (diff_weights_data_t*)ws_reduction_ + (ti->ithr_mb - 1) * wei_size;
+        : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
     diff_weights_data_t *diff_bia = ti->ithr_mb == 0
         ? (diff_weights_data_t*)ti->diff_bias
-        : (diff_weights_data_t*)ws_reduction_ + (nthr_mb_ - 1) * wei_size
+        : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
           + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
 
     // TODO: use memory descriptor with the same fmt as src (or use a macro :))
@@ -1069,7 +1038,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 
         const int _ic = g * jcp.nb_ic + ic_b;
         src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)];
-        src_data_t *tr_src1 = &tr_src_[tr_src_off(ti->ithr_mb, _ic, j)];
+        src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)];
 
         assert(jcp.ic_block == 16);
         const int src_stride = jcp.iw * jcp.ic_block;
@@ -1147,7 +1116,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
         const diff_dst_data_t *diff_dst1
             = &ti->diff_dst[diff_dst_d.blk_off(img, oc, j)];
         diff_dst_data_t *tr_diff_dst1
-            = &tr_diff_dst_[tr_diff_dst_off(img, oc, j)];
+            = &ti->tr_diff_dst[tr_diff_dst_off(img, oc, j)];
 
 
         assert(jcp.ic_block == 16);
@@ -1206,7 +1175,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
     if (jcp.is_1stconv && jcp.ver == ver_4fma) {
         /* prepare contexts */
         auto tr_ctx = jit_trans_src_t::ctx_t();
-        tr_ctx.tr_src = tr_src_
+        tr_ctx.tr_src = ti->tr_src
             + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld;
 
         assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1));
@@ -1215,7 +1184,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
         balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end);
         tr_ctx.tr_src_ih_start = ih_start;
         tr_ctx.tr_src_ih_end = ih_end;
-        tr_ctx.tr_src_bctx = tr_src_bctx_ + ti->ithr_but_oc;
+        tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc;
 
         auto p = jit_conv_call_s();
         p.src = tr_ctx.tr_src;
@@ -1267,20 +1236,20 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
                 /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
                 using simple_barrier::barrier;
                 if (nthr_oc_b_ > 1)
-                    barrier(&tr_src_bctx_[ti->ithr_but_oc], nthr_oc_b_);
+                    barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
                 uker_trans(img);
                 if (nthr_oc_b_ > 1)
-                    barrier(&tr_src_bctx_[ti->ithr_but_oc], nthr_oc_b_);
+                    barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
             }
 
             if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
                 /* tr_diff_dst[nb_oc][OW][oh][16c][2ow]
                  *  <- diff_dst[nb_oc][oh][ow][16c] */
                 if (nthr_ic_b_ > 1)
-                    barrier(&tr_diff_dst_bctx_[ti->ithr_but_ic], nthr_ic_b_);
+                    barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
                 diff_dst_trans(img);
                 if (nthr_ic_b_ > 1)
-                    barrier(&tr_diff_dst_bctx_[ti->ithr_but_ic], nthr_ic_b_);
+                    barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
             }
 
             for (int g = ti->g_start; g < ti->g_end; ++g) {
@@ -1291,10 +1260,10 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 
                 jit_conv_ker_pipeline(kernel_->jit_ker, p,
                          (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
-                         ? &tr_src_[tr_src_off(ti->ithr_mb, _ic, 0)]
+                         ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
                          : &ti->src[src_d.blk_off(img, _ic)]),
                          utils::one_of(jcp.ver, ver_4vnni, ver_vnni)
-                         ? &tr_diff_dst_[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
+                         ? &ti->tr_diff_dst[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
                          : &ti->diff_dst[diff_dst_d.blk_off(img, _oc)],
                         diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
                         0, (img == ti->img_start), 0, 0);
@@ -1307,10 +1276,10 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
             const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
             jit_conv_ker_pipeline(kernel_->jit_ker, p,
                     (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
-                     ? &tr_src_[tr_src_off(ti->ithr_mb, _ic, 0)]
+                     ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
                      : &ti->src[src_d.blk_off(img + 1, _ic)]),
                     utils::one_of(jcp.ver, ver_4vnni, ver_vnni)
-                    ? &tr_diff_dst_[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
+                    ? &ti->tr_diff_dst[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
                     : &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
                     diff_wei + wht_blk_off(
                         diff_weights_d, ti->g_start,
@@ -1323,10 +1292,11 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 template <data_type_t src_type, data_type_t diff_dst_type,
           data_type_t diff_weights_type>
 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
-    diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) {
-    const memory_desc_wrapper src_d(conf_.src_pd(0));
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+    diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) const
+{
+    const memory_desc_wrapper src_d(pd()->src_pd(0));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
     const int wei_size
@@ -1334,10 +1304,10 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 
     diff_weights_data_t *diff_wei = ti->ithr_mb == 0
         ? (diff_weights_data_t*)ti->diff_weights
-        : (diff_weights_data_t*)ws_reduction_ + (ti->ithr_mb - 1) * wei_size;
+        : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
     diff_weights_data_t *diff_bia = ti->ithr_mb == 0
         ? (diff_weights_data_t*)ti->diff_bias
-        : (diff_weights_data_t*)ws_reduction_ + (nthr_mb_ - 1) * wei_size
+        : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
           + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
 
     const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
@@ -1397,17 +1367,17 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 template <data_type_t src_type, data_type_t diff_dst_type,
           data_type_t diff_weights_type>
 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
-    diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) {
-    const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+    diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const {
+    const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
     const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
     const int bia_size = jcp.ngroups * jcp.oc;
     const diff_weights_data_t *diff_bias_ws
-        = ws_reduction_ + (nthr_mb_ - 1) * wei_size;
+        = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size;
 
-    /* diff_weights[:] += sum(ws_reduction_[thr_mb][:]) */
-    simple_barrier::barrier(&reduction_bctx_, nthr_);
+    /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
+    simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
 
     const int ic_b_kh_work = ti->ic_b_work * jcp.kh;
     const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
@@ -1437,7 +1407,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
             diff_weights_data_t *d
                 = (diff_weights_data_t *)ti->diff_weights + off;
             diff_weights_data_t *s
-                = ws_reduction_ + (thr_mb - 1) * wei_size + off;
+                = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
 
             acc_ker_->accumulate(d, s, acc_size);
 
@@ -1457,15 +1427,15 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 template <data_type_t src_type, data_type_t diff_dst_type,
           data_type_t diff_weights_type>
 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
-    diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) {
-    const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+    diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) const {
+    const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
     const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw
         * jcp.kd;
 
-    /* diff_weights[:] += sum(ws_reduction_[thr_mb][:]) */
-    simple_barrier::barrier(&reduction_bctx_, nthr_);
+    /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
+    simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
 
     const int ic_b_kh_work = ti->ic_b_work * jcp.kd;
     const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
@@ -1494,7 +1464,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
             diff_weights_data_t *d
                 = (diff_weights_data_t *)ti->diff_weights + off;
             diff_weights_data_t *s
-                = ws_reduction_ + (thr_mb - 1) * wei_size + off;
+                = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
             acc_ker_->accumulate(d, s, acc_size);
 
             nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
@@ -1506,25 +1476,28 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 template <data_type_t src_type, data_type_t diff_dst_type,
           data_type_t diff_weights_type>
 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
-    diff_weights_type>::compute_diff_bias(const thread_info_t *ti) {
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
+    diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const {
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
 
     auto rb = this->reducer_bias_;
-    assert(nthr_ == rb->balancer_.nthr_);
+    assert(nthr_ == rb->balancer().nthr_);
+
+    const auto reducer_bia_scratchpad = memory_tracking::grantor_t(
+            ti->scratchpad, prefix_reducer_bia);
 
     const auto &jcp = kernel_->jcp;
 
     if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return;
 
-    const int b_job_start = rb->balancer_.ithr_job_off(ti->ithr);
-    const int b_njobs = rb->balancer_.ithr_njobs(ti->ithr);
+    const int b_job_start = rb->balancer().ithr_job_off(ti->ithr);
+    const int b_njobs = rb->balancer().ithr_njobs(ti->ithr);
 
     if (b_njobs == 0) return;
 
     /* reduction dimension */
     int img_start{0}, img_end{0};
-    balance211(jcp.mb, rb->balancer_.nthr_per_group_,
-            rb->balancer_.id_in_group(ti->ithr), img_start, img_end);
+    balance211(jcp.mb, rb->balancer().nthr_per_group_,
+            rb->balancer().id_in_group(ti->ithr), img_start, img_end);
 
     /* jobs */
     int g_start{0}, ocb_start{0};
@@ -1536,9 +1509,9 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 
             const diff_dst_data_t *d_dst
                 = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
-            diff_weights_data_t *d_bias = &rb->get_local_ptr(ti->ithr,
-                (diff_weights_data_t *)ti->diff_bias)[
-                b_job_loc * rb->balancer_.job_size_];
+            diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr,
+                    ti->diff_bias, reducer_bia_scratchpad)
+                + b_job_loc * rb->balancer().job_size_;
 
             if (img == img_start)
                 for (int o = 0; o < 16; ++o)
@@ -1554,13 +1527,13 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
         }
     }
 
-    rb->reduce(ti->ithr, ti->diff_bias);
+    rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad);
 }
 
 template <data_type_t src_type, data_type_t diff_dst_type,
           data_type_t diff_weights_type>
 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
-    diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) {
+    diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) const {
 
     const auto &jcp = kernel_->jcp;
 
@@ -1568,7 +1541,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
         * jcp.kw * jcp.kd;
     const int bia_size = jcp.ngroups * jcp.oc;
     const diff_weights_data_t *diff_bias_ws
-            = ws_reduction_ + (size_t)(nthr_mb_ - 1) * wei_size;
+            = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size;
 
     if (nthr_mb_ > 1) mkldnn_thr_barrier();
 
@@ -1584,161 +1557,91 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
 template <data_type_t src_type, data_type_t diff_dst_type,
           data_type_t diff_weights_type>
 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
-    diff_weights_type>::execute_backward_weights() {
+    diff_weights_type>::prepare_scratchpad_data() const
+{
+    const auto &j = pd()->jcp_;
+    auto scratchpad = this->scratchpad();
+
+    if (utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
+        if (!j.is_1stconv) {
+            // XXX: See the comment about tr_iw and guarding elements in
+            // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
+            const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic;
+            const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw;
+
+            auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
+            /* to avoid NaNs in computations we zero tail num_guard_elems for
+             * each possible thread group */
+
+            for (int ithr = 1; ithr <= max_nthr; ++ithr) {
+                src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr];
+                for (int i = 0; i < j.tr_src_num_guard_elems; ++i)
+                    ts[i] = 0;
+            }
+        }
+
+        if (j.nthr_oc_b > 1) {
+            const int tr_src_bctx_size = j.nthr / j.nthr_oc_b;
+            auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+                    key_conv_tr_src_bctx);
+            for (int i = 0; i < tr_src_bctx_size; ++i)
+                simple_barrier::ctx_init(&tr_src_bctx[i]);
+        }
+
+        if (utils::one_of(j.ver, ver_4vnni, ver_vnni) && j.nthr_ic_b > 1) {
+            const int tr_diff_dst_bctx_size = j.nthr / j.nthr_ic_b;
+            auto tr_diff_dst_bctx =
+                scratchpad.template get<simple_barrier::ctx_t>(
+                        key_conv_tr_diff_dst_bctx);
+                for (int i = 0; i < tr_diff_dst_bctx_size; ++i)
+                    simple_barrier::ctx_init(&tr_diff_dst_bctx[i]);
+        }
+    }
+
+    if (nthr_mb_ > 1) {
+        simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
+                    key_conv_wei_bia_reduction_bctx));
+    }
+
+    const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+            prefix_reducer_bia);
+    auto rb = this->reducer_bias_;
+    rb->init(reducer_bia_scratchpad);
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+          data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+    diff_weights_type>::execute_backward_weights() const {
+    prepare_scratchpad_data();
+
     parallel(nthr_, [&](const int ithr, const int nthr) {
         assert(nthr_ == nthr);
 
         thread_info_t thread_info(this, ithr);
 
-        if (utils::one_of(conf_.ndims(), 3, 4)) {
+        if (utils::one_of(pd()->ndims(), 3, 4)) {
             compute_diff_weights(&thread_info);
             if (nthr_mb_ > 1) reduce_diff_weights(&thread_info);
-            if (conf_.with_bias()) compute_diff_bias(&thread_info);
-        } else if (conf_.ndims() == 5) {
+            if (pd()->with_bias()) compute_diff_bias(&thread_info);
+        } else if (pd()->ndims() == 5) {
             compute_diff_weights_3d(&thread_info);
             if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
-            if (conf_.with_bias()) compute_diff_bias_3d(&thread_info);
+            if (pd()->with_bias()) compute_diff_bias_3d(&thread_info);
         } else {
             assert(false);
         }
     });
 
     /* TODO: put that into compute_diff_bias() */
-    if (conf_.want_padded_bias()) {
+    if (pd()->wants_padded_bias()) {
+        auto diff_bias = scratchpad().template get<const diff_weights_data_t>(
+                key_conv_padded_bias);
         auto diff_bias_in
             = reinterpret_cast<diff_weights_data_t *>(this->memory(1));
-        for (int oc = 0; oc < conf_.jcp_.oc_without_padding; ++oc)
-            diff_bias_in[oc] = this->padded_bias_[oc];
-    }
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
-          data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
-    diff_weights_type>::balance() {
-    const int max_threads = mkldnn_get_max_threads();
-    const auto &j = conf_.jcp_;
-
-    nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
-
-    if (max_threads < j.ngroups) {
-        /* simplification... fortunately it doesn't hurt much */
-        return;
-    }
-
-    if (!mkldnn_thr_syncable()
-            && utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
-        // should not happen -- the driver is not ready
-        // for TBB-like non-synchronous threading yet
-        return;
+        for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
+            diff_bias_in[oc] = diff_bias[oc];
     }
-
-    if (j.ver == ver_4fma && j.is_1stconv) {
-        nthr_g_ = 1;
-        nthr_oc_b_ = 1;
-        nthr_ic_b_ = nstl::min(j.nb_ic, max_threads);
-        nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb);
-        nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_;
-        return;
-    }
-
-    nthr_g_ = j.ngroups;
-    const int nthr = max_threads / nthr_g_;
-
-    auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
-        /* calculate per thread memory cost (read/write). high level optimizer
-         * tries to minimize memory consumption. few notes:
-         *  (n1) unclear why, but that essentially helps first convolution...
-         *  (n2) assuming the reduction over minibatch is always there:
-         *    - instead of 8 it should be 5 here (write ~= 2 read):
-         *      kernel: temporal workspace 1 write
-         *      reduction: 1 read from workspace and 1 write to the diff_wei
-         *    - but experiments showed 8 works better than 5 or 6... */
-
-        const int src_coef = j.ver == ver_4fma || j.ver == ver_vnni ? 4 : 1;
-        const int dst_coef = 1;
-        const int wei_coef = j.ver == ver_vnni ? 4 : 8;
-
-        return 0
-            + src_coef
-            * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
-            * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
-            / j.stride_d / j.stride_h / j.stride_w /* (n1) */
-            + dst_coef
-            * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
-            * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
-            + wei_coef /* (n2) */
-            * div_up(j.ngroups, nthr_g_)
-            * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
-            * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
-    };
-
-    int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
-
-    /* step 1: find the best thread distribution with lowest memory cost */
-    const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
-    for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
-        const int nthr_par = nthr / nthr_mb;
-        const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
-        for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
-            int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
-
-            int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
-            if (mem_cost <= best_mem_cost) {
-                best_mem_cost = mem_cost;
-                nthr_mb_ = nthr_mb;
-                nthr_oc_b_ = nthr_oc_b;
-                nthr_ic_b_ = nthr_ic_b;
-            }
-        }
-
-        if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
-    }
-
-    if (j.ver != ver_vnni && !mayiuse(avx512_mic)) {
-        auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
-            return 1
-                * div_up(j.mb, nthr_mb)
-                * div_up(j.ngroups, nthr_g_)
-                * div_up(j.nb_oc, nthr_oc_b)
-                * div_up(j.nb_ic, nthr_ic_b);
-        };
-
-        /* step 2: search for a thread distribution with lower compute cost.
-         * the constrains:
-         *  - memory cost cannot exceed 110% of the best found in the step 1
-         *  - unless compute cost is 133% lower than the current best case
-         * note: both constants were found empirically */
-        int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
-        for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
-            const int nthr_par = nthr / nthr_mb;
-            const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
-            for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
-                int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
-                int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
-                int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
-
-                const bool opt1 = comp_cost <= best_comp_cost
-                    && mem_cost < 1.1 * best_mem_cost;
-                const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
-
-                if (opt1 || opt2) {
-                    best_comp_cost = comp_cost;
-                    nthr_mb_ = nthr_mb;
-                    nthr_oc_b_ = nthr_oc_b;
-                    nthr_ic_b_ = nthr_ic_b;
-                }
-            }
-
-            if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
-        }
-    }
-
-    if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
-        nthr_mb_ = min(j.mb * j.od, max_threads);
-    nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
-    assert(nthr_ <= max_threads);
-    assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
 }
 
 template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>;