Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_dw_conv_kernel_f32.cpp
index 0d97cce..db6454c 100644 (file)
@@ -30,6 +30,7 @@ namespace cpu {
 
 using namespace mkldnn::impl::prop_kind;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 using namespace Xbyak;
@@ -183,13 +184,6 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_postprocess(int ur_ch_blocks, in
     int depthwise_inj_idx = 0;
     const auto &p = attr_.post_ops_;
 
-    if (p.len_ == 0 && eltwise_injectors.size() == 1) {
-        int start_idx = get_acc_reg(0).getIdx();
-        int end_idx = get_acc_reg(repeats * ur_w * ur_ch_blocks).getIdx();
-
-        eltwise_injectors[0]->compute_vector_range(start_idx, end_idx);
-    }
-
     for (int i = 0; i < p.len_; i++) {
         auto& post_op = p.entry_[i];
         if (post_op.is_eltwise()) {
@@ -293,14 +287,7 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
 }
 
 template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate()
-{
-    if (jcp.with_eltwise) {
-        eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
-                this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
-        ));
-    }
-
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() {
     const auto &p = attr_.post_ops_;
     for (int i = 0; i < p.len_; i++) {
         auto &post_op = p.entry_[i];
@@ -369,14 +356,10 @@ bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok(
     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
-    case 0: return true; // no post_ops
-    case 1: return true  // sum OR eltwise OR deptwise
-                    && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
-    case 2: return true // sum->relu OR sum->depthwise OR eltwise->depthwise OR depthwise->depthwise
-                    && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
-                                             (is_simple(0) && is_simple(1)));
-    case 3: return true // sum->eltwise->depthwise OR sum->depthwise->eltwise OR sum->depthwise->depthwise
-                   && !jcp.with_eltwise && ((is_sum(0) && is_simple(1) && is_simple(2)));
+    case 0: return true;
+    case 1: return is_simple(0) || is_sum(0);
+    case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
+    case 3: return is_sum(0) && is_simple(1) && is_simple(2);
     default: return false;
     }
 
@@ -387,7 +370,7 @@ template <cpu_isa_t isa>
 status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
-        const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+        const primitive_attr_t &attr)
 {
     if (!mayiuse(isa)) return status::unimplemented;
 
@@ -426,9 +409,6 @@ status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
 
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alg = mkldnn_eltwise_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
 
     if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
@@ -473,6 +453,13 @@ status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
     return status::success;
 }
 
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+    if (jcp.with_bias && jcp.oc_without_padding != jcp.oc)
+        scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+}
+
 template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_common>;
 template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>;
 template struct jit_uni_dw_conv_fwd_kernel_f32<sse42>;
@@ -754,6 +741,13 @@ status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf(
     return status::success;
 }
 
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+    UNUSED(scratchpad);
+    UNUSED(jcp);
+}
+
 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_common>;
 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>;
 template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse42>;
@@ -776,7 +770,7 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_filter() {
             int off_filter = (reg_set + i) * simd_w;
             Vmm vmm_acc = get_acc_reg(reg_set + i);
             uni_vmovups(vmm_acc,
-                    vmmword[tmp_reg_filter + off_filter * sizeof(float)]);
+                    vmmword[reg_tmp_filter + off_filter * sizeof(float)]);
         }
     }
 }
@@ -800,58 +794,59 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_bias() {
 
 template <cpu_isa_t isa>
 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
-        int l_pad, int r_pad, int pad_offset, int ow_block) {
-    const int pad = nstl::max(jcp.l_pad, jcp.r_pad);
-    const int iw_overlap = jcp.iw + jcp.kw - 1 - jcp.l_pad - jcp.r_pad;
-    const int unroll_w = nstl::min(jcp.ur_w, iw_overlap);
-    const int right_border = iw_overlap - ow_block;
+        int unroll_w, int l_pad, int pad_offset, int ow_block) {
+
+    const int iw_block = ow_block * jcp.stride_w;
+    const int right_border = jcp.iw - iw_block;
+
+    const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
 
     /* preamble count for number of cascaded LOAD + FMA operation */
-    const int input_preamble_count
-            = nstl::max(jcp.kw - jcp.stride_w - l_pad, 0);
+    const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
 
     /* LOAD initial input registers, then cascade LOADs and FMAs*/
     for (int r = 0; r < reg_repeats; ++r) {
-        for (int i = 0; i < input_preamble_count; i++) {
-            int off_input = ((i - pad_offset) * reg_repeats + r) * simd_w;
-            Vmm vmm_input = get_input_reg((i + l_pad) * reg_repeats + r);
-            uni_vmovups(vmm_input,
-                    ptr[tmp_reg_idx_input + off_input * sizeof(float)]);
-        }
-
-        for (int i = 0; i < unroll_w; ++i) {
-            int off_output = (i * reg_repeats + r) * simd_w;
+        for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
+            int off_output = (i_ur * reg_repeats + r) * simd_w;
             Vmm vmm_output = get_output_reg(r);
             uni_vmovups(vmm_output,
-                    ptr[tmp_reg_idx_output + off_output * sizeof(float)]);
-
-            int input_load_overlap = i * jcp.stride_w + input_preamble_count;
-
-            /* Cascade 'input' loads for the corresponding FMAs */
-            const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
-            for (int c = 0; c < cascade_input; ++c) {
-                int off_input
-                        = ((c + input_load_overlap - pad_offset) * reg_repeats
-                                  + r)
-                        * simd_w;
-                Vmm vmm_input = get_input_reg(
-                        ((c + input_load_overlap + l_pad) % jcp.kw)
-                                * reg_repeats
-                        + r);
-                uni_vmovups(vmm_input,
-                        ptr[tmp_reg_idx_input + off_input * sizeof(float)]);
+                    ptr[reg_tmp_output + off_output * sizeof(float)]);
+            if (i_ur == 0) {
+                for (int c = 0; c < input_overlap; ++c) {
+                    int off_input
+                            = ((c - pad_offset) * reg_repeats + r) * simd_w;
+                    Vmm vmm_input
+                            = get_input_reg((c % jcp.kw) * reg_repeats + r);
+                    uni_vmovups(vmm_input,
+                            ptr[reg_tmp_input + off_input * sizeof(float)]);
+                }
+            } else {
+                for (int c = 0; c < cascade_input; ++c) {
+                    int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
+                    int off_input
+                            = ((overlap + c - pad_offset) * reg_repeats + r)
+                            * simd_w;
+                    Vmm vmm_input = get_input_reg(
+                            ((overlap + c) % jcp.kw) * reg_repeats + r);
+                    uni_vmovups(vmm_input,
+                            ptr[reg_tmp_input + off_input * sizeof(float)]);
+                }
             }
 
-            for (int j = 0; j < jcp.kw; ++j) {
+            for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) {
+                int io_overlap = i_kw + (i_ur * jcp.stride_w);
 
                 /* Don't apply FMAs that fall into the padded region */
-                if (i + j < l_pad || i + j - pad >= right_border)
+                if (io_overlap - l_pad < 0
+                        || io_overlap - jcp.l_pad >= right_border)
                     continue;
+
                 Vmm vmm_input = get_input_reg(
-                        ((i * jcp.stride_w + j) % jcp.kw) * reg_repeats + r);
-                Vmm vmm_acc = get_acc_reg(j * reg_repeats + r);
+                        ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r);
+                Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r);
                 Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input;
-                if( isa == sse42 ) uni_vmovups(vmm_aux, vmm_input);
+                if (isa == sse42)
+                    uni_vmovups(vmm_aux, vmm_input);
                 uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output);
             }
         }
@@ -866,8 +861,16 @@ jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_step_unroll(
         for (int i = 0; i < unroll_w; ++i) {
             Vmm vmm_bias = get_bias_reg(r);
             int off_output = (i * reg_repeats + r) * simd_w;
-            uni_vaddps(vmm_bias, vmm_bias,
-                    vmmword[tmp_reg_idx_output + off_output * sizeof(float)]);
+            if (isa == sse42) {
+                /* Need to support unaligned address loads for SSE42*/
+                Vmm vmm_output = get_output_reg(1 + r);
+                uni_vmovups(vmm_output,
+                        ptr[reg_tmp_output + off_output * sizeof(float)]);
+                uni_vaddps(vmm_bias, vmm_bias, vmm_output);
+            } else {
+                uni_vaddps(vmm_bias, vmm_bias,
+                        vmmword[reg_tmp_output + off_output * sizeof(float)]);
+            }
         }
     }
 }
@@ -879,7 +882,7 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_filter() {
         for (int i = 0; i < jcp.kw; ++i) {
             int off_filter = (i + reg_set) * simd_w;
             Vmm vmm_acc = get_acc_reg(i + reg_set);
-            uni_vmovups(vmmword[tmp_reg_filter + off_filter * sizeof(float)],
+            uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)],
                     vmm_acc);
         }
     }
@@ -895,343 +898,304 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_bias() {
 }
 
 template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::create_h_bounds_table() {
-    /* Bounds are stored on an 8-bit sized element.
-     * XXX: potential issues if bounds exceed 255.
-     */
-    const bool handle_padding = (jcp.t_pad > 0) || (jcp.b_pad > 0);
-    if (handle_padding) {
-
-        /* Calculate how many 'h_start' bounds are needed */
-        const int h_bounds_count = get_loop_bounds_count(
-                nstl::max(jcp.t_pad, jcp.b_pad), jcp.oh, jcp.oh_blk_size);
-
-        align(64);
-        L(bound_start_table);
-        /* Generate starting bounds for 'oh' loop. This value also determines
-         * the overlap (computed as an address offset) between the output over
-         * the input for that loop iteration. */
-        for (int oh_block = 0; oh_block < h_bounds_count; ++oh_block) {
-            for (int kh = 0; kh < jcp.kh; ++kh) {
-                te_size start_bound = nstl::max(
-                        jcp.t_pad - oh_block * jcp.oh_blk_size - kh, 0);
-                write_table(start_bound);
-            }
-        }
-        /* Write offset count for 'input' address calculation. The offset for
-         * the input address is conditioned by the 'h' padding intersection over
-         * the output rows. */
-        for (int kh = 1; kh < jcp.kh; ++kh) {
-            te_size kh_accum_value = nstl::max(nstl::min(kh - jcp.t_pad, 1), 0);
-            write_table(kh_accum_value);
-        }
-        /* Last value is not used for offset calculation, write 'nop'
-         * equivalent*/
-        write_table(0);
-
-        /* Non-padded blocks always increment 'kh' dimension */
-        for (int oh_block = 0; oh_block < h_bounds_count - 1; oh_block++) {
-            for (int kh = 0; kh < jcp.kh; ++kh) {
-                te_size kh_accum_value = 1;
-                write_table(kh_accum_value);
-            }
-        }
-
-        /* number of input elements that overlap over output */
-        int ih_overlap = jcp.oh_blk_size + jcp.kh - 1 - jcp.t_pad - jcp.b_pad;
-
-        /* End Bounds for 'oh' default to 'OH' or OH_BLOCK_SIZE, unless
-         * the 'oh_block' is within the 'bottom_padding' region. */
-        int oh_end_blk = 0;
-        for (; oh_end_blk < h_bounds_count - 1; ++oh_end_blk) {
-            for (int kh = 0; kh < jcp.kh; ++kh) {
-                te_size end_bound = nstl::min((jcp.ih / jcp.stride_h)
-                                - jcp.oh_blk_size - oh_end_blk * jcp.oh_blk_size
-                                + ih_overlap + 1 - kh,
-                        jcp.oh_blk_size);
-                write_table(end_bound);
-            }
-        }
-        /* Write bounds for the special case of when 'oh_block' falls within the
-         * 'bottom_paddin' region - this always executes since at least 1 row of
-         * bounds should exist. */
-        const int pad = nstl::max(jcp.b_pad, jcp.t_pad);
-        ih_overlap
-                = (jcp.ih / jcp.stride_h + jcp.kh - 1 - jcp.t_pad - jcp.b_pad);
-        oh_end_blk = jcp.oh - jcp.oh_blk_size;
-        for (int kh = 0; kh < jcp.kh; ++kh) {
-            te_size end_bound = nstl::min(
-                    jcp.oh_blk_size, ih_overlap - oh_end_blk + pad - kh);
-            write_table(end_bound);
-        }
-    }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_loop() {
-
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_loop(
+        const int block_size) {
     Label oh_label;
     Label ow_blk_label;
 
-    const int oh_block_size = jcp.oh_blk_size;
-    const int ow_unroll = jcp.ur_w;
-    const int ow_block_count = jcp.ow / ow_unroll;
+    const int unroll_w = nstl::min(block_size, jcp.ow);
+    const int unroll_w_trips = jcp.ow / unroll_w;
+    const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0;
+
     const int ch_offset = jcp.ch_block;
 
-    mov(tmp_reg_idx_output, reg_output_baddr);
+    mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
+    mov(reg_oh_worksize,
+            ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
 
-    xor_(iter_oh, iter_oh);
+    mov(reg_tmp_output, reg_output_baddr);
     L(oh_label);
     {
 
-        xor_(iter_ow_blk, iter_ow_blk);
+        mov(iter_ow_blk, unroll_w_trips);
         L(ow_blk_label);
         {
 
-            compute_bias_step_unroll(ow_unroll);
+            compute_bias_step_unroll(unroll_w);
+            add(reg_tmp_output, unroll_w * ch_offset * sizeof(float));
 
-            add(tmp_reg_idx_output, ow_unroll * ch_offset * sizeof(float));
+            dec(iter_ow_blk);
+            cmp(iter_ow_blk, 0);
+            jg(ow_blk_label, T_NEAR);
+        }
 
-            inc(iter_ow_blk);
-            cmp(iter_ow_blk, ow_block_count);
-            jl(ow_blk_label, T_NEAR);
+        if (tail_w > 0) {
+            compute_bias_step_unroll(tail_w);
+            add(reg_tmp_output, tail_w * ch_offset * sizeof(float));
         }
 
-        inc(iter_oh);
-        cmp(iter_oh, oh_block_size);
+        inc(reg_oh);
+        cmp(reg_oh, reg_oh_worksize);
         jl(oh_label, T_NEAR);
     }
 }
 
 template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_kh_loop(
-        int l_pad, int r_pad, int pad_offset, bool first_iteration,
-        int ow_block) {
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_zero_filter() {
 
-    Label kh_label;
-    Label oh_label;
-    Label exit_innerloop_label;
-    Label skip_load_acc;
+    const int ch_offset = jcp.ch_block;
 
-    const int table_row_count = get_loop_bounds_count(
-            nstl::max(jcp.t_pad, jcp.b_pad), jcp.oh, jcp.oh_blk_size);
-    const int ih_table_off = 1 * table_row_count * jcp.kh * sizeof(te_size);
-    const int end_bound_table_off
-            = 2 * table_row_count * jcp.kh * sizeof(te_size);
+    Label kh_loop_label, skip_zeroing_label;
+
+    mov(reg_exec_flags,
+            ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
+    and_(reg_exec_flags, FLAG_ZERO_FILTER);
+    test(reg_exec_flags, reg_exec_flags);
+    je(skip_zeroing_label);
+
+    zero_filter();
+
+    mov(reg_tmp_filter, reg_filter_baddr);
+    mov(reg_kh, jcp.kh);
+    L(kh_loop_label);
+    {
+        store_filter();
+
+        add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
+        dec(reg_kh);
+        cmp(reg_kh, 0);
+        jg(kh_loop_label);
+    }
+
+    /* Comeback pointers */
+    sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float));
+
+    L(skip_zeroing_label);
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_step(
+        int unroll_w, int l_pad, int pad_offset, int ow_block) {
 
     const int ch_offset = jcp.ch_block;
 
-    const bool handle_padding = (jcp.t_pad > 0) || (jcp.b_pad > 0);
+    Label kh_loop_label, skip_loop_label;
 
-    mov(tmp_reg_filter, reg_filter_baddr);
-    mov(tmp_reg_kh_input, reg_input_baddr);
-    xor_(reg_tmp_off, reg_tmp_off);
+    cmp(reg_kh_count, 0);
+    je(skip_loop_label, T_NEAR);
 
-    if (handle_padding) {
-        mov(reg_bound_table_addr, bound_start_table);
+    mov(reg_kh, reg_kh_count);
+    L(kh_loop_label);
+    {
+        load_filter();
+        compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block);
+        store_filter();
 
-        /* move to the row containing the indices for the current 'h' block */
-        mov(reg_tmp_off, reg_table_idx);
-        imul(reg_tmp_off, reg_tmp_off, jcp.kh * sizeof(unsigned char));
-        add(reg_bound_table_addr, reg_tmp_off);
+        add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
+        add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
+        dec(reg_kh);
+        cmp(reg_kh, 0);
+        jg(kh_loop_label);
     }
 
-    xor_(iter_kh, iter_kh);
-    L(kh_label);
+    /* Comeback pointers */
+    Label kh_comeback_label;
+    mov(reg_kh, reg_kh_count);
+    L(kh_comeback_label);
     {
+        sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
+        sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
+        dec(reg_kh);
+        cmp(reg_kh, 0);
+        jg(kh_comeback_label, T_NEAR);
+    }
 
-        mov(tmp_reg_idx_output, reg_output_baddr);
-        mov(tmp_reg_idx_input, tmp_reg_kh_input);
+    L(skip_loop_label);
+}
 
-        if (first_iteration) {
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_loop(
+        int unroll_w, int l_pad, int pad_offset, int ow_block) {
 
-            /* apply zero filter */
-            zero_filter();
+    const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ?
+            jcp.ih / jcp.stride_h - 1 :
+            jcp.oh - jcp.b_pad - 1;
+    const int ch_offset = jcp.ch_block;
+    const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
+    const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
 
-            /* if zero_filter_flag is set to '1', load filter memory into
-             * reg_accum */
-            if (jcp.with_bias) {
-                mov(reg_tmp_al, reg_exec_flag);
-                and_(reg_tmp_al, FLAG_ZERO_FILTER);
-                cmp(reg_tmp_al, 0);
-            } else {
-                /* none of the other flags are active, so we can use the
-                 * register directly */
-                cmp(reg_exec_flag, 0);
-            }
-            je(skip_load_acc);
-            load_filter();
-            L(skip_load_acc);
+    Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label,
+            end_h_loop_label;
 
-        } else {
-            load_filter();
-        }
+    mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
+    mov(reg_oh_worksize,
+            ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
+    mov(reg_kh_count,
+            ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]);
 
-        xor_(iter_oh, iter_oh);
+    mov(reg_tmp_output, reg_output_baddr);
+    mov(reg_tmp_input, reg_input_baddr);
+    mov(reg_tmp_filter, reg_filter_baddr);
 
-        if (handle_padding) {
+    L(h_loop_label);
+    {
 
-            /* 'oh loop' initial bounds are stored in bound_table */
-            mov(iter_oh_lb, byte[reg_bound_table_addr]);
+        compute_h_step(unroll_w, l_pad, pad_offset, ow_block);
 
-            /* skip 'oh' row that intersects with top padding */
-            xor_(reg_tmp_off, reg_tmp_off);
-            mov(reg_tmp_off, iter_oh);
-            imul(reg_tmp_off, reg_tmp_off, jcp.ow * ch_offset * sizeof(float));
-            add(tmp_reg_idx_output, reg_tmp_off);
+        add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float));
 
-            /* forward the input address by 'stride_h' */
-            if (jcp.stride_h > 1) {
-                xor_(reg_tmp_off, reg_tmp_off);
-                mov(reg_tmp_off, iter_oh);
-                imul(reg_tmp_off, reg_tmp_off,
-                        (jcp.stride_h - 1) * jcp.iw * ch_offset * sizeof(float));
-                add(tmp_reg_idx_input, reg_tmp_off);
-            }
-        }
-
-        L(oh_label);
-        {
+        /* If within the top_pad region */
+        if (jcp.t_pad > 0) {
+            /* Skip t_pad area if no longer in initial h_block */
+            cmp(reg_oh, jcp.t_pad);
+            jg(skip_tpad_label, T_NEAR);
 
-            compute_ow_step_unroll(l_pad, r_pad, pad_offset, ow_block);
+            cmp(reg_kh_count, jcp.kh);
+            jge(skip_tpad_label, T_NEAR);
 
-            add(tmp_reg_idx_input,
-                    jcp.stride_h * jcp.iw * ch_offset * sizeof(float));
-            add(tmp_reg_idx_output, jcp.ow * ch_offset * sizeof(float));
+            add(reg_kh_count, t_overlap_off);
+            sub(reg_tmp_filter,
+                    t_overlap_off * jcp.kw * ch_offset * sizeof(float));
 
-            inc(iter_oh);
-            if (handle_padding) {
-                /* 'oh loop' end bounds are stored in bound_table (precomputed
-                 * during JIT generation) */
-                cmp(iter_oh_lb,
-                        byte[reg_bound_table_addr + end_bound_table_off]);
-            } else {
-                cmp(iter_oh, jcp.oh_blk_size);
+            /* kernel has moved beyond padding (adjust for stride effects) */
+            if (jcp.t_pad % jcp.stride_h != 0) {
+                int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h;
+                add(reg_tmp_input,
+                        inp_corr * jcp.iw * ch_offset * sizeof(float));
             }
-            jl(oh_label, T_NEAR);
+            jmp(tpad_loop_label, T_NEAR);
         }
 
-        store_filter();
+        L(skip_tpad_label);
 
-        add(tmp_reg_filter, jcp.kw * ch_offset * sizeof(float));
+        cmp(reg_oh, io_overlap);
+        jl(skip_bpad_label, T_NEAR);
+        sub(reg_kh_count, b_overlap_off);
 
-        if (handle_padding) {
-            xor_(kh_offset, kh_offset);
-            mov(kh_offset_lb, byte[reg_bound_table_addr + ih_table_off]);
-            /* increase 'ih' row in regards to 'kh'. */
-            imul(kh_offset, kh_offset, jcp.iw * ch_offset * sizeof(float));
-            add(tmp_reg_kh_input, kh_offset);
+        L(skip_bpad_label);
+        add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float));
 
-            /* increase bound_table idx for the next 'kh' value in table*/
-            add(reg_bound_table_addr, sizeof(te_size));
-        } else {
-            add(tmp_reg_kh_input, jcp.iw * ch_offset * sizeof(float));
-        }
+        L(tpad_loop_label);
+
+        cmp(reg_oh, jcp.ih / jcp.stride_h);
+        jge(end_h_loop_label, T_NEAR);
 
-        inc(iter_kh);
-        cmp(iter_kh, jcp.kh);
-        jl(kh_label, T_NEAR);
+        inc(reg_oh);
+
+        cmp(reg_oh, reg_oh_worksize);
+        jl(h_loop_label, T_NEAR);
     }
+    L(end_h_loop_label);
 }
 
 template <cpu_isa_t isa>
 inline void
 jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
 
-    Label skip_load_bias;
-
-    /* Only apply zero_filter (xor'ing accum_reg) on the left edge */
-    bool zero_filter_1st_iter = true;
-
     const int ch_offset = jcp.ch_block;
-
-    const int ow_block_size = jcp.ow_blk_size;
-    const int iw_block_size = jcp.ow_blk_size * jcp.stride_w;
-
-    int w_unrolled_loop_count = jcp.ow / ow_block_size;
-
-    const bool handle_padding = (jcp.l_pad > 0) || (jcp.r_pad > 0);
-
-    int pad_offset = jcp.l_pad;
-
-    int ow_block = 0;
-
+    int ow = jcp.ow;
+    int pad_offset = 0;
+    int l_pad = jcp.l_pad;
+
+    /* Calculate effective padding */
+    int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
+                    + (jcp.kw - 1) * (jcp.dilate_w + 1)
+                    - (jcp.iw + jcp.l_pad - 1));
+
+    /* Is this strictly defined by:
+     * -code-size (?)
+     * -address size (?) */
+    const int max_unroll_w = 30;
+    const int block_size = 15;
+
+    int unroll_w_tail = 0;
+    int unroll_w = 0;
+    int unroll_w_trips = 0;
+
+    if (jcp.ow > max_unroll_w) {
+        unroll_w = nstl::min(block_size, jcp.ow);
+        unroll_w_trips = ow / unroll_w;
+        /* calculate tail */
+        unroll_w_tail = ow % unroll_w;
+        /* Perform some rebalancing if tail too small*/
+        if ((unroll_w_tail == 0 && r_pad != 0)
+                || (r_pad > 0 && r_pad >= unroll_w_tail)) {
+            if (unroll_w_trips > 1) {
+                unroll_w_tail += unroll_w;
+                unroll_w_trips--;
+            } else {
+                /* Idealy, this case shouldn't happen */
+                unroll_w_tail += (unroll_w - unroll_w / 2);
+                unroll_w = unroll_w / 2;
+            }
+        }
+    } else {
+        unroll_w = jcp.ow;
+        unroll_w_trips = nstl::max(1, ow / unroll_w);
+    }
     if (jcp.with_bias) {
+        Label skip_load_bias;
+        mov(reg_bias_baddr,
+                ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]);
 
         zero_bias();
 
-        /* if zero_bias is '1', load bias accumulator from memory. This happens
-         * after the first iteration is executed  */
-        mov(reg_tmp_al, reg_exec_flag);
-        and_(reg_tmp_al, FLAG_ZERO_BIAS);
-        cmp(reg_tmp_al, 0);
-        je(skip_load_bias);
+        mov(reg_exec_flags,
+                ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
+        and_(reg_exec_flags, FLAG_ZERO_BIAS);
+        test(reg_exec_flags, reg_exec_flags);
+        jne(skip_load_bias);
+
         load_bias();
-        L(skip_load_bias);
 
-        compute_bias_loop();
+        L(skip_load_bias);
+        compute_bias_loop(block_size);
 
         store_bias();
     }
 
-    /* compute left padded block */
-    if (handle_padding) {
-
-        const int r_pad = jcp.iw - ow_block_size > 0 ? 0 : jcp.r_pad;
-
-        compute_kh_loop(jcp.l_pad, r_pad, 0, zero_filter_1st_iter, ow_block);
-        zero_filter_1st_iter = false;
+    /* Pass filter address, then offset for h_padding. */
+    compute_zero_filter();
+    mov(reg_kh_offset,
+            ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]);
+    add(reg_filter_baddr, reg_kh_offset);
 
-        w_unrolled_loop_count--;
-
-        if (w_unrolled_loop_count >= 1) {
-            add(reg_output_baddr, ow_block_size * ch_offset * sizeof(float));
-            add(reg_input_baddr, iw_block_size * ch_offset * sizeof(float));
-        }
+    /* compute left padded block */
+    if (l_pad) {
+        compute_h_loop(unroll_w, l_pad, 0, 0);
+        add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
+        add(reg_input_baddr,
+                unroll_w * jcp.stride_w * ch_offset * sizeof(float));
+        unroll_w_trips--;
+        pad_offset = l_pad;
+        l_pad = 0;
     }
 
-    /* This block may execute under 2 different scenarios:
-     * 1) When padding is present, this executes the middle loop (if any).
-     * 2) With no padding, it writes the full loop of the micro-kernel. */
-    int middle_loop_count = handle_padding ? w_unrolled_loop_count - 1 :
-                                             w_unrolled_loop_count;
-    if (middle_loop_count >= 1) {
-        Label ow_blk_label;
-
-        /* Insert loop for 'ow' block when middle block needs to execute more
-         * than once */
-        bool do_ow_blk_loop = middle_loop_count > 1;
-        if (do_ow_blk_loop) {
-            mov(iter_ow_blk, middle_loop_count);
-            L(ow_blk_label);
-        }
-
-        compute_kh_loop(0, 0, pad_offset, zero_filter_1st_iter);
-        /* disable zero_filter for the rest of the iterations i.e. from now on
-         * load contents of 'filter' from memory */
-        mov(reg_exec_flag, FLAG_ZERO_FILTER);
-
-        if (do_ow_blk_loop || handle_padding) {
-            add(reg_output_baddr, ow_block_size * ch_offset * sizeof(float));
-            add(reg_input_baddr, iw_block_size * ch_offset * sizeof(float));
-        }
-
-        if (do_ow_blk_loop) {
-            dec(iter_ow_blk);
-            cmp(iter_ow_blk, 0);
-            jg(ow_blk_label, T_NEAR);
-        }
+    /* compute middle block */
+    Label ow_blk_label;
 
-        w_unrolled_loop_count -= middle_loop_count;
+    /* Insert loop for 'ow' block when middle block needs to execute more
+     * than once */
+    bool do_ow_blk_loop = unroll_w_trips > 1;
+    if (do_ow_blk_loop) {
+        mov(iter_ow_blk, unroll_w_trips);
+        L(ow_blk_label);
+    }
+    if (unroll_w_trips > 0) {
+        compute_h_loop(unroll_w, l_pad, pad_offset, 0);
+        add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
+        add(reg_input_baddr,
+                unroll_w * jcp.stride_w * ch_offset * sizeof(float));
+    }
+    if (do_ow_blk_loop) {
+        dec(iter_ow_blk);
+        cmp(iter_ow_blk, 0);
+        jg(ow_blk_label, T_NEAR);
     }
 
-    /* compute right padded block: ow_blk = LAST */
-    if (handle_padding && w_unrolled_loop_count >= 1) {
-        ow_block = jcp.ow - ow_block_size;
-        compute_kh_loop(
-                0, jcp.r_pad, pad_offset, zero_filter_1st_iter, ow_block);
-
-        w_unrolled_loop_count--;
+    /* compute right padded block */
+    if (unroll_w_tail) {
+        compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail);
     }
 }
 
@@ -1245,17 +1209,10 @@ void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() {
             ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]);
     mov(reg_filter_baddr,
             ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]);
-    if (jcp.with_bias)
-        mov(reg_bias_baddr,
-                ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]);
-    mov(reg_table_flags,
-            ptr[this->param1 + offsetof(jit_dw_conv_call_s, table_flags)]);
 
     compute_ow_block_unroll();
 
     this->postamble();
-
-    create_h_bounds_table();
 }
 
 template <cpu_isa_t isa>
@@ -1263,8 +1220,7 @@ status_t jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_conf(
         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
         const memory_desc_wrapper &src_d,
         const memory_desc_wrapper &diff_weights_d,
-        const memory_desc_wrapper &diff_dst_d) {
-
+        const memory_desc_wrapper &diff_dst_d, int nthreads) {
     if (!mayiuse(isa))
         return status::unimplemented;
 
@@ -1295,8 +1251,6 @@ status_t jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_conf(
     jcp.stride_w = cd.strides[1];
 
     jcp.t_pad = cd.padding[0][0];
-    /* bottom padding should equal top padding to generate the proper 'h' loop
-     * bounds. */
     jcp.b_pad = cd.padding[1][0];
 
     jcp.l_pad = cd.padding[0][1];
@@ -1315,53 +1269,71 @@ status_t jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_conf(
     auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
     auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
 
-    bool args_ok = true
-                   && src_d.format() == desired_act_fmt
-                   && diff_weights_d.format() == desired_wei_fmt
-                   && diff_dst_d.format() == desired_act_fmt
-                   && one_of(cd.bias_desc.format, memory_format::undef, any, x)
-                   //&& jcp.ngroups % simd_w == 0
-                   && jcp.ngroups % jcp.ch_block == 0
-                   && jcp.dilate_h == 0
-                   && jcp.dilate_w == 0
-                   && jcp.kw <= 3
-                   && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
-                   && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
-    if (!args_ok) return status::unimplemented;
-
-    /* Note: this IMPLICATION-check does not allow 'negative padding' execution
-     */
-    bool ok = true && IMPLICATION(jcp.r_pad > 0, jcp.r_pad == jcp.l_pad)
-            && IMPLICATION(jcp.b_pad > 0, jcp.b_pad == jcp.t_pad);
-    if (!ok)
+    bool args_ok = true && src_d.format() == desired_act_fmt
+            && diff_weights_d.format() == desired_wei_fmt
+            && diff_dst_d.format() == desired_act_fmt
+            && one_of(cd.bias_desc.format, memory_format::undef, any, x)
+            && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0
+            && jcp.dilate_w == 0 && jcp.kw <= 3
+            && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
+            && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
+    if (!args_ok)
         return status::unimplemented;
 
     jcp.nb_ch = jcp.ngroups / jcp.ch_block;
 
-    /* Values for block size to try; order gives priority */
-    constexpr int BLOCK_SIZE[] = { 14, 16, 7, 8 };
-
-    int block_size_h = 1;
-    int block_size_w = 1;
+    /* kernel applicability check wrt boundaries
+     * the conditions are quite general across the kernels we have,
+     * but ideally the check should belong to a specific kernel... */
+    const int max_hpad = (jcp.kh - 1 + 1) / 2;
+    const int max_wpad = (jcp.kw - 1 + 1) / 2;
+    const bool boundaries_ok = true && jcp.t_pad <= max_hpad
+            && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad
+            && jcp.r_pad <= max_wpad;
+    if (!boundaries_ok)
+        return status::unimplemented;
 
-    /* *Try different block sizes for convolution */
-    for (int block : BLOCK_SIZE) {
+    balance(jcp, nthreads);
 
-        block_size_h = block / jcp.stride_h;
-        block_size_w = block / jcp.stride_w;
+    return status::success;
+}
 
-        if ((jcp.oh % block_size_h == 0) && (jcp.ow % block_size_w == 0))
-            break;
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+    /* Notes: if splitting thread work on 'mb', then a reduction has to take
+     * place. Hence, book a per-thread, local weights-buffer for the
+     * reduction */
+    if (jcp.nthr_mb > 1) {
+        const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
+        scratchpad.book(key_conv_wei_reduction,
+                sizeof(float) * wei_size * (jcp.nthr_mb - 1));
+
+        if (jcp.with_bias)
+            scratchpad.book(key_conv_bia_reduction,
+                    sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1));
     }
+}
 
-    if (jcp.oh % block_size_h != 0 || jcp.ow % block_size_w != 0)
-        return status::unimplemented;
-
-    jcp.oh_blk_size = block_size_h;
-
-    jcp.ur_w = jcp.ow_blk_size = block_size_w;
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::balance(jit_conv_conf_t &jcp,
+        int nthreads) {
+    jcp.nthr = nthreads;
+    jcp.nthr_g = jcp.nthr_mb = 1;
+
+    /* Basic-Heuristics for parallel strategy:
+     * 1) Tries to parallel on the number of Groups (g) where tasks are
+     * independent. Otherwise,
+     * 2) Tries to split the work across g and MiniBatch (mb).
+     * Parallelizing on mb requires computing a reduction for weights.
+     *
+     * NOTE: because of 'task partitioning' scheme, there will be unbalanced
+     * per-thread load when the number of threads is high (e.g. > 16).
+     */
+    jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr);
+    jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb);
 
-    return status::success;
+    jcp.nthr = jcp.nthr_g * jcp.nthr_mb;
 }
 
 template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_common>;