Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_depthwise.cpp
index 9aad4f1..ac49aad 100644 (file)
@@ -740,18 +740,27 @@ void jit_uni_dw_conv_row_f32<isa>::apply_postprocessing(int ur_w, int oc_step) {
 
             for (int ow = 0; ow < ur_w; ow++) {
                 if (is_scalar_store) {
-                    for (int oc = 0; oc < tail_size; oc++) {
-                        int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
+                    if (isa == avx512_common) {
+                        int o_off = ow * ow_stride_;
 
-                        uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
-                        cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
-
-                        if (oc >= jcp.ch_block / 2) {
-                            vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
-                        }
-                        uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
+                        Vmm vmm_in = vmm_sum | ktail_mask | T_z;
 
+                        cvt2ps(jcp.dst_dt, vmm_in, ptr[reg_output + o_off * jcp.typesize_out], false);
                         uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
+                    } else {
+                        for (int oc = 0; oc < tail_size; oc++) {
+                            int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
+
+                            uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
+                            cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
+
+                            if (oc >= jcp.ch_block / 2) {
+                                vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
+                            }
+                            uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
+
+                            uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
+                        }
                     }
                 } else {
                     int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2);
@@ -854,8 +863,15 @@ void jit_uni_dw_conv_row_f32<isa>::store_dst_typed(const Xbyak::Address &op, Vmm
 
 template <cpu_isa_t isa>
 void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
+    int nbits = 8;
     int repeats = isa == sse42 && oc_step > (jcp.ch_block / 2) ? 2 : 1;
 
+    if (isa == avx512_common && oc_step != jcp.ch_block) {
+        int mask = (1 << oc_step) - 1;
+        mov(reg_tmp_32, mask);
+        kmovw(ktail_mask, reg_tmp_32);
+    }
+
     for (int i = 0; i < repeats; i++) {
         for (int ow = 0; ow < ur_w; ow++) {
             Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
@@ -872,26 +888,42 @@ void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
     }
 
     if (jcp.with_binarization) {
-        int output_step = div_up(ow_stride_, 8);
+        int output_step = div_up(ow_stride_, nbits);
 
         const auto &p = attr_.post_ops_;
         int binarization_idx = p.find(primitive_kind::binarization);
 
+        push(reg_bias);
+
         mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
+        mov(reg_b_out_mask, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.output_mask_data));
         add(reg_b_weights, reg_oc_off);
+        add(reg_b_out_mask, reg_oc_off);
 
         for (int ow = 0; ow < ur_w; ow++) {
             for (int i = 0; i < repeats; i++) {
                 int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
                 mov(reg_b_mask, (1 << tail_size) - 1);
                 uni_vmovups(vmm_thr, ptr[reg_b_weights + i * (jcp.ch_block / 2) * sizeof(float)]);
+                uni_vmovups(vmm_out_mask, ptr[reg_b_out_mask + i * (jcp.ch_block / 2) * sizeof(float)]);
 
                 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
 
-                uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
+                if (isa == avx512_common) {
+                    vcmpps(bin_mask0, vmm_dst, vmm_thr, _cmp_gt_os);
+                    vptestmd(bin_mask1, vmm_out_mask, vmm_out_mask);
+                    kxnorw(bin_mask0, bin_mask0, bin_mask1);
+                } else {
+                    uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
+                    uni_vpcmpeqd(vmm_dst, vmm_dst, vmm_out_mask);
+                }
 
                 if (i == 0) {
-                    uni_vmovmskps(reg_tmp_32, vmm_dst);
+                    if (isa == avx512_common) {
+                        kmovw(reg_tmp_32, bin_mask0);
+                    } else {
+                        uni_vmovmskps(reg_tmp_32, vmm_dst);
+                    }
                     and_(reg_tmp_64, reg_b_mask);
                 } else {
                     uni_vmovmskps(reg_tmp2_32, vmm_dst);
@@ -902,10 +934,16 @@ void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
 
                 if (i == repeats - 1) {
                     const size_t o_off = ow * output_step;
-                    mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
+                    if (isa == avx512_common && oc_step > nbits) {
+                        mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_16);
+                    } else {
+                        mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
+                    }
                 }
             }
         }
+
+        pop(reg_bias);
     } else {
         for (int i = 0; i < repeats; i++) {
             int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
@@ -913,17 +951,24 @@ void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
             if (is_scalar_store) {
                 for (int ow = 0; ow < ur_w; ow++) {
                     Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
-                    Ymm ymm_dst = Ymm(vmm_dst.getIdx());
 
-                    for (int oc = 0; oc < tail_size; oc++) {
-                        int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
-                        store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
+                    if (isa == avx512_common) {
+                        int o_off = ow * ow_stride_;
+
+                        store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst | ktail_mask, false);
+                    } else {
+                        for (int oc = 0; oc < tail_size; oc++) {
+                            int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
+                            store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
+
+                            if (isa == sse42) {
+                                psrldq(vmm_dst, jcp.typesize_out);
+                            } else {
+                                Ymm ymm_dst = Ymm(vmm_dst.getIdx());
 
-                        if (isa == sse42) {
-                            psrldq(vmm_dst, jcp.typesize_out);
-                        } else {
-                            vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
-                            vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
+                                vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
+                                vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
+                            }
                         }
                     }
                 }