Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / nchw_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "math_utils.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "nstl.hpp"
25
26 #include "nchw_pooling.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 template <impl::data_type_t data_type>
33 void nchw_pooling_fwd_t<data_type>::execute_forward() const {
34     using namespace alg_kind;
35
36     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
37     auto dst = reinterpret_cast<data_t*>(this->memory(0));
38     auto ws = pd()->desc()->alg_kind == alg_kind::pooling_max ?
39         reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
40
41     const memory_desc_wrapper ws_d(pd()->workspace_pd());
42     const memory_desc_wrapper src_d(pd()->src_pd());
43     const memory_desc_wrapper dst_d(pd()->dst_pd());
44     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
45
46     src += src_d.off_l(0);
47     dst += dst_d.off_l(0);
48
49     const int MB = pd()->MB();
50     const int C = pd()->C();
51     const int OD = pd()->OD();
52     const int OH = pd()->OH();
53     const int OW = pd()->OW();
54     const int ID = pd()->ID();
55     const int IH = pd()->IH();
56     const int IW = pd()->IW();
57     const int KD = pd()->KD();
58     const int KH = pd()->KH();
59     const int KW = pd()->KW();
60     const int SD = pd()->KSD();
61     const int SH = pd()->KSH();
62     const int SW = pd()->KSW();
63     const int padF = pd()->padFront();
64     const int padT = pd()->padT();
65     const int padL = pd()->padL();
66     const int padBack = pd()->padBack();
67     const int padB = pd()->padB();
68     const int padR = pd()->padR();
69
70     auto alg = pd()->desc()->alg_kind;
71     
72     auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
73         if (ws) {
74             assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
75             size_t ws_offset
76                 = (size_t)OW * OH * OD * C * mb
77                 + (size_t)OW * OH * OD * c
78                 + (size_t)OW * OH * od
79                 + (size_t)OW * oh
80                 + (size_t)ow;
81             if (ws_dt == data_type::u8) {
82                 assert(0 <= value && value <= 255);
83                 ws[ws_offset] = value;
84             } else
85                 reinterpret_cast<int *>(ws)[ws_offset] = value;
86         }
87     };
88
89     auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
90         bool is_initialized = false;
91         for (int kd = 0; kd < KD; ++kd) {
92             for (int kh = 0; kh < KH; ++kh) {
93                 for (int kw = 0; kw < KW; ++kw) {
94                     const int id = od * SD - padF + kd;
95                     const int ih = oh * SH - padT + kh;
96                     const int iw = ow * SW - padL + kw;
97
98                     if (id < 0 || id >= ID) continue;
99                     if (ih < 0 || ih >= IH) continue;
100                     if (iw < 0 || iw >= IW) continue;
101
102                     auto src_offset
103                         = (size_t)IW * IH * ID * C * mb
104                         + (size_t)IW * IH * ID * c
105                         + (size_t)IW * IH * id
106                         + (size_t)IW * ih
107                         + (size_t)iw;
108                     auto s = src[src_offset];
109                     if (!is_initialized) {
110                         d[0] = s;
111                         set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw);
112                         is_initialized = true;
113                     } else {
114                         if (d[0] < s)
115                             d[0] = s;
116                             set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw);
117                     }
118                 }
119             }
120         }
121     };
122
123     auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
124         auto id_start = od*SD - padF;
125         auto ih_start = oh*SH - padT;
126         auto iw_start = ow*SW - padL;
127         auto id_end = nstl::min(od*SD - padF + KD, ID + padBack);
128         auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB);
129         auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR);
130
131         // case alg == pooling_avg_include_padding
132         auto num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
133
134         id_start = nstl::max(id_start, 0);
135         ih_start = nstl::max(ih_start, 0);
136         iw_start = nstl::max(iw_start, 0);
137         id_end = nstl::min(id_end, ID);
138         ih_end = nstl::min(ih_end, IH);
139         iw_end = nstl::min(iw_end, IW);
140
141         if (alg == pooling_avg_exclude_padding)
142             num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
143         if (num_summands == 0) return;
144
145         for (int id = id_start; id < id_end; ++id) {
146             for (int ih = ih_start; ih < ih_end; ++ih) {
147                 for (int iw = iw_start; iw < iw_end; ++iw) {
148                     auto src_offset
149                         = (size_t)IW * IH * ID * C * mb
150                         + (size_t)IW * IH * ID * c
151                         + (size_t)IW * IH * id
152                         + (size_t)IW * ih
153                         + (size_t)iw;
154                     d[0] += src[src_offset];
155                 }
156             }
157         }
158
159         d[0] = math::out_round<data_t>((float)d[0] / num_summands);
160     };
161
162
163     if (pd()->desc()->alg_kind == pooling_max) {
164         parallel_nd(MB, C, OD, OH, OW,
165             [&](int mb, int c, int od, int oh, int ow) {
166             size_t dst_offset
167                 = (size_t)OW * OH * OD * C * mb
168                 + (size_t)OW * OH * OD * c
169                 + (size_t)OW * OH * od
170                 + (size_t)OW * oh
171                 + (size_t)ow;
172             data_t *d = &dst[dst_offset];
173             d[0] = (data_t)0;
174             set_ws(mb, c, od, oh, ow, 0);
175             ker_max(d, mb, c, od, oh, ow);
176         });
177     } else {
178         parallel_nd(MB, C, OD, OH, OW,
179             [&](int mb, int c, int od, int oh, int ow) {
180             size_t dst_offset
181                 = (size_t)OW * OH * OD * C * mb
182                 + (size_t)OW * OH * OD * c
183                 + (size_t)OW * OH * od
184                 + (size_t)OW * oh
185                 + (size_t)ow;
186             data_t *d = &dst[dst_offset];
187             d[0] = (data_t)0;
188             ker_avg(d, mb, c, od, oh, ow);
189         });
190     }
191 }
192
193 template <impl::data_type_t data_type>
194 void nchw_pooling_bwd_t<data_type>::execute_backward() const {
195     using namespace alg_kind;
196
197     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
198     auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr :
199         reinterpret_cast<const unsigned char *>(this->input_memory(1));
200     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
201
202     const memory_desc_wrapper ws_d(pd()->workspace_pd());
203
204     const int MB = pd()->MB();
205     const int C = pd()->C();
206     const int OD = pd()->OD();
207     const int OH = pd()->OH();
208     const int OW = pd()->OW();
209     const int ID = pd()->ID();
210     const int IH = pd()->IH();
211     const int IW = pd()->IW();
212     const int KD = pd()->KD();
213     const int KH = pd()->KH();
214     const int KW = pd()->KW();
215     const int SD = pd()->KSD();
216     const int SH = pd()->KSH();
217     const int SW = pd()->KSW();
218     const int padF = pd()->padFront();
219     const int padT = pd()->padT();
220     const int padL = pd()->padL();
221
222     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
223
224     auto alg = pd()->desc()->alg_kind;
225
226     auto apply_offset = [=](int index, int offset) {
227         return (index > offset) ? index - offset : 0;
228     };
229
230     auto ker_zero = [=](int mb, int c) {
231         size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
232         for (int id = 0; id < ID; ++id) {
233             for (int ih = 0; ih < IH; ++ih) {
234                 for (int iw = 0; iw < IW; ++iw) {
235                     diff_src[diff_src_offset++] = 0;
236                 }
237             }
238         }
239     };
240
241     auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
242         auto b_c = ws_d.blocking_desc().block_dims[1];
243         auto ws_offset = is_3d
244             ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c
245             : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c;
246
247         const int index = ws_d.data_type() == data_type::u8
248             ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset];
249         const int kw = index % KW;
250         const int kh = (index / KW) % KH;
251         const int kd = (index / KW) / KH;
252
253         const int id = od * SD - padF + kd;
254         const int ih = oh * SH - padT + kh;
255         const int iw = ow * SW - padL + kw;
256
257         // If padding area could fit the kernel,
258         // then input displacement would be out of bounds.
259         // No need to back propagate there as padding is
260         // virtual in pooling_max case.
261         if (id < 0 || id >= ID)
262             return;
263         if (ih < 0 || ih >= IH)
264             return;
265         if (iw < 0 || iw >= IW)
266             return;
267
268         size_t diff_src_offset =
269             (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW + (size_t)id*IH*IW
270             + (size_t)ih*IW + (size_t)iw;
271         diff_src[diff_src_offset] += d[0];
272     };
273
274     auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
275         auto id_start = apply_offset(od*SD, padF);
276         auto ih_start = apply_offset(oh*SH, padT);
277         auto iw_start = apply_offset(ow*SW, padL);
278         auto id_end = nstl::min(od*SD - padF + KD, ID);
279         auto ih_end = nstl::min(oh*SH - padT + KH, IH);
280         auto iw_end = nstl::min(ow*SW - padL + KW, IW);
281
282         size_t num_summands = (alg == pooling_avg_include_padding)
283             ? (size_t)KW*KH*KD
284             : (size_t)(id_end - id_start)*(ih_end - ih_start)
285                 *(iw_end - iw_start);
286
287         for (int id = id_start; id < id_end; ++id) {
288             for (int ih = ih_start; ih < ih_end; ++ih) {
289                 for (int iw = iw_start; iw < iw_end; ++iw) {
290                     size_t diff_src_offset = (size_t)mb*C*ID*IH*IW
291                         + (size_t)c*ID*IH*IW + (size_t)id*IH*IW
292                         + (size_t)ih*IW + (size_t)iw;
293                     diff_src[diff_src_offset] += d[0] / num_summands;
294                 }
295             }
296         }
297     };
298
299     if (pd()->desc()->alg_kind == pooling_max) {
300         parallel_nd(MB, C, [&](int mb, int c) {
301             size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW
302                 + (size_t)c*OD*OH*OW;
303             ker_zero(mb, c);
304             for (int od = 0; od < OD; ++od) {
305                 for (int oh = 0; oh < OH; ++oh) {
306                     for (int ow = 0; ow < OW; ++ow) {
307                         const data_t *d = &diff_dst[diff_dst_offset++];
308                         ker_max(d, mb, c, od, oh, ow);
309                     }
310                 }
311             }
312         });
313     } else {
314         parallel_nd(MB, C, [&](int mb, int c) {
315             size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW
316                 + (size_t)c*OD*OH*OW;
317             ker_zero(mb, c);
318             for (int od = 0; od < OD; ++od) {
319                 for (int oh = 0; oh < OH; ++oh) {
320                     for (int ow = 0; ow < OW; ++ow) {
321                         const data_t *d = &diff_dst[diff_dst_offset++];
322                         ker_avg(d, mb, c, od, oh, ow);
323                     }
324                 }
325             }
326         });
327     }
328 }
329
330 template struct nchw_pooling_fwd_t<data_type::f32>;
331 template struct nchw_pooling_bwd_t<data_type::f32>;
332
333 }
334 }
335 }
336
337 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s