Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_depthwise.cpp
index 634e9f9..9aad4f1 100644 (file)
@@ -1,5 +1,5 @@
 /*******************************************************************************
-* Copyright 2018 Intel Corporation
+* Copyright 2018-2019 Intel Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -56,7 +56,7 @@ struct jit_uni_depthwise_kernel_f32 : public c_compatible {
 template <cpu_isa_t isa>
 int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg) {
     switch (depthwise_alg) {
-        case alg_kind::depthwise_scale_shift: return 0;
+        case alg_kind::depthwise_scale_shift: return isa == sse42 ? 1 : 0;
         case alg_kind::depthwise_prelu: return 2;
         default: assert(!"unsupported depthwise algorithm");
     }
@@ -132,8 +132,15 @@ void jit_uni_depthwise_injector_f32<isa>::assign_regs() {
 template <cpu_isa_t isa>
 void jit_uni_depthwise_injector_f32<isa>::scale_shift_compute_vector(const Vmm &vmm_src,
         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
-    h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
-    h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
+    if (isa == sse42) {
+        h->movups(vmm_mask, h->ptr[p_weights]);
+        h->mulps(vmm_src, vmm_mask);
+        h->movups(vmm_mask, h->ptr[p_bias]);
+        h->addps(vmm_src, vmm_mask);
+    } else {
+        h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
+        h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
+    };
 }
 
 template <cpu_isa_t isa>
@@ -145,8 +152,8 @@ void jit_uni_depthwise_injector_f32<isa>::prelu_compute_vector(const Vmm &vmm_sr
     if (isa == sse42) {
         h->pxor(vmm_mask, vmm_mask);
         h->cmpps(vmm_mask, vmm_src, _cmp_gt_os);
-        h->movups(vmm_aux0, vmm_src);
-        h->mulps(vmm_aux0, h->ptr[p_weights]);
+        h->movups(vmm_aux0, h->ptr[p_weights]);
+        h->mulps(vmm_aux0, vmm_src);
         h->blendvps(vmm_src, vmm_aux0);
     } else if (isa == avx2) {
         h->vxorps(vmm_mask, vmm_mask, vmm_mask);
@@ -202,7 +209,7 @@ struct jit_uni_scale_shift_kernel_f32 : public jit_uni_depthwise_kernel_f32,
         assert(desc.alg_kind == alg_kind::depthwise_scale_shift);
         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
 
-        bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw ;
+        bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw;
 
         Reg64 param = abi_param1;
 
@@ -465,30 +472,30 @@ status_t jit_uni_depthwise_fwd_t<isa>::pd_t::init() {
 }
 
 template <cpu_isa_t isa>
-jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *pd,
+jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *apd,
         const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr),
+    : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr),
       padded_weights_(nullptr), padded_bias_(nullptr) {
-    const auto &desc = *conf_.desc();
+    const auto &desc = *pd()->desc();
     switch (desc.alg_kind) {
         case alg_kind::depthwise_scale_shift:
-            kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd->with_bias()); break;
+            kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd()->with_bias()); break;
         case alg_kind::depthwise_prelu:
-            kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd->with_bias()); break;
+            kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd()->with_bias()); break;
         default: assert(!"unknown depthwise alg_kind");
     }
 
     const int simd_w = isa == avx512_common ? 16 : 8;
-    const memory_desc_wrapper data_d(conf_.src_pd());
+    const memory_desc_wrapper data_d(pd()->src_pd());
     const int c_without_padding = data_d.dims()[1];
     const int c_padded = rnd_up(c_without_padding, simd_w);
 
-    if (conf_.want_padded_weights()) {
+    if (pd()->want_padded_weights()) {
         padded_weights_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
         for (int oc = c_without_padding; oc < c_padded; ++oc)
             padded_weights_[oc] = 0;
 
-        if (conf_.with_bias()) {
+        if (pd()->with_bias()) {
             padded_bias_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
             for (int oc = c_without_padding; oc < c_padded; ++oc)
                 padded_bias_[oc] = 0;
@@ -504,15 +511,15 @@ jit_uni_depthwise_fwd_t<isa>::~jit_uni_depthwise_fwd_t() {
 }
 
 template <cpu_isa_t isa>
-void jit_uni_depthwise_fwd_t<isa>::execute_forward() {
+void jit_uni_depthwise_fwd_t<isa>::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper data_d(conf_.src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+    const memory_desc_wrapper data_d(pd()->src_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper bias_d(pd()->weights_pd(1));
 
     const int N = data_d.dims()[0];
     const int C = data_d.dims()[1];
@@ -523,12 +530,12 @@ void jit_uni_depthwise_fwd_t<isa>::execute_forward() {
     const int ch_block_size = data_d.format() == nchw ? 1 : simd_w;
     const int CB = div_up(C, ch_block_size);
 
-    if (conf_.want_padded_weights()) {
+    if (pd()->want_padded_weights()) {
         for (int oc = 0; oc < C; ++oc)
             padded_weights_[oc] = weights[oc];
         weights = padded_weights_;
 
-        if (conf_.with_bias()) {
+        if (pd()->with_bias()) {
             for (int oc = 0; oc < C; ++oc)
                 padded_bias_[oc] = bias[oc];
             bias = padded_bias_;
@@ -537,7 +544,7 @@ void jit_uni_depthwise_fwd_t<isa>::execute_forward() {
 
     parallel_nd(N, CB, H,
         [&](int n, int cb, int h) {
-        jit_args arg = {};
+        auto arg = jit_args();
 
         arg.from    = &src[data_d.blk_off(n, cb, h)];
         arg.to      = &dst[data_d.blk_off(n, cb, h)];
@@ -564,21 +571,38 @@ void jit_uni_dw_conv_row_f32<isa>::load_src(int ur_w) {
         for (int ow = 0; ow < ur_w; ow++) {
             Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
 
-            if (this->jcp.with_bias)
-                uni_vmovups(vmm_acc, vmmword[reg_bias + i*4*sizeof(float)]);
-            else
-                uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
-
-            int o_off = ow*jcp.ch_block + i*4;
-            if (this->jcp.with_sum)
-                uni_vaddps(vmm_acc, vmm_acc,
-                           vmmword[reg_output + o_off*sizeof(float)]);
+            uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
         }
     }
 }
 
 template <cpu_isa_t isa>
 void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
+    auto load_src = [=](Vmm vmm_src, const Xbyak::Address &op) {
+        if (jcp.src_dt == data_type::u8) {
+            uni_vpmovzxbd(vmm_src, op);
+        } else {
+            uni_vmovups(vmm_src, op);
+        }
+    };
+
+    auto load_ker = [=](Vmm vmm_ker, const Xbyak::Address &op) {
+        if (jcp.src_dt == data_type::u8) {
+            uni_vpmovsxbd(vmm_ker, op);
+        } else {
+            uni_vmovups(vmm_ker, op);
+        }
+    };
+
+    auto compute = [=](Vmm vmm_acc, Vmm vmm_src, Vmm vmm_ker) {
+        if (jcp.src_dt == data_type::u8) {
+            uni_vpmulld(vmm_src, vmm_src, vmm_ker);
+            uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
+        } else {
+            uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+        }
+    };
+
     int ch_blk = jcp.ch_block;
     int stride_w = jcp.stride_w;
 
@@ -590,69 +614,63 @@ void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
     jl(exit_label, T_NEAR);
     for (int i = 0; i < repeats; i++) {
         for (int kw = 0; kw < kw_size; kw++) {
-            int ker_off = kw * ch_blk + i*4;
+            int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
 
             Vmm vmm_ker = get_ker_reg(0);
-            uni_vmovups(vmm_ker, ptr[aux_reg_kernel
-                                     + ker_off * sizeof(float)]);
+            load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
 
             for (int ow = 0; ow < ur_w; ow++) {
-                int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
+                int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
 
                 Vmm vmm_src = get_src_reg(0);
-                uni_vmovups(vmm_src, ptr[aux_reg_input0
-                                         + inp_off * sizeof(float)]);
+                load_src(vmm_src, ptr[aux_reg_input0 + inp_off * jcp.typesize_in]);
 
                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
-                uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+                compute(vmm_acc, vmm_src, vmm_ker);
             }
         }
     }
-    add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
+    add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
 
     cmp(reg_kh, 2);
     jl(exit_label, T_NEAR);
     for (int i = 0; i < repeats; i++) {
         for (int kw = 0; kw < kw_size; kw++) {
-            int ker_off = kw * ch_blk + i*4;
+            int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
 
             Vmm vmm_ker = get_ker_reg(0);
-            uni_vmovups(vmm_ker, ptr[aux_reg_kernel
-                                     + ker_off * sizeof(float)]);
+            load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
 
             for (int ow = 0; ow < ur_w; ow++) {
-                int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
+                int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
 
                 Vmm vmm_src = get_src_reg(0);
-                uni_vmovups(vmm_src, ptr[aux_reg_input1
-                                         + inp_off * sizeof(float)]);
+                load_src(vmm_src, ptr[aux_reg_input1 + inp_off * jcp.typesize_in]);
 
                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
-                uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+                compute(vmm_acc, vmm_src, vmm_ker);
             }
         }
     }
-    add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
+    add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
 
     cmp(reg_kh, 3);
     jl(exit_label, T_NEAR);
     for (int i = 0; i < repeats; i++) {
         for (int kw = 0; kw < kw_size; kw++) {
-            int ker_off = kw * ch_blk + i*4;
+            int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
 
             Vmm vmm_ker = get_ker_reg(0);
-            uni_vmovups(vmm_ker, ptr[aux_reg_kernel
-                                     + ker_off * sizeof(float)]);
+            load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
 
             for (int ow = 0; ow < ur_w; ow++) {
-                int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
+                int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
 
                 Vmm vmm_src = get_src_reg(0);
-                uni_vmovups(vmm_src, ptr[aux_reg_input2
-                                         + inp_off * sizeof(float)]);
+                load_src(vmm_src, ptr[aux_reg_input2 + inp_off * jcp.typesize_in]);
 
                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
-                uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+                compute(vmm_acc, vmm_src, vmm_ker);
             }
         }
     }
@@ -661,34 +679,276 @@ void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
 }
 
 template <cpu_isa_t isa>
-void jit_uni_dw_conv_row_f32<isa>::apply_activation(int ur_w) {
-    if (this->jcp.with_eltwise) {
-        int repeats = isa == sse42 ? 2 : 1;
-        eltwise_injector->compute_vector_range(4, repeats * ur_w + 4);
+void jit_uni_dw_conv_row_f32<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
+    Xmm xmm_in = Xmm(vmm_in.getIdx());
+
+    switch (type_in) {
+        case data_type::f32:
+        case data_type::s32:
+            if (scalar_load) {
+                mov(reg_tmp_32, op);
+                movq(xmm_in, reg_tmp_64);
+            } else {
+                uni_vmovups(vmm_in, op);
+            }
+            break;
+        case data_type::s8:
+            if (scalar_load) {
+                movsx(reg_tmp_32, op);
+                movq(xmm_in, reg_tmp_64);
+            } else {
+                uni_vpmovsxbd(vmm_in, op);
+            }
+            break;
+        case data_type::u8:
+            if (scalar_load) {
+                movzx(reg_tmp_32, op);
+                movq(xmm_in, reg_tmp_64);
+            } else {
+                uni_vpmovzxbd(vmm_in, op);
+            }
+            break;
+        default: assert(!"unsupported data type");
     }
+
+    if (type_in != data_type::f32)
+        uni_vcvtdq2ps(vmm_in, vmm_in);
 }
 
 template <cpu_isa_t isa>
-void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w) {
+void jit_uni_dw_conv_row_f32<isa>::apply_postprocessing(int ur_w, int oc_step) {
     int repeats = isa == sse42 ? 2 : 1;
+
+    for (int r = 0; r < repeats; r++) {
+        for (int ow = 0; ow < ur_w; ow++) {
+            if (jcp.src_dt == data_type::u8) {
+                uni_vcvtdq2ps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow));
+            }
+
+            if (jcp.with_bias) {
+                int b_off = r * (jcp.ch_block / 2);
+                cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias + b_off * jcp.typesize_bia], false);
+                uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_bias);
+            }
+        }
+    }
+
+    if (jcp.with_sum) {
+        for (int r = 0; r < repeats; r++) {
+            int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - r * jcp.ch_block / 2) : oc_step;
+            bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
+
+            for (int ow = 0; ow < ur_w; ow++) {
+                if (is_scalar_store) {
+                    for (int oc = 0; oc < tail_size; oc++) {
+                        int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
+
+                        uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
+                        cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
+
+                        if (oc >= jcp.ch_block / 2) {
+                            vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
+                        }
+                        uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
+
+                        uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
+                    }
+                } else {
+                    int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2);
+
+                    uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
+                    cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
+
+                    uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
+                }
+            }
+        }
+    }
+
+    const auto &p = attr_.post_ops_;
+    int eltwise_inj_idx = 0;
+    int depthwise_inj_idx = 0;
+    int start_idx = p.find(primitive_kind::convolution) + 1;
+    for (int i = start_idx; i < p.len_; i++) {
+        auto& post_op = p.entry_[i];
+        if (post_op.is_eltwise()) {
+            eltwise_injectors[eltwise_inj_idx]->compute_vector_range(4, 4 + repeats * ur_w);
+            eltwise_inj_idx++;
+        } else if (post_op.is_depthwise()) {
+            mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+            mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+            add(reg_d_weights, reg_oc_off);
+            add(reg_d_bias, reg_oc_off);
+
+            depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4, 4 + ur_w, reg_d_weights, reg_d_bias);
+
+            if (repeats == 2) {
+                add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
+                add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
+
+                depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4 + ur_w, 4 + 2 * ur_w, reg_d_weights, reg_d_bias);
+            }
+
+            depthwise_inj_idx++;
+        }
+    }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_row_f32<isa>::store_dst_typed(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
+    Ymm ymm_dst = Ymm(vmm_dst.getIdx());
+    Xmm xmm_dst = Xmm(vmm_dst.getIdx());
+
+    switch (jcp.dst_dt) {
+        case data_type::f32:
+        case data_type::s32:
+            if (scalar_store) {
+                movq(reg_tmp_64, xmm_dst);
+                mov(op, reg_tmp_32);
+            } else {
+                uni_vmovups(op, vmm_dst);
+            }
+            break;
+        case data_type::s8:
+            uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
+
+            if (isa != sse42 && !scalar_store)
+                vpermq(ymm_dst, ymm_dst, 0x08);
+
+            uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
+
+            if (scalar_store) {
+                movq(reg_tmp_64, xmm_dst);
+                mov(op, reg_tmp_8);
+            } else {
+                if (isa != sse42)
+                    vmovq(op, xmm_dst);
+                else
+                    movd(op, xmm_dst);
+            }
+            break;
+        case data_type::u8:
+        case data_type::bin:
+            uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
+
+            if (isa != sse42 && !scalar_store)
+                vpermq(ymm_dst, ymm_dst, 0x08);
+
+            uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
+
+            if (scalar_store) {
+                movq(reg_tmp_64, xmm_dst);
+                mov(op, reg_tmp_8);
+            } else {
+                if (isa != sse42)
+                    vmovq(op, xmm_dst);
+                else
+                    movd(op, xmm_dst);
+            }
+            break;
+        default:
+            assert(!"unknown dst_dt");
+    }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
+    int repeats = isa == sse42 && oc_step > (jcp.ch_block / 2) ? 2 : 1;
+
     for (int i = 0; i < repeats; i++) {
         for (int ow = 0; ow < ur_w; ow++) {
-            int o_off = ow*jcp.ch_block + i*4;
-            Vmm vmm_dst = get_acc_reg(i*ur_w + ow);
+            Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
+            if (jcp.dst_dt != data_type::f32 && jcp.dst_dt != data_type::bin) {
+                if (attr_.round_mode_ == round_mode::nearest)
+                    uni_vcvtps2dq(vmm_dst, vmm_dst);
+                else if (attr_.round_mode_ == round_mode::down) {
+                    uni_vroundps(vmm_dst, vmm_dst, 1);
+                    uni_vcvtps2dq(vmm_dst, vmm_dst);
+                } else
+                    assert(!"unimplemented");
+            }
+        }
+    }
+
+    if (jcp.with_binarization) {
+        int output_step = div_up(ow_stride_, 8);
+
+        const auto &p = attr_.post_ops_;
+        int binarization_idx = p.find(primitive_kind::binarization);
+
+        mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
+        add(reg_b_weights, reg_oc_off);
+
+        for (int ow = 0; ow < ur_w; ow++) {
+            for (int i = 0; i < repeats; i++) {
+                int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
+                mov(reg_b_mask, (1 << tail_size) - 1);
+                uni_vmovups(vmm_thr, ptr[reg_b_weights + i * (jcp.ch_block / 2) * sizeof(float)]);
+
+                Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
+
+                uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
+
+                if (i == 0) {
+                    uni_vmovmskps(reg_tmp_32, vmm_dst);
+                    and_(reg_tmp_64, reg_b_mask);
+                } else {
+                    uni_vmovmskps(reg_tmp2_32, vmm_dst);
+                    and_(reg_tmp2_64, reg_b_mask);
+                    shl(reg_tmp2_32, 4);
+                    or_(reg_tmp_32, reg_tmp2_32);
+                }
+
+                if (i == repeats - 1) {
+                    const size_t o_off = ow * output_step;
+                    mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
+                }
+            }
+        }
+    } else {
+        for (int i = 0; i < repeats; i++) {
+            int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
+            bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
+            if (is_scalar_store) {
+                for (int ow = 0; ow < ur_w; ow++) {
+                    Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
+                    Ymm ymm_dst = Ymm(vmm_dst.getIdx());
+
+                    for (int oc = 0; oc < tail_size; oc++) {
+                        int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
+                        store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
+
+                        if (isa == sse42) {
+                            psrldq(vmm_dst, jcp.typesize_out);
+                        } else {
+                            vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
+                            vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
+                        }
+                    }
+                }
+            } else {
+                for (int ow = 0; ow < ur_w; ow++) {
+                    int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2);
+                    Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
 
-            uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
+                    store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
+                }
+            }
         }
     }
 }
 
 template <cpu_isa_t isa>
-void jit_uni_dw_conv_row_f32<isa>::loop_body() {
+void jit_uni_dw_conv_row_f32<isa>::loop_body(int oc_step) {
     Label left_pad_label;
     Label right_pad_label;
     Label unrolled_w_label;
     Label tail_w_label;
     Label exit_label;
 
+    int output_step = jcp.with_binarization ? div_up(ow_stride_, 8) : ow_stride_;
+
     L(left_pad_label); {
         int ur_w = 1;
         int kw = jcp.iw == 1 ? jcp.kw - 2 : jcp.kw - 1;
@@ -697,18 +957,17 @@ void jit_uni_dw_conv_row_f32<isa>::loop_body() {
         mov(aux_reg_input1, reg_input1);
         mov(aux_reg_input2, reg_input2);
         mov(aux_reg_kernel, reg_kernel);
-        add(aux_reg_kernel, jcp.ch_block*sizeof(float));
+        add(aux_reg_kernel, jcp.ch_block*jcp.typesize_in);
 
         load_src(ur_w);
         apply_filter(ur_w, kw);
-        apply_activation(ur_w);
-        store_dst(ur_w);
+        apply_postprocessing(ur_w, oc_step);
+        store_dst(ur_w, oc_step);
 
-        add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
-        add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
-        add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
-
-        add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
+        add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
+        add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
+        add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
+        add(reg_output, jcp.typesize_out * ur_w * output_step);
 
         sub(reg_ur_w, ur_w);
     }
@@ -727,13 +986,13 @@ void jit_uni_dw_conv_row_f32<isa>::loop_body() {
 
         load_src(ur_w);
         apply_filter(ur_w, kw);
-        apply_activation(ur_w);
-        store_dst(ur_w);
+        apply_postprocessing(ur_w, oc_step);
+        store_dst(ur_w, oc_step);
 
-        add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
-        add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
-        add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
-        add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
+        add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+        add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+        add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+        add(reg_output, jcp.typesize_out * ur_w * output_step);
 
         sub(reg_ur_w, ur_w);
         jmp(unrolled_w_label, T_NEAR);
@@ -756,13 +1015,13 @@ void jit_uni_dw_conv_row_f32<isa>::loop_body() {
 
         load_src(ur_w);
         apply_filter(ur_w, kw);
-        apply_activation(ur_w);
-        store_dst(ur_w);
+        apply_postprocessing(ur_w, oc_step);
+        store_dst(ur_w, oc_step);
 
-        add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
-        add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
-        add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
-        add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
+        add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+        add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+        add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+        add(reg_output, jcp.typesize_out * ur_w * output_step);
 
         sub(reg_ur_w, ur_w);
         jmp(tail_w_label, T_NEAR);
@@ -780,8 +1039,8 @@ void jit_uni_dw_conv_row_f32<isa>::loop_body() {
 
             load_src(ur_w);
             apply_filter(ur_w, kw);
-            apply_activation(ur_w);
-            store_dst(ur_w);
+            apply_postprocessing(ur_w, oc_step);
+            store_dst(ur_w, oc_step);
 
             sub(reg_ur_w, ur_w);
         }
@@ -791,8 +1050,26 @@ void jit_uni_dw_conv_row_f32<isa>::loop_body() {
 }
 
 template <cpu_isa_t isa>
-void jit_uni_dw_conv_row_f32<isa>::generate()
-{
+void jit_uni_dw_conv_row_f32<isa>::generate() {
+    const auto &p = attr_.post_ops_;
+    int start_idx = p.find(primitive_kind::convolution) + 1;
+    for (int i = start_idx; i < p.len_; i++) {
+        auto &post_op = p.entry_[i];
+        if (post_op.is_eltwise()) {
+            eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
+                    this,
+                    post_op.eltwise.alg,
+                    post_op.eltwise.alpha,
+                    post_op.eltwise.beta
+            ));
+        } else if (post_op.is_depthwise()) {
+            depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
+                    this,
+                    post_op.depthwise.alg
+            ));
+        }
+    }
+
     this->preamble();
 
     mov(reg_input0, ptr[this->param1 + GET_OFF_DW(src_row0)]);
@@ -804,45 +1081,196 @@ void jit_uni_dw_conv_row_f32<isa>::generate()
         mov(reg_bias, ptr[this->param1 + GET_OFF_DW(bias)]);
     mov(reg_kh, ptr[this->param1 + GET_OFF_DW(kh_padding)]);
     mov(reg_ur_w, ptr[this->param1 + GET_OFF_DW(ur_w)]);
+    mov(reg_oc_work, ptr[this->param1 + GET_OFF_DW(oc_work)]);
+    mov(reg_oc_off, ptr[this->param1 + GET_OFF_DW(oc_off)]);
+
+    Label(tail_label);
+    Label(exit_label);
 
-    loop_body();
+    cmp(reg_oc_work, jcp.ch_block);
+    jl(tail_label, T_NEAR);
+
+    loop_body(jcp.ch_block);
+    jmp(exit_label, T_NEAR);
+
+    L(tail_label);
+
+    if (jcp.oc % jcp.ch_block != 0)
+        loop_body(jcp.oc % jcp.ch_block);
+
+    L(exit_label);
 
     this->postamble();
 
-    if (jcp.with_eltwise)
-        eltwise_injector->prepare_table();
+    for (auto& inj : eltwise_injectors)
+        inj->prepare_table();
+}
+
+template <cpu_isa_t isa>
+bool jit_uni_dw_conv_row_f32<isa>::post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
+    const auto &p = attr.post_ops_;
+
+    auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+    auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
+    auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+    auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
+    auto is_binarization = [&](int idx) { return p.entry_[idx].is_binarization(); };
+
+    int start_idx = p.find(primitive_kind::convolution) + 1;
+
+    switch (p.len_ - start_idx) {
+    case 0: return true; // no post_ops
+    case 1: return is_simple(start_idx) || is_sum(start_idx) || is_binarization(start_idx);
+    case 2: return (is_sum(start_idx) && is_simple(start_idx+1)) || (is_simple(start_idx) && is_simple(start_idx+1)) ||
+                   (is_simple(start_idx) && is_binarization(start_idx+1));
+    case 3: return (is_sum(start_idx) && is_simple(start_idx+1) && is_simple(start_idx+2));
+    default: return false;
+    }
+
+    return false;
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_1x1_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
+        const primitive_attr_t &attr) {
+    if (!mayiuse(isa)) return status::unimplemented;
+    const int simd_w = isa == avx512_common ? 16 : 8;
+
+    const auto &p = attr.post_ops_;
+
+    int dw_conv_ind = p.find(primitive_kind::convolution);
+    jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
+
+    jcp_dw.ch_block = simd_w;
+    jcp_dw.with_bias = true;
+
+    jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
+    jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
+    jcp_dw.ic = jcp.oc;
+    jcp_dw.oc = jcp.oc;
+    jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
+    jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
+    jcp_dw.oh = jcp.dw_conv_oh;
+    jcp_dw.ow = jcp.dw_conv_ow;
+    jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
+    jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
+    jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
+    jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+
+    if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
+        return status::unimplemented;
+
+    if (!post_ops_ok(jcp_dw, attr))
+        return status::unimplemented;
+
+    jcp_dw.ur_w = 4;
+
+    jcp_dw.src_dt = jcp.src_dt;
+    jcp_dw.dst_dt = jcp.dst_dt;
+    jcp_dw.bia_dt = jcp.bia_dt;
+    jcp_dw.typesize_in = (int)types::data_type_size(jcp.src_dt);
+    jcp_dw.typesize_bia = (int)types::data_type_size(jcp.bia_dt);
+    jcp_dw.typesize_out = (int)types::data_type_size(jcp.dst_dt);
+
+    if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
+        return status::unimplemented;
+
+    return status::success;
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
+        const primitive_attr_t &attr) {
+    if (!mayiuse(isa)) return status::unimplemented;
+    const int simd_w = isa == avx512_common ? 16 : 8;
+
+    const auto &p = attr.post_ops_;
+
+    int dw_conv_ind = p.find(primitive_kind::convolution);
+    jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
+
+    jcp_dw.ch_block = simd_w;
+    jcp_dw.with_bias = true;
+
+    jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
+    jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
+    jcp_dw.ic = jcp.oc;
+    jcp_dw.oc = jcp.oc;
+    jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
+    jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
+    jcp_dw.oh = jcp.dw_conv_oh;
+    jcp_dw.ow = jcp.dw_conv_ow;
+    jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
+    jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
+    jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
+    jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+
+    if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
+        return status::unimplemented;
+
+    if (!post_ops_ok(jcp_dw, attr))
+        return status::unimplemented;
+
+    jcp_dw.ur_w = 4;
+
+    jcp_dw.src_dt = jcp.dst_dt;
+    jcp_dw.dst_dt = jcp.dst_dt;
+    jcp_dw.bia_dt = jcp.bia_dt;
+    jcp_dw.typesize_in = (int)types::data_type_size(jcp.src_dt);
+    jcp_dw.typesize_bia = (int)types::data_type_size(jcp.bia_dt);
+    jcp_dw.typesize_out = (int)types::data_type_size(jcp.dst_dt);
+
+    if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
+        return status::unimplemented;
+
+    return status::success;
 }
 
 template <cpu_isa_t isa>
-status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp,
-        int ic, int ih, int iw, int oh, int ow, int ker_h, int ker_w, int str_h, int str_w, alg_kind_t eltwise_alg,
-        float eltwise_alpha, float eltwise_beta, bool with_sum) {
+status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_bin_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
+        const primitive_attr_t &attr) {
     if (!mayiuse(isa)) return status::unimplemented;
     const int simd_w = isa == avx512_common ? 16 : 8;
 
-    jcp.kh = ker_h;
-    jcp.kw = ker_w;
-    jcp.ch_block = simd_w;
-    jcp.with_bias = true;
-    jcp.ic = ic;
-    jcp.oc = ic;
-    jcp.ih = ih;
-    jcp.iw = iw;
-    jcp.oh = oh;
-    jcp.ow = ow;
-    jcp.stride_h = str_h;
-    jcp.stride_w = str_w;
-
-    if (jcp.kh != 3 || jcp.kw != 3)
-        return  status::unimplemented;
-
-    jcp.ur_w = 4;
-
-    jcp.with_eltwise  = eltwise_alg != mkldnn_alg_kind_undef;
-    jcp.eltwise_alg   = eltwise_alg;
-    jcp.eltwise_alpha = eltwise_alpha;
-    jcp.eltwise_beta  = eltwise_beta;
-    jcp.with_sum = with_sum;
+    const auto &p = attr.post_ops_;
+
+    int dw_conv_ind = p.find(primitive_kind::convolution);
+    jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
+    jcp_dw.with_binarization = p.find(primitive_kind::binarization, dw_conv_ind) != -1;
+
+    jcp_dw.ch_block = simd_w;
+    jcp_dw.with_bias = true;
+
+    jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
+    jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
+    jcp_dw.ic = jcp.oc;
+    jcp_dw.oc = jcp.oc;
+    jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
+    jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
+    jcp_dw.oh = jcp.dw_conv_oh;
+    jcp_dw.ow = jcp.dw_conv_ow;
+    jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
+    jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
+    jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
+    jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+
+    if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
+        return status::unimplemented;
+
+    if (!post_ops_ok(jcp_dw, attr))
+        return status::unimplemented;
+
+    jcp_dw.ur_w = 4;
+
+    jcp_dw.src_dt = mkldnn_f32;
+    jcp_dw.dst_dt = jcp_dw.with_binarization ? mkldnn_bin : mkldnn_f32;
+    jcp_dw.bia_dt = mkldnn_f32;
+    jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
+    jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
+    jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
+
+    if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
+        return status::unimplemented;
 
     return status::success;
 }