Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / nhwc_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "math_utils.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "nstl.hpp"
25
26 #include "nhwc_pooling.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 #define MEM_D(name) name##_d
33
34 #define DECLARE_READ_STRIDES(name)                                             \
35     const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0][0];  \
36     const size_t name##_d_stride = (!is_3d)                                    \
37                                  ? 0                                           \
38                                  : MEM_D(name).blocking_desc().strides[0][2];  \
39     const size_t name##_h_stride = (!is_3d)                                    \
40                                  ? MEM_D(name).blocking_desc().strides[0][2]   \
41                                  : MEM_D(name).blocking_desc().strides[0][3];  \
42     const size_t name##_w_stride = (!is_3d)                                    \
43                                  ? MEM_D(name).blocking_desc().strides[0][3]   \
44                                  : MEM_D(name).blocking_desc().strides[0][4];
45
46 namespace nhwc_pooling {
47     size_t strided_offset(const int _n, const size_t _sn,
48                           const int _d, const size_t _sd,
49                           const int _h, const size_t _sh,
50                           const int _w, const size_t _sw)
51     {
52         return   _n * _sn
53                + _d * _sd
54                + _h * _sh
55                + _w * _sw;
56     }
57 }
58
59 template <impl::data_type_t data_type>
60 void nhwc_pooling_fwd_t<data_type>::array_div_by_const(const int n,
61         const data_t *src, const size_t num, data_t *dst) const
62 {
63     for (int i = 0; i < n; ++i)
64     {
65         float ftmp = (float)src[i];
66         ftmp = ftmp / num;
67         dst[i] = math::out_round<data_t>(ftmp);
68     }
69 }
70
71 template <impl::data_type_t data_type>
72 void nhwc_pooling_fwd_t<data_type>::array_add(const int n, const data_t *src,
73         data_t *dst) const
74 {
75     for (int i = 0;  i < n; ++i)
76     {
77         dst[i] += src[i];
78     }
79 }
80
81 template <impl::data_type_t data_type>
82 void nhwc_pooling_fwd_t<data_type>::execute_forward() const {
83     using namespace alg_kind;
84     using namespace prop_kind;
85     using namespace nhwc_pooling;
86
87     auto alg = pd()->desc()->alg_kind;
88
89     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
90     auto dst = reinterpret_cast<data_t *>(this->memory(0));
91     unsigned char * ws = reinterpret_cast<unsigned char *>(
92                   alg == pooling_max
93                       && pd()->desc()->prop_kind == forward_training ?
94                   this->memory(1) : nullptr
95               );
96
97     const memory_desc_wrapper MEM_D(dst)(pd()->dst_pd());
98     const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
99     const memory_desc_wrapper MEM_D(src)(pd()->src_pd());
100
101     const int ID = pd()->ID();
102     const int IH = pd()->IH();
103     const int IW = pd()->IW();
104     const int KD = pd()->KD();
105     const int KH = pd()->KH();
106     const int KW = pd()->KW();
107     const int SD = pd()->KSD();
108     const int SH = pd()->KSH();
109     const int SW = pd()->KSW();
110     const int padF = pd()->padFront();
111     const int padT = pd()->padT();
112     const int padL = pd()->padL();
113     const int MB = pd()->MB();
114     const int OC = pd()->C();
115     const int OD = pd()->OD();
116     const int OH = pd()->OH();
117     const int OW = pd()->OW();
118
119     const bool is_3d = pd()->desc()->src_desc.ndims == 5;
120     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
121
122     DECLARE_READ_STRIDES(src);
123     DECLARE_READ_STRIDES(dst);
124
125     auto apply_offset = [=](int index, int offset) {
126         return (index > offset) ? index - offset : 0;
127     };
128
129     parallel_nd(MB, OD, OH, OW,
130         [&](int mb, int od, int oh, int ow) {
131         size_t dst_offset_init = strided_offset(mb, dst_n_stride,
132                                                 od, dst_d_stride,
133                                                 oh, dst_h_stride,
134                                                 ow, dst_w_stride);
135         if (alg == pooling_max) {
136             size_t ws_offset_init = 0;
137             if (ws)
138             {
139                 DECLARE_READ_STRIDES(ws);
140                 ws_offset_init = strided_offset(mb, ws_n_stride,
141                                                 od, ws_d_stride,
142                                                 oh, ws_h_stride,
143                                                 ow, ws_w_stride);
144             }
145             // Note: GCC 4.8.5 won't vectorize below
146             // simple loops unless they are singled out
147             // into separate helper routines:
148             //    array_nhwc_initialize, array_nhwc_max
149             if (!ws)
150                 array_nhwc_initialize<false>(OC, dst + dst_offset_init,
151                                     ws, ws_offset_init, ws_dt);
152             else
153                 array_nhwc_initialize<true>(OC, dst + dst_offset_init,
154                                     ws, ws_offset_init, ws_dt);
155
156
157             for (int kd = 0; kd < KD; ++kd)
158             for (int kh = 0; kh < KH; ++kh)
159             for (int kw = 0; kw < KW; ++kw) {
160                 const int id = od * SD - padF + kd;
161                 const int ih = oh * SH - padT + kh;
162                 const int iw = ow * SW - padL + kw;
163
164                 if (id < 0 || id >= ID)
165                     continue;
166                 if (ih < 0 || ih >= IH)
167                     continue;
168                 if (iw < 0 || iw >= IW)
169                     continue;
170
171                 size_t src_offset_init = strided_offset(mb, src_n_stride,
172                                                         id, src_d_stride,
173                                                         ih, src_h_stride,
174                                                         iw, src_w_stride);
175
176                 if (!ws)
177                     array_nhwc_max<false>(OC,
178                        dst + dst_offset_init,
179                        src + src_offset_init,
180                        ws, ws_offset_init,
181                        ws_dt,
182                        kd * KH * KW + kh * KW + kw
183                     );
184                 else
185                     array_nhwc_max<true>(OC,
186                        dst + dst_offset_init,
187                        src + src_offset_init,
188                        ws, ws_offset_init,
189                        ws_dt,
190                        kd * KH * KW + kh * KW + kw
191                     );
192             }
193         } else {
194             // pooling_avg
195             auto d = dst + dst_offset_init;
196
197             utils::array_set(d, 0, OC);
198
199             auto id_start = apply_offset(od * SD, padF);
200             auto ih_start = apply_offset(oh * SH, padT);
201             auto iw_start = apply_offset(ow * SW, padL);
202             auto id_end = nstl::min(od * SD - padF + KD, ID);
203             auto ih_end = nstl::min(oh * SH - padT + KH, IH);
204             auto iw_end = nstl::min(ow * SW - padL + KW, IW);
205
206             // it is cheaper to actually count this in a loop
207             // as the typical kernel is small
208             size_t num_summands = 0;
209
210             for (int id = id_start; id < id_end; ++id)
211             for (int ih = ih_start; ih < ih_end; ++ih)
212             for (int iw = iw_start; iw < iw_end; ++iw) {
213                 size_t src_offset_init = strided_offset(mb, src_n_stride,
214                                                         id, src_d_stride,
215                                                         ih, src_h_stride,
216                                                         iw, src_w_stride);
217                 auto s = src + src_offset_init;
218
219                 // need to move the loop to separate function
220                 // for GCC 4.8.5 to vectorize
221                 array_add(OC, s, d);
222
223                 num_summands++;
224             }
225
226             num_summands = (alg == pooling_avg_include_padding) ?
227                     KW * KH * KD : num_summands;
228
229             // need to move the loop to separate function
230             // for GCC 4.8.5 to vectorize
231             array_div_by_const(OC, d, num_summands, d);
232         }
233     });
234 }
235
236 template <impl::data_type_t data_type>
237 void nhwc_pooling_bwd_t<data_type>::execute_backward() const {
238     using namespace alg_kind;
239     using namespace nhwc_pooling;
240
241     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
242     auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
243         : reinterpret_cast<const unsigned char *>(this->input_memory(1));
244     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
245
246     const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_pd());
247     const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
248     const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_pd());
249
250     const int ID = pd()->ID();
251     const int IH = pd()->IH();
252     const int IW = pd()->IW();
253     const int KD = pd()->KD();
254     const int KH = pd()->KH();
255     const int KW = pd()->KW();
256     const int SD = pd()->KSD();
257     const int SH = pd()->KSH();
258     const int SW = pd()->KSW();
259     const int OC = pd()->C();
260     const int padF = pd()->padFront();
261     const int padT = pd()->padT();
262     const int padL = pd()->padL();
263     const int OD = pd()->OD();
264     const int OH = pd()->OH();
265     const int OW = pd()->OW();
266
267     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
268     auto alg = pd()->desc()->alg_kind;
269
270     DECLARE_READ_STRIDES(diff_src);
271     DECLARE_READ_STRIDES(diff_dst);
272
273     auto apply_offset = [=](int index, int offset) {
274         return (index > offset) ? index - offset : 0;
275     };
276
277     const int MB = pd()->MB();
278
279     parallel_nd(MB, ID, IH, IW,
280         [&](int mb, int id, int ih, int iw) {
281         size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
282                                                 id, diff_src_d_stride,
283                                                 ih, diff_src_h_stride,
284                                                 iw, diff_src_w_stride);
285
286         // check if kernel windows are disjoint, in this case there's no
287         // update needed and we just write there once, no initialization
288         // required.
289         if (!(KD == SD && KH == SH && KW == SW))
290             for (int oc = 0; oc < OC; ++oc)
291                 diff_src[src_offset_init + oc] = data_type_t(0);
292
293         // Find out which output cells may correspond to current
294         // input position. Current input postition divided by
295         // stride, with integer divide rounding down, is the
296         // right-most output.
297         // Left-most output may be computed if we decrement input
298         // by (kernel_size - 1) and then do the same division by
299         // stride.
300         int od_left  = nstl::max((id + padF - KD + 1) / SD,  0);
301         int oh_left  = nstl::max((ih + padT - KH + 1) / SH,  0);
302         int ow_left  = nstl::max((iw + padL - KW + 1) / SW,  0);
303         // Notice +1 here to preserve the C loop "less than"
304         // condition for continuing the for loop.
305         int od_right = nstl::min((id + padF) / SD + 1     , OD);
306         int oh_right = nstl::min((ih + padT) / SH + 1     , OH);
307         int ow_right = nstl::min((iw + padL) / SW + 1     , OW);
308
309         for (int od = od_left; od < od_right; ++od)
310         for (int oh = oh_left; oh < oh_right; ++oh)
311         for (int ow = ow_left; ow < ow_right; ++ow) {
312             const int kd = id - od*SD + padF;
313             const int kh = ih - oh*SH + padT;
314             const int kw = iw - ow*SW + padL;
315
316             if (kd < 0 || kd >= KD)
317                 continue;
318             if (kh < 0 || kh >= KH)
319                 continue;
320             if (kw < 0 || kw >= KW)
321                 continue;
322
323             size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride,
324                                                     od, diff_dst_d_stride,
325                                                     oh, diff_dst_h_stride,
326                                                     ow, diff_dst_w_stride);
327
328             if (alg == pooling_max) {
329                 DECLARE_READ_STRIDES(ws);
330                 size_t ws_offset_init = strided_offset(mb, ws_n_stride,
331                                                        od, ws_d_stride,
332                                                        oh, ws_h_stride,
333                                                        ow, ws_w_stride);
334                 const int index = kd * KH * KW + kh * KW + kw;
335
336                 PRAGMA_OMP_SIMD()
337                 for (int oc = 0; oc < OC; ++oc) {
338                     const int index_from_ws =
339                                     (MEM_D(ws).data_type() == data_type::u8)
340                                     ? (int)ws[ws_offset_init + oc]
341                                     : ((int *)ws)[ws_offset_init + oc];
342
343                     const data_t d = diff_dst[dst_offset_init + oc];
344
345                     // Check if kernel windows are disjoint, in this case
346                     // there's no update needed and we just write there once
347                     // otherwise we add value to the contents.
348                     if (!(KD == SD && KH == SH && KW == SW))
349                         diff_src[src_offset_init + oc] +=
350                                                    (index_from_ws == index)
351                                                    ? d
352                                                    : data_type_t(0);
353                     else
354                         diff_src[src_offset_init + oc] =
355                                                    (index_from_ws == index)
356                                                    ? d
357                                                    : data_type_t(0);
358                 }
359             } else {
360                 // pooling_avg
361                 auto id_start = apply_offset(od*SD, padF);
362                 auto ih_start = apply_offset(oh*SH, padT);
363                 auto iw_start = apply_offset(ow*SW, padL);
364                 auto id_end = nstl::min(od*SD - padF + KD, ID);
365                 auto ih_end = nstl::min(oh*SH - padT + KH, IH);
366                 auto iw_end = nstl::min(ow*SW - padL + KW, IW);
367
368                 auto num_summands = (alg == pooling_avg_include_padding)
369                   ? KW*KH*KD
370                   : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
371
372                 PRAGMA_OMP_SIMD()
373                 for (int oc = 0; oc < OC; ++oc) {
374                     const data_t d = diff_dst[dst_offset_init + oc];
375                     // Check if kernel windows are disjoint, in this case
376                     // there's no update needed and we just write there once
377                     // otherwise we add value to the contents.
378                     if (!(KD == SD && KH == SH && KW == SW))
379                       diff_src[src_offset_init + oc] += d / num_summands;
380                     else
381                       diff_src[src_offset_init + oc] = d / num_summands;
382                 }
383             }
384         }
385     });
386 }
387
388 template struct nhwc_pooling_fwd_t<data_type::f32>;
389 template struct nhwc_pooling_bwd_t<data_type::f32>;
390
391 }
392 }
393 }
394
395 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s