1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
23 #include "jit_avx2_convolution.hpp"
29 using namespace mkldnn::impl::status;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::memory_tracking::names;
32 using namespace mkldnn::impl::utils;
34 #define src_blk_off(f, n, c, d, h, w) \
35 (pd()->ndims() == 3) \
36 ? (f).blk_off(n, c, w) \
37 : (pd()->ndims() == 4) \
38 ? (f).blk_off(n, c, h, w) \
39 : (f).blk_off(n, c, d, h, w)
41 #define wht_blk_off_(f, g, ...) \
42 pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
43 #define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
44 (pd()->ndims() == 3) \
45 ? wht_blk_off_(f, g, oc, ic, kw) \
46 : (pd()->ndims() == 4) \
47 ? wht_blk_off_(f, g, oc, ic, kh, kw) \
48 : wht_blk_off_(f, g, oc, ic, kd, kh, kw)
50 void jit_avx2_convolution_fwd_t::execute_forward() const {
51 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
52 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
53 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
54 auto dst = reinterpret_cast<data_t *>(this->memory());
56 const memory_desc_wrapper src_d(pd()->src_pd());
57 const memory_desc_wrapper dst_d(pd()->dst_pd());
58 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
59 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
61 const auto &jcp = kernel_->jcp;
62 const int MB = pd()->MB();
64 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
65 const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.od
68 auto ker = [&](const int ithr, const int nthr) {
69 size_t start{0}, end{0};
70 balance211(work_amount, nthr, ithr, start, end);
73 while (icbb < jcp.nb_ic) {
74 int icb_step = jcp.nb_ic_blocking;
75 int icb_step_rem = jcp.nb_ic - icbb;
76 if (icb_step_rem < jcp.nb_ic_blocking_max)
77 icb_step = icb_step_rem;
79 size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0};
80 nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work,
81 od, jcp.od, oh, jcp.oh);
82 for (size_t iwork = start; iwork < end; ++iwork) {
83 int ocb = ocbb * jcp.nb_oc_blocking;
84 int ocb_num = jcp.nb_oc_blocking;
86 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
87 auto par_conv = jit_conv_call_s();
89 const int ij = oh * jcp.stride_h;
90 const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
91 const int i_b_overflow = nstl::max(jcp.ih, ij
92 + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
94 const int dj = od * jcp.stride_d;
95 const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
96 const int d_b_overflow = nstl::max(jcp.id, dj
97 + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id;
99 const size_t _oc = g * jcp.nb_oc + ocb;
100 const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb;
102 const int ih = nstl::max(ij - jcp.t_pad
103 + div_up(i_t_overflow,
104 (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
106 const int id = nstl::max(dj - jcp.f_pad
107 + div_up(d_t_overflow,
108 (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0);
110 par_conv.src = &src[src_blk_off(src_d, n,
111 jcp.ic == 3 ? 0 : _ic, id, ih, 0)];
113 par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)];
115 const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
116 const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
117 par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
118 jcp.ic == 3 ? 0 : icb, wd, wh, 0)];
123 &bias[bias_d.blk_off(_oc * jcp.oc_block)];
124 par_conv.flags |= FLAG_IC_FIRST;
127 if (icb + 1 == jcp.nb_ic) {
128 par_conv.flags |= FLAG_IC_LAST;
131 par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
134 nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
136 par_conv.kw_padding = 0;
137 const int kh_padding = jcp.kh
138 - div_up(i_t_overflow, (jcp.dilate_h + 1))
139 - div_up(i_b_overflow, (jcp.dilate_h + 1));
140 par_conv.kh_padding = nstl::max(0, kh_padding);
142 const int kd_padding = jcp.kd
143 - div_up(d_t_overflow, (jcp.dilate_d + 1))
144 - div_up(d_b_overflow, (jcp.dilate_d + 1));
145 par_conv.kd_padding = nstl::max(0, kd_padding);
147 kernel_->jit_ker(&par_conv);
149 nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work,
150 od, jcp.od, oh, jcp.oh);
156 if (pd()->wants_padded_bias()) {
157 auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
158 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
159 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
160 jcp.oc - jcp.oc_without_padding);
166 if (pd()->wants_zero_pad_dst())
167 output_memory_primitive(0)->zero_pad();
170 void jit_avx2_convolution_fwd_t::execute_forward_with_dw_conv() const {
171 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
172 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
173 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
174 auto dst = reinterpret_cast<data_t *>(this->memory());
176 const memory_desc_wrapper src_d(pd()->src_pd());
177 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
178 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
180 const auto &jcp = kernel_->jcp;
181 const auto &jcp_dw = kernel_dw_->jcp;
182 const int MB = pd()->MB();
184 auto dw_bias = jcp_dw.conv_biases;
186 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
187 const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
189 auto ker = [&](const int ithr, const int nthr) {
190 auto compute_row_gen = [&](float* ws_p, int n, int g, int ocb, int ocb_num, int oh, int num_rows) {
191 for (int h = 0; h < num_rows; h++) {
192 if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
193 for (int chb = ocb; chb < ocb + ocb_num; chb++) {
194 memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
195 (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
198 for (int icb = 0; icb < jcp.nb_ic; ++icb) {
199 auto par_conv = jit_conv_call_s();
201 const int ij = (oh + h) * jcp.stride_h;
202 const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
203 const int i_b_overflow = nstl::max(jcp.ih, ij
204 + (jcp.kh - 1) * (jcp.dilate_h + 1) - jcp.t_pad +
207 const size_t _oc = g * jcp.nb_oc + ocb;
208 const size_t _ic = g * jcp.nb_ic + icb;
210 const int ih = nstl::max(ij - jcp.t_pad
211 + div_up(i_t_overflow,
212 (jcp.dilate_h + 1)) * (jcp.dilate_h + 1), 0);
213 par_conv.src = &src[src_d.blk_off(n,
214 jcp.ic == 3 ? 0 : _ic, ih, 0)];
216 par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow *
219 const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
220 par_conv.filt = &weights[pd()->with_groups()
221 ? weights_d.blk_off(g, ocb,
222 jcp.ic == 3 ? 0 : icb, wh, 0)
223 : weights_d.blk_off(ocb,
224 jcp.ic == 3 ? 0 : icb, wh, 0)];
229 &bias[bias_d.blk_off(_oc * jcp.oc_block)];
230 par_conv.flags |= FLAG_IC_FIRST;
233 if (icb + 1 == jcp.nb_ic) {
234 par_conv.flags |= FLAG_IC_LAST;
237 par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
240 nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
242 par_conv.kw_padding = 0;
243 const int kh_padding = jcp.kh
244 - div_up(i_t_overflow, (jcp.dilate_h + 1))
245 - div_up(i_b_overflow, (jcp.dilate_h + 1));
246 par_conv.kh_padding = nstl::max(0, kh_padding);
247 kernel_->jit_ker(&par_conv);
253 auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int ocb_num,
255 for (int chb = ocb; chb < nstl::min(ocb + ocb_num, jcp.nb_oc); chb++) {
256 auto par_conv_dw = jit_conv_call_s();
258 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
259 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
260 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
261 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
262 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
263 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
265 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + chb*jcp_dw.ch_block*jcp_dw.oh*jcp_dw.ow +
266 dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
268 par_conv_dw.kh_padding = jcp_dw.kh;
269 par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
270 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
271 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
272 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
273 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
275 kernel_dw_->jit_ker(&par_conv_dw);
279 size_t start{0}, end{0};
280 balance211(work_amount, nthr, ithr, start, end);
282 auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
283 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
284 auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
286 size_t n{0}, g{0}, ocbb{0}, oh{0};
287 nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work,
289 for (size_t iwork = start; iwork < end; ++iwork) {
290 int ocb = ocbb * jcp.nb_oc_blocking;
291 int ocb_num = jcp.nb_oc_blocking;
293 if (iwork == start || oh == 0) {
294 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh - 1, 2);
296 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh, 1);
299 if (iwork > start && ((oh - 1) % jcp_dw.stride_h == 0) && oh > 0) {
300 compute_row_dw(pbuf, n, ocb, ocb_num, oh - 1);
303 if ((iwork == end - 1 || (int) oh == jcp.oh - 1) && ((oh) % jcp_dw.stride_h == 0)) {
304 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh + 1, 1);
305 compute_row_dw(pbuf, n, ocb, ocb_num, oh);
308 nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work,
313 if (pd()->wants_padded_bias()) {
314 auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
315 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
316 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
317 jcp.oc - jcp.oc_without_padding);
320 auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
321 utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
322 utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
323 jcp.oc - jcp.oc_without_padding);
324 dw_bias = dw_padded_bias;
329 if (pd()->wants_zero_pad_dst())
330 output_memory_primitive(0)->zero_pad();
333 void jit_avx2_convolution_bwd_data_t::execute_backward_data() const {
334 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
335 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
336 auto diff_src = reinterpret_cast<data_t *>(this->memory());
338 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
339 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
340 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
342 const auto &jcp = kernel_->jcp;
343 const int MB = pd()->MB();
345 int icb_work = jcp.nb_ic / jcp.nb_ic_blocking;
346 int ih_block_size = jcp.ih;
347 int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
348 size_t work_amount = MB * jcp.ngroups * icb_work * num_ih_blocks;
349 if (work_amount < (size_t)2 * mkldnn_get_max_threads()) {
351 num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
352 work_amount *= num_ih_blocks;
355 auto ker = [&](const int ithr, const int nthr) {
356 size_t start{0}, end{0};
357 balance211(work_amount, nthr, ithr, start, end);
359 size_t n{0}, g{0}, icbb{0}, ihb{0};
360 nd_iterator_init(start, n, MB, g, jcp.ngroups, icbb, icb_work,
363 for (size_t iwork = start; iwork < end; ++iwork) {
364 for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking)
365 for (int id = 0; id < jcp.id; ++id) {
366 auto par_conv = jit_conv_call_s();
368 const int idp = jcp.id + 2 * jcp.f_pad;
369 const int d_t_overflow = nstl::max(0,
370 jcp.kd - 1 - id - jcp.f_pad);
371 const int back_pad = idp - jcp.id - jcp.f_pad;
372 const int d_b_overflow = nstl::max(0,
373 jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
374 const int od = id + jcp.f_pad - d_b_overflow;
376 int ih_start = ihb * ih_block_size;
377 int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size);
378 for (int ih = ih_start; ih < ih_end; ++ih) {
380 const int i_t_overflow = nstl::max(0, (jcp.kh - 1
381 - ih - jcp.t_pad) / jcp.stride_h);
382 const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih
383 + ih - jcp.b_pad) / jcp.stride_h);
384 int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
385 + jcp.b_pad - ih) % jcp.stride_h);
386 int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h;
388 par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
389 par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo)
390 / jcp.stride_h + 1 - i_t_overflow - i_b_overflow;
391 par_conv.kw_padding = 0;
393 const int k_lo = overflow_kh_lo
394 + i_b_overflow * jcp.stride_h;
395 const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h;
397 par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
398 /*jcp.ic == 3 ? 0 :*/
399 g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
400 par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
401 n, g * jcp.nb_oc + oc, od, oh, 0)];
402 par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
403 jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
404 d_b_overflow, k_lo, 0)];
406 par_conv.src_prf = nullptr;
407 par_conv.dst_prf = nullptr;
408 par_conv.filt_prf = nullptr;
409 par_conv.channel = oc;
410 par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc,
413 kernel_->jit_ker(&par_conv);
416 nd_iterator_step(n, MB, g, jcp.ngroups, icbb, icb_work, ihb,
424 void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() const {
425 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
426 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
427 auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
428 auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
430 auto scratchpad = this->scratchpad();
432 data_t *diff_bias = pd()->wants_padded_bias()
433 ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
435 const memory_desc_wrapper src_d(pd()->src_pd(0));
436 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
437 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
439 const auto &jcp = kernel_->jcp;
441 auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
443 auto rb = this->reducer_bias_;
444 rb->init(reducer_bia_scratchpad);
446 auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
448 auto rw = this->reducer_weights_;
449 rw->init(reducer_wei_scratchpad);
451 auto ker = [&](int ithr, int nthr) {
452 assert(nthr == rw->balancer().nthr_);
454 const int w_job_start = rw->balancer().ithr_job_off(ithr);
455 const int w_njobs = rw->balancer().ithr_njobs(ithr);
457 if (w_njobs == 0) return;
459 /* reduction dimension */
460 int img_od_start{0}, img_od_end{0}, img{0}, od_s{0};
461 balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_,
462 rw->balancer().id_in_group(ithr), img_od_start, img_od_end);
464 int img_start = img_od_start, img_end = img_od_end;
465 nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
466 const int img_first = img;
469 int g_start{0}, ocb_start{0}, icb_start{0};
470 nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start,
471 jcp.nb_oc, icb_start, jcp.nb_ic);
473 while (img_start < img_end) {
474 int g = g_start, ocb = ocb_start, icb = icb_start;
476 const int work_rem = img_end - img_start;
477 const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
478 const int id_s = od_s * jcp.stride_d;
479 const int idp = jcp.id + jcp.f_pad + jcp.back_pad;
481 if (id_s < idp - jcp.back_pad - jcp.kd + 1)
482 for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) {
483 const size_t _oc = g * jcp.nb_oc + ocb;
484 const size_t _ic = g * jcp.nb_ic + icb;
486 /* TODO: put dw <-- 0 in kernel */
487 if (img == img_first)
488 array_set(rw->get_local_ptr(ithr, diff_weights,
489 reducer_wei_scratchpad) +
490 w_job_loc * rw->balancer().job_size_, 0,
491 rw->balancer().job_size_);
493 for (int od = od_s; od < od_e; ++od) {
494 const int id = od * jcp.stride_d;
495 if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break;
497 auto par_conv = jit_conv_call_s();
498 par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
500 &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)];
501 par_conv.filt = rw->get_local_ptr(ithr, diff_weights,
502 reducer_wei_scratchpad) +
503 w_job_loc * rw->balancer().job_size_;
505 kernel_->jit_ker(&par_conv);
507 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb,
510 nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
512 rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
515 auto ker_bias = [&](int ithr, int nthr) {
516 assert(nthr == rb->balancer().nthr_);
518 const int b_job_start = rb->balancer().ithr_job_off(ithr);
519 const int b_njobs = rb->balancer().ithr_njobs(ithr);
521 if (b_njobs == 0) return;
523 /* reduction dimension */
524 int img_start{0}, img_end{0};
525 balance211(jcp.mb, rb->balancer().nthr_per_group_,
526 rb->balancer().id_in_group(ithr), img_start, img_end);
529 int g_start{0}, ocb_start{0};
530 nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start,
533 for (int img = img_start; img < img_end; ++img) {
534 int g = g_start, ocb = ocb_start;
535 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
536 const size_t _oc = g * jcp.nb_oc + ocb;
538 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
539 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
540 reducer_bia_scratchpad) +
541 b_job_loc * rb->balancer().job_size_;
543 if (img == img_start)
544 for (int o = 0; o < 8; ++o)
547 for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) {
549 for (int o = 0; o < 8; ++o)
550 d_bias[o] += d_dst[o];
554 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
557 rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
560 parallel(0, [&](const int ithr, const int nthr) {
562 if (pd()->with_bias())
563 ker_bias(ithr, nthr);
566 /* TODO: put this in ker_bias */
567 if (pd()->wants_padded_bias()) {
568 assert(jcp.ngroups == 1);
569 for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
570 diff_bias_in[oc] = diff_bias[oc];
578 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s