updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / nchw_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2017-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "math_utils.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "nstl.hpp"
25
26 #include "nchw_pooling.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace nstl;
33 using namespace alg_kind;
34 using namespace bf16_cvt_utils;
35
36 template <data_type_t d_type>
37 void nchw_pooling_fwd_t<d_type>::execute_forward() const {
38
39     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
40     auto dst = reinterpret_cast<data_t*>(this->memory(0));
41     auto ws = pd()->desc()->alg_kind == alg_kind::pooling_max ?
42         reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
43
44     const memory_desc_wrapper ws_d(pd()->workspace_pd());
45     const memory_desc_wrapper src_d(pd()->src_pd());
46     const memory_desc_wrapper dst_d(pd()->dst_pd());
47     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
48
49     src += src_d.off_l(0);
50     dst += dst_d.off_l(0);
51
52     const int MB = pd()->MB();
53     const int C = pd()->C();
54     const int OD = pd()->OD();
55     const int OH = pd()->OH();
56     const int OW = pd()->OW();
57     const int ID = pd()->ID();
58     const int IH = pd()->IH();
59     const int IW = pd()->IW();
60     const int KD = pd()->KD();
61     const int KH = pd()->KH();
62     const int KW = pd()->KW();
63     const int SD = pd()->KSD();
64     const int SH = pd()->KSH();
65     const int SW = pd()->KSW();
66     const int padF = pd()->padFront();
67     const int padT = pd()->padT();
68     const int padL = pd()->padL();
69     const int padB = pd()->padB();
70     const int padR = pd()->padR();
71     const int padBack = pd()->padBack();
72
73     auto alg = pd()->desc()->alg_kind;
74
75     auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
76         // value = -1 means that pool window is placed outside of source domain
77         // for current {od, oh, ow} point
78         if (ws) {
79             assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
80             size_t ws_offset
81                 = (size_t)OW * OH * OD * C * mb
82                 + (size_t)OW * OH * OD * c
83                 + (size_t)OW * OH * od
84                 + (size_t)OW * oh
85                 + (size_t)ow;
86             if (ws_dt == data_type::u8) {
87                 const int u8_max = numeric_limits<
88                     typename prec_traits<data_type::u8>::type>::max();
89                 if (value == -1)
90                     value = u8_max;
91                 assert(0 <= value && value <= u8_max);
92                 ws[ws_offset] = value;
93             } else
94                 reinterpret_cast<int *>(ws)[ws_offset] = value;
95         }
96     };
97
98     auto ker_max = [=](data_t *d, const data_t *src_, int mb, int c, int od, int oh, int ow) {
99         bool is_initialized = false;
100         int current_pool_size = 0;
101         for (int kd = 0; kd < KD; ++kd) {
102             for (int kh = 0; kh < KH; ++kh) {
103                 for (int kw = 0; kw < KW; ++kw) {
104                     const int id = od * SD - padF + kd;
105                     const int ih = oh * SH - padT + kh;
106                     const int iw = ow * SW - padL + kw;
107
108                     if (id < 0 || id >= ID) continue;
109                     if (ih < 0 || ih >= IH) continue;
110                     if (iw < 0 || iw >= IW) continue;
111
112                     auto src_offset =
113                         + (size_t)IW * IH * kd
114                         + (size_t)IW * kh
115                         + (size_t)kw;
116                     auto s = src_[src_offset];
117                     if (!is_initialized) {
118                         d[0] = s;
119                         set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
120                         is_initialized = true;
121                     } else {
122                         if (s > d[0]) {
123                             d[0] = s;
124                             set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
125                         }
126                     }
127                     current_pool_size++;
128                 }
129             }
130         }
131
132         // corner case: pool window is outside of real input domain
133         // for this point.
134         if (current_pool_size == 0)
135             set_ws(mb, c, od, oh, ow, -1);
136     };
137
138     auto ker_avg = [=](data_t *d, const data_t *src_,
139                        int mb, int c, int od, int oh, int ow) {
140         auto id_start = od*SD - padF;
141         auto ih_start = oh*SH - padT;
142         auto iw_start = ow*SW - padL;
143         auto id_end = nstl::min(od*SD - padF + KD, ID + padBack);
144         auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB);
145         auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR);
146
147         auto num_summands = (alg == pooling_avg_include_padding) ? KD*KW*KH
148             : (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
149
150         id_start = nstl::max(id_start, 0);
151         ih_start = nstl::max(ih_start, 0);
152         iw_start = nstl::max(iw_start, 0);
153
154         id_end = nstl::min(id_end, ID);
155         ih_end = nstl::min(ih_end, IH);
156         iw_end = nstl::min(iw_end, IW);
157
158         if (alg == pooling_avg_exclude_padding)
159             num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
160         if (num_summands == 0) return;
161
162         for (int id = id_start; id < id_end; ++id) {
163             for (int ih = ih_start; ih < ih_end; ++ih) {
164                 for (int iw = iw_start; iw < iw_end; ++iw) {
165                     auto src_offset
166                         = (size_t)IW * IH * id
167                         + (size_t)IW * ih
168                         + (size_t)iw;
169                     d[0] += src_[src_offset];
170                 }
171             }
172         }
173
174         d[0] = math::out_round<data_t>((data_t)d[0] / num_summands);
175     };
176
177     if (pd()->desc()->alg_kind == pooling_max) {
178         parallel_nd(MB, C, OD, OH, OW,
179             [&](int mb, int c, int od, int oh, int ow) {
180
181             size_t dst_offset
182                 = (size_t)OW * OH * OD * C * mb
183                 + (size_t)OW * OH * OD * c
184                 + (size_t)OW * OH * od
185                 + (size_t)OW * oh
186                 + (size_t)ow;
187             auto src_offset
188                 = (size_t)IW * IH * ID * C * mb
189                 + (size_t)IW * IH * ID * c
190                 + (size_t)IW * IH * (od * SD - padF)
191                 + (size_t)IW * (oh * SH - padT)
192                 + (size_t)(ow * SW - padL);
193
194             set_ws(mb, c, od, oh, ow, 0);
195
196             data_t *d = reinterpret_cast<data_t*>(&dst[dst_offset]);
197             d[0] = (data_t)0;
198             const data_t *src_ =
199                       reinterpret_cast<const data_t*>(&src[src_offset]);
200
201             ker_max(d, src_,  mb, c, od, oh, ow);
202         });
203     } else {
204         parallel_nd(MB, C, OD, OH, OW,
205             [&](int mb, int c, int od, int oh, int ow) {
206             size_t dst_offset
207                 = (size_t)OW * OH * OD * C * mb
208                 + (size_t)OW * OH * OD * c
209                 + (size_t)OW * OH * od
210                 + (size_t)OW * oh
211                 + (size_t)ow;
212             auto src_offset
213                 = (size_t)IW * IH * ID * C * mb
214                 + (size_t)IW * IH * ID * c;
215
216             data_t *d = reinterpret_cast<data_t*>(&dst[dst_offset]);
217             d[0] = 0;
218             const data_t *src_ =
219                     reinterpret_cast<const data_t*>(&src[src_offset]);
220             ker_avg(d, src_, mb, c, od, oh, ow);
221         });
222     }
223 }
224
225 template <>
226 void nchw_pooling_fwd_t<data_type::bf16>::execute_forward() const {
227     auto src = reinterpret_cast<const mkldnn_bfloat16_t *>(this->input_memory(0));
228     auto dst = reinterpret_cast<mkldnn_bfloat16_t*>(this->memory(0));
229     auto ws = pd()->desc()->alg_kind == alg_kind::pooling_max ?
230         reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
231
232     auto scratchpad = this->scratchpad();
233     float *bf16cvt_wsp_ = scratchpad.template get<float>(
234                           memory_tracking::names::key_pool_src_bf16cvt);
235
236     const memory_desc_wrapper ws_d(pd()->workspace_pd());
237     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
238
239     const int MB = pd()->MB();
240     const int C = pd()->C();
241     const int OD = pd()->OD();
242     const int OH = pd()->OH();
243     const int OW = pd()->OW();
244     const int ID = pd()->ID();
245     const int IH = pd()->IH();
246     const int IW = pd()->IW();
247     const int KD = pd()->KD();
248     const int KH = pd()->KH();
249     const int KW = pd()->KW();
250     const int SD = pd()->KSD();
251     const int SH = pd()->KSH();
252     const int SW = pd()->KSW();
253     const int padF = pd()->padFront();
254     const int padT = pd()->padT();
255     const int padL = pd()->padL();
256     const int padB = pd()->padB();
257     const int padR = pd()->padR();
258     const int padBack = pd()->padBack();
259
260     const size_t simd_w_ = 16;
261     const size_t src_size_ = MB * C * ID * IH * IW;
262     const size_t blocked_size_ = src_size_ / simd_w_;
263     const size_t tail_size_ = src_size_ % simd_w_;
264
265     auto alg = pd()->desc()->alg_kind;
266
267     auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
268         // value = -1 means that pool window is placed outside of source domain
269         // for current {od, oh, ow} point
270         if (ws) {
271             assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
272             size_t ws_offset
273                 = (size_t)OW * OH * OD * C * mb
274                 + (size_t)OW * OH * OD * c
275                 + (size_t)OW * OH * od
276                 + (size_t)OW * oh
277                 + (size_t)ow;
278             if (ws_dt == data_type::u8) {
279                 const int u8_max = numeric_limits<
280                     typename prec_traits<data_type::u8>::type>::max();
281                 if (value == -1)
282                     value = u8_max;
283                 assert(0 <= value && value <= u8_max);
284                 ws[ws_offset] = value;
285             } else
286                 reinterpret_cast<int *>(ws)[ws_offset] = value;
287         }
288     };
289
290     auto ker_max = [=](float *d, const float *src_,
291             int mb, int c, int od, int oh, int ow) {
292         bool is_initialized = false;
293         int current_pool_size = 0;
294         for (int kd = 0; kd < KD; ++kd) {
295             for (int kh = 0; kh < KH; ++kh) {
296                 for (int kw = 0; kw < KW; ++kw) {
297                     const int id = od * SD - padF + kd;
298                     const int ih = oh * SH - padT + kh;
299                     const int iw = ow * SW - padL + kw;
300
301                     if (id < 0 || id >= ID) continue;
302                     if (ih < 0 || ih >= IH) continue;
303                     if (iw < 0 || iw >= IW) continue;
304
305                     auto src_offset =
306                         + (size_t)IW * IH * kd
307                         + (size_t)IW * kh
308                         + (size_t)kw;
309                     auto s = src_[src_offset];
310                     if (!is_initialized) {
311                         d[0] = s;
312                         set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw);
313                         is_initialized = true;
314                     } else {
315                         if (s > d[0]) {
316                             d[0] = s;
317                             set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw);
318                         }
319                     }
320                     current_pool_size++;
321                 }
322             }
323         }
324
325         // corner case: pool window is outside of real input domain
326         // for this point.
327         if (current_pool_size == 0)
328             set_ws(mb, c, od, oh, ow, -1);
329     };
330
331     auto ker_avg = [=](float *d, const float *src_,
332                        int mb, int c, int od, int oh, int ow) {
333         auto id_start = od*SD - padF;
334         auto ih_start = oh*SH - padT;
335         auto iw_start = ow*SW - padL;
336         auto id_end = nstl::min(od*SD - padF + KD, ID + padBack);
337         auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB);
338         auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR);
339
340         // case alg == pooling_avg_include_padding
341         auto num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
342
343         id_start = nstl::max(id_start, 0);
344         ih_start = nstl::max(ih_start, 0);
345         iw_start = nstl::max(iw_start, 0);
346
347         id_end = nstl::min(id_end, ID);
348         ih_end = nstl::min(ih_end, IH);
349         iw_end = nstl::min(iw_end, IW);
350
351         if (alg == pooling_avg_exclude_padding)
352             num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
353         if (num_summands == 0) return;
354
355         for (int id = id_start; id < id_end; ++id) {
356             for (int ih = ih_start; ih < ih_end; ++ih) {
357                 for (int iw = iw_start; iw < iw_end; ++iw) {
358                     auto src_offset
359                         = (size_t)IW * IH * id
360                         + (size_t)IW * ih
361                         + (size_t)iw;
362                     d[0] += src_[src_offset];
363                 }
364             }
365         }
366
367         d[0] = math::out_round<float>((float)d[0] / num_summands);
368     };
369
370     parallel_nd(blocked_size_, [&](size_t i) {
371         cvt_bfloat16_to_float(&bf16cvt_wsp_[i * simd_w_],
372             &src[i * simd_w_], simd_w_);});
373     if (tail_size_)
374         cvt_bfloat16_to_float(&bf16cvt_wsp_[blocked_size_ * simd_w_],
375             &src[blocked_size_ * simd_w_], tail_size_);
376
377     if (pd()->desc()->alg_kind == pooling_max) {
378         parallel_nd(MB, C, OD, OH, OW,
379             [&](int mb, int c, int od, int oh, int ow) {
380
381             size_t dst_offset
382                 = (size_t)OW * OH * OD * C * mb
383                 + (size_t)OW * OH * OD * c
384                 + (size_t)OW * OH * od
385                 + (size_t)OW * oh
386                 + (size_t)ow;
387             auto src_offset
388                 = (size_t)IW * IH * ID * C * mb
389                 + (size_t)IW * IH * ID * c
390                 + (size_t)IW * IH * (od * SD - padF)
391                 + (size_t)IW * (oh * SH - padT)
392                 + (size_t)(ow * SW - padL);
393
394             set_ws(mb, c, od, oh, ow, 0);
395
396             const float *src_ = &bf16cvt_wsp_[src_offset];
397
398             float d_fp32 = cvt_bfloat16_to_float(approx_bfloat16_lowest());
399             ker_max(&d_fp32, src_, mb, c, od, oh, ow);
400             dst[dst_offset] = cvt_float_to_bfloat16(d_fp32);
401         });
402     } else {
403         parallel_nd(MB, C, OD, OH, OW,
404             [&](int mb, int c, int od, int oh, int ow) {
405             size_t dst_offset
406                 = (size_t)OW * OH * OD * C * mb
407                 + (size_t)OW * OH * OD * c
408                 + (size_t)OW * OH * od
409                 + (size_t)OW * oh
410                 + (size_t)ow;
411             auto src_offset
412                 = (size_t)IW * IH * ID * C * mb
413                 + (size_t)IW * IH * ID * c;
414
415             const float *src_ = &bf16cvt_wsp_[src_offset];
416
417             float d_fp32 = 0.0f;
418             ker_avg(&d_fp32, src_, mb, c, od, oh, ow);
419             dst[dst_offset] = cvt_float_to_bfloat16(d_fp32);
420         });
421     }
422 }
423
424 template <data_type_t d_type>
425 void nchw_pooling_bwd_t<d_type>::execute_backward() const {
426     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
427     auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr :
428         reinterpret_cast<const unsigned char *>(this->input_memory(1));
429     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
430
431     const memory_desc_wrapper ws_d(pd()->workspace_pd());
432
433     const int MB = pd()->MB();
434     const int C = pd()->C();
435     const int OD = pd()->OD();
436     const int OH = pd()->OH();
437     const int OW = pd()->OW();
438     const int ID = pd()->ID();
439     const int IH = pd()->IH();
440     const int IW = pd()->IW();
441     const int KD = pd()->KD();
442     const int KH = pd()->KH();
443     const int KW = pd()->KW();
444     const int SD = pd()->KSD();
445     const int SH = pd()->KSH();
446     const int SW = pd()->KSW();
447     const int padF = pd()->padFront();
448     const int padT = pd()->padT();
449     const int padL = pd()->padL();
450
451     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
452
453     auto alg = pd()->desc()->alg_kind;
454
455     auto apply_offset = [=](int index, int offset) {
456         return (index > offset) ? index - offset : 0;
457     };
458
459     auto ker_zero = [=](data_t *diff_src) {
460         size_t diff_src_offset = 0;
461         for (int id = 0; id < ID; ++id) {
462             for (int ih = 0; ih < IH; ++ih) {
463                 for (int iw = 0; iw < IW; ++iw) {
464                     diff_src[diff_src_offset++] = 0;
465                 }
466             }
467         }
468     };
469
470     auto ker_max = [=](const data_t *d,  data_t *diff_src_,
471             int mb, int c, int od, int oh, int ow) {
472         auto b_c = ws_d.blocking_desc().block_dims[1];
473         auto ws_offset = is_3d
474             ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c
475             : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c;
476
477         const int index = ws_d.data_type() == data_type::u8
478             ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset];
479         const int invalid_index_value = ws_d.data_type() == data_type::u8
480             ? numeric_limits<typename prec_traits<data_type::u8>::type>::max()
481             : -1;
482         if (index == invalid_index_value)
483            return; // corner case: pool window is outside of real input domain
484                    // for this point, do nothing
485
486         const int kw = index % KW;
487         const int kh = (index / KW) % KH;
488         const int kd = (index / KW) / KH;
489
490         const int id = od * SD - padF + kd;
491         const int ih = oh * SH - padT + kh;
492         const int iw = ow * SW - padL + kw;
493
494         // If padding area could fit the kernel,
495         // then input displacement would be out of bounds.
496         // No need to back propagate there as padding is
497         // virtual in pooling_max case.
498         if (id < 0 || id >= ID) return;
499         if (ih < 0 || ih >= IH) return;
500         if (iw < 0 || iw >= IW) return;
501
502         size_t diff_src_offset
503             = (size_t)IH * IW * id
504             + (size_t)IW * ih
505             + (size_t)iw;
506         diff_src_[diff_src_offset] += d[0];
507     };
508
509     auto ker_avg = [=](const data_t *d, data_t *diff_src_,
510                         int mb, int c, int od, int oh, int ow) {
511         auto id_start = apply_offset(od*SD, padF);
512         auto ih_start = apply_offset(oh*SH, padT);
513         auto iw_start = apply_offset(ow*SW, padL);
514         auto id_end = nstl::min(od*SD - padF + KD, ID);
515         auto ih_end = nstl::min(oh*SH - padT + KH, IH);
516         auto iw_end = nstl::min(ow*SW - padL + KW, IW);
517
518         size_t num_summands = (alg == pooling_avg_include_padding)
519             ? (size_t)KW*KH*KD
520             : (size_t)(id_end - id_start)*(ih_end - ih_start)
521                 *(iw_end - iw_start);
522
523         for (int id = id_start; id < id_end; ++id) {
524             for (int ih = ih_start; ih < ih_end; ++ih) {
525                 for (int iw = iw_start; iw < iw_end; ++iw) {
526                     size_t diff_src_offset
527                         = (size_t)id*IH*IW
528                         + (size_t)ih*IW
529                         + (size_t)iw;
530                     diff_src_[diff_src_offset] += d[0] / num_summands;
531                 }
532             }
533         }
534     };
535
536     if (pd()->desc()->alg_kind == pooling_max) {
537         parallel_nd(MB, C, [&](int mb, int c) {
538             size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
539             size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + (size_t)c*OD*OH*OW;
540             const data_t* d =
541                 reinterpret_cast<const data_t*>(&diff_dst[diff_dst_offset]);
542             data_t* diff_src_ =
543                 reinterpret_cast<data_t*>(&diff_src[diff_src_offset]);
544             ker_zero(diff_src_);
545             size_t count = 0;
546             for (int od = 0; od < OD; ++od) {
547                 for (int oh = 0; oh < OH; ++oh) {
548                     for (int ow = 0; ow < OW; ++ow) {
549                         const data_t* local_d = &d[count++];
550                         ker_max(local_d, diff_src_, mb, c, od, oh, ow);
551                     }
552                 }
553             }
554         });
555     } else {
556         parallel_nd(MB, C, [&](int mb, int c) {
557             size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
558             size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + (size_t)c*OD*OH*OW;
559             const data_t* d =
560                 reinterpret_cast<const data_t*>(&diff_dst[diff_dst_offset]);
561             data_t* diff_src_ =
562                 reinterpret_cast<data_t*>(&diff_src[diff_src_offset]);
563             ker_zero(diff_src_);
564             size_t count  = 0;
565             for (int od = 0; od < OD; ++od) {
566                 for (int oh = 0; oh < OH; ++oh) {
567                     for (int ow = 0; ow < OW; ++ow) {
568                         const data_t* local_d = &d[count++];
569                         ker_avg(local_d, diff_src_, mb, c, od, oh, ow);
570                     }
571                 }
572             }
573         });
574     }
575 }
576
577 template <>
578 void nchw_pooling_bwd_t<data_type::bf16>::execute_backward() const {
579     auto diff_dst = reinterpret_cast<const mkldnn_bfloat16_t *>(this->input_memory(0));
580     auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr :
581         reinterpret_cast<const unsigned char *>(this->input_memory(1));
582     auto diff_src = reinterpret_cast<mkldnn_bfloat16_t*>(this->memory(0));
583
584     auto scratchpad = this->scratchpad();
585     float *bf16cvt_src_ = scratchpad.template get<float>(
586             memory_tracking::names::key_pool_src_bf16cvt);
587     float *bf16cvt_dst_ = scratchpad.template get<float>(
588             memory_tracking::names::key_pool_dst_bf16cvt);
589
590     const memory_desc_wrapper ws_d(pd()->workspace_pd());
591
592     const int MB = pd()->MB();
593     const int C = pd()->C();
594     const int OD = pd()->OD();
595     const int OH = pd()->OH();
596     const int OW = pd()->OW();
597     const int ID = pd()->ID();
598     const int IH = pd()->IH();
599     const int IW = pd()->IW();
600     const int KD = pd()->KD();
601     const int KH = pd()->KH();
602     const int KW = pd()->KW();
603     const int SD = pd()->KSD();
604     const int SH = pd()->KSH();
605     const int SW = pd()->KSW();
606     const int padF = pd()->padFront();
607     const int padT = pd()->padT();
608     const int padL = pd()->padL();
609     const size_t dst_sp_sz = pd()->OD() * pd()->OH() * pd()->OW();
610     const size_t src_sp_sz = pd()->ID() * pd()->IH() * pd()->IW();
611
612     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
613
614     auto alg = pd()->desc()->alg_kind;
615
616     auto apply_offset = [=](int index, int offset) {
617         return (index > offset) ? index - offset : 0;
618     };
619
620     auto ker_zero = [=](float *diff_src) {
621         size_t diff_src_offset = 0;
622         for (int id = 0; id < ID; ++id) {
623             for (int ih = 0; ih < IH; ++ih) {
624                 for (int iw = 0; iw < IW; ++iw) {
625                     diff_src[diff_src_offset++] = 0.0f;
626                 }
627             }
628         }
629     };
630
631     auto ker_max = [=](const float *d,  float *diff_src_,
632                         int mb, int c, int od, int oh, int ow) {
633         auto b_c = ws_d.blocking_desc().block_dims[1];
634         auto ws_offset = is_3d
635             ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c
636             : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c;
637
638         const int index = ws_d.data_type() == data_type::u8
639             ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset];
640         const int invalid_index_value = ws_d.data_type() == data_type::u8
641             ? numeric_limits<typename prec_traits<data_type::u8>::type>::max()
642             : -1;
643         if (index == invalid_index_value)
644            return; // corner case: pool window is outside of real input domain
645                    // for this point, do nothing
646
647         const int kw = index % KW;
648         const int kh = (index / KW) % KH;
649         const int kd = (index / KW) / KH;
650
651         const int id = od * SD - padF + kd;
652         const int ih = oh * SH - padT + kh;
653         const int iw = ow * SW - padL + kw;
654
655         // If padding area could fit the kernel,
656         // then input displacement would be out of bounds.
657         // No need to back propagate there as padding is
658         // virtual in pooling_max case.
659         if (id < 0 || id >= ID) return;
660         if (ih < 0 || ih >= IH) return;
661         if (iw < 0 || iw >= IW) return;
662
663         size_t diff_src_offset
664             = (size_t)IH * IW * id
665             + (size_t)IW * ih
666             + (size_t)iw;
667         diff_src_[diff_src_offset] += d[0];
668     };
669
670     auto ker_avg = [=](const float *d, float *diff_src_,
671                         int mb, int c, int od, int oh, int ow) {
672         auto id_start = apply_offset(od*SD, padF);
673         auto ih_start = apply_offset(oh*SH, padT);
674         auto iw_start = apply_offset(ow*SW, padL);
675         auto id_end = nstl::min(od*SD - padF + KD, ID);
676         auto ih_end = nstl::min(oh*SH - padT + KH, IH);
677         auto iw_end = nstl::min(ow*SW - padL + KW, IW);
678
679         size_t num_summands = (alg == pooling_avg_include_padding)
680             ? (size_t)KW*KH*KD
681             : (size_t)(id_end - id_start)*(ih_end - ih_start)
682                 *(iw_end - iw_start);
683
684         for (int id = id_start; id < id_end; ++id) {
685             for (int ih = ih_start; ih < ih_end; ++ih) {
686                 for (int iw = iw_start; iw < iw_end; ++iw) {
687                     size_t diff_src_offset
688                         = (size_t)id*IH*IW
689                         + (size_t)ih*IW
690                         + (size_t)iw;
691                     diff_src_[diff_src_offset] += d[0] / num_summands;
692                 }
693             }
694         }
695     };
696
697     if (pd()->desc()->alg_kind == pooling_max) {
698         parallel_nd(MB, C, [&](int mb, int c) {
699             size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
700             size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + (size_t)c*OD*OH*OW;
701             float *src_fp32_ = &bf16cvt_src_[mkldnn_get_thread_num()
702                                                         * src_sp_sz];
703             float *dst_fp32_ = &bf16cvt_dst_[mkldnn_get_thread_num()
704                                                         * dst_sp_sz];
705             ker_zero(src_fp32_);
706
707             cvt_bfloat16_to_float(dst_fp32_, &diff_dst[diff_dst_offset],
708                 dst_sp_sz);
709
710             size_t idx = 0;
711             for (int od = 0; od < OD; ++od) {
712                 for (int oh = 0; oh < OH; ++oh) {
713                     for (int ow = 0; ow < OW; ++ow) {
714                         ker_max(&dst_fp32_[idx++], src_fp32_, mb, c, od, oh, ow);
715                     }
716                 }
717             }
718             cvt_float_to_bfloat16(&diff_src[diff_src_offset], src_fp32_,
719                 src_sp_sz);
720         });
721     } else {
722         parallel_nd(MB, C, [&](int mb, int c) {
723             size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
724             size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + (size_t)c*OD*OH*OW;
725             float *src_fp32_ = &bf16cvt_src_[mkldnn_get_thread_num()
726                                                         * src_sp_sz];
727             float *dst_fp32_ = &bf16cvt_dst_[mkldnn_get_thread_num()
728                                                         * dst_sp_sz];
729             ker_zero(src_fp32_);
730
731             cvt_bfloat16_to_float(dst_fp32_, &diff_dst[diff_dst_offset],
732                 dst_sp_sz);
733
734             size_t idx = 0;
735             for (int od = 0; od < OD; ++od) {
736                 for (int oh = 0; oh < OH; ++oh) {
737                     for (int ow = 0; ow < OW; ++ow) {
738                         ker_avg(&dst_fp32_[idx++], src_fp32_, mb, c, od, oh, ow);
739                     }
740                 }
741             }
742             cvt_float_to_bfloat16(&diff_src[diff_src_offset], src_fp32_,
743                 src_sp_sz);
744         });
745     }
746 }
747 template struct nchw_pooling_fwd_t<data_type::f32>;
748 template struct nchw_pooling_bwd_t<data_type::f32>;
749 template struct nchw_pooling_fwd_t<data_type::bf16>;
750 template struct nchw_pooling_bwd_t<data_type::bf16>;
751
752 }
753 }
754 }
755
756 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s