updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / nhwc_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "math_utils.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "nstl.hpp"
25
26 #include "nhwc_pooling.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 #define MEM_D(name) name##_d
33
34 #define DECLARE_READ_STRIDES(name)                                             \
35     const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0][0];  \
36     const size_t name##_d_stride = (!is_3d)                                    \
37                                  ? 0                                           \
38                                  : MEM_D(name).blocking_desc().strides[0][2];  \
39     const size_t name##_h_stride = (!is_3d)                                    \
40                                  ? MEM_D(name).blocking_desc().strides[0][2]   \
41                                  : MEM_D(name).blocking_desc().strides[0][3];  \
42     const size_t name##_w_stride = (!is_3d)                                    \
43                                  ? MEM_D(name).blocking_desc().strides[0][3]   \
44                                  : MEM_D(name).blocking_desc().strides[0][4];
45
46 namespace {
47     size_t strided_offset(const int _n, const size_t _sn,
48                           const int _d, const size_t _sd,
49                           const int _h, const size_t _sh,
50                           const int _w, const size_t _sw)
51     {
52         return   _n * _sn
53                + _d * _sd
54                + _h * _sh
55                + _w * _sw;
56     }
57 }
58
59 using namespace alg_kind;
60 using namespace prop_kind;
61 using namespace bf16_cvt_utils;
62
63 template <data_type_t d_type>
64 void nhwc_pooling_fwd_t<d_type>::execute_forward() const {
65     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
66     auto dst = reinterpret_cast<data_t *>(this->memory(0));
67     unsigned char * ws = reinterpret_cast<unsigned char *>(
68                   pd()->desc()->alg_kind == pooling_max
69                       && pd()->desc()->prop_kind == forward_training ?
70                   this->memory(1) : nullptr
71               );
72
73     const memory_desc_wrapper MEM_D(dst)(pd()->dst_pd());
74     const memory_desc_wrapper MEM_D(src)(pd()->src_pd());
75     const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
76     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
77
78     const int MB = pd()->MB();
79     const int C = pd()->C();
80     const int OD = pd()->OD();
81     const int OH = pd()->OH();
82     const int OW = pd()->OW();
83     const int ID = pd()->ID();
84     const int IH = pd()->IH();
85     const int IW = pd()->IW();
86     const int KD = pd()->KD();
87     const int KH = pd()->KH();
88     const int KW = pd()->KW();
89     const int SD = pd()->KSD();
90     const int SH = pd()->KSH();
91     const int SW = pd()->KSW();
92     const int padF = pd()->padFront();
93     const int padT = pd()->padT();
94     const int padL = pd()->padL();
95
96     const bool is_3d = pd()->desc()->src_desc.ndims == 5;
97     auto alg = pd()->desc()->alg_kind;
98
99     DECLARE_READ_STRIDES(src);
100     DECLARE_READ_STRIDES(dst);
101
102     auto apply_offset = [=](int index, int offset) {
103         return (index > offset) ? index - offset : 0;
104     };
105
106     auto ker_max = [=](data_t *d, const data_t *s, int mb, int od, int oh, int ow) {
107         size_t ws_offset_init = 0;
108         if (ws)
109         {
110             DECLARE_READ_STRIDES(ws);
111             ws_offset_init = strided_offset(mb, ws_n_stride,
112                                             od, ws_d_stride,
113                                             oh, ws_h_stride,
114                                             ow, ws_w_stride);
115         }
116
117         /* Note: GCC 4.8.5 won't vectorize below simple loops unless
118          * they are singled out into separate helper routines:
119          *  array_initialize, array_max */
120         array_initialize(C, d,
121                 ws, ws_offset_init, ws_dt);
122
123         for (int kd = 0; kd < KD; ++kd)
124         for (int kh = 0; kh < KH; ++kh)
125         for (int kw = 0; kw < KW; ++kw) {
126             const int id = od * SD - padF + kd;
127             const int ih = oh * SH - padT + kh;
128             const int iw = ow * SW - padL + kw;
129
130             if (id < 0 || id >= ID)
131                 continue;
132             if (ih < 0 || ih >= IH)
133                 continue;
134             if (iw < 0 || iw >= IW)
135                 continue;
136
137             size_t src_offset_init = strided_offset(mb, src_n_stride,
138                                                     id, src_d_stride,
139                                                     ih, src_h_stride,
140                                                     iw, src_w_stride);
141             array_max(C,
142                d, &s[src_offset_init],
143                ws, ws_offset_init,
144                ws_dt,
145                kd * KH * KW + kh * KW + kw
146             );
147         }
148     };
149
150     auto ker_avg = [=](data_t *d, const data_t *s,
151             int mb, int od, int oh, int ow) {
152         utils::array_set(d, 0, C);
153
154         auto id_start = apply_offset(od * SD, padF);
155         auto ih_start = apply_offset(oh * SH, padT);
156         auto iw_start = apply_offset(ow * SW, padL);
157         auto id_end = nstl::min(od * SD - padF + KD, ID);
158         auto ih_end = nstl::min(oh * SH - padT + KH, IH);
159         auto iw_end = nstl::min(ow * SW - padL + KW, IW);
160
161         // it is cheaper to actually count this in a loop
162         // as the typical kernel is small
163         size_t num_summands = 0;
164
165         /* Note: GCC 4.8.5 won't vectorize below simple loops unless
166          * they are singled out into separate helper routines:
167          *  array_add, array_div_by_const */
168         for (int id = id_start; id < id_end; ++id)
169         for (int ih = ih_start; ih < ih_end; ++ih)
170         for (int iw = iw_start; iw < iw_end; ++iw) {
171             size_t src_offset_init = strided_offset(mb, src_n_stride,
172                                                     id, src_d_stride,
173                                                     ih, src_h_stride,
174                                                     iw, src_w_stride);
175             array_add(C, d, &s[src_offset_init]);
176             num_summands++;
177         }
178
179         num_summands = (alg == pooling_avg_include_padding) ?
180                 KW * KH * KD : num_summands;
181
182         array_div_by_const(C, d, num_summands, d);
183     };
184
185     parallel_nd(MB, OD, OH, OW,
186         [&](int mb, int od, int oh, int ow) {
187         size_t dst_offset_init = strided_offset(mb, dst_n_stride,
188                                                 od, dst_d_stride,
189                                                 oh, dst_h_stride,
190                                                 ow, dst_w_stride);
191         data_t *d = reinterpret_cast<data_t*>(&dst[dst_offset_init]);
192
193         if (alg == pooling_max) {
194             ker_max(d, src, mb, od, oh, ow);
195
196         } else {
197             ker_avg(d, src, mb, od, oh, ow);
198         }
199     });
200 }
201
202 template <>
203 void nhwc_pooling_fwd_t<data_type::bf16>::execute_forward() const {
204
205     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
206     auto dst = reinterpret_cast<data_t *>(this->memory(0));
207     unsigned char * ws = reinterpret_cast<unsigned char *>(
208                   pd()->desc()->alg_kind == pooling_max
209                       && pd()->desc()->prop_kind == forward_training ?
210                   this->memory(1) : nullptr
211               );
212
213     auto scratchpad = this->scratchpad();
214     float *bf16cvt_src_wsp = scratchpad.template get<float>(
215             memory_tracking::names::key_pool_src_bf16cvt);
216     float *bf16cvt_dst_wsp = scratchpad.template get<float>(
217             memory_tracking::names::key_pool_dst_bf16cvt);
218
219     const memory_desc_wrapper MEM_D(dst)(pd()->dst_pd());
220     const memory_desc_wrapper MEM_D(src)(pd()->src_pd());
221     const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
222     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
223
224     const int MB = pd()->MB();
225     const int C = pd()->C();
226     const int OD = pd()->OD();
227     const int OH = pd()->OH();
228     const int OW = pd()->OW();
229     const int ID = pd()->ID();
230     const int IH = pd()->IH();
231     const int IW = pd()->IW();
232     const int KD = pd()->KD();
233     const int KH = pd()->KH();
234     const int KW = pd()->KW();
235     const int SD = pd()->KSD();
236     const int SH = pd()->KSH();
237     const int SW = pd()->KSW();
238     const int padF = pd()->padFront();
239     const int padT = pd()->padT();
240     const int padL = pd()->padL();
241
242     const bool is_3d = pd()->desc()->src_desc.ndims == 5;
243     auto alg = pd()->desc()->alg_kind;
244
245     DECLARE_READ_STRIDES(src);
246     DECLARE_READ_STRIDES(dst);
247
248     auto apply_offset = [=](int index, int offset) {
249         return (index > offset) ? index - offset : 0;
250     };
251
252     auto ker_max = [=](mkldnn_bfloat16_t *d, const mkldnn_bfloat16_t *s,
253         int mb, int od, int oh, int ow) {
254         size_t ws_offset_init = 0;
255         if (ws)
256         {
257             DECLARE_READ_STRIDES(ws);
258             ws_offset_init = strided_offset(mb, ws_n_stride,
259                                             od, ws_d_stride,
260                                             oh, ws_h_stride,
261                                             ow, ws_w_stride);
262         }
263         size_t ithr = mkldnn_get_thread_num();
264         float *dst_f32_ = &bf16cvt_dst_wsp[ithr * C];
265         float *src_f32_ = &bf16cvt_src_wsp[ithr * C];
266
267         /* Note: GCC 4.8.5 won't vectorize below simple loops unless
268          * they are singled out into separate helper routines:
269          *  array_initialize, array_max */
270         array_initialize(C, dst_f32_,
271                 ws, ws_offset_init, ws_dt);
272
273         for (int kd = 0; kd < KD; ++kd)
274         for (int kh = 0; kh < KH; ++kh)
275         for (int kw = 0; kw < KW; ++kw) {
276             const int id = od * SD - padF + kd;
277             const int ih = oh * SH - padT + kh;
278             const int iw = ow * SW - padL + kw;
279
280             if (id < 0 || id >= ID)
281                 continue;
282             if (ih < 0 || ih >= IH)
283                 continue;
284             if (iw < 0 || iw >= IW)
285                 continue;
286
287             size_t src_offset_init = strided_offset(mb, src_n_stride,
288                                                     id, src_d_stride,
289                                                     ih, src_h_stride,
290                                                     iw, src_w_stride);
291             cvt_bfloat16_to_float(src_f32_, &s[src_offset_init], C);
292             array_max(C,
293                dst_f32_, src_f32_,
294                ws, ws_offset_init,
295                ws_dt,
296                kd * KH * KW + kh * KW + kw
297             );
298         }
299         cvt_float_to_bfloat16(d, dst_f32_, C);
300     };
301
302     auto ker_avg = [=](mkldnn_bfloat16_t *d, const mkldnn_bfloat16_t *s,
303             int mb, int od, int oh, int ow) {
304         size_t ithr = mkldnn_get_thread_num();
305         float *dst_f32_ = &bf16cvt_dst_wsp[ithr * C];
306         float *src_f32_ = &bf16cvt_src_wsp[ithr * C];
307         utils::array_set(dst_f32_, 0, C);
308
309         auto id_start = apply_offset(od * SD, padF);
310         auto ih_start = apply_offset(oh * SH, padT);
311         auto iw_start = apply_offset(ow * SW, padL);
312         auto id_end = nstl::min(od * SD - padF + KD, ID);
313         auto ih_end = nstl::min(oh * SH - padT + KH, IH);
314         auto iw_end = nstl::min(ow * SW - padL + KW, IW);
315
316         // it is cheaper to actually count this in a loop
317         // as the typical kernel is small
318         size_t num_summands = 0;
319
320         /* Note: GCC 4.8.5 won't vectorize below simple loops unless
321          * they are singled out into separate helper routines:
322          *  array_add, array_div_by_const */
323         for (int id = id_start; id < id_end; ++id)
324         for (int ih = ih_start; ih < ih_end; ++ih)
325         for (int iw = iw_start; iw < iw_end; ++iw) {
326             size_t src_offset_init = strided_offset(mb, src_n_stride,
327                                                     id, src_d_stride,
328                                                     ih, src_h_stride,
329                                                     iw, src_w_stride);
330             cvt_bfloat16_to_float(src_f32_, &s[src_offset_init], C);
331
332             array_add(C, dst_f32_, src_f32_);
333             num_summands++;
334         }
335
336         num_summands = (alg == pooling_avg_include_padding) ?
337                 KW * KH * KD : num_summands;
338
339         array_div_by_const(C, dst_f32_, num_summands, dst_f32_);
340         cvt_float_to_bfloat16(d, dst_f32_, C);
341     };
342
343     parallel_nd(MB, OD, OH, OW,
344         [&](int mb, int od, int oh, int ow) {
345         size_t dst_offset_init = strided_offset(mb, dst_n_stride,
346                                                 od, dst_d_stride,
347                                                 oh, dst_h_stride,
348                                                 ow, dst_w_stride);
349         data_t *d = reinterpret_cast<data_t*>(&dst[dst_offset_init]);
350
351         if (alg == pooling_max)
352             ker_max((mkldnn_bfloat16_t *)d, (const mkldnn_bfloat16_t *)src,
353                     mb, od, oh, ow);
354         else
355             ker_avg((mkldnn_bfloat16_t *)d, (const mkldnn_bfloat16_t *)src,
356                     mb, od, oh, ow);
357     });
358 }
359
360 template <impl::data_type_t d_type>
361 void nhwc_pooling_bwd_t<d_type>::execute_backward() const {
362     using namespace alg_kind;
363     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
364     auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
365         : reinterpret_cast<const unsigned char *>(this->input_memory(1));
366     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
367
368     const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_pd());
369     const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_pd());
370
371     const int MB = pd()->MB();
372     const int C = pd()->C();
373     const int OD = pd()->OD();
374     const int OH = pd()->OH();
375     const int OW = pd()->OW();
376     const int ID = pd()->ID();
377     const int IH = pd()->IH();
378     const int IW = pd()->IW();
379     const int KD = pd()->KD();
380     const int KH = pd()->KH();
381     const int KW = pd()->KW();
382     const int SD = pd()->KSD();
383     const int SH = pd()->KSH();
384     const int SW = pd()->KSW();
385     const int padF = pd()->padFront();
386     const int padT = pd()->padT();
387     const int padL = pd()->padL();
388
389     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
390     auto alg = pd()->desc()->alg_kind;
391
392     DECLARE_READ_STRIDES(diff_src);
393     DECLARE_READ_STRIDES(diff_dst);
394
395     auto apply_offset = [=](int index, int offset) {
396         return (index > offset) ? index - offset : 0;
397     };
398
399     auto ker_max = [=](data_t *ds, const data_t *dd,
400             int mb, int od, int oh, int ow,
401             int kd, int kh, int kw) {
402         const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
403         DECLARE_READ_STRIDES(ws);
404         size_t ws_offset_init = strided_offset(mb, ws_n_stride,
405                                                od, ws_d_stride,
406                                                oh, ws_h_stride,
407                                                ow, ws_w_stride);
408         const int index = kd * KH * KW
409             + kh * KW + kw;
410
411         PRAGMA_OMP_SIMD()
412         for (int c = 0; c < C; ++c) {
413             const int index_from_ws =
414                             (MEM_D(ws).data_type() == data_type::u8)
415                             ? (int)ws[ws_offset_init + c]
416                             : ((int *)ws)[ws_offset_init + c];
417
418             // Check if kernel windows are disjoint, in this case
419             // there's no update needed and we just write there once
420             // otherwise we add value to the contents.
421             if (!(KD == SD &&
422                         KH == SH && KW == SW))
423                 ds[c] += (index_from_ws == index)
424                         ? dd[c] : data_type_t(0);
425             else
426                 ds[c] = (index_from_ws == index)
427                        ? dd[c] : data_type_t(0);
428         }
429
430     };
431
432     auto ker_avg = [=](data_t *ds, const data_t *dd,
433             int mb, int od, int oh, int ow) {
434         auto id_start = apply_offset(od * SD, padF);
435         auto ih_start = apply_offset(oh * SH, padT);
436         auto iw_start = apply_offset(ow * SW, padL);
437         auto id_end = nstl::min(
438                 od * SD - padF + KD, ID);
439         auto ih_end = nstl::min(
440                 oh * SH - padT + KH, IH);
441         auto iw_end = nstl::min(
442                 ow * SW - padL + KW, IW);
443
444         auto num_summands = (alg == pooling_avg_include_padding)
445           ? KW * KH * KD
446           : (ih_end - ih_start) * (iw_end - iw_start) * (id_end - id_start);
447
448         PRAGMA_OMP_SIMD()
449         for (int c = 0; c < C; ++c) {
450             const data_t d = dd[c];
451             // Check if kernel windows are disjoint, in this case
452             // there's no update needed and we just write there once
453             // otherwise we add value to the contents.
454             if (!(KD == SD &&
455                         KH == SH && KW == SW))
456               ds[c] += d / num_summands;
457             else
458               ds[c] = d / num_summands;
459         }
460     };
461
462     parallel_nd(MB, ID, IH, IW,
463         [&](int mb, int id, int ih, int iw) {
464         size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
465                                                 id, diff_src_d_stride,
466                                                 ih, diff_src_h_stride,
467                                                 iw, diff_src_w_stride);
468
469         // check if kernel windows are disjoint, in this case there's no
470         // update needed and we just write there once, no initialization
471         // required.
472         if (!(KD == SD && KH == SH
473                 && KW == SW))
474             for (int c = 0; c < C; ++c)
475                 diff_src[src_offset_init + c] = data_type_t(0);
476
477         // Find out which output cells may correspond to current
478         // input position. Current input postition divided by
479         // stride, with integer divide rounding down, is the
480         // right-most output.
481         // Left-most output may be computed if we decrement input
482         // by (kernel_size - 1) and then do the same division by
483         // stride.
484         int od_left  = nstl::max(
485                 (id + padF - KD + 1) / SD,  0);
486         int oh_left  = nstl::max(
487                 (ih + padT - KH + 1) / SH,  0);
488         int ow_left  = nstl::max(
489                 (iw + padL - KW + 1) / SW,  0);
490         // Notice +1 here to preserve the C loop "less than"
491         // condition for continuing the for loop.
492         int od_right = nstl::min(
493                 (id + padF) / SD + 1, OD);
494         int oh_right = nstl::min(
495                 (ih + padT) / SH + 1, OH);
496         int ow_right = nstl::min(
497                 (iw + padL) / SW + 1, OW);
498
499         for (int od = od_left; od < od_right; ++od)
500         for (int oh = oh_left; oh < oh_right; ++oh)
501         for (int ow = ow_left; ow < ow_right; ++ow) {
502             const int kd = id - od * SD + padF;
503             const int kh = ih - oh * SH + padT;
504             const int kw = iw - ow * SW + padL;
505
506             if (kd < 0 || kd >= KD)
507                 continue;
508             if (kh < 0 || kh >= KH)
509                 continue;
510             if (kw < 0 || kw >= KW)
511                 continue;
512
513             size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride,
514                                                     od, diff_dst_d_stride,
515                                                     oh, diff_dst_h_stride,
516                                                     ow, diff_dst_w_stride);
517             if (alg == pooling_max) {
518                 ker_max(&diff_src[src_offset_init], &diff_dst[dst_offset_init],
519                         mb, od, oh, ow, kd, kh, kw);
520             } else {
521                 ker_avg(&diff_src[src_offset_init], &diff_dst[dst_offset_init],
522                         mb, od, oh, ow);
523             }
524         }
525     });
526 }
527
528 template <>
529 void nhwc_pooling_bwd_t<data_type::bf16>::execute_backward() const {
530     using namespace alg_kind;
531     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
532     auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
533         : reinterpret_cast<const unsigned char *>(this->input_memory(1));
534     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
535
536     const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_pd());
537     const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_pd());
538
539     auto scratchpad = this->scratchpad();
540     float *bf16cvt_dsrc_ = scratchpad.template get<float>(
541             memory_tracking::names::key_pool_src_bf16cvt);
542     float *bf16cvt_ddst_ = scratchpad.template get<float>(
543             memory_tracking::names::key_pool_dst_bf16cvt);
544
545     const int MB = pd()->MB();
546     const int C = pd()->C();
547     const int OD = pd()->OD();
548     const int OH = pd()->OH();
549     const int OW = pd()->OW();
550     const int ID = pd()->ID();
551     const int IH = pd()->IH();
552     const int IW = pd()->IW();
553     const int KD = pd()->KD();
554     const int KH = pd()->KH();
555     const int KW = pd()->KW();
556     const int SD = pd()->KSD();
557     const int SH = pd()->KSH();
558     const int SW = pd()->KSW();
559     const int padF = pd()->padFront();
560     const int padT = pd()->padT();
561     const int padL = pd()->padL();
562
563     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
564     auto alg = pd()->desc()->alg_kind;
565
566     DECLARE_READ_STRIDES(diff_src);
567     DECLARE_READ_STRIDES(diff_dst);
568
569     auto apply_offset = [=](int index, int offset) {
570         return (index > offset) ? index - offset : 0;
571     };
572
573     auto ker_max = [=](float *ds, const float *dd,
574             int mb, int od, int oh, int ow,
575             int kd, int kh, int kw) {
576         const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
577         DECLARE_READ_STRIDES(ws);
578         size_t ws_offset_init = strided_offset(mb, ws_n_stride,
579                                                od, ws_d_stride,
580                                                oh, ws_h_stride,
581                                                ow, ws_w_stride);
582         const int index = kd * KH * KW
583             + kh * KW + kw;
584
585         PRAGMA_OMP_SIMD()
586         for (int c = 0; c < C; ++c) {
587             const int index_from_ws =
588                             (MEM_D(ws).data_type() == data_type::u8)
589                             ? (int)ws[ws_offset_init + c]
590                             : ((int *)ws)[ws_offset_init + c];
591
592             // Check if kernel windows are disjoint, in this case
593             // there's no update needed and we just write there once
594             // otherwise we add value to the contents.
595             if (!(KD == SD &&
596                         KH == SH && KW == SW))
597                 ds[c] += (index_from_ws == index)
598                         ? dd[c] : 0.0f;
599             else
600                 ds[c] = (index_from_ws == index)
601                        ? dd[c] : 0.0f;
602         }
603
604     };
605
606     auto ker_avg = [=](float *ds, const float *dd,
607             int mb, int od, int oh, int ow) {
608         auto id_start = apply_offset(od * SD, padF);
609         auto ih_start = apply_offset(oh * SH, padT);
610         auto iw_start = apply_offset(ow * SW, padL);
611         auto id_end = nstl::min(
612                 od * SD - padF + KD, ID);
613         auto ih_end = nstl::min(
614                 oh * SH - padT + KH, IH);
615         auto iw_end = nstl::min(
616                 ow * SW - padL + KW, IW);
617
618         auto num_summands = (alg == pooling_avg_include_padding)
619           ? KW * KH * KD
620           : (ih_end - ih_start) * (iw_end - iw_start) * (id_end - id_start);
621
622         PRAGMA_OMP_SIMD()
623         for (int c = 0; c < C; ++c) {
624             // Check if kernel windows are disjoint, in this case
625             // there's no update needed and we just write there once
626             // otherwise we add value to the contents.
627             if (!(KD == SD &&
628                         KH == SH && KW == SW))
629               ds[c] += dd[c] / num_summands;
630             else
631               ds[c] = dd[c] / num_summands;
632         }
633     };
634
635     parallel_nd(MB, ID, IH, IW,
636         [&](int mb, int id, int ih, int iw) {
637         size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
638                                                 id, diff_src_d_stride,
639                                                 ih, diff_src_h_stride,
640                                                 iw, diff_src_w_stride);
641
642         float *ddst_fp32_ = &bf16cvt_ddst_[mkldnn_get_thread_num() * C];
643         float *dsrc_fp32_ = &bf16cvt_dsrc_[mkldnn_get_thread_num() * C];
644         // check if kernel windows are disjoint, in this case there's no
645         // update needed and we just write there once, no initialization
646         // required.
647         if (!(KD == SD && KH == SH
648                 && KW == SW))
649             for (int c = 0; c < C; ++c)
650                 dsrc_fp32_[c] = 0.0f;
651
652         // Find out which output cells may correspond to current
653         // input position. Current input postition divided by
654         // stride, with integer divide rounding down, is the
655         // right-most output.
656         // Left-most output may be computed if we decrement input
657         // by (kernel_size - 1) and then do the same division by
658         // stride.
659         int od_left  = nstl::max(
660                 (id + padF - KD + 1) / SD,  0);
661         int oh_left  = nstl::max(
662                 (ih + padT - KH + 1) / SH,  0);
663         int ow_left  = nstl::max(
664                 (iw + padL - KW + 1) / SW,  0);
665         // Notice +1 here to preserve the C loop "less than"
666         // condition for continuing the for loop.
667         int od_right = nstl::min(
668                 (id + padF) / SD + 1, OD);
669         int oh_right = nstl::min(
670                 (ih + padT) / SH + 1, OH);
671         int ow_right = nstl::min(
672                 (iw + padL) / SW + 1, OW);
673
674         for (int od = od_left; od < od_right; ++od)
675         for (int oh = oh_left; oh < oh_right; ++oh)
676         for (int ow = ow_left; ow < ow_right; ++ow) {
677             const int kd = id - od * SD + padF;
678             const int kh = ih - oh * SH + padT;
679             const int kw = iw - ow * SW + padL;
680
681             if (kd < 0 || kd >= KD)
682                 continue;
683             if (kh < 0 || kh >= KH)
684                 continue;
685             if (kw < 0 || kw >= KW)
686                 continue;
687
688             size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride,
689                                                     od, diff_dst_d_stride,
690                                                     oh, diff_dst_h_stride,
691                                                     ow, diff_dst_w_stride);
692             cvt_bfloat16_to_float(ddst_fp32_, &diff_dst[dst_offset_init], C);
693             if (alg == pooling_max) {
694                 ker_max(dsrc_fp32_, ddst_fp32_,
695                         mb, od, oh, ow, kd, kh, kw);
696             } else {
697                 ker_avg(dsrc_fp32_, ddst_fp32_,
698                         mb, od, oh, ow);
699             }
700         }
701         cvt_float_to_bfloat16(&diff_src[src_offset_init],
702                 dsrc_fp32_, C);
703     });
704 }
705 template struct nhwc_pooling_fwd_t<data_type::f32>;
706 template struct nhwc_pooling_bwd_t<data_type::f32>;
707 template struct nhwc_pooling_fwd_t<data_type::bf16>;
708 template struct nhwc_pooling_bwd_t<data_type::bf16>;
709
710 }
711 }
712 }
713
714 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s