Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_conv_kernel.cpp
index 7f00356..3206270 100644 (file)
@@ -18,6 +18,8 @@
 #include "nstl.hpp"
 #include "type_helpers.hpp"
 #include "utils.hpp"
+
+#include "cpu_barrier.hpp"
 #include "cpu_memory.hpp"
 
 #include "jit_avx512_common_conv_kernel.hpp"
@@ -30,6 +32,7 @@ namespace impl {
 namespace cpu {
 
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 using namespace Xbyak;
 
@@ -59,32 +62,29 @@ inline void pick_loop_order(jit_conv_conf_t &jcp) {
 
 inline bool is_1stconv(const jit_conv_conf_t &jcp) {
     if (mayiuse(avx512_core) && !mayiuse(avx512_core_vnni))
-        return jcp.ic < 16;
+        return (jcp.ic < 16 && jcp.ngroups == 1);
     else
         return one_of(jcp.ic, 1, 3);
 }
-inline bool is_1D_conv(const jit_conv_conf_t &jcp) {
-    return (jcp.ih == 1 && jcp.kh == 1);
-}
-inline bool is_ow_threading_available(const jit_conv_conf_t &jcp) {
-    return (is_1D_conv(jcp) && one_of(jcp.ndims, 3, 4)
-        && !(jcp.ver == ver_fma && mayiuse(avx512_mic)));
-}
+
 inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
     return (jcp.nb_ow > 1);
 }
-inline bool is_1D_prefetching(const jit_conv_conf_t &jcp) {
-    return (jcp.ver == ver_4fma && is_1D_conv(jcp) && is_ow_threading_on(jcp));
+
+inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) {
+    return (jcp.ver == ver_4fma && is_ow_threading_on(jcp));
 }
+
 }
 
-void jit_avx512_common_conv_fwd_kernel::prepare_output(int ur_w)
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w)
 {
     for (int k = 0; k < jcp.nb_oc_blocking; k++)
         for (int j = 0; j < ur_w; j++) {
-            Zmm zmm = zmm_out(j, k);
-            vpxord(zmm, zmm, zmm);
-            if (!is_1D_prefetching(jcp)) {
+            Vmm vmm = vmm_out(j, k);
+            vpxord(vmm, vmm, vmm);
+            if (!is_owb_prefetching(jcp)) {
                 size_t aux_output_offset = get_output_offset(j, k);
                 mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf,
                             aux_output_offset, reg_out_long_offt));
@@ -92,7 +92,8 @@ void jit_avx512_common_conv_fwd_kernel::prepare_output(int ur_w)
         }
 }
 
-void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w)
 {
     Label no_update_label, store_label, postproc_label;
 
@@ -108,9 +109,9 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
 
     for (int k = 0; k < jcp.nb_oc_blocking; k++)
         for (int j = 0; j < ur_w; j++) {
-            Zmm zmm = zmm_out(j, k);
+            Vmm vmm = vmm_out(j, k);
             size_t aux_output_offset = get_output_offset(j, k);
-            vadd(zmm,
+            vadd(vmm,
                 make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
         }
 
@@ -126,8 +127,8 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
         for (int k = 0; k < jcp.nb_oc_blocking; k++) {
             int bias_offset = jcp.typesize_out * k * jcp.oc_block;
             for (int j = 0; j < ur_w; j++) {
-                Zmm zmm = zmm_out(j, k);
-                vadd(zmm, EVEX_compress_addr(reg_bias, bias_offset));
+                Vmm vmm = vmm_out(j, k);
+                vadd(vmm, EVEX_compress_addr(reg_bias, bias_offset));
             }
             mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64));
         }
@@ -142,18 +143,29 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
     int depthwise_inj_idx = 0;
     const auto &p = attr_.post_ops_;
 
-    if (p.len_ == 0 && eltwise_injectors.size() == 1) {
-        for (int k = 0; k < jcp.nb_oc_blocking; k++)
-            eltwise_injectors[0]->compute_vector_range(
-                    k*jcp.ur_w, k*jcp.ur_w + ur_w);
-    }
-
     for (int i = 0; i < p.len_; i++) {
         auto& post_op = p.entry_[i];
         if (post_op.is_eltwise()) {
-            for (int k = 0; k < jcp.nb_oc_blocking; k++)
-                eltwise_injectors[eltwise_inj_idx]->compute_vector_range(
-                        k*jcp.ur_w, k*jcp.ur_w + ur_w);
+            if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) {
+                Vmm vmm_zero = vmm_wei;
+                vpxord(vmm_zero, vmm_zero, vmm_zero);
+
+                for (int k = 0; k < jcp.nb_oc_blocking; k++)
+                    for (int j = 0; j < ur_w; j++) {
+                        Vmm vmm = vmm_out(j, k);
+                        vpcmpd(k1, vmm, vmm_zero, _cmp_lt_os);
+                        vpmulld(vmm | k1, vmm, vmm_zero);
+                    }
+            } else {
+                if (ur_w == jcp.ur_w) {
+                    eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0,
+                                                            jcp.nb_oc_blocking * jcp.ur_w);
+                } else {
+                    for (int k = 0; k < jcp.nb_oc_blocking; k++)
+                        eltwise_injectors[eltwise_inj_idx]->compute_vector_range(k * jcp.ur_w,
+                                                                                 k * jcp.ur_w + ur_w);
+                }
+            }
 
             eltwise_inj_idx++;
         } else if (post_op.is_depthwise()) {
@@ -178,18 +190,25 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
     L(store_label);
     for (int k = 0; k < jcp.nb_oc_blocking; k++)
         for (int j = 0; j < ur_w; j++) {
-            Zmm zmm = zmm_out(j, k);
+            Vmm vmm = vmm_out(j, k);
             size_t aux_output_offset = (size_t)typesize *
                 ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
             vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset,
-                        reg_out_long_offt), zmm);
-            if (!is_1D_prefetching(jcp))
+                        reg_out_long_offt), vmm);
+            if (!is_owb_prefetching(jcp))
                 mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf,
                             aux_output_offset, reg_out_long_offt));
         }
 }
 
-void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma_1st(int ur_w,
+    int pad_l, int pad_r)
+{
+}
+
+template<>
+void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma_1st(int ur_w,
         int pad_l, int pad_r)
 {
     assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0);
@@ -201,9 +220,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
     int ic_block = jcp.ic_block;
     int oc_block = jcp.oc_block;
 
-    Label kh_label, kd_label, skip_kd_loop;
-
-    prepare_output(ur_w);
+    Label kh_label, kd_label;
 
     if (one_of(jcp.ndims, 3, 4)) {
         mov(aux_reg_inp, reg_inp);
@@ -226,18 +243,9 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
         mov(aux_reg_inp_d, reg_inp);
         mov(aux_reg_inp_d_prf, reg_inp_prf);
 
-        if ((jcp.kd - 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
-            cmp(reg_ki, 0);
-            je(skip_kd_loop, T_NEAR);
-        }
         L(kd_label);
     }
     mov(reg_kj, reg_kh);
-    Label skip_kh_loop;
-    if ((jcp.kh - 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
-        cmp(reg_kj, 0);
-        je(skip_kh_loop, T_NEAR);
-    }
     if (jcp.ndims == 5) {
         mov(aux_reg_inp, aux_reg_inp_d);
         mov(aux_reg_ker, aux_reg_ker_d);
@@ -253,10 +261,10 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
                         * ((ki + i) * oc_block
                                   + ic * kw * jcp.kh * jcp.kd * oc_block);
                 if (ki + i < kw)
-                    vmovups(zmm_ker(i),
+                    vmovups(vmm_ker(i),
                         EVEX_compress_addr(aux_reg_ker, aux_ker_offset));
                 else
-                    vpxord(zmm_ker(i), zmm_ker(i), zmm_ker(i));
+                    vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i));
             }
 
             int j_start = get_ow_start(ki, pad_l);
@@ -266,7 +274,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
                 size_t aux_input_offset = (size_t)jcp.typesize_in
                         * ((size_t)(ki + j * stride_w
                             - pad_l) + (size_t)ic * iw * ih * jcp.id);
-                v4fmaddps(zmm_out(j, 0), zmm_ker(0),
+                v4fmaddps(vmm_out(j, 0), vmm_ker(0),
                         EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset,
                         reg_long_offt));
                 if (ki + prf_count < kw && prf_count < 4
@@ -299,8 +307,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
     cmp(reg_kj, 0);
     jg(kh_label, T_NEAR);
 
-    L(skip_kh_loop);
-
     if (jcp.ndims == 5) {
         add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw);
         add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block);
@@ -309,23 +315,28 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
         dec(reg_ki);
         cmp(reg_ki, 0);
         jg(kd_label, T_NEAR);
-        L(skip_kd_loop);
 
         pop(reg_out);
         pop(reg_out_prf);
     }
 
-    store_output(ur_w);
     if (max_input_offset > INT_MAX) pop(reg_inp_prf);
 }
 
-void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma(int ur_w,
+    int pad_l, int pad_r)
+{
+}
+
+template<>
+void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma(int ur_w,
         int pad_l, int pad_r)
 {
     int stride_w = jcp.stride_w;
     int ic_block = jcp.ic_block;
     int oc_block = jcp.oc_block;
-    Label kh_label, last_iter_label, loop_end_label, kd_label, skip_kd_loop;
+    Label kh_label, last_iter_label, loop_end_label, kd_label;
     int ker_load_number = 4;
     int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block;
     int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
@@ -347,7 +358,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
     auto kernel_loads = [=](int ki, int ic, int kk) {
         for (int ii = 0; ii < ker_load_number; ii++) {
             int aux_kernel_offset = kernel_offset(kk, ic + ii, ki);
-            vmovups(zmm_ker(ii),
+            vmovups(vmm_ker(ii),
                 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
         }
     };
@@ -364,8 +375,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
         }
     };
 
-    prepare_output(ur_w);
-
     if (one_of(jcp.ndims, 3, 4)) {
         mov(aux_reg_inp, reg_inp);
         mov(aux_reg_ker, reg_ker);
@@ -382,21 +391,11 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
         mov(aux_reg_inp_d, reg_inp);
         mov(aux_reg_inp_d_prf, reg_inp_prf);
         mov(aux_reg_ker_d_prf, reg_ker_prf);
-
-        if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
-            cmp(reg_ki, 0);
-            je(skip_kd_loop, T_NEAR);
-        }
         L(kd_label);
         mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
     } else {
         mov(reg_kj, reg_kh);
     }
-    Label skip_kh_loop;
-    if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
-        cmp(reg_kj, 0);
-        je(skip_kh_loop, T_NEAR);
-    }
     if (jcp.ndims == 5) {
         mov(aux_reg_inp, aux_reg_inp_d);
         mov(aux_reg_ker, aux_reg_ker_d);
@@ -427,7 +426,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
                                 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
                                            - pad_l) * ic_block
                                                        + ic);
-                        v4fmaddps(zmm_out(oi, kk), zmm_ker(0),
+                        v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
                             EVEX_compress_addr(aux_reg_inp, aux_input_offset));
 
                         if (oi % 2) {
@@ -468,7 +467,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
                                     * ((ki * (jcp.dilate_w + 1) + oi * stride_w
                                                - pad_l) * ic_block
                                                            + ic);
-                            v4fmaddps(zmm_out(oi, kk), zmm_ker(0),
+                            v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
                                 EVEX_compress_addr(aux_reg_inp,
                                     aux_input_offset));
                             if (oi % 2) {
@@ -499,11 +498,11 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
                         int aux_input_offset = typesize
                                 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
                                 - pad_l) * ic_block + ic);
-                        v4fmaddps(zmm_out(oi, kk), zmm_ker(0),
+                        v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
                             EVEX_compress_addr(aux_reg_inp,
                                 aux_input_offset));
 
-                        if (!is_1D_prefetching(jcp)) {
+                        if (!is_owb_prefetching(jcp)) {
                             if ((oi % 2) && (prf_count_t1 < 4)) {
                                 mic_prefetcht1(EVEX_compress_addr(
                                     aux_reg_ker_prf, kernel_offset(kk,
@@ -521,7 +520,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
                                 prf_count_t0++;
                             }
                         }
-                        if (!is_1D_prefetching(jcp)) {
+                        if (!is_owb_prefetching(jcp)) {
                             if (pref_current_inp) {
                                 if (ki == 0 && ic == 0 && kk == 0)
                                     mic_prefetcht0(EVEX_compress_addr(
@@ -560,8 +559,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
     cmp(reg_kj, 0);
     jg(kh_label, T_NEAR);
 
-    L(skip_kh_loop);
-
     if (jcp.ndims == 5) {
         add(aux_reg_inp_d,
                 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
@@ -575,16 +572,14 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
         dec(reg_ki);
         cmp(reg_ki, 0);
         jg(kd_label, T_NEAR);
-        L(skip_kd_loop);
 
         pop(reg_out);
         pop(reg_out_prf);
     }
-
-    store_output(ur_w);
 }
 
-void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(int ur_w,
         int pad_l, int pad_r)
 {
     bool prf_ker = true;
@@ -597,20 +592,19 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
     int ic_block = jcp.ic_block;
     int oc_block = jcp.oc_block;
     int nb_oc_block = jcp.nb_oc_blocking;
-    Label kh_label, kd_label, skip_kd_loop;
+    Label kh_label, kd_label;
 
     int ker_pipeline_depth = 4;
     assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
     assert(oc_block >= ker_pipeline_depth);
 
     int num_ker_loads = ic_block * nb_oc_block * kw;
-    const int simd_w = 16;
     int num_ker_prfs = prf_ker ? num_ker_loads : 0;
     int num_inp_prfs = prf_inp ?
             ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) :
             0;
     if (jcp.is_1stconv && prf_inp) {
-        num_inp_prfs = div_up(num_inp_prfs, simd_w) * ic_block;
+        num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block;
     }
     int num_prfs = num_ker_prfs + num_inp_prfs;
     int num_fmas = num_ker_loads * ur_w;
@@ -619,8 +613,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
     int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
     int inp_mul = !jcp.is_1stconv ? ic_block : 1;
 
-    prepare_output(ur_w);
-
     if (one_of(jcp.ndims, 3, 4)) {
         mov(aux_reg_inp, reg_inp);
         mov(aux_reg_ker, reg_ker);
@@ -643,20 +635,11 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
         mov(aux_reg_inp_d_prf, reg_inp_prf);
         mov(aux_reg_ker_d_prf, reg_ker_prf);
 
-        if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
-            cmp(reg_ki, 0);
-            je(skip_kd_loop, T_NEAR);
-        }
         L(kd_label);
         mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
     } else {
         mov(reg_kj, reg_kh);
     }
-    Label skip_kh_loop;
-    if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
-        cmp(reg_kj, 0);
-        je(skip_kh_loop, T_NEAR);
-    }
 
     if (jcp.ndims == 5) {
         mov(aux_reg_inp, aux_reg_inp_d);
@@ -676,7 +659,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
                 if (step == 0) {
                     for (int i = 0; i < ker_pipeline_depth; i++) {
                         aux_kernel_offset = get_kernel_offset(ki, ic, 0, i);
-                        vmovups(zmm_ker(i), EVEX_compress_addr(
+                        vmovups(vmm_ker(i), EVEX_compress_addr(
                                         aux_reg_ker, aux_kernel_offset));
                     }
                 } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
@@ -685,19 +668,19 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
                         = (step + load_offset) % ker_pipeline_depth;
                     aux_kernel_offset
                             = get_kernel_offset(ki, ic, 0, load_offset);
-                    vmovups(zmm_ker(ker_load_reg_idx),
+                    vmovups(vmm_ker(ker_load_reg_idx),
                             EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
                 }
 
                 bool ker_prf_inserted = false;
-                Zmm zmm_kernel = zmm_ker(step % ker_pipeline_depth);
+                Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth);
                 int j_start = get_ow_start(ki, pad_l);
                 int j_end = get_ow_end(ur_w, ki, pad_r);
                 for (int j = j_start; j < j_end; j++) {
                     size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l);
                     auto addr = EVEX_compress_addr_safe(aux_reg_inp,
                             aux_input_offset, reg_long_offt, true);
-                    vfmadd231ps(zmm_out(j, 0), zmm_kernel, addr);
+                    vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr);
                     int fma_idx = step * ur_w + j;
                     int prf_slot_idx = fma_idx / prf_inst_spacing;
                     if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
@@ -724,7 +707,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
                                     size_t ic_prf_stride =
                                         (size_t)jcp.typesize_in * iw * ih * id;
                                     size_t iw_prf_stride
-                                            = jcp.typesize_in * simd_w;
+                                            = jcp.typesize_in * jcp.simd_w;
                                     inp_prf_offset = ((inp_prf_idx / ic_block)
                                             * iw_prf_stride
                                             + (inp_prf_idx % ic_block)
@@ -752,7 +735,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
         jg(kh_label, T_NEAR);
     }
 
-    L(skip_kh_loop);
 
     if (jcp.ndims == 5) {
         add(aux_reg_inp_d,
@@ -767,16 +749,15 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
         dec(reg_ki);
         cmp(reg_ki, 0);
         jg(kd_label, T_NEAR);
-        L(skip_kd_loop);
 
         pop(reg_out);
         pop(reg_out_prf);
     }
     if (max_input_offset > INT_MAX) pop(reg_inp_prf);
-    store_output(ur_w);
 }
 
-void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(int ur_w,
     int pad_l, int pad_r)
 {
     int kw = jcp.kw;
@@ -784,7 +765,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
     int ic_block = jcp.ic_block;
     int oc_block = jcp.oc_block;
     int nb_oc_block = jcp.nb_oc_blocking;
-    Label kh_label, skip_kh_loop, kd_label, skip_kd_loop;
+    Label kh_label, kd_label;
     int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block
         * jcp.ic_block;
     int inp_mul = !jcp.is_1stconv ? ic_block : 1;
@@ -799,8 +780,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
                 * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id));
     };
 
-    prepare_output(ur_w);
-
     if (one_of(jcp.ndims, 3, 4)) {
         mov(aux_reg_inp, reg_inp);
         mov(aux_reg_ker, reg_ker);
@@ -813,19 +792,11 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
         mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
         mov(aux_reg_inp_d, reg_inp);
 
-        if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
-            cmp(reg_ki, 0);
-            je(skip_kd_loop, T_NEAR);
-        }
         L(kd_label);
         mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
     } else {
         mov(reg_kj, reg_kh);
     }
-    if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
-        cmp(reg_kj, 0);
-        je(skip_kh_loop, T_NEAR);
-    }
 
     if (jcp.ndims == 5) {
         mov(aux_reg_inp, aux_reg_inp_d);
@@ -841,7 +812,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
                 if (jcp.kernel_kind == expl_bcast) {
                     for (int jj = jj_start; jj < jj_end; jj++) {
                         size_t aux_input_offset = input_offset(jj, ic, ki);
-                        vbroadcastss(zmm_inp(jj, nb_oc_block),
+                        vbroadcastss(vmm_inp(jj, nb_oc_block),
                             EVEX_compress_addr_safe(aux_reg_inp,
                             aux_input_offset, reg_long_offt));
                     }
@@ -851,15 +822,15 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
                         * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block
                         * oc_block + ki * ic_block * oc_block + ic * oc_block);
                     if (jj_end - jj_start > 0)
-                        vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
+                        vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker,
                             aux_kernel_offset));
                     for (int jj = jj_start; jj < jj_end; jj++)
                         if (jcp.kernel_kind == expl_bcast)
-                            vfmadd231ps(zmm_out(jj, ii),
-                                zmm_inp(jj, nb_oc_block), zmm_wei);
+                            vfmadd231ps(vmm_out(jj, ii),
+                                vmm_inp(jj, nb_oc_block), vmm_wei);
                         else {
                             size_t aux_input_offset = input_offset(jj, ic, ki);
-                            vfmadd231ps(zmm_out(jj, ii), zmm_wei,
+                            vfmadd231ps(vmm_out(jj, ii), vmm_wei,
                                 EVEX_compress_addr_safe(aux_reg_inp,
                                 aux_input_offset, reg_long_offt, true));
                         }
@@ -872,7 +843,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
         cmp(reg_kj, 0);
         jg(kh_label, T_NEAR);
     }
-    L(skip_kh_loop);
 
     if (jcp.ndims == 5) {
         add(aux_reg_inp_d,
@@ -883,15 +853,19 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
         dec(reg_ki);
         cmp(reg_ki, 0);
         jg(kd_label, T_NEAR);
-        L(skip_kd_loop);
 
         pop(reg_out);
     }
+}
 
-    store_output(ur_w);
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_vnni(
+    int ur_w, int pad_l, int pad_r)
+{
 }
 
-void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
+template<>
+void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_vnni(
         int ur_w, int pad_l, int pad_r)
 {
     Label kh_label, kd_label;
@@ -908,7 +882,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
     assert(reg_inp_prf == reg_long_offt);
     if (max_input_offset > INT_MAX) push(reg_inp_prf);
 
-    prepare_output(ur_w);
 
     if (one_of(jcp.ndims, 3, 4)) {
         mov(aux_reg_inp, reg_inp);
@@ -917,8 +890,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
         mov(aux_reg_inp_prf, reg_inp_prf);
     }
 
-    Label skip_kh_loop, skip_kd_loop;
-
     if (jcp.ndims == 5) {
         push(reg_out_prf);
         push(reg_out);
@@ -929,19 +900,11 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
         mov(aux_reg_inp_d_prf, reg_inp_prf);
         mov(aux_reg_ker_d_prf, reg_ker_prf);
 
-        if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
-            cmp(reg_ki, 0);
-            je(skip_kd_loop, T_NEAR);
-        }
         L(kd_label);
         mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
     } else {
         mov(reg_kj, reg_kh);
     }
-    if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
-        cmp(reg_kj, 0);
-        je(skip_kh_loop, T_NEAR);
-    }
     if (jcp.ndims == 5) {
         mov(aux_reg_inp, aux_reg_inp_d);
         mov(aux_reg_ker, aux_reg_ker_d);
@@ -957,7 +920,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
                 if (jcp.kernel_kind == expl_bcast) {
                     for (int oi = ow_start; oi < ow_end; oi++) {
                         size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
-                        vpbroadcastd(zmm_inp(oi, jcp.nb_oc_blocking),
+                        vpbroadcastd(vmm_inp(oi, jcp.nb_oc_blocking),
                             EVEX_compress_addr_safe(aux_reg_inp, input_offset,
                             reg_long_offt));
                     }
@@ -965,7 +928,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
                 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
                     if (jcp.kernel_kind == expl_bcast) {
                         int kernel_offset = get_kernel_offset(ki, ic, kk, 0);
-                        vmovups(zmm_wei,
+                        vmovups(vmm_wei,
                             EVEX_compress_addr(aux_reg_ker, kernel_offset));
                     } else {
                         for (int ii = 0; ii < ker_load_number; ii++) {
@@ -979,12 +942,17 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
                     for (int oi = ow_start, prf_count = 0; oi < ow_end; oi++) {
                         size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
                         if (jcp.kernel_kind == expl_bcast) {
-                            vpdpwssd(zmm_out(oi, kk), zmm_wei,
-                                zmm_inp(oi, jcp.nb_oc_blocking));
+                            vpdpwssd(vmm_out(oi, kk), vmm_wei,
+                                vmm_inp(oi, jcp.nb_oc_blocking));
                         } else {
-                            vpXdpwssd(zmm_out(oi, kk), Zmm(ker_reg_base_idx),
-                            EVEX_compress_addr_safe(aux_reg_inp, input_offset,
-                            reg_long_offt, jcp.ver != ver_4vnni));
+                            if (jcp.ver == ver_4vnni)
+                                vp4dpwssd(vmm_out(oi, kk), Zmm(ker_reg_base_idx),
+                                EVEX_compress_addr_safe(aux_reg_inp,
+                                    input_offset, reg_long_offt, false));
+                            else
+                                vpdpwssd(vmm_out(oi, kk), Zmm(ker_reg_base_idx),
+                                EVEX_compress_addr_safe(aux_reg_inp,
+                                    input_offset, reg_long_offt, true));
                         }
                         if ((oi % 2) && (prf_count < ker_load_number)) {
                             int kernel_offset = get_kernel_offset(
@@ -1014,8 +982,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
         jg(kh_label, T_NEAR);
     }
 
-    L(skip_kh_loop);
-
     if (jcp.ndims == 5) {
         add(aux_reg_inp_d, jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block);
         add(aux_reg_ker_d, jcp.typesize_in * jcp.kw * jcp.kh * jcp.oc_block
@@ -1027,19 +993,37 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
         dec(reg_ki);
         cmp(reg_ki, 0);
         jg(kd_label, T_NEAR);
-        L(skip_kd_loop);
 
         pop(reg_out);
         pop(reg_out_prf);
     }
     if (max_input_offset > INT_MAX) pop(reg_inp_prf);
-    store_output(ur_w);
 }
 
-void jit_avx512_common_conv_fwd_kernel::compute_loop(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(int ur_w,
         int pad_l, int pad_r)
 {
     if (jcp.ndims == 5) push(reg_oi);
+
+    prepare_output(ur_w);
+
+    Label skip_compute_loop;
+    if (jcp.ndims == 5) {
+        if ((jcp.dilate_d >= jcp.id)
+                || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
+            mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
+            cmp(reg_kj, 0);
+            je(skip_compute_loop, T_NEAR);
+        }
+    }
+    if ((jcp.dilate_h >= jcp.ih)
+            || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
+        mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
+        cmp(reg_kj, 0);
+        je(skip_compute_loop, T_NEAR);
+    }
+
     if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
         compute_loop_vnni(ur_w, pad_l, pad_r);
     else if (jcp.ver == ver_4fma)
@@ -1058,17 +1042,15 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop(int ur_w,
                 compute_loop_fma_core(ur_w, pad_l, pad_r);
     else
         assert(!"unknown convolution version");
+
+    L(skip_compute_loop);
+    store_output(ur_w);
     if (jcp.ndims == 5) pop(reg_oi);
 }
 
-void jit_avx512_common_conv_fwd_kernel::generate()
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate()
 {
-    if (jcp.with_eltwise) {
-        eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
-                this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
-        ));
-    }
-
     const auto &p = attr_.post_ops_;
     for (int i = 0; i < p.len_; i++) {
         auto &post_op = p.entry_[i];
@@ -1318,17 +1300,10 @@ bool jit_avx512_common_conv_fwd_kernel::post_ops_ok(
     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
-    case 0: return true; // no post_ops
-    case 1:
-        return true // sum OR eltwise OR depthwise
-                && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
-    case 2:
-        return true // sum->relu
-                && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
-                                         (is_simple(0) && is_simple(1)));
-    case 3:
-        return true // sum->relu
-                && !jcp.with_eltwise && (is_sum(0) && is_simple(1) && is_simple(2));
+    case 0: return true;
+    case 1: return is_simple(0) || is_sum(0);
+    case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
+    case 3: return is_sum(0) && is_simple(1) && is_simple(2);
     default: return false;
     }
 
@@ -1336,25 +1311,22 @@ bool jit_avx512_common_conv_fwd_kernel::post_ops_ok(
 }
 
 status_t jit_avx512_common_conv_fwd_kernel::init_conf(
-            jit_conv_conf_t &jcp,
-            const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
-            cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
-            cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
-            int nthreads, bool with_relu, float relu_negative_slope)
+            jit_conv_conf_t &jcp, const convolution_desc_t &cd,
+            cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
+            cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
+            const primitive_attr_t &attr, int nthreads)
 {
     using namespace prop_kind;
 
     if (!mayiuse(avx512_common))
         return status::unimplemented;
 
-    const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
-
     const memory_desc_wrapper src_d(&src_pd);
     const memory_desc_wrapper weights_d(&weights_pd);
     const memory_desc_wrapper dst_d(&dst_pd);
     const memory_desc_wrapper bias_d(&bias_pd);
 
-    int regs = 28;
+    const int regs = 28;
     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
     int ndims = src_d.ndims();
 
@@ -1382,9 +1354,6 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
     jcp.stride_w = cd.strides[ndims-3];
     jcp.src_fmt = src_d.format();
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alg = mkldnn_eltwise_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
 
     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
@@ -1397,14 +1366,26 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
 
     jcp.is_1stconv = is_1stconv(jcp);
 
-    jcp.oc_block = simd_w;
-    jcp.ic_block = jcp.is_1stconv ? jcp.ic : simd_w;
-    jcp.aligned_threads = 0;
-
     bool ok_to_pad_channels = true
         && jcp.ngroups == 1
         && src_d.data_type() == data_type::f32;
 
+    const int full_simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+    jcp.simd_w = full_simd_w;
+    bool ok_to_try_xmm = true
+        && mayiuse(avx512_core)
+        && src_d.data_type() == data_type::f32
+        && !jcp.is_1stconv
+        && !ok_to_pad_channels
+        && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0)
+        && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0);
+    if (ok_to_try_xmm)
+        jcp.simd_w = 4;
+
+    jcp.oc_block = jcp.simd_w;
+    jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
+    jcp.aligned_threads = 0;
+
     if (ok_to_pad_channels) {
         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
@@ -1420,14 +1401,28 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
 
     const auto &p = attr.post_ops_;
     jcp.with_sum = p.find(primitive_kind::sum) != -1;
+    const int eltwise_ind = p.find(primitive_kind::eltwise);
+    jcp.with_eltwise = eltwise_ind != -1;
+    if (jcp.with_eltwise) {
+        jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+        if (dst_d.data_type() == data_type::s32) return status::unimplemented;
+    }
 
     auto src_format = jcp.is_1stconv
         ? pick(ndims - 3, ncw, nchw, ncdhw)
+        : ((jcp.simd_w == 4)
+            ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
+            : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c));
+    auto dst_format = (jcp.simd_w == 4)
+        ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
         : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
-    auto dst_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
     auto wei_format = with_groups
-        ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
-        : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
+        ? ((jcp.simd_w == 4)
+            ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o)
+            : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o))
+        : ((jcp.simd_w == 4)
+            ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o)
+            : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o));
 
     if (src_d.format() == any)
         CHECK(src_pd.set_format(src_format));
@@ -1491,16 +1486,24 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
                 jcp.ver = ver_fma;
             if (jcp.ver == ver_4fma) {
                 const auto w_format = with_groups
-                    ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)
-                    : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o);
+                    ? ((jcp.simd_w == 4)
+                        ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o)
+                        : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o))
+                    : ((jcp.simd_w == 4)
+                        ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o)
+                        : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o));
                 if (weights_d.format() == any)
                     CHECK(weights_pd.set_format(w_format));
                 if (weights_d.format() != w_format)
                     return status::unimplemented;
             } else {
                 const auto w_format = with_groups
-                    ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
-                    : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
+                    ? ((jcp.simd_w == 4)
+                        ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o)
+                        : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o))
+                    : ((jcp.simd_w == 4)
+                        ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o)
+                        : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o));
                 if (weights_d.format() == any)
                     CHECK(weights_pd.set_format(w_format));
                 if (weights_d.format() != w_format)
@@ -1561,10 +1564,25 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
         }
     }
 
+    /* Grouped channel offset to support 'non-blocked data' format for
+     * convolution sizes with '(input_channel / ngroups) < simd' */
+    jcp.nonblk_group_off
+            = (jcp.ngroups > 1 && one_of(src_d.format(), ncw, nchw, ncdhw)) ?
+            jcp.ic :
+            1;
+
     jcp.nb_ic = jcp.ic / jcp.ic_block;
     jcp.nb_oc = jcp.oc / jcp.oc_block;
     jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
 
+    auto is_ow_threading_applicable = [=]() {
+        return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4)
+                && IMPLICATION(mayiuse(avx512_mic),
+                           jcp.ver == ver_4fma
+                                   && IMPLICATION(jcp.mb != 1,
+                                              jcp.ih == 1 && jcp.kh == 1)));
+    };
+
     if (jcp.ver == ver_4vnni) {
         jcp.kernel_kind = embd_bcast;
     }
@@ -1593,9 +1611,13 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
     }
 
     if (one_of(jcp.ver, ver_4vnni, ver_4fma) && !jcp.is_1stconv) {
-        if (jcp.kw == 3 && jcp.kh == 3 && jcp.ow == 7 && jcp.oh == 7) {
-            if (jcp.nb_oc % 2 == 0)
+        if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8
+                    && jcp.oh <= 8 && jcp.ow == jcp.oh)
+                || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) {
+            if (jcp.nb_oc % 2 == 0) {
                 jcp.nb_oc_blocking = 2;
+                jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
+            }
         } else {
             for (int i = jcp.nb_oc; i > 0; i--)
                 if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) {
@@ -1603,15 +1625,74 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
                     break;
                 }
         }
-        if (jcp.ver == ver_4fma
-            && is_1D_conv(jcp) && one_of(jcp.ndims, 3, 4)) {
-            if (jcp.nb_oc % 2 == 0) {
+        if (jcp.ver == ver_4fma && is_ow_threading_applicable()) {
+            if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow
+                    && jcp.ow != 2 * jcp.ur_w) {
                 jcp.nb_oc_blocking = 2;
                 jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
             }
         }
     }
 
+    jcp.ow_block = jcp.ow;
+
+    auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) {
+        int nb_ow = div_up(jcp.ow, ow_block);
+        int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking);
+        int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow;
+        float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block);
+        float thr_eff = disbalance * (float)work_amount
+            / rnd_up(work_amount, nthreads);
+        return thr_eff;
+    };
+
+    auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) {
+        int res_ow_block = jcp.ow;
+        eff = get_thr_eff(nb_oc_blocking, res_ow_block);
+        if (!is_ow_threading_applicable())
+            return res_ow_block;
+
+        int L2_part = (get_cache_size(2) * 7 / 8) / typesize;
+        if (jcp.ver == ver_4fma)
+            L2_part /= 2;
+        int size_src_chunk = jcp.ic_block * ur_w * jcp.kh;
+        int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w;
+        int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block
+            * jcp.kw * jcp.kh;
+        int nurw_cache = (L2_part - 2 * size_wei_chunk)
+            / (2 * size_dst_chunk + 2 * size_src_chunk);
+        // current design of generate() requires ow_block >= 2 * ur_w
+        int ow_block_cache = ur_w * nstl::max(2, nurw_cache);
+
+        int ow_block_thr = ow_block_cache;
+        eff = get_thr_eff(nb_oc_blocking, ow_block_thr);
+
+        int max_nb_ow = div_up(jcp.ow, 2 * ur_w);
+        int start_nb_ow = div_up(jcp.ow, ow_block_thr);
+        for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) {
+            int ow_block
+                = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
+            float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f;
+            if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold)
+                break;
+            if (div_up(jcp.ow, ow_block) != nb_ow)
+                continue;
+            float thr_eff = get_thr_eff(nb_oc_blocking, ow_block);
+            float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f;
+            if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) {
+                ow_block_thr = ow_block;
+                eff = thr_eff;
+            }
+            eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f;
+            if (eff > eff_threshold)
+                break;
+        }
+        res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr));
+        eff = get_thr_eff(nb_oc_blocking, res_ow_block);
+        return res_ow_block;
+    };
+
+
     if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
         int try_nb_oc_blocking = 2;
         unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w)
@@ -1629,7 +1710,6 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
             && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512);
 
         if (jcp.mb == 1) {
-            jcp.kernel_kind = embd_bcast;
             unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h)
                     * div_up(jcp.iw, jcp.stride_w) * jcp.ic;
             unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw;
@@ -1662,59 +1742,52 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
                     }
                 }
             }
-        } else if (jcp.kw > 3
-            || (jcp.stride_w == 1 && jcp.stride_h == 1
-                && embd_bcast_condition)
-            || ((jcp.stride_w != 1 || jcp.stride_h != 1)
-                && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10)
-                     && embd_bcast_condition)))
-            ) {
+        }
+
+        if (jcp.kw > 3
+                || (jcp.stride_w == 1 && jcp.stride_h == 1
+                           && embd_bcast_condition)
+                || ((jcp.stride_w != 1 || jcp.stride_h != 1)
+                           && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10)
+                                      && embd_bcast_condition)))
+                || (jcp.mb == 1
+                           && (jcp.ur_w >= jcp.ow || jcp.is_1stconv
+                                      || (jcp.ow <= 147 && jcp.oc <= 96)))) {
             jcp.kernel_kind = embd_bcast;
             jcp.ur_w = nstl::min(jcp.ow, regs);
             jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
             if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3
-                && jcp.kw <= 3) {
-                if (jcp.nb_oc % try_nb_oc_blocking == 0 && !jcp.is_1stconv) {
-                    jcp.nb_oc_blocking = try_nb_oc_blocking;
-                    jcp.ur_w = 31 / (jcp.nb_oc_blocking + 1);
-                    if (jcp.ow < jcp.ur_w)
-                        jcp.ur_w = jcp.ow;
-                }
+                    && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0
+                    && IMPLICATION(jcp.is_1stconv, jcp.mb == 1)
+                    && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) {
+                jcp.nb_oc_blocking = try_nb_oc_blocking;
+                jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
             }
         } else {
             jcp.kernel_kind = expl_bcast;
             jcp.nb_ic_blocking = 1;
-            jcp.nb_oc_blocking = 4;
-            if (jcp.nb_oc < jcp.nb_oc_blocking) jcp.nb_oc_blocking = jcp.nb_oc;
-            if (jcp.nb_oc % jcp.nb_oc_blocking != 0)
-                for (int i = jcp.nb_oc_blocking; i > 0; i--)
+            if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) {
+                float best_thr_eff = 0.f;
+                int best_nb_oc_blocking = 1;
+                for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) {
                     if (jcp.nb_oc % i == 0) {
-                        jcp.nb_oc_blocking = i;
-                        break;
+                        float thr_eff;
+                        int ur_w = nstl::min(jcp.ow, 31 / (i + 1));
+                        get_ow_block(i, ur_w, thr_eff);
+                        if (thr_eff > 1.05f * best_thr_eff) {
+                            best_nb_oc_blocking = i;
+                            best_thr_eff = thr_eff;
+                        }
                     }
-            jcp.ur_w = 31 / (jcp.nb_oc_blocking + 1);
-            if (jcp.ow < jcp.ur_w)
-                jcp.ur_w = jcp.ow;
+                }
+                jcp.nb_oc_blocking = best_nb_oc_blocking;
+                jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
+            }
         }
     }
 
     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
 
-    jcp.ow_block = jcp.ow;
-    if (is_ow_threading_available(jcp)) {
-        const int L1_part = get_cache_size(1) * 5 / 8;
-        int size_src_chunk = typesize * jcp.ic_block * jcp.ur_w;
-        int size_dst_chunk = typesize
-            * jcp.oc_block * jcp.nb_oc_blocking * jcp.ur_w;
-        int size_wei_chunk = typesize
-            * jcp.oc_block * jcp.ic_block * jcp.nb_oc_blocking * jcp.kw;
-        int nurw = (L1_part - size_wei_chunk)
-            / (size_dst_chunk + size_src_chunk);
-        // current design of generate() requires ow_block >= 2 * ur_w
-        jcp.ow_block = jcp.ur_w * nstl::max(2, nurw);
-    }
-    jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
-
     args_ok = true
         && jcp.l_pad <= jcp.ur_w
         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
@@ -1734,10 +1807,14 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
 
     jcp.nb_ic_L2 = jcp.nb_ic;
 
+    float thr_eff;
+    jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff);
+    jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
+
     const int L2_size = get_cache_size(2, true) / sizeof(float);
     // Source and output data needs to fit in L2,
     // leaving some space for weights and prefetching.
-    int h_L2 = int(((0.6f * L2_size) / simd_w
+    int h_L2 = int(((0.6f * L2_size) / jcp.simd_w
                            - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
             / (jcp.stride_h * jcp.iw + jcp.ow));
     jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
@@ -1765,7 +1842,7 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
                     break;
                 }
             }
-        } else {
+        } else if (jcp.ic > 64) {
             jcp.nb_ic_L2 = 2; /* according to performance data*/
         }
     }
@@ -1773,6 +1850,12 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
     return status::success;
 }
 
+void jit_avx512_common_conv_fwd_kernel::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+    if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+        scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
+}
+
 void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w)
 {
     for (int k = 0; k < jcp.nb_ic_blocking; k++) {
@@ -1826,7 +1909,7 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
     int kw = jcp.kw;
     int ic_block = jcp.ic_block;
     int oc_block = jcp.oc_block;
-    Label kh_label, last_iter_label, loop_end_label, kd_label, skip_kd_loop;
+    Label kh_label, last_iter_label, loop_end_label, kd_label;
     int ker_load_number = 4;
     int shift_ker_ptr = typesize * kw * oc_block * ic_block;
     int shift_dst_ptr = typesize * ow * oc_block;
@@ -1857,8 +1940,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
         }
     };
 
-    prepare_output(ur_w);
-
     if (one_of(jcp.ndims, 3, 4)) {
         mov(aux_reg_dst, reg_dst);
         mov(aux_reg_ker, reg_ker);
@@ -2004,13 +2085,10 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
         dec(reg_ki);
         cmp(reg_ki, 0);
         jg(kd_label, T_NEAR);
-        L(skip_kd_loop);
 
         pop(reg_src);
         pop(reg_src_prf);
     }
-
-    store_output(ur_w);
 }
 
 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni(
@@ -2031,8 +2109,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni(
         return jcp.typesize_in * (blk_offset + oc_offset);
     };
 
-    prepare_output(ur_w);
-
     mov(aux_reg_dst, reg_dst);
     mov(aux_reg_ker, reg_ker);
     mov(aux_reg_dst_prf, reg_dst_prf);
@@ -2108,15 +2184,12 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni(
         cmp(reg_kj, 0);
         jg(kh_label, T_NEAR);
     }
-
-    store_output(ur_w);
 }
 
 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
         int ur_w, int l_overflow, int r_overflow)
 {
-    Label kh_label, kd_label, skip_kd_loop;
-    Label store_output_label;
+    Label kh_label, kd_label;
     int kw = jcp.kw;
     int ow = jcp.ow;
 
@@ -2139,8 +2212,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
     int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs);
     int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
 
-    prepare_output(ur_w);
-
     if (one_of(jcp.ndims, 3, 4)) {
         mov(aux_reg_dst, reg_dst);
         mov(aux_reg_ker, reg_ker);
@@ -2154,9 +2225,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
         push(reg_src);
 
         mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
-        cmp(reg_ki, 0);
-        je(store_output_label, T_NEAR);
-
         mov(aux_reg_dst_d, reg_dst);
         mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
         mov(aux_reg_dst_d_prf, reg_dst_prf);
@@ -2167,8 +2235,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
     } else {
         mov(reg_kj, reg_kh);
     }
-    cmp(reg_kj, 0);
-    je(store_output_label, T_NEAR);
 
     if (jcp.ndims == 5) {
         mov(aux_reg_dst, aux_reg_dst_d);
@@ -2268,16 +2334,12 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
         dec(reg_ki);
         cmp(reg_ki, 0);
         jg(kd_label, T_NEAR);
-        L(skip_kd_loop);
     }
 
-    L(store_output_label); {
-        if (jcp.ndims == 5)
-        {
-            pop(reg_src);
-            pop(reg_src_prf);
-        }
-        store_output(ur_w);
+    if (jcp.ndims == 5)
+    {
+        pop(reg_src);
+        pop(reg_src_prf);
     }
 }
 
@@ -2291,7 +2353,7 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
     int ic_block = jcp.ic_block;
     int oc_block = jcp.oc_block;
     int nb_ic_block = jcp.nb_ic_blocking;
-    Label kh_label, skip_kh_loop, kd_label, skip_kd_loop;
+    Label kh_label, kd_label;
 
     int shift_ker_ptr = typesize * kw * oc_block * ic_block;
     int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block;
@@ -2307,8 +2369,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
         return typesize * (blk_offset + oc_offset);
     };
 
-    prepare_output(ur_w);
-
     if (one_of(jcp.ndims, 3, 4)) {
         mov(aux_reg_dst, reg_dst);
         mov(aux_reg_ker, reg_ker);
@@ -2327,8 +2387,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
     } else {
         mov(reg_kj, reg_kh);
     }
-    cmp(reg_kj, 0);
-    je(skip_kh_loop, T_NEAR);
 
     if (jcp.ndims == 5) {
         mov(aux_reg_dst, aux_reg_dst_d);
@@ -2370,7 +2428,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
         cmp(reg_kj, 0);
         jg(kh_label, T_NEAR);
     }
-    L(skip_kh_loop);
 
     if (jcp.ndims == 5) {
         sub(aux_reg_dst_d,
@@ -2380,19 +2437,29 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
         dec(reg_ki);
         cmp(reg_ki, 0);
         jg(kd_label, T_NEAR);
-        L(skip_kd_loop);
 
         pop(reg_src);
         pop(reg_src_prf);
     }
-
-    store_output(ur_w);
 }
 
 inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop(
         int ur_w, int l_overflow, int r_overflow)
 {
     if (jcp.ndims == 5) push(reg_oi);
+
+    prepare_output(ur_w);
+
+    Label skip_compute_loop;
+    if (jcp.ndims == 5) {
+        mov(reg_kj, ptr[param + GET_OFF(kd_padding)]);
+        cmp(reg_kj, 0);
+        je(skip_compute_loop, T_NEAR);
+    }
+    mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
+    cmp(reg_kj, 0);
+    je(skip_compute_loop, T_NEAR);
+
     if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
         compute_loop_vnni(ur_w, l_overflow, r_overflow);
     else if (jcp.ver == ver_4fma)
@@ -2407,6 +2474,9 @@ inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop(
               compute_loop_fma_core(ur_w, l_overflow, r_overflow);
     else
         assert("!unknown convolution version");
+
+    L(skip_compute_loop);
+    store_output(ur_w);
     if (jcp.ndims == 5) pop(reg_oi);
 }
 
@@ -2504,7 +2574,9 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
 {
     if (!mayiuse(avx512_common)) return status::unimplemented;
 
-    const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+    jcp = zero<decltype(jcp)>();
+
+    jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
     int ndims = diff_src_d.ndims();
 
@@ -2556,8 +2628,8 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
 
     jcp.is_1stconv = false;
 
-    jcp.oc_block = simd_w;
-    jcp.ic_block = jcp.is_1stconv ? jcp.ic : simd_w;
+    jcp.oc_block = jcp.simd_w;
+    jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
 
     bool ok_to_pad_channels = true
         && jcp.ngroups == 1
@@ -2777,8 +2849,15 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
         && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
         && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
+    if (!args_ok) return status::unimplemented;
 
-    return args_ok ? status::success : status::unimplemented;
+    return status::success;
+}
+
+void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+    UNUSED(scratchpad);
+    UNUSED(jcp);
 }
 
 const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28;
@@ -4464,13 +4543,10 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::generate()
 status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
     jit_conv_conf_t &jcp, const convolution_desc_t &cd,
     cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &diff_weights_pd,
-    cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd)
-{
+    cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd) {
     if (!mayiuse(avx512_common))
         return status::unimplemented;
 
-    const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
-
     const memory_desc_wrapper src_d(&src_pd);
     const memory_desc_wrapper diff_weights_d(&diff_weights_pd);
     const memory_desc_wrapper diff_bias_d(&diff_bias_pd);
@@ -4480,6 +4556,8 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
     int ndims = src_d.ndims();
 
     jcp = zero<decltype(jcp)>();
+
+    jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
     jcp.ndims = ndims;
     jcp.prop_kind = cd.prop_kind;
 
@@ -4545,14 +4623,14 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
     /* check for the 1st convolution */
     jcp.is_1stconv = is_1stconv(jcp);
 
-    jcp.oc_block = simd_w;
+    jcp.oc_block = jcp.simd_w;
 
     bool ok_to_pad_channels = true
         && jcp.ngroups == 1
         && src_d.data_type() == data_type::f32;
 
     if (ok_to_pad_channels)
-        jcp.oc = rnd_up(jcp.oc, simd_w);
+        jcp.oc = rnd_up(jcp.oc, jcp.simd_w);
 
     if (jcp.oc % jcp.oc_block)
         return status::unimplemented;
@@ -4628,7 +4706,7 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
             && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad)
             && jcp.kw <= 28 - jcp.with_bias
             && jcp.stride_w == 4
-            && tr_ld / simd_w <= 4 /* [bwd_w:tr_src:r1] */
+            && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */
             && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */
             && IMPLICATION(diff_weights_d.format() != any,
                     diff_weights_d.format() == want_4fma_wfmt);
@@ -4667,7 +4745,7 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
         if (!ok)
             return status::unimplemented;
 
-        jcp.ic_block = simd_w;
+        jcp.ic_block = jcp.simd_w;
         if (ok_to_pad_channels)
             jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
         jcp.nb_ic = jcp.ic / jcp.ic_block;
@@ -4735,10 +4813,209 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
         && jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
         && jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
+    if (!args_ok) return status::unimplemented;
+
+    {   // balancing
+        int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
+        balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
+        jcp.nthr = nthr;
+        jcp.nthr_mb = nthr_mb;
+        jcp.nthr_g = nthr_g;
+        jcp.nthr_oc_b = nthr_oc_b;
+        jcp.nthr_ic_b = nthr_ic_b;
+    }
+
+    return status::success;
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+    if (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
+        if (jcp.is_1stconv) {
+            const size_t tr_src_size =
+                jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld;
+            scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
+        } else {
+            // XXX: See the comment about tr_iw and guarding elements in
+            // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
+            const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic;
+            const size_t min_tr_src_size_per_thr
+                = jcp.ih * jcp.ic_block * jcp.tr_iw;
+            const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr
+                + jcp.tr_src_num_guard_elems;
+            scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
+        }
+
+        /* prepare synchronization contexts */
+        if (jcp.nthr_oc_b > 1) {
+            const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
+            scratchpad.book(key_conv_tr_src_bctx,
+                    sizeof(simple_barrier::ctx_t) * tr_src_bctx_size);
+        }
+
+        if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
+            const size_t tr_diff_dst_size = jcp.nthr_mb * jcp.ngroups
+                * jcp.nb_oc * jcp.oc_block * jcp.tr_ow * jcp.oh;
+            scratchpad.book(key_conv_tr_diff_dst,
+                    jcp.typesize_in * tr_diff_dst_size);
+
+            /* prepare synchronization contexts */
+            if (jcp.nthr_ic_b > 1) {
+                const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
+                scratchpad.book(key_conv_tr_diff_dst_bctx,
+                        sizeof(simple_barrier::ctx_t) * tr_diff_dst_bctx_size);
+            }
+        }
+    }
+
+    if (jcp.nthr_mb > 1) {
+        const int wei_size = jcp.ngroups * jcp.oc * jcp.ic
+            * jcp.kh * jcp.kw * jcp.kd;
+        const int bia_size = jcp.ngroups * jcp.oc;
+        const size_t wei_bia_reduction_size = wei_size + bia_size;
+
+        scratchpad.book(key_conv_wei_bia_reduction,
+                jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1));
+        scratchpad.book(key_conv_wei_bia_reduction_bctx,
+                sizeof(simple_barrier::ctx_t));
+    }
+
+    if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+        scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
+}
 
-    return args_ok ? status::success : status::unimplemented;
+void jit_avx512_common_conv_bwd_weights_kernel_f32::balance(
+        const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
+        int &nthr_oc_b_, int &nthr_ic_b_)
+{
+    nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
+
+    const int max_threads = mkldnn_get_max_threads();
+
+    if (max_threads < j.ngroups) {
+        /* simplification... fortunately it doesn't hurt much */
+        return;
+    }
+
+    if (!mkldnn_thr_syncable()
+            && utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
+        // should not happen -- the driver is not ready
+        // for TBB-like non-synchronous threading yet
+        return;
+    }
+
+    if (j.ver == ver_4fma && j.is_1stconv) {
+        nthr_g_ = 1;
+        nthr_oc_b_ = 1;
+        nthr_ic_b_ = nstl::min(j.nb_ic, max_threads);
+        nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb);
+        nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_;
+        return;
+    }
+
+    nthr_g_ = j.ngroups;
+    const int nthr = max_threads / nthr_g_;
+
+    auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
+        /* calculate per thread memory cost (read/write). high level optimizer
+         * tries to minimize memory consumption. few notes:
+         *  (n1) unclear why, but that essentially helps first convolution...
+         *  (n2) assuming the reduction over minibatch is always there:
+         *    - instead of 8 it should be 5 here (write ~= 2 read):
+         *      kernel: temporal workspace 1 write
+         *      reduction: 1 read from workspace and 1 write to the diff_wei
+         *    - but experiments showed 8 works better than 5 or 6... */
+
+        const int src_coef = j.ver == ver_4fma || j.ver == ver_vnni ? 4 : 1;
+        const int dst_coef = 1;
+        const int wei_coef = j.ver == ver_vnni ? 4 : 8;
+
+        return 0
+            + src_coef
+            * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
+            * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
+            / j.stride_d / j.stride_h / j.stride_w /* (n1) */
+            + dst_coef
+            * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
+            * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
+            + wei_coef /* (n2) */
+            * div_up(j.ngroups, nthr_g_)
+            * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
+            * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
+    };
+
+    int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
+
+    /* step 1: find the best thread distribution with lowest memory cost */
+    const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
+    for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
+        const int nthr_par = nthr / nthr_mb;
+        const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
+        for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
+            int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
+
+            int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+            if (mem_cost <= best_mem_cost) {
+                best_mem_cost = mem_cost;
+                nthr_mb_ = nthr_mb;
+                nthr_oc_b_ = nthr_oc_b;
+                nthr_ic_b_ = nthr_ic_b;
+            }
+        }
+
+        if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
+    }
+
+    if (j.ver != ver_vnni && !mayiuse(avx512_mic)) {
+        auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
+            return 1
+                * div_up(j.mb, nthr_mb)
+                * div_up(j.ngroups, nthr_g_)
+                * div_up(j.nb_oc, nthr_oc_b)
+                * div_up(j.nb_ic, nthr_ic_b);
+        };
+
+        /* step 2: search for a thread distribution with lower compute cost.
+         * the constrains:
+         *  - memory cost cannot exceed 110% of the best found in the step 1
+         *  - unless compute cost is 133% lower than the current best case
+         * note: both constants were found empirically */
+        int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
+        for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
+            const int nthr_par = nthr / nthr_mb;
+            const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
+            for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
+                int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
+                int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+                int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+
+                const bool opt1 = comp_cost <= best_comp_cost
+                    && mem_cost < 1.1 * best_mem_cost;
+                const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
+
+                if (opt1 || opt2) {
+                    best_comp_cost = comp_cost;
+                    nthr_mb_ = nthr_mb;
+                    nthr_oc_b_ = nthr_oc_b;
+                    nthr_ic_b_ = nthr_ic_b;
+                }
+            }
+
+            if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
+        }
+    }
+
+    if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
+        nthr_mb_ = nstl::min(j.mb * j.od, max_threads);
+    nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
+
+    assert(nthr_ <= max_threads);
+    assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
 }
 
+template struct  _jit_avx512_common_conv_fwd_kernel<Zmm>;
+template struct  _jit_avx512_common_conv_fwd_kernel<Xmm>;
+
 }
 }
 }