* limitations under the License.
*******************************************************************************/
-#include <common/primitive_attr.hpp>
#include "c_types_map.hpp"
#include "nstl.hpp"
#include "type_helpers.hpp"
using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
using namespace Xbyak;
vfmadd231ps(Ymm(ur_w * ii + jj),
Ymm(oc_blocks * ur_w + jj), ymm15);
else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
- Ymm tmp = ymask;
- vmulps(tmp, ymm15, Ymm(oc_blocks * ur_w + jj));
- vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), tmp);
+ vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
+ vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
}
}
}
vfmadd231ps(Ymm(ur_w * ii + jj),
Ymm(oc_blocks * ur_w + jj), ymm15);
else { // Intel AVX support
- Ymm tmp = ymask;
- vmulps(tmp, ymm15, Ymm(oc_blocks * ur_w + jj));
- vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), tmp);
+ vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
+ vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
}
}
}
for (int jj = 0; jj < ur_w; jj++) {
size_t offt;
if (jcp.with_dw_conv)
- offt = sizeof(float) * ((size_t)ii * od * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
+ offt = sizeof(float) * ((size_t)ii * od * jcp_dw.kh * ow + jj) * oc_blk;
else
offt = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
vmovups(Ymm(ur_w * ii + jj),
mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
mov(aux_reg_inp_d, reg_input);
- if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
+ if ((jcp.dilate_d >= jcp.id)
+ || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
cmp(reg_ki, 0);
je(skip_kd_loop, T_NEAR);
}
mov(aux_reg_kernel, aux_reg_ker_d);
}
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
+ if ((jcp.dilate_h >= jcp.ih)
+ || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
cmp(kj, 0);
je(skip_kh_loop, T_NEAR);
}
pop(reg_output);
}
-
- Label done, regular_store;
+ Label regular_store;
test(reg_ci_flag, FLAG_IC_LAST);
je(regular_store, T_NEAR);
int depthwise_inj_idx = 0;
const auto &p = attr_.post_ops_;
- if (p.len_ == 0 && eltwise_injectors.size() == 1) {
- eltwise_injectors[0]->compute_vector_range(0, oc_blocks * ur_w);
- }
-
int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
for (int i = 0; i < end_idx; i++) {
auto& post_op = p.entry_[i];
for (int jj = 0; jj < ur_w; jj++) {
size_t o_off;
if (jcp.with_dw_conv)
- o_off = sizeof(float) * ((size_t)ii * od * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
+ o_off = sizeof(float) * ((size_t)ii * od * jcp_dw.kh * ow + jj) * oc_blk;
else
o_off = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
Ymm reg_out = Ymm(ur_w * ii + jj);
vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out);
}
}
- L(done);
}
inline void jit_avx2_conv_fwd_kernel_f32::solve_common(
void jit_avx2_conv_fwd_kernel_f32::generate()
{
- if (jcp.with_eltwise) {
- eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx2>(
- this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
- ));
- }
-
const auto &p = attr_.post_ops_;
int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
for (int i = 0; i < end_idx; i++) {
auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
- case 0: return true; // no post_ops
- case 1:
- return true // sum OR eltwise OR dw_conv
- && !jcp.with_eltwise && (is_simple(0) || is_sum(0) || is_dw_conv(0));
- case 2:
- return true // sum->eltwise OR dw_conv->eltwise OR eltwise->dw_conv OR dw_conv->sum OR sum->depthwise OR
- // eltwise->depthwise OR depthwise->depthwise
- && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
- (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
- (is_simple(0) && is_simple(1)));
- case 3:
- return true // eltwise->dw_conv->eltwise OR dw_conv->sum->eltwise OR sum->eltwise->depthwise OR
- // sum->depthwise->eltwise OR sum->depthwise->depthwise
- && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
- (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
- (is_sum(0) && is_simple(1) && is_simple(2)));
- case 4: return true // eltwise->dw_conv->sum->eltwise
- && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
- default: return false;
+ case 0: return true;
+ case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
+ case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
+ (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
+ (is_simple(0) && is_simple(1));
+ case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
+ (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
+ (is_sum(0) && is_simple(1) && is_simple(2));
+ case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
+ default: return false;
}
return false;
status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+ const primitive_attr_t &attr)
{
if (!mayiuse(avx)) return status::unimplemented;
jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
jcp.dilate_w = cd.dilates[ndims-3];
- jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
- - (jcp.ih + jcp.t_pad - 1);
-
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alg = mkldnn_eltwise_relu;
- jcp.eltwise_alpha = relu_negative_slope;
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
const auto &p = attr.post_ops_;
- jcp.with_dw_conv = false;
+
int dw_conv_ind = p.find(primitive_kind::convolution);
- if (dw_conv_ind != -1) {
- jcp.with_dw_conv = true;
- jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
- jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
- jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
- jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
- jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
- jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
- jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
- jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+ jcp.with_dw_conv = dw_conv_ind != -1;
+ if (jcp.with_dw_conv) {
+ jcp.dw_conv_oh = jcp.oh;
+ jcp.dw_conv_ow = jcp.ow;
+ jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
+ jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
}
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+ - (jcp.ih + jcp.t_pad - 1);
+
if (jcp.with_dw_conv && !mayiuse(avx2))
return status::unimplemented;
if (jcp.with_dw_conv && jcp.ndims == 5)
return status::unimplemented;
- if (jcp.with_dw_conv) {
- int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
- if (dw_conv_eltwise_ind != -1) {
- jcp.dw_conv_with_eltwise = true;
- jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
- jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
- jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
+ if (!mayiuse(avx2)) {
+ for (int i = 0; i < p.len_; i++) {
+ auto &post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ if (post_op.eltwise.alg != alg_kind::eltwise_relu)
+ return status::unimplemented;
+ } else if (post_op.is_depthwise()) {
+ return status::unimplemented;
+ }
}
}
jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
- if (jcp.with_dw_conv) {
- jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
- }
- if (jcp.with_dw_conv) {
- jcp.oh = jcp.dw_conv_in_h;
- jcp.ow = jcp.dw_conv_in_w;
- }
+ jcp.src_dt = cd.src_desc.data_type;
+ jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+ jcp.dst_dt = cd.dst_desc.data_type;
const int simd_w = 8;
const bool flat = jcp.ic < simd_w;
const bool mimo = !flat;
+
+ /* Grouped channel offset to support 'non-blocked data' format for
+ * convolution sizes with '(input_channel / ngroups) < simd' */
+ jcp.nonblk_group_off
+ = (one_of(src_d.format(), ncw, nchw, ncdhw) && jcp.ngroups > 1) ?
+ jcp.ic :
+ 1;
+
bool ok_to_pad_channels = true
&& jcp.ngroups == 1;
return status::success;
}
-void jit_avx2_conv_bwd_data_kernel_f32::hsw_iter(int ur_w, int l_overflow,
- int r_overflow, int start_off)
+void jit_avx2_conv_fwd_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
+ if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+
+ if (jcp.with_dw_conv) {
+ const int nthreads = mkldnn_get_max_threads();
+ size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
+ scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
+
+ if (jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
+ }
+}
+
+void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow,
+ int r_overflow)
{
int kw = jcp.kw;
int kh = jcp.kh;
int ih = jcp.ih;
int id = jcp.id;
int ow = jcp.ow;
- int stride_w = jcp.stride_w;
- int stride_h = jcp.stride_h;
int ic_block = jcp.ic_block;
int oc_block = jcp.oc_block;
int nb_ic_block = jcp.nb_ic_blocking;
+ int stride_w = jcp.stride_w;
+ int stride_h = jcp.stride_h;
Label kd_loop, skip_kd_loop;
+ Label oc_loop, skip_oc_loop;
for (int ii = 0; ii < nb_ic_block; ii++)
for (int jj = 0; jj < ur_w; jj++) {
- size_t offt = sizeof(float) * ((size_t)ii * id * ih * iw + jj)
- * ic_block;
- vmovups(Ymm(ur_w * ii + jj),
- make_safe_addr(reg_dsrc, offt, reg_long_offt));
+ uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
+ Ymm(ur_w * ii + jj));
}
if (one_of(jcp.ndims, 3, 4)) {
- mov(aux_reg_ddst, reg_ddst);
- mov(aux_reg_kernel, reg_kernel);
+ cmp(reg_channel_work, 0);
+ jle(skip_oc_loop, T_NEAR);
+ xor_(reg_channel, reg_channel);
+
+ mov(aux_reg_ddst_oc_loop, reg_ddst);
+ mov(aux_reg_kernel_oc_loop, reg_kernel);
+
+ L(oc_loop);
+ mov(aux_reg_ddst, aux_reg_ddst_oc_loop);
+ mov(aux_reg_kernel, aux_reg_kernel_oc_loop);
}
if (jcp.ndims == 5) {
+ assert(jcp.nb_oc_blocking == 1);
push(oi_iter);
mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]);
mov(aux_reg_kernel, aux_reg_ker_d);
}
- mov(kj, reg_kh);
-
- Label kh_label;
-
- L(kh_label); {
+ Label kh_loop, skip_kh_loop;
+ cmp(kj, 0);
+ jle(skip_kh_loop, T_NEAR);
+ L(kh_loop); {
for (int ki = 0; ki < kw; ki++) {
- int jj_start = nstl::max(0, l_overflow - (kw - 1) + ki) ; // 0;
- int jj_end = ur_w - nstl::max(0, r_overflow - ki); // ur_w;
+ int jj_start = get_iw_start(ki, l_overflow); // 0;
+ int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w;
for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
- for (int jj = jj_start; jj < jj_end; jj++) {
- if ((jj - ki + jcp.l_pad + start_off) % stride_w == 0) {
- int aux_output_offset = ((jj - ki + jcp.l_pad + start_off) / stride_w) * jcp.oc_block + ofm2;
- vbroadcastss(Ymm(nb_ic_block * ur_w + jj), ptr[aux_reg_ddst + sizeof(float) * aux_output_offset]);
- }
+ for (int jj = jj_start ; jj < jj_end; jj += stride_w) {
+ int aux_output_offset
+ = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2;
+ vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w),
+ ptr[aux_reg_ddst
+ + sizeof(float) * aux_output_offset]);
}
- for (int ii = 0; ii < nb_ic_block; ii++) {
- int aux_kernel_offset = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block + ki * jcp.ic_block * jcp.oc_block + ofm2 * jcp.ic_block;
- vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]);
-
- for (int jj = jj_start; jj < jj_end; jj++) {
- if ((jj - ki + jcp.l_pad + start_off) % stride_w == 0) {
- vfmadd231ps(Ymm(ur_w * ii + jj), Ymm(nb_ic_block * ur_w + jj), ymm15);
- }
- }
+ for (int ii = 0; ii < nb_ic_block; ii++) {
+ int aux_kernel_offset
+ = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block
+ + ki * jcp.ic_block * jcp.oc_block
+ + ofm2 * jcp.ic_block;
+ vmovups(ymm15,
+ ptr[aux_reg_kernel
+ + sizeof(float) * aux_kernel_offset]);
+ for (int jj = jj_start; jj < jj_end; jj += stride_w)
+ vfmadd231ps(Ymm(ur_w * ii + jj),
+ Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15);
}
}
}
- add(aux_reg_kernel, sizeof(float) * kw * oc_block * ic_block * stride_h);
+ add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block
+ * ic_block);
sub(aux_reg_ddst, sizeof(float) * ow * oc_block);
- sub(kj, stride_h);
+ dec(kj);
cmp(kj, 0);
- jg(kh_label, T_NEAR);
+ jg(kh_loop, T_NEAR);
}
+ L(skip_kh_loop);
if (jcp.ndims == 5) {
sub(aux_reg_dst_d,
pop(oi_iter);
}
+ if (one_of(jcp.ndims, 3, 4)) {
+ int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow
+ * jcp.oc_block;
+ int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw
+ * jcp.ic * jcp.oc_block;
+
+ add(aux_reg_ddst_oc_loop, ddst_oc_shift);
+ add(aux_reg_kernel_oc_loop, kernel_oc_shift);
+
+ inc(reg_channel);
+ cmp(reg_channel, reg_channel_work);
+ jl(oc_loop, T_NEAR);
+
+ L(skip_oc_loop);
+ mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
+ }
+
+ Label no_update_label;
+ cmp(reg_channel, 0);
+ je(no_update_label, T_NEAR);
+ for (int ii = 0; ii < nb_ic_block; ii++) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ size_t offt =
+ sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
+ vmovups(Ymm(15),
+ make_safe_addr(reg_dsrc, offt, reg_long_offt));
+ vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
+ Ymm(15));
+
+ }
+ }
+ L(no_update_label);
+
for (int ii = 0; ii < nb_ic_block; ii++)
for (int jj = 0; jj < ur_w; jj++) {
size_t offt =
void jit_avx2_conv_bwd_data_kernel_f32::generate() {
preamble();
- auto hsw_iter_body = [=] (int ur_w, int l_overflow, int r_overflow) {
- if (jcp.stride_w == 1) {
- hsw_iter(ur_w, l_overflow, r_overflow, 0);
- add(reg_dsrc, sizeof(float) * jcp.ur_w * jcp.ic_block);
- add(reg_ddst, sizeof(float) * jcp.ur_w * jcp.oc_block);
- } else {
- Label hsw_iter_off_0;
- Label hsw_iter_off_1;
- Label hsw_iter_exit;
-
- int dst_off = jcp.ur_w / jcp.stride_w;
-
- and_(start_off_reg, 1);
-
- L(hsw_iter_off_0); {
- cmp(start_off_reg, 0);
- jg(hsw_iter_off_1, T_NEAR);
-
- hsw_iter(ur_w, l_overflow, r_overflow, 0);
- add(reg_dsrc, sizeof(float) * jcp.ur_w * jcp.ic_block);
- add(reg_ddst, sizeof(float) * dst_off * jcp.oc_block);
-
- jmp(hsw_iter_exit, T_NEAR);
- }
-
- L(hsw_iter_off_1); {
- hsw_iter(ur_w, l_overflow, r_overflow, 1);
- add(reg_dsrc, sizeof(float) * jcp.ur_w * jcp.ic_block);
- add(reg_ddst, sizeof(float) * (dst_off + 1) * jcp.oc_block);
- }
-
- L(hsw_iter_exit);
- add(start_off_reg, std::abs(jcp.ur_w - jcp.stride_w));
- }
- };
-
mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
+ mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
+ mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]);
- int n_oi = jcp.iw / jcp.ur_w;
- xor_(oi_iter, oi_iter);
- xor_(start_off_reg, start_off_reg);
+ int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block;
+ int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block;
- int l_overflow = nstl::max(0, jcp.kw - 1 - jcp.l_pad);
- if (l_overflow > 0) {
- hsw_iter_body(jcp.ur_w, l_overflow, 0);
- inc(oi_iter);
- }
+ int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
+ int r_overflow = nstl::max(0, (jcp.kw - 1
+ - nstl::max(0, jcp.r_pad)) / jcp.stride_w);
+ int r_overflow1 = nstl::max(0, (jcp.kw - 1
+ - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
- int r_pad = jcp.iwp - jcp.iw - jcp.l_pad;
- int r_overflow1
- = nstl::max(0, jcp.kw - 1 - (jcp.iw - jcp.ur_w * n_oi) - r_pad);
- int r_overflow = nstl::max(0, jcp.kw - 1 - r_pad);
+ int n_oi = jcp.iw / jcp.ur_w;
if (r_overflow1 > 0)
n_oi--;
- if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) {
- Label ow_loop;
- L(ow_loop); {
- hsw_iter_body(jcp.ur_w, 0, 0);
+ if (jcp.ur_w == jcp.iw) {
+ compute_loop(jcp.ur_w, l_overflow, r_overflow);
+ } else if (n_oi == 0) {
+ compute_loop(jcp.ur_w, l_overflow, r_overflow1);
+ add(reg_dsrc, dsrc_shift);
+ add(reg_ddst, ddst_shift);
+ if (jcp.ur_w_tail != 0)
+ compute_loop(jcp.ur_w_tail, 0, r_overflow);
+ } else {
+ xor_(oi_iter, oi_iter);
+ if (l_overflow > 0) {
+ compute_loop(jcp.ur_w, l_overflow, 0);
+ add(reg_dsrc, dsrc_shift);
+ add(reg_ddst, ddst_shift);
inc(oi_iter);
- cmp(oi_iter, n_oi);
- jl(ow_loop, T_NEAR);
}
- }
- if (r_overflow1 > 0 )
- hsw_iter_body(jcp.ur_w, 0, r_overflow1);
+ if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) {
+ Label ow_loop;
+ L(ow_loop); {
+ compute_loop(jcp.ur_w, 0, 0);
+ add(reg_dsrc, dsrc_shift);
+ add(reg_ddst, ddst_shift);
+ inc(oi_iter);
+ cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR);
+ }
+ }
- if (jcp.ur_w_tail != 0)
- hsw_iter_body(jcp.ur_w_tail, 0, r_overflow);
+ if (r_overflow1 > 0 ) {
+ compute_loop(jcp.ur_w, 0, r_overflow1);
+ add(reg_dsrc, dsrc_shift);
+ add(reg_ddst, ddst_shift);
+ }
+
+ if (jcp.ur_w_tail != 0)
+ compute_loop(jcp.ur_w_tail, 0, r_overflow);
+ }
this->postamble();
}
bool ok_to_pad_channels = true
&& jcp.ngroups == 1;
+ /* gemm-based convolution performs better in these cases */
+ if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1)
+ return status::unimplemented;
+
if (ok_to_pad_channels) {
jcp.oc = rnd_up(jcp.oc, simd_w);
jcp.ic = rnd_up(jcp.ic, simd_w);
jcp.ur_h = 1; /* no code-unrolling by h so far */
jcp.nb_ic_blocking = 1;
jcp.nb_oc_blocking = 1;
+ jcp.ur_w = 1;
+
+ if(one_of(ndims, 3, 4) && jcp.ow < 40)
+ jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2;
jcp.src_fmt = diff_src_d.format();
- jcp.with_eltwise = false;
bool args_ok = true
&& one_of(diff_src_d.format(), nCw8c, nChw8c, nCdhw8c)
&& one_of(weights_d.format(), gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i,
gOIdhw8o8i, OIdhw8o8i)
&& one_of(diff_dst_d.format(), nCw8c, nChw8c, nCdhw8c)
- && (jcp.stride_w == 1 || jcp.stride_w == 2)
+ && jcp.stride_w == jcp.stride_h
&& jcp.stride_d == 1
&& jcp.dilate_d == 0
&& jcp.dilate_h == 0
&& jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
&& jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
if (!args_ok) return status::unimplemented;
+ jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad;
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad;
+ int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
+
+ const int max_regs = 15; /* Maximun number of registers available for
+ result accumulation and delta dst data.
+ One additional register is reserved for weights
+ data. */
+
+ /* Find the best blocking with maximum number of fma instructions
+ per ur_w * nb_ic_blocking compute loops. Number of required registers
+ is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
+ ur_w must be divisible by stride_w */
+ if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
+ distribution exceeds max_regs */
+ return status::unimplemented;
- jcp.ur_w = 3;
-
- for (int b = 4; b > 1; b--)
+ int best_nfmas = 0;
+ for (int b = 1; b <= 4; b++)
{
- if (jcp.nb_ic % b == 0)
+ if (jcp.nb_ic % b != 0)
+ continue;
+
+ for (int u = jcp.stride_w;
+ u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w;
+ u += jcp.stride_w)
{
- jcp.nb_ic_blocking = b;
- break;
+ int ur_w = nstl::min(u, jcp.iw);
+ /* maximum 1 step with l_overflow so far */
+ if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
+ continue;
+ int nfmas = utils::div_up(ur_w, jcp.stride_w) * b;
+ if (nfmas > best_nfmas
+ || (nfmas == best_nfmas && jcp.ur_w < ur_w)) {
+ jcp.ur_w = ur_w;
+ jcp.nb_ic_blocking = b;
+ best_nfmas = nfmas;
+ }
}
}
+ if (best_nfmas == 0) /* can't find appropriate blocking */
+ return status::unimplemented;
jcp.ur_w_tail = jcp.iw % jcp.ur_w;
- int l_overflow = nstl::max(0, jcp.kw - 1 - jcp.l_pad);
- if (l_overflow > jcp.ur_w) /* maximum 1 step with l_overflow so far */
- return status::unimplemented;
- int r_pad = jcp.iwp - jcp.iw - jcp.l_pad;
- int r_overflow_step0 = nstl::max(0, jcp.kw - 1 - (jcp.iw - jcp.ur_w) - r_pad);
- if (l_overflow > 0 && r_overflow_step0 > 0) /* no steps with both left and
- right overflow so far */
+
+ int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail
+ - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
+ /* maximum 1 ur_w block with r_overflow so far */
+ if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
return status::unimplemented;
- int r_overflow_no_tail = nstl::max(0,jcp.kw - 1 - jcp.ur_w_tail - r_pad);
- if (r_overflow_no_tail > jcp.ur_w) /* maximum 1 ur_w block with
- r_overflow so far */
+
+ if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
return status::unimplemented;
+
return status::success;
}
+void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ UNUSED(scratchpad);
+ UNUSED(jcp);
+}
+
void jit_avx2_conv_bwd_weights_kernel_f32::generate() {
this->preamble();
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
- jcp.with_eltwise = false;
- jcp.eltwise_alpha = 0;
const bool flat = jcp.ic == 3;
const bool mimo = !flat;
jcp.oc_block = simd_w;
jcp.nb_oc = jcp.oc / jcp.oc_block;
jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
+
return status::success;
}
+void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+}
+
inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
{
Label kd_comeback_loop;