Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21 #include <cstring>
22
23 #include "jit_avx2_convolution.hpp"
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using namespace mkldnn::impl::status;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::memory_tracking::names;
32 using namespace mkldnn::impl::utils;
33
34 #define src_blk_off(f, n, c, d, h, w) \
35     (pd()->ndims() == 3) \
36     ? (f).blk_off(n, c, w) \
37     : (pd()->ndims() == 4) \
38     ? (f).blk_off(n, c, h, w) \
39     : (f).blk_off(n, c, d, h, w)
40
41 #define wht_blk_off_(f, g, ...) \
42     pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
43 #define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
44     (pd()->ndims() == 3) \
45     ? wht_blk_off_(f, g, oc, ic, kw) \
46     : (pd()->ndims() == 4) \
47     ? wht_blk_off_(f, g, oc, ic, kh, kw) \
48     : wht_blk_off_(f, g, oc, ic, kd, kh, kw)
49
50 void jit_avx2_convolution_fwd_t::execute_forward() const {
51     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
52     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
53     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
54     auto dst = reinterpret_cast<data_t *>(this->memory());
55
56     const memory_desc_wrapper src_d(pd()->src_pd());
57     const memory_desc_wrapper dst_d(pd()->dst_pd());
58     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
59     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
60
61     const auto &jcp = kernel_->jcp;
62     const int MB = pd()->MB();
63
64     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
65     const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.od
66         * jcp.oh;
67
68     auto ker = [&](const int ithr, const int nthr) {
69         size_t start{0}, end{0};
70         balance211(work_amount, nthr, ithr, start, end);
71
72         int icbb = 0;
73         while (icbb < jcp.nb_ic) {
74             int icb_step = jcp.nb_ic_blocking;
75             int icb_step_rem = jcp.nb_ic - icbb;
76             if (icb_step_rem < jcp.nb_ic_blocking_max)
77                 icb_step = icb_step_rem;
78
79             size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0};
80             nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work,
81                              od, jcp.od, oh, jcp.oh);
82             for (size_t iwork = start; iwork < end; ++iwork) {
83                 int ocb = ocbb * jcp.nb_oc_blocking;
84                 int ocb_num = jcp.nb_oc_blocking;
85
86                 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
87                     auto par_conv = jit_conv_call_s();
88
89                     const int ij = oh * jcp.stride_h;
90                     const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
91                     const int i_b_overflow = nstl::max(jcp.ih, ij
92                         + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
93
94                     const int dj = od * jcp.stride_d;
95                     const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
96                     const int d_b_overflow = nstl::max(jcp.id, dj
97                         + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id;
98
99                     const size_t _oc = g * jcp.nb_oc + ocb;
100                     const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb;
101
102                     const int ih = nstl::max(ij - jcp.t_pad
103                         + div_up(i_t_overflow,
104                                  (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
105
106                     const int id = nstl::max(dj - jcp.f_pad
107                         + div_up(d_t_overflow,
108                                  (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0);
109
110                     par_conv.src = &src[src_blk_off(src_d, n,
111                         jcp.ic == 3 ? 0 : _ic, id, ih, 0)];
112
113                     par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)];
114
115                     const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
116                     const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
117                     par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
118                             jcp.ic == 3 ? 0 : icb, wd, wh, 0)];
119
120                     if (icb == 0) {
121                         if (bias)
122                             par_conv.bias =
123                                     &bias[bias_d.blk_off(_oc * jcp.oc_block)];
124                         par_conv.flags |= FLAG_IC_FIRST;
125                     }
126
127                     if (icb + 1 == jcp.nb_ic) {
128                         par_conv.flags |= FLAG_IC_LAST;
129                     }
130
131                     par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
132
133                     par_conv.oc_blocks =
134                             nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
135
136                     par_conv.kw_padding = 0;
137                     const int kh_padding = jcp.kh
138                         - div_up(i_t_overflow, (jcp.dilate_h + 1))
139                         - div_up(i_b_overflow, (jcp.dilate_h + 1));
140                     par_conv.kh_padding = nstl::max(0, kh_padding);
141
142                     const int kd_padding = jcp.kd
143                         - div_up(d_t_overflow, (jcp.dilate_d + 1))
144                         - div_up(d_b_overflow, (jcp.dilate_d + 1));
145                     par_conv.kd_padding = nstl::max(0, kd_padding);
146
147                     kernel_->jit_ker(&par_conv);
148                 }
149                 nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work,
150                                 od, jcp.od, oh, jcp.oh);
151             }
152             icbb += icb_step;
153         }
154     };
155
156     if (pd()->wants_padded_bias()) {
157         auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
158         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
159         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
160                 jcp.oc - jcp.oc_without_padding);
161         bias = padded_bias;
162     }
163
164     parallel(0, ker);
165
166     if (pd()->wants_zero_pad_dst())
167         output_memory_primitive(0)->zero_pad();
168 }
169
170 void jit_avx2_convolution_fwd_t::execute_forward_with_dw_conv() const {
171     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
172     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
173     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
174     auto dst = reinterpret_cast<data_t *>(this->memory());
175
176     const memory_desc_wrapper src_d(pd()->src_pd());
177     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
178     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
179
180     const auto &jcp = kernel_->jcp;
181     const auto &jcp_dw = kernel_dw_->jcp;
182     const int MB = pd()->MB();
183
184     auto dw_bias = jcp_dw.conv_biases;
185
186     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
187     const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
188
189     auto ker = [&](const int ithr, const int nthr) {
190         auto compute_row_gen = [&](float* ws_p, int n, int g, int ocb, int ocb_num, int oh, int num_rows) {
191             for (int h = 0; h < num_rows; h++) {
192                 if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
193                     for (int chb = ocb; chb < ocb + ocb_num; chb++) {
194                         memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
195                                (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
196                     }
197                 } else {
198                     for (int icb = 0; icb < jcp.nb_ic; ++icb) {
199                         auto par_conv = jit_conv_call_s();
200
201                         const int ij = (oh + h) * jcp.stride_h;
202                         const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
203                         const int i_b_overflow = nstl::max(jcp.ih, ij
204                                                                    + (jcp.kh - 1) * (jcp.dilate_h + 1) - jcp.t_pad +
205                                                                    1) - jcp.ih;
206
207                         const size_t _oc = g * jcp.nb_oc + ocb;
208                         const size_t _ic = g * jcp.nb_ic + icb;
209
210                         const int ih = nstl::max(ij - jcp.t_pad
211                                                  + div_up(i_t_overflow,
212                                                           (jcp.dilate_h + 1)) * (jcp.dilate_h + 1), 0);
213                         par_conv.src = &src[src_d.blk_off(n,
214                                                           jcp.ic == 3 ? 0 : _ic, ih, 0)];
215
216                         par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow *
217                                              jcp.oc_block];
218
219                         const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
220                         par_conv.filt = &weights[pd()->with_groups()
221                                                  ? weights_d.blk_off(g, ocb,
222                                                                      jcp.ic == 3 ? 0 : icb, wh, 0)
223                                                  : weights_d.blk_off(ocb,
224                                                                      jcp.ic == 3 ? 0 : icb, wh, 0)];
225
226                         if (icb == 0) {
227                             if (bias)
228                                 par_conv.bias =
229                                         &bias[bias_d.blk_off(_oc * jcp.oc_block)];
230                             par_conv.flags |= FLAG_IC_FIRST;
231                         }
232
233                         if (icb + 1 == jcp.nb_ic) {
234                             par_conv.flags |= FLAG_IC_LAST;
235                         }
236
237                         par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
238
239                         par_conv.oc_blocks =
240                                 nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
241
242                         par_conv.kw_padding = 0;
243                         const int kh_padding = jcp.kh
244                                                - div_up(i_t_overflow, (jcp.dilate_h + 1))
245                                                - div_up(i_b_overflow, (jcp.dilate_h + 1));
246                         par_conv.kh_padding = nstl::max(0, kh_padding);
247                         kernel_->jit_ker(&par_conv);
248                     }
249                 }
250             }
251         };
252
253         auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int ocb_num,
254                                   int dst_idx) {
255             for (int chb = ocb; chb < nstl::min(ocb + ocb_num, jcp.nb_oc); chb++) {
256                 auto par_conv_dw = jit_conv_call_s();
257
258                 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
259                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
260                 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
261                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
262                 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
263                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
264
265                 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + chb*jcp_dw.ch_block*jcp_dw.oh*jcp_dw.ow +
266                                        dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
267
268                 par_conv_dw.kh_padding = jcp_dw.kh;
269                 par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
270                 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
271                 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
272                 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
273                 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
274
275                 kernel_dw_->jit_ker(&par_conv_dw);
276             }
277         };
278
279         size_t start{0}, end{0};
280         balance211(work_amount, nthr, ithr, start, end);
281
282         auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
283         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
284         auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
285
286         size_t n{0}, g{0}, ocbb{0}, oh{0};
287         nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work,
288                          oh, jcp.oh);
289         for (size_t iwork = start; iwork < end; ++iwork) {
290             int ocb = ocbb * jcp.nb_oc_blocking;
291             int ocb_num = jcp.nb_oc_blocking;
292
293             if (iwork == start || oh == 0) {
294                 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh - 1, 2);
295             } else {
296                 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh, 1);
297             }
298
299             if (iwork > start && ((oh - 1) % jcp_dw.stride_h == 0) && oh > 0) {
300                 compute_row_dw(pbuf, n, ocb, ocb_num, oh - 1);
301             }
302
303             if ((iwork == end - 1 || (int) oh == jcp.oh - 1) && ((oh) % jcp_dw.stride_h == 0)) {
304                 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh + 1, 1);
305                 compute_row_dw(pbuf, n, ocb, ocb_num, oh);
306             }
307
308             nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work,
309                              oh, jcp.oh);
310         }
311     };
312
313     if (pd()->wants_padded_bias()) {
314         auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
315         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
316         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
317                 jcp.oc - jcp.oc_without_padding);
318         bias = padded_bias;
319
320         auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
321         utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
322         utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
323                          jcp.oc - jcp.oc_without_padding);
324         dw_bias = dw_padded_bias;
325     }
326
327     parallel(0, ker);
328
329     if (pd()->wants_zero_pad_dst())
330         output_memory_primitive(0)->zero_pad();
331 }
332
333 void jit_avx2_convolution_bwd_data_t::execute_backward_data() const {
334     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
335     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
336     auto diff_src = reinterpret_cast<data_t *>(this->memory());
337
338     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
339     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
340     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
341
342     const auto &jcp = kernel_->jcp;
343     const int MB = pd()->MB();
344
345     int icb_work = jcp.nb_ic / jcp.nb_ic_blocking;
346     int ih_block_size = jcp.ih;
347     int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
348     size_t work_amount = MB * jcp.ngroups * icb_work * num_ih_blocks;
349     if (work_amount < (size_t)2 * mkldnn_get_max_threads()) {
350         ih_block_size = 1;
351         num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
352         work_amount *= num_ih_blocks;
353     }
354
355     auto ker = [&](const int ithr, const int nthr) {
356         size_t start{0}, end{0};
357         balance211(work_amount, nthr, ithr, start, end);
358
359         size_t n{0}, g{0}, icbb{0}, ihb{0};
360         nd_iterator_init(start, n, MB, g, jcp.ngroups, icbb, icb_work,
361                          ihb, num_ih_blocks);
362
363         for (size_t iwork = start; iwork < end; ++iwork) {
364             for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking)
365             for (int id = 0; id < jcp.id; ++id) {
366                 auto par_conv = jit_conv_call_s();
367
368                 const int idp = jcp.id + 2 * jcp.f_pad;
369                 const int d_t_overflow = nstl::max(0,
370                         jcp.kd - 1 - id - jcp.f_pad);
371                 const int back_pad = idp - jcp.id - jcp.f_pad;
372                 const int d_b_overflow = nstl::max(0,
373                         jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
374                 const int od = id + jcp.f_pad - d_b_overflow;
375
376                 int ih_start = ihb * ih_block_size;
377                 int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size);
378                 for (int ih = ih_start; ih < ih_end; ++ih) {
379
380                     const int i_t_overflow = nstl::max(0, (jcp.kh - 1
381                                         - ih - jcp.t_pad) / jcp.stride_h);
382                     const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih
383                                         + ih - jcp.b_pad) / jcp.stride_h);
384                     int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
385                                 + jcp.b_pad - ih) % jcp.stride_h);
386                     int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h;
387
388                     par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
389                     par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo)
390                               / jcp.stride_h + 1 - i_t_overflow - i_b_overflow;
391                     par_conv.kw_padding = 0;
392
393                     const int k_lo = overflow_kh_lo
394                                    + i_b_overflow * jcp.stride_h;
395                     const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h;
396
397                     par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
398                         /*jcp.ic == 3 ? 0 :*/
399                         g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
400                     par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
401                             n, g * jcp.nb_oc + oc, od, oh, 0)];
402                     par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
403                                 jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
404                                 d_b_overflow, k_lo, 0)];
405
406                     par_conv.src_prf = nullptr;
407                     par_conv.dst_prf = nullptr;
408                     par_conv.filt_prf = nullptr;
409                     par_conv.channel = oc;
410                     par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc,
411                                        jcp.nb_oc_blocking);
412
413                     kernel_->jit_ker(&par_conv);
414                 }
415             }
416             nd_iterator_step(n, MB, g, jcp.ngroups, icbb, icb_work, ihb,
417                              num_ih_blocks);
418         }
419     };
420
421     parallel(0, ker);
422 }
423
424 void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() const {
425     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
426     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
427     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
428     auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
429
430     auto scratchpad = this->scratchpad();
431
432     data_t *diff_bias = pd()->wants_padded_bias()
433         ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
434
435     const memory_desc_wrapper src_d(pd()->src_pd(0));
436     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
437     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
438
439     const auto &jcp = kernel_->jcp;
440
441     auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
442             prefix_reducer_bia);
443     auto rb = this->reducer_bias_;
444     rb->init(reducer_bia_scratchpad);
445
446     auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
447             prefix_reducer_wei);
448     auto rw = this->reducer_weights_;
449     rw->init(reducer_wei_scratchpad);
450
451     auto ker = [&](int ithr, int nthr) {
452         assert(nthr == rw->balancer().nthr_);
453
454         const int w_job_start = rw->balancer().ithr_job_off(ithr);
455         const int w_njobs = rw->balancer().ithr_njobs(ithr);
456
457         if (w_njobs == 0) return;
458
459         /* reduction dimension */
460         int img_od_start{0}, img_od_end{0}, img{0}, od_s{0};
461         balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_,
462                 rw->balancer().id_in_group(ithr), img_od_start, img_od_end);
463
464         int img_start = img_od_start, img_end = img_od_end;
465         nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
466         const int img_first = img;
467
468         /* jobs */
469         int g_start{0}, ocb_start{0}, icb_start{0};
470         nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start,
471                 jcp.nb_oc, icb_start, jcp.nb_ic);
472
473         while (img_start < img_end) {
474             int g = g_start, ocb = ocb_start, icb = icb_start;
475
476             const int work_rem = img_end - img_start;
477             const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
478             const int id_s = od_s * jcp.stride_d;
479             const int idp = jcp.id + jcp.f_pad + jcp.back_pad;
480
481             if (id_s < idp - jcp.back_pad - jcp.kd + 1)
482             for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) {
483                 const size_t _oc = g * jcp.nb_oc + ocb;
484                 const size_t _ic = g * jcp.nb_ic + icb;
485
486                 /* TODO: put dw <-- 0 in kernel */
487                 if (img == img_first)
488                     array_set(rw->get_local_ptr(ithr, diff_weights,
489                                 reducer_wei_scratchpad) +
490                             w_job_loc * rw->balancer().job_size_, 0,
491                             rw->balancer().job_size_);
492
493                 for (int od = od_s; od < od_e; ++od) {
494                     const int id = od * jcp.stride_d;
495                     if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break;
496
497                     auto par_conv = jit_conv_call_s();
498                     par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
499                     par_conv.dst =
500                         &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)];
501                     par_conv.filt = rw->get_local_ptr(ithr, diff_weights,
502                             reducer_wei_scratchpad) +
503                         w_job_loc * rw->balancer().job_size_;
504
505                     kernel_->jit_ker(&par_conv);
506                 }
507                 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb,
508                         jcp.nb_ic);
509             }
510             nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
511         }
512         rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
513     };
514
515     auto ker_bias = [&](int ithr, int nthr) {
516         assert(nthr == rb->balancer().nthr_);
517
518         const int b_job_start = rb->balancer().ithr_job_off(ithr);
519         const int b_njobs = rb->balancer().ithr_njobs(ithr);
520
521         if (b_njobs == 0) return;
522
523         /* reduction dimension */
524         int img_start{0}, img_end{0};
525         balance211(jcp.mb, rb->balancer().nthr_per_group_,
526                 rb->balancer().id_in_group(ithr), img_start, img_end);
527
528         /* jobs */
529         int g_start{0}, ocb_start{0};
530         nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start,
531                 jcp.nb_oc);
532
533         for (int img = img_start; img < img_end; ++img) {
534             int g = g_start, ocb = ocb_start;
535             for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
536                 const size_t _oc = g * jcp.nb_oc + ocb;
537
538                 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
539                 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
540                         reducer_bia_scratchpad) +
541                     b_job_loc * rb->balancer().job_size_;
542
543                 if (img == img_start)
544                     for (int o = 0; o < 8; ++o)
545                         d_bias[o] = 0.;
546
547                 for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) {
548                     PRAGMA_OMP_SIMD()
549                     for (int o = 0; o < 8; ++o)
550                         d_bias[o] += d_dst[o];
551                     d_dst += 8;
552                 }
553
554                 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
555             }
556         }
557         rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
558     };
559
560     parallel(0, [&](const int ithr, const int nthr) {
561         ker(ithr, nthr);
562         if (pd()->with_bias())
563             ker_bias(ithr, nthr);
564     });
565
566     /* TODO: put this in ker_bias */
567     if (pd()->wants_padded_bias()) {
568         assert(jcp.ngroups == 1);
569         for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
570             diff_bias_in[oc] = diff_bias[oc];
571     }
572 }
573
574 }
575 }
576 }
577
578 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s