}
template <cpu_isa_t isa>
-bool jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::maybe_relu(int position) {
- using namespace primitive_kind;
- const auto &p = attr_.post_ops_;
-
- if (position == 0) {
- /* relu before sum */
- return false
- || jcp.with_eltwise
- || p.contain(eltwise, 0)
- || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
- } else if (position == 1) {
- /* relu after sum */
- const int sum_idx = p.contain(sum, 0)
- ? 0 : (p.contain(sum, 1) ? 1 : -1);
- if (sum_idx == -1)
- return false;
-
- return false
- || p.contain(eltwise, sum_idx + 1)
- || jcp.dst_dt == data_type::u8;
- }
-
- return false;
-}
-
-template <cpu_isa_t isa>
void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
Ymm ymm_dst = Ymm(vmm_dst.getIdx());
Xmm xmm_dst = Xmm(vmm_dst.getIdx());
if (isa != sse42 && !scalar_store)
vpermq(ymm_dst, ymm_dst, 0x08);
- uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
+ uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
if (scalar_store) {
movq(reg_tmp_64, xmm_dst);
if (isa != sse42 && !scalar_store)
vpermq(ymm_dst, ymm_dst, 0x08);
- uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
+ uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
if (scalar_store) {
movq(reg_tmp_64, xmm_dst);
void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(int ur_ch_blocks, int ch_step, int ur_w) {
int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
+ pop(reg_oc_off);
pop(reg_scales_base);
- uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
+ mov(imm_addr64, l_table);
+
+ const auto &p = attr_.post_ops_;
+ const int sum_idx = p.find(primitive_kind::sum);
+ const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
+
+ bool is_scalar_store = ch_step < jcp.ch_block;
+
for (int r = 0; r < repeats; r++) {
- if (ch_step < jcp.ch_block) {
+ for (int ii = 0; ii < ur_ch_blocks; ii++) {
+ if (jcp.with_bias) {
+ int b_off = ii * jcp.ch_block + r * (jcp.ch_block / 2);
+ cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], is_scalar_store);
+ }
+
for (int jj = 0; jj < ur_w; jj++) {
- Vmm vmm_dst = get_acc_reg(r * ur_w * ur_ch_blocks + jj);
+ Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
uni_vcvtdq2ps(vmm_dst, vmm_dst);
- if (jcp.with_bias) {
- int b_off = r * (jcp.ch_block / 2);
- cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], true);
+ if (jcp.with_bias)
uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
- }
- int s_off = jcp.is_oc_scale * (r * (jcp.ch_block / 2));
- cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], true);
+ int s_off = jcp.is_oc_scale * (ii * jcp.ch_block + r * (jcp.ch_block / 2));
+ cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], is_scalar_store);
uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
+ }
+ }
- int o_off = jj * jcp.oc + r * (jcp.ch_block / 2);
- if (jcp.with_sum) {
- uni_vpxor(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst);
- cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], true);
- uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+ int eltwise_inj_idx = 0;
+ int depthwise_inj_idx = 0;
+ for (int i = 0; i < p.len_; i++) {
+ int start_idx = 4 + r * ur_ch_blocks*ur_w;
+
+ auto& post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + ur_ch_blocks * ur_w);
+ eltwise_inj_idx++;
+ } else if (post_op.is_depthwise()) {
+ mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+ mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+ add(reg_d_weights, reg_oc_off);
+ add(reg_d_bias, reg_oc_off);
+
+ if (r == 1) {
+ add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
+ add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
}
- if (maybe_relu(0))
- uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
+ for (int ii = 0; ii < ur_ch_blocks; ii++) {
+ depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
+ start_idx + ur_w * ii, start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
+
+ add(reg_d_weights, jcp.ch_block * sizeof(float));
+ add(reg_d_bias, jcp.ch_block * sizeof(float));
+ }
+
+ depthwise_inj_idx++;
+ } else if (post_op.is_sum(false)) {
+ for (int ii = 0; ii < ur_ch_blocks; ii++) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks*ur_w + ur_w * ii + jj);
+ int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
+
+ cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], is_scalar_store);
+
+ if (p_sum_scale == 1.f) {
+ uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+ } else {
+ uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 0 * vlen]);
+ }
+ }
+ }
+ }
+ }
- if (maybe_relu(1))
- uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
+ for (int ii = 0; ii < ur_ch_blocks; ii++) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
+ int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
if (jcp.dst_dt != data_type::f32) {
if (attr_.round_mode_ == round_mode::nearest)
assert(!"unimplemented");
}
- store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
- }
- } else {
- for (int ii = 0; ii < ur_ch_blocks; ii++) {
- if (jcp.with_bias) {
- int b_off = ii * jcp.ch_block + r * (jcp.ch_block / 2);
- cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], false);
- }
-
- for (int jj = 0; jj < ur_w; jj++) {
- Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks*ur_w + ur_w * ii + jj);
- uni_vcvtdq2ps(vmm_dst, vmm_dst);
-
- if (jcp.with_bias)
- uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
-
- int s_off = jcp.is_oc_scale * (ii * jcp.ch_block + r * (jcp.ch_block / 2));
- cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], false);
- uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
-
- int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
- if (jcp.with_sum) {
- cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], false);
- uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
- }
-
- if (maybe_relu(0))
- uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
-
- if (maybe_relu(1))
- uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
-
- if (jcp.dst_dt != data_type::f32) {
- if (attr_.round_mode_ == round_mode::nearest)
- uni_vcvtps2dq(vmm_dst, vmm_dst);
- else if (attr_.round_mode_ == round_mode::down) {
- uni_vroundps(vmm_dst, vmm_dst, 1);
- uni_vcvtps2dq(vmm_dst, vmm_dst);
- } else
- assert(!"unimplemented");
- }
-
- store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
- }
+ store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, is_scalar_store);
}
}
}
push(reg_scales_base);
+ push(reg_oc_off);
}
template <cpu_isa_t isa>
push(reg_kernel_base);
push(reg_ch_work);
push(reg_scales_base);
+ push(reg_oc_off);
L(unrolled_w_label); {
int ur_w = jcp.ur_w;
L(exit_label);
+ pop(reg_oc_off);
pop(reg_scales_base);
pop(reg_ch_work);
pop(reg_kernel_base);
template <cpu_isa_t isa>
void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
+ const auto &p = attr_.post_ops_;
+ for (int i = 0; i < p.len_; i++) {
+ auto &post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
+ this,
+ post_op.eltwise.alg,
+ post_op.eltwise.alpha,
+ post_op.eltwise.beta
+ ));
+ } else if (post_op.is_depthwise()) {
+ depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
+ this,
+ post_op.depthwise.alg
+ ));
+ }
+ }
+
this->preamble();
mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
mov(reg_ch_work, ptr[this->param1 + GET_OFF(ch_work)]);
+ mov(reg_oc_off, ptr[this->param1 + GET_OFF(oc_off)]);
Label main_loop_label;
Label tail_loop_label;
add(reg_kernel_base, jcp.ch_block * jcp.kh * jcp.kw * jcp.typesize_in);
add(reg_bias_base, jcp.ch_block * jcp.typesize_bia);
add(reg_scales_base, jcp.is_oc_scale * jcp.ch_block * sizeof(float));
+ add(reg_oc_off, jcp.ch_block * sizeof(float));
jmp(main_loop_label, T_NEAR);
}
add(reg_kernel_base, 1 * jcp.typesize_in);
add(reg_bias_base, 1 * jcp.typesize_bia);
add(reg_scales_base, jcp.is_oc_scale * 1 * sizeof(float));
+ add(reg_oc_off, 1 * sizeof(float));
jmp(tail_loop_label, T_NEAR);
}
L(exit_label);
this->postamble();
+
+ prepare_table();
+
+ for (auto& inj : eltwise_injectors)
+ inj->prepare_table();
+}
+
+template <cpu_isa_t isa>
+void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::prepare_table() {
+ const auto &p = attr_.post_ops_;
+ const int sum_idx = p.find(primitive_kind::sum);
+ const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
+
+ const int32_t cvals_sum_scale[] = {
+ float2int(p_sum_scale)
+ };
+
+ align(64);
+ L(l_table);
+ for (size_t i = 0; i < sizeof(cvals_sum_scale) / sizeof(cvals_sum_scale[0]); ++i) {
+ for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
+ dd(cvals_sum_scale[i]);
+ }
+ }
}
template <cpu_isa_t isa>
jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
const auto &p = attr.post_ops_;
- auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
+ auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return !jcp.with_eltwise && (is_relu(0) || is_sum(0)); // sum OR relu
- case 2: return !jcp.with_eltwise && (is_sum(0) && is_relu(1)); // sum->relu
- default: return false;
+ case 0: return true;
+ case 1: return is_simple(0) || is_sum(0);
+ case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
+ (is_simple(0) && is_simple(1));
+ case 3: return (is_simple(0) && is_sum(1) && is_simple(2));
+ default: return false;
}
return false;
status_t jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope)
+ const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr)
{
if (!mayiuse(isa)) return status::unimplemented;
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alpha = relu_negative_slope;
jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
const auto &p = attr.post_ops_;
jcp.with_sum = p.find(primitive_kind::sum) != -1;
- if (!jcp.with_eltwise) {
- int eltwise_ind = p.find(primitive_kind::eltwise);
- if (eltwise_ind != -1) {
- jcp.with_eltwise = true;
- jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha;
- }
- }
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise)
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
auto desired_act_fmt = nhwc;
auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;