Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_conv_kernel.cpp
index b94295b..09c60dc 100644 (file)
@@ -1,5 +1,5 @@
 /*******************************************************************************
-* Copyright 2018 Intel Corporation
+* Copyright 2018-2019 Intel Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
 * limitations under the License.
 *******************************************************************************/
 
+#include <common/memory_tracking.hpp>
 #include "c_types_map.hpp"
 #include "nstl.hpp"
 #include "type_helpers.hpp"
@@ -30,37 +31,12 @@ namespace cpu {
 
 using namespace mkldnn::impl::prop_kind;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 using namespace Xbyak;
 
 template <cpu_isa_t isa>
-bool jit_uni_x8s8s32x_conv_fwd_kernel<isa>::maybe_relu(int position) {
-    using namespace primitive_kind;
-    const auto &p = attr_.post_ops_;
-
-    if (position == 0) {
-        /* relu before sum */
-        return false
-               || jcp.with_eltwise
-               || p.contain(eltwise, 0)
-               || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
-    } else if (position == 1) {
-        /* relu after sum */
-        const int sum_idx = p.contain(sum, 0)
-                            ? 0 : (p.contain(sum, 1) ? 1 : -1);
-        if (sum_idx == -1)
-            return false;
-
-        return false
-               || p.contain(eltwise, sum_idx + 1)
-               || jcp.dst_dt == data_type::u8;
-    }
-
-    return false;
-}
-
-template <cpu_isa_t isa>
 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in,
         const Xbyak::Operand &op, bool scalar_load) {
     Xmm xmm_in = Xmm(vmm_in.getIdx());
@@ -118,7 +94,7 @@ void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op,
             if (isa != sse42 && !scalar_store)
                 vpermq(ymm_dst, ymm_dst, 0x08);
 
-            uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
+            uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
 
             if (scalar_store) {
                 movq(reg_tmp_64, xmm_dst);
@@ -136,7 +112,7 @@ void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op,
             if (isa != sse42 && !scalar_store)
                 vpermq(ymm_dst, ymm_dst, 0x08);
 
-            uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
+            uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
 
             if (scalar_store) {
                 movq(reg_tmp_64, xmm_dst);
@@ -177,32 +153,27 @@ void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::apply_filter(int ur_w, int pad_l, in
         for (int r = 0; r < repeats; r++) {
             for (int jj = _start; jj < _end; jj++) {
                 int inp_off = (ki * dilate_w + jj * stride_w - pad_l) * jcp.ic * jcp.ngroups;
-                    if (tail_size > 0) {
-                        if (h_padded || jj < jj_start || jj >= jj_end) {
-                            uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
-                            uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
-                            uni_vandps(get_src_reg(jj), get_src_reg(jj), vmm_mask);
-                            uni_vpbroadcastd(get_src_reg(jj), Xmm(get_src_reg(jj).getIdx()));
-                        } else {
-                            uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
-
-                            if (jcp.signed_input) {
-                                uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
-                            }
-
-                            uni_vandps(get_src_reg(jj), get_src_reg(jj), vmm_mask);
-                            uni_vpbroadcastd(get_src_reg(jj), Xmm(get_src_reg(jj).getIdx()));
-                        }
+                if (tail_size > 0) {
+                    if (h_padded || jj < jj_start || jj >= jj_end) {
+                        uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
+                        uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
                     } else {
-                        if (h_padded || jj < jj_start || jj >= jj_end) {
-                            uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
-                        } else {
-                            uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
-                        }
+                        uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
 
-                        if (jcp.signed_input)
+                        if (jcp.signed_input) {
                             uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
+                        }
+                    }
+                } else {
+                    if (h_padded || jj < jj_start || jj >= jj_end) {
+                        uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
+                    } else {
+                        uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
                     }
+
+                    if (jcp.signed_input)
+                        uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
+                }
             }
 
             for (int ii = 0; ii < oc_blocks; ii++) {
@@ -279,7 +250,6 @@ void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::kh_loop(int ur_w, int pad_l, int pad
     mov(imm_addr64, l_table);
     uni_vmovups(vmm_one,   ptr[imm_addr64 + 0 * vlen]);
     uni_vmovups(vmm_shift, ptr[imm_addr64 + 1 * vlen]);
-    uni_vmovups(vmm_mask, ptr[imm_addr64 + 4 * vlen]);
 
     if (jcp.signed_input) {
         mov(reg_overflow,  ptr[param1 + GET_OFF(t_overflow)]);
@@ -349,6 +319,7 @@ void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::width_blk_step(int ur_w, int pad_l,
 
     kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
 
+    pop(reg_oc_off);
     pop(reg_scales_base);
 
     mov(imm_addr64, l_table);
@@ -359,140 +330,143 @@ void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::width_blk_step(int ur_w, int pad_l,
     const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
 
     for (int r = 0; r < repeats; r++) {
+        auto get_dst_off = [=](int ii, int jj) {
+            if (jcp.with_dw_conv)
+                return (ii * jcp_dw.kh * jcp.ow + jj) * jcp.oc_block + r * (jcp.oc_block / 2);
+            else
+                return ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
+        };
+
         int tail_size = isa == avx2 ? oc_step : nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2);
         bool is_scalar_store = isa == avx2 ? tail_size < jcp.oc_block : tail_size < jcp.oc_block / 2;
 
-        if (is_scalar_store) {
+        for (int ii = 0; ii < oc_blocks; ii++) {
+            if (jcp.with_bias) {
+                int b_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
+                cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], false);
+
+                if (jcp.signed_input)
+                    uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha);
+            }
+
             for (int jj = 0; jj < ur_w; jj++) {
-                Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + jj);
+                Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
                 uni_vcvtdq2ps(vmm_dst, vmm_dst);
-                uni_vmovups(vmm_reminder_dst, vmm_dst);
 
-                for (int oc = 0; oc < tail_size; oc++) {
-                    uni_vmovups(vmm_dst, vmm_reminder_dst);
+                if (jcp.signed_input) {
+                    int c_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
+                    cvt2ps(data_type::s32, vmm_comp, ptr[reg_compensation_base + c_off * sizeof(int32_t)], false);
+                }
 
-                    if (jcp.with_bias) {
-                        int b_off = r * (jcp.oc_block / 2) + oc;
-                        cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], true);
+                if (jcp.signed_input)
+                    uni_vaddps(vmm_dst, vmm_dst, vmm_comp);
+                if (jcp.with_bias)
+                    uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
 
-                        if (jcp.signed_input)
-                            uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha);
-                    }
-                    if (jcp.signed_input) {
-                        int c_off = r * (jcp.oc_block / 2) + oc;
-                        cvt2ps(data_type::s32, vmm_comp, ptr[reg_compensation_base + c_off * sizeof(int32_t)], true);
-                    }
+                int s_off = jcp.is_oc_scale * (ii * jcp.oc_block + r * (jcp.oc_block / 2));
+                cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], false);
+                uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
+            }
+        }
 
-                    if (jcp.signed_input)
-                        uni_vaddps(vmm_dst, vmm_dst, vmm_comp);
-                    if (jcp.with_bias)
-                        uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
+        int eltwise_inj_idx = 0;
+        int depthwise_inj_idx = 0;
+        int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
+        for (int i = 0; i < end_idx; i++) {
+            int start_idx = 1 + r * jcp.ur_w * jcp.nb_oc_blocking;
+
+            auto& post_op = p.entry_[i];
+            if (post_op.is_eltwise()) {
+                eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + oc_blocks * ur_w);
+                eltwise_inj_idx++;
+            } else if (post_op.is_depthwise()) {
+                mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+                mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+                add(reg_d_weights, reg_oc_off);
+                add(reg_d_bias, reg_oc_off);
+
+                if (r == 1) {
+                    add(reg_d_weights, (jcp.oc_block / 2) * sizeof(float));
+                    add(reg_d_bias, (jcp.oc_block / 2) * sizeof(float));
+                }
 
-                    int s_off = jcp.is_oc_scale * (r * (jcp.oc_block / 2) + oc);
-                    cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], true);
-                    uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
+                for (int ii = 0; ii < oc_blocks; ii++) {
+                    depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx + ur_w * ii,
+                            start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
 
-                    int o_off = jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
-                    if (jcp.with_sum) {
-                        uni_vpxor(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst);
-                        cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], true);
+                    add(reg_d_weights, jcp.oc_block * sizeof(float));
+                    add(reg_d_bias, jcp.oc_block * sizeof(float));
+                }
 
-                        if (p_sum_scale == 1.f) {
-                            uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+                depthwise_inj_idx++;
+            } else if (post_op.is_sum(false)) {
+                for (int ii = 0; ii < oc_blocks; ii++) {
+                    for (int jj = 0; jj < ur_w; jj++) {
+                        Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
+                        int o_off = get_dst_off(ii, jj);
+
+                        if (is_scalar_store) {
+                            for (int oc = 0; oc < tail_size; oc++) {
+                                uni_vpxor(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst);
+                                cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + (o_off + oc) * jcp.typesize_out], true);
+
+                                if (oc < jcp.oc_block / 2) {
+                                    uni_vpslldq(vmm_prev_dst, vmm_prev_dst, oc * sizeof(float));
+                                } else {
+                                    Ymm ymm_prev_dst = Ymm(vmm_prev_dst.getIdx());
+                                    vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
+                                    vpslldq(vmm_prev_dst, vmm_prev_dst, (oc - jcp.oc_block / 2) * sizeof(float));
+                                }
+
+                                if (p_sum_scale == 1.f) {
+                                    uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+                                } else {
+                                    uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
+                                }
+                            }
                         } else {
-                            uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
-                        }
-                    }
-
-                    if (maybe_relu(0)) {
-                        uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
-                        uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
-                    }
-
-                    if (maybe_relu(1)) {
-                        uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
-                        uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
-                    }
-
-                    if (jcp.dst_dt != data_type::f32) {
-                        if (attr_.round_mode_ == round_mode::nearest)
-                            uni_vcvtps2dq(vmm_dst, vmm_dst);
-                        else if (attr_.round_mode_ == round_mode::down) {
-                            uni_vroundps(vmm_dst, vmm_dst, 1);
-                            uni_vcvtps2dq(vmm_dst, vmm_dst);
-                        } else
-                            assert(!"unimplemented");
-                    }
-
-                    store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
+                            cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], false);
 
-                    if (isa == avx2) {
-                        vperm2i128(ymm_tmp, ymm_reminder_dst, ymm_reminder_dst, 0x01);
-                        vpalignr(ymm_reminder_dst, ymm_tmp, ymm_reminder_dst, jcp.typesize_out);
-                    } else {
-                        psrldq(vmm_reminder_dst, jcp.typesize_out);
+                            if (p_sum_scale == 1.f) {
+                                uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+                            } else {
+                                uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
+                            }
+                        }
                     }
                 }
             }
-        } else {
-            for (int ii = 0; ii < oc_blocks; ii++) {
-                if (jcp.with_bias) {
-                    int b_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
-                    cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], false);
+        }
 
-                    if (jcp.signed_input)
-                        uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha);
+        for (int ii = 0; ii < oc_blocks; ii++) {
+            for (int jj = 0; jj < ur_w; jj++) {
+                Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
+                int o_off = get_dst_off(ii, jj);
+
+                if (jcp.dst_dt != data_type::f32) {
+                    if (attr_.round_mode_ == round_mode::nearest)
+                        uni_vcvtps2dq(vmm_dst, vmm_dst);
+                    else if (attr_.round_mode_ == round_mode::down) {
+                        uni_vroundps(vmm_dst, vmm_dst, 1);
+                        uni_vcvtps2dq(vmm_dst, vmm_dst);
+                    } else
+                        assert(!"unimplemented");
                 }
 
-                for (int jj = 0; jj < ur_w; jj++) {
-                    Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
-                    uni_vcvtdq2ps(vmm_dst, vmm_dst);
-
-                    if (jcp.signed_input) {
-                        int c_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
-                        cvt2ps(data_type::s32, vmm_comp, ptr[reg_compensation_base + c_off * sizeof(int32_t)], false);
-                    }
-
-                    if (jcp.signed_input)
-                        uni_vaddps(vmm_dst, vmm_dst, vmm_comp);
-                    if (jcp.with_bias)
-                        uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
-
-                    int s_off = jcp.is_oc_scale * (ii * jcp.oc_block + r * (jcp.oc_block / 2));
-                    cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], false);
-                    uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
-
-                    int o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
-                    if (jcp.with_sum) {
-                        cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], false);
+                if (is_scalar_store) {
+                    for (int oc = 0; oc < tail_size; oc++) {
+                        store_dst(ptr[reg_output + (o_off + oc) * jcp.typesize_out], vmm_dst, true);
 
-                        if (p_sum_scale == 1.f) {
-                            uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+                        if (isa == avx2) {
+                            Ymm ymm_dst = Ymm(vmm_dst.getIdx());
+                            vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
+                            vpalignr(ymm_dst, ymm_tmp, ymm_dst, jcp.typesize_out);
                         } else {
-                            uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
+                            psrldq(vmm_dst, jcp.typesize_out);
                         }
                     }
-
-                    if (maybe_relu(0)) {
-                        uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
-                        uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
-                    }
-
-                    if (maybe_relu(1)) {
-                        uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
-                        uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
-                    }
-
-                    if (jcp.dst_dt != data_type::f32) {
-                        if (attr_.round_mode_ == round_mode::nearest)
-                            uni_vcvtps2dq(vmm_dst, vmm_dst);
-                        else if (attr_.round_mode_ == round_mode::down) {
-                            uni_vroundps(vmm_dst, vmm_dst, 1);
-                            uni_vcvtps2dq(vmm_dst, vmm_dst);
-                        } else
-                            assert(!"unimplemented");
-                    }
-
+                } else {
                     store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
                 }
             }
@@ -500,6 +474,7 @@ void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::width_blk_step(int ur_w, int pad_l,
     }
 
     push(reg_scales_base);
+    push(reg_oc_off);
 }
 
 template <cpu_isa_t isa>
@@ -513,6 +488,7 @@ inline void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::solve_common(int oc_blocks, i
     int dilate_w = jcp.dilate_w + 1;
     int str_w = jcp.stride_w;
     const int inp_mult = jcp.ic * jcp.ngroups;
+    const int out_mult = jcp.with_dw_conv ? jcp.oc_block : jcp.oc * jcp.ngroups;
 
     int l_pad = jcp.l_pad;
     int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
@@ -529,6 +505,7 @@ inline void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::solve_common(int oc_blocks, i
     push(reg_output_base);
     push(reg_kernel_base);
     push(reg_scales_base);
+    push(reg_oc_off);
 
     if (l_pad > 0) {
         n_oi--;
@@ -537,7 +514,7 @@ inline void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::solve_common(int oc_blocks, i
         else
             width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad"
         add(reg_input, jcp.typesize_in * (ur_w * str_w - l_pad) * inp_mult);
-        add(reg_output, jcp.typesize_out * ur_w * jcp.oc * jcp.ngroups);
+        add(reg_output, jcp.typesize_out * ur_w * out_mult);
     }
 
     Label ow_loop_label;
@@ -548,7 +525,7 @@ inline void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::solve_common(int oc_blocks, i
 
         width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle"
         add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
-        add(reg_output, jcp.typesize_out * ur_w * jcp.oc * jcp.ngroups);
+        add(reg_output, jcp.typesize_out * ur_w * out_mult);
 
         inc(reg_oi_iter);
         cmp(reg_oi_iter, n_oi);
@@ -558,12 +535,13 @@ inline void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::solve_common(int oc_blocks, i
     if (r_pad1 > 0 && n_oi >=0) {
         width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad"
         add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
-        add(reg_output, jcp.typesize_out * ur_w * jcp.oc * jcp.ngroups);
+        add(reg_output, jcp.typesize_out * ur_w * out_mult);
     }
 
     if (ur_w_tail != 0)
         width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail"
 
+    pop(reg_oc_off);
     pop(reg_scales_base);
     pop(reg_kernel_base);
     pop(reg_output_base);
@@ -573,56 +551,84 @@ inline void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::solve_common(int oc_blocks, i
 template <cpu_isa_t isa>
 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::generate()
 {
+    const auto &p = attr_.post_ops_;
+    int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
+    for (int i = 0; i < end_idx; i++) {
+        auto &post_op = p.entry_[i];
+        if (post_op.is_eltwise()) {
+            eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
+                    this,
+                    post_op.eltwise.alg,
+                    post_op.eltwise.alpha,
+                    post_op.eltwise.beta
+            ));
+        } else if (post_op.is_depthwise()) {
+            depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
+                    this,
+                    post_op.depthwise.alg
+            ));
+        }
+    }
+
     this->preamble();
 
     mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
     mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
     mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
-    mov(reg_oc, ptr[this->param1 + GET_OFF(oc_work)]);
+    mov(reg_oc_work, ptr[this->param1 + GET_OFF(oc_work)]);
     if (jcp.with_bias)
         mov(reg_bias_base, ptr[this->param1 + GET_OFF(bias)]);
     mov(reg_scales_base, ptr[this->param1 + GET_OFF(scales)]);
     if (jcp.signed_input)
         mov(reg_compensation_base, ptr[param1 + GET_OFF(compensation)]);
+    mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
 
     Label main_loop_label;
     Label tail_label;
     Label exit_label;
 
-    cmp(reg_oc, jcp.nb_oc_blocking * jcp.oc_block);
+    cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
     jne(main_loop_label, T_NEAR);
 
     solve_common(jcp.nb_oc_blocking, jcp.oc_block);
 
-    sub(reg_oc, jcp.nb_oc_blocking * jcp.oc_block);
+    sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
 
     jmp(exit_label, T_NEAR);
 
     L(main_loop_label); {
-        cmp(reg_oc, jcp.oc_block);
+        cmp(reg_oc_work, jcp.oc_block);
         jl(tail_label, T_NEAR);
 
         solve_common(1, jcp.oc_block);
 
-        sub(reg_oc, jcp.oc_block);
+        sub(reg_oc_work, jcp.oc_block);
         add(reg_kernel_base, jcp.oc_block * jcp.nb_ic * jcp.kh * jcp.kw * jcp.ic_block * jcp.typesize_in);
-        add(reg_output_base, jcp.oc_block * jcp.typesize_out);
+        if (jcp.with_dw_conv)
+            add(reg_output_base, jcp.oc_block * jcp_dw.kh * jcp.ow * jcp.typesize_out);
+        else
+            add(reg_output_base, jcp.oc_block * jcp.typesize_out);
         add(reg_bias_base, jcp.oc_block * jcp.typesize_bia);
         add(reg_scales_base, jcp.is_oc_scale * jcp.oc_block * sizeof(float));
         add(reg_compensation_base, jcp.oc_block * sizeof(int32_t));
+        add(reg_oc_off, jcp.oc_block * sizeof(float));
 
         jmp(main_loop_label, T_NEAR);
     }
 
     L(tail_label);
 
-    solve_common(1, jcp.oc % jcp.oc_block);
+    if (jcp.oc % jcp.oc_block != 0)
+        solve_common(1, jcp.oc % jcp.oc_block);
 
     L(exit_label);
 
     this->postamble();
 
     prepare_table();
+
+    for (auto& inj : eltwise_injectors)
+        inj->prepare_table();
 }
 
 template <cpu_isa_t isa>
@@ -672,43 +678,29 @@ void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::prepare_table() {
             dd(cvals_sum_scale[i]);
         }
     }
-
-    for (size_t i = 0; i < sizeof(cvals_shift) / sizeof(cvals_shift[0]); ++i) {
-        for (size_t d = 0; d < vlen / sizeof(int8_t); ++d) {
-            if ((int)d < jcp.ic % jcp.ic_block)
-                db(255);
-            else
-                db(0);
-        }
-    }
 }
 
 template <cpu_isa_t isa>
 bool jit_uni_x8s8s32x_conv_fwd_kernel<isa>::post_ops_ok(
         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
-    using namespace primitive_kind;
     const auto &p = attr.post_ops_;
 
-    auto is_relu = [&](int idx) {
-        return p.entry_[idx].kind == eltwise
-               && p.entry_[idx].eltwise.scale == 1.
-               && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
-               && p.entry_[idx].eltwise.alpha == 0.;
-    };
+    auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+    auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
+    auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
+    auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
+    auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
         case 0: return true;
-        case 1: return true
-                       && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0))
-                       && IMPLICATION(!jcp.with_eltwise, is_relu(0) || p.contain(sum, 0));
-        case 2: return true
-                       && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0) && is_relu(1))
-                       && IMPLICATION(!jcp.with_eltwise, false
-                                                         || (p.contain(sum, 0) && is_relu(1))
-                                                         || (p.contain(sum, 1) && is_relu(0)));
-        case 3: return true
-                       && jcp.with_eltwise == false
-                       && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
+        case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
+        case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
+                       (is_dw_conv(0) && is_simple(1)) || (is_simple(0) && is_dw_conv(1)) ||
+                       (is_simple(0) && is_simple(1));
+        case 3: return (is_simple(0) && is_sum(1) && is_simple(2)) ||
+                       (is_simple(0) && is_dw_conv(1) && is_simple(2)) ||
+                       (is_dw_conv(0) && is_simple(1) && is_simple(2));
+        case 4: return (is_simple(0) && is_dw_conv(1) && is_simple(2) && is_simple(3));
         default: return false;
     }
 
@@ -720,7 +712,7 @@ status_t jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
         const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
         cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
         cpu_memory_t::pd_t &bias_pd,
-        const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+        const primitive_attr_t &attr)
 {
     if (!mayiuse(isa)) return status::unimplemented;
 
@@ -758,8 +750,6 @@ status_t jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
 
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
 
     jcp.signed_input = src_d.data_type() == data_type::s8;
 
@@ -772,14 +762,23 @@ status_t jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
     jcp.oc_padded = rnd_up(jcp.oc, jcp.oc_block);
     jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
 
+    if (jcp.ngroups != 1) {
+        if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
+            return status::unimplemented;
+    }
+
     if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
 
     const auto &p = attr.post_ops_;
-    jcp.with_sum = p.find(primitive_kind::sum) != -1;
-    if (!jcp.with_eltwise) {
-        jcp.with_eltwise = p.find(primitive_kind::eltwise) != -1;
-        jcp.eltwise_alpha = 0.f;
+
+    int dw_conv_ind = p.find(primitive_kind::convolution);
+    jcp.with_dw_conv = dw_conv_ind != -1;
+    if (jcp.with_dw_conv) {
+        jcp.dw_conv_oh = jcp.oh;
+        jcp.dw_conv_ow = jcp.ow;
+        jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
+        jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
     }
 
     auto desired_act_fmt = nhwc;
@@ -808,6 +807,7 @@ status_t jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
             return status::unimplemented;
     }
 
+    jcp.src_dt = cd.src_desc.data_type;
     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
     jcp.dst_dt = cd.dst_desc.data_type;
 
@@ -824,9 +824,15 @@ status_t jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
 
     jcp.ur_h = 1; /* no code-unrolling by h so far */
-    jcp.ur_w = isa == avx2 ? 3 : 2;
-    jcp.nb_oc_blocking = 2;
-    if (jcp.nb_oc % jcp.nb_oc_blocking != 0) jcp.nb_oc_blocking = 1;
+    jcp.ur_w = isa == avx2 ? 4 : 2;
+    jcp.nb_oc_blocking = nstl::min(2, jcp.nb_oc);
+    jcp.max_regs_ur = 12;
+
+    // WA to prevent fallback on gemm implementation
+    if (isa == sse42 && jcp.ic == 3) {
+        jcp.ur_w = 4;
+        jcp.nb_oc_blocking = 1;
+    }
 
     if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
@@ -839,24 +845,42 @@ status_t jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
 
     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
         + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+    if (r_pad_no_tail > jcp.ur_w)
+        return status::unimplemented;
 
-    if (r_pad_no_tail > jcp.ur_w) {
-        /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
-        jcp.ur_w = r_pad_no_tail + 1;
-        jcp.ur_w_tail = jcp.ow % jcp.ur_w;
-        /* check again ... */
-        r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
-            + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
-        if ((r_pad_no_tail > jcp.ur_w) || (jcp.ow < jcp.ur_w))
-            return status::unimplemented;
-    }
-    if (jcp.l_pad > jcp.ur_w) return status::unimplemented;
+    if (jcp.l_pad > jcp.ur_w)
+        return status::unimplemented;
 
     jcp.wei_adj_scale = (jcp.signed_input) ? (1.0f / 2.0f) : 1.0f;
 
     return status::success;
 }
 
+template <cpu_isa_t isa>
+void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw,
+        const primitive_attr_t &attr) {
+    if (jcp.oc != jcp.oc_padded)
+        scratchpad.book(key_conv_padded_bias, (size_t)jcp.typesize_bia * jcp.oc_padded);
+
+    if (jcp.signed_input) {
+        size_t count = nstl::max(attr.output_scales_.count_, 8);
+        scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
+
+        if (jcp.oc != jcp.oc_padded)
+            scratchpad.book(key_conv_padded_compensation, sizeof(int32_t) * jcp.oc_padded);
+    }
+
+    if (jcp.with_dw_conv) {
+        const int nthreads = mkldnn_get_max_threads();
+        size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
+        scratchpad.book(key_dw_conv_buffer, jcp_dw.typesize_in * dw_conv_buffer_size_ * nthreads);
+
+        if (jcp.oc != jcp.oc_padded)
+            scratchpad.book(key_dw_conv_padded_bias, (size_t)jcp_dw.typesize_bia * jcp.oc_padded);
+    }
+}
+
 template struct jit_uni_x8s8s32x_conv_fwd_kernel<avx2>;
 template struct jit_uni_x8s8s32x_conv_fwd_kernel<sse42>;