for (int ow = 0; ow < ur_w; ow++) {
if (is_scalar_store) {
- for (int oc = 0; oc < tail_size; oc++) {
- int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
+ if (isa == avx512_common) {
+ int o_off = ow * ow_stride_;
- uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
- cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
-
- if (oc >= jcp.ch_block / 2) {
- vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
- }
- uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
+ Vmm vmm_in = vmm_sum | ktail_mask | T_z;
+ cvt2ps(jcp.dst_dt, vmm_in, ptr[reg_output + o_off * jcp.typesize_out], false);
uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
+ } else {
+ for (int oc = 0; oc < tail_size; oc++) {
+ int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
+
+ uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
+ cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
+
+ if (oc >= jcp.ch_block / 2) {
+ vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
+ }
+ uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
+
+ uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
+ }
}
} else {
int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2);
template <cpu_isa_t isa>
void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
+ int nbits = 8;
int repeats = isa == sse42 && oc_step > (jcp.ch_block / 2) ? 2 : 1;
+ if (isa == avx512_common && oc_step != jcp.ch_block) {
+ int mask = (1 << oc_step) - 1;
+ mov(reg_tmp_32, mask);
+ kmovw(ktail_mask, reg_tmp_32);
+ }
+
for (int i = 0; i < repeats; i++) {
for (int ow = 0; ow < ur_w; ow++) {
Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
}
if (jcp.with_binarization) {
- int output_step = div_up(ow_stride_, 8);
+ int output_step = div_up(ow_stride_, nbits);
const auto &p = attr_.post_ops_;
int binarization_idx = p.find(primitive_kind::binarization);
+ push(reg_bias);
+
mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
+ mov(reg_b_out_mask, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.output_mask_data));
add(reg_b_weights, reg_oc_off);
+ add(reg_b_out_mask, reg_oc_off);
for (int ow = 0; ow < ur_w; ow++) {
for (int i = 0; i < repeats; i++) {
int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
mov(reg_b_mask, (1 << tail_size) - 1);
uni_vmovups(vmm_thr, ptr[reg_b_weights + i * (jcp.ch_block / 2) * sizeof(float)]);
+ uni_vmovups(vmm_out_mask, ptr[reg_b_out_mask + i * (jcp.ch_block / 2) * sizeof(float)]);
Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
- uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
+ if (isa == avx512_common) {
+ vcmpps(bin_mask0, vmm_dst, vmm_thr, _cmp_gt_os);
+ vptestmd(bin_mask1, vmm_out_mask, vmm_out_mask);
+ kxnorw(bin_mask0, bin_mask0, bin_mask1);
+ } else {
+ uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
+ uni_vpcmpeqd(vmm_dst, vmm_dst, vmm_out_mask);
+ }
if (i == 0) {
- uni_vmovmskps(reg_tmp_32, vmm_dst);
+ if (isa == avx512_common) {
+ kmovw(reg_tmp_32, bin_mask0);
+ } else {
+ uni_vmovmskps(reg_tmp_32, vmm_dst);
+ }
and_(reg_tmp_64, reg_b_mask);
} else {
uni_vmovmskps(reg_tmp2_32, vmm_dst);
if (i == repeats - 1) {
const size_t o_off = ow * output_step;
- mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
+ if (isa == avx512_common && oc_step > nbits) {
+ mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_16);
+ } else {
+ mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
+ }
}
}
}
+
+ pop(reg_bias);
} else {
for (int i = 0; i < repeats; i++) {
int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
if (is_scalar_store) {
for (int ow = 0; ow < ur_w; ow++) {
Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
- Ymm ymm_dst = Ymm(vmm_dst.getIdx());
- for (int oc = 0; oc < tail_size; oc++) {
- int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
- store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
+ if (isa == avx512_common) {
+ int o_off = ow * ow_stride_;
+
+ store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst | ktail_mask, false);
+ } else {
+ for (int oc = 0; oc < tail_size; oc++) {
+ int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
+ store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
+
+ if (isa == sse42) {
+ psrldq(vmm_dst, jcp.typesize_out);
+ } else {
+ Ymm ymm_dst = Ymm(vmm_dst.getIdx());
- if (isa == sse42) {
- psrldq(vmm_dst, jcp.typesize_out);
- } else {
- vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
- vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
+ vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
+ vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
+ }
}
}
}