1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_avx512_common_1x1_convolution.hpp"
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
35 #define data_blk_off(f, n, c, h, w) \
37 ? (f).blk_off(n, c, w) \
38 : (f).blk_off(n, c, h, w))
42 template <typename T, typename U>
43 void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
44 T nx, T &nx_start, T &nx_end, T nx_divider)
46 const int grp_count = nstl::min(nx_divider, nthr);
47 const int grp_size_big = nthr / grp_count + 1;
48 const int grp_size_small = nthr / grp_count;
49 const int n_grp_big = nthr % grp_count;
50 const int threads_in_big_groups = n_grp_big * grp_size_big;
52 const int ithr_bound_distance = ithr - threads_in_big_groups;
53 T grp, grp_ithr, grp_nthr;
54 if (ithr_bound_distance < 0) { // ithr in first groups
55 grp = ithr / grp_size_big;
56 grp_ithr = ithr % grp_size_big;
57 grp_nthr = grp_size_big;
58 } else { // ithr in last groups
59 grp = n_grp_big + ithr_bound_distance / grp_size_small;
60 grp_ithr = ithr_bound_distance % grp_size_small;
61 grp_nthr = grp_size_small;
64 balance211(nx, grp_count, grp, nx_start, nx_end);
65 balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
68 /* convolution forward */
70 template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
71 void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
72 execute_forward() const {
73 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
75 reinterpret_cast<const wei_data_t *>(this->input_memory(1));
76 auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
77 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
79 auto scratchpad = this->scratchpad();
81 auto &jcp = kernel_->jcp;
82 if (pd()->wants_padded_bias()) {
83 auto padded_bias = scratchpad.template get<dst_data_t>(
84 key_conv_padded_bias);
85 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
86 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
87 jcp.oc - jcp.oc_without_padding);
91 parallel(0, [&](const int ithr, const int nthr) {
92 execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
95 if (pd()->wants_zero_pad_dst())
96 output_memory_primitive(0)->zero_pad();
99 template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
100 void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
101 execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
102 const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst,
103 const memory_tracking::grantor_t &scratchpad) const {
104 const memory_desc_wrapper src_d(pd()->src_pd());
105 const memory_desc_wrapper dst_d(pd()->dst_pd());
106 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
108 auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
110 const int ndims = src_d.ndims();
111 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
112 const int stride_w = pd()->desc()->strides[ndims - 3];
113 const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
114 const int pad_l = pd()->desc()->padding[0][ndims - 3];
116 const auto &jcp = kernel_->jcp;
117 const int MB = pd()->MB();
118 const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
120 auto step = [](int default_step, int remaining, int tail_step) {
121 assert(default_step <= tail_step);
122 return remaining < tail_step ? remaining : default_step;
125 auto p = jit_1x1_conv_call_s();
127 auto rp = rtus_driver_t<avx512_common>::call_params_t();
129 const int nb_oc = jcp.nb_load;
130 const int nb_ic = jcp.nb_reduce;
131 const int nb_ic_blocking = jcp.nb_reduce_blocking;
132 const int os_block = jcp.bcast_block;
134 int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
135 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
136 jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count);
138 auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
139 int &oh, int &ow, int &ih, int &iw)
142 nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
144 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
145 jcp.nb_bcast_blocking_max);
146 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
148 const int os = osb * os_block;
152 ih = nstl::max(oh * stride_h - pad_t, 0);
153 iw = nstl::max(ow * stride_w - pad_l, 0);
156 p.bcast_dim = this_block_size(os, jcp.os,
157 bcast_step * os_block);
161 auto init_load = [&](int ocb, int &load_step)
163 load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
164 jcp.nb_load_blocking_max);
165 p.load_dim = this_block_size(ocb * jcp.oc_block,
166 ocb_end * jcp.oc_block, load_step * jcp.oc_block);
169 auto init_reduce = [&](int icb)
171 const int nb_ic_blocking_step =
172 nstl::min(icb + nb_ic_blocking, nb_ic) - icb;
173 p.first_last_flag = 0
174 | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
175 | (icb + nb_ic_blocking_step >= nb_ic
176 ? FLAG_REDUCE_LAST : 0);
178 p.reduce_dim = this_block_size(icb * jcp.ic_block,
179 jcp.ic, nb_ic_blocking_step * jcp.ic_block);
180 rp.icb = p.reduce_dim / jcp.reduce_block;
183 auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow,
187 const int _ocb = g * nb_oc + ocb;
188 const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
190 p.output_data = &dst[dst_off];
191 p.bias_data = &bias[_ocb * jcp.oc_block];
192 p.load_data = &weights[pd()->with_groups()
193 ? weights_d.blk_off(g, ocb, icb)
194 : weights_d.blk_off(ocb, icb)];
196 const int _icb = g * nb_ic + icb;
197 if (pd()->rtus_.reduce_src_) {
198 rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
199 + _icb * jcp.is * jcp.ic_block;
200 if (ocb == ocb_start) {
201 rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
202 rtus_driver_->ker_(&rp);
204 p.bcast_data = rp.ws;
206 p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw);
208 p.oc_off = _ocb * jcp.oc_block * sizeof(dst_data_t);
210 kernel_->jit_ker(&p);
213 if (jcp.loop_order == loop_rlb) {
214 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
217 while (ocb < ocb_end) {
219 init_load(ocb, load_step);
220 int iwork = bcast_start;
221 while (iwork < bcast_end) {
222 int n, g, bcast_step, oh, ow, ih, iw;
223 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
224 inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
230 } else if (jcp.loop_order == loop_lbr) {
232 while (ocb < ocb_end) {
234 init_load(ocb, load_step);
235 int iwork = bcast_start;
236 while (iwork < bcast_end) {
237 int n, g, bcast_step, oh, ow, ih, iw;
238 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
239 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
241 inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
247 } else if (jcp.loop_order == loop_rbl) {
248 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
250 int iwork = bcast_start;
251 while (iwork < bcast_end) {
252 int n, g, bcast_step, oh, ow, ih, iw;
253 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
255 while (ocb < ocb_end) {
257 init_load(ocb, load_step);
258 inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
264 } else if (jcp.loop_order == loop_blr) {
265 int iwork = bcast_start;
266 while (iwork < bcast_end) {
267 int n, g, bcast_step, oh, ow, ih, iw;
268 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
270 while (ocb < ocb_end) {
272 init_load(ocb, load_step);
273 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
275 inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
282 assert(!"unsupported loop order");
287 template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
288 template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::s16,
289 data_type::s16, data_type::s32>;
290 /* convolution backward wtr data */
292 template <data_type_t diff_dst_type, data_type_t wei_type,
293 data_type_t diff_src_type>
294 void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type,
295 diff_src_type>::execute_backward_data() const {
296 auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
297 (this->input_memory(0));
298 auto weights = reinterpret_cast<const wei_data_t *>
299 (this->input_memory(1));
300 auto diff_src = reinterpret_cast<diff_src_data_t *>(this->memory());
302 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
303 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
304 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
306 auto rtus_space = scratchpad().template get<diff_src_data_t>(
307 key_conv_rtus_space);
309 const int ndims = diff_src_d.ndims();
310 const auto &jcp = kernel_->jcp;
311 const int MB = pd()->MB();
313 // TODO (Roma): remove this restriction
314 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
316 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
317 const int stride_w = pd()->desc()->strides[ndims - 3];
318 const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
319 const int pad_l = pd()->desc()->padding[0][ndims - 3];
321 const int nb_ic = jcp.nb_load;
322 const int nb_oc = jcp.nb_reduce;
323 const int os_block = jcp.bcast_block;
324 const int nb_oc_blocking = jcp.nb_reduce_blocking;
326 const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
328 auto step = [](int default_step, int remaining, int tail_step) {
329 assert(default_step <= tail_step);
330 return remaining < tail_step ? remaining : default_step;
333 parallel(0, [&](const int ithr, const int nthr) {
334 auto p = jit_1x1_conv_call_s();
335 auto rp = rtus_driver_t<avx512_common>::call_params_t();
337 int bcast_start{0}, bcast_end{0}, icb_start{0}, icb_end{0};
338 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
339 jcp.nb_load, icb_start, icb_end, jcp.load_grp_count);
341 bool reduce_outer = (jcp.loop_order == loop_rbl
342 || jcp.loop_order == loop_rlb);
343 int nboc_outer = reduce_outer ? nb_oc : 1;
344 int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1;
346 int nboc_inner = reduce_outer ? 1 : nb_oc;
347 int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking;
349 for (int ocb_outer = 0; ocb_outer < nboc_outer;
350 ocb_outer += ocb_outer_step) {
351 size_t cur_ocb_outer =
352 nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - ocb_outer;
355 for (int icb = icb_start; icb < icb_end; icb += load_step) {
356 load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
357 jcp.nb_load_blocking_max);
359 p.load_dim = this_block_size(icb * jcp.ic_block,
360 icb_end * jcp.ic_block, load_step * jcp.ic_block);
361 rp.icb = p.load_dim / jcp.ic_block;
364 for (int iwork = bcast_start; iwork < bcast_end;
367 int n{0}, g{0}, osb{0};
368 nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
371 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
372 jcp.nb_bcast_blocking_max);
373 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
375 const int os = osb * os_block;
376 p.bcast_dim = this_block_size(os, jcp.os,
377 bcast_step * os_block);
380 const int oh = os / jcp.ow;
381 const int ow = os % jcp.ow;
382 const int ih = nstl::max(oh * stride_h - pad_t, 0);
383 const int iw = nstl::max(ow * stride_w - pad_l, 0);
386 const int _icb = g * nb_ic + icb;
387 rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
388 if (pd()->rtus_.reduce_src_) {
390 + ithr * pd()->rtus_.space_per_thread_;
391 p.output_data = rp.ws;
393 p.output_data = rp.src;
395 for (int ocb_inner = 0; ocb_inner < nboc_inner;
396 ocb_inner += ocb_inner_step) {
398 nstl::min(ocb_inner + ocb_inner_step, nboc_inner) -
401 int ocb = reduce_outer ? ocb_outer : ocb_inner;
402 int nb_oc_blocking_step = reduce_outer
403 ? cur_ocb_outer : cur_ocb_inner;
404 const int _ocb = g * nb_oc + ocb;
405 size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow);
406 p.bcast_data = &diff_dst[diff_dst_off];
408 p.load_data = &weights[pd()->with_groups()
409 ? weights_d.blk_off(g, ocb, icb)
410 : weights_d.blk_off(ocb, icb)];
412 p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
414 p.reduce_dim = this_block_size(ocb * jcp.oc_block,
415 jcp.oc, nb_oc_blocking_step * jcp.oc_block);
417 kernel_->jit_ker(&p);
419 if (pd()->rtus_.reduce_src_)
420 rtus_driver_->ker_(&rp);
427 template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
428 template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::s16,
429 data_type::s16, data_type::s32>;
431 /* convolution backward wtr weights */
433 #define wht_blk_off(d, g, ...) \
434 (pd()->with_groups() \
435 ? (d).blk_off((g), __VA_ARGS__) \
436 : (d).blk_off(__VA_ARGS__))
438 jit_avx512_common_1x1_convolution_bwd_weights_t ::
439 jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd,
440 const input_vector &inputs, const output_vector &outputs)
441 : cpu_primitive_t(apd, inputs, outputs)
442 , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
443 , trans_kernel_(nullptr), rtus_driver_(nullptr)
445 kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
446 acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
447 reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
448 init_rtus_driver<avx512_common>(this);
450 const auto &jcp = kernel_->jcp;
452 if (jcp.transpose_src) {
453 auto tp = jit_transpose4x16_src_t();
454 tp.src_pf0_distance = 4;
455 tp.tr_src_pf0_distance = 0;
457 tp.tr_src_pf1 = false;
458 trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp);
462 void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights() const
464 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
465 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
466 auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
467 auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
469 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
470 const memory_desc_wrapper src_d(pd()->src_pd());
471 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
473 const auto &jcp = kernel_->jcp;
475 const auto scratchpad = this->scratchpad();
477 auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
478 data_t *diff_bias = pd()->wants_padded_bias()
479 ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
480 auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction);
482 /* prepare src transposition barriers */
483 auto tr_src = scratchpad.get<data_t>(key_conv_tr_src);
484 auto tr_src_bctx = scratchpad.get<simple_barrier::ctx_t>(
485 key_conv_tr_src_bctx);
486 if (jcp.transpose_src) {
487 for (int i = 0; i < jcp.nthr; ++i)
488 simple_barrier::ctx_init(&tr_src_bctx[i]);
491 const int ndims = src_d.ndims();
492 const int wei_size = jcp.ngroups * jcp.oc * jcp.ic;
494 simple_barrier::ctx_t reduction_barrier;
495 simple_barrier::ctx_init(&reduction_barrier);
497 const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
499 auto rb = this->reducer_bias_;
500 rb->init(reducer_bia_scratchpad);
502 // TODO (Roma): remove this restriction
503 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
505 const int nb_ic = jcp.nb_bcast;
506 const int nb_ic_blocking = jcp.nb_bcast_blocking;
508 const int nb_oc = jcp.nb_load;
509 const int nb_oc_blocking = jcp.nb_load_blocking;
511 const int sp_nb = jcp.nb_reduce;
512 const int mb_sp_work = jcp.mb * sp_nb;
514 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
515 const int stride_w = pd()->desc()->strides[ndims - 3];
516 const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
517 const int pad_l = pd()->desc()->padding[0][ndims - 3];
519 auto step = [](int default_step, int remaining, int tail_step) {
520 assert(default_step <= tail_step);
521 return remaining < tail_step ? remaining : default_step;
524 // TODO: use memory descriptor with the same fmt as src
525 // (or use a macro :))
526 auto tr_src_off = [&](int img, int icb, int is) {
527 const size_t tr_chn_size = jcp.tr_is * jcp.ic_block;
528 const size_t tr_img_size = tr_chn_size * nb_ic * jcp.ngroups;
529 return img * tr_img_size + icb * tr_chn_size + is * jcp.ic_block;
532 auto uker_trans = [&](int ithr_mb, int img, int sp_b_start, int sp_size,
533 int g_start, int g_work, int ic_b_start, int ic_b_work,
534 int ithr, int nthr, int first_ic_b)
536 const int work_amount = g_work * ic_b_work;
538 int start{ 0 }, end{ 0 };
539 balance211(work_amount, nthr, ithr, start, end);
541 int g{ 0 }, ic_b{ 0 };
542 nd_iterator_init(start, g, g_work, ic_b, ic_b_work);
544 const int ic_b_tr = g * nb_ic + first_ic_b + ic_b;
547 const int _ic = g * nb_ic + ic_b;
549 const int is = sp_b_start * jcp.reduce_block;
550 const int ih = is / jcp.iw;
551 const int iw = is % jcp.iw;
553 const int src1_off = data_blk_off(src_d, img, _ic, ih, iw);
554 data_t *src1 = (data_t *)&src[src1_off];
555 data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)];
557 assert(jcp.ic_block == 16);
558 const int src_stride = jcp.is * jcp.ic_block;
559 const int tr_src_stride = jcp.tr_is * jcp.ic_block;
561 const int my_work = end - start;
562 for (int iwork = 0; iwork < my_work; iwork++) {
563 auto par_trans = jit_src_transpose_s();
564 assert(sp_size % 4 == 0 || sp_size % 4 == jcp.is % 4);
565 par_trans.size = sp_size;
566 par_trans.src = src1;
567 par_trans.tr_src = tr_src1;
568 par_trans.src_prf = src1 + 64 * 16;
569 par_trans.tr_src_prf = tr_src1 + 80 * 16;
570 trans_kernel_->jit_ker(&par_trans);
573 tr_src1 += tr_src_stride;
577 auto ker = [&](const int ithr, const int nthr) {
578 assert(nthr == jcp.nthr);
579 assert(IMPLICATION(!mkldnn_thr_syncable(), jcp.nthr_mb == 1));
581 const int ithr_ic_b = ithr % jcp.nthr_ic_b;
582 const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
583 const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
584 const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b /
587 const int ithr_but_oc
588 = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b;
590 /* reduction dimension */
591 int mb_sp_b_start{ 0 }, mb_sp_b_end{ 0 };
592 if (jcp.transpose_src && jcp.nthr_mb < jcp.mb / 2) {
593 // it's preferable to parallelize by mb if possible
594 int img_start{ 0 }, img_end{ 0 };
595 balance211(jcp.mb, jcp.nthr_mb, ithr_mb, img_start, img_end);
596 mb_sp_b_start = img_start * sp_nb;
597 mb_sp_b_end = img_end * sp_nb;
600 balance211(mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start,
604 /* independent dimensions */
605 int g_start{ 0 }, oc_b_start{ 0 }, ic_b_start{ 0 };
606 int g_end{ 0 }, oc_b_end{ 0 }, ic_b_end{ 0 };
608 balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end);
609 balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start,
611 balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start,
614 const int g_work = g_end - g_start;
615 const int oc_b_work = oc_b_end - oc_b_start;
616 const int ic_b_work = ic_b_end - ic_b_start;
618 data_t *diff_wei = ithr_mb == 0
619 ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size;
622 for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end;
623 mb_sp_b += sp_b_step) {
624 int img{ 0 }, sp_b{ 0 };
625 nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb);
626 sp_b_step = step(jcp.nb_reduce_blocking,
627 nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b),
628 jcp.nb_reduce_blocking_max);
630 for (int g = g_start; g < g_end; ++g) {
633 for (int ic_b = ic_b_start; ic_b < ic_b_end;
634 ic_b += bcast_step) {
635 bcast_step = step(nb_ic_blocking, ic_b_end - ic_b,
636 jcp.nb_bcast_blocking_max);
637 if (jcp.transpose_src) {
638 if (jcp.nthr_oc_b > 1)
639 simple_barrier::barrier(
640 &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
642 = nstl::min(sp_b_step * jcp.reduce_block,
643 jcp.is - sp_b * jcp.reduce_block);
644 uker_trans(ithr_mb, img, sp_b, sp_size, g, 1, ic_b,
645 bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start);
646 if (jcp.nthr_oc_b > 1)
647 simple_barrier::barrier(
648 &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
651 for (int oc_b = oc_b_start; oc_b < oc_b_end;
653 load_step = step(nb_oc_blocking, oc_b_end - oc_b,
654 jcp.nb_load_blocking_max);
655 const int _ic_b = g * nb_ic + ic_b;
656 const int _ic_b_tr = g * nb_ic + ic_b_start;
657 const int _oc_b = g * nb_oc + oc_b;
662 = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
663 store_to = diff_wei + off;
665 const data_t *diff_src = jcp.transpose_src ?
666 &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] :
667 &src[src_d.blk_off(img, _ic_b)];
669 int sp_b_end = sp_b + sp_b_step;
670 const data_t *pdiff_dst
671 = &diff_dst[diff_dst_d.blk_off(img, _oc_b)];
672 const data_t *local_src = diff_src;
674 auto p = jit_1x1_conv_call_s();
675 auto rp = rtus_driver_t<avx512_common>::call_params_t();
678 = jcp.ic * jcp.oc_block * jcp.typesize_out;
680 p.load_dim = load_step * jcp.oc_block;
682 p.bcast_dim = bcast_step * jcp.ic_block;
684 p.output_data = store_to;
686 p.reduce_dim = sp_b_step * jcp.reduce_block;
687 rp.os = p.reduce_dim;
689 p.first_last_flag = 0
690 | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0)
691 | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0);
693 int sp = sp_b * jcp.reduce_block;
694 p.load_data = pdiff_dst + sp * jcp.oc_block;
696 if (pd()->rtus_.reduce_src_) {
697 const int oh = sp / jcp.ow;
698 const int ow = sp % jcp.ow;
700 const int ih = nstl::max(oh * stride_h - pad_t, 0);
701 const int iw = nstl::max(ow * stride_w - pad_l, 0);
705 + ithr * pd()->rtus_.space_per_thread_
709 rp.src = local_src + iw
710 * src_d.blocking_desc().strides[0][2];
712 rp.src = local_src + ih
713 * src_d.blocking_desc().strides[0][2]
714 + iw * src_d.blocking_desc().strides[0][3];
715 rtus_driver_->ker_(&rp);
717 p.bcast_data = rp.ws;
719 p.bcast_data = local_src + sp * jcp.ic_block;
721 kernel_->jit_ker(&p);
727 /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */
728 if (jcp.nthr_mb > 1) {
729 simple_barrier::barrier(&reduction_barrier, jcp.nthr);
730 const int work = g_work * oc_b_work * ic_b_work;
731 int start{ 0 }, end{ 0 };
732 balance211(work, jcp.nthr_mb, ithr_mb, start, end);
736 for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
738 int sub_g_start{ 0 }, sub_oc_b_start{ 0 },
740 nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start,
741 oc_b_work, sub_ic_b_start, ic_b_work);
743 const int g = g_start + sub_g_start;
744 const int oc_b = oc_b_start + sub_oc_b_start;
745 const int ic_b = ic_b_start + sub_ic_b_start;
748 = nstl::min(end - w, ic_b_work - sub_ic_b_start)
749 * jcp.ic_block * jcp.oc_block;
752 = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
753 data_t *d = diff_weights + off;
754 data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off;
756 acc_ker_->accumulate(d, s, acc_size);
758 nd_iterator_jump(w, end, sub_g_start, g_work,
759 sub_oc_b_start, oc_b_work, sub_ic_b_start,
766 auto ker_bias = [&](int ithr, int nthr) {
767 assert(nthr == rb->balancer().nthr_);
769 const int b_job_start = rb->balancer().ithr_job_off(ithr);
770 const int b_njobs = rb->balancer().ithr_njobs(ithr);
775 /* reduction dimension */
776 int img_start{ 0 }, img_end{ 0 };
778 balance211(jcp.mb, rb->balancer().nthr_per_group_,
779 rb->balancer().id_in_group(ithr), img_start, img_end);
782 int g_start{ 0 }, ocb_start{ 0 };
784 b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load);
786 for (int img = img_start; img < img_end; ++img) {
787 int g = g_start, ocb = ocb_start;
788 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
789 const size_t _oc = g * jcp.nb_load + ocb;
791 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
792 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
793 reducer_bia_scratchpad)
794 + b_job_loc * rb->balancer().job_size_;
796 if (img == img_start)
797 for (int o = 0; o < 16; ++o)
800 for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) {
802 for (int o = 0; o < 16; ++o)
803 d_bias[o] += d_dst[o];
807 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load);
810 rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
813 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
815 if (pd()->with_bias())
816 ker_bias(ithr, jcp.nthr);
819 /* TODO: put this in ker_bias */
820 if (pd()->wants_padded_bias()) {
821 assert(jcp.ngroups == 1);
822 utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding);