namespace impl {
namespace cpu {
+using namespace memory_tracking::names;
+
namespace {
unsigned int LLC_cache_size = get_cache_size(3, false);
}
template <bool is_fwd>
-void input_transform_tileblock_data(int tile_block,
- const jit_conv_winograd_conf_t &jcp,
- float *inp, float *tinp)
-{
- const int inph = is_fwd ? jcp.ih : jcp.oh;
- const int inpw = is_fwd ? jcp.iw : jcp.ow;
- const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
- const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
- const int wp_max = inpw + l_pad;
- const int hp_max = inph + t_pad;
- float Iw[alpha][alpha][simd_w];
- float I[alpha][alpha][simd_w];
-
- array_offset_calculator<float, 5> input(inp,
- jcp.mb, jcp.dimK/simd_w, inph, inpw, simd_w);
- array_offset_calculator<float, 7> output(tinp,
- alpha, alpha,
- jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
- jcp.dimN_reg_block, jcp.dimK_reg_block);
-
- int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
-
- for (int nb_tile_block_ur = 0;
- nb_tile_block_ur < jcp.nb_tile_block_ur;
- nb_tile_block_ur++) {
- for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
- tile_block_ur++) {
-
- int img = tile_index / (jcp.jtiles * jcp.itiles);
- int ti = tile_index % jcp.itiles;
- int tj = (tile_index / jcp.itiles) % jcp.jtiles;
- float *pinp_b = &(input(img, 0, 0, 0, 0));
-
- for (int j = 0; j < alpha; j++) {
- int ydim = tj * tile_size + j;
- if ((t_pad <= ydim) && (ydim < hp_max)) {
- float *pinp_j = pinp_b + (ydim - t_pad) * inpw * simd_w;
- for (int i = 0; i < alpha; i++) {
- int xdim = ti * tile_size + i;
- if ((l_pad <= xdim) && (xdim < wp_max)) {
- float *pinp_i = pinp_j + (xdim - l_pad) * simd_w;
- load_ps(I[j][i], pinp_i);
- } else {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- } else {
- for (int i = 0; i < alpha; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- }
-
- trans_I_4x4_3x3(Iw, I);
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- store_output(&(output(j, i,
- nb_tile_block_ur, 0, 0,
- tile_block_ur, 0)),
- Iw[j][i], false);
- }
- }
- tile_index++;
- }
- }
-}
-
-template <bool is_fwd>
void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
float *wp, float *twp)
{
O[j][i][v] = true
&& with_relu_presum && O[j][i][v] < 0.f
? O[j][i][v]
- * jcp.eltwise_alpha
+ * jcp.eltwise.alpha
: O[j][i][v];
}
}
}
}
-template <bool is_fwd, bool with_bias, bool with_relu_presum, bool with_sum>
-void output_transform_tileblock_data(int tile_block,
- const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
- float *toutp, float *outp, float *bias, bool streamout) {
- float Ow[alpha][alpha][simd_w];
- float O[tile_size][tile_size][simd_w];
- int outw = is_fwd ? jcp.ow : jcp.iw;
- int outh = is_fwd ? jcp.oh : jcp.ih;
-
- /* Prepare for PostOps */
- bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1;
-
- array_offset_calculator<float, 6> input(toutp,
- alpha, alpha,
- jcp.dimN_block, jcp.dimM_block,
- jcp.dimN_reg_block, jcp.dimM_simd_block);
- array_offset_calculator<float, 5> output(outp,
- jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw,
- jcp.dimM_simd_block);
-
- int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
-
- for (int nb_tile_block_ur = 0;
- nb_tile_block_ur < jcp.nb_tile_block_ur;
- nb_tile_block_ur++) {
-
- for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
- tile_block_ur++) {
- int img = tile_index / (jcp.jtiles * jcp.itiles);
- int ti = tile_index % jcp.itiles;
- int tj = (tile_index / jcp.itiles) % jcp.jtiles;
-
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- float *pinp_tile = &(input(j, i, nb_tile_block_ur, 0,
- tile_block_ur, 0));
- load_ps(Ow[j][i], pinp_tile);
- }
- }
-
- trans_O_4x4_3x3(Ow, O);
-
- float *pout_b = &(output(img, 0, 0, 0, 0));
- for (int j = 0; j < tile_size; j++) {
- int ydim = tj * tile_size + j;
- if (ydim < outh) {
- float *pout_j = pout_b + ydim * outw * simd_w;
- for (int i = 0; i < tile_size; i++) {
- int xdim = ti * tile_size + i;
- if (xdim < outw) {
- float *pout_i = pout_j + xdim * simd_w;
- if (is_fwd) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- O[j][i][v] += with_bias ? bias[v] : 0.f;
- O[j][i][v] = true
- && with_relu_presum && O[j][i][v] < 0.f
- ? O[j][i][v]
- * jcp.eltwise_alpha
- : O[j][i][v];
-
- }
- }
- if (with_sum)
- accum_output(pout_i, O[j][i], streamout,
- with_relu_postsum);
- else
- store_output(pout_i, O[j][i], streamout);
- }
- }
- }
- }
- tile_index++;
- }
- }
-}
-
template <bool ver_4fma>
void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
float *inp, float *tinp, float *Iw_temp,
template <bool is_fwd>
void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
- const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
+ const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const{
const auto &jcp = kernel_->jcp;
const auto &p_ops = attr_->post_ops_;
const int outh = is_fwd ? jcp.oh : jcp.ih;
const int outw = is_fwd ? jcp.ow : jcp.iw;
- /* Note that jcp.with_relu is true for both fused conv+relu primitive
+ /* Note that jcp.with_eltwise is true for both fused conv+relu primitive
* and conv primitive with PostOps with relu before sum
* (PostOps relu after sum is handled later) */
auto output_transform = jcp.with_bias
array_offset_calculator<float, 2> bias(bias_ptr,
jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
- array_offset_calculator<float, 8> M(
- (float *)((is_fwd
- ? (this->scratchpad_)->M_ptr()
- : (this->scratchpad_)->V_ptr())),
+ array_offset_calculator<float, 8> M(is_fwd
+ ? scratchpad.template get<float>(key_wino_M)
+ : scratchpad.template get<float>(key_wino_V),
jcp.dimN_nb_block, jcp.dimM_nb_block,
alpha, alpha,
jcp.dimN_block, jcp.dimM_block,
jcp.dimN_reg_block, jcp.dimM_simd_block);
- array_offset_calculator<float, 8> U((float *)((this->scratchpad_)->U_ptr()),
+ array_offset_calculator<float, 8> U(
+ scratchpad.template get<float>(key_wino_U),
jcp.dimM_nb_block,
alpha, alpha,
jcp.dimK_nb_block,
jcp.dimM_block, jcp.dimK_block,
jcp.dimK_reg_block, jcp.dimM_simd_block);
- array_offset_calculator<float, 8> V(
- (float *)((is_fwd
- ? (this->scratchpad_)->V_ptr()
- : (this->scratchpad_)->M_ptr())),
+ array_offset_calculator<float, 8> V(is_fwd
+ ? scratchpad.template get<float>(key_wino_V)
+ : scratchpad.template get<float>(key_wino_M),
jcp.dimN_nb_block, alpha, alpha,
jcp.dimN_block, jcp.dimK_nb_block,
jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0;
- const bool want_padded_bias = jcp.with_bias
+ const bool wants_padded_bias = jcp.with_bias
&& jcp.oc_without_padding != jcp.oc;
float last_slice_bias[simd_w] = {0};
- if (want_padded_bias) {
+ if (wants_padded_bias) {
for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
}
-#pragma omp parallel
+PRAGMA_OMP(parallel)
{
parallel_nd_in_omp(MB, jcp.dimK_nb_block, jcp.dimK_block,
[&](int img, int K_blk1, int K_blk2) {
ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr);
});
-#pragma omp barrier
+PRAGMA_OMP(barrier)
parallel_nd_in_omp(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block,
[&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) {
});
-#pragma omp barrier
+PRAGMA_OMP(barrier)
parallel_nd_in_omp(MB, jcp.dimM_nb_block, jcp.dimM_block,
[&](int img, int M_blk1, int M_blk2) {
const int M_blk = M_blk1 * jcp.dimM_block + M_blk2;
- float *bias_ptr = want_padded_bias
+ float *bias_ptr = wants_padded_bias
&& M_blk == jcp.dimM / jcp.dimM_simd_block - 1
? last_slice_bias : &bias(M_blk, 0);
}
}
-template void
-_jit_avx512_common_convolution_winograd_t<true>::_execute_data_W_S_G_D(
- const int, float *, float *, float *, float *);
-template void
-_jit_avx512_common_convolution_winograd_t<false>::_execute_data_W_S_G_D(
- const int, float *, float *, float *, float *);
-
-template <bool is_fwd>
-void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_SGD(
- const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
- const auto &jcp = kernel_->jcp;
- const auto &p_ops = attr_->post_ops_;
-
- const int inph = is_fwd ? jcp.ih : jcp.oh;
- const int inpw = is_fwd ? jcp.iw : jcp.ow;
- const int outh = is_fwd ? jcp.oh : jcp.ih;
- const int outw = is_fwd ? jcp.ow : jcp.iw;
-
- /* Note that jcp.with_relu is true for both fused conv+relu primitive
- * and conv primitive with PostOps with relu before sum
- * (PostOps relu after sum is handled later) */
- auto output_transform_tileblock = jcp.with_bias
- ? (jcp.with_eltwise
- ? (jcp.with_sum
- ? output_transform_tileblock_data<is_fwd, true, true, true>
- : output_transform_tileblock_data<is_fwd, true, true, false>)
- : (jcp.with_sum
- ? output_transform_tileblock_data<is_fwd, true, false, true>
- : output_transform_tileblock_data<is_fwd, true, false, false>))
- : (jcp.with_eltwise
- ? (jcp.with_sum
- ? output_transform_tileblock_data<is_fwd, false, true, true>
- : output_transform_tileblock_data<is_fwd, false, true, false>)
- : (jcp.with_sum
- ? output_transform_tileblock_data<is_fwd, false, false, true>
- : output_transform_tileblock_data<is_fwd, false, false, false>));
-
- array_offset_calculator<float, 5> input(inp_ptr,
- MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
- array_offset_calculator<float, 5> output(out_ptr,
- MB, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
- array_offset_calculator<float, 6> weights(wei_ptr,
- jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
- jcp.ic_simd_block, jcp.oc_simd_block);
- array_offset_calculator<float, 2> bias(bias_ptr,
- jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
-
- array_offset_calculator<float, 8> U((float *)((this->scratchpad_)->U_ptr()),
- jcp.dimM_nb_block,
- alpha, alpha,
- jcp.dimK_nb_block,
- jcp.dimM_block, jcp.dimK_block,
- jcp.dimK_reg_block, jcp.dimM_simd_block);
-
- array_offset_calculator<float, 8> M(
- (float *)((is_fwd
- ? (this->scratchpad_)->M_ptr()
- : (this->scratchpad_)->V_ptr())),
- 0, jcp.dimM_nb_block, alpha, alpha,
- jcp.dimN_block, jcp.dimM_block,
- jcp.dimN_reg_block, jcp.dimM_simd_block);
-
- array_offset_calculator<float, 8> V(
- (float *)((is_fwd
- ? (this->scratchpad_)->V_ptr()
- : (this->scratchpad_)->M_ptr())),
- 0, alpha, alpha, jcp.dimN_block,
- jcp.dimK_nb_block, jcp.dimK_block,
- jcp.dimN_reg_block, jcp.dimK_reg_block);
-
- const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0;
-
- const bool want_padded_bias = jcp.with_bias
- && jcp.oc_without_padding != jcp.oc;
- float last_slice_bias[simd_w] = {0};
- if (want_padded_bias) {
- for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
- last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
- }
-
-#pragma omp parallel
- {
- parallel_nd_in_omp(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block,
- [&](int ofm1, int ifm1, int ofm2, int ifm2) {
-
- float *U_base_ptr = is_fwd
- ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
- : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
- weight_transform_data<is_fwd>(jcp,
- &(weights(ofm1 * jcp.oc_block + ofm2,
- ifm1 * jcp.ic_block + ifm2,
- 0, 0, 0, 0)),
- U_base_ptr);
- });
-
-#pragma omp barrier
-
- int ithr = mkldnn_get_thread_num();
-
-#pragma omp for schedule(static)
- for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
- for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
- for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
- input_transform_tileblock_data<is_fwd>(
- tile_block, jcp,
- &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
- &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
- }
- }
-
- for (int oj = 0; oj < alpha; oj++) {
- for (int oi = 0; oi < alpha; oi++) {
- for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
- for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++) {
- kernel_->gemm_loop_ker_first_iter(
- (float *)&(M(ithr, M_blk1, oj, oi,
- N_blk, 0, 0, 0)),
- (const float *)&(U(M_blk1, oj, oi, 0,
- 0, 0, 0, 0)),
- (const float *)&(V(ithr, oj, oi,
- N_blk, 0, 0, 0, 0)));
- for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
- kernel_->gemm_loop_ker(
- (float *)&(M(ithr, M_blk1, oj, oi,
- N_blk, 0, 0, 0)),
- (const float *)&(U(M_blk1, oj, oi, K_blk1,
- 0, 0, 0, 0)),
- (const float *)&(V(ithr, oj, oi,
- N_blk, K_blk1, 0, 0, 0)));
- }
- }
- }
- }
- }
-
- for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
- for (int M_blk2 = 0; M_blk2 < jcp.dimM_block; M_blk2++) {
- const int M_blk = M_blk1 * jcp.dimM_block + M_blk2;
-
- float *bias_ptr = want_padded_bias
- && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
- ? last_slice_bias : &bias(M_blk, 0);
-
- output_transform_tileblock(tile_block, jcp, p_ops,
- &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
- &(output(0, M_blk, 0, 0, 0)),
- bias_ptr, output_is_aligned);
- }
- }
- }
- }
-}
-
-template void
-_jit_avx512_common_convolution_winograd_t<true>::_execute_data_W_SGD(
- const int, float *, float *, float *, float *);
-template void
-_jit_avx512_common_convolution_winograd_t<false>::_execute_data_W_SGD(
- const int, float *, float *, float *, float *);
+template struct _jit_avx512_common_convolution_winograd_t<true>;
+template struct _jit_avx512_common_convolution_winograd_t<false>;
void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_maybe_execute_diff_bias_copy() {
- if (conf_.want_padded_bias()) {
+_maybe_execute_diff_bias_copy(
+ const memory_tracking::grantor_t &scratchpad) const {
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
float *diff_bias = (float *)this->memory(1);
- for (int oc = 0; oc < conf_.jcp_.oc_without_padding; ++oc)
- diff_bias[oc] = this->padded_bias_[oc];
+ for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
+ diff_bias[oc] = padded_bias[oc];
}
}
void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_execute_backward_weights_S_D_G_W()
-{
+_execute_backward_weights_S_D_G_W(
+ const memory_tracking::grantor_t &scratchpad) const {
const auto &jcp = kernel_->jcp;
- const int nthreads = scratchpad_->num_threads();
+ const int nthreads = jcp.nthr;
auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ?
diff_src_transform_bwd_weights<true> :
jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w);
array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
- array_offset_calculator<float, 2> diff_bias(
- conf_.want_padded_bias() ? padded_bias_ : (float *)this->memory(1),
- jcp.oc/simd_w, simd_w);
+ array_offset_calculator<float, 2> diff_bias(pd()->wants_padded_bias()
+ ? scratchpad.get<float>(key_conv_padded_bias)
+ : (float *)this->memory(1), jcp.oc/simd_w, simd_w);
array_offset_calculator<float, 8> U(
- (float *)(scratchpad_->U_ptr()),
+ scratchpad.get<float>(key_wino_U),
jcp.nb_ic, jcp.nb_oc,
alpha, alpha,
jcp.oc_block, jcp.ic_block,
jcp.ic_simd_block, jcp.oc_simd_block);
array_offset_calculator<float, 8> M(
- (float *)(scratchpad_->M_ptr()),
+ scratchpad.get<float>(key_wino_M),
jcp.nb_oc, alpha, alpha,
jcp.tile_block, jcp.oc_block,
jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
jcp.oc_simd_block);
array_offset_calculator<float, 8> V(
- (float *)(scratchpad_->V_ptr()),
+ scratchpad.get<float>(key_wino_V),
jcp.nb_ic, alpha, alpha,
jcp.tile_block, jcp.ic_block,
jcp.nb_tile_block_ur, jcp.tile_block_ur,
const int trans_buffer_size = alpha * alpha * jcp.tile_4fma
* jcp.ic_simd_block;
array_offset_calculator<float, 2> trans_buffer(
- (float *)(scratchpad_->src_transpose_ptr()),
+ scratchpad.get<float>(key_conv_tr_src),
nthreads,
trans_buffer_size);
array_offset_calculator<float, 2> diff_bias_prv(
- (float *)(scratchpad_->bias_ptr()),
- mkldnn_get_max_threads(),
+ scratchpad.get<float>(key_conv_bia_reduction),
+ nthreads,
jcp.oc);
-#pragma omp parallel num_threads(nthreads)
+PRAGMA_OMP(parallel num_threads(nthreads))
{
if (jcp.with_bias) {
parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
diff_bias_prv(ithr, ofm) = 0.0f;
});
-#pragma omp for nowait
+PRAGMA_OMP(for nowait)
for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
PRAGMA_OMP_SIMD()
for (int v = 0; v < simd_w; v++)
dbias);
});
-#pragma omp barrier
+PRAGMA_OMP(barrier)
for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
parallel_nd_in_omp(alpha, alpha, jcp.nb_oc,
});
}
-#pragma omp barrier
+PRAGMA_OMP(barrier)
parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
[&](int ifm1, int ofm1, int ofm2, int ifm2) {
});
if (jcp.with_bias) {
-#pragma omp for
+PRAGMA_OMP(for)
for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
for (int ithr = 0; ithr < nthreads; ithr++) {
float* base_bias_ptr = &(diff_bias(ofm1, 0));
}
}
- _maybe_execute_diff_bias_copy();
+ _maybe_execute_diff_bias_copy(scratchpad);
}
-namespace {
-
-const int max_threads_number = 1024;
-
-template <bool ver_4fma>
-void diff_src_transform_bwd_weights_tile(int tile_block,
- jit_conv_winograd_conf_t conv, float *inp, float *tinp,
- void(*transpose_4fma_ker)(float *, float *))
-{
- const int ifwp = conv.iw + conv.l_pad;
- const int ifhp = conv.ih + conv.t_pad;
- float I[alpha][alpha][simd_w];
- float Iw[alpha][alpha][simd_w];
-
- float *Iw_buffer = nullptr;
- if (ver_4fma) {
- Iw_buffer = (float *)malloc(alpha * alpha * conv.tile_4fma
- * simd_w * sizeof(float), 64);
- }
- array_offset_calculator<float, 4> Iw_scratchpad(Iw_buffer,
- alpha, alpha, conv.tile_4fma, simd_w);
- array_offset_calculator<float, 5> input(inp,
- conv.mb, conv.ic / simd_w, conv.ih, conv.iw, simd_w);
- array_offset_calculator<float, 7> output(tinp,
- 0, alpha, alpha,
- conv.ic_block,
- conv.nb_tile_block_ur, conv.tile_block_ur,
- conv.ic_simd_block * conv.tile_4fma);
-
- int tile_4fma = 0;
-
- int n_tiles = tile_block * conv.nb_tile_block_ur * conv.tile_block_ur;
- for (int nb_tile_block_ur = 0; nb_tile_block_ur < conv.nb_tile_block_ur;
- nb_tile_block_ur++) {
- for (int tile_block_ur = 0; tile_block_ur < conv.tile_block_ur;
- tile_block_ur++) {
-
- int img = n_tiles / (conv.jtiles * conv.itiles);
- int no_tile = n_tiles % (conv.jtiles * conv.itiles);
- int ti = no_tile % conv.itiles;
- int tj = no_tile / conv.itiles;
-
- for (int j = 0; j < alpha; j++) {
- int ydim = tj * tile_size + j;
- if ((conv.t_pad <= ydim) && ydim < ifhp) {
- for (int i = 0; i < alpha; i++) {
- int xdim = ti * tile_size + i;
- if ((conv.l_pad <= xdim) && xdim < ifwp) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = input(img, 0,
- ydim - conv.t_pad,
- xdim - conv.l_pad, v);
- }
- }
- else {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- }
- else {
- for (int i = 0; i < alpha; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- }
-
- trans_I_4x4_3x3(Iw, I);
-
- if (ver_4fma) {
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- Iw_scratchpad(j, i, tile_4fma, v) = Iw[j][i][v];
- }
- }
- }
- tile_4fma++;
- if (tile_4fma == conv.tile_4fma) {
- float *outp = &(output(0, 0, 0, 0,
- nb_tile_block_ur, tile_block_ur, 0));
- transpose_4fma_ker(outp, (float *)Iw_buffer);
- tile_4fma = 0;
- }
- }
- else {
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- store_output(
- &(output(0, j, i, 0,
- nb_tile_block_ur, tile_block_ur, 0)),
- Iw[j][i], false);
-
- }
- }
- }
- n_tiles++;
- }
- }
-}
-
-template <bool with_bias>
-void diff_dst_transform_bwd_weights_tile(int tile_block,
- jit_conv_winograd_conf_t conv, float *inp, float *tinp, float *dbias)
-{
- float I[alpha][alpha][simd_w];
- float Iw[alpha][alpha][simd_w];
-
- array_offset_calculator<float, 5> input(inp,
- conv.mb, conv.oc / simd_w, conv.oh, conv.ow, conv.oc_simd_block);
- array_offset_calculator<float, 7> output(tinp,
- conv.nb_oc, alpha, alpha,
- conv.oc_block,
- conv.nb_tile_block_ur,
- conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block);
-
- int n_tiles = tile_block * conv.nb_tile_block_ur * conv.tile_block_ur;
- for (int nb_tile_block_ur = 0; nb_tile_block_ur < conv.nb_tile_block_ur;
- nb_tile_block_ur++) {
- for (int tile_block_ur = 0; tile_block_ur < conv.tile_block_ur;
- tile_block_ur++) {
-
- int img = n_tiles / (conv.jtiles * conv.itiles);
- int no_tile = n_tiles % (conv.jtiles * conv.itiles);
- int ti = no_tile % conv.itiles;
- int tj = no_tile / conv.itiles;
-
- for (int j = 0; j < alpha; j++) {
- int ydim = tj * tile_size + j;
- if (ydim < conv.oh) {
- for (int i = 0; i < alpha; i++) {
- int xdim = ti * tile_size + i;
- if (xdim < conv.ow) {
- float *input_base = &input(img, 0, ydim, xdim, 0);
-
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = input_base[v];
- }
- if (with_bias && j < tile_size && i < tile_size) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- dbias[v] += input_base[v];
- }
- }
- }
- else {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- }
- else {
- for (int i = 0; i < alpha; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- }
-
- trans_W_3x3_4x4_wu(Iw, I);
-
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- /*TODO: Try instrinsic for casting into __m512*/
- store_output(&(output(0, j, i, 0,
- nb_tile_block_ur, tile_block_ur, 0)),
- Iw[j][i], false);
- }
- }
- n_tiles++;
- }
- }
-}
-
-// Sum to the first buffer array
-void array_sum(int num_arrs, float *output,
- size_t nelems, float *input_ptrs[], bool reduce_to_first = true)
-{
- const size_t block_size = 16 * 1024 / sizeof(float);
- const size_t blocks_number = nelems / block_size;
- const size_t tail = nelems % block_size;
-
-#pragma omp parallel
- {
- const int ithr = mkldnn_get_thread_num();
- const int nthr = mkldnn_get_num_threads();
- size_t start{ 0 }, end{ 0 };
- balance211(blocks_number, nthr, ithr, start, end);
-
- for (size_t nb = start; nb < end; ++nb) {
- size_t start_e = nb * block_size;
- size_t end_e = start_e + block_size;
- if (!reduce_to_first) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] = input_ptrs[0][e];
- }
- }
- for (int a = 1; a < num_arrs; a++) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] += input_ptrs[a][e];
- }
- }
- }
-
- if (tail != 0 && ithr == nthr - 1) {
- size_t start_e = nelems - tail;
- size_t end_e = nelems;
- if (!reduce_to_first) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] = input_ptrs[0][e];
- }
- }
- for (int a = 1; a < num_arrs; a++) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] += input_ptrs[a][e];
- }
- }
- }
- }
-}
-
-void subarray_sum(int num_arrs, float *output, size_t nelems,
- float *input_ptrs[], size_t input_starts[], size_t input_ends[])
-{
- using namespace nstl;
- const size_t block_size = 16 * 1024 / sizeof(float);
- const size_t blocks_number = nelems / block_size;
- const size_t tail = nelems % block_size;
-
-#pragma omp parallel
- {
- const int ithr = mkldnn_get_thread_num();
- const int nthr = mkldnn_get_num_threads();
- size_t start{ 0 }, end{ 0 };
- balance211(blocks_number, nthr, ithr, start, end);
-
- for (size_t nb = start; nb < end; ++nb) {
- size_t start_e = nb * block_size;
- size_t end_e = start_e + block_size;
- size_t input_start = max(start_e, min(input_starts[0], end_e));
- size_t input_end = max(start_e, min(input_ends[0], end_e));
-
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < input_start; e++) {
- output[e] = 0.f;
- }
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_start; e < input_end; e++) {
- output[e] = input_ptrs[0][e];
- }
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_end; e < end_e; e++) {
- output[e] = 0.f;
- }
- for (int a = 1; a < num_arrs; a++) {
- input_start = max(start_e, input_starts[a]);
- input_end = min(input_ends[a], end_e);
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_start; e < input_end; e++) {
- output[e] += input_ptrs[a][e];
- }
- }
- }
-
- if (tail != 0 && ithr == nthr - 1) {
- size_t start_e = nelems - tail;
- size_t end_e = nelems;
- size_t input_start = max(start_e, min(input_starts[0], end_e));
- size_t input_end = max(start_e, min(input_ends[0], end_e));
-
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < input_start; e++) {
- output[e] = 0.f;
- }
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_start; e < input_end; e++) {
- output[e] = input_ptrs[0][e];
- }
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_end; e < end_e; e++) {
- output[e] = 0.f;
- }
- for (int a = 1; a < num_arrs; a++) {
- input_start = max(start_e, input_starts[a]);
- input_end = min(input_ends[a], end_e);
-
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] += input_ptrs[a][e];
- }
- }
- }
- }
-}
-} // namespace
-
-void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_execute_backward_weights_S_D_Giot_W()
-{
- const auto &jcp = kernel_->jcp;
- const int nthreads = scratchpad_->num_threads();
- int U_size = jcp.oc * jcp.ic * alpha * alpha * sizeof(float);
-
- auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ?
- diff_src_transform_bwd_weights<true> :
- diff_src_transform_bwd_weights<false>;
- auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
- ? diff_dst_transform_bwd_weights<true>
- : diff_dst_transform_bwd_weights<false>;
-
- array_offset_calculator<float, 5> diff_src((float *)this->input_memory(0),
- jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
- array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
- jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
- array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
- jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
- array_offset_calculator<float, 2> diff_bias(
- conf_.want_padded_bias() ? padded_bias_ : (float *)this->memory(1),
- jcp.oc / simd_w, simd_w);
-
- array_offset_calculator<float, 8> U((float *)(scratchpad_->U_ptr()),
- jcp.nb_ic, jcp.nb_oc,
- alpha, alpha,
- jcp.oc_block, jcp.ic_block,
- jcp.ic_simd_block, jcp.oc_simd_block);
-
- array_offset_calculator<float, 9> Us(
- (float *)(scratchpad_->U_ptr() + U_size),
- 0, jcp.nb_ic, jcp.nb_oc,
- alpha, alpha,
- jcp.oc_block, jcp.ic_block,
- jcp.ic_simd_block, jcp.oc_simd_block);
-
- array_offset_calculator<float, 8> M((float *)(scratchpad_->M_ptr()),
- jcp.nb_oc, alpha, alpha,
- jcp.tile_block, jcp.oc_block,
- jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
- jcp.oc_simd_block);
-
- array_offset_calculator<float, 8> V((float *)(scratchpad_->V_ptr()),
- jcp.nb_ic, alpha, alpha,
- jcp.tile_block, jcp.ic_block,
- jcp.nb_tile_block_ur, jcp.tile_block_ur,
- jcp.ic_simd_block * jcp.tile_4fma);
-
- const int trans_buffer_size = alpha * alpha * jcp.tile_4fma
- * jcp.ic_simd_block;
- array_offset_calculator<float, 2> trans_buffer(
- (float *)(scratchpad_->src_transpose_ptr()),
- nthreads,
- trans_buffer_size);
-
- array_offset_calculator<float, 2> diff_bias_prv(
- (float *)(scratchpad_->bias_ptr()), nthreads, jcp.oc);
-
-#pragma omp parallel
- {
- if (jcp.with_bias) {
- parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
- diff_bias_prv(ithr, ofm) = 0.0f;
- });
-#pragma omp for nowait
- for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++)
- diff_bias(bofm, v) = 0.0f;
- }
- }
- }
-
-#pragma omp parallel
- {
- const int ithread = mkldnn_get_thread_num();
- parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block,
- [&](int img, int ifm1, int ifm2) {
- float *transb = jcp.ver == ver_4fma
- ? &(trans_buffer(ithread, 0))
- : NULL;
- diff_src_transform_bwd_weights_ver(img, jcp,
- &(diff_src(img, ifm1 * jcp.ic_block + ifm2,
- 0, 0, 0)),
- &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)),
- transb,
- kernel_->transpose_4fma_ker);
- });
- }
-
-#pragma omp parallel num_threads(nthreads)
- {
- parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block,
- [&](int img, int ofm1, int ofm2) {
- const int ithread = mkldnn_get_thread_num();
- float *dbias = jcp.with_bias
- ? &(diff_bias_prv(ithread,
- simd_w * (ofm1 * jcp.oc_block + ofm2)))
- : NULL;
- diff_dst_transform_bwd_weights_ver(img, jcp,
- &(diff_dst(img, ofm1 * jcp.oc_block + ofm2, 0, 0, 0)),
- &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)), dbias);
- });
- }
-
- size_t input_starts[max_threads_number];
- size_t input_ends[max_threads_number];
- int th_counter = 0;
-#pragma omp parallel firstprivate(th_counter) num_threads(nthreads)
- {
- parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
- [&](int ifm1, int ofm1, int oj, int oi, int tile_block) {
- int ithr = mkldnn_get_thread_num();
- if (th_counter == 0) {
- input_starts[ithr] = (float *)&(Us(ithr, ifm1, ofm1,
- oj, oi, 0, 0, 0, 0)) - (float *)&(Us(ithr, 0, 0,
- 0, 0, 0, 0, 0, 0));
- input_ends[ithr] = input_starts[ithr]
- + jcp.oc_block * jcp.ic_block
- * jcp.ic_simd_block * jcp.oc_simd_block;
- }
- else if (tile_block == 0) {
- input_ends[ithr] += jcp.oc_block * jcp.ic_block
- * jcp.ic_simd_block * jcp.oc_simd_block;
- }
-
- if (th_counter == 0 || tile_block == 0) {
- kernel_->gemm_loop_ker_first_iter(
- &(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, 0)),
- &(M(ofm1, oj, oi, tile_block, 0, 0, 0, 0)),
- &(V(ifm1, oj, oi, tile_block, 0, 0, 0, 0)));
- } else {
- kernel_->gemm_loop_ker(
- &(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, 0)),
- &(M(ofm1, oj, oi, tile_block, 0, 0, 0, 0)),
- &(V(ifm1, oj, oi, tile_block, 0, 0, 0, 0)));
- }
- th_counter++;
- });
- }
-
-
- // Reduce diff-weights
- {
- float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0));
- size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
- float *input_ptrs[max_threads_number];
- for (int i = 0; i < nthreads; i++)
- input_ptrs[i] = output + nelems * (i + 1);
- subarray_sum(
- nthreads, output, nelems, input_ptrs, input_starts, input_ends);
- }
-
- parallel_nd(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
- [&](int ifm1, int ofm1, int ofm2, int ifm2) {
- diff_weights_transform_bwd_weights(jcp,
- &(diff_weights(ofm1 * jcp.oc_block + ofm2,
- ifm1 * jcp.ic_block + ifm2,
- 0, 0, 0, 0)),
- &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0)));
- });
-
-#pragma omp parallel
- if (jcp.with_bias) {
-#pragma omp for
- for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
- for (int ithr = 0; ithr < nthreads; ithr++) {
- float* base_bias_ptr = &(diff_bias(ofm1, 0));
- float* base_bias_prv_ptr = &(diff_bias_prv(
- ithr * jcp.oc + ofm1 * simd_w));
- PRAGMA_OMP_SIMD()
- for (int ofm2 = 0; ofm2 < simd_w; ofm2++) {
- base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2];
- }
- }
- }
- }
-
- _maybe_execute_diff_bias_copy();
-}
-
-void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_execute_backward_weights_SDGtWo()
-{
- const auto &jcp = kernel_->jcp;
- const int nthreads = scratchpad_->num_threads();
-
- auto diff_src_transform_bwd_weights_ver_tile = jcp.ver == ver_4fma ?
- diff_src_transform_bwd_weights_tile<true> :
- diff_src_transform_bwd_weights_tile<false>;
- auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
- ? diff_dst_transform_bwd_weights_tile<true>
- : diff_dst_transform_bwd_weights_tile<false>;
-
- array_offset_calculator<float, 5> diff_src((float *)this->input_memory(0),
- jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
- array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
- jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
- array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
- jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
- array_offset_calculator<float, 3> diff_bias(
- conf_.want_padded_bias() ? padded_bias_ : (float *)this->memory(1),
- jcp.nb_oc, jcp.oc_block, simd_w);
-
- array_offset_calculator<float, 8> Us((float *)(scratchpad_->U_ptr()),
- 0, jcp.nb_ic, alpha, alpha,
- jcp.oc_block, jcp.ic_block,
- jcp.ic_simd_block, jcp.oc_simd_block);
-
- array_offset_calculator<float, 7> M((float *)(scratchpad_->M_ptr()),
- 0, alpha, alpha,
- jcp.oc_block,
- jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
- jcp.oc_simd_block);
-
- array_offset_calculator<float, 8> V((float *)(scratchpad_->V_ptr()),
- 0, jcp.nb_ic, alpha, alpha,
- jcp.ic_block,
- jcp.nb_tile_block_ur, jcp.tile_block_ur,
- jcp.ic_simd_block * jcp.tile_4fma);
-
- array_offset_calculator<float, 2> diff_bias_prv(
- (float *)(scratchpad_->bias_ptr()),
- nthreads, jcp.oc / jcp.nb_oc);
-
- for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
- int th_counter = 0;
-
-#pragma omp parallel
- {
- if (jcp.with_bias) {
- parallel_nd_in_omp(nthreads, jcp.oc / jcp.nb_oc,
- [&](int ithr, int ofm) {
- diff_bias_prv(ithr, ofm) = 0.0f;
- });
-#pragma omp for nowait
- for (int bofm = 0; bofm < jcp.oc_block; bofm++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++)
- diff_bias(ofm1, bofm, v) = 0.0f;
- }
- }
- }
-
-#pragma omp parallel firstprivate(th_counter) num_threads(nthreads)
-#pragma omp for nowait
- for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
- int ithr = mkldnn_get_thread_num();
- for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
- for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
- diff_src_transform_bwd_weights_ver_tile(tile_block, jcp,
- &(diff_src(0, ifm1 * jcp.ic_block + ifm2, 0, 0, 0)),
- &(V(ithr, ifm1, 0, 0, ifm2, 0, 0, 0)),
- kernel_->transpose_4fma_ker);
- }
- }
-
- for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
- float *dbias = jcp.with_bias
- ? &(diff_bias_prv(ithr, simd_w * ofm2))
- : NULL;
- diff_dst_transform_bwd_weights_ver(tile_block, jcp,
- &(diff_dst(0, ofm1 * jcp.oc_block + ofm2, 0, 0, 0)),
- &(M(ithr, 0, 0, ofm2, 0, 0, 0)),
- dbias);
- }
-
- for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
- for (int oj = 0; oj < alpha; oj++) {
- for (int oi = 0; oi < alpha; oi++) {
- if (th_counter == 0)
- kernel_->gemm_loop_ker_first_iter(
- &(Us(ithr, ifm1, oj, oi, 0, 0, 0, 0)),
- &(M(ithr, oj, oi, 0, 0, 0, 0)),
- &(V(ithr, ifm1, oj, oi, 0, 0, 0, 0)));
- else
- kernel_->gemm_loop_ker(
- &(Us(ithr, ifm1, oj, oi, 0, 0, 0, 0)),
- &(M(ithr, oj, oi, 0, 0, 0, 0)),
- &(V(ithr, ifm1, oj, oi, 0, 0, 0, 0)));
- }
- }
- }
- th_counter++;
- }
- // Reduce diff-weights
- {
- float *output = (float *)(scratchpad_->U_ptr());
- size_t nelems
- = jcp.ic * (jcp.oc / jcp.nb_oc) * alpha * alpha;
- float *input_ptrs[max_threads_number];
- for (int i = 0; i < nthreads; i++) {
- input_ptrs[i] = output + nelems * i;
- }
- array_sum(nthreads, output, nelems, input_ptrs);
- }
-
- parallel_nd(jcp.nb_ic, jcp.oc_block, jcp.ic_block,
- [&](int ifm1, int ofm2, int ifm2) {
- diff_weights_transform_bwd_weights(jcp,
- &(diff_weights(ofm1 * jcp.oc_block + ofm2,
- ifm1 * jcp.ic_block + ifm2,
- 0, 0, 0, 0)),
- &(Us(0, ifm1, 0, 0, ofm2, ifm2, 0, 0)));
- });
-
-#pragma omp parallel
- if (jcp.with_bias) {
-#pragma omp for
- for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
- for (int ithr = 0; ithr < nthreads; ithr++) {
- float* base_bias_ptr = &(diff_bias(ofm1, ofm2, 0));
- float* base_bias_prv_ptr = &(diff_bias_prv(
- ithr * jcp.oc_block * simd_w + ofm2 * simd_w));
- PRAGMA_OMP_SIMD()
- for (int ofm3 = 0; ofm3 < simd_w; ofm3++) {
- base_bias_ptr[ofm3] += base_bias_prv_ptr[ofm3];
- }
- }
- }
- }
- }
-
- _maybe_execute_diff_bias_copy();
-}
-
-void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_execute_backward_weights_SDGt_W()
-{
- const auto &jcp = kernel_->jcp;
- const int nthreads = scratchpad_->num_threads();
-
- auto diff_src_transform_bwd_weights_ver_tile = jcp.ver == ver_4fma ?
- diff_src_transform_bwd_weights_tile<true> :
- diff_src_transform_bwd_weights_tile<false>;
- auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
- ? diff_dst_transform_bwd_weights_tile<true>
- : diff_dst_transform_bwd_weights_tile<false>;
-
- array_offset_calculator<float, 5> diff_src((float *)this->input_memory(0),
- jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
- array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
- jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
- array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
- jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
- array_offset_calculator<float, 2> diff_bias(
- conf_.want_padded_bias() ? padded_bias_ : (float *)this->memory(1),
- jcp.oc / simd_w, simd_w);
-
- array_offset_calculator<float, 8> U((float *)(scratchpad_->U_ptr()),
- jcp.nb_oc, jcp.nb_ic,
- alpha, alpha,
- jcp.oc_block, jcp.ic_block,
- jcp.ic_simd_block, jcp.oc_simd_block);
-
- array_offset_calculator<float, 9> Us((float *)(scratchpad_->U_ptr()),
- 0, jcp.nb_oc, jcp.nb_ic,
- alpha, alpha,
- jcp.oc_block, jcp.ic_block,
- jcp.ic_simd_block, jcp.oc_simd_block);
-
- array_offset_calculator<float, 8> M((float *)(scratchpad_->M_ptr()),
- 0, jcp.nb_oc, alpha, alpha, jcp.oc_block,
- jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
- jcp.oc_simd_block);
-
- array_offset_calculator<float, 8> V((float *)(scratchpad_->V_ptr()),
- 0, jcp.nb_ic, alpha, alpha, jcp.ic_block,
- jcp.nb_tile_block_ur, jcp.tile_block_ur,
- jcp.ic_simd_block * jcp.tile_4fma);
-
- array_offset_calculator<float, 2> diff_bias_prv(
- (float *)(scratchpad_->bias_ptr()),
- nthreads, jcp.oc);
-
-#pragma omp parallel
- {
- if (jcp.with_bias) {
- parallel_nd_in_omp(nthreads, jcp.oc,
- [&](int ithr, int ofm) {
- diff_bias_prv(ithr, ofm) = 0.0f;
- });
-#pragma omp for nowait
- for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++)
- diff_bias(bofm, v) = 0.0f;
- }
- }
- }
-
- int th_counter = 0;
-#pragma omp parallel firstprivate(th_counter) num_threads(nthreads)
-#pragma omp for nowait
- for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
- int ithr = mkldnn_get_thread_num();
-
- for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
- for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
- diff_src_transform_bwd_weights_ver_tile(tile_block, jcp,
- &(diff_src(0, ifm1 * jcp.ic_block + ifm2,
- 0, 0, 0)),
- &(V(ithr, ifm1, 0, 0, ifm2, 0, 0, 0)),
- kernel_->transpose_4fma_ker);
- }
- }
-
- for (int ofm1 = 0; ofm1 < jcp.nb_oc; ofm1++) {
- for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
- float *dbias = jcp.with_bias
- ? &(diff_bias_prv(ithr,
- simd_w * (ofm1 * jcp.oc_block + ofm2)))
- : NULL;
- diff_dst_transform_bwd_weights_ver(tile_block, jcp,
- &(diff_dst(0, ofm1 * jcp.oc_block + ofm2,
- 0, 0, 0)),
- &(M(ithr, ofm1, 0, 0, ofm2, 0, 0, 0)),
- dbias);
- }
- }
-
- for (int ofm1 = 0; ofm1 < jcp.nb_oc; ofm1++) {
- for (int oj = 0; oj < alpha; oj++) {
- for (int oi = 0; oi < alpha; oi++) {
- for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
- if (th_counter == 0)
- kernel_->gemm_loop_ker_first_iter(
- &(Us(ithr, ofm1, ifm1, oj, oi, 0, 0, 0, 0)),
- &(M(ithr, ofm1, oj, oi, 0, 0, 0, 0)),
- &(V(ithr, ifm1, oj, oi, 0, 0, 0, 0)));
- else
- kernel_->gemm_loop_ker(
- &(Us(ithr, ofm1, ifm1, oj, oi, 0, 0, 0, 0)),
- &(M(ithr, ofm1, oj, oi, 0, 0, 0, 0)),
- &(V(ithr, ifm1, oj, oi, 0, 0, 0, 0)));
- }
- }
- }
- }
- th_counter++;
- }
-
- // Reduce diff-weights
- {
- float *output = (float *)(scratchpad_->U_ptr());
- size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
- float *input_ptrs[max_threads_number];
- for (int i = 0; i < nthreads; i++) {
- input_ptrs[i] = output + nelems * i;
- }
- array_sum(nthreads, output, nelems, input_ptrs);
- }
-
- parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block,
- [&](int ofm1, int ifm1, int ofm2, int ifm2) {
- diff_weights_transform_bwd_weights(jcp,
- &(diff_weights(ofm1 * jcp.oc_block + ofm2,
- ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)),
- &(U(ofm1, ifm1, 0, 0, ofm2, ifm2, 0, 0)));
- });
-
-#pragma omp parallel
- if (jcp.with_bias) {
-#pragma omp for
- for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
- for (int ithr = 0; ithr < nthreads; ithr++) {
- float* base_bias_ptr = &(diff_bias(ofm1, 0));
- float* base_bias_prv_ptr = &(diff_bias_prv(
- ithr * jcp.oc + ofm1 * simd_w));
- PRAGMA_OMP_SIMD()
- for (int ofm2 = 0; ofm2 < simd_w; ofm2++) {
- base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2];
- }
- }
- }
- }
-
- _maybe_execute_diff_bias_copy();
-}
}
}
}