/*******************************************************************************
-* Copyright 2018 Intel Corporation
+* Copyright 2018-2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* limitations under the License.
*******************************************************************************/
+#include <common/memory_tracking.hpp>
#include "c_types_map.hpp"
#include "nstl.hpp"
#include "type_helpers.hpp"
using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
using namespace Xbyak;
template <cpu_isa_t isa>
-bool jit_uni_x8s8s32x_conv_fwd_kernel<isa>::maybe_relu(int position) {
- using namespace primitive_kind;
- const auto &p = attr_.post_ops_;
-
- if (position == 0) {
- /* relu before sum */
- return false
- || jcp.with_eltwise
- || p.contain(eltwise, 0)
- || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
- } else if (position == 1) {
- /* relu after sum */
- const int sum_idx = p.contain(sum, 0)
- ? 0 : (p.contain(sum, 1) ? 1 : -1);
- if (sum_idx == -1)
- return false;
-
- return false
- || p.contain(eltwise, sum_idx + 1)
- || jcp.dst_dt == data_type::u8;
- }
-
- return false;
-}
-
-template <cpu_isa_t isa>
void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in,
const Xbyak::Operand &op, bool scalar_load) {
Xmm xmm_in = Xmm(vmm_in.getIdx());
if (isa != sse42 && !scalar_store)
vpermq(ymm_dst, ymm_dst, 0x08);
- uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
+ uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
if (scalar_store) {
movq(reg_tmp_64, xmm_dst);
if (isa != sse42 && !scalar_store)
vpermq(ymm_dst, ymm_dst, 0x08);
- uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
+ uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
if (scalar_store) {
movq(reg_tmp_64, xmm_dst);
for (int r = 0; r < repeats; r++) {
for (int jj = _start; jj < _end; jj++) {
int inp_off = (ki * dilate_w + jj * stride_w - pad_l) * jcp.ic * jcp.ngroups;
- if (tail_size > 0) {
- if (h_padded || jj < jj_start || jj >= jj_end) {
- uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
- uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
- uni_vandps(get_src_reg(jj), get_src_reg(jj), vmm_mask);
- uni_vpbroadcastd(get_src_reg(jj), Xmm(get_src_reg(jj).getIdx()));
- } else {
- uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
-
- if (jcp.signed_input) {
- uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
- }
-
- uni_vandps(get_src_reg(jj), get_src_reg(jj), vmm_mask);
- uni_vpbroadcastd(get_src_reg(jj), Xmm(get_src_reg(jj).getIdx()));
- }
+ if (tail_size > 0) {
+ if (h_padded || jj < jj_start || jj >= jj_end) {
+ uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
+ uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
} else {
- if (h_padded || jj < jj_start || jj >= jj_end) {
- uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
- } else {
- uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
- }
+ uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
- if (jcp.signed_input)
+ if (jcp.signed_input) {
uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
+ }
+ }
+ } else {
+ if (h_padded || jj < jj_start || jj >= jj_end) {
+ uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
+ } else {
+ uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
}
+
+ if (jcp.signed_input)
+ uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
+ }
}
for (int ii = 0; ii < oc_blocks; ii++) {
mov(imm_addr64, l_table);
uni_vmovups(vmm_one, ptr[imm_addr64 + 0 * vlen]);
uni_vmovups(vmm_shift, ptr[imm_addr64 + 1 * vlen]);
- uni_vmovups(vmm_mask, ptr[imm_addr64 + 4 * vlen]);
if (jcp.signed_input) {
mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
+ pop(reg_oc_off);
pop(reg_scales_base);
mov(imm_addr64, l_table);
const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
for (int r = 0; r < repeats; r++) {
+ auto get_dst_off = [=](int ii, int jj) {
+ if (jcp.with_dw_conv)
+ return (ii * jcp_dw.kh * jcp.ow + jj) * jcp.oc_block + r * (jcp.oc_block / 2);
+ else
+ return ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
+ };
+
int tail_size = isa == avx2 ? oc_step : nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2);
bool is_scalar_store = isa == avx2 ? tail_size < jcp.oc_block : tail_size < jcp.oc_block / 2;
- if (is_scalar_store) {
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ if (jcp.with_bias) {
+ int b_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
+ cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], false);
+
+ if (jcp.signed_input)
+ uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha);
+ }
+
for (int jj = 0; jj < ur_w; jj++) {
- Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + jj);
+ Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
uni_vcvtdq2ps(vmm_dst, vmm_dst);
- uni_vmovups(vmm_reminder_dst, vmm_dst);
- for (int oc = 0; oc < tail_size; oc++) {
- uni_vmovups(vmm_dst, vmm_reminder_dst);
+ if (jcp.signed_input) {
+ int c_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
+ cvt2ps(data_type::s32, vmm_comp, ptr[reg_compensation_base + c_off * sizeof(int32_t)], false);
+ }
- if (jcp.with_bias) {
- int b_off = r * (jcp.oc_block / 2) + oc;
- cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], true);
+ if (jcp.signed_input)
+ uni_vaddps(vmm_dst, vmm_dst, vmm_comp);
+ if (jcp.with_bias)
+ uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
- if (jcp.signed_input)
- uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha);
- }
- if (jcp.signed_input) {
- int c_off = r * (jcp.oc_block / 2) + oc;
- cvt2ps(data_type::s32, vmm_comp, ptr[reg_compensation_base + c_off * sizeof(int32_t)], true);
- }
+ int s_off = jcp.is_oc_scale * (ii * jcp.oc_block + r * (jcp.oc_block / 2));
+ cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], false);
+ uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
+ }
+ }
- if (jcp.signed_input)
- uni_vaddps(vmm_dst, vmm_dst, vmm_comp);
- if (jcp.with_bias)
- uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
+ int eltwise_inj_idx = 0;
+ int depthwise_inj_idx = 0;
+ int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
+ for (int i = 0; i < end_idx; i++) {
+ int start_idx = 1 + r * jcp.ur_w * jcp.nb_oc_blocking;
+
+ auto& post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + oc_blocks * ur_w);
+ eltwise_inj_idx++;
+ } else if (post_op.is_depthwise()) {
+ mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+ mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+ add(reg_d_weights, reg_oc_off);
+ add(reg_d_bias, reg_oc_off);
+
+ if (r == 1) {
+ add(reg_d_weights, (jcp.oc_block / 2) * sizeof(float));
+ add(reg_d_bias, (jcp.oc_block / 2) * sizeof(float));
+ }
- int s_off = jcp.is_oc_scale * (r * (jcp.oc_block / 2) + oc);
- cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], true);
- uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx + ur_w * ii,
+ start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
- int o_off = jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
- if (jcp.with_sum) {
- uni_vpxor(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst);
- cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], true);
+ add(reg_d_weights, jcp.oc_block * sizeof(float));
+ add(reg_d_bias, jcp.oc_block * sizeof(float));
+ }
- if (p_sum_scale == 1.f) {
- uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+ depthwise_inj_idx++;
+ } else if (post_op.is_sum(false)) {
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
+ int o_off = get_dst_off(ii, jj);
+
+ if (is_scalar_store) {
+ for (int oc = 0; oc < tail_size; oc++) {
+ uni_vpxor(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst);
+ cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + (o_off + oc) * jcp.typesize_out], true);
+
+ if (oc < jcp.oc_block / 2) {
+ uni_vpslldq(vmm_prev_dst, vmm_prev_dst, oc * sizeof(float));
+ } else {
+ Ymm ymm_prev_dst = Ymm(vmm_prev_dst.getIdx());
+ vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
+ vpslldq(vmm_prev_dst, vmm_prev_dst, (oc - jcp.oc_block / 2) * sizeof(float));
+ }
+
+ if (p_sum_scale == 1.f) {
+ uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+ } else {
+ uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
+ }
+ }
} else {
- uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
- }
- }
-
- if (maybe_relu(0)) {
- uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
- uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
- }
-
- if (maybe_relu(1)) {
- uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
- uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
- }
-
- if (jcp.dst_dt != data_type::f32) {
- if (attr_.round_mode_ == round_mode::nearest)
- uni_vcvtps2dq(vmm_dst, vmm_dst);
- else if (attr_.round_mode_ == round_mode::down) {
- uni_vroundps(vmm_dst, vmm_dst, 1);
- uni_vcvtps2dq(vmm_dst, vmm_dst);
- } else
- assert(!"unimplemented");
- }
-
- store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
+ cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], false);
- if (isa == avx2) {
- vperm2i128(ymm_tmp, ymm_reminder_dst, ymm_reminder_dst, 0x01);
- vpalignr(ymm_reminder_dst, ymm_tmp, ymm_reminder_dst, jcp.typesize_out);
- } else {
- psrldq(vmm_reminder_dst, jcp.typesize_out);
+ if (p_sum_scale == 1.f) {
+ uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+ } else {
+ uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
+ }
+ }
}
}
}
- } else {
- for (int ii = 0; ii < oc_blocks; ii++) {
- if (jcp.with_bias) {
- int b_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
- cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], false);
+ }
- if (jcp.signed_input)
- uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha);
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
+ int o_off = get_dst_off(ii, jj);
+
+ if (jcp.dst_dt != data_type::f32) {
+ if (attr_.round_mode_ == round_mode::nearest)
+ uni_vcvtps2dq(vmm_dst, vmm_dst);
+ else if (attr_.round_mode_ == round_mode::down) {
+ uni_vroundps(vmm_dst, vmm_dst, 1);
+ uni_vcvtps2dq(vmm_dst, vmm_dst);
+ } else
+ assert(!"unimplemented");
}
- for (int jj = 0; jj < ur_w; jj++) {
- Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
- uni_vcvtdq2ps(vmm_dst, vmm_dst);
-
- if (jcp.signed_input) {
- int c_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
- cvt2ps(data_type::s32, vmm_comp, ptr[reg_compensation_base + c_off * sizeof(int32_t)], false);
- }
-
- if (jcp.signed_input)
- uni_vaddps(vmm_dst, vmm_dst, vmm_comp);
- if (jcp.with_bias)
- uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
-
- int s_off = jcp.is_oc_scale * (ii * jcp.oc_block + r * (jcp.oc_block / 2));
- cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], false);
- uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
-
- int o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
- if (jcp.with_sum) {
- cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], false);
+ if (is_scalar_store) {
+ for (int oc = 0; oc < tail_size; oc++) {
+ store_dst(ptr[reg_output + (o_off + oc) * jcp.typesize_out], vmm_dst, true);
- if (p_sum_scale == 1.f) {
- uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+ if (isa == avx2) {
+ Ymm ymm_dst = Ymm(vmm_dst.getIdx());
+ vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
+ vpalignr(ymm_dst, ymm_tmp, ymm_dst, jcp.typesize_out);
} else {
- uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
+ psrldq(vmm_dst, jcp.typesize_out);
}
}
-
- if (maybe_relu(0)) {
- uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
- uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
- }
-
- if (maybe_relu(1)) {
- uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
- uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
- }
-
- if (jcp.dst_dt != data_type::f32) {
- if (attr_.round_mode_ == round_mode::nearest)
- uni_vcvtps2dq(vmm_dst, vmm_dst);
- else if (attr_.round_mode_ == round_mode::down) {
- uni_vroundps(vmm_dst, vmm_dst, 1);
- uni_vcvtps2dq(vmm_dst, vmm_dst);
- } else
- assert(!"unimplemented");
- }
-
+ } else {
store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
}
}
}
push(reg_scales_base);
+ push(reg_oc_off);
}
template <cpu_isa_t isa>
int dilate_w = jcp.dilate_w + 1;
int str_w = jcp.stride_w;
const int inp_mult = jcp.ic * jcp.ngroups;
+ const int out_mult = jcp.with_dw_conv ? jcp.oc_block : jcp.oc * jcp.ngroups;
int l_pad = jcp.l_pad;
int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
push(reg_output_base);
push(reg_kernel_base);
push(reg_scales_base);
+ push(reg_oc_off);
if (l_pad > 0) {
n_oi--;
else
width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad"
add(reg_input, jcp.typesize_in * (ur_w * str_w - l_pad) * inp_mult);
- add(reg_output, jcp.typesize_out * ur_w * jcp.oc * jcp.ngroups);
+ add(reg_output, jcp.typesize_out * ur_w * out_mult);
}
Label ow_loop_label;
width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle"
add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
- add(reg_output, jcp.typesize_out * ur_w * jcp.oc * jcp.ngroups);
+ add(reg_output, jcp.typesize_out * ur_w * out_mult);
inc(reg_oi_iter);
cmp(reg_oi_iter, n_oi);
if (r_pad1 > 0 && n_oi >=0) {
width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad"
add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
- add(reg_output, jcp.typesize_out * ur_w * jcp.oc * jcp.ngroups);
+ add(reg_output, jcp.typesize_out * ur_w * out_mult);
}
if (ur_w_tail != 0)
width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail"
+ pop(reg_oc_off);
pop(reg_scales_base);
pop(reg_kernel_base);
pop(reg_output_base);
template <cpu_isa_t isa>
void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::generate()
{
+ const auto &p = attr_.post_ops_;
+ int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
+ for (int i = 0; i < end_idx; i++) {
+ auto &post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
+ this,
+ post_op.eltwise.alg,
+ post_op.eltwise.alpha,
+ post_op.eltwise.beta
+ ));
+ } else if (post_op.is_depthwise()) {
+ depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
+ this,
+ post_op.depthwise.alg
+ ));
+ }
+ }
+
this->preamble();
mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
- mov(reg_oc, ptr[this->param1 + GET_OFF(oc_work)]);
+ mov(reg_oc_work, ptr[this->param1 + GET_OFF(oc_work)]);
if (jcp.with_bias)
mov(reg_bias_base, ptr[this->param1 + GET_OFF(bias)]);
mov(reg_scales_base, ptr[this->param1 + GET_OFF(scales)]);
if (jcp.signed_input)
mov(reg_compensation_base, ptr[param1 + GET_OFF(compensation)]);
+ mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
Label main_loop_label;
Label tail_label;
Label exit_label;
- cmp(reg_oc, jcp.nb_oc_blocking * jcp.oc_block);
+ cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
jne(main_loop_label, T_NEAR);
solve_common(jcp.nb_oc_blocking, jcp.oc_block);
- sub(reg_oc, jcp.nb_oc_blocking * jcp.oc_block);
+ sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
jmp(exit_label, T_NEAR);
L(main_loop_label); {
- cmp(reg_oc, jcp.oc_block);
+ cmp(reg_oc_work, jcp.oc_block);
jl(tail_label, T_NEAR);
solve_common(1, jcp.oc_block);
- sub(reg_oc, jcp.oc_block);
+ sub(reg_oc_work, jcp.oc_block);
add(reg_kernel_base, jcp.oc_block * jcp.nb_ic * jcp.kh * jcp.kw * jcp.ic_block * jcp.typesize_in);
- add(reg_output_base, jcp.oc_block * jcp.typesize_out);
+ if (jcp.with_dw_conv)
+ add(reg_output_base, jcp.oc_block * jcp_dw.kh * jcp.ow * jcp.typesize_out);
+ else
+ add(reg_output_base, jcp.oc_block * jcp.typesize_out);
add(reg_bias_base, jcp.oc_block * jcp.typesize_bia);
add(reg_scales_base, jcp.is_oc_scale * jcp.oc_block * sizeof(float));
add(reg_compensation_base, jcp.oc_block * sizeof(int32_t));
+ add(reg_oc_off, jcp.oc_block * sizeof(float));
jmp(main_loop_label, T_NEAR);
}
L(tail_label);
- solve_common(1, jcp.oc % jcp.oc_block);
+ if (jcp.oc % jcp.oc_block != 0)
+ solve_common(1, jcp.oc % jcp.oc_block);
L(exit_label);
this->postamble();
prepare_table();
+
+ for (auto& inj : eltwise_injectors)
+ inj->prepare_table();
}
template <cpu_isa_t isa>
dd(cvals_sum_scale[i]);
}
}
-
- for (size_t i = 0; i < sizeof(cvals_shift) / sizeof(cvals_shift[0]); ++i) {
- for (size_t d = 0; d < vlen / sizeof(int8_t); ++d) {
- if ((int)d < jcp.ic % jcp.ic_block)
- db(255);
- else
- db(0);
- }
- }
}
template <cpu_isa_t isa>
bool jit_uni_x8s8s32x_conv_fwd_kernel<isa>::post_ops_ok(
jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
- using namespace primitive_kind;
const auto &p = attr.post_ops_;
- auto is_relu = [&](int idx) {
- return p.entry_[idx].kind == eltwise
- && p.entry_[idx].eltwise.scale == 1.
- && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
- && p.entry_[idx].eltwise.alpha == 0.;
- };
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
+ auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
+ auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
case 0: return true;
- case 1: return true
- && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0))
- && IMPLICATION(!jcp.with_eltwise, is_relu(0) || p.contain(sum, 0));
- case 2: return true
- && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0) && is_relu(1))
- && IMPLICATION(!jcp.with_eltwise, false
- || (p.contain(sum, 0) && is_relu(1))
- || (p.contain(sum, 1) && is_relu(0)));
- case 3: return true
- && jcp.with_eltwise == false
- && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
+ case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
+ case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
+ (is_dw_conv(0) && is_simple(1)) || (is_simple(0) && is_dw_conv(1)) ||
+ (is_simple(0) && is_simple(1));
+ case 3: return (is_simple(0) && is_sum(1) && is_simple(2)) ||
+ (is_simple(0) && is_dw_conv(1) && is_simple(2)) ||
+ (is_dw_conv(0) && is_simple(1) && is_simple(2));
+ case 4: return (is_simple(0) && is_dw_conv(1) && is_simple(2) && is_simple(3));
default: return false;
}
const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
cpu_memory_t::pd_t &bias_pd,
- const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+ const primitive_attr_t &attr)
{
if (!mayiuse(isa)) return status::unimplemented;
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alpha = relu_negative_slope;
jcp.signed_input = src_d.data_type() == data_type::s8;
jcp.oc_padded = rnd_up(jcp.oc, jcp.oc_block);
jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
+ if (jcp.ngroups != 1) {
+ if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
+ return status::unimplemented;
+ }
+
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
const auto &p = attr.post_ops_;
- jcp.with_sum = p.find(primitive_kind::sum) != -1;
- if (!jcp.with_eltwise) {
- jcp.with_eltwise = p.find(primitive_kind::eltwise) != -1;
- jcp.eltwise_alpha = 0.f;
+
+ int dw_conv_ind = p.find(primitive_kind::convolution);
+ jcp.with_dw_conv = dw_conv_ind != -1;
+ if (jcp.with_dw_conv) {
+ jcp.dw_conv_oh = jcp.oh;
+ jcp.dw_conv_ow = jcp.ow;
+ jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
+ jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
}
auto desired_act_fmt = nhwc;
return status::unimplemented;
}
+ jcp.src_dt = cd.src_desc.data_type;
jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
jcp.dst_dt = cd.dst_desc.data_type;
assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
jcp.ur_h = 1; /* no code-unrolling by h so far */
- jcp.ur_w = isa == avx2 ? 3 : 2;
- jcp.nb_oc_blocking = 2;
- if (jcp.nb_oc % jcp.nb_oc_blocking != 0) jcp.nb_oc_blocking = 1;
+ jcp.ur_w = isa == avx2 ? 4 : 2;
+ jcp.nb_oc_blocking = nstl::min(2, jcp.nb_oc);
+ jcp.max_regs_ur = 12;
+
+ // WA to prevent fallback on gemm implementation
+ if (isa == sse42 && jcp.ic == 3) {
+ jcp.ur_w = 4;
+ jcp.nb_oc_blocking = 1;
+ }
if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
jcp.ur_w_tail = jcp.ow % jcp.ur_w;
int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
+ (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+ if (r_pad_no_tail > jcp.ur_w)
+ return status::unimplemented;
- if (r_pad_no_tail > jcp.ur_w) {
- /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
- jcp.ur_w = r_pad_no_tail + 1;
- jcp.ur_w_tail = jcp.ow % jcp.ur_w;
- /* check again ... */
- r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
- if ((r_pad_no_tail > jcp.ur_w) || (jcp.ow < jcp.ur_w))
- return status::unimplemented;
- }
- if (jcp.l_pad > jcp.ur_w) return status::unimplemented;
+ if (jcp.l_pad > jcp.ur_w)
+ return status::unimplemented;
jcp.wei_adj_scale = (jcp.signed_input) ? (1.0f / 2.0f) : 1.0f;
return status::success;
}
+template <cpu_isa_t isa>
+void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw,
+ const primitive_attr_t &attr) {
+ if (jcp.oc != jcp.oc_padded)
+ scratchpad.book(key_conv_padded_bias, (size_t)jcp.typesize_bia * jcp.oc_padded);
+
+ if (jcp.signed_input) {
+ size_t count = nstl::max(attr.output_scales_.count_, 8);
+ scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
+
+ if (jcp.oc != jcp.oc_padded)
+ scratchpad.book(key_conv_padded_compensation, sizeof(int32_t) * jcp.oc_padded);
+ }
+
+ if (jcp.with_dw_conv) {
+ const int nthreads = mkldnn_get_max_threads();
+ size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
+ scratchpad.book(key_dw_conv_buffer, jcp_dw.typesize_in * dw_conv_buffer_size_ * nthreads);
+
+ if (jcp.oc != jcp.oc_padded)
+ scratchpad.book(key_dw_conv_padded_bias, (size_t)jcp_dw.typesize_bia * jcp.oc_padded);
+ }
+}
+
template struct jit_uni_x8s8s32x_conv_fwd_kernel<avx2>;
template struct jit_uni_x8s8s32x_conv_fwd_kernel<sse42>;