1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "math_utils.hpp"
23 #include "mkldnn_thread.hpp"
26 #include "nhwc_pooling.hpp"
32 #define MEM_D(name) name##_d
34 #define DECLARE_READ_STRIDES(name) \
35 const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0][0]; \
36 const size_t name##_d_stride = (!is_3d) \
38 : MEM_D(name).blocking_desc().strides[0][2]; \
39 const size_t name##_h_stride = (!is_3d) \
40 ? MEM_D(name).blocking_desc().strides[0][2] \
41 : MEM_D(name).blocking_desc().strides[0][3]; \
42 const size_t name##_w_stride = (!is_3d) \
43 ? MEM_D(name).blocking_desc().strides[0][3] \
44 : MEM_D(name).blocking_desc().strides[0][4];
46 namespace nhwc_pooling {
47 size_t strided_offset(const int _n, const size_t _sn,
48 const int _d, const size_t _sd,
49 const int _h, const size_t _sh,
50 const int _w, const size_t _sw)
59 template <impl::data_type_t data_type>
60 void nhwc_pooling_fwd_t<data_type>::array_div_by_const(const int n,
61 const data_t *src, const size_t num, data_t *dst) const
63 for (int i = 0; i < n; ++i)
65 float ftmp = (float)src[i];
67 dst[i] = math::out_round<data_t>(ftmp);
71 template <impl::data_type_t data_type>
72 void nhwc_pooling_fwd_t<data_type>::array_add(const int n, const data_t *src,
75 for (int i = 0; i < n; ++i)
81 template <impl::data_type_t data_type>
82 void nhwc_pooling_fwd_t<data_type>::execute_forward() const {
83 using namespace alg_kind;
84 using namespace prop_kind;
85 using namespace nhwc_pooling;
87 auto alg = pd()->desc()->alg_kind;
89 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
90 auto dst = reinterpret_cast<data_t *>(this->memory(0));
91 unsigned char * ws = reinterpret_cast<unsigned char *>(
93 && pd()->desc()->prop_kind == forward_training ?
94 this->memory(1) : nullptr
97 const memory_desc_wrapper MEM_D(dst)(pd()->dst_pd());
98 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
99 const memory_desc_wrapper MEM_D(src)(pd()->src_pd());
101 const int ID = pd()->ID();
102 const int IH = pd()->IH();
103 const int IW = pd()->IW();
104 const int KD = pd()->KD();
105 const int KH = pd()->KH();
106 const int KW = pd()->KW();
107 const int SD = pd()->KSD();
108 const int SH = pd()->KSH();
109 const int SW = pd()->KSW();
110 const int padF = pd()->padFront();
111 const int padT = pd()->padT();
112 const int padL = pd()->padL();
113 const int MB = pd()->MB();
114 const int OC = pd()->C();
115 const int OD = pd()->OD();
116 const int OH = pd()->OH();
117 const int OW = pd()->OW();
119 const bool is_3d = pd()->desc()->src_desc.ndims == 5;
120 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
122 DECLARE_READ_STRIDES(src);
123 DECLARE_READ_STRIDES(dst);
125 auto apply_offset = [=](int index, int offset) {
126 return (index > offset) ? index - offset : 0;
129 parallel_nd(MB, OD, OH, OW,
130 [&](int mb, int od, int oh, int ow) {
131 size_t dst_offset_init = strided_offset(mb, dst_n_stride,
135 if (alg == pooling_max) {
136 size_t ws_offset_init = 0;
139 DECLARE_READ_STRIDES(ws);
140 ws_offset_init = strided_offset(mb, ws_n_stride,
145 // Note: GCC 4.8.5 won't vectorize below
146 // simple loops unless they are singled out
147 // into separate helper routines:
148 // array_nhwc_initialize, array_nhwc_max
150 array_nhwc_initialize<false>(OC, dst + dst_offset_init,
151 ws, ws_offset_init, ws_dt);
153 array_nhwc_initialize<true>(OC, dst + dst_offset_init,
154 ws, ws_offset_init, ws_dt);
157 for (int kd = 0; kd < KD; ++kd)
158 for (int kh = 0; kh < KH; ++kh)
159 for (int kw = 0; kw < KW; ++kw) {
160 const int id = od * SD - padF + kd;
161 const int ih = oh * SH - padT + kh;
162 const int iw = ow * SW - padL + kw;
164 if (id < 0 || id >= ID)
166 if (ih < 0 || ih >= IH)
168 if (iw < 0 || iw >= IW)
171 size_t src_offset_init = strided_offset(mb, src_n_stride,
177 array_nhwc_max<false>(OC,
178 dst + dst_offset_init,
179 src + src_offset_init,
182 kd * KH * KW + kh * KW + kw
185 array_nhwc_max<true>(OC,
186 dst + dst_offset_init,
187 src + src_offset_init,
190 kd * KH * KW + kh * KW + kw
195 auto d = dst + dst_offset_init;
197 utils::array_set(d, 0, OC);
199 auto id_start = apply_offset(od * SD, padF);
200 auto ih_start = apply_offset(oh * SH, padT);
201 auto iw_start = apply_offset(ow * SW, padL);
202 auto id_end = nstl::min(od * SD - padF + KD, ID);
203 auto ih_end = nstl::min(oh * SH - padT + KH, IH);
204 auto iw_end = nstl::min(ow * SW - padL + KW, IW);
206 // it is cheaper to actually count this in a loop
207 // as the typical kernel is small
208 size_t num_summands = 0;
210 for (int id = id_start; id < id_end; ++id)
211 for (int ih = ih_start; ih < ih_end; ++ih)
212 for (int iw = iw_start; iw < iw_end; ++iw) {
213 size_t src_offset_init = strided_offset(mb, src_n_stride,
217 auto s = src + src_offset_init;
219 // need to move the loop to separate function
220 // for GCC 4.8.5 to vectorize
226 num_summands = (alg == pooling_avg_include_padding) ?
227 KW * KH * KD : num_summands;
229 // need to move the loop to separate function
230 // for GCC 4.8.5 to vectorize
231 array_div_by_const(OC, d, num_summands, d);
236 template <impl::data_type_t data_type>
237 void nhwc_pooling_bwd_t<data_type>::execute_backward() const {
238 using namespace alg_kind;
239 using namespace nhwc_pooling;
241 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
242 auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
243 : reinterpret_cast<const unsigned char *>(this->input_memory(1));
244 auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
246 const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_pd());
247 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
248 const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_pd());
250 const int ID = pd()->ID();
251 const int IH = pd()->IH();
252 const int IW = pd()->IW();
253 const int KD = pd()->KD();
254 const int KH = pd()->KH();
255 const int KW = pd()->KW();
256 const int SD = pd()->KSD();
257 const int SH = pd()->KSH();
258 const int SW = pd()->KSW();
259 const int OC = pd()->C();
260 const int padF = pd()->padFront();
261 const int padT = pd()->padT();
262 const int padL = pd()->padL();
263 const int OD = pd()->OD();
264 const int OH = pd()->OH();
265 const int OW = pd()->OW();
267 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
268 auto alg = pd()->desc()->alg_kind;
270 DECLARE_READ_STRIDES(diff_src);
271 DECLARE_READ_STRIDES(diff_dst);
273 auto apply_offset = [=](int index, int offset) {
274 return (index > offset) ? index - offset : 0;
277 const int MB = pd()->MB();
279 parallel_nd(MB, ID, IH, IW,
280 [&](int mb, int id, int ih, int iw) {
281 size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
282 id, diff_src_d_stride,
283 ih, diff_src_h_stride,
284 iw, diff_src_w_stride);
286 // check if kernel windows are disjoint, in this case there's no
287 // update needed and we just write there once, no initialization
289 if (!(KD == SD && KH == SH && KW == SW))
290 for (int oc = 0; oc < OC; ++oc)
291 diff_src[src_offset_init + oc] = data_type_t(0);
293 // Find out which output cells may correspond to current
294 // input position. Current input postition divided by
295 // stride, with integer divide rounding down, is the
296 // right-most output.
297 // Left-most output may be computed if we decrement input
298 // by (kernel_size - 1) and then do the same division by
300 int od_left = nstl::max((id + padF - KD + 1) / SD, 0);
301 int oh_left = nstl::max((ih + padT - KH + 1) / SH, 0);
302 int ow_left = nstl::max((iw + padL - KW + 1) / SW, 0);
303 // Notice +1 here to preserve the C loop "less than"
304 // condition for continuing the for loop.
305 int od_right = nstl::min((id + padF) / SD + 1 , OD);
306 int oh_right = nstl::min((ih + padT) / SH + 1 , OH);
307 int ow_right = nstl::min((iw + padL) / SW + 1 , OW);
309 for (int od = od_left; od < od_right; ++od)
310 for (int oh = oh_left; oh < oh_right; ++oh)
311 for (int ow = ow_left; ow < ow_right; ++ow) {
312 const int kd = id - od*SD + padF;
313 const int kh = ih - oh*SH + padT;
314 const int kw = iw - ow*SW + padL;
316 if (kd < 0 || kd >= KD)
318 if (kh < 0 || kh >= KH)
320 if (kw < 0 || kw >= KW)
323 size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride,
324 od, diff_dst_d_stride,
325 oh, diff_dst_h_stride,
326 ow, diff_dst_w_stride);
328 if (alg == pooling_max) {
329 DECLARE_READ_STRIDES(ws);
330 size_t ws_offset_init = strided_offset(mb, ws_n_stride,
334 const int index = kd * KH * KW + kh * KW + kw;
337 for (int oc = 0; oc < OC; ++oc) {
338 const int index_from_ws =
339 (MEM_D(ws).data_type() == data_type::u8)
340 ? (int)ws[ws_offset_init + oc]
341 : ((int *)ws)[ws_offset_init + oc];
343 const data_t d = diff_dst[dst_offset_init + oc];
345 // Check if kernel windows are disjoint, in this case
346 // there's no update needed and we just write there once
347 // otherwise we add value to the contents.
348 if (!(KD == SD && KH == SH && KW == SW))
349 diff_src[src_offset_init + oc] +=
350 (index_from_ws == index)
354 diff_src[src_offset_init + oc] =
355 (index_from_ws == index)
361 auto id_start = apply_offset(od*SD, padF);
362 auto ih_start = apply_offset(oh*SH, padT);
363 auto iw_start = apply_offset(ow*SW, padL);
364 auto id_end = nstl::min(od*SD - padF + KD, ID);
365 auto ih_end = nstl::min(oh*SH - padT + KH, IH);
366 auto iw_end = nstl::min(ow*SW - padL + KW, IW);
368 auto num_summands = (alg == pooling_avg_include_padding)
370 : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
373 for (int oc = 0; oc < OC; ++oc) {
374 const data_t d = diff_dst[dst_offset_init + oc];
375 // Check if kernel windows are disjoint, in this case
376 // there's no update needed and we just write there once
377 // otherwise we add value to the contents.
378 if (!(KD == SD && KH == SH && KW == SW))
379 diff_src[src_offset_init + oc] += d / num_summands;
381 diff_src[src_offset_init + oc] = d / num_summands;
388 template struct nhwc_pooling_fwd_t<data_type::f32>;
389 template struct nhwc_pooling_bwd_t<data_type::f32>;
395 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s