using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
using namespace Xbyak;
int depthwise_inj_idx = 0;
const auto &p = attr_.post_ops_;
- if (p.len_ == 0 && eltwise_injectors.size() == 1) {
- int start_idx = get_acc_reg(0).getIdx();
- int end_idx = get_acc_reg(repeats * ur_w * ur_ch_blocks).getIdx();
-
- eltwise_injectors[0]->compute_vector_range(start_idx, end_idx);
- }
-
for (int i = 0; i < p.len_; i++) {
auto& post_op = p.entry_[i];
if (post_op.is_eltwise()) {
}
template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate()
-{
- if (jcp.with_eltwise) {
- eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
- this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
- ));
- }
-
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() {
const auto &p = attr_.post_ops_;
for (int i = 0; i < p.len_; i++) {
auto &post_op = p.entry_[i];
auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return true // sum OR eltwise OR deptwise
- && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
- case 2: return true // sum->relu OR sum->depthwise OR eltwise->depthwise OR depthwise->depthwise
- && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
- (is_simple(0) && is_simple(1)));
- case 3: return true // sum->eltwise->depthwise OR sum->depthwise->eltwise OR sum->depthwise->depthwise
- && !jcp.with_eltwise && ((is_sum(0) && is_simple(1) && is_simple(2)));
+ case 0: return true;
+ case 1: return is_simple(0) || is_sum(0);
+ case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
+ case 3: return is_sum(0) && is_simple(1) && is_simple(2);
default: return false;
}
status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+ const primitive_attr_t &attr)
{
if (!mayiuse(isa)) return status::unimplemented;
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alg = mkldnn_eltwise_relu;
- jcp.eltwise_alpha = relu_negative_slope;
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
return status::success;
}
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ if (jcp.with_bias && jcp.oc_without_padding != jcp.oc)
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+}
+
template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_common>;
template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>;
template struct jit_uni_dw_conv_fwd_kernel_f32<sse42>;
return status::success;
}
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ UNUSED(scratchpad);
+ UNUSED(jcp);
+}
+
template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_common>;
template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>;
template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse42>;
int off_filter = (reg_set + i) * simd_w;
Vmm vmm_acc = get_acc_reg(reg_set + i);
uni_vmovups(vmm_acc,
- vmmword[tmp_reg_filter + off_filter * sizeof(float)]);
+ vmmword[reg_tmp_filter + off_filter * sizeof(float)]);
}
}
}
template <cpu_isa_t isa>
inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
- int l_pad, int r_pad, int pad_offset, int ow_block) {
- const int pad = nstl::max(jcp.l_pad, jcp.r_pad);
- const int iw_overlap = jcp.iw + jcp.kw - 1 - jcp.l_pad - jcp.r_pad;
- const int unroll_w = nstl::min(jcp.ur_w, iw_overlap);
- const int right_border = iw_overlap - ow_block;
+ int unroll_w, int l_pad, int pad_offset, int ow_block) {
+
+ const int iw_block = ow_block * jcp.stride_w;
+ const int right_border = jcp.iw - iw_block;
+
+ const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
/* preamble count for number of cascaded LOAD + FMA operation */
- const int input_preamble_count
- = nstl::max(jcp.kw - jcp.stride_w - l_pad, 0);
+ const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
/* LOAD initial input registers, then cascade LOADs and FMAs*/
for (int r = 0; r < reg_repeats; ++r) {
- for (int i = 0; i < input_preamble_count; i++) {
- int off_input = ((i - pad_offset) * reg_repeats + r) * simd_w;
- Vmm vmm_input = get_input_reg((i + l_pad) * reg_repeats + r);
- uni_vmovups(vmm_input,
- ptr[tmp_reg_idx_input + off_input * sizeof(float)]);
- }
-
- for (int i = 0; i < unroll_w; ++i) {
- int off_output = (i * reg_repeats + r) * simd_w;
+ for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
+ int off_output = (i_ur * reg_repeats + r) * simd_w;
Vmm vmm_output = get_output_reg(r);
uni_vmovups(vmm_output,
- ptr[tmp_reg_idx_output + off_output * sizeof(float)]);
-
- int input_load_overlap = i * jcp.stride_w + input_preamble_count;
-
- /* Cascade 'input' loads for the corresponding FMAs */
- const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
- for (int c = 0; c < cascade_input; ++c) {
- int off_input
- = ((c + input_load_overlap - pad_offset) * reg_repeats
- + r)
- * simd_w;
- Vmm vmm_input = get_input_reg(
- ((c + input_load_overlap + l_pad) % jcp.kw)
- * reg_repeats
- + r);
- uni_vmovups(vmm_input,
- ptr[tmp_reg_idx_input + off_input * sizeof(float)]);
+ ptr[reg_tmp_output + off_output * sizeof(float)]);
+ if (i_ur == 0) {
+ for (int c = 0; c < input_overlap; ++c) {
+ int off_input
+ = ((c - pad_offset) * reg_repeats + r) * simd_w;
+ Vmm vmm_input
+ = get_input_reg((c % jcp.kw) * reg_repeats + r);
+ uni_vmovups(vmm_input,
+ ptr[reg_tmp_input + off_input * sizeof(float)]);
+ }
+ } else {
+ for (int c = 0; c < cascade_input; ++c) {
+ int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
+ int off_input
+ = ((overlap + c - pad_offset) * reg_repeats + r)
+ * simd_w;
+ Vmm vmm_input = get_input_reg(
+ ((overlap + c) % jcp.kw) * reg_repeats + r);
+ uni_vmovups(vmm_input,
+ ptr[reg_tmp_input + off_input * sizeof(float)]);
+ }
}
- for (int j = 0; j < jcp.kw; ++j) {
+ for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) {
+ int io_overlap = i_kw + (i_ur * jcp.stride_w);
/* Don't apply FMAs that fall into the padded region */
- if (i + j < l_pad || i + j - pad >= right_border)
+ if (io_overlap - l_pad < 0
+ || io_overlap - jcp.l_pad >= right_border)
continue;
+
Vmm vmm_input = get_input_reg(
- ((i * jcp.stride_w + j) % jcp.kw) * reg_repeats + r);
- Vmm vmm_acc = get_acc_reg(j * reg_repeats + r);
+ ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r);
+ Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r);
Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input;
- if( isa == sse42 ) uni_vmovups(vmm_aux, vmm_input);
+ if (isa == sse42)
+ uni_vmovups(vmm_aux, vmm_input);
uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output);
}
}
for (int i = 0; i < unroll_w; ++i) {
Vmm vmm_bias = get_bias_reg(r);
int off_output = (i * reg_repeats + r) * simd_w;
- uni_vaddps(vmm_bias, vmm_bias,
- vmmword[tmp_reg_idx_output + off_output * sizeof(float)]);
+ if (isa == sse42) {
+ /* Need to support unaligned address loads for SSE42*/
+ Vmm vmm_output = get_output_reg(1 + r);
+ uni_vmovups(vmm_output,
+ ptr[reg_tmp_output + off_output * sizeof(float)]);
+ uni_vaddps(vmm_bias, vmm_bias, vmm_output);
+ } else {
+ uni_vaddps(vmm_bias, vmm_bias,
+ vmmword[reg_tmp_output + off_output * sizeof(float)]);
+ }
}
}
}
for (int i = 0; i < jcp.kw; ++i) {
int off_filter = (i + reg_set) * simd_w;
Vmm vmm_acc = get_acc_reg(i + reg_set);
- uni_vmovups(vmmword[tmp_reg_filter + off_filter * sizeof(float)],
+ uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)],
vmm_acc);
}
}
}
template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::create_h_bounds_table() {
- /* Bounds are stored on an 8-bit sized element.
- * XXX: potential issues if bounds exceed 255.
- */
- const bool handle_padding = (jcp.t_pad > 0) || (jcp.b_pad > 0);
- if (handle_padding) {
-
- /* Calculate how many 'h_start' bounds are needed */
- const int h_bounds_count = get_loop_bounds_count(
- nstl::max(jcp.t_pad, jcp.b_pad), jcp.oh, jcp.oh_blk_size);
-
- align(64);
- L(bound_start_table);
- /* Generate starting bounds for 'oh' loop. This value also determines
- * the overlap (computed as an address offset) between the output over
- * the input for that loop iteration. */
- for (int oh_block = 0; oh_block < h_bounds_count; ++oh_block) {
- for (int kh = 0; kh < jcp.kh; ++kh) {
- te_size start_bound = nstl::max(
- jcp.t_pad - oh_block * jcp.oh_blk_size - kh, 0);
- write_table(start_bound);
- }
- }
- /* Write offset count for 'input' address calculation. The offset for
- * the input address is conditioned by the 'h' padding intersection over
- * the output rows. */
- for (int kh = 1; kh < jcp.kh; ++kh) {
- te_size kh_accum_value = nstl::max(nstl::min(kh - jcp.t_pad, 1), 0);
- write_table(kh_accum_value);
- }
- /* Last value is not used for offset calculation, write 'nop'
- * equivalent*/
- write_table(0);
-
- /* Non-padded blocks always increment 'kh' dimension */
- for (int oh_block = 0; oh_block < h_bounds_count - 1; oh_block++) {
- for (int kh = 0; kh < jcp.kh; ++kh) {
- te_size kh_accum_value = 1;
- write_table(kh_accum_value);
- }
- }
-
- /* number of input elements that overlap over output */
- int ih_overlap = jcp.oh_blk_size + jcp.kh - 1 - jcp.t_pad - jcp.b_pad;
-
- /* End Bounds for 'oh' default to 'OH' or OH_BLOCK_SIZE, unless
- * the 'oh_block' is within the 'bottom_padding' region. */
- int oh_end_blk = 0;
- for (; oh_end_blk < h_bounds_count - 1; ++oh_end_blk) {
- for (int kh = 0; kh < jcp.kh; ++kh) {
- te_size end_bound = nstl::min((jcp.ih / jcp.stride_h)
- - jcp.oh_blk_size - oh_end_blk * jcp.oh_blk_size
- + ih_overlap + 1 - kh,
- jcp.oh_blk_size);
- write_table(end_bound);
- }
- }
- /* Write bounds for the special case of when 'oh_block' falls within the
- * 'bottom_paddin' region - this always executes since at least 1 row of
- * bounds should exist. */
- const int pad = nstl::max(jcp.b_pad, jcp.t_pad);
- ih_overlap
- = (jcp.ih / jcp.stride_h + jcp.kh - 1 - jcp.t_pad - jcp.b_pad);
- oh_end_blk = jcp.oh - jcp.oh_blk_size;
- for (int kh = 0; kh < jcp.kh; ++kh) {
- te_size end_bound = nstl::min(
- jcp.oh_blk_size, ih_overlap - oh_end_blk + pad - kh);
- write_table(end_bound);
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_loop() {
-
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_loop(
+ const int block_size) {
Label oh_label;
Label ow_blk_label;
- const int oh_block_size = jcp.oh_blk_size;
- const int ow_unroll = jcp.ur_w;
- const int ow_block_count = jcp.ow / ow_unroll;
+ const int unroll_w = nstl::min(block_size, jcp.ow);
+ const int unroll_w_trips = jcp.ow / unroll_w;
+ const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0;
+
const int ch_offset = jcp.ch_block;
- mov(tmp_reg_idx_output, reg_output_baddr);
+ mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
+ mov(reg_oh_worksize,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
- xor_(iter_oh, iter_oh);
+ mov(reg_tmp_output, reg_output_baddr);
L(oh_label);
{
- xor_(iter_ow_blk, iter_ow_blk);
+ mov(iter_ow_blk, unroll_w_trips);
L(ow_blk_label);
{
- compute_bias_step_unroll(ow_unroll);
+ compute_bias_step_unroll(unroll_w);
+ add(reg_tmp_output, unroll_w * ch_offset * sizeof(float));
- add(tmp_reg_idx_output, ow_unroll * ch_offset * sizeof(float));
+ dec(iter_ow_blk);
+ cmp(iter_ow_blk, 0);
+ jg(ow_blk_label, T_NEAR);
+ }
- inc(iter_ow_blk);
- cmp(iter_ow_blk, ow_block_count);
- jl(ow_blk_label, T_NEAR);
+ if (tail_w > 0) {
+ compute_bias_step_unroll(tail_w);
+ add(reg_tmp_output, tail_w * ch_offset * sizeof(float));
}
- inc(iter_oh);
- cmp(iter_oh, oh_block_size);
+ inc(reg_oh);
+ cmp(reg_oh, reg_oh_worksize);
jl(oh_label, T_NEAR);
}
}
template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_kh_loop(
- int l_pad, int r_pad, int pad_offset, bool first_iteration,
- int ow_block) {
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_zero_filter() {
- Label kh_label;
- Label oh_label;
- Label exit_innerloop_label;
- Label skip_load_acc;
+ const int ch_offset = jcp.ch_block;
- const int table_row_count = get_loop_bounds_count(
- nstl::max(jcp.t_pad, jcp.b_pad), jcp.oh, jcp.oh_blk_size);
- const int ih_table_off = 1 * table_row_count * jcp.kh * sizeof(te_size);
- const int end_bound_table_off
- = 2 * table_row_count * jcp.kh * sizeof(te_size);
+ Label kh_loop_label, skip_zeroing_label;
+
+ mov(reg_exec_flags,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
+ and_(reg_exec_flags, FLAG_ZERO_FILTER);
+ test(reg_exec_flags, reg_exec_flags);
+ je(skip_zeroing_label);
+
+ zero_filter();
+
+ mov(reg_tmp_filter, reg_filter_baddr);
+ mov(reg_kh, jcp.kh);
+ L(kh_loop_label);
+ {
+ store_filter();
+
+ add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
+ dec(reg_kh);
+ cmp(reg_kh, 0);
+ jg(kh_loop_label);
+ }
+
+ /* Comeback pointers */
+ sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float));
+
+ L(skip_zeroing_label);
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_step(
+ int unroll_w, int l_pad, int pad_offset, int ow_block) {
const int ch_offset = jcp.ch_block;
- const bool handle_padding = (jcp.t_pad > 0) || (jcp.b_pad > 0);
+ Label kh_loop_label, skip_loop_label;
- mov(tmp_reg_filter, reg_filter_baddr);
- mov(tmp_reg_kh_input, reg_input_baddr);
- xor_(reg_tmp_off, reg_tmp_off);
+ cmp(reg_kh_count, 0);
+ je(skip_loop_label, T_NEAR);
- if (handle_padding) {
- mov(reg_bound_table_addr, bound_start_table);
+ mov(reg_kh, reg_kh_count);
+ L(kh_loop_label);
+ {
+ load_filter();
+ compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block);
+ store_filter();
- /* move to the row containing the indices for the current 'h' block */
- mov(reg_tmp_off, reg_table_idx);
- imul(reg_tmp_off, reg_tmp_off, jcp.kh * sizeof(unsigned char));
- add(reg_bound_table_addr, reg_tmp_off);
+ add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
+ add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
+ dec(reg_kh);
+ cmp(reg_kh, 0);
+ jg(kh_loop_label);
}
- xor_(iter_kh, iter_kh);
- L(kh_label);
+ /* Comeback pointers */
+ Label kh_comeback_label;
+ mov(reg_kh, reg_kh_count);
+ L(kh_comeback_label);
{
+ sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
+ sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
+ dec(reg_kh);
+ cmp(reg_kh, 0);
+ jg(kh_comeback_label, T_NEAR);
+ }
- mov(tmp_reg_idx_output, reg_output_baddr);
- mov(tmp_reg_idx_input, tmp_reg_kh_input);
+ L(skip_loop_label);
+}
- if (first_iteration) {
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_loop(
+ int unroll_w, int l_pad, int pad_offset, int ow_block) {
- /* apply zero filter */
- zero_filter();
+ const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ?
+ jcp.ih / jcp.stride_h - 1 :
+ jcp.oh - jcp.b_pad - 1;
+ const int ch_offset = jcp.ch_block;
+ const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
+ const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
- /* if zero_filter_flag is set to '1', load filter memory into
- * reg_accum */
- if (jcp.with_bias) {
- mov(reg_tmp_al, reg_exec_flag);
- and_(reg_tmp_al, FLAG_ZERO_FILTER);
- cmp(reg_tmp_al, 0);
- } else {
- /* none of the other flags are active, so we can use the
- * register directly */
- cmp(reg_exec_flag, 0);
- }
- je(skip_load_acc);
- load_filter();
- L(skip_load_acc);
+ Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label,
+ end_h_loop_label;
- } else {
- load_filter();
- }
+ mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
+ mov(reg_oh_worksize,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
+ mov(reg_kh_count,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]);
- xor_(iter_oh, iter_oh);
+ mov(reg_tmp_output, reg_output_baddr);
+ mov(reg_tmp_input, reg_input_baddr);
+ mov(reg_tmp_filter, reg_filter_baddr);
- if (handle_padding) {
+ L(h_loop_label);
+ {
- /* 'oh loop' initial bounds are stored in bound_table */
- mov(iter_oh_lb, byte[reg_bound_table_addr]);
+ compute_h_step(unroll_w, l_pad, pad_offset, ow_block);
- /* skip 'oh' row that intersects with top padding */
- xor_(reg_tmp_off, reg_tmp_off);
- mov(reg_tmp_off, iter_oh);
- imul(reg_tmp_off, reg_tmp_off, jcp.ow * ch_offset * sizeof(float));
- add(tmp_reg_idx_output, reg_tmp_off);
+ add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float));
- /* forward the input address by 'stride_h' */
- if (jcp.stride_h > 1) {
- xor_(reg_tmp_off, reg_tmp_off);
- mov(reg_tmp_off, iter_oh);
- imul(reg_tmp_off, reg_tmp_off,
- (jcp.stride_h - 1) * jcp.iw * ch_offset * sizeof(float));
- add(tmp_reg_idx_input, reg_tmp_off);
- }
- }
-
- L(oh_label);
- {
+ /* If within the top_pad region */
+ if (jcp.t_pad > 0) {
+ /* Skip t_pad area if no longer in initial h_block */
+ cmp(reg_oh, jcp.t_pad);
+ jg(skip_tpad_label, T_NEAR);
- compute_ow_step_unroll(l_pad, r_pad, pad_offset, ow_block);
+ cmp(reg_kh_count, jcp.kh);
+ jge(skip_tpad_label, T_NEAR);
- add(tmp_reg_idx_input,
- jcp.stride_h * jcp.iw * ch_offset * sizeof(float));
- add(tmp_reg_idx_output, jcp.ow * ch_offset * sizeof(float));
+ add(reg_kh_count, t_overlap_off);
+ sub(reg_tmp_filter,
+ t_overlap_off * jcp.kw * ch_offset * sizeof(float));
- inc(iter_oh);
- if (handle_padding) {
- /* 'oh loop' end bounds are stored in bound_table (precomputed
- * during JIT generation) */
- cmp(iter_oh_lb,
- byte[reg_bound_table_addr + end_bound_table_off]);
- } else {
- cmp(iter_oh, jcp.oh_blk_size);
+ /* kernel has moved beyond padding (adjust for stride effects) */
+ if (jcp.t_pad % jcp.stride_h != 0) {
+ int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h;
+ add(reg_tmp_input,
+ inp_corr * jcp.iw * ch_offset * sizeof(float));
}
- jl(oh_label, T_NEAR);
+ jmp(tpad_loop_label, T_NEAR);
}
- store_filter();
+ L(skip_tpad_label);
- add(tmp_reg_filter, jcp.kw * ch_offset * sizeof(float));
+ cmp(reg_oh, io_overlap);
+ jl(skip_bpad_label, T_NEAR);
+ sub(reg_kh_count, b_overlap_off);
- if (handle_padding) {
- xor_(kh_offset, kh_offset);
- mov(kh_offset_lb, byte[reg_bound_table_addr + ih_table_off]);
- /* increase 'ih' row in regards to 'kh'. */
- imul(kh_offset, kh_offset, jcp.iw * ch_offset * sizeof(float));
- add(tmp_reg_kh_input, kh_offset);
+ L(skip_bpad_label);
+ add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float));
- /* increase bound_table idx for the next 'kh' value in table*/
- add(reg_bound_table_addr, sizeof(te_size));
- } else {
- add(tmp_reg_kh_input, jcp.iw * ch_offset * sizeof(float));
- }
+ L(tpad_loop_label);
+
+ cmp(reg_oh, jcp.ih / jcp.stride_h);
+ jge(end_h_loop_label, T_NEAR);
- inc(iter_kh);
- cmp(iter_kh, jcp.kh);
- jl(kh_label, T_NEAR);
+ inc(reg_oh);
+
+ cmp(reg_oh, reg_oh_worksize);
+ jl(h_loop_label, T_NEAR);
}
+ L(end_h_loop_label);
}
template <cpu_isa_t isa>
inline void
jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
- Label skip_load_bias;
-
- /* Only apply zero_filter (xor'ing accum_reg) on the left edge */
- bool zero_filter_1st_iter = true;
-
const int ch_offset = jcp.ch_block;
-
- const int ow_block_size = jcp.ow_blk_size;
- const int iw_block_size = jcp.ow_blk_size * jcp.stride_w;
-
- int w_unrolled_loop_count = jcp.ow / ow_block_size;
-
- const bool handle_padding = (jcp.l_pad > 0) || (jcp.r_pad > 0);
-
- int pad_offset = jcp.l_pad;
-
- int ow_block = 0;
-
+ int ow = jcp.ow;
+ int pad_offset = 0;
+ int l_pad = jcp.l_pad;
+
+ /* Calculate effective padding */
+ int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1)
+ - (jcp.iw + jcp.l_pad - 1));
+
+ /* Is this strictly defined by:
+ * -code-size (?)
+ * -address size (?) */
+ const int max_unroll_w = 30;
+ const int block_size = 15;
+
+ int unroll_w_tail = 0;
+ int unroll_w = 0;
+ int unroll_w_trips = 0;
+
+ if (jcp.ow > max_unroll_w) {
+ unroll_w = nstl::min(block_size, jcp.ow);
+ unroll_w_trips = ow / unroll_w;
+ /* calculate tail */
+ unroll_w_tail = ow % unroll_w;
+ /* Perform some rebalancing if tail too small*/
+ if ((unroll_w_tail == 0 && r_pad != 0)
+ || (r_pad > 0 && r_pad >= unroll_w_tail)) {
+ if (unroll_w_trips > 1) {
+ unroll_w_tail += unroll_w;
+ unroll_w_trips--;
+ } else {
+ /* Idealy, this case shouldn't happen */
+ unroll_w_tail += (unroll_w - unroll_w / 2);
+ unroll_w = unroll_w / 2;
+ }
+ }
+ } else {
+ unroll_w = jcp.ow;
+ unroll_w_trips = nstl::max(1, ow / unroll_w);
+ }
if (jcp.with_bias) {
+ Label skip_load_bias;
+ mov(reg_bias_baddr,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]);
zero_bias();
- /* if zero_bias is '1', load bias accumulator from memory. This happens
- * after the first iteration is executed */
- mov(reg_tmp_al, reg_exec_flag);
- and_(reg_tmp_al, FLAG_ZERO_BIAS);
- cmp(reg_tmp_al, 0);
- je(skip_load_bias);
+ mov(reg_exec_flags,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
+ and_(reg_exec_flags, FLAG_ZERO_BIAS);
+ test(reg_exec_flags, reg_exec_flags);
+ jne(skip_load_bias);
+
load_bias();
- L(skip_load_bias);
- compute_bias_loop();
+ L(skip_load_bias);
+ compute_bias_loop(block_size);
store_bias();
}
- /* compute left padded block */
- if (handle_padding) {
-
- const int r_pad = jcp.iw - ow_block_size > 0 ? 0 : jcp.r_pad;
-
- compute_kh_loop(jcp.l_pad, r_pad, 0, zero_filter_1st_iter, ow_block);
- zero_filter_1st_iter = false;
+ /* Pass filter address, then offset for h_padding. */
+ compute_zero_filter();
+ mov(reg_kh_offset,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]);
+ add(reg_filter_baddr, reg_kh_offset);
- w_unrolled_loop_count--;
-
- if (w_unrolled_loop_count >= 1) {
- add(reg_output_baddr, ow_block_size * ch_offset * sizeof(float));
- add(reg_input_baddr, iw_block_size * ch_offset * sizeof(float));
- }
+ /* compute left padded block */
+ if (l_pad) {
+ compute_h_loop(unroll_w, l_pad, 0, 0);
+ add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
+ add(reg_input_baddr,
+ unroll_w * jcp.stride_w * ch_offset * sizeof(float));
+ unroll_w_trips--;
+ pad_offset = l_pad;
+ l_pad = 0;
}
- /* This block may execute under 2 different scenarios:
- * 1) When padding is present, this executes the middle loop (if any).
- * 2) With no padding, it writes the full loop of the micro-kernel. */
- int middle_loop_count = handle_padding ? w_unrolled_loop_count - 1 :
- w_unrolled_loop_count;
- if (middle_loop_count >= 1) {
- Label ow_blk_label;
-
- /* Insert loop for 'ow' block when middle block needs to execute more
- * than once */
- bool do_ow_blk_loop = middle_loop_count > 1;
- if (do_ow_blk_loop) {
- mov(iter_ow_blk, middle_loop_count);
- L(ow_blk_label);
- }
-
- compute_kh_loop(0, 0, pad_offset, zero_filter_1st_iter);
- /* disable zero_filter for the rest of the iterations i.e. from now on
- * load contents of 'filter' from memory */
- mov(reg_exec_flag, FLAG_ZERO_FILTER);
-
- if (do_ow_blk_loop || handle_padding) {
- add(reg_output_baddr, ow_block_size * ch_offset * sizeof(float));
- add(reg_input_baddr, iw_block_size * ch_offset * sizeof(float));
- }
-
- if (do_ow_blk_loop) {
- dec(iter_ow_blk);
- cmp(iter_ow_blk, 0);
- jg(ow_blk_label, T_NEAR);
- }
+ /* compute middle block */
+ Label ow_blk_label;
- w_unrolled_loop_count -= middle_loop_count;
+ /* Insert loop for 'ow' block when middle block needs to execute more
+ * than once */
+ bool do_ow_blk_loop = unroll_w_trips > 1;
+ if (do_ow_blk_loop) {
+ mov(iter_ow_blk, unroll_w_trips);
+ L(ow_blk_label);
+ }
+ if (unroll_w_trips > 0) {
+ compute_h_loop(unroll_w, l_pad, pad_offset, 0);
+ add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
+ add(reg_input_baddr,
+ unroll_w * jcp.stride_w * ch_offset * sizeof(float));
+ }
+ if (do_ow_blk_loop) {
+ dec(iter_ow_blk);
+ cmp(iter_ow_blk, 0);
+ jg(ow_blk_label, T_NEAR);
}
- /* compute right padded block: ow_blk = LAST */
- if (handle_padding && w_unrolled_loop_count >= 1) {
- ow_block = jcp.ow - ow_block_size;
- compute_kh_loop(
- 0, jcp.r_pad, pad_offset, zero_filter_1st_iter, ow_block);
-
- w_unrolled_loop_count--;
+ /* compute right padded block */
+ if (unroll_w_tail) {
+ compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail);
}
}
ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]);
mov(reg_filter_baddr,
ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]);
- if (jcp.with_bias)
- mov(reg_bias_baddr,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]);
- mov(reg_table_flags,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, table_flags)]);
compute_ow_block_unroll();
this->postamble();
-
- create_h_bounds_table();
}
template <cpu_isa_t isa>
jit_conv_conf_t &jcp, const convolution_desc_t &cd,
const memory_desc_wrapper &src_d,
const memory_desc_wrapper &diff_weights_d,
- const memory_desc_wrapper &diff_dst_d) {
-
+ const memory_desc_wrapper &diff_dst_d, int nthreads) {
if (!mayiuse(isa))
return status::unimplemented;
jcp.stride_w = cd.strides[1];
jcp.t_pad = cd.padding[0][0];
- /* bottom padding should equal top padding to generate the proper 'h' loop
- * bounds. */
jcp.b_pad = cd.padding[1][0];
jcp.l_pad = cd.padding[0][1];
auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
- bool args_ok = true
- && src_d.format() == desired_act_fmt
- && diff_weights_d.format() == desired_wei_fmt
- && diff_dst_d.format() == desired_act_fmt
- && one_of(cd.bias_desc.format, memory_format::undef, any, x)
- //&& jcp.ngroups % simd_w == 0
- && jcp.ngroups % jcp.ch_block == 0
- && jcp.dilate_h == 0
- && jcp.dilate_w == 0
- && jcp.kw <= 3
- && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
- && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
- if (!args_ok) return status::unimplemented;
-
- /* Note: this IMPLICATION-check does not allow 'negative padding' execution
- */
- bool ok = true && IMPLICATION(jcp.r_pad > 0, jcp.r_pad == jcp.l_pad)
- && IMPLICATION(jcp.b_pad > 0, jcp.b_pad == jcp.t_pad);
- if (!ok)
+ bool args_ok = true && src_d.format() == desired_act_fmt
+ && diff_weights_d.format() == desired_wei_fmt
+ && diff_dst_d.format() == desired_act_fmt
+ && one_of(cd.bias_desc.format, memory_format::undef, any, x)
+ && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0
+ && jcp.dilate_w == 0 && jcp.kw <= 3
+ && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
+ && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
+ if (!args_ok)
return status::unimplemented;
jcp.nb_ch = jcp.ngroups / jcp.ch_block;
- /* Values for block size to try; order gives priority */
- constexpr int BLOCK_SIZE[] = { 14, 16, 7, 8 };
-
- int block_size_h = 1;
- int block_size_w = 1;
+ /* kernel applicability check wrt boundaries
+ * the conditions are quite general across the kernels we have,
+ * but ideally the check should belong to a specific kernel... */
+ const int max_hpad = (jcp.kh - 1 + 1) / 2;
+ const int max_wpad = (jcp.kw - 1 + 1) / 2;
+ const bool boundaries_ok = true && jcp.t_pad <= max_hpad
+ && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad
+ && jcp.r_pad <= max_wpad;
+ if (!boundaries_ok)
+ return status::unimplemented;
- /* *Try different block sizes for convolution */
- for (int block : BLOCK_SIZE) {
+ balance(jcp, nthreads);
- block_size_h = block / jcp.stride_h;
- block_size_w = block / jcp.stride_w;
+ return status::success;
+}
- if ((jcp.oh % block_size_h == 0) && (jcp.ow % block_size_w == 0))
- break;
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ /* Notes: if splitting thread work on 'mb', then a reduction has to take
+ * place. Hence, book a per-thread, local weights-buffer for the
+ * reduction */
+ if (jcp.nthr_mb > 1) {
+ const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
+ scratchpad.book(key_conv_wei_reduction,
+ sizeof(float) * wei_size * (jcp.nthr_mb - 1));
+
+ if (jcp.with_bias)
+ scratchpad.book(key_conv_bia_reduction,
+ sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1));
}
+}
- if (jcp.oh % block_size_h != 0 || jcp.ow % block_size_w != 0)
- return status::unimplemented;
-
- jcp.oh_blk_size = block_size_h;
-
- jcp.ur_w = jcp.ow_blk_size = block_size_w;
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::balance(jit_conv_conf_t &jcp,
+ int nthreads) {
+ jcp.nthr = nthreads;
+ jcp.nthr_g = jcp.nthr_mb = 1;
+
+ /* Basic-Heuristics for parallel strategy:
+ * 1) Tries to parallel on the number of Groups (g) where tasks are
+ * independent. Otherwise,
+ * 2) Tries to split the work across g and MiniBatch (mb).
+ * Parallelizing on mb requires computing a reduction for weights.
+ *
+ * NOTE: because of 'task partitioning' scheme, there will be unbalanced
+ * per-thread load when the number of threads is high (e.g. > 16).
+ */
+ jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr);
+ jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb);
- return status::success;
+ jcp.nthr = jcp.nthr_g * jcp.nthr_mb;
}
template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_common>;