Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_dw_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "memory_tracking.hpp"
19 #include "mkldnn_thread.hpp"
20
21 #include "jit_uni_dw_convolution.hpp"
22
23 namespace mkldnn {
24 namespace impl {
25 namespace cpu {
26
27 using namespace mkldnn::impl::status;
28 using namespace mkldnn::impl::memory_format;
29 using namespace mkldnn::impl::memory_tracking::names;
30 using namespace mkldnn::impl::utils;
31
32 template <cpu_isa_t isa>
33 void _jit_uni_dw_convolution_fwd_t<isa>::execute_forward() const {
34     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
35     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
36     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
37     auto dst = reinterpret_cast<data_t *>(this->memory());
38
39     const memory_desc_wrapper src_d(pd()->src_pd());
40     const memory_desc_wrapper dst_d(pd()->dst_pd());
41     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
42     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
43
44     const auto &jcp = kernel_->jcp;
45
46     if (pd()->wants_padded_bias()) {
47         auto padded_bias = this->scratchpad().template get<data_t>(
48                 key_conv_padded_bias);
49         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
50         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
51                 jcp.oc - jcp.oc_without_padding);
52         bias = padded_bias;
53     }
54
55     int dil_h = jcp.dilate_h + 1;
56     int dil_w = jcp.dilate_w + 1;
57     int str_h = jcp.stride_h;
58     int str_w = jcp.stride_w;
59
60     auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh,
61             int kh_padding, int ch, int ch_num, int n) {
62         auto par_conv = jit_conv_call_s();
63
64         const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
65         const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
66             + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw;
67
68         const int iw = nstl::max((ow*str_w - jcp.l_pad
69             + div_up(i_l_overflow, dil_w)*dil_w), 0);
70         const int kw = div_up(i_l_overflow, dil_w);
71
72         const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w)
73             - div_up(i_r_overflow, dil_w);
74
75         par_conv.src = &src[src_d.blk_off(n, ch, ih, iw)];
76         par_conv.dst = &dst[dst_d.blk_off(n, ch, oh, ow)];
77
78         par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)];
79         if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)];
80
81         par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
82         par_conv.kw_padding = (size_t)nstl::max(0, kw_padding);
83
84         par_conv.ur_w = (size_t)ur_w_step;
85
86         par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch;
87         par_conv.oc_off = ch * jcp.ch_block * sizeof(float);
88
89         return par_conv;
90     };
91
92     int MB = pd()->MB();
93     const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
94     parallel_nd(MB, chb_work, jcp.oh,
95             [&](int n, int chb, int oh) {
96         int ch = chb * jcp.nb_ch_blocking;
97         int ch_num = jcp.nb_ch_blocking;
98
99         const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h));
100         const int i_b_overflow = nstl::max(jcp.ih,
101             (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih;
102
103         const int ih = nstl::max((int)(oh*str_h - jcp.t_pad
104             + div_up(i_t_overflow, dil_h)*dil_h), 0);
105         const int kh = div_up(i_t_overflow, dil_h);
106         const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h)
107             - div_up(i_b_overflow, dil_h);
108
109         // left border
110         int ow = 0;
111         int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow);
112         int ur_w_step = 1;
113         for (; ow < l_border; ow++) {
114             jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
115                                         kh, kh_padding, ch, ch_num, n);
116
117             kernel_->jit_ker(&par_conv);
118         }
119
120         // main loop
121         ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1)
122             / jcp.stride_w - ow + 1;
123         if (ur_w_step > 0) {
124             jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
125                                         kh, kh_padding, ch, ch_num, n);
126
127             kernel_->jit_ker(&par_conv);
128
129             ow += ur_w_step;
130         }
131
132         // right border
133         ur_w_step = 1;
134         for (; ow < jcp.ow; ow++) {
135             jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
136                                         kh, kh_padding, ch, ch_num, n);
137
138             kernel_->jit_ker(&par_conv);
139         }
140     });
141
142     if (pd()->wants_zero_pad_dst())
143         output_memory_primitive(0)->zero_pad();
144 }
145
146 template struct _jit_uni_dw_convolution_fwd_t<avx512_common>;
147 template struct _jit_uni_dw_convolution_fwd_t<avx2>;
148 template struct _jit_uni_dw_convolution_fwd_t<sse42>;
149
150 template <cpu_isa_t isa>
151 void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data() const {
152     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
153     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
154     auto diff_src = reinterpret_cast<data_t *>(this->memory());
155
156     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
157     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
158     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
159
160     const auto &jcp = kernel_->jcp;
161
162     auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih,
163             int i_t_overflow, int i_b_overflow, int stride_off_h,
164             int ch, int ch_num, int n) {
165         auto par_conv = jit_conv_call_s();
166
167         const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad));
168         const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw)
169             - jcp.r_pad));
170
171         int ow = iw + jcp.l_pad - i_r_overflow;
172         int stride_off_w = ow % jcp.stride_w;
173         ow /= jcp.stride_w;
174
175         par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)];
176         par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)];
177         par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow
178             + stride_off_h, i_r_overflow + stride_off_w)];
179
180         par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow
181             - stride_off_h);
182         par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow
183             - stride_off_w);
184
185         par_conv.ur_str_w = ur_str_w;
186
187         par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch;
188
189         return par_conv;
190     };
191
192     int MB = pd()->MB();
193     const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
194     parallel_nd(MB, chb_work, jcp.ih,
195         [&](int n, int chb, int ih) {
196         int ch = chb * jcp.nb_ch_blocking;
197         int ch_num = jcp.nb_ch_blocking;
198
199         const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih
200             - jcp.t_pad));
201         const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1
202             - (jcp.ih - 1 - ih) - jcp.b_pad));
203
204         int oh = ih + jcp.t_pad - i_b_overflow;
205         int stride_off_h = oh % jcp.stride_h;
206         oh /= jcp.stride_h;
207
208         for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) {
209             // left border
210             int iw = i_str_w;
211             int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw);
212             int ur_str_w = 1;
213             for (; iw < l_border; iw += jcp.stride_w) {
214                 jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
215                                              ih, i_t_overflow, i_b_overflow,
216                                              stride_off_h, ch, ch_num, n);
217
218                 kernel_->jit_ker(&par_conv);
219             }
220
221             // main loop
222             ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw)
223                  / jcp.stride_w, jcp.iw);
224             if (ur_str_w > 0) {
225                 jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
226                                              ih, i_t_overflow, i_b_overflow,
227                                              stride_off_h, ch, ch_num, n);
228
229                 kernel_->jit_ker(&par_conv);
230
231                 iw += ur_str_w * jcp.stride_w;
232             }
233
234             // right border
235             ur_str_w = 1;
236             for (; iw < jcp.iw; iw += jcp.stride_w) {
237                 jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
238                                              ih, i_t_overflow, i_b_overflow,
239                                              stride_off_h, ch, ch_num, n);
240
241                 kernel_->jit_ker(&par_conv);
242             }
243         }
244     });
245 }
246
247 template struct _jit_uni_dw_convolution_bwd_data_t<avx512_common>;
248 template struct _jit_uni_dw_convolution_bwd_data_t<avx2>;
249 template struct _jit_uni_dw_convolution_bwd_data_t<sse42>;
250
251 template <cpu_isa_t isa>
252 _jit_uni_dw_convolution_bwd_weights_t<isa>::
253 _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd,
254         const input_vector &inputs, const output_vector &outputs)
255     : cpu_primitive_t(apd, inputs, outputs)
256     , kernel_(nullptr), acc_ker_(nullptr)
257 {
258     kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32<isa>(pd()->jcp_);
259     if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction())
260         acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
261 }
262
263 template <cpu_isa_t isa>
264 void _jit_uni_dw_convolution_bwd_weights_t<isa>::execute_backward_weights() const {
265     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
266     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
267     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
268     auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
269
270     auto diff_wei_reduction_buf =
271         scratchpad().template get<data_t>(key_conv_wei_reduction);
272     auto diff_bia_reduction_buf =
273         scratchpad().template get<data_t>(key_conv_bia_reduction);
274
275     const auto &jcp = kernel_->jcp;
276
277     /* Used when executing a parallel reduction */
278     simple_barrier::ctx_t reduction_bctx;
279     simple_barrier::ctx_init(&reduction_bctx);
280
281     const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
282     const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0;
283
284     const int ch_block = jcp.ch_block;
285
286     auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params,
287             const int batch, const int group, const int oh_start,
288             const int work_size, const unsigned char exec_flag,
289             const size_t kh_padding, const size_t filter_off) {
290         const int tpad_underflow_off = jcp.t_pad - filter_off;
291
292         conv_params->exec_flags = exec_flag;
293         conv_params->kh_count = jcp.kh - kh_padding;
294
295         const int oh_s = oh_start;
296         const int oh_e = oh_start + work_size;
297         const int ih_s = oh_s * jcp.stride_h;
298
299         conv_params->filter_pad_off
300                 = filter_off * jcp.kw * ch_block * sizeof(float);
301         conv_params->oh_index = oh_s;
302         conv_params->oh_count = oh_e;
303
304         size_t diff_dst_off
305                 = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh
306                           + oh_start)
307                 * jcp.ow;
308
309         size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih
310                 + ih_s - tpad_underflow_off) * jcp.iw;
311
312         conv_params->output = &diff_dst[diff_dst_off * ch_block];
313         conv_params->input = &src[src_off * ch_block];
314     };
315
316     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
317         assert(nthr == jcp.nthr);
318
319         auto conv_params = jit_dw_conv_call_s();
320         const int h_block_size = 15;
321
322         /* assign iteration space to thread */
323         const int ithr_g = ithr % jcp.nthr_g;
324         const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb;
325
326         /* split dimensions */
327         int g_start{ 0 }, g_end{ 0 };
328         balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end);
329
330         int mb_start{ 0 }, mb_end{ 0 };
331         balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end);
332
333         auto diff_wei = ithr_mb == 0
334             ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size;
335         auto diff_bia = ithr_mb == 0
336             ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size;
337
338         for (int g = g_start; g < g_end; ++g) {
339             unsigned char zero_filter_flag = FLAG_ZERO_FILTER;
340             unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0;
341
342             size_t diff_wei_off = g * jcp.kh * jcp.kw;
343             conv_params.filter = &diff_wei[diff_wei_off * ch_block];
344
345             if (jcp.with_bias)
346                 conv_params.bias = &diff_bia[g * ch_block];
347
348             for (int mb = mb_start; mb < mb_end; ++mb) {
349                 int oh = 0;
350                 while (oh < jcp.oh) {
351                     const int h_work = nstl::min(h_block_size, jcp.oh - oh);
352                     auto kh_t_padding = nstl::max(0, jcp.t_pad - oh);
353                     auto kh_b_padding
354                             = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ?
355                             jcp.b_pad - (h_work - 1) :
356                             0;
357
358                     set_kernel_params(&conv_params, mb, g, oh, h_work,
359                             zero_filter_flag | zero_bias_flag,
360                             kh_t_padding + kh_b_padding, kh_t_padding);
361                     kernel_->jit_ker(&conv_params);
362
363                     zero_bias_flag &= ~FLAG_ZERO_BIAS;
364                     zero_filter_flag &= ~FLAG_ZERO_FILTER;
365                     oh += h_work;
366                 }
367             }
368         }
369
370         if (do_parallel_reduction() && jcp.nthr_mb > 1) {
371             size_t reduct_start{ 0 }, reduct_end{ 0 };
372             balance211(wei_size, nthr, ithr, reduct_start, reduct_end);
373
374             const int acc_size = reduct_end - reduct_start;
375             const size_t reduct_off = reduct_start;
376             auto *acc_data = diff_weights + reduct_off;
377
378             simple_barrier::barrier(&reduction_bctx, nthr);
379
380             for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
381                 auto *src_data = diff_wei_reduction_buf
382                         + (thr_mb - 1) * wei_size + reduct_off;
383                 acc_ker_->accumulate(acc_data, src_data, acc_size);
384             }
385         }
386     });
387
388     if (jcp.nthr_mb <= 1) return;
389
390     /* Apply single-threaded 'mb' reduction */
391     for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
392         size_t mb_accum_offset = (thr_mb - 1) * wei_size;
393         size_t b_accum_offset = (thr_mb - 1) * bias_size;
394
395         for (int g = 0; g < jcp.nb_ch; ++g) {
396             /* Reduction on Bias */
397             if (jcp.with_bias) {
398                 PRAGMA_OMP_SIMD()
399                 for (int g_block = 0; g_block < ch_block; ++g_block) {
400                     size_t bias_offset = g * ch_block + g_block;
401                     diff_bias[bias_offset] += diff_bia_reduction_buf[
402                         b_accum_offset + bias_offset];
403                 }
404             }
405
406             if (do_parallel_reduction()) continue;
407
408             for (int kh = 0; kh < jcp.kh; ++kh)
409             for (int kw = 0; kw < jcp.kw; ++kw)
410             {
411                 size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw;
412                 PRAGMA_OMP_SIMD()
413                 for (int g_block = 0; g_block < ch_block; ++g_block) {
414                     const size_t off = wei_offset * ch_block + g_block;
415                     diff_weights[off] +=
416                         diff_wei_reduction_buf[mb_accum_offset + off];
417                 }
418             }
419         }
420     }
421 }
422
423 template struct _jit_uni_dw_convolution_bwd_weights_t<avx512_common>;
424 template struct _jit_uni_dw_convolution_bwd_weights_t<avx2>;
425 template struct _jit_uni_dw_convolution_bwd_weights_t<sse42>;
426
427 }
428 }
429 }