Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_1x1_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21 #include <cstring>
22 #include "jit_generator.hpp"
23
24 #include "jit_avx2_1x1_convolution.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
34
35 #define data_blk_off(f, n, c, h, w) \
36     ((ndims == 3) \
37     ? (f).blk_off(n, c, w) \
38     : (f).blk_off(n, c, h, w))
39
40 /* convolution forward */
41
42 void jit_avx2_1x1_convolution_fwd_t::execute_forward() const {
43     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
44     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
45     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
46     auto dst = reinterpret_cast<data_t *>(this->memory());
47
48     const memory_desc_wrapper src_d(pd()->src_pd());
49     const memory_desc_wrapper dst_d(pd()->dst_pd());
50     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
51
52     auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
53
54     const auto &jcp = kernel_->jcp;
55     const int MB = pd()->MB();
56
57     const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
58     const int ndims = dst_d.ndims();
59
60     const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
61     const int stride_w = pd()->desc()->strides[ndims - 3];
62     const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
63     const int pad_l = pd()->desc()->padding[0][ndims - 3];
64
65     auto step = [](int default_step, int remaining, int tail_step) {
66         assert(default_step <= tail_step);
67         return remaining < tail_step ? remaining : default_step;
68     };
69
70     auto ker = [&](const int ithr, const int nthr) {
71         // TODO (Roma): remove this restriction
72         assert(jcp.stride_w == 1 && jcp.stride_h == 1);
73
74             auto p = jit_1x1_conv_call_s();
75             auto rp = rtus_driver_t<avx2>::call_params_t();
76
77         const int nb_oc = jcp.nb_load;
78         const int nb_ic = jcp.nb_reduce;
79         const int nb_ic_blocking = jcp.nb_reduce_blocking;
80         const int os_block = jcp.bcast_block;
81
82         int start{0}, end{0};
83         balance211(work_amount, nthr, ithr, start, end);
84
85         int iwork = start;
86         while (iwork < end) {
87             int n{0}, g{0}, osb{0};
88             nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
89                     jcp.nb_bcast);
90
91             int bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
92                     jcp.nb_bcast_blocking_max);
93             bcast_step = nstl::min(bcast_step, end - iwork);
94
95             const int os = osb * os_block;
96             const int oh = os / jcp.ow;
97             const int ow = os % jcp.ow;
98
99             const int ih = nstl::max(oh * stride_h - pad_t, 0);
100             const int iw = nstl::max(ow * stride_w - pad_l, 0);
101             rp.iw_start = iw;
102
103             p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
104             rp.os = p.bcast_dim;
105
106             int ocb = 0;
107             while (ocb < jcp.nb_load) {
108                 const int load_step = step(jcp.nb_load_blocking,
109                         jcp.nb_load - ocb, jcp.nb_load_blocking_max);
110
111                 const int _ocb = g * nb_oc + ocb;
112                 p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
113                         load_step * jcp.oc_block);
114                 const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
115
116                 p.output_data = &dst[dst_off];
117
118                 p.bias_data = &bias[_ocb * jcp.oc_block];
119
120                 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
121                     p.first_last_flag = 0
122                         | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
123                         | (icb + nb_ic_blocking >= nb_ic
124                                 ? FLAG_REDUCE_LAST : 0);
125
126                     p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
127                             nb_ic_blocking * jcp.ic_block);
128                     rp.icb = p.reduce_dim / jcp.reduce_block;
129
130                     p.load_data = &weights[pd()->with_groups()
131                         ? weights_d.blk_off(g, ocb, icb)
132                         : weights_d.blk_off(ocb, icb)];
133
134                     const int _icb = g * nb_ic + icb;
135                     if (pd()->rtus_.reduce_src_) {
136                         rp.ws = rtus_space
137                             + ithr * pd()->rtus_.space_per_thread_
138                             + _icb * jcp.is * jcp.ic_block;
139
140                         if (ocb == 0) {
141                             rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
142                             rtus_driver_->ker_(&rp);
143                         }
144
145                         p.bcast_data = rp.ws;
146                     } else
147                         p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw);
148
149                     p.oc_off = _ocb * jcp.oc_block * sizeof(float);
150
151                     kernel_->jit_ker(&p);
152                 }
153
154                 ocb += load_step;
155             }
156
157             iwork += bcast_step;
158         }
159     };
160
161     if (pd()->wants_padded_bias()) {
162         auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
163         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
164         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
165                 jcp.oc - jcp.oc_without_padding);
166         bias = padded_bias;
167     }
168
169     parallel(0, ker);
170
171     if (pd()->wants_zero_pad_dst())
172         output_memory_primitive(0)->zero_pad();
173 }
174
175 void jit_avx2_1x1_convolution_fwd_t::execute_forward_with_dw_conv() const {
176     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
177     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
178     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
179     auto dst = reinterpret_cast<data_t *>(this->memory());
180
181     const memory_desc_wrapper src_d(pd()->src_pd());
182     const memory_desc_wrapper dst_d(pd()->dst_pd());
183     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
184
185     auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
186
187     const auto &jcp = kernel_->jcp;
188     const auto &jcp_dw = kernel_dw_->jcp;
189     const int MB = pd()->MB();
190
191     auto dw_bias = jcp_dw.conv_biases;
192
193     int ocb_work = jcp.with_dw_conv ? utils::div_up(jcp.nb_load, jcp.nb_load_blocking) : 1;
194     const int work_amount = MB * jcp.ngroups * ocb_work * jcp.nb_bcast;
195
196     auto step = [](int default_step, int remaining, int tail_step) {
197         assert(default_step <= tail_step);
198         return remaining < tail_step ? remaining : default_step;
199     };
200
201     auto ker = [&](const int ithr, const int nthr) {
202         // TODO (Roma): remove this restriction
203         assert(jcp.stride_w == 1 && jcp.stride_h == 1);
204
205         auto compute_block_1x1 = [&](float* ws_p, int n, int g, int oh, int ow, int ih, int iw, int os, int os_block, int bcast_step, int ocb, int load_step,
206                                     int num_rows) {
207             auto rp = rtus_driver_t<avx2>::call_params_t();
208             auto p = jit_1x1_conv_call_s();
209
210             for (int h = 0; h < num_rows; h++) {
211                 ih = nstl::max((oh + h) * jcp.stride_h - jcp.t_pad, 0);
212
213                 if ((oh + h) < 0 || (oh + h) >= jcp.ih) {
214                     for (int chb = ocb; chb < ocb + load_step; chb++) {
215                         memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
216                                (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
217                     }
218                 } else {
219                     const int _ocb = g * jcp.nb_load + ocb;
220
221                     rp.iw_start = iw;
222                     p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
223
224                     rp.os = p.bcast_dim;
225                     p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, load_step * jcp.oc_block);
226
227                     p.output_data = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block];
228
229                     p.bias_data = &bias[_ocb * jcp.oc_block];
230
231                     for (int icb = 0; icb < jcp.nb_reduce; icb += jcp.nb_reduce_blocking) {
232                         p.first_last_flag = 0
233                                             | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
234                                             | (icb + jcp.nb_reduce_blocking >= jcp.nb_reduce
235                                                ? FLAG_REDUCE_LAST : 0);
236
237                         p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
238                                                        jcp.nb_reduce_blocking * jcp.ic_block);
239                         rp.icb = p.reduce_dim / jcp.reduce_block;
240
241                         p.load_data = &weights[pd()->with_groups()
242                                                ? weights_d.blk_off(g, ocb, icb)
243                                                : weights_d.blk_off(ocb, icb)];
244
245                         const int _icb = g * jcp.nb_reduce + icb;
246                         if (pd()->rtus_.reduce_src_) {
247                             rp.ws = rtus_space
248                                     + ithr * pd()->rtus_.space_per_thread_
249                                     + _icb * jcp.is * jcp.ic_block;
250
251                             if (ocb == 0) {
252                                 rp.src = src + src_d.blk_off(n, _icb, ih, iw);
253                                 rtus_driver_->ker_(&rp);
254                             }
255
256                             p.bcast_data = rp.ws;
257                         } else {
258                             p.bcast_data = src + src_d.blk_off(n, _icb, ih, iw);
259                         }
260
261                         p.oc_off = _ocb * jcp.oc_block * sizeof(float);
262
263                         kernel_->jit_ker(&p);
264                     }
265                 }
266             }
267         };
268
269         auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int load_step, int dst_idx) {
270
271             for (int chb = ocb; chb < ocb + load_step; chb++) {
272                 auto par_conv_dw = jit_conv_call_s();
273
274                 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
275                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
276                 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
277                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
278                 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
279                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
280
281                 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + chb*jcp_dw.ch_block*jcp_dw.oh*jcp_dw.ow +
282                                        dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
283
284                 par_conv_dw.kh_padding = jcp_dw.kh;
285                 par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
286                 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
287                 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
288                 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
289                 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
290
291                 kernel_dw_->jit_ker(&par_conv_dw);
292             }
293         };
294
295         assert(jcp.stride_w == 1 && jcp.stride_h == 1);
296
297         int start{0}, end{0};
298         balance211(work_amount, nthr, ithr, start, end);
299
300         auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
301         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
302         auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
303
304         const int os_block = jcp.iw;
305
306         int iwork = start;
307         while (iwork < end) {
308             int n{0}, g{0}, ocbb{0}, osb{0};
309             nd_iterator_init(iwork, n, MB, g, jcp.ngroups, ocbb, ocb_work, osb,
310                              jcp.nb_bcast);
311             int bcast_step = 1;
312
313             const int os = osb * os_block;
314             const int oh = os / jcp.ow;
315             const int ow = os % jcp.ow;
316
317             const int ih = nstl::max(oh * jcp.stride_h - jcp.t_pad, 0);
318             const int iw = nstl::max(ow * jcp.stride_w - jcp.l_pad, 0);
319
320             int ocb = ocbb * jcp.nb_load_blocking;
321
322             const int load_step = step(jcp.nb_load_blocking,
323                                        jcp.nb_load - ocb, jcp.nb_load_blocking_max);
324
325             if (iwork == start || oh == 0) {
326                 bcast_step = nstl::min(1, end - iwork);
327                 compute_block_1x1(pbuf, n, g, oh - 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step + 2);
328             } else {
329                 bcast_step = nstl::min(1, end - iwork);
330                 compute_block_1x1(pbuf, n, g, oh + 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step);
331             }
332
333             if ((oh % jcp_dw.stride_h == 0)) {
334                 compute_row_dw(pbuf, n, ocb, load_step, oh);
335             }
336
337             iwork += bcast_step;
338         }
339     };
340
341     if (pd()->wants_padded_bias()) {
342         auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
343         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
344         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
345                 jcp.oc - jcp.oc_without_padding);
346         bias = padded_bias;
347
348         auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
349         utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
350         utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
351                          jcp.oc - jcp.oc_without_padding);
352         dw_bias = dw_padded_bias;
353     }
354
355     parallel(0, ker);
356
357     if (pd()->wants_zero_pad_dst())
358         output_memory_primitive(0)->zero_pad();
359 }
360
361 /* convolution backward wtr data */
362
363 void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data() const {
364     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
365     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
366     auto diff_src = reinterpret_cast<data_t *>(this->memory());
367
368     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
369     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
370     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
371
372     auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
373
374     const auto &jcp = kernel_->jcp;
375     const int MB = pd()->MB();
376
377     // TODO (Roma): remove this restriction
378     assert(jcp.stride_w == 1 && jcp.stride_h == 1);
379     const int ndims = diff_dst_d.ndims();
380
381     const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
382     const int stride_w = pd()->desc()->strides[ndims - 3];
383     const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
384     const int pad_l = pd()->desc()->padding[0][ndims - 3];
385
386     const int nb_ic = jcp.nb_load;
387     const int nb_oc = jcp.nb_reduce;
388     const int os_block = jcp.bcast_block;
389     const int nb_oc_blocking = jcp.nb_reduce_blocking;
390
391     const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
392
393     auto step = [](int default_step, int remaining, int tail_step) {
394         assert(default_step <= tail_step);
395         return remaining < tail_step ? remaining : default_step;
396     };
397
398     auto ker = [&](const int ithr, const int nthr) {
399         auto p = jit_1x1_conv_call_s();
400         auto rp = rtus_driver_t<avx2>::call_params_t();
401
402         int start{0}, end{0};
403         balance211(work_amount, nthr, ithr, start, end);
404
405         int load_step = 0;
406         for (int icb = 0; icb < jcp.nb_load; icb += load_step) {
407             load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
408                     jcp.nb_load_blocking_max);
409
410             p.load_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
411                     load_step * jcp.ic_block);
412             rp.icb = p.load_dim / jcp.ic_block;
413
414             int bcast_step;
415             for (int iwork = start; iwork < end; iwork += bcast_step) {
416                 int n{0}, g{0}, osb{0};
417                 nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
418                         jcp.nb_bcast);
419
420                 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
421                         jcp.nb_bcast_blocking_max);
422                 bcast_step = nstl::min(bcast_step, end - iwork);
423
424                 const int os = osb * os_block;
425                 p.bcast_dim = this_block_size(os, jcp.os,
426                         bcast_step * os_block);
427                 rp.os = p.bcast_dim;
428
429                 const int oh = os / jcp.ow;
430                 const int ow = os % jcp.ow;
431                 const int ih = nstl::max(oh * stride_h - pad_t, 0);
432                 const int iw = nstl::max(ow * stride_w - pad_l, 0);
433                 rp.iw_start = iw;
434
435                 const int _icb = g * nb_ic + icb;
436                 rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
437                 if (pd()->rtus_.reduce_src_) {
438                     rp.ws = rtus_space
439                         + ithr * pd()->rtus_.space_per_thread_;
440                     p.output_data = rp.ws;
441                 } else
442                     p.output_data = rp.src;
443
444                 for (int ocb = 0; ocb < jcp.nb_reduce;
445                         ocb += jcp.nb_reduce_blocking) {
446                     const int _ocb = g * nb_oc + ocb;
447                     size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh,
448                         ow);
449                     p.bcast_data = &diff_dst[diff_dst_off];
450
451                     p.load_data = &weights[pd()->with_groups()
452                         ? weights_d.blk_off(g, ocb, icb)
453                         : weights_d.blk_off(ocb, icb)];
454
455                     p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
456
457                     p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
458                             nb_oc_blocking * jcp.oc_block);
459
460                     kernel_->jit_ker(&p);
461                 }
462
463                 if (pd()->rtus_.reduce_src_)
464                     rtus_driver_->ker_(&rp);
465             }
466         }
467     };
468
469     parallel(0, ker);
470 }
471
472 /* convolution backward wtr weights */
473
474 jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t(
475         const pd_t *apd, const input_vector &inputs,
476         const output_vector &outputs)
477     : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr)
478     , rtus_driver_(nullptr)
479 {
480     kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, jit_conv_conf_t(), *pd()->attr());
481     reducer_weights_ =
482         new cpu_reducer_2d_t<data_type::f32>(pd()->reducer_wei_conf_);
483     reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
484     init_rtus_driver<avx2>(this);
485 }
486
487 void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() const {
488     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
489     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
490     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
491     auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
492
493     auto scratchpad = this->scratchpad();
494
495     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
496     const memory_desc_wrapper src_d(pd()->src_pd());
497     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
498     const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
499
500     const auto &jcp = kernel_->jcp;
501     auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
502
503     data_t *diff_bias = pd()->wants_padded_bias()
504         ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
505
506     auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
507             prefix_reducer_bia);
508     auto rb = this->reducer_bias_;
509     rb->init(reducer_bia_scratchpad);
510
511     auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
512             prefix_reducer_wei);
513     auto rw = this->reducer_weights_;
514     rw->init(reducer_wei_scratchpad);
515
516     const int ndims = diff_dst_d.ndims();
517     // TODO (Roma): remove this restriction
518     assert(jcp.stride_w == 1 && jcp.stride_h == 1);
519
520     const int nb_ic = jcp.nb_bcast;
521     const int nb_ic_blocking = jcp.nb_bcast_blocking;
522     const int bcast_work = div_up(nb_ic, nb_ic_blocking);
523
524     const int nb_oc = jcp.nb_load;
525     const int nb_oc_blocking = jcp.nb_load_blocking;
526     const int load_work = div_up(nb_oc, nb_oc_blocking);
527
528     const int sp_dim = jcp.reduce_dim;
529     const int mb_sp_work = jcp.mb * sp_dim;
530
531     const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
532     const int stride_w = pd()->desc()->strides[ndims - 3];
533     const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
534     const int pad_l = pd()->desc()->padding[0][ndims - 3];
535
536     auto step = [](int default_step, int remaining, int tail_step) {
537         assert(default_step <= tail_step);
538         return remaining < tail_step ? remaining : default_step;
539     };
540
541     auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image,
542             data_t *store_to, size_t store_to_ld, const data_t *diff_dst,
543             const data_t *src, int ithr) {
544         auto p = jit_1x1_conv_call_s();
545         auto rp = rtus_driver_t<avx2>::call_params_t();
546
547         p.output_stride = store_to_ld * sizeof(float);
548         const int sp_step_def = jcp.nb_reduce_blocking * jcp.reduce_block;
549
550         int oc_b_step = 0;
551         for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) {
552             oc_b_step = step(12, nb_oc_blocking - oc_b, 18);
553             p.load_dim = oc_b_step * jcp.oc_block;
554
555             int ic_b_step = 0;
556             for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) {
557                 ic_b_step = step(12, nb_ic_blocking - ic_b, 18);
558                 p.bcast_dim = ic_b_step * jcp.ic_block;
559                 rp.icb = p.bcast_dim / jcp.ic_block;
560
561                 p.output_data = store_to + oc_b * store_to_ld
562                     + ic_b * jcp.ic_block * jcp.oc_block;
563
564                 /* spatial reduction */
565                 int sp_step = 0;
566                 for (int sp = sp_start; sp < sp_end; sp += sp_step) {
567                     sp_step = step(sp_step_def, sp_end - sp, 192);
568                     p.reduce_dim = sp_step;
569                     rp.os = p.reduce_dim;
570
571                     p.first_last_flag = sp == sp_start && first_image
572                         ? FLAG_REDUCE_FIRST : 0;
573
574                     p.load_data = diff_dst
575                         + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block;
576
577                     if (pd()->rtus_.reduce_src_) {
578                         const int oh = sp / jcp.ow;
579                         const int ow = sp % jcp.ow;
580
581                         const int ih = nstl::max(oh * stride_h - pad_t, 0);
582                         const int iw = nstl::max(ow * stride_w - pad_l, 0);
583                         rp.iw_start = iw;
584
585                         rp.ws = rtus_space
586                             + ithr * pd()->rtus_.space_per_thread_
587                             + (ic_b * jcp.is + sp) * jcp.ic_block;
588                         if (ndims == 3)
589                             rp.src = src
590                                 + iw * src_d.blocking_desc().strides[0][2];
591                         else
592                             rp.src = src
593                                 + ih * src_d.blocking_desc().strides[0][2]
594                                 + iw * src_d.blocking_desc().strides[0][3];
595
596                         if (oc_b == 0)
597                             rtus_driver_->ker_(&rp);
598
599                         p.bcast_data = rp.ws;
600                     } else
601                         p.bcast_data = src
602                             + (ic_b * jcp.reduce_dim + sp) * jcp.ic_block;
603
604                     kernel_->jit_ker(&p);
605                 }
606             }
607         }
608     };
609
610     auto ker = [&](const int ithr, const int nthr) {
611         assert(nthr == rw->balancer().nthr_);
612
613         const int w_njobs = rw->balancer().ithr_njobs(ithr);
614         if (w_njobs == 0) return;
615
616         /* setup: independent work (oc, ic) */
617         const int w_job_start = rw->balancer().ithr_job_off(ithr);
618         int g{0}, load_i{0}, bcast_i{0};
619         nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work,
620                 bcast_i, bcast_work);
621
622         /* setup: reduction work (mb, sp) */
623         int mb_sp_start{0}, mb_sp_end{0};
624         balance211(mb_sp_work, rw->balancer().nthr_per_group_,
625                 rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end);
626         int img_start{0}, sp_start{0};
627         nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim);
628
629         /* independent work */
630         for (int iwork = 0; iwork < w_njobs; ++iwork) {
631             const int oc_b = nb_oc_blocking * load_i;
632             const int ic_b = nb_ic_blocking * bcast_i;
633
634             const int _ic_b = g * nb_ic + ic_b;
635             const int _oc_b = g * nb_oc + oc_b;
636
637             data_t *store_to;
638             size_t store_to_ld;
639
640             if (rw->balancer().nthr_per_group_ == 1) {
641                 const size_t off = pd()->with_groups()
642                     ? diff_weights_d.blk_off(g, oc_b, ic_b)
643                     : diff_weights_d.blk_off(oc_b, ic_b);
644                 store_to = &diff_weights[off];
645                 store_to_ld = jcp.ic * jcp.oc_block;
646             } else {
647                 const size_t off = iwork * rw->balancer().job_size_;
648                 store_to =
649                     rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off;
650                 store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block;
651             }
652
653             /* reduction work */
654             int img = img_start;
655             int sp = sp_start;
656             int sp_step = 0;
657             for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step)
658             {
659                 sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp);
660
661                 const bool first_image = img == img_start;
662                 oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to,
663                         store_to_ld, &diff_dst[diff_dst_d.blk_off(img, _oc_b)],
664                         &src[src_d.blk_off(img, _ic_b)], ithr);
665
666                 sp = 0;
667                 img += 1;
668             }
669
670             nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i,
671                              bcast_work);
672         }
673         rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
674     };
675
676     auto ker_bias = [&](int ithr, int nthr) {
677         assert(nthr == rb->balancer().nthr_);
678
679         const int b_job_start = rb->balancer().ithr_job_off(ithr);
680         const int b_njobs = rb->balancer().ithr_njobs(ithr);
681
682         if (b_njobs == 0) return;
683
684         /* reduction dimension */
685         int img_start{0}, img_end{0};
686         balance211(jcp.mb, rb->balancer().nthr_per_group_,
687                 rb->balancer().id_in_group(ithr), img_start, img_end);
688
689         /* jobs */
690         int g_start{0}, ocb_start{0};
691         nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc);
692
693         for (int img = img_start; img < img_end; ++img) {
694             int g = g_start, ocb = ocb_start;
695             for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
696                 const size_t _oc = g * nb_oc + ocb;
697
698                 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
699                 data_t *d_bias =
700                     rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad)
701                     + b_job_loc * rb->balancer().job_size_;
702
703                 if (img == img_start)
704                     for (int o = 0; o < 8; ++o) d_bias[o] = 0.;
705
706                 for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) {
707                     PRAGMA_OMP_SIMD()
708                     for (int o = 0; o < 8; ++o)
709                         d_bias[o] += d_dst[o];
710                     d_dst += 8;
711                 }
712
713                 nd_iterator_step(g, jcp.ngroups, ocb, nb_oc);
714             }
715         }
716         rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
717     };
718
719     parallel(0, [&](const int ithr, const int nthr) {
720         ker(ithr, nthr);
721         if (pd()->with_bias())
722             ker_bias(ithr, nthr);
723     });
724
725     /* TODO: put this in ker_bias */
726     if (pd()->wants_padded_bias()) {
727         assert(jcp.ngroups == 1);
728         for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
729             diff_bias_in[oc] = diff_bias[oc];
730     }
731 }
732
733 }
734 }
735 }