updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_1x1_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21
22 #include "jit_generator.hpp"
23
24 #include "jit_avx512_common_1x1_convolution.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
34
35 #define data_blk_off(f, n, c, h, w) \
36     ((ndims == 3) \
37     ? (f).blk_off(n, c, w) \
38     : (f).blk_off(n, c, h, w))
39
40
41 namespace {
42 template <typename T, typename U>
43 void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
44     T nx, T &nx_start, T &nx_end, T nx_divider)
45 {
46     const int grp_count = nstl::min(nx_divider, nthr);
47     const int grp_size_big = nthr / grp_count + 1;
48     const int grp_size_small = nthr / grp_count;
49     const int n_grp_big = nthr % grp_count;
50     const int threads_in_big_groups = n_grp_big * grp_size_big;
51
52     const int ithr_bound_distance = ithr - threads_in_big_groups;
53     T grp, grp_ithr, grp_nthr;
54     if (ithr_bound_distance < 0) { // ithr in first groups
55         grp = ithr / grp_size_big;
56         grp_ithr = ithr % grp_size_big;
57         grp_nthr = grp_size_big;
58     } else { // ithr in last groups
59         grp = n_grp_big + ithr_bound_distance / grp_size_small;
60         grp_ithr = ithr_bound_distance % grp_size_small;
61         grp_nthr = grp_size_small;
62     }
63
64     balance211(nx, grp_count, grp, nx_start, nx_end);
65     balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
66 }
67 }
68 /* convolution forward */
69
70 template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
71 void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
72 execute_forward() const {
73     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
74     auto weights =
75         reinterpret_cast<const wei_data_t *>(this->input_memory(1));
76     auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
77     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
78
79     auto scratchpad = this->scratchpad();
80
81     auto &jcp = kernel_->jcp;
82     if (pd()->wants_padded_bias()) {
83         auto padded_bias = scratchpad.template get<dst_data_t>(
84                 key_conv_padded_bias);
85         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
86         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
87                 jcp.oc - jcp.oc_without_padding);
88         bias = padded_bias;
89     }
90
91     const int MB = pd()->MB();
92     const int work_amount = MB * jcp.ngroups * jcp.nb_bcast * jcp.nb_load;
93
94     parallel(0, (size_t)work_amount, [&](const int ithr, const int nthr) {
95         execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
96     });
97
98     if (pd()->wants_zero_pad_dst())
99         output_memory_primitive(0)->zero_pad();
100 }
101
102 template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
103 void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
104 execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
105         const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst,
106         const memory_tracking::grantor_t &scratchpad) const {
107     const memory_desc_wrapper src_d(pd()->src_pd());
108     const memory_desc_wrapper dst_d(pd()->dst_pd());
109     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
110
111     auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
112
113     const int ndims = src_d.ndims();
114     const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
115     const int stride_w = pd()->desc()->strides[ndims - 3];
116     const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
117     const int pad_l = pd()->desc()->padding[0][ndims - 3];
118
119     const auto &jcp = kernel_->jcp;
120     const int MB = pd()->MB();
121     const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
122
123     auto step = [](int default_step, int remaining, int tail_step) {
124         assert(default_step <= tail_step);
125         return remaining < tail_step ? remaining : default_step;
126     };
127
128     auto p = jit_1x1_conv_call_s();
129
130     auto rp = rtus_driver_t<avx512_common>::call_params_t();
131
132     const int nb_oc = jcp.nb_load;
133     const int nb_ic = jcp.nb_reduce;
134     const int nb_ic_blocking = jcp.nb_reduce_blocking;
135     const int os_block = jcp.bcast_block;
136
137     int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
138     balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
139         jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count);
140
141     auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
142             int &oh, int &ow, int &ih, int &iw)
143     {
144         int osb{0};
145         nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
146             jcp.nb_bcast);
147         bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
148                 jcp.nb_bcast_blocking_max);
149         bcast_step = nstl::min(bcast_step, bcast_end - iwork);
150
151         const int os = osb * os_block;
152         oh = os / jcp.ow;
153         ow = os % jcp.ow;
154
155         ih = nstl::max(oh * stride_h - pad_t, 0);
156         iw = nstl::max(ow * stride_w - pad_l, 0);
157         rp.iw_start = iw;
158
159         p.bcast_dim = this_block_size(os, jcp.os,
160             bcast_step * os_block);
161         rp.os = p.bcast_dim;
162     };
163
164     auto init_load = [&](int ocb, int &load_step)
165     {
166         load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
167             jcp.nb_load_blocking_max);
168         p.load_dim = this_block_size(ocb * jcp.oc_block,
169             ocb_end * jcp.oc_block, load_step * jcp.oc_block);
170     };
171
172     auto init_reduce = [&](int icb)
173     {
174         const int nb_ic_blocking_step =
175             nstl::min(icb + nb_ic_blocking, nb_ic) - icb;
176         p.first_last_flag = 0
177             | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
178             | (icb + nb_ic_blocking_step >= nb_ic
179                     ? FLAG_REDUCE_LAST : 0);
180
181         p.reduce_dim = this_block_size(icb * jcp.ic_block,
182             jcp.ic, nb_ic_blocking_step * jcp.ic_block);
183         rp.icb = p.reduce_dim / jcp.reduce_block;
184     };
185
186     auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow,
187         int ih, int iw)
188     {
189
190         const int _ocb = g * nb_oc + ocb;
191         const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
192
193         p.output_data = &dst[dst_off];
194         p.bias_data = &bias[_ocb * jcp.oc_block];
195         p.load_data = &weights[pd()->with_groups()
196             ? weights_d.blk_off(g, ocb, icb)
197             : weights_d.blk_off(ocb, icb)];
198
199         const int _icb = g * nb_ic + icb;
200         if (pd()->rtus_.reduce_src_) {
201             rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
202                 + _icb * jcp.is * jcp.ic_block;
203             if (ocb == ocb_start) {
204                 rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
205                 rtus_driver_->ker_(&rp);
206             }
207             p.bcast_data = rp.ws;
208         } else
209             p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw);
210
211         p.oc_off = _ocb * jcp.oc_block * sizeof(dst_data_t);
212
213         kernel_->jit_ker(&p);
214     };
215
216     if (jcp.loop_order == loop_rlb) {
217         for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
218             init_reduce(icb);
219             int ocb = ocb_start;
220             while (ocb < ocb_end) {
221                 int load_step;
222                 init_load(ocb, load_step);
223                 int iwork = bcast_start;
224                 while (iwork < bcast_end) {
225                     int n, g, bcast_step, oh, ow, ih, iw;
226                     init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
227                     inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
228                     iwork += bcast_step;
229                 }
230                 ocb += load_step;
231             }
232         }
233     } else if (jcp.loop_order == loop_lbr) {
234         int ocb = ocb_start;
235         while (ocb < ocb_end) {
236             int load_step;
237             init_load(ocb, load_step);
238             int iwork = bcast_start;
239             while (iwork < bcast_end) {
240                 int n, g, bcast_step, oh, ow, ih, iw;
241                 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
242                 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
243                     init_reduce(icb);
244                     inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
245                 }
246                 iwork += bcast_step;
247             }
248             ocb += load_step;
249         }
250     } else if (jcp.loop_order == loop_rbl) {
251         for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
252             init_reduce(icb);
253             int iwork = bcast_start;
254             while (iwork < bcast_end) {
255                 int n, g, bcast_step, oh, ow, ih, iw;
256                 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
257                 int ocb = ocb_start;
258                 while (ocb < ocb_end) {
259                     int load_step;
260                     init_load(ocb, load_step);
261                     inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
262                     ocb += load_step;
263                 }
264                 iwork += bcast_step;
265             }
266         }
267     } else if (jcp.loop_order == loop_blr) {
268         int iwork = bcast_start;
269         while (iwork < bcast_end) {
270             int n, g, bcast_step, oh, ow, ih, iw;
271             init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
272             int ocb = ocb_start;
273             while (ocb < ocb_end) {
274                 int load_step;
275                 init_load(ocb, load_step);
276                 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
277                     init_reduce(icb);
278                     inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
279                 }
280                 ocb += load_step;
281             }
282             iwork += bcast_step;
283         }
284     } else {
285         assert(!"unsupported loop order");
286     }
287 }
288
289
290 template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
291 template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::s16,
292     data_type::s16, data_type::s32>;
293 /* convolution backward wtr data */
294
295 template <data_type_t diff_dst_type, data_type_t wei_type,
296          data_type_t diff_src_type>
297 void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type,
298      diff_src_type>::execute_backward_data() const {
299     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
300         (this->input_memory(0));
301     auto weights = reinterpret_cast<const wei_data_t *>
302         (this->input_memory(1));
303     auto diff_src = reinterpret_cast<diff_src_data_t *>(this->memory());
304
305     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
306     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
307     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
308
309     auto rtus_space = scratchpad().template get<diff_src_data_t>(
310             key_conv_rtus_space);
311
312     const int ndims = diff_src_d.ndims();
313     const auto &jcp = kernel_->jcp;
314     const int MB = pd()->MB();
315
316     assert(jcp.stride_w == 1 && jcp.stride_h == 1);
317
318     const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
319     const int stride_w = pd()->desc()->strides[ndims - 3];
320     const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
321     const int pad_l = pd()->desc()->padding[0][ndims - 3];
322
323     const int nb_ic = jcp.nb_load;
324     const int nb_oc = jcp.nb_reduce;
325     const int os_block = jcp.bcast_block;
326     const int nb_oc_blocking = jcp.nb_reduce_blocking;
327
328     const int work_amount = MB * jcp.ngroups * jcp.nb_bcast * jcp.nb_load;
329
330     auto step = [](int default_step, int remaining, int tail_step) {
331         assert(default_step <= tail_step);
332         return remaining < tail_step ? remaining : default_step;
333     };
334
335     parallel(0, (size_t)work_amount, [&](const int ithr, const int nthr) {
336         auto p = jit_1x1_conv_call_s();
337         auto rp = rtus_driver_t<avx512_common>::call_params_t();
338
339         int bcast_start{0}, bcast_end{0}, icb_start{0}, icb_end{0};
340         balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
341             jcp.nb_load, icb_start, icb_end, jcp.load_grp_count);
342
343         bool reduce_outer = (jcp.loop_order == loop_rbl
344             || jcp.loop_order == loop_rlb);
345         int nboc_outer = reduce_outer ? nb_oc : 1;
346         int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1;
347
348         int nboc_inner = reduce_outer ? 1 : nb_oc;
349         int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking;
350
351         for (int ocb_outer = 0; ocb_outer < nboc_outer;
352             ocb_outer += ocb_outer_step) {
353             size_t cur_ocb_outer =
354                 nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - ocb_outer;
355
356             int load_step = 0;
357             for (int icb = icb_start; icb < icb_end; icb += load_step) {
358                 load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
359                         jcp.nb_load_blocking_max);
360
361                 p.load_dim = this_block_size(icb * jcp.ic_block,
362                     icb_end * jcp.ic_block, load_step * jcp.ic_block);
363                 rp.icb = p.load_dim / jcp.ic_block;
364
365                 int bcast_step;
366                 for (int iwork = bcast_start; iwork < bcast_end;
367                     iwork += bcast_step)
368                 {
369                     int n{0}, g{0}, osb{0};
370                     nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
371                             jcp.nb_bcast);
372
373                     bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
374                             jcp.nb_bcast_blocking_max);
375                     bcast_step = nstl::min(bcast_step, bcast_end - iwork);
376
377                     const int os = osb * os_block;
378                     p.bcast_dim = this_block_size(os, jcp.os,
379                             bcast_step * os_block);
380                     rp.os = p.bcast_dim;
381
382                     const int oh = os / jcp.ow;
383                     const int ow = os % jcp.ow;
384                     const int ih = nstl::max(oh * stride_h - pad_t, 0);
385                     const int iw = nstl::max(ow * stride_w - pad_l, 0);
386                     rp.iw_start = iw;
387
388                     const int _icb = g * nb_ic + icb;
389                     rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
390                     if (pd()->rtus_.reduce_src_) {
391                         rp.ws = rtus_space
392                             + ithr * pd()->rtus_.space_per_thread_;
393                         p.output_data = rp.ws;
394                     } else
395                         p.output_data = rp.src;
396
397                     for (int ocb_inner = 0; ocb_inner < nboc_inner;
398                         ocb_inner += ocb_inner_step) {
399                         int cur_ocb_inner =
400                             nstl::min(ocb_inner + ocb_inner_step, nboc_inner) -
401                             ocb_inner;
402
403                         int ocb = reduce_outer ? ocb_outer : ocb_inner;
404                         int nb_oc_blocking_step = reduce_outer
405                             ? cur_ocb_outer : cur_ocb_inner;
406                         const int _ocb = g * nb_oc + ocb;
407                         size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow);
408                         p.bcast_data = &diff_dst[diff_dst_off];
409
410                         p.load_data = &weights[pd()->with_groups()
411                             ? weights_d.blk_off(g, ocb, icb)
412                             : weights_d.blk_off(ocb, icb)];
413
414                         p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
415
416                         p.reduce_dim = this_block_size(ocb * jcp.oc_block,
417                             jcp.oc, nb_oc_blocking_step * jcp.oc_block);
418
419                         kernel_->jit_ker(&p);
420                     }
421                     if (pd()->rtus_.reduce_src_)
422                         rtus_driver_->ker_(&rp);
423                 }
424             }
425         }
426     });
427 }
428
429 template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
430 template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::s16,
431     data_type::s16, data_type::s32>;
432
433 /* convolution backward wtr weights */
434
435 #define wht_blk_off(d, g, ...) \
436         (pd()->with_groups() \
437          ? (d).blk_off((g), __VA_ARGS__) \
438          : (d).blk_off(__VA_ARGS__))
439
440 jit_avx512_common_1x1_convolution_bwd_weights_t ::
441         jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd,
442                 const input_vector &inputs, const output_vector &outputs)
443     : cpu_primitive_t(apd, inputs, outputs)
444     , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
445     , trans_kernel_(nullptr), rtus_driver_(nullptr)
446 {
447     kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
448     acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
449     reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
450     init_rtus_driver<avx512_common>(this);
451
452     const auto &jcp = kernel_->jcp;
453
454     if (jcp.transpose_src) {
455         auto tp = jit_transpose4x16_src_t();
456         tp.src_pf0_distance = 4;
457         tp.tr_src_pf0_distance = 0;
458         tp.src_pf1 = true;
459         tp.tr_src_pf1 = false;
460         trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp);
461     }
462 }
463
464 void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights() const
465 {
466     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
467     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
468     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
469     auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
470
471     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
472     const memory_desc_wrapper src_d(pd()->src_pd());
473     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
474
475     const auto &jcp = kernel_->jcp;
476
477     const auto scratchpad = this->scratchpad();
478
479     auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
480     data_t *diff_bias = pd()->wants_padded_bias()
481         ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
482     auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction);
483
484     /* prepare src transposition barriers */
485     auto tr_src = scratchpad.get<data_t>(key_conv_tr_src);
486     auto tr_src_bctx = scratchpad.get<simple_barrier::ctx_t>(
487             key_conv_tr_src_bctx);
488     if (jcp.transpose_src) {
489         for (int i = 0; i < jcp.nthr; ++i)
490             simple_barrier::ctx_init(&tr_src_bctx[i]);
491     }
492
493     const int ndims = src_d.ndims();
494     const int wei_size = jcp.ngroups * jcp.oc * jcp.ic;
495
496     simple_barrier::ctx_t reduction_barrier;
497     simple_barrier::ctx_init(&reduction_barrier);
498
499     const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
500             prefix_reducer_bia);
501     auto rb = this->reducer_bias_;
502     rb->init(reducer_bia_scratchpad);
503
504     // TODO (Roma): remove this restriction
505     assert(jcp.stride_w == 1 && jcp.stride_h == 1);
506
507     const int nb_ic = jcp.nb_bcast;
508     const int nb_ic_blocking = jcp.nb_bcast_blocking;
509
510     const int nb_oc = jcp.nb_load;
511     const int nb_oc_blocking = jcp.nb_load_blocking;
512
513     const int sp_nb = jcp.nb_reduce;
514     const int mb_sp_work = jcp.mb * sp_nb;
515
516     const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
517     const int stride_w = pd()->desc()->strides[ndims - 3];
518     const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
519     const int pad_l = pd()->desc()->padding[0][ndims - 3];
520
521     auto step = [](int default_step, int remaining, int tail_step) {
522         assert(default_step <= tail_step);
523         return remaining < tail_step ? remaining : default_step;
524     };
525
526     // TODO: use memory descriptor with the same fmt as src
527     // (or use a macro :))
528     auto tr_src_off = [&](int img, int icb, int is) {
529         const size_t tr_chn_size = jcp.tr_is * jcp.ic_block;
530         const size_t tr_img_size = tr_chn_size * nb_ic * jcp.ngroups;
531         return img * tr_img_size + icb * tr_chn_size + is * jcp.ic_block;
532     };
533
534     auto uker_trans = [&](int ithr_mb, int img, int sp_b_start, int sp_size,
535         int g_start, int g_work, int ic_b_start, int ic_b_work,
536         int ithr, int nthr, int first_ic_b)
537     {
538         const int work_amount = g_work * ic_b_work;
539
540         int start{ 0 }, end{ 0 };
541         balance211(work_amount, nthr, ithr, start, end);
542
543         int g{ 0 }, ic_b{ 0 };
544         nd_iterator_init(start, g, g_work, ic_b, ic_b_work);
545         g += g_start;
546         const int ic_b_tr = g * nb_ic + first_ic_b + ic_b;
547         ic_b += ic_b_start;
548
549         const int _ic = g * nb_ic + ic_b;
550
551         const int is = sp_b_start * jcp.reduce_block;
552         const int ih = is / jcp.iw;
553         const int iw = is % jcp.iw;
554
555         const int src1_off = data_blk_off(src_d, img, _ic, ih, iw);
556         data_t *src1 = (data_t *)&src[src1_off];
557         data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)];
558
559         assert(jcp.ic_block == 16);
560         const int src_stride = jcp.is * jcp.ic_block;
561         const int tr_src_stride = jcp.tr_is * jcp.ic_block;
562
563         const int my_work = end - start;
564         for (int iwork = 0; iwork < my_work; iwork++) {
565             auto par_trans = jit_src_transpose_s();
566             assert(sp_size % 4 == 0 || sp_size % 4 == jcp.is % 4);
567             par_trans.size = sp_size;
568             par_trans.src = src1;
569             par_trans.tr_src = tr_src1;
570             par_trans.src_prf = src1 + 64 * 16;
571             par_trans.tr_src_prf = tr_src1 + 80 * 16;
572             trans_kernel_->jit_ker(&par_trans);
573
574             src1 += src_stride;
575             tr_src1 += tr_src_stride;
576         }
577     };
578
579     auto ker = [&](const int ithr, const int nthr) {
580         assert(nthr == jcp.nthr);
581         assert(IMPLICATION(!mkldnn_thr_syncable(), jcp.nthr_mb == 1));
582
583         const int ithr_ic_b = ithr % jcp.nthr_ic_b;
584         const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
585         const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
586         const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b /
587                             jcp.nthr_g;
588
589         const int ithr_but_oc
590                 = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b;
591
592         /* reduction dimension */
593         int mb_sp_b_start{ 0 }, mb_sp_b_end{ 0 };
594         if (jcp.transpose_src && jcp.nthr_mb < jcp.mb / 2) {
595             // it's preferable to parallelize by mb if possible
596             int img_start{ 0 }, img_end{ 0 };
597             balance211(jcp.mb, jcp.nthr_mb, ithr_mb, img_start, img_end);
598             mb_sp_b_start = img_start * sp_nb;
599             mb_sp_b_end = img_end * sp_nb;
600         }
601         else {
602             balance211(mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start,
603                     mb_sp_b_end);
604         }
605
606         /* independent dimensions */
607         int g_start{ 0 }, oc_b_start{ 0 }, ic_b_start{ 0 };
608         int g_end{ 0 }, oc_b_end{ 0 }, ic_b_end{ 0 };
609
610         balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end);
611         balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start,
612                     oc_b_end);
613         balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start,
614                     ic_b_end);
615
616         const int g_work = g_end - g_start;
617         const int oc_b_work = oc_b_end - oc_b_start;
618         const int ic_b_work = ic_b_end - ic_b_start;
619
620         data_t *diff_wei = ithr_mb == 0
621             ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size;
622
623         int sp_b_step = 0;
624         for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end;
625                 mb_sp_b += sp_b_step) {
626             int img{ 0 }, sp_b{ 0 };
627             nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb);
628             sp_b_step = step(jcp.nb_reduce_blocking,
629                     nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b),
630                     jcp.nb_reduce_blocking_max);
631
632             for (int g = g_start; g < g_end; ++g) {
633                 int load_step = 0;
634                 int bcast_step = 0;
635                 for (int ic_b = ic_b_start; ic_b < ic_b_end;
636                         ic_b += bcast_step) {
637                     bcast_step = step(nb_ic_blocking, ic_b_end - ic_b,
638                             jcp.nb_bcast_blocking_max);
639                     if (jcp.transpose_src) {
640                         if (jcp.nthr_oc_b > 1)
641                             simple_barrier::barrier(
642                                     &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
643                         const int sp_size
644                                 = nstl::min(sp_b_step * jcp.reduce_block,
645                                         jcp.is - sp_b * jcp.reduce_block);
646                         uker_trans(ithr_mb, img, sp_b, sp_size, g, 1, ic_b,
647                             bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start);
648                         if (jcp.nthr_oc_b > 1)
649                             simple_barrier::barrier(
650                                     &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
651                     }
652
653                     for (int oc_b = oc_b_start; oc_b < oc_b_end;
654                             oc_b += load_step) {
655                         load_step = step(nb_oc_blocking, oc_b_end - oc_b,
656                                 jcp.nb_load_blocking_max);
657                         const int _ic_b = g * nb_ic + ic_b;
658                         const int _ic_b_tr = g * nb_ic + ic_b_start;
659                         const int _oc_b = g * nb_oc + oc_b;
660
661                         data_t *store_to;
662
663                         const size_t off
664                                 = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
665                         store_to = diff_wei + off;
666
667                         const data_t *diff_src = jcp.transpose_src ?
668                                 &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] :
669                                 &src[src_d.blk_off(img, _ic_b)];
670
671                         int sp_b_end = sp_b + sp_b_step;
672                         const data_t *pdiff_dst
673                                 = &diff_dst[diff_dst_d.blk_off(img, _oc_b)];
674                         const data_t *local_src = diff_src;
675
676                         auto p = jit_1x1_conv_call_s();
677                         auto rp = rtus_driver_t<avx512_common>::call_params_t();
678
679                         p.output_stride
680                                 = jcp.ic * jcp.oc_block * jcp.typesize_out;
681
682                         p.load_dim = load_step * jcp.oc_block;
683
684                         p.bcast_dim = bcast_step * jcp.ic_block;
685                         rp.icb = bcast_step;
686                         p.output_data = store_to;
687
688                         p.reduce_dim = sp_b_step * jcp.reduce_block;
689                         rp.os = p.reduce_dim;
690
691                         p.first_last_flag = 0
692                             | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0)
693                             | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0);
694
695                         int sp = sp_b * jcp.reduce_block;
696                         p.load_data = pdiff_dst + sp * jcp.oc_block;
697
698                         if (pd()->rtus_.reduce_src_) {
699                             const int oh = sp / jcp.ow;
700                             const int ow = sp % jcp.ow;
701
702                             const int ih = nstl::max(oh * stride_h - pad_t, 0);
703                             const int iw = nstl::max(ow * stride_w - pad_l, 0);
704                             rp.iw_start = iw;
705
706                             rp.ws = rtus_space
707                                 + ithr * pd()->rtus_.space_per_thread_
708                                 + sp * jcp.ic_block;
709
710                             if (ndims == 3)
711                                 rp.src = local_src + iw
712                                     * src_d.blocking_desc().strides[0][2];
713                             else
714                                 rp.src = local_src + ih
715                                     * src_d.blocking_desc().strides[0][2]
716                                     + iw * src_d.blocking_desc().strides[0][3];
717                             rtus_driver_->ker_(&rp);
718
719                             p.bcast_data = rp.ws;
720                         } else
721                             p.bcast_data = local_src + sp * jcp.ic_block;
722
723                         kernel_->jit_ker(&p);
724                     }
725                 }
726             }
727         }
728
729         /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */
730         if (jcp.nthr_mb > 1) {
731             simple_barrier::barrier(&reduction_barrier, jcp.nthr);
732             const int work = g_work * oc_b_work * ic_b_work;
733             int start{ 0 }, end{ 0 };
734             balance211(work, jcp.nthr_mb, ithr_mb, start, end);
735             if (start == end)
736                 return;
737
738             for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
739                 int w = start;
740                 int sub_g_start{ 0 }, sub_oc_b_start{ 0 },
741                         sub_ic_b_start{ 0 };
742                 nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start,
743                         oc_b_work, sub_ic_b_start, ic_b_work);
744                 while (w < end) {
745                     const int g = g_start + sub_g_start;
746                     const int oc_b = oc_b_start + sub_oc_b_start;
747                     const int ic_b = ic_b_start + sub_ic_b_start;
748
749                     const int acc_size
750                             = nstl::min(end - w, ic_b_work - sub_ic_b_start)
751                             * jcp.ic_block * jcp.oc_block;
752
753                     const size_t off
754                             = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
755                     data_t *d = diff_weights + off;
756                     data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off;
757
758                     acc_ker_->accumulate(d, s, acc_size);
759
760                     nd_iterator_jump(w, end, sub_g_start, g_work,
761                             sub_oc_b_start, oc_b_work, sub_ic_b_start,
762                             ic_b_work);
763                 }
764             }
765         }
766     };
767
768     auto ker_bias = [&](int ithr, int nthr) {
769         assert(nthr == rb->balancer().nthr_);
770
771         const int b_job_start = rb->balancer().ithr_job_off(ithr);
772         const int b_njobs = rb->balancer().ithr_njobs(ithr);
773
774         if (b_njobs == 0)
775             return;
776
777         /* reduction dimension */
778         int img_start{ 0 }, img_end{ 0 };
779
780         balance211(jcp.mb, rb->balancer().nthr_per_group_,
781                 rb->balancer().id_in_group(ithr), img_start, img_end);
782
783         /* jobs */
784         int g_start{ 0 }, ocb_start{ 0 };
785         nd_iterator_init(
786                 b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load);
787
788         for (int img = img_start; img < img_end; ++img) {
789             int g = g_start, ocb = ocb_start;
790             for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
791                 const size_t _oc = g * jcp.nb_load + ocb;
792
793                 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
794                 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
795                         reducer_bia_scratchpad)
796                     + b_job_loc * rb->balancer().job_size_;
797
798                 if (img == img_start)
799                     for (int o = 0; o < 16; ++o)
800                         d_bias[o] = 0.;
801
802                 for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) {
803                     PRAGMA_OMP_SIMD()
804                     for (int o = 0; o < 16; ++o)
805                         d_bias[o] += d_dst[o];
806                     d_dst += 16;
807                 }
808
809                 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load);
810             }
811         }
812         rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
813     };
814
815     parallel(jcp.nthr, (size_t)mkldnn_get_max_threads(), [&](const int ithr, const int nthr) {
816         ker(ithr, jcp.nthr);
817         if (pd()->with_bias())
818             ker_bias(ithr, jcp.nthr);
819     });
820
821     /* TODO: put this in ker_bias */
822     if (pd()->wants_padded_bias()) {
823         assert(jcp.ngroups == 1);
824         utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding);
825     }
826 }
827
828 }
829 }
830 }