/*******************************************************************************
-* Copyright 2018 Intel Corporation
+* Copyright 2018-2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
template <cpu_isa_t isa>
int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg) {
switch (depthwise_alg) {
- case alg_kind::depthwise_scale_shift: return 0;
+ case alg_kind::depthwise_scale_shift: return isa == sse42 ? 1 : 0;
case alg_kind::depthwise_prelu: return 2;
default: assert(!"unsupported depthwise algorithm");
}
template <cpu_isa_t isa>
void jit_uni_depthwise_injector_f32<isa>::scale_shift_compute_vector(const Vmm &vmm_src,
const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
- h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
- h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
+ if (isa == sse42) {
+ h->movups(vmm_mask, h->ptr[p_weights]);
+ h->mulps(vmm_src, vmm_mask);
+ h->movups(vmm_mask, h->ptr[p_bias]);
+ h->addps(vmm_src, vmm_mask);
+ } else {
+ h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
+ h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
+ };
}
template <cpu_isa_t isa>
if (isa == sse42) {
h->pxor(vmm_mask, vmm_mask);
h->cmpps(vmm_mask, vmm_src, _cmp_gt_os);
- h->movups(vmm_aux0, vmm_src);
- h->mulps(vmm_aux0, h->ptr[p_weights]);
+ h->movups(vmm_aux0, h->ptr[p_weights]);
+ h->mulps(vmm_aux0, vmm_src);
h->blendvps(vmm_src, vmm_aux0);
} else if (isa == avx2) {
h->vxorps(vmm_mask, vmm_mask, vmm_mask);
assert(desc.alg_kind == alg_kind::depthwise_scale_shift);
assert(isa == sse42 || isa == avx2 || isa == avx512_common);
- bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw ;
+ bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw;
Reg64 param = abi_param1;
}
template <cpu_isa_t isa>
-jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *pd,
+jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr),
+ : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr),
padded_weights_(nullptr), padded_bias_(nullptr) {
- const auto &desc = *conf_.desc();
+ const auto &desc = *pd()->desc();
switch (desc.alg_kind) {
case alg_kind::depthwise_scale_shift:
- kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd->with_bias()); break;
+ kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd()->with_bias()); break;
case alg_kind::depthwise_prelu:
- kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd->with_bias()); break;
+ kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd()->with_bias()); break;
default: assert(!"unknown depthwise alg_kind");
}
const int simd_w = isa == avx512_common ? 16 : 8;
- const memory_desc_wrapper data_d(conf_.src_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
const int c_without_padding = data_d.dims()[1];
const int c_padded = rnd_up(c_without_padding, simd_w);
- if (conf_.want_padded_weights()) {
+ if (pd()->want_padded_weights()) {
padded_weights_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
for (int oc = c_without_padding; oc < c_padded; ++oc)
padded_weights_[oc] = 0;
- if (conf_.with_bias()) {
+ if (pd()->with_bias()) {
padded_bias_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
for (int oc = c_without_padding; oc < c_padded; ++oc)
padded_bias_[oc] = 0;
}
template <cpu_isa_t isa>
-void jit_uni_depthwise_fwd_t<isa>::execute_forward() {
+void jit_uni_depthwise_fwd_t<isa>::execute_forward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<data_t *>(this->memory());
- const memory_desc_wrapper data_d(conf_.src_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ const memory_desc_wrapper data_d(pd()->src_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
const int N = data_d.dims()[0];
const int C = data_d.dims()[1];
const int ch_block_size = data_d.format() == nchw ? 1 : simd_w;
const int CB = div_up(C, ch_block_size);
- if (conf_.want_padded_weights()) {
+ if (pd()->want_padded_weights()) {
for (int oc = 0; oc < C; ++oc)
padded_weights_[oc] = weights[oc];
weights = padded_weights_;
- if (conf_.with_bias()) {
+ if (pd()->with_bias()) {
for (int oc = 0; oc < C; ++oc)
padded_bias_[oc] = bias[oc];
bias = padded_bias_;
parallel_nd(N, CB, H,
[&](int n, int cb, int h) {
- jit_args arg = {};
+ auto arg = jit_args();
arg.from = &src[data_d.blk_off(n, cb, h)];
arg.to = &dst[data_d.blk_off(n, cb, h)];
for (int ow = 0; ow < ur_w; ow++) {
Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
- if (this->jcp.with_bias)
- uni_vmovups(vmm_acc, vmmword[reg_bias + i*4*sizeof(float)]);
- else
- uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
-
- int o_off = ow*jcp.ch_block + i*4;
- if (this->jcp.with_sum)
- uni_vaddps(vmm_acc, vmm_acc,
- vmmword[reg_output + o_off*sizeof(float)]);
+ uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
}
}
}
template <cpu_isa_t isa>
void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
+ auto load_src = [=](Vmm vmm_src, const Xbyak::Address &op) {
+ if (jcp.src_dt == data_type::u8) {
+ uni_vpmovzxbd(vmm_src, op);
+ } else {
+ uni_vmovups(vmm_src, op);
+ }
+ };
+
+ auto load_ker = [=](Vmm vmm_ker, const Xbyak::Address &op) {
+ if (jcp.src_dt == data_type::u8) {
+ uni_vpmovsxbd(vmm_ker, op);
+ } else {
+ uni_vmovups(vmm_ker, op);
+ }
+ };
+
+ auto compute = [=](Vmm vmm_acc, Vmm vmm_src, Vmm vmm_ker) {
+ if (jcp.src_dt == data_type::u8) {
+ uni_vpmulld(vmm_src, vmm_src, vmm_ker);
+ uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
+ } else {
+ uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+ }
+ };
+
int ch_blk = jcp.ch_block;
int stride_w = jcp.stride_w;
jl(exit_label, T_NEAR);
for (int i = 0; i < repeats; i++) {
for (int kw = 0; kw < kw_size; kw++) {
- int ker_off = kw * ch_blk + i*4;
+ int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
Vmm vmm_ker = get_ker_reg(0);
- uni_vmovups(vmm_ker, ptr[aux_reg_kernel
- + ker_off * sizeof(float)]);
+ load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
for (int ow = 0; ow < ur_w; ow++) {
- int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
+ int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
Vmm vmm_src = get_src_reg(0);
- uni_vmovups(vmm_src, ptr[aux_reg_input0
- + inp_off * sizeof(float)]);
+ load_src(vmm_src, ptr[aux_reg_input0 + inp_off * jcp.typesize_in]);
Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
- uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+ compute(vmm_acc, vmm_src, vmm_ker);
}
}
}
- add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
+ add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
cmp(reg_kh, 2);
jl(exit_label, T_NEAR);
for (int i = 0; i < repeats; i++) {
for (int kw = 0; kw < kw_size; kw++) {
- int ker_off = kw * ch_blk + i*4;
+ int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
Vmm vmm_ker = get_ker_reg(0);
- uni_vmovups(vmm_ker, ptr[aux_reg_kernel
- + ker_off * sizeof(float)]);
+ load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
for (int ow = 0; ow < ur_w; ow++) {
- int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
+ int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
Vmm vmm_src = get_src_reg(0);
- uni_vmovups(vmm_src, ptr[aux_reg_input1
- + inp_off * sizeof(float)]);
+ load_src(vmm_src, ptr[aux_reg_input1 + inp_off * jcp.typesize_in]);
Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
- uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+ compute(vmm_acc, vmm_src, vmm_ker);
}
}
}
- add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
+ add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
cmp(reg_kh, 3);
jl(exit_label, T_NEAR);
for (int i = 0; i < repeats; i++) {
for (int kw = 0; kw < kw_size; kw++) {
- int ker_off = kw * ch_blk + i*4;
+ int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
Vmm vmm_ker = get_ker_reg(0);
- uni_vmovups(vmm_ker, ptr[aux_reg_kernel
- + ker_off * sizeof(float)]);
+ load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
for (int ow = 0; ow < ur_w; ow++) {
- int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
+ int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
Vmm vmm_src = get_src_reg(0);
- uni_vmovups(vmm_src, ptr[aux_reg_input2
- + inp_off * sizeof(float)]);
+ load_src(vmm_src, ptr[aux_reg_input2 + inp_off * jcp.typesize_in]);
Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
- uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+ compute(vmm_acc, vmm_src, vmm_ker);
}
}
}
}
template <cpu_isa_t isa>
-void jit_uni_dw_conv_row_f32<isa>::apply_activation(int ur_w) {
- if (this->jcp.with_eltwise) {
- int repeats = isa == sse42 ? 2 : 1;
- eltwise_injector->compute_vector_range(4, repeats * ur_w + 4);
+void jit_uni_dw_conv_row_f32<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
+ Xmm xmm_in = Xmm(vmm_in.getIdx());
+
+ switch (type_in) {
+ case data_type::f32:
+ case data_type::s32:
+ if (scalar_load) {
+ mov(reg_tmp_32, op);
+ movq(xmm_in, reg_tmp_64);
+ } else {
+ uni_vmovups(vmm_in, op);
+ }
+ break;
+ case data_type::s8:
+ if (scalar_load) {
+ movsx(reg_tmp_32, op);
+ movq(xmm_in, reg_tmp_64);
+ } else {
+ uni_vpmovsxbd(vmm_in, op);
+ }
+ break;
+ case data_type::u8:
+ if (scalar_load) {
+ movzx(reg_tmp_32, op);
+ movq(xmm_in, reg_tmp_64);
+ } else {
+ uni_vpmovzxbd(vmm_in, op);
+ }
+ break;
+ default: assert(!"unsupported data type");
}
+
+ if (type_in != data_type::f32)
+ uni_vcvtdq2ps(vmm_in, vmm_in);
}
template <cpu_isa_t isa>
-void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w) {
+void jit_uni_dw_conv_row_f32<isa>::apply_postprocessing(int ur_w, int oc_step) {
int repeats = isa == sse42 ? 2 : 1;
+
+ for (int r = 0; r < repeats; r++) {
+ for (int ow = 0; ow < ur_w; ow++) {
+ if (jcp.src_dt == data_type::u8) {
+ uni_vcvtdq2ps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow));
+ }
+
+ if (jcp.with_bias) {
+ int b_off = r * (jcp.ch_block / 2);
+ cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias + b_off * jcp.typesize_bia], false);
+ uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_bias);
+ }
+ }
+ }
+
+ if (jcp.with_sum) {
+ for (int r = 0; r < repeats; r++) {
+ int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - r * jcp.ch_block / 2) : oc_step;
+ bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
+
+ for (int ow = 0; ow < ur_w; ow++) {
+ if (is_scalar_store) {
+ for (int oc = 0; oc < tail_size; oc++) {
+ int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
+
+ uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
+ cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
+
+ if (oc >= jcp.ch_block / 2) {
+ vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
+ }
+ uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
+
+ uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
+ }
+ } else {
+ int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2);
+
+ uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
+ cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
+
+ uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
+ }
+ }
+ }
+ }
+
+ const auto &p = attr_.post_ops_;
+ int eltwise_inj_idx = 0;
+ int depthwise_inj_idx = 0;
+ int start_idx = p.find(primitive_kind::convolution) + 1;
+ for (int i = start_idx; i < p.len_; i++) {
+ auto& post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors[eltwise_inj_idx]->compute_vector_range(4, 4 + repeats * ur_w);
+ eltwise_inj_idx++;
+ } else if (post_op.is_depthwise()) {
+ mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+ mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+ add(reg_d_weights, reg_oc_off);
+ add(reg_d_bias, reg_oc_off);
+
+ depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4, 4 + ur_w, reg_d_weights, reg_d_bias);
+
+ if (repeats == 2) {
+ add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
+ add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
+
+ depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4 + ur_w, 4 + 2 * ur_w, reg_d_weights, reg_d_bias);
+ }
+
+ depthwise_inj_idx++;
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_row_f32<isa>::store_dst_typed(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
+ Ymm ymm_dst = Ymm(vmm_dst.getIdx());
+ Xmm xmm_dst = Xmm(vmm_dst.getIdx());
+
+ switch (jcp.dst_dt) {
+ case data_type::f32:
+ case data_type::s32:
+ if (scalar_store) {
+ movq(reg_tmp_64, xmm_dst);
+ mov(op, reg_tmp_32);
+ } else {
+ uni_vmovups(op, vmm_dst);
+ }
+ break;
+ case data_type::s8:
+ uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
+
+ if (isa != sse42 && !scalar_store)
+ vpermq(ymm_dst, ymm_dst, 0x08);
+
+ uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
+
+ if (scalar_store) {
+ movq(reg_tmp_64, xmm_dst);
+ mov(op, reg_tmp_8);
+ } else {
+ if (isa != sse42)
+ vmovq(op, xmm_dst);
+ else
+ movd(op, xmm_dst);
+ }
+ break;
+ case data_type::u8:
+ case data_type::bin:
+ uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
+
+ if (isa != sse42 && !scalar_store)
+ vpermq(ymm_dst, ymm_dst, 0x08);
+
+ uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
+
+ if (scalar_store) {
+ movq(reg_tmp_64, xmm_dst);
+ mov(op, reg_tmp_8);
+ } else {
+ if (isa != sse42)
+ vmovq(op, xmm_dst);
+ else
+ movd(op, xmm_dst);
+ }
+ break;
+ default:
+ assert(!"unknown dst_dt");
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
+ int repeats = isa == sse42 && oc_step > (jcp.ch_block / 2) ? 2 : 1;
+
for (int i = 0; i < repeats; i++) {
for (int ow = 0; ow < ur_w; ow++) {
- int o_off = ow*jcp.ch_block + i*4;
- Vmm vmm_dst = get_acc_reg(i*ur_w + ow);
+ Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
+ if (jcp.dst_dt != data_type::f32 && jcp.dst_dt != data_type::bin) {
+ if (attr_.round_mode_ == round_mode::nearest)
+ uni_vcvtps2dq(vmm_dst, vmm_dst);
+ else if (attr_.round_mode_ == round_mode::down) {
+ uni_vroundps(vmm_dst, vmm_dst, 1);
+ uni_vcvtps2dq(vmm_dst, vmm_dst);
+ } else
+ assert(!"unimplemented");
+ }
+ }
+ }
+
+ if (jcp.with_binarization) {
+ int output_step = div_up(ow_stride_, 8);
+
+ const auto &p = attr_.post_ops_;
+ int binarization_idx = p.find(primitive_kind::binarization);
+
+ mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
+ add(reg_b_weights, reg_oc_off);
+
+ for (int ow = 0; ow < ur_w; ow++) {
+ for (int i = 0; i < repeats; i++) {
+ int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
+ mov(reg_b_mask, (1 << tail_size) - 1);
+ uni_vmovups(vmm_thr, ptr[reg_b_weights + i * (jcp.ch_block / 2) * sizeof(float)]);
+
+ Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
+
+ uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
+
+ if (i == 0) {
+ uni_vmovmskps(reg_tmp_32, vmm_dst);
+ and_(reg_tmp_64, reg_b_mask);
+ } else {
+ uni_vmovmskps(reg_tmp2_32, vmm_dst);
+ and_(reg_tmp2_64, reg_b_mask);
+ shl(reg_tmp2_32, 4);
+ or_(reg_tmp_32, reg_tmp2_32);
+ }
+
+ if (i == repeats - 1) {
+ const size_t o_off = ow * output_step;
+ mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < repeats; i++) {
+ int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
+ bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
+ if (is_scalar_store) {
+ for (int ow = 0; ow < ur_w; ow++) {
+ Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
+ Ymm ymm_dst = Ymm(vmm_dst.getIdx());
+
+ for (int oc = 0; oc < tail_size; oc++) {
+ int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
+ store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
+
+ if (isa == sse42) {
+ psrldq(vmm_dst, jcp.typesize_out);
+ } else {
+ vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
+ vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
+ }
+ }
+ }
+ } else {
+ for (int ow = 0; ow < ur_w; ow++) {
+ int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2);
+ Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
- uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
+ store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
+ }
+ }
}
}
}
template <cpu_isa_t isa>
-void jit_uni_dw_conv_row_f32<isa>::loop_body() {
+void jit_uni_dw_conv_row_f32<isa>::loop_body(int oc_step) {
Label left_pad_label;
Label right_pad_label;
Label unrolled_w_label;
Label tail_w_label;
Label exit_label;
+ int output_step = jcp.with_binarization ? div_up(ow_stride_, 8) : ow_stride_;
+
L(left_pad_label); {
int ur_w = 1;
int kw = jcp.iw == 1 ? jcp.kw - 2 : jcp.kw - 1;
mov(aux_reg_input1, reg_input1);
mov(aux_reg_input2, reg_input2);
mov(aux_reg_kernel, reg_kernel);
- add(aux_reg_kernel, jcp.ch_block*sizeof(float));
+ add(aux_reg_kernel, jcp.ch_block*jcp.typesize_in);
load_src(ur_w);
apply_filter(ur_w, kw);
- apply_activation(ur_w);
- store_dst(ur_w);
+ apply_postprocessing(ur_w, oc_step);
+ store_dst(ur_w, oc_step);
- add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
- add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
- add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
-
- add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
+ add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
+ add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
+ add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
+ add(reg_output, jcp.typesize_out * ur_w * output_step);
sub(reg_ur_w, ur_w);
}
load_src(ur_w);
apply_filter(ur_w, kw);
- apply_activation(ur_w);
- store_dst(ur_w);
+ apply_postprocessing(ur_w, oc_step);
+ store_dst(ur_w, oc_step);
- add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
+ add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_output, jcp.typesize_out * ur_w * output_step);
sub(reg_ur_w, ur_w);
jmp(unrolled_w_label, T_NEAR);
load_src(ur_w);
apply_filter(ur_w, kw);
- apply_activation(ur_w);
- store_dst(ur_w);
+ apply_postprocessing(ur_w, oc_step);
+ store_dst(ur_w, oc_step);
- add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
+ add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_output, jcp.typesize_out * ur_w * output_step);
sub(reg_ur_w, ur_w);
jmp(tail_w_label, T_NEAR);
load_src(ur_w);
apply_filter(ur_w, kw);
- apply_activation(ur_w);
- store_dst(ur_w);
+ apply_postprocessing(ur_w, oc_step);
+ store_dst(ur_w, oc_step);
sub(reg_ur_w, ur_w);
}
}
template <cpu_isa_t isa>
-void jit_uni_dw_conv_row_f32<isa>::generate()
-{
+void jit_uni_dw_conv_row_f32<isa>::generate() {
+ const auto &p = attr_.post_ops_;
+ int start_idx = p.find(primitive_kind::convolution) + 1;
+ for (int i = start_idx; i < p.len_; i++) {
+ auto &post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
+ this,
+ post_op.eltwise.alg,
+ post_op.eltwise.alpha,
+ post_op.eltwise.beta
+ ));
+ } else if (post_op.is_depthwise()) {
+ depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
+ this,
+ post_op.depthwise.alg
+ ));
+ }
+ }
+
this->preamble();
mov(reg_input0, ptr[this->param1 + GET_OFF_DW(src_row0)]);
mov(reg_bias, ptr[this->param1 + GET_OFF_DW(bias)]);
mov(reg_kh, ptr[this->param1 + GET_OFF_DW(kh_padding)]);
mov(reg_ur_w, ptr[this->param1 + GET_OFF_DW(ur_w)]);
+ mov(reg_oc_work, ptr[this->param1 + GET_OFF_DW(oc_work)]);
+ mov(reg_oc_off, ptr[this->param1 + GET_OFF_DW(oc_off)]);
+
+ Label(tail_label);
+ Label(exit_label);
- loop_body();
+ cmp(reg_oc_work, jcp.ch_block);
+ jl(tail_label, T_NEAR);
+
+ loop_body(jcp.ch_block);
+ jmp(exit_label, T_NEAR);
+
+ L(tail_label);
+
+ if (jcp.oc % jcp.ch_block != 0)
+ loop_body(jcp.oc % jcp.ch_block);
+
+ L(exit_label);
this->postamble();
- if (jcp.with_eltwise)
- eltwise_injector->prepare_table();
+ for (auto& inj : eltwise_injectors)
+ inj->prepare_table();
+}
+
+template <cpu_isa_t isa>
+bool jit_uni_dw_conv_row_f32<isa>::post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+ auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
+ auto is_binarization = [&](int idx) { return p.entry_[idx].is_binarization(); };
+
+ int start_idx = p.find(primitive_kind::convolution) + 1;
+
+ switch (p.len_ - start_idx) {
+ case 0: return true; // no post_ops
+ case 1: return is_simple(start_idx) || is_sum(start_idx) || is_binarization(start_idx);
+ case 2: return (is_sum(start_idx) && is_simple(start_idx+1)) || (is_simple(start_idx) && is_simple(start_idx+1)) ||
+ (is_simple(start_idx) && is_binarization(start_idx+1));
+ case 3: return (is_sum(start_idx) && is_simple(start_idx+1) && is_simple(start_idx+2));
+ default: return false;
+ }
+
+ return false;
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_1x1_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
+ const primitive_attr_t &attr) {
+ if (!mayiuse(isa)) return status::unimplemented;
+ const int simd_w = isa == avx512_common ? 16 : 8;
+
+ const auto &p = attr.post_ops_;
+
+ int dw_conv_ind = p.find(primitive_kind::convolution);
+ jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
+
+ jcp_dw.ch_block = simd_w;
+ jcp_dw.with_bias = true;
+
+ jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
+ jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
+ jcp_dw.ic = jcp.oc;
+ jcp_dw.oc = jcp.oc;
+ jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
+ jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
+ jcp_dw.oh = jcp.dw_conv_oh;
+ jcp_dw.ow = jcp.dw_conv_ow;
+ jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
+ jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
+ jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
+ jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+
+ if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
+ return status::unimplemented;
+
+ if (!post_ops_ok(jcp_dw, attr))
+ return status::unimplemented;
+
+ jcp_dw.ur_w = 4;
+
+ jcp_dw.src_dt = jcp.src_dt;
+ jcp_dw.dst_dt = jcp.dst_dt;
+ jcp_dw.bia_dt = jcp.bia_dt;
+ jcp_dw.typesize_in = (int)types::data_type_size(jcp.src_dt);
+ jcp_dw.typesize_bia = (int)types::data_type_size(jcp.bia_dt);
+ jcp_dw.typesize_out = (int)types::data_type_size(jcp.dst_dt);
+
+ if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
+ return status::unimplemented;
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
+ const primitive_attr_t &attr) {
+ if (!mayiuse(isa)) return status::unimplemented;
+ const int simd_w = isa == avx512_common ? 16 : 8;
+
+ const auto &p = attr.post_ops_;
+
+ int dw_conv_ind = p.find(primitive_kind::convolution);
+ jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
+
+ jcp_dw.ch_block = simd_w;
+ jcp_dw.with_bias = true;
+
+ jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
+ jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
+ jcp_dw.ic = jcp.oc;
+ jcp_dw.oc = jcp.oc;
+ jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
+ jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
+ jcp_dw.oh = jcp.dw_conv_oh;
+ jcp_dw.ow = jcp.dw_conv_ow;
+ jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
+ jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
+ jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
+ jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+
+ if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
+ return status::unimplemented;
+
+ if (!post_ops_ok(jcp_dw, attr))
+ return status::unimplemented;
+
+ jcp_dw.ur_w = 4;
+
+ jcp_dw.src_dt = jcp.dst_dt;
+ jcp_dw.dst_dt = jcp.dst_dt;
+ jcp_dw.bia_dt = jcp.bia_dt;
+ jcp_dw.typesize_in = (int)types::data_type_size(jcp.src_dt);
+ jcp_dw.typesize_bia = (int)types::data_type_size(jcp.bia_dt);
+ jcp_dw.typesize_out = (int)types::data_type_size(jcp.dst_dt);
+
+ if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
+ return status::unimplemented;
+
+ return status::success;
}
template <cpu_isa_t isa>
-status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp,
- int ic, int ih, int iw, int oh, int ow, int ker_h, int ker_w, int str_h, int str_w, alg_kind_t eltwise_alg,
- float eltwise_alpha, float eltwise_beta, bool with_sum) {
+status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_bin_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
+ const primitive_attr_t &attr) {
if (!mayiuse(isa)) return status::unimplemented;
const int simd_w = isa == avx512_common ? 16 : 8;
- jcp.kh = ker_h;
- jcp.kw = ker_w;
- jcp.ch_block = simd_w;
- jcp.with_bias = true;
- jcp.ic = ic;
- jcp.oc = ic;
- jcp.ih = ih;
- jcp.iw = iw;
- jcp.oh = oh;
- jcp.ow = ow;
- jcp.stride_h = str_h;
- jcp.stride_w = str_w;
-
- if (jcp.kh != 3 || jcp.kw != 3)
- return status::unimplemented;
-
- jcp.ur_w = 4;
-
- jcp.with_eltwise = eltwise_alg != mkldnn_alg_kind_undef;
- jcp.eltwise_alg = eltwise_alg;
- jcp.eltwise_alpha = eltwise_alpha;
- jcp.eltwise_beta = eltwise_beta;
- jcp.with_sum = with_sum;
+ const auto &p = attr.post_ops_;
+
+ int dw_conv_ind = p.find(primitive_kind::convolution);
+ jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
+ jcp_dw.with_binarization = p.find(primitive_kind::binarization, dw_conv_ind) != -1;
+
+ jcp_dw.ch_block = simd_w;
+ jcp_dw.with_bias = true;
+
+ jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
+ jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
+ jcp_dw.ic = jcp.oc;
+ jcp_dw.oc = jcp.oc;
+ jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
+ jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
+ jcp_dw.oh = jcp.dw_conv_oh;
+ jcp_dw.ow = jcp.dw_conv_ow;
+ jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
+ jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
+ jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
+ jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+
+ if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
+ return status::unimplemented;
+
+ if (!post_ops_ok(jcp_dw, attr))
+ return status::unimplemented;
+
+ jcp_dw.ur_w = 4;
+
+ jcp_dw.src_dt = mkldnn_f32;
+ jcp_dw.dst_dt = jcp_dw.with_binarization ? mkldnn_bin : mkldnn_f32;
+ jcp_dw.bia_dt = mkldnn_f32;
+ jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
+ jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
+ jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
+
+ if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
+ return status::unimplemented;
return status::success;
}