Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_conv_kernel_f32.cpp
index 392622a..0caa4b4 100644 (file)
@@ -15,7 +15,6 @@
 * limitations under the License.
 *******************************************************************************/
 
-#include <common/primitive_attr.hpp>
 #include "c_types_map.hpp"
 #include "nstl.hpp"
 #include "type_helpers.hpp"
@@ -32,6 +31,7 @@ namespace cpu {
 
 using namespace mkldnn::impl::prop_kind;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 using namespace Xbyak;
@@ -77,9 +77,8 @@ void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
                         vfmadd231ps(Ymm(ur_w * ii + jj),
                                 Ymm(oc_blocks * ur_w + jj), ymm15);
                     else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
-                        Ymm tmp = ymask;
-                        vmulps(tmp, ymm15, Ymm(oc_blocks * ur_w + jj));
-                        vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), tmp);
+                        vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
+                        vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
                     }
             }
         }
@@ -131,9 +130,8 @@ void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
                         vfmadd231ps(Ymm(ur_w * ii + jj),
                                 Ymm(oc_blocks * ur_w + jj), ymm15);
                     else { // Intel AVX support
-                        Ymm tmp = ymask;
-                        vmulps(tmp, ymm15, Ymm(oc_blocks * ur_w + jj));
-                        vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), tmp);
+                        vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
+                        vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
                     }
             }
         }
@@ -176,7 +174,7 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
         for (int jj = 0; jj < ur_w; jj++) {
             size_t offt;
             if (jcp.with_dw_conv)
-                offt = sizeof(float) * ((size_t)ii * od * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
+                offt = sizeof(float) * ((size_t)ii * od * jcp_dw.kh * ow + jj) * oc_blk;
             else
                 offt = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
             vmovups(Ymm(ur_w * ii + jj),
@@ -224,7 +222,8 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
         mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
         mov(aux_reg_inp_d, reg_input);
 
-        if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
+        if ((jcp.dilate_d >= jcp.id)
+                || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
             cmp(reg_ki, 0);
             je(skip_kd_loop, T_NEAR);
         }
@@ -239,7 +238,8 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
         mov(aux_reg_kernel, aux_reg_ker_d);
     }
 
-    if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
+    if ((jcp.dilate_h >= jcp.ih)
+            || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
         cmp(kj, 0);
         je(skip_kh_loop, T_NEAR);
     }
@@ -279,8 +279,7 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
         pop(reg_output);
     }
 
-
-    Label done, regular_store;
+    Label regular_store;
 
     test(reg_ci_flag, FLAG_IC_LAST);
     je(regular_store, T_NEAR);
@@ -289,10 +288,6 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
     int depthwise_inj_idx = 0;
     const auto &p = attr_.post_ops_;
 
-    if (p.len_ == 0 && eltwise_injectors.size() == 1) {
-        eltwise_injectors[0]->compute_vector_range(0, oc_blocks * ur_w);
-    }
-
     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
     for (int i = 0; i < end_idx; i++) {
         auto& post_op = p.entry_[i];
@@ -324,14 +319,13 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
         for (int jj = 0; jj < ur_w; jj++) {
             size_t o_off;
             if (jcp.with_dw_conv)
-                o_off = sizeof(float) * ((size_t)ii * od * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
+                o_off = sizeof(float) * ((size_t)ii * od * jcp_dw.kh * ow + jj) * oc_blk;
             else
                 o_off = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
             Ymm reg_out = Ymm(ur_w * ii + jj);
             vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out);
         }
     }
-    L(done);
 }
 
 inline void jit_avx2_conv_fwd_kernel_f32::solve_common(
@@ -397,12 +391,6 @@ inline void jit_avx2_conv_fwd_kernel_f32::solve_common(
 
 void jit_avx2_conv_fwd_kernel_f32::generate()
 {
-    if (jcp.with_eltwise) {
-        eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx2>(
-                this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
-        ));
-    }
-
     const auto &p = attr_.post_ops_;
     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
     for (int i = 0; i < end_idx; i++) {
@@ -474,25 +462,16 @@ bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok(
     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
-    case 0: return true; // no post_ops
-    case 1:
-        return true // sum OR eltwise OR dw_conv
-                && !jcp.with_eltwise && (is_simple(0) || is_sum(0) || is_dw_conv(0));
-    case 2:
-        return true // sum->eltwise OR dw_conv->eltwise OR eltwise->dw_conv OR dw_conv->sum OR sum->depthwise OR
-                    // eltwise->depthwise OR depthwise->depthwise
-                && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
-                                         (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
-                                         (is_simple(0) && is_simple(1)));
-    case 3:
-        return true // eltwise->dw_conv->eltwise OR dw_conv->sum->eltwise OR sum->eltwise->depthwise OR
-                    // sum->depthwise->eltwise OR sum->depthwise->depthwise
-                && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
-                                         (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
-                                         (is_sum(0) && is_simple(1) && is_simple(2)));
-    case 4: return true // eltwise->dw_conv->sum->eltwise
-            && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
-    default: return false;
+        case 0: return true;
+        case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
+        case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
+                       (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
+                       (is_simple(0) && is_simple(1));
+        case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
+                       (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
+                       (is_sum(0) && is_simple(1) && is_simple(2));
+        case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
+        default: return false;
     }
 
     return false;
@@ -501,7 +480,7 @@ bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok(
 status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
-        const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+        const primitive_attr_t &attr)
 {
     if (!mayiuse(avx)) return status::unimplemented;
 
@@ -539,63 +518,62 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
     jcp.dilate_w = cd.dilates[ndims-3];
 
-    jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
-            - (jcp.ih + jcp.t_pad - 1);
-
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alg = mkldnn_eltwise_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
 
     if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
 
     const auto &p = attr.post_ops_;
-    jcp.with_dw_conv = false;
+
     int dw_conv_ind = p.find(primitive_kind::convolution);
-    if (dw_conv_ind != -1) {
-        jcp.with_dw_conv = true;
-        jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
-        jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
-        jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
-        jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
-        jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
-        jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
-        jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
-        jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+    jcp.with_dw_conv = dw_conv_ind != -1;
+    if (jcp.with_dw_conv) {
+        jcp.dw_conv_oh = jcp.oh;
+        jcp.dw_conv_ow = jcp.ow;
+        jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
+        jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
     }
 
+    jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+                - (jcp.ih + jcp.t_pad - 1);
+
     if (jcp.with_dw_conv && !mayiuse(avx2))
         return status::unimplemented;
 
     if (jcp.with_dw_conv && jcp.ndims == 5)
         return status::unimplemented;
 
-    if (jcp.with_dw_conv) {
-        int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
-        if (dw_conv_eltwise_ind != -1) {
-            jcp.dw_conv_with_eltwise = true;
-            jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
-            jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
-            jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
+    if (!mayiuse(avx2)) {
+        for (int i = 0; i < p.len_; i++) {
+            auto &post_op = p.entry_[i];
+            if (post_op.is_eltwise()) {
+                if (post_op.eltwise.alg != alg_kind::eltwise_relu)
+                    return status::unimplemented;
+            } else if (post_op.is_depthwise()) {
+                return status::unimplemented;
+            }
         }
     }
 
     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
-    if (jcp.with_dw_conv) {
-        jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
-    }
 
-    if (jcp.with_dw_conv) {
-        jcp.oh = jcp.dw_conv_in_h;
-        jcp.ow = jcp.dw_conv_in_w;
-    }
+    jcp.src_dt = cd.src_desc.data_type;
+    jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+    jcp.dst_dt = cd.dst_desc.data_type;
 
     const int simd_w = 8;
     const bool flat = jcp.ic < simd_w;
     const bool mimo = !flat;
 
+
+    /* Grouped channel offset to support 'non-blocked data' format for
+     * convolution sizes with '(input_channel / ngroups) < simd' */
+    jcp.nonblk_group_off
+            = (one_of(src_d.format(), ncw, nchw, ncdhw) && jcp.ngroups > 1) ?
+            jcp.ic :
+            1;
+
     bool ok_to_pad_channels = true
         && jcp.ngroups == 1;
 
@@ -686,8 +664,23 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
     return status::success;
 }
 
-void jit_avx2_conv_bwd_data_kernel_f32::hsw_iter(int ur_w, int l_overflow,
-        int r_overflow, int start_off)
+void jit_avx2_conv_fwd_kernel_f32::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
+    if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+        scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+
+    if (jcp.with_dw_conv) {
+        const int nthreads = mkldnn_get_max_threads();
+        size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
+        scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
+
+        if (jcp.oc != jcp.oc_without_padding)
+            scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
+    }
+}
+
+void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow,
+        int r_overflow)
 {
     int kw = jcp.kw;
     int kh = jcp.kh;
@@ -696,29 +689,37 @@ void jit_avx2_conv_bwd_data_kernel_f32::hsw_iter(int ur_w, int l_overflow,
     int ih = jcp.ih;
     int id = jcp.id;
     int ow = jcp.ow;
-    int stride_w = jcp.stride_w;
-    int stride_h = jcp.stride_h;
 
     int ic_block = jcp.ic_block;
     int oc_block = jcp.oc_block;
     int nb_ic_block = jcp.nb_ic_blocking;
+    int stride_w = jcp.stride_w;
+    int stride_h = jcp.stride_h;
 
     Label kd_loop, skip_kd_loop;
+    Label oc_loop, skip_oc_loop;
 
     for (int ii = 0; ii < nb_ic_block; ii++)
         for (int jj = 0; jj < ur_w; jj++) {
-            size_t offt = sizeof(float) * ((size_t)ii * id * ih * iw + jj)
-                * ic_block;
-            vmovups(Ymm(ur_w * ii + jj),
-                    make_safe_addr(reg_dsrc, offt, reg_long_offt));
+            uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
+                      Ymm(ur_w * ii + jj));
         }
 
     if (one_of(jcp.ndims, 3, 4)) {
-        mov(aux_reg_ddst, reg_ddst);
-        mov(aux_reg_kernel, reg_kernel);
+        cmp(reg_channel_work, 0);
+        jle(skip_oc_loop, T_NEAR);
+        xor_(reg_channel, reg_channel);
+
+        mov(aux_reg_ddst_oc_loop, reg_ddst);
+        mov(aux_reg_kernel_oc_loop, reg_kernel);
+
+        L(oc_loop);
+        mov(aux_reg_ddst, aux_reg_ddst_oc_loop);
+        mov(aux_reg_kernel, aux_reg_kernel_oc_loop);
     }
 
     if (jcp.ndims == 5) {
+        assert(jcp.nb_oc_blocking == 1);
         push(oi_iter);
 
         mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]);
@@ -736,42 +737,46 @@ void jit_avx2_conv_bwd_data_kernel_f32::hsw_iter(int ur_w, int l_overflow,
         mov(aux_reg_kernel, aux_reg_ker_d);
     }
 
-    mov(kj, reg_kh);
-
-    Label kh_label;
-
-    L(kh_label); {
+    Label kh_loop, skip_kh_loop;
+    cmp(kj, 0);
+    jle(skip_kh_loop, T_NEAR);
+    L(kh_loop); {
         for (int ki = 0; ki < kw; ki++) {
-            int jj_start = nstl::max(0, l_overflow - (kw - 1) + ki) ; // 0;
-            int jj_end = ur_w - nstl::max(0, r_overflow - ki); // ur_w;
+            int jj_start = get_iw_start(ki, l_overflow); // 0;
+            int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w;
             for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
 
-                for (int jj = jj_start; jj < jj_end; jj++) {
-                   if ((jj - ki + jcp.l_pad + start_off) % stride_w == 0) {
-                        int aux_output_offset = ((jj - ki + jcp.l_pad + start_off) / stride_w) * jcp.oc_block + ofm2;
-                        vbroadcastss(Ymm(nb_ic_block * ur_w + jj), ptr[aux_reg_ddst + sizeof(float) * aux_output_offset]);
-                   }
+                for (int jj = jj_start ; jj < jj_end; jj += stride_w) {
+                    int aux_output_offset
+                      = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2;
+                    vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w),
+                            ptr[aux_reg_ddst
+                            + sizeof(float) * aux_output_offset]);
                 }
 
-                for (int ii = 0; ii < nb_ic_block; ii++) {
-                    int aux_kernel_offset = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block + ki * jcp.ic_block * jcp.oc_block + ofm2 * jcp.ic_block;
-                    vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]);
-
-                    for (int jj = jj_start; jj < jj_end; jj++) {
-                       if ((jj - ki + jcp.l_pad + start_off) % stride_w == 0) {
-                            vfmadd231ps(Ymm(ur_w * ii + jj), Ymm(nb_ic_block * ur_w + jj), ymm15);
-                       }
-                    }
+                for (int ii = 0; ii  < nb_ic_block; ii++) {
+                    int aux_kernel_offset
+                        = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block
+                        + ki * jcp.ic_block * jcp.oc_block
+                        + ofm2 * jcp.ic_block;
+                    vmovups(ymm15,
+                            ptr[aux_reg_kernel
+                            + sizeof(float) * aux_kernel_offset]);
+                    for (int jj = jj_start; jj  < jj_end; jj += stride_w)
+                        vfmadd231ps(Ymm(ur_w * ii + jj),
+                                Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15);
                 }
             }
         }
-        add(aux_reg_kernel, sizeof(float) * kw  * oc_block * ic_block * stride_h);
+        add(aux_reg_kernel, sizeof(float) * stride_h * kw  * oc_block
+                                          * ic_block);
         sub(aux_reg_ddst, sizeof(float) * ow * oc_block);
 
-        sub(kj, stride_h);
+        dec(kj);
         cmp(kj, 0);
-        jg(kh_label, T_NEAR);
+        jg(kh_loop, T_NEAR);
     }
+    L(skip_kh_loop);
 
     if (jcp.ndims == 5) {
         sub(aux_reg_dst_d,
@@ -787,6 +792,39 @@ void jit_avx2_conv_bwd_data_kernel_f32::hsw_iter(int ur_w, int l_overflow,
         pop(oi_iter);
     }
 
+    if (one_of(jcp.ndims, 3, 4)) {
+        int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow
+                          * jcp.oc_block;
+        int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw
+                          * jcp.ic * jcp.oc_block;
+
+        add(aux_reg_ddst_oc_loop, ddst_oc_shift);
+        add(aux_reg_kernel_oc_loop, kernel_oc_shift);
+
+        inc(reg_channel);
+        cmp(reg_channel, reg_channel_work);
+        jl(oc_loop, T_NEAR);
+
+        L(skip_oc_loop);
+        mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
+    }
+
+    Label no_update_label;
+    cmp(reg_channel, 0);
+    je(no_update_label, T_NEAR);
+    for (int ii = 0; ii < nb_ic_block; ii++) {
+        for (int jj = 0; jj < ur_w; jj++) {
+            size_t offt =
+                sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
+            vmovups(Ymm(15),
+                    make_safe_addr(reg_dsrc, offt, reg_long_offt));
+            vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
+                    Ymm(15));
+
+        }
+    }
+    L(no_update_label);
+
     for (int ii = 0; ii < nb_ic_block; ii++)
         for (int jj = 0; jj < ur_w; jj++) {
             size_t offt =
@@ -799,79 +837,63 @@ void jit_avx2_conv_bwd_data_kernel_f32::hsw_iter(int ur_w, int l_overflow,
 void jit_avx2_conv_bwd_data_kernel_f32::generate() {
     preamble();
 
-    auto hsw_iter_body = [=] (int ur_w, int l_overflow, int r_overflow) {
-        if (jcp.stride_w == 1) {
-            hsw_iter(ur_w, l_overflow, r_overflow, 0);
-            add(reg_dsrc, sizeof(float) * jcp.ur_w * jcp.ic_block);
-            add(reg_ddst, sizeof(float) * jcp.ur_w * jcp.oc_block);
-        } else {
-            Label hsw_iter_off_0;
-            Label hsw_iter_off_1;
-            Label hsw_iter_exit;
-
-            int dst_off = jcp.ur_w / jcp.stride_w;
-
-            and_(start_off_reg, 1);
-
-            L(hsw_iter_off_0); {
-                cmp(start_off_reg, 0);
-                jg(hsw_iter_off_1, T_NEAR);
-
-                hsw_iter(ur_w, l_overflow, r_overflow, 0);
-                add(reg_dsrc, sizeof(float) * jcp.ur_w * jcp.ic_block);
-                add(reg_ddst, sizeof(float) * dst_off * jcp.oc_block);
-
-                jmp(hsw_iter_exit, T_NEAR);
-            }
-
-            L(hsw_iter_off_1); {
-                hsw_iter(ur_w, l_overflow, r_overflow, 1);
-                add(reg_dsrc, sizeof(float) * jcp.ur_w * jcp.ic_block);
-                add(reg_ddst, sizeof(float) * (dst_off + 1) * jcp.oc_block);
-            }
-
-            L(hsw_iter_exit);
-            add(start_off_reg, std::abs(jcp.ur_w - jcp.stride_w));
-        }
-    };
-
     mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
     mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
+    mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
+    mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]);
 
-    int n_oi = jcp.iw / jcp.ur_w;
-    xor_(oi_iter, oi_iter);
-    xor_(start_off_reg, start_off_reg);
+    int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block;
+    int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block;
 
-    int l_overflow = nstl::max(0, jcp.kw - 1 - jcp.l_pad);
-    if (l_overflow > 0) {
-        hsw_iter_body(jcp.ur_w, l_overflow, 0);
-        inc(oi_iter);
-    }
+    int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
+    int r_overflow = nstl::max(0, (jcp.kw - 1
+                    - nstl::max(0, jcp.r_pad)) / jcp.stride_w);
+    int r_overflow1 = nstl::max(0, (jcp.kw - 1
+                    - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
 
-    int r_pad = jcp.iwp - jcp.iw - jcp.l_pad;
-    int r_overflow1
-        = nstl::max(0, jcp.kw - 1 - (jcp.iw - jcp.ur_w * n_oi) - r_pad);
-    int r_overflow = nstl::max(0, jcp.kw - 1 - r_pad);
+    int n_oi = jcp.iw / jcp.ur_w;
     if (r_overflow1 > 0)
         n_oi--;
 
-    if ((l_overflow <= 0 && n_oi > 0) || (l_overflow >  0 && n_oi > 1)) {
-        Label ow_loop;
-        L(ow_loop); {
-            hsw_iter_body(jcp.ur_w, 0, 0);
+    if (jcp.ur_w == jcp.iw) {
+        compute_loop(jcp.ur_w, l_overflow, r_overflow);
+    } else if (n_oi == 0) {
+        compute_loop(jcp.ur_w, l_overflow, r_overflow1);
+        add(reg_dsrc, dsrc_shift);
+        add(reg_ddst, ddst_shift);
+        if (jcp.ur_w_tail != 0)
+            compute_loop(jcp.ur_w_tail, 0, r_overflow);
+    } else {
+        xor_(oi_iter, oi_iter);
+        if (l_overflow > 0) {
+            compute_loop(jcp.ur_w, l_overflow, 0);
+            add(reg_dsrc, dsrc_shift);
+            add(reg_ddst, ddst_shift);
             inc(oi_iter);
-            cmp(oi_iter, n_oi);
-            jl(ow_loop, T_NEAR);
         }
-    }
 
-    if (r_overflow1 > 0 )
-        hsw_iter_body(jcp.ur_w, 0, r_overflow1);
+        if ((l_overflow <= 0 && n_oi > 0) || (l_overflow >  0 && n_oi > 1)) {
+            Label ow_loop;
+            L(ow_loop); {
+                compute_loop(jcp.ur_w, 0, 0);
+                add(reg_dsrc, dsrc_shift);
+                add(reg_ddst, ddst_shift);
+                inc(oi_iter);
+                cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR);
+            }
+        }
 
-    if (jcp.ur_w_tail != 0)
-        hsw_iter_body(jcp.ur_w_tail, 0, r_overflow);
+        if (r_overflow1 > 0 ) {
+            compute_loop(jcp.ur_w, 0, r_overflow1);
+            add(reg_dsrc, dsrc_shift);
+            add(reg_ddst, ddst_shift);
+        }
+
+        if (jcp.ur_w_tail != 0)
+            compute_loop(jcp.ur_w_tail, 0, r_overflow);
+    }
 
     this->postamble();
 }
@@ -930,6 +952,10 @@ status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
     bool ok_to_pad_channels = true
         && jcp.ngroups == 1;
 
+    /* gemm-based convolution performs better in these cases */
+    if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1)
+        return status::unimplemented;
+
     if (ok_to_pad_channels) {
         jcp.oc = rnd_up(jcp.oc, simd_w);
         jcp.ic = rnd_up(jcp.ic, simd_w);
@@ -945,16 +971,19 @@ status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
     jcp.ur_h = 1; /* no code-unrolling by h so far */
     jcp.nb_ic_blocking = 1;
     jcp.nb_oc_blocking = 1;
+    jcp.ur_w = 1;
+
+    if(one_of(ndims, 3, 4) && jcp.ow < 40)
+        jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2;
 
     jcp.src_fmt = diff_src_d.format();
-    jcp.with_eltwise = false;
 
     bool args_ok = true
         && one_of(diff_src_d.format(), nCw8c, nChw8c, nCdhw8c)
         && one_of(weights_d.format(), gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i,
                 gOIdhw8o8i, OIdhw8o8i)
         && one_of(diff_dst_d.format(), nCw8c, nChw8c, nCdhw8c)
-        && (jcp.stride_w == 1 || jcp.stride_w == 2)
+        && jcp.stride_w == jcp.stride_h
         && jcp.stride_d == 1
         && jcp.dilate_d == 0
         && jcp.dilate_h == 0
@@ -965,34 +994,69 @@ status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
         && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
         && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
     if (!args_ok) return status::unimplemented;
+    jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad;
+    jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad;
+    int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
+
+    const int max_regs = 15; /* Maximun number of registers available for
+                                result accumulation and delta dst data.
+                                One additional register is reserved for weights
+                                data. */
+
+    /* Find the best blocking with maximum number of fma instructions
+       per ur_w * nb_ic_blocking compute loops. Number of required registers
+       is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
+       ur_w must be divisible by stride_w */
+    if (jcp.stride_w + 1 > max_regs)  /* Minimal possible registers
+                                         distribution exceeds max_regs */
+        return status::unimplemented;
 
-    jcp.ur_w = 3;
-
-    for (int b = 4; b > 1; b--)
+    int best_nfmas = 0;
+    for (int b = 1; b <= 4; b++)
     {
-        if (jcp.nb_ic % b == 0)
+        if (jcp.nb_ic % b != 0)
+            continue;
+
+        for (int u = jcp.stride_w;
+             u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w;
+             u += jcp.stride_w)
         {
-            jcp.nb_ic_blocking = b;
-            break;
+            int ur_w = nstl::min(u, jcp.iw);
+            /* maximum 1 step with l_overflow so far */
+            if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
+                continue;
+            int nfmas = utils::div_up(ur_w, jcp.stride_w) * b;
+            if (nfmas > best_nfmas
+               || (nfmas == best_nfmas && jcp.ur_w < ur_w)) {
+                jcp.ur_w = ur_w;
+                jcp.nb_ic_blocking = b;
+                best_nfmas = nfmas;
+            }
         }
     }
+    if (best_nfmas == 0) /* can't find appropriate blocking */
+        return status::unimplemented;
 
     jcp.ur_w_tail = jcp.iw % jcp.ur_w;
-    int l_overflow = nstl::max(0, jcp.kw - 1 - jcp.l_pad);
-    if (l_overflow > jcp.ur_w) /* maximum 1 step with l_overflow so far */
-        return status::unimplemented;
-    int r_pad = jcp.iwp - jcp.iw - jcp.l_pad;
-    int r_overflow_step0 = nstl::max(0, jcp.kw - 1 - (jcp.iw - jcp.ur_w) - r_pad);
-    if (l_overflow > 0 && r_overflow_step0 > 0) /* no steps with both left and
-                                                   right overflow so far */
+
+    int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail
+                    - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
+    /* maximum 1 ur_w block with r_overflow so far */
+    if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
         return status::unimplemented;
-    int r_overflow_no_tail = nstl::max(0,jcp.kw - 1 - jcp.ur_w_tail - r_pad);
-    if (r_overflow_no_tail > jcp.ur_w) /* maximum 1 ur_w block with
-                                          r_overflow so far */
+
+    if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
         return status::unimplemented;
+
     return status::success;
 }
 
+void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+    UNUSED(scratchpad);
+    UNUSED(jcp);
+}
+
 void jit_avx2_conv_bwd_weights_kernel_f32::generate() {
     this->preamble();
 
@@ -1045,8 +1109,6 @@ status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
 
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
-    jcp.with_eltwise = false;
-    jcp.eltwise_alpha = 0;
 
     const bool flat = jcp.ic == 3;
     const bool mimo = !flat;
@@ -1097,9 +1159,16 @@ status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
     jcp.oc_block = simd_w;
     jcp.nb_oc = jcp.oc / jcp.oc_block;
     jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
+
     return status::success;
 }
 
+void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+    if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+        scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+}
+
 inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
 {
     Label kd_comeback_loop;