1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_avx2_1x1_convolution.hpp"
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
35 #define data_blk_off(f, n, c, h, w) \
37 ? (f).blk_off(n, c, w) \
38 : (f).blk_off(n, c, h, w))
40 /* convolution forward */
42 void jit_avx2_1x1_convolution_fwd_t::execute_forward() const {
43 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
44 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
45 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
46 auto dst = reinterpret_cast<data_t *>(this->memory());
48 const memory_desc_wrapper src_d(pd()->src_pd());
49 const memory_desc_wrapper dst_d(pd()->dst_pd());
50 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
52 auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
54 const auto &jcp = kernel_->jcp;
55 const int MB = pd()->MB();
57 const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
58 const int ndims = dst_d.ndims();
60 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
61 const int stride_w = pd()->desc()->strides[ndims - 3];
62 const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
63 const int pad_l = pd()->desc()->padding[0][ndims - 3];
65 auto step = [](int default_step, int remaining, int tail_step) {
66 assert(default_step <= tail_step);
67 return remaining < tail_step ? remaining : default_step;
70 auto ker = [&](const int ithr, const int nthr) {
71 // TODO (Roma): remove this restriction
72 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
74 auto p = jit_1x1_conv_call_s();
75 auto rp = rtus_driver_t<avx2>::call_params_t();
77 const int nb_oc = jcp.nb_load;
78 const int nb_ic = jcp.nb_reduce;
79 const int nb_ic_blocking = jcp.nb_reduce_blocking;
80 const int os_block = jcp.bcast_block;
83 balance211(work_amount, nthr, ithr, start, end);
87 int n{0}, g{0}, osb{0};
88 nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
91 int bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
92 jcp.nb_bcast_blocking_max);
93 bcast_step = nstl::min(bcast_step, end - iwork);
95 const int os = osb * os_block;
96 const int oh = os / jcp.ow;
97 const int ow = os % jcp.ow;
99 const int ih = nstl::max(oh * stride_h - pad_t, 0);
100 const int iw = nstl::max(ow * stride_w - pad_l, 0);
103 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
107 while (ocb < jcp.nb_load) {
108 const int load_step = step(jcp.nb_load_blocking,
109 jcp.nb_load - ocb, jcp.nb_load_blocking_max);
111 const int _ocb = g * nb_oc + ocb;
112 p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
113 load_step * jcp.oc_block);
114 const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
116 p.output_data = &dst[dst_off];
118 p.bias_data = &bias[_ocb * jcp.oc_block];
120 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
121 p.first_last_flag = 0
122 | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
123 | (icb + nb_ic_blocking >= nb_ic
124 ? FLAG_REDUCE_LAST : 0);
126 p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
127 nb_ic_blocking * jcp.ic_block);
128 rp.icb = p.reduce_dim / jcp.reduce_block;
130 p.load_data = &weights[pd()->with_groups()
131 ? weights_d.blk_off(g, ocb, icb)
132 : weights_d.blk_off(ocb, icb)];
134 const int _icb = g * nb_ic + icb;
135 if (pd()->rtus_.reduce_src_) {
137 + ithr * pd()->rtus_.space_per_thread_
138 + _icb * jcp.is * jcp.ic_block;
141 rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
142 rtus_driver_->ker_(&rp);
145 p.bcast_data = rp.ws;
147 p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw);
149 p.oc_off = _ocb * jcp.oc_block * sizeof(float);
151 kernel_->jit_ker(&p);
161 if (pd()->wants_padded_bias()) {
162 auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
163 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
164 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
165 jcp.oc - jcp.oc_without_padding);
171 if (pd()->wants_zero_pad_dst())
172 output_memory_primitive(0)->zero_pad();
175 void jit_avx2_1x1_convolution_fwd_t::execute_forward_with_dw_conv() const {
176 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
177 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
178 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
179 auto dst = reinterpret_cast<data_t *>(this->memory());
181 const memory_desc_wrapper src_d(pd()->src_pd());
182 const memory_desc_wrapper dst_d(pd()->dst_pd());
183 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
185 auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
187 const auto &jcp = kernel_->jcp;
188 const auto &jcp_dw = kernel_dw_->jcp;
189 const int MB = pd()->MB();
191 auto dw_bias = jcp_dw.conv_biases;
193 int ocb_work = jcp.with_dw_conv ? utils::div_up(jcp.nb_load, jcp.nb_load_blocking) : 1;
194 const int work_amount = MB * jcp.ngroups * ocb_work * jcp.nb_bcast;
196 auto step = [](int default_step, int remaining, int tail_step) {
197 assert(default_step <= tail_step);
198 return remaining < tail_step ? remaining : default_step;
201 auto ker = [&](const int ithr, const int nthr) {
202 // TODO (Roma): remove this restriction
203 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
205 auto compute_block_1x1 = [&](float* ws_p, int n, int g, int oh, int ow, int ih, int iw, int os, int os_block, int bcast_step, int ocb, int load_step,
207 auto rp = rtus_driver_t<avx2>::call_params_t();
208 auto p = jit_1x1_conv_call_s();
210 for (int h = 0; h < num_rows; h++) {
211 ih = nstl::max((oh + h) * jcp.stride_h - jcp.t_pad, 0);
213 if ((oh + h) < 0 || (oh + h) >= jcp.ih) {
214 for (int chb = ocb; chb < ocb + load_step; chb++) {
215 memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
216 (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
219 const int _ocb = g * jcp.nb_load + ocb;
222 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
225 p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, load_step * jcp.oc_block);
227 p.output_data = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block];
229 p.bias_data = &bias[_ocb * jcp.oc_block];
231 for (int icb = 0; icb < jcp.nb_reduce; icb += jcp.nb_reduce_blocking) {
232 p.first_last_flag = 0
233 | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
234 | (icb + jcp.nb_reduce_blocking >= jcp.nb_reduce
235 ? FLAG_REDUCE_LAST : 0);
237 p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
238 jcp.nb_reduce_blocking * jcp.ic_block);
239 rp.icb = p.reduce_dim / jcp.reduce_block;
241 p.load_data = &weights[pd()->with_groups()
242 ? weights_d.blk_off(g, ocb, icb)
243 : weights_d.blk_off(ocb, icb)];
245 const int _icb = g * jcp.nb_reduce + icb;
246 if (pd()->rtus_.reduce_src_) {
248 + ithr * pd()->rtus_.space_per_thread_
249 + _icb * jcp.is * jcp.ic_block;
252 rp.src = src + src_d.blk_off(n, _icb, ih, iw);
253 rtus_driver_->ker_(&rp);
256 p.bcast_data = rp.ws;
258 p.bcast_data = src + src_d.blk_off(n, _icb, ih, iw);
261 p.oc_off = _ocb * jcp.oc_block * sizeof(float);
263 kernel_->jit_ker(&p);
269 auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int load_step, int dst_idx) {
271 for (int chb = ocb; chb < ocb + load_step; chb++) {
272 auto par_conv_dw = jit_conv_call_s();
274 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
275 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
276 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
277 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
278 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
279 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
281 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + chb*jcp_dw.ch_block*jcp_dw.oh*jcp_dw.ow +
282 dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
284 par_conv_dw.kh_padding = jcp_dw.kh;
285 par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
286 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
287 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
288 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
289 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
291 kernel_dw_->jit_ker(&par_conv_dw);
295 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
297 int start{0}, end{0};
298 balance211(work_amount, nthr, ithr, start, end);
300 auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
301 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
302 auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
304 const int os_block = jcp.iw;
307 while (iwork < end) {
308 int n{0}, g{0}, ocbb{0}, osb{0};
309 nd_iterator_init(iwork, n, MB, g, jcp.ngroups, ocbb, ocb_work, osb,
313 const int os = osb * os_block;
314 const int oh = os / jcp.ow;
315 const int ow = os % jcp.ow;
317 const int ih = nstl::max(oh * jcp.stride_h - jcp.t_pad, 0);
318 const int iw = nstl::max(ow * jcp.stride_w - jcp.l_pad, 0);
320 int ocb = ocbb * jcp.nb_load_blocking;
322 const int load_step = step(jcp.nb_load_blocking,
323 jcp.nb_load - ocb, jcp.nb_load_blocking_max);
325 if (iwork == start || oh == 0) {
326 bcast_step = nstl::min(1, end - iwork);
327 compute_block_1x1(pbuf, n, g, oh - 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step + 2);
329 bcast_step = nstl::min(1, end - iwork);
330 compute_block_1x1(pbuf, n, g, oh + 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step);
333 if ((oh % jcp_dw.stride_h == 0)) {
334 compute_row_dw(pbuf, n, ocb, load_step, oh);
341 if (pd()->wants_padded_bias()) {
342 auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
343 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
344 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
345 jcp.oc - jcp.oc_without_padding);
348 auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
349 utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
350 utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
351 jcp.oc - jcp.oc_without_padding);
352 dw_bias = dw_padded_bias;
357 if (pd()->wants_zero_pad_dst())
358 output_memory_primitive(0)->zero_pad();
361 /* convolution backward wtr data */
363 void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data() const {
364 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
365 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
366 auto diff_src = reinterpret_cast<data_t *>(this->memory());
368 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
369 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
370 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
372 auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
374 const auto &jcp = kernel_->jcp;
375 const int MB = pd()->MB();
377 // TODO (Roma): remove this restriction
378 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
379 const int ndims = diff_dst_d.ndims();
381 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
382 const int stride_w = pd()->desc()->strides[ndims - 3];
383 const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
384 const int pad_l = pd()->desc()->padding[0][ndims - 3];
386 const int nb_ic = jcp.nb_load;
387 const int nb_oc = jcp.nb_reduce;
388 const int os_block = jcp.bcast_block;
389 const int nb_oc_blocking = jcp.nb_reduce_blocking;
391 const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
393 auto step = [](int default_step, int remaining, int tail_step) {
394 assert(default_step <= tail_step);
395 return remaining < tail_step ? remaining : default_step;
398 auto ker = [&](const int ithr, const int nthr) {
399 auto p = jit_1x1_conv_call_s();
400 auto rp = rtus_driver_t<avx2>::call_params_t();
402 int start{0}, end{0};
403 balance211(work_amount, nthr, ithr, start, end);
406 for (int icb = 0; icb < jcp.nb_load; icb += load_step) {
407 load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
408 jcp.nb_load_blocking_max);
410 p.load_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
411 load_step * jcp.ic_block);
412 rp.icb = p.load_dim / jcp.ic_block;
415 for (int iwork = start; iwork < end; iwork += bcast_step) {
416 int n{0}, g{0}, osb{0};
417 nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
420 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
421 jcp.nb_bcast_blocking_max);
422 bcast_step = nstl::min(bcast_step, end - iwork);
424 const int os = osb * os_block;
425 p.bcast_dim = this_block_size(os, jcp.os,
426 bcast_step * os_block);
429 const int oh = os / jcp.ow;
430 const int ow = os % jcp.ow;
431 const int ih = nstl::max(oh * stride_h - pad_t, 0);
432 const int iw = nstl::max(ow * stride_w - pad_l, 0);
435 const int _icb = g * nb_ic + icb;
436 rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
437 if (pd()->rtus_.reduce_src_) {
439 + ithr * pd()->rtus_.space_per_thread_;
440 p.output_data = rp.ws;
442 p.output_data = rp.src;
444 for (int ocb = 0; ocb < jcp.nb_reduce;
445 ocb += jcp.nb_reduce_blocking) {
446 const int _ocb = g * nb_oc + ocb;
447 size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh,
449 p.bcast_data = &diff_dst[diff_dst_off];
451 p.load_data = &weights[pd()->with_groups()
452 ? weights_d.blk_off(g, ocb, icb)
453 : weights_d.blk_off(ocb, icb)];
455 p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
457 p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
458 nb_oc_blocking * jcp.oc_block);
460 kernel_->jit_ker(&p);
463 if (pd()->rtus_.reduce_src_)
464 rtus_driver_->ker_(&rp);
472 /* convolution backward wtr weights */
474 jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t(
475 const pd_t *apd, const input_vector &inputs,
476 const output_vector &outputs)
477 : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr)
478 , rtus_driver_(nullptr)
480 kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, jit_conv_conf_t(), *pd()->attr());
482 new cpu_reducer_2d_t<data_type::f32>(pd()->reducer_wei_conf_);
483 reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
484 init_rtus_driver<avx2>(this);
487 void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() const {
488 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
489 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
490 auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
491 auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
493 auto scratchpad = this->scratchpad();
495 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
496 const memory_desc_wrapper src_d(pd()->src_pd());
497 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
498 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
500 const auto &jcp = kernel_->jcp;
501 auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
503 data_t *diff_bias = pd()->wants_padded_bias()
504 ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
506 auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
508 auto rb = this->reducer_bias_;
509 rb->init(reducer_bia_scratchpad);
511 auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
513 auto rw = this->reducer_weights_;
514 rw->init(reducer_wei_scratchpad);
516 const int ndims = diff_dst_d.ndims();
517 // TODO (Roma): remove this restriction
518 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
520 const int nb_ic = jcp.nb_bcast;
521 const int nb_ic_blocking = jcp.nb_bcast_blocking;
522 const int bcast_work = div_up(nb_ic, nb_ic_blocking);
524 const int nb_oc = jcp.nb_load;
525 const int nb_oc_blocking = jcp.nb_load_blocking;
526 const int load_work = div_up(nb_oc, nb_oc_blocking);
528 const int sp_dim = jcp.reduce_dim;
529 const int mb_sp_work = jcp.mb * sp_dim;
531 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
532 const int stride_w = pd()->desc()->strides[ndims - 3];
533 const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
534 const int pad_l = pd()->desc()->padding[0][ndims - 3];
536 auto step = [](int default_step, int remaining, int tail_step) {
537 assert(default_step <= tail_step);
538 return remaining < tail_step ? remaining : default_step;
541 auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image,
542 data_t *store_to, size_t store_to_ld, const data_t *diff_dst,
543 const data_t *src, int ithr) {
544 auto p = jit_1x1_conv_call_s();
545 auto rp = rtus_driver_t<avx2>::call_params_t();
547 p.output_stride = store_to_ld * sizeof(float);
548 const int sp_step_def = jcp.nb_reduce_blocking * jcp.reduce_block;
551 for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) {
552 oc_b_step = step(12, nb_oc_blocking - oc_b, 18);
553 p.load_dim = oc_b_step * jcp.oc_block;
556 for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) {
557 ic_b_step = step(12, nb_ic_blocking - ic_b, 18);
558 p.bcast_dim = ic_b_step * jcp.ic_block;
559 rp.icb = p.bcast_dim / jcp.ic_block;
561 p.output_data = store_to + oc_b * store_to_ld
562 + ic_b * jcp.ic_block * jcp.oc_block;
564 /* spatial reduction */
566 for (int sp = sp_start; sp < sp_end; sp += sp_step) {
567 sp_step = step(sp_step_def, sp_end - sp, 192);
568 p.reduce_dim = sp_step;
569 rp.os = p.reduce_dim;
571 p.first_last_flag = sp == sp_start && first_image
572 ? FLAG_REDUCE_FIRST : 0;
574 p.load_data = diff_dst
575 + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block;
577 if (pd()->rtus_.reduce_src_) {
578 const int oh = sp / jcp.ow;
579 const int ow = sp % jcp.ow;
581 const int ih = nstl::max(oh * stride_h - pad_t, 0);
582 const int iw = nstl::max(ow * stride_w - pad_l, 0);
586 + ithr * pd()->rtus_.space_per_thread_
587 + (ic_b * jcp.is + sp) * jcp.ic_block;
590 + iw * src_d.blocking_desc().strides[0][2];
593 + ih * src_d.blocking_desc().strides[0][2]
594 + iw * src_d.blocking_desc().strides[0][3];
597 rtus_driver_->ker_(&rp);
599 p.bcast_data = rp.ws;
602 + (ic_b * jcp.reduce_dim + sp) * jcp.ic_block;
604 kernel_->jit_ker(&p);
610 auto ker = [&](const int ithr, const int nthr) {
611 assert(nthr == rw->balancer().nthr_);
613 const int w_njobs = rw->balancer().ithr_njobs(ithr);
614 if (w_njobs == 0) return;
616 /* setup: independent work (oc, ic) */
617 const int w_job_start = rw->balancer().ithr_job_off(ithr);
618 int g{0}, load_i{0}, bcast_i{0};
619 nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work,
620 bcast_i, bcast_work);
622 /* setup: reduction work (mb, sp) */
623 int mb_sp_start{0}, mb_sp_end{0};
624 balance211(mb_sp_work, rw->balancer().nthr_per_group_,
625 rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end);
626 int img_start{0}, sp_start{0};
627 nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim);
629 /* independent work */
630 for (int iwork = 0; iwork < w_njobs; ++iwork) {
631 const int oc_b = nb_oc_blocking * load_i;
632 const int ic_b = nb_ic_blocking * bcast_i;
634 const int _ic_b = g * nb_ic + ic_b;
635 const int _oc_b = g * nb_oc + oc_b;
640 if (rw->balancer().nthr_per_group_ == 1) {
641 const size_t off = pd()->with_groups()
642 ? diff_weights_d.blk_off(g, oc_b, ic_b)
643 : diff_weights_d.blk_off(oc_b, ic_b);
644 store_to = &diff_weights[off];
645 store_to_ld = jcp.ic * jcp.oc_block;
647 const size_t off = iwork * rw->balancer().job_size_;
649 rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off;
650 store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block;
657 for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step)
659 sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp);
661 const bool first_image = img == img_start;
662 oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to,
663 store_to_ld, &diff_dst[diff_dst_d.blk_off(img, _oc_b)],
664 &src[src_d.blk_off(img, _ic_b)], ithr);
670 nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i,
673 rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
676 auto ker_bias = [&](int ithr, int nthr) {
677 assert(nthr == rb->balancer().nthr_);
679 const int b_job_start = rb->balancer().ithr_job_off(ithr);
680 const int b_njobs = rb->balancer().ithr_njobs(ithr);
682 if (b_njobs == 0) return;
684 /* reduction dimension */
685 int img_start{0}, img_end{0};
686 balance211(jcp.mb, rb->balancer().nthr_per_group_,
687 rb->balancer().id_in_group(ithr), img_start, img_end);
690 int g_start{0}, ocb_start{0};
691 nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc);
693 for (int img = img_start; img < img_end; ++img) {
694 int g = g_start, ocb = ocb_start;
695 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
696 const size_t _oc = g * nb_oc + ocb;
698 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
700 rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad)
701 + b_job_loc * rb->balancer().job_size_;
703 if (img == img_start)
704 for (int o = 0; o < 8; ++o) d_bias[o] = 0.;
706 for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) {
708 for (int o = 0; o < 8; ++o)
709 d_bias[o] += d_dst[o];
713 nd_iterator_step(g, jcp.ngroups, ocb, nb_oc);
716 rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
719 parallel(0, [&](const int ithr, const int nthr) {
721 if (pd()->with_bias())
722 ker_bias(ithr, nthr);
725 /* TODO: put this in ker_bias */
726 if (pd()->wants_padded_bias()) {
727 assert(jcp.ngroups == 1);
728 for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
729 diff_bias_in[oc] = diff_bias[oc];