Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21
22 #include "jit_avx512_common_convolution.hpp"
23
24 namespace mkldnn {
25 namespace impl {
26 namespace cpu {
27
28 using namespace mkldnn::impl::status;
29 using namespace mkldnn::impl::memory_format;
30 using namespace mkldnn::impl::memory_tracking::names;
31 using namespace mkldnn::impl::utils;
32
33 using namespace nstl;
34
35 using jit_conv_ker_t = void (*)(jit_conv_call_s *);
36
37 #define PIPELINE(field) \
38     do { \
39         p.field = p.field ## _prf; \
40         p.field ## _prf = field; \
41     } while (0)
42
43 inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
44         const void *src, const void *dst, const void *filt, const void *bias,
45         int channel, int kh_padding, int oc_off)
46 {
47     PIPELINE(src);
48     PIPELINE(dst);
49     PIPELINE(filt);
50     PIPELINE(bias);
51     PIPELINE(channel);
52     PIPELINE(kh_padding);
53     PIPELINE(oc_off);
54
55     if (p.src)
56         ker(&p);
57 }
58 // The special case for the driver with ow-parallelization (FWD)
59 // TODO: implement it for BWD_D and BWD_W too
60 inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p,
61         const void *src, const void *dst, const void *filt, const void *bias,
62         int channel, int kh_padding, int owb, int oc_off)
63 {
64     PIPELINE(src);
65     PIPELINE(dst);
66     PIPELINE(filt);
67     PIPELINE(bias);
68     PIPELINE(channel);
69     PIPELINE(kh_padding);
70     PIPELINE(owb);
71     PIPELINE(oc_off);
72
73     if (p.src)
74         ker(&p);
75 }
76
77 inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
78         const void *src, const void *dst, const void *filt, const void *bias,
79         int channel, int kh_padding, int kd_padding, int oc_off)
80 {
81     PIPELINE(src);
82     PIPELINE(dst);
83     PIPELINE(filt);
84     PIPELINE(bias);
85     PIPELINE(channel);
86     PIPELINE(kh_padding);
87     PIPELINE(kd_padding);
88     PIPELINE(oc_off);
89
90     if (p.src)
91         ker(&p);
92 }
93 // The special case for the driver with ow-parallelization (FWD)
94 // TODO: implement it for BWD_D and BWD_W too
95 inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker,
96         jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
97         const void *bias, int channel, int kh_padding, int kd_padding, int owb, int oc_off)
98 {
99     PIPELINE(src);
100     PIPELINE(dst);
101     PIPELINE(filt);
102     PIPELINE(bias);
103     PIPELINE(channel);
104     PIPELINE(kh_padding);
105     PIPELINE(kd_padding);
106     PIPELINE(owb);
107     PIPELINE(oc_off);
108
109     if (p.src)
110         ker(&p);
111 }
112
113 void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
114         const void *src, const void *dst, const void *filt, const void *bias,
115         int channel, int d_index, int d_worksize,
116         int kd_padding /* kd_work_size */, size_t kd_offset) {
117     PIPELINE(src);
118     PIPELINE(dst);
119     PIPELINE(filt);
120     PIPELINE(bias);
121     PIPELINE(channel);
122     PIPELINE(kd_padding);
123     PIPELINE(d_worksize);
124     PIPELINE(d_index);
125     PIPELINE(kd_offset);
126
127     if (p.src)
128         ker(&p);
129 }
130 #define wht_blk_off(d, g, ...) \
131         (pd()->with_groups() \
132          ? (d).blk_off((g), __VA_ARGS__) \
133          : (d).blk_off(__VA_ARGS__))
134
135 template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
136 void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
137 prepare_padded_bias(const dst_data_t *&bias) const {
138     if (!pd()->wants_padded_bias()) return;
139
140     auto padded_bias = scratchpad().template get<dst_data_t>(
141             key_conv_padded_bias);
142     utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding);
143     utils::array_set(padded_bias + pd()->jcp_.oc_without_padding,
144             (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding);
145     bias = padded_bias;
146 }
147
148 template <data_type_t src_type, data_type_t wei_type,
149           data_type_t dst_type>
150 void jit_avx512_common_convolution_fwd_t
151     <src_type, wei_type, dst_type>::execute_forward_1d() const
152 {
153     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
154     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
155     auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
156     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
157
158     prepare_padded_bias(bias);
159
160     const memory_desc_wrapper src_d(pd()->src_pd());
161     const memory_desc_wrapper dst_d(pd()->dst_pd());
162     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
163
164     const auto &jcp = pd()->jcp_;
165     assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
166
167     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
168     int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow;
169
170     int nthr;
171     if (jcp.aligned_threads)
172         nthr = jcp.aligned_threads;
173     else
174         nthr = mkldnn_get_max_threads();
175
176     parallel(nthr, [&](const int ithr, const int nthr) {
177         int start{0}, end{0}, start_copy;
178         balance211(work_amount, nthr, ithr, start, end);
179         start_copy = start;
180
181         auto par_conv = jit_conv_call_s();
182         size_t src_c_stride = src_d.blk_off(0, 1);
183         size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
184
185         for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
186             start = start_copy;
187             int n{0}, g{0}, occ{0}, owb{0};
188
189             if (jcp.loop_order == loop_cwgn) {
190                 int dummy{0};
191                 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
192                         g, jcp.ngroups, n, jcp.mb, dummy, 1);
193             } else if (jcp.loop_order == loop_gncw) {
194                 int dummy{0};
195                 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ,
196                         oc_chunks, owb, jcp.nb_ow, dummy, 1);
197             } else {
198                 assert(!"unsupported loop order");
199             }
200
201             while (start < end) {
202                 int ocb = occ * jcp.nb_oc_blocking;
203                 int g_ocb = g * jcp.nb_oc + ocb;
204                 int g_oc = g_ocb * jcp.oc_block;
205                 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
206
207                 int ow_s =  owb * jcp.ow_block;
208                 int iw_s =  ow_s * jcp.stride_w;
209                 auto bias_w = bias ? bias + g_oc : nullptr;
210                 auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s);
211                 auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s);
212                 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
213
214                 int oc_off = g_oc * sizeof(dst_data_t);
215
216                 for (int icb = icb_l2;
217                      icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
218                      jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
219                         src_w, dst_w, wht_w, bias_w, icb, 1, owb, oc_off);
220
221                     src_w += src_c_stride;
222                     wht_w += wht_ic_stride;
223                 }
224                 if (jcp.loop_order == loop_cwgn) {
225                     int dummy{0};
226                     nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
227                             g, jcp.ngroups, n, jcp.mb, dummy, 1);
228                 } else if (jcp.loop_order == loop_gncw) {
229                     int dummy{0};
230                     nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb,
231                             occ, oc_chunks, owb, jcp.nb_ow, dummy, 1);
232                 } else {
233                     assert(!"unsupported loop order");
234                 }
235             }
236         }
237         jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
238                 src, dst, weights, bias, 0, 0, 0, 0);
239     });
240 }
241
242 template <data_type_t src_type, data_type_t wei_type,
243           data_type_t dst_type>
244 void jit_avx512_common_convolution_fwd_t
245     <src_type, wei_type, dst_type>::execute_forward_2d() const
246 {
247     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
248     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
249     auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
250     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
251
252     prepare_padded_bias(bias);
253
254     const memory_desc_wrapper src_d(pd()->src_pd());
255     const memory_desc_wrapper dst_d(pd()->dst_pd());
256     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
257
258     const auto &jcp = pd()->jcp_;
259     const int MB = pd()->MB();
260     assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
261
262     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
263     int work_amount = MB * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow;
264
265     int nthr;
266     if (jcp.aligned_threads)
267         nthr = jcp.aligned_threads;
268     else
269         nthr = mkldnn_get_max_threads();
270
271     parallel(nthr, [&](const int ithr, const int nthr) {
272         int start{0}, end{0}, start_copy;
273         balance211(work_amount, nthr, ithr, start, end);
274         start_copy = start;
275
276         auto par_conv = jit_conv_call_s();
277         size_t src_h_stride = src_d.blk_off(0, 0, 1) - src_d.off_l(0);
278         size_t src_c_stride = src_d.blk_off(0, 1) - src_d.off_l(0);
279         size_t dst_h_stride = dst_d.blk_off(0, 0, 1) - dst_d.off_l(0);
280         size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
281         size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
282
283         for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
284             start = start_copy;
285             int n{0}, g{0}, occ{0}, oh_s{0}, owb{0};
286
287             if (jcp.loop_order == loop_cwgn)
288                 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
289                     g, jcp.ngroups, n, MB, oh_s, jcp.oh);
290             else if (jcp.loop_order == loop_gncw)
291                 nd_iterator_init(start, g, jcp.ngroups, n, MB,
292                     occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
293             else
294                 assert(!"unsupported loop order");
295
296             while (start < end) {
297                 int ocb = occ * jcp.nb_oc_blocking;
298                 int g_ocb = g * jcp.nb_oc + ocb;
299                 int g_oc = g_ocb * jcp.oc_block;
300                 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
301
302                 int work_rem = end - start;
303
304                 int ow_s =  owb * jcp.ow_block;
305                 int iw_s =  ow_s * jcp.stride_w;
306                 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
307                 auto bias_w = bias ? bias + g_oc : nullptr;
308
309                 for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) {
310                     int ih_b = -jcp.t_pad + oh_b * jcp.stride_h;
311
312                     auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s);
313                     auto src_w
314                         = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s);
315                     auto wht_w
316                             = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
317
318                     for (int icb = icb_l2;
319                             icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2);
320                             ++icb) {
321                         auto src_c = src_w;
322                         auto dst_c = dst_w;
323                         for (int oj = oh_b, ij = ih_b;
324                                 oj < min(oh_e, oh_b + jcp.h_blocking);
325                                 ++oj, ij += jcp.stride_h) {
326                             int dilate_h = jcp.dilate_h + 1;
327                             int i_t_overflow = div_up(max(0, -ij), dilate_h);
328                             int i_b_overflow = div_up(max(0, ij - jcp.ih
329                                 + (jcp.kh - 1) * dilate_h + 1), dilate_h);
330                             int kh_padding = nstl::max(
331                                     0, jcp.kh - i_t_overflow - i_b_overflow);
332
333                             auto aux_src = src_c
334                                     + i_t_overflow * dilate_h * src_h_stride;
335                             auto aux_wht = wht_w + i_t_overflow * wht_h_stride;
336
337                             int oc_off = g_oc * sizeof(dst_data_t);
338
339                             jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker,
340                                 par_conv, aux_src, dst_c, aux_wht, bias_w, icb,
341                                 kh_padding, owb, oc_off);
342
343                             src_c += src_h_stride * jcp.stride_h;
344                             dst_c += dst_h_stride;
345                         }
346                         src_w += src_c_stride;
347                         wht_w += wht_ic_stride;
348                     }
349                 }
350
351                 if (jcp.loop_order == loop_cwgn)
352                     nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
353                         g, jcp.ngroups, n, MB, oh_s, jcp.oh);
354                 else if (jcp.loop_order == loop_gncw)
355                     nd_iterator_jump(start, end, g, jcp.ngroups, n, MB, occ,
356                         oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
357                 else
358                     assert(!"unsupported loop order");
359             }
360         }
361
362         jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
363                 src, dst, weights, bias, 0, 0, 0, 0);
364     });
365 }
366
367 template <data_type_t src_type, data_type_t wei_type,
368           data_type_t dst_type>
369 void jit_avx512_common_convolution_fwd_t
370     <src_type, wei_type, dst_type>::execute_forward_3d() const
371 {
372     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
373     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
374     auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
375     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
376
377     prepare_padded_bias(bias);
378
379     const memory_desc_wrapper src_d(pd()->src_pd());
380     const memory_desc_wrapper dst_d(pd()->dst_pd());
381     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
382     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
383
384     const auto &jcp = pd()->jcp_;
385     const int MB = pd()->MB();
386     assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
387
388     parallel(0, [&](const int ithr, const int nthr) {
389         int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
390         int start{0}, end{0}, start_copy;
391         int work_amount = MB * jcp.ngroups * oc_chunks * jcp.od * jcp.oh
392             * jcp.nb_ow;
393         balance211(work_amount, nthr, ithr, start, end);
394         start_copy = start;
395
396         auto par_conv = jit_conv_call_s();
397         size_t src_d_stride = src_d.blk_off(0, 0, 1);
398         size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
399         size_t src_c_stride = src_d.blk_off(0, 1);
400         size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
401         size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
402         size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
403         size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
404
405         for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
406             start = start_copy;
407             int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0};
408
409             if (jcp.loop_order == loop_cwgn)
410                 nd_iterator_init(start,
411                     occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, MB,
412                     od_s, jcp.od, oh_s, jcp.oh);
413             else if (jcp.loop_order == loop_gncw)
414                 nd_iterator_init(start,
415                     g, jcp.ngroups, n, MB, occ, oc_chunks, owb, jcp.nb_ow,
416                     od_s, jcp.od, oh_s, jcp.oh);
417             else
418                 assert(!"unsupported loop order");
419
420             while (start < end) {
421                 int ocb = occ * jcp.nb_oc_blocking;
422                 int g_ocb = g * jcp.nb_oc + ocb;
423                 int g_oc = g_ocb * jcp.oc_block;
424                 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
425
426                 int work_rem = end - start;
427                 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
428                 int ow_s =  owb * jcp.ow_block;
429                 int iw_s =  ow_s * jcp.stride_w;
430                 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
431
432                 int id_s = -jcp.f_pad + od_s * jcp.stride_d;
433
434                 int dilate_d = jcp.dilate_d + 1;
435                 int d_t_overflow = div_up(max(0, -id_s), dilate_d);
436                 int d_b_overflow = div_up(
437                         max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
438                         dilate_d);
439                 int kd_padding = nstl::max(0,
440                     jcp.kd - d_t_overflow - d_b_overflow);
441
442                 auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0;
443                 auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s);
444                 auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s,
445                     iw_s) + d_t_overflow * dilate_d * src_d_stride;
446                 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2)
447                     + d_t_overflow * wht_d_stride;
448
449                 for (int icb = icb_l2;
450                      icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
451                     auto src_c = src_w;
452                     auto dst_c = dst_w;
453                     for (int oj = oh_s, ij = ih_s;
454                             oj < oh_e; ++oj, ij += jcp.stride_h)
455                     {
456                         int dilate_h = jcp.dilate_h + 1;
457                         int i_t_overflow = div_up(max(0, -ij), dilate_h);
458                         int i_b_overflow = div_up(
459                                 max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h
460                                                 + 1),
461                                 dilate_h);
462                         int kh_padding = nstl::max(0,
463                             jcp.kh - i_t_overflow - i_b_overflow);
464
465                         int oc_off = g_oc * sizeof(dst_data_t);
466
467                         jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker,
468                             par_conv,
469                             src_c + i_t_overflow * dilate_h * src_h_stride,
470                             dst_c, wht_w + i_t_overflow * wht_h_stride,
471                             bias_w, icb, kh_padding, kd_padding, owb, oc_off);
472
473                         src_c += src_h_stride * jcp.stride_h;
474                         dst_c += dst_h_stride;
475                     }
476                     src_w += src_c_stride;
477                     wht_w += wht_ic_stride;
478                 }
479
480                 if (jcp.loop_order == loop_cwgn)
481                     nd_iterator_jump(start, end,
482                       occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, MB,
483                       od_s, jcp.od, oh_s, jcp.oh);
484                 else if (jcp.loop_order == loop_gncw)
485                     nd_iterator_jump(start, end,
486                       g, jcp.ngroups, n, MB, occ, oc_chunks, owb, jcp.nb_ow,
487                       od_s, jcp.od, oh_s, jcp.oh);
488                 else
489                     assert(!"unsupported loop order");
490             }
491         }
492         jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
493                 src, dst, weights, bias, 0, 0, 0, 0);
494     });
495 }
496
497 template struct jit_avx512_common_convolution_fwd_t<data_type::f32>;
498 template struct jit_avx512_common_convolution_fwd_t<data_type::s16,
499         data_type::s16, data_type::s32>;
500
501 template <data_type_t diff_dst_type, data_type_t wei_type,
502           data_type_t diff_src_type>
503 void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
504           diff_src_type>::execute_backward_data_1d() const {
505     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
506                                                        (this->input_memory(0));
507     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
508     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
509
510     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
511     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
512     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
513
514     const auto &jcp = kernel_->jcp;
515
516     parallel(0, [&](const int ithr, const int nthr) {
517         int start{0}, end{0}, start_copy;
518         int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
519         int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
520         balance211(work_amount, nthr, ithr, start, end);
521         start_copy = start;
522
523         auto par_conv = jit_conv_call_s();
524         size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
525         size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
526
527         for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
528             start = start_copy;
529             int n{0}, g{0}, icc{0};
530             if (jcp.loop_order == loop_cgn) {
531                 int dummy{0};
532                 nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n,
533                         jcp.mb, dummy, 1);
534             } else if (jcp.loop_order == loop_gnc) {
535                 int dummy{0};
536                 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc,
537                         ic_chunks, dummy, 1);
538             } else {
539                 assert(!"unsupported loop order");
540             }
541
542             while (start < end) {
543                 int icb = icc * jcp.nb_ic_blocking;
544                 int g_icb = g * jcp.nb_ic + icb;
545                 int g_ocb = g * jcp.nb_oc;
546
547                 auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
548                 auto diff_dst_w = diff_dst
549                     + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
550                 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
551
552                 for (int ocb = ocb_l2;
553                       ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
554                     jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
555                             diff_src_w, diff_dst_w, wht_w, 0, ocb, 1, 0);
556                     diff_dst_w += diff_dst_c_stride;
557                     wht_w += wht_oc_stride;
558                 }
559
560                 if (jcp.loop_order == loop_cgn) {
561                     int dummy{0};
562                     nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups,
563                             n, jcp.mb, dummy, 1);
564                 } else if (jcp.loop_order == loop_gnc) {
565                     int dummy{0};
566                     nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc,
567                             ic_chunks, dummy, 1);
568                 } else {
569                     assert(!"unsupported loop order");
570                 }
571             }
572         }
573
574         jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
575                 diff_src, diff_dst, weights, 0, 0, 1, 0);
576     });
577 }
578
579 template <data_type_t diff_dst_type, data_type_t wei_type,
580           data_type_t diff_src_type>
581 void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
582           diff_src_type>::execute_backward_data_2d() const {
583     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
584                                                        (this->input_memory(0));
585     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
586     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
587
588     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
589     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
590     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
591
592     const auto &jcp = kernel_->jcp;
593     const int MB = pd()->MB();
594
595     parallel(0, [&](const int ithr, const int nthr) {
596         int start{0}, end{0}, start_copy;
597         int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
598         int work_amount = jcp.ngroups * MB * ic_chunks * jcp.ih;
599         balance211(work_amount, nthr, ithr, start, end);
600         start_copy = start;
601
602         auto par_conv = jit_conv_call_s();
603         size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1);
604         size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1);
605         size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
606         size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
607         size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
608
609         bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1;
610
611         for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
612             start = start_copy;
613             int n{0}, g{0}, icc{0}, ih_s{0};
614             if (jcp.loop_order == loop_cgn)
615                 nd_iterator_init(start,
616                     icc, ic_chunks, g, jcp.ngroups, n, MB, ih_s, jcp.ih);
617             else if (jcp.loop_order == loop_gnc)
618                 nd_iterator_init(start,
619                     g, jcp.ngroups, n, MB, icc, ic_chunks, ih_s, jcp.ih);
620             else
621                 assert(!"unsupported loop order");
622
623             while (start < end) {
624                 int icb = icc * jcp.nb_ic_blocking;
625                 int g_icb = g * jcp.nb_ic + icb;
626                 int g_ocb = g * jcp.nb_oc;
627
628                 int work_rem = end - start;
629                 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
630
631                 auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
632                 auto diff_dst_w = diff_dst
633                     + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
634                 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
635
636                 for (int ocb = ocb_l2;
637                       ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
638                     for (int ij = ih_s; ij < ih_e; ++ij) {
639                         int oj, k_len, k_lo;
640                         if (is_fast_path) { // dilate == 0 && stride == 1
641                             int i_t_overflow = max(0, jcp.kh - 1 - ij
642                                 - jcp.t_pad);
643                             int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
644                                 - jcp.b_pad);
645                             k_len = jcp.kh - i_t_overflow - i_b_overflow;
646                             k_lo = i_b_overflow;
647                             oj = ij + jcp.t_pad - i_b_overflow;
648                         } else if (jcp.dilate_h != 0) { // stride == 1
649                             int dilate_h = jcp.dilate_h + 1;
650                             // Note: use div_up to account for "holes" in filter
651                             int i_t_overflow
652                                 = div_up(max(0, (jcp.kh - 1) * dilate_h
653                                         - ij - jcp.t_pad), dilate_h);
654                             int i_b_overflow
655                                 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
656                                         - jcp.ih + ij - jcp.b_pad), dilate_h);
657                             k_len = jcp.kh - i_t_overflow - i_b_overflow;
658                             k_lo = i_b_overflow;
659                             oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
660                         } else { // dilate == 0
661                             int i_t_overflow = max(0, (jcp.kh - 1 - ij
662                                 - jcp.t_pad) / jcp.stride_h);
663                             int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
664                                 - jcp.b_pad) / jcp.stride_h);
665                             int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
666                                 + jcp.b_pad - ij) % jcp.stride_h);
667                             int overflow_kh_lo = (ij + jcp.t_pad)
668                                 % jcp.stride_h;
669
670                             k_len = (overflow_kh_hi - overflow_kh_lo)
671                                 / jcp.stride_h + 1 - i_t_overflow
672                                 - i_b_overflow;
673                             k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
674                             oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
675                         }
676                         assert(k_len >= 0);
677
678                         jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
679                                 diff_src_w + ij * diff_src_h_stride,
680                                 diff_dst_w + oj * diff_dst_h_stride,
681                                 wht_w + k_lo * wht_h_stride,
682                                 0, ocb, k_len, 0);
683                     }
684                     diff_dst_w += diff_dst_c_stride;
685                     wht_w += wht_oc_stride;
686                 }
687
688                 if (jcp.loop_order == loop_cgn)
689                     nd_iterator_jump(start, end,
690                       icc, ic_chunks, g, jcp.ngroups, n, MB, ih_s, jcp.ih);
691                 else if (jcp.loop_order == loop_gnc)
692                     nd_iterator_jump(start, end,
693                       g, jcp.ngroups, n, MB, icc, ic_chunks, ih_s, jcp.ih);
694                 else
695                     assert(!"unsupported loop order");
696             }
697         }
698
699         jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
700                 diff_src, diff_dst, weights, 0, 0, 1, 0);
701     });
702 }
703
704 template <data_type_t diff_dst_type, data_type_t wei_type,
705           data_type_t diff_src_type>
706 void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
707           diff_src_type>::execute_backward_data_3d() const {
708     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
709                                                        (this->input_memory(0));
710     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
711     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
712
713     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
714     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
715     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
716
717     const auto &jcp = kernel_->jcp;
718     const int MB = pd()->MB();
719
720     parallel(0, [&](const int ithr, const int nthr) {
721         int start{0}, end{0}, start_copy;
722         int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
723         int work_amount = jcp.ngroups * MB * ic_chunks * jcp.id * jcp.ih;
724         balance211(work_amount, nthr, ithr, start, end);
725         start_copy = start;
726
727         auto par_conv = jit_conv_call_s();
728         size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1);
729         size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1);
730         size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1);
731         size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1);
732         size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
733         size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
734         size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
735         size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
736
737         bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1;
738         bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1;
739
740         for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
741             start = start_copy;
742             int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0};
743             if (jcp.loop_order == loop_cgn)
744                 nd_iterator_init(start,
745                     icc, ic_chunks, g, jcp.ngroups, n, MB, id_s, jcp.id,
746                     ih_s, jcp.ih);
747             else if (jcp.loop_order == loop_gnc)
748                 nd_iterator_init(start,
749                     g, jcp.ngroups, n, MB, icc, ic_chunks, id_s, jcp.id,
750                     ih_s, jcp.ih);
751             else
752                 assert(!"unsupported loop order");
753
754             while (start < end) {
755                 int icb = icc * jcp.nb_ic_blocking;
756                 int g_icb = g * jcp.nb_ic + icb;
757                 int g_ocb = g * jcp.nb_oc;
758
759                 int work_rem = end - start;
760                 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
761                 int d_len = 0, d_lo = 0, d_oj = 0;
762                 if (is_fast_path_d) { // dilate == 0 && stride == 1
763                     int d_t_overflow = max(0, jcp.kd - 1 - id_s
764                             - jcp.f_pad);
765                     int d_b_overflow = max(0, jcp.kd - jcp.id + id_s
766                             - jcp.back_pad);
767                     d_len = jcp.kd - d_t_overflow - d_b_overflow;
768                     d_lo = d_b_overflow;
769                     d_oj = id_s + jcp.f_pad - d_b_overflow;
770                 } else if (jcp.dilate_d != 0) { // stride == 1
771                     int dilate_d = jcp.dilate_d + 1;
772                     // Note: use div_up to account for "holes" in filter
773                     int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d
774                                 - id_s - jcp.f_pad), dilate_d);
775                     int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1
776                                 - jcp.id + id_s - jcp.back_pad), dilate_d);
777                     d_len = jcp.kd - d_t_overflow - d_b_overflow;
778                     d_lo = d_b_overflow;
779                     d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d;
780                 } else { // dilate == 0
781                     int d_t_overflow = max(0, (jcp.kd - 1 - id_s
782                                 - jcp.f_pad) / jcp.stride_d);
783                     int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s
784                                 - jcp.back_pad) / jcp.stride_d);
785                     int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1
786                                 + jcp.back_pad - id_s) % jcp.stride_d);
787                     int overflow_kd_lo = (id_s + jcp.f_pad)
788                         % jcp.stride_d;
789
790                     d_len = (overflow_kd_hi - overflow_kd_lo)
791                         / jcp.stride_d + 1 - d_t_overflow
792                         - d_b_overflow;
793                     d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d;
794                     d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d;
795                 }
796                 assert(d_len >= 0);
797
798                 auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb)
799                     + id_s * diff_src_d_stride;
800                 auto diff_dst_w = diff_dst
801                     + diff_dst_d.blk_off(n, g_ocb + ocb_l2)
802                     + d_oj * diff_dst_d_stride;
803                 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb)
804                     + d_lo * wht_d_stride;
805
806                 for (int ocb = ocb_l2;
807                       ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
808                     for (int ij = ih_s; ij < ih_e; ++ij) {
809                         int oj, k_len, k_lo;
810                         if (is_fast_path_h) { // dilate == 0 && stride == 1
811                             int i_t_overflow = max(0, jcp.kh - 1 - ij
812                                 - jcp.t_pad);
813                             int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
814                                 - jcp.b_pad);
815                             k_len = jcp.kh - i_t_overflow - i_b_overflow;
816                             k_lo = i_b_overflow;
817                             oj = ij + jcp.t_pad - i_b_overflow;
818                         } else if (jcp.dilate_h != 0) { // stride == 1
819                             int dilate_h = jcp.dilate_h + 1;
820                             // Note: use div_up to account for "holes" in filter
821                             int i_t_overflow
822                                 = div_up(max(0, (jcp.kh - 1) * dilate_h
823                                         - ij - jcp.t_pad), dilate_h);
824                             int i_b_overflow
825                                 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
826                                         - jcp.ih + ij - jcp.b_pad), dilate_h);
827                             k_len = jcp.kh - i_t_overflow - i_b_overflow;
828                             k_lo = i_b_overflow;
829                             oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
830                         } else { // dilate == 0
831                             int i_t_overflow = max(0, (jcp.kh - 1 - ij
832                                 - jcp.t_pad) / jcp.stride_h);
833                             int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
834                                 - jcp.b_pad) / jcp.stride_h);
835                             int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
836                                 + jcp.b_pad - ij) % jcp.stride_h);
837                             int overflow_kh_lo = (ij + jcp.t_pad)
838                                 % jcp.stride_h;
839
840                             k_len = (overflow_kh_hi - overflow_kh_lo)
841                                 / jcp.stride_h + 1 - i_t_overflow
842                                 - i_b_overflow;
843                             k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
844                             oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
845                         }
846                         assert(k_len >= 0);
847
848                         jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
849                                 diff_src_w + ij * diff_src_h_stride,
850                                 diff_dst_w + oj * diff_dst_h_stride,
851                                 wht_w + k_lo * wht_h_stride,
852                                 0, ocb, k_len, d_len, 0);
853                     }
854                     diff_dst_w += diff_dst_c_stride;
855                     wht_w += wht_oc_stride;
856                 }
857
858                 if (jcp.loop_order == loop_cgn)
859                     nd_iterator_jump(start, end,
860                       icc, ic_chunks, g, jcp.ngroups, n, MB, id_s, jcp.id,
861                       ih_s, jcp.ih);
862                 else if (jcp.loop_order == loop_gnc)
863                     nd_iterator_jump(start, end,
864                       g, jcp.ngroups, n, MB, icc, ic_chunks, id_s, jcp.id,
865                       ih_s, jcp.ih);
866                 else
867                     assert(!"unsupported loop order");
868             }
869         }
870
871         jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
872                 diff_src, diff_dst, weights, 0, 0, 1, 1, 0);
873     });
874 }
875
876 template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>;
877 template struct jit_avx512_common_convolution_bwd_data_t<data_type::s16,
878     data_type::s16, data_type::s32>;
879
880 template <data_type_t src_type, data_type_t diff_dst_type,
881           data_type_t diff_weights_type>
882 jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
883           diff_weights_type>::
884 jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd,
885         const input_vector &inputs, const output_vector &outputs)
886     : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr)
887     , trans_kernel_(nullptr), trans_dst_kernel_(nullptr), acc_ker_(nullptr)
888     , reducer_bias_(nullptr)
889 {
890     const auto &j = pd()->jcp_;
891
892     nthr_ = j.nthr;
893     nthr_mb_ = j.nthr_mb;
894     nthr_g_ = j.nthr_g;
895     nthr_oc_b_ = j.nthr_oc_b;
896     nthr_ic_b_ = j.nthr_ic_b;
897
898     kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j);
899
900     if (utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
901         trans_kernel_ = create_trans_src(&j);
902         if (utils::one_of(j.ver, ver_4vnni, ver_vnni))
903             trans_dst_kernel_ = create_trans_dst(&j);
904     }
905
906     if (nthr_mb_ > 1)
907         acc_ker_ = new cpu_accumulator_1d_t<diff_weights_type>();
908
909     reducer_bias_ =
910         new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_);
911 }
912
913 template <data_type_t src_type, data_type_t diff_dst_type,
914           data_type_t diff_weights_type>
915 struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
916     diff_weights_type>::thread_info_t {
917     const src_data_t *src;
918     const diff_dst_data_t *diff_dst;
919     const diff_weights_data_t *diff_weights;
920     diff_weights_data_t *diff_bias;
921
922     const memory_tracking::grantor_t scratchpad;
923
924     src_data_t *tr_src;
925     simple_barrier::ctx_t *tr_src_bctx;
926
927     diff_dst_data_t *tr_diff_dst;
928     simple_barrier::ctx_t *tr_diff_dst_bctx;
929
930     diff_weights_data_t *wei_bia_reduction;
931     simple_barrier::ctx_t *wei_bia_reduction_bctx;
932
933     int ithr;
934     int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb;
935     int ithr_but_oc;
936     int ithr_but_ic;
937
938     int img_start = 0, img_end = 0, img_work;
939     int g_start = 0, g_end = 0, g_work;
940     int oc_b_start = 0, oc_b_end = 0, oc_b_work;
941     int ic_b_start = 0, ic_b_end = 0, ic_b_work;
942
943     thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self,
944             int ithr): scratchpad(self->scratchpad()), ithr(ithr) {
945         src = reinterpret_cast<const src_data_t *>(self->input_memory(0));
946         diff_dst = reinterpret_cast<const diff_dst_data_t *>(
947             self->input_memory(1));
948         diff_weights = reinterpret_cast<diff_weights_data_t *>(self->memory(0));
949         diff_bias = self->pd()->wants_padded_bias()
950             ? scratchpad.template get<diff_weights_data_t>(
951                     key_conv_padded_bias)
952             : reinterpret_cast<diff_weights_data_t *>(self->memory(1));
953
954         tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
955         tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
956                 key_conv_tr_src_bctx);
957
958         tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
959                 key_conv_tr_diff_dst);
960         tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
961                 key_conv_tr_diff_dst_bctx);
962
963         wei_bia_reduction = scratchpad.template get<diff_weights_data_t>(
964                 key_conv_wei_bia_reduction);
965         wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
966                 key_conv_wei_bia_reduction_bctx);
967
968         ithr_ic_b = ithr % self->nthr_ic_b_;
969         ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
970         ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
971         ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_;
972
973         ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_
974             + ithr_ic_b;
975
976         ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_
977             + ithr_oc_b;
978
979         const auto &jcp = self->kernel_->jcp;
980
981         /* reduction dimension */
982         balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end);
983         img_work = img_end - img_start;
984
985         /* independent dimensions */
986         balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end);
987         g_work = g_end - g_start;
988
989         balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start,
990                 oc_b_end);
991         oc_b_work = oc_b_end - oc_b_start;
992
993         balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start,
994                 ic_b_end);
995         ic_b_work = ic_b_end - ic_b_start;
996     }
997 };
998
999 template <data_type_t src_type, data_type_t diff_dst_type,
1000           data_type_t diff_weights_type>
1001 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1002     diff_weights_type>::compute_diff_weights(const thread_info_t *ti) const {
1003     const memory_desc_wrapper src_d(pd()->src_pd(0));
1004     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
1005     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
1006
1007     const auto &jcp = kernel_->jcp;
1008     const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd;
1009
1010     diff_weights_data_t *diff_wei = ti->ithr_mb == 0
1011         ? (diff_weights_data_t*)ti->diff_weights
1012         : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1013     diff_weights_data_t *diff_bia = ti->ithr_mb == 0
1014         ? (diff_weights_data_t*)ti->diff_bias
1015         : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
1016           + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
1017
1018     // TODO: use memory descriptor with the same fmt as src (or use a macro :))
1019     auto tr_src_off = [&](int ithr_mb, int ic, int ij) {
1020         const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1021         const size_t tr_chn_size = tr_row_size * jcp.ih;
1022         const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups;
1023
1024         return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size;
1025     };
1026
1027     auto uker_trans = [&](int img) {
1028         const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih;
1029
1030         int start{0}, end{0};
1031         balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end);
1032         const int my_work = end - start;
1033
1034         int g{0}, ic_b{0}, j{0};
1035         nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih);
1036         g += ti->g_start;
1037         ic_b += ti->ic_b_start;
1038
1039         const int _ic = g * jcp.nb_ic + ic_b;
1040         src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)];
1041         src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)];
1042
1043         assert(jcp.ic_block == 16);
1044         const int src_stride = jcp.iw * jcp.ic_block;
1045         const int tr_src_stride = jcp.tr_iw * jcp.ic_block;
1046
1047         const int pf_depth = 2;
1048         struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth];
1049
1050         for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) {
1051             pf_circ_buf[iwork % pf_depth] = {src1, tr_src1};
1052
1053             if (iwork >= pf_depth - 1) {
1054                 int old_idx = (iwork - pf_depth + 1) % pf_depth;
1055                 auto ctx = jit_trans_src_t::ctx_t();
1056                 ctx.src = pf_circ_buf[old_idx].src;
1057                 ctx.tr_src = pf_circ_buf[old_idx].tr_src;
1058                 ctx.src_prf = src1;
1059                 ctx.tr_src_prf = tr_src1;
1060                 (*trans_kernel_)(&ctx);
1061             }
1062             src1 += src_stride;
1063             tr_src1 += tr_src_stride;
1064         }
1065 #if 0
1066         // reference transposition
1067         const int l_pad = jcp.l_pad;
1068         const int iwlp = l_pad + jcp.iw;
1069         const int tr_iw = jcp.tr_iw;
1070
1071         for (size_t iwork = start; iwork < end; iwork++) {
1072             PRAGMA_OMP_SIMD()
1073 #           pragma unroll
1074             for (int i = 0; i < l_pad; i++)
1075                 for (int j = 0; j < jcp.ic_block; j++)
1076                     tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
1077
1078             PRAGMA_OMP_SIMD()
1079 #           pragma unroll
1080             for (int i = l_pad; i < iwlp; i++)
1081                 for (int j = 0; j < jcp.ic_block; j++)
1082                     tr_src1[j * jcp.tr_iw + i]
1083                         = (src_data_t)src1[(i - l_pad) * 16 + j];
1084
1085             PRAGMA_OMP_SIMD()
1086 #           pragma unroll
1087             for (int i = iwlp; i < tr_iw; i++)
1088                 for (int j = 0; j < jcp.ic_block; j++)
1089                     tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
1090
1091              src1 += src_stride;
1092              tr_src1 += tr_src_stride;
1093          }
1094 #endif
1095     };
1096
1097     auto tr_diff_dst_off = [&](int ithr_mb, int oc, int oj) {
1098         const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1099         const size_t tr_chn_size = tr_row_size * jcp.oh;
1100         const size_t tr_img_size = tr_chn_size * jcp.nb_oc * jcp.ngroups;
1101         return ti->ithr_mb * tr_img_size + oc * tr_chn_size + oj * tr_row_size;
1102     };
1103
1104     auto diff_dst_trans = [&](int img) {
1105         const size_t work_amount = ti->g_work * ti->oc_b_work * jcp.oh;
1106
1107         size_t start{0}, end{0};
1108         balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, start, end);
1109         const int my_work = end - start;
1110
1111         int g{0}, oc_b{0}, j{0};
1112         nd_iterator_init(start, g, ti->g_work, oc_b, ti->oc_b_work, j, jcp.oh);
1113         g += ti->g_start;
1114         oc_b += ti->oc_b_start;
1115         const int oc = g * jcp.nb_oc + oc_b;
1116         const diff_dst_data_t *diff_dst1
1117             = &ti->diff_dst[diff_dst_d.blk_off(img, oc, j)];
1118         diff_dst_data_t *tr_diff_dst1
1119             = &ti->tr_diff_dst[tr_diff_dst_off(img, oc, j)];
1120
1121
1122         assert(jcp.ic_block == 16);
1123         const int diff_dst_stride = jcp.ow * jcp.oc_block;
1124         const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block;
1125
1126         const int pf_depth = 2;
1127         struct { diff_dst_data_t *diff_dst, *tr_diff_dst; }
1128             pf_circ_buf[pf_depth];
1129
1130         for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) {
1131             pf_circ_buf[iwork % pf_depth]
1132                 = {(diff_dst_data_t*)diff_dst1, tr_diff_dst1};
1133
1134             if (iwork >= pf_depth - 1) {
1135                 int old_idx = (iwork - pf_depth + 1) % pf_depth;
1136                 auto ctx = jit_trans_dst_t::ctx_t();
1137                 ctx.src = pf_circ_buf[old_idx].diff_dst;
1138                 ctx.tr_src = pf_circ_buf[old_idx].tr_diff_dst;
1139                 ctx.src_prf = diff_dst1;
1140                 ctx.tr_src_prf = tr_diff_dst1;
1141                 (*trans_dst_kernel_)(&ctx);
1142             }
1143             diff_dst1 += diff_dst_stride;
1144             tr_diff_dst1 += tr_diff_dst_stride;
1145         }
1146 #if 0
1147         // reference transposition
1148         int r_pad = jcp.ow % 2;
1149         for(size_t work = start; work < end; ++work) {
1150
1151             for (int j = 0; j < jcp.oc_block; ++j) {
1152 #               pragma unroll
1153                 for (int i = 0; i < jcp.ow / 2; i++) {
1154                     tr_diff_dst1[i*jcp.oc_block*2 + j*2] =
1155                        diff_dst1[2*i*jcp.oc_block + j];
1156                     tr_diff_dst1[i*jcp.oc_block*2 + j*2 + 1] =
1157                        diff_dst1[(2*i+1)*jcp.oc_block + j];
1158                 }
1159                 if (r_pad != 0) {
1160                     const int last_w = jcp.ow / 2;
1161                     tr_diff_dst1[last_w * jcp.oc_block * 2 + j * 2] =
1162                        diff_dst1[last_w * jcp.oc_block * 2 + j];
1163                     tr_diff_dst1[last_w * jcp.oc_block * 2 + j * 2 + 1] =
1164                         diff_dst_data_t{0};
1165                 }
1166
1167             }
1168
1169             diff_dst1 += diff_dst_stride;
1170             tr_diff_dst1 += tr_diff_dst_stride;
1171         }
1172 #endif
1173     };
1174
1175     if (jcp.is_1stconv && jcp.ver == ver_4fma) {
1176         /* prepare contexts */
1177         auto tr_ctx = jit_trans_src_t::ctx_t();
1178         tr_ctx.tr_src = ti->tr_src
1179             + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld;
1180
1181         assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1));
1182         tr_ctx.nthr_oc_b = nthr_oc_b_;
1183         int ih_start{0}, ih_end{0};
1184         balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end);
1185         tr_ctx.tr_src_ih_start = ih_start;
1186         tr_ctx.tr_src_ih_end = ih_end;
1187         tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc;
1188
1189         auto p = jit_conv_call_s();
1190         p.src = tr_ctx.tr_src;
1191
1192         /* zero diff_bias if applicable */
1193         if (jcp.with_bias && ti->ithr_ic_b == 0) {
1194             assert(jcp.oc_block == 16);
1195             for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1196                 diff_weights_data_t *db = &diff_bia[oc_b * 16];
1197                 for (int o = 0; o < 16; ++o)
1198                     db[o] = 0;
1199             }
1200         }
1201
1202         for (int img = ti->img_start; img < ti->img_end; ++img) {
1203             p.flags = (img == ti->img_start) * FLAG_MB_FIRST;
1204
1205             for (int g = ti->g_start; g < ti->g_end; ++g) {
1206             for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1207                 const int _ic = g * jcp.nb_ic + ic_b;
1208                 tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)];
1209
1210                 (*trans_kernel_)(&tr_ctx);
1211
1212                 if (ic_b == 0)
1213                     p.flags |= FLAG_IC_FIRST;
1214                 else
1215                     p.flags &= ~FLAG_IC_FIRST;
1216
1217                 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1218                     const int _oc = g * jcp.nb_oc + oc_b;
1219                     p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
1220
1221                     const size_t off =
1222                         wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1223                     p.filt = diff_wei + off;
1224                     p.bias = diff_bia + _oc * jcp.oc_block;
1225
1226                     kernel_->jit_ker(&p);
1227                 }
1228             }
1229             }
1230         }
1231     } else {
1232         for (int img = ti->img_start; img < ti->img_end; ++img) {
1233             auto p = jit_conv_call_s();
1234
1235             if (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
1236                 /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
1237                 using simple_barrier::barrier;
1238                 if (nthr_oc_b_ > 1)
1239                     barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1240                 uker_trans(img);
1241                 if (nthr_oc_b_ > 1)
1242                     barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1243             }
1244
1245             if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
1246                 /* tr_diff_dst[nb_oc][OW][oh][16c][2ow]
1247                  *  <- diff_dst[nb_oc][oh][ow][16c] */
1248                 if (nthr_ic_b_ > 1)
1249                     barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1250                 diff_dst_trans(img);
1251                 if (nthr_ic_b_ > 1)
1252                     barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1253             }
1254
1255             for (int g = ti->g_start; g < ti->g_end; ++g) {
1256             for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1257             for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1258                 const int _oc = g * jcp.nb_oc + oc_b;
1259                 const int _ic = g * jcp.nb_ic + ic_b;
1260
1261                 jit_conv_ker_pipeline(kernel_->jit_ker, p,
1262                          (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
1263                          ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
1264                          : &ti->src[src_d.blk_off(img, _ic)]),
1265                          utils::one_of(jcp.ver, ver_4vnni, ver_vnni)
1266                          ? &ti->tr_diff_dst[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
1267                          : &ti->diff_dst[diff_dst_d.blk_off(img, _oc)],
1268                         diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
1269                         0, (img == ti->img_start), 0, 0);
1270
1271             }
1272             }
1273             }
1274
1275             const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
1276             const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
1277             jit_conv_ker_pipeline(kernel_->jit_ker, p,
1278                     (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
1279                      ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
1280                      : &ti->src[src_d.blk_off(img + 1, _ic)]),
1281                     utils::one_of(jcp.ver, ver_4vnni, ver_vnni)
1282                     ? &ti->tr_diff_dst[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
1283                     : &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
1284                     diff_wei + wht_blk_off(
1285                         diff_weights_d, ti->g_start,
1286                         ti->oc_b_start, ti->ic_b_start),
1287                     0, 0, 0, 0);
1288         }
1289     }
1290 }
1291
1292 template <data_type_t src_type, data_type_t diff_dst_type,
1293           data_type_t diff_weights_type>
1294 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1295     diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) const
1296 {
1297     const memory_desc_wrapper src_d(pd()->src_pd(0));
1298     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
1299     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
1300
1301     const auto &jcp = kernel_->jcp;
1302     const int wei_size
1303             = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd;
1304
1305     diff_weights_data_t *diff_wei = ti->ithr_mb == 0
1306         ? (diff_weights_data_t*)ti->diff_weights
1307         : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1308     diff_weights_data_t *diff_bia = ti->ithr_mb == 0
1309         ? (diff_weights_data_t*)ti->diff_bias
1310         : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
1311           + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
1312
1313     const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
1314     const int input_step = jcp.stride_d * jcp.ih * jcp.iw * inp_mult;
1315     const int output_step = jcp.ow * jcp.oh * jcp.oc_block;
1316     int img{0}, od_s{0};
1317     int img_start = ti->img_start, img_end = ti->img_end;
1318     nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
1319     const int img_first = img;
1320
1321     while (img_start < img_end) {
1322         auto p = jit_conv_call_s();
1323
1324         int work_rem = img_end - img_start;
1325         const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
1326         const int id_s = od_s * jcp.stride_d;
1327         const int ik_overlap = nstl::max(0, id_s - jcp.f_pad);
1328         const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s);
1329         const int kd_back_pad = nstl::max(0, id_s + 1 + jcp.back_pad - jcp.od);
1330         int kd_pad_off = kd_front_pad * jcp.kh * jcp.kw * jcp.ic_block
1331                 * jcp.oc_block * jcp.typesize_out;
1332
1333         for (int g = ti->g_start; g < ti->g_end; ++g) {
1334         for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1335         for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1336             const int _oc = g * jcp.nb_oc + oc_b;
1337             const int _ic = g * jcp.nb_ic + ic_b;
1338
1339             auto src = &ti->src[src_d.blk_off(img, _ic)
1340                     + ik_overlap * input_step];
1341             auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)
1342                     + od_s * output_step];
1343
1344             jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst,
1345                     diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
1346                     diff_bia + _oc * 16, (img == img_first), od_s, od_e,
1347                     jcp.kd - nstl::max(kd_front_pad, kd_back_pad), kd_pad_off);
1348
1349             if (ic_b == 0) p.flags = 0;
1350             else p.flags = 1;
1351         }
1352         }
1353         }
1354
1355         const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
1356         const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
1357         jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p,
1358                 &ti->src[src_d.blk_off(img + 1, _ic)],
1359                 &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
1360                 diff_wei + wht_blk_off(diff_weights_d, ti->g_start,
1361                     ti->oc_b_start, ti->ic_b_start),
1362                 diff_bia, 0, 0, 0, 0, 0);
1363         nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
1364     }
1365 }
1366
1367 template <data_type_t src_type, data_type_t diff_dst_type,
1368           data_type_t diff_weights_type>
1369 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1370     diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const {
1371     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
1372
1373     const auto &jcp = kernel_->jcp;
1374     const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
1375     const int bia_size = jcp.ngroups * jcp.oc;
1376     const diff_weights_data_t *diff_bias_ws
1377         = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size;
1378
1379     /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1380     simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1381
1382     const int ic_b_kh_work = ti->ic_b_work * jcp.kh;
1383     const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1384
1385     int start{0}, end{0};
1386     balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1387     if (start == end) return;
1388
1389     for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1390         int w = start;
1391         int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
1392         nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1393                 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1394         while (w < end) {
1395             const int g = ti->g_start + sub_g_start;
1396             const int oc_b = ti->oc_b_start + sub_oc_b_start;
1397             const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh;
1398             const int kh = sub_ic_b_kh_start % jcp.kh;
1399
1400             const int acc_size
1401                 = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
1402                 * jcp.kw * jcp.ic_block * jcp.oc_block;
1403
1404             const size_t off
1405                 = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh);
1406
1407             diff_weights_data_t *d
1408                 = (diff_weights_data_t *)ti->diff_weights + off;
1409             diff_weights_data_t *s
1410                 = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
1411
1412             acc_ker_->accumulate(d, s, acc_size);
1413
1414             nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1415                     ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1416         }
1417
1418         if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) {
1419             if (ti->ithr == 0)
1420                 acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias,
1421                     diff_bias_ws, bia_size);
1422             diff_bias_ws += bia_size;
1423         }
1424     }
1425 }
1426
1427 template <data_type_t src_type, data_type_t diff_dst_type,
1428           data_type_t diff_weights_type>
1429 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1430     diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) const {
1431     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
1432
1433     const auto &jcp = kernel_->jcp;
1434     const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw
1435         * jcp.kd;
1436
1437     /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1438     simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1439
1440     const int ic_b_kh_work = ti->ic_b_work * jcp.kd;
1441     const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1442
1443     int start{0}, end{0};
1444     balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1445     if (start == end) return;
1446
1447     for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1448         int w = start;
1449         int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
1450         nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1451                 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1452         while (w < end) {
1453             const int g = ti->g_start + sub_g_start;
1454             const int oc_b = ti->oc_b_start + sub_oc_b_start;
1455             const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd;
1456             const int kd = sub_ic_b_kh_start % jcp.kd;
1457
1458             const int acc_size
1459                 = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
1460                 * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh;
1461
1462             const size_t off
1463                 = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd);
1464             diff_weights_data_t *d
1465                 = (diff_weights_data_t *)ti->diff_weights + off;
1466             diff_weights_data_t *s
1467                 = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
1468             acc_ker_->accumulate(d, s, acc_size);
1469
1470             nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1471                     ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1472         }
1473     }
1474 }
1475
1476 template <data_type_t src_type, data_type_t diff_dst_type,
1477           data_type_t diff_weights_type>
1478 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1479     diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const {
1480     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
1481
1482     auto rb = this->reducer_bias_;
1483     assert(nthr_ == rb->balancer().nthr_);
1484
1485     const auto reducer_bia_scratchpad = memory_tracking::grantor_t(
1486             ti->scratchpad, prefix_reducer_bia);
1487
1488     const auto &jcp = kernel_->jcp;
1489
1490     if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return;
1491
1492     const int b_job_start = rb->balancer().ithr_job_off(ti->ithr);
1493     const int b_njobs = rb->balancer().ithr_njobs(ti->ithr);
1494
1495     if (b_njobs == 0) return;
1496
1497     /* reduction dimension */
1498     int img_start{0}, img_end{0};
1499     balance211(jcp.mb, rb->balancer().nthr_per_group_,
1500             rb->balancer().id_in_group(ti->ithr), img_start, img_end);
1501
1502     /* jobs */
1503     int g_start{0}, ocb_start{0};
1504     nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc);
1505     for (int img = img_start; img < img_end; ++img) {
1506         int g = g_start, ocb = ocb_start;
1507         for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
1508             const size_t _oc = g * jcp.nb_oc + ocb;
1509
1510             const diff_dst_data_t *d_dst
1511                 = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
1512             diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr,
1513                     ti->diff_bias, reducer_bia_scratchpad)
1514                 + b_job_loc * rb->balancer().job_size_;
1515
1516             if (img == img_start)
1517                 for (int o = 0; o < 16; ++o)
1518                     d_bias[o] = 0;
1519             for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) {
1520                 PRAGMA_OMP_SIMD()
1521                 for (int o = 0; o < 16; ++o)
1522                     d_bias[o] += d_dst[o];
1523                 d_dst += 16;
1524             }
1525
1526             nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
1527         }
1528     }
1529
1530     rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad);
1531 }
1532
1533 template <data_type_t src_type, data_type_t diff_dst_type,
1534           data_type_t diff_weights_type>
1535 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1536     diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) const {
1537
1538     const auto &jcp = kernel_->jcp;
1539
1540     const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh
1541         * jcp.kw * jcp.kd;
1542     const int bia_size = jcp.ngroups * jcp.oc;
1543     const diff_weights_data_t *diff_bias_ws
1544             = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size;
1545
1546     if (nthr_mb_ > 1) mkldnn_thr_barrier();
1547
1548     if (ti->ithr == 0)
1549     {
1550         for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1551             acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size);
1552             diff_bias_ws += bia_size;
1553         }
1554     }
1555 }
1556
1557 template <data_type_t src_type, data_type_t diff_dst_type,
1558           data_type_t diff_weights_type>
1559 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1560     diff_weights_type>::prepare_scratchpad_data() const
1561 {
1562     const auto &j = pd()->jcp_;
1563     auto scratchpad = this->scratchpad();
1564
1565     if (utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
1566         if (!j.is_1stconv) {
1567             // XXX: See the comment about tr_iw and guarding elements in
1568             // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
1569             const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic;
1570             const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw;
1571
1572             auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
1573             /* to avoid NaNs in computations we zero tail num_guard_elems for
1574              * each possible thread group */
1575
1576             for (int ithr = 1; ithr <= max_nthr; ++ithr) {
1577                 src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr];
1578                 for (int i = 0; i < j.tr_src_num_guard_elems; ++i)
1579                     ts[i] = 0;
1580             }
1581         }
1582
1583         if (j.nthr_oc_b > 1) {
1584             const int tr_src_bctx_size = j.nthr / j.nthr_oc_b;
1585             auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
1586                     key_conv_tr_src_bctx);
1587             for (int i = 0; i < tr_src_bctx_size; ++i)
1588                 simple_barrier::ctx_init(&tr_src_bctx[i]);
1589         }
1590
1591         if (utils::one_of(j.ver, ver_4vnni, ver_vnni) && j.nthr_ic_b > 1) {
1592             const int tr_diff_dst_bctx_size = j.nthr / j.nthr_ic_b;
1593             auto tr_diff_dst_bctx =
1594                 scratchpad.template get<simple_barrier::ctx_t>(
1595                         key_conv_tr_diff_dst_bctx);
1596                 for (int i = 0; i < tr_diff_dst_bctx_size; ++i)
1597                     simple_barrier::ctx_init(&tr_diff_dst_bctx[i]);
1598         }
1599     }
1600
1601     if (nthr_mb_ > 1) {
1602         simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
1603                     key_conv_wei_bia_reduction_bctx));
1604     }
1605
1606     const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
1607             prefix_reducer_bia);
1608     auto rb = this->reducer_bias_;
1609     rb->init(reducer_bia_scratchpad);
1610 }
1611
1612 template <data_type_t src_type, data_type_t diff_dst_type,
1613           data_type_t diff_weights_type>
1614 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1615     diff_weights_type>::execute_backward_weights() const {
1616     prepare_scratchpad_data();
1617
1618     parallel(nthr_, [&](const int ithr, const int nthr) {
1619         assert(nthr_ == nthr);
1620
1621         thread_info_t thread_info(this, ithr);
1622
1623         if (utils::one_of(pd()->ndims(), 3, 4)) {
1624             compute_diff_weights(&thread_info);
1625             if (nthr_mb_ > 1) reduce_diff_weights(&thread_info);
1626             if (pd()->with_bias()) compute_diff_bias(&thread_info);
1627         } else if (pd()->ndims() == 5) {
1628             compute_diff_weights_3d(&thread_info);
1629             if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
1630             if (pd()->with_bias()) compute_diff_bias_3d(&thread_info);
1631         } else {
1632             assert(false);
1633         }
1634     });
1635
1636     /* TODO: put that into compute_diff_bias() */
1637     if (pd()->wants_padded_bias()) {
1638         auto diff_bias = scratchpad().template get<const diff_weights_data_t>(
1639                 key_conv_padded_bias);
1640         auto diff_bias_in
1641             = reinterpret_cast<diff_weights_data_t *>(this->memory(1));
1642         for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
1643             diff_bias_in[oc] = diff_bias[oc];
1644     }
1645 }
1646
1647 template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>;
1648 template struct jit_avx512_common_convolution_bwd_weights_t<data_type::s16,
1649     data_type::s16, data_type::s32>;
1650
1651 }
1652 }
1653 }
1654
1655 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s