1 /*******************************************************************************
2 * Copyright 2017-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "math_utils.hpp"
23 #include "mkldnn_thread.hpp"
26 #include "nchw_pooling.hpp"
33 using namespace alg_kind;
34 using namespace bf16_cvt_utils;
36 template <data_type_t d_type>
37 void nchw_pooling_fwd_t<d_type>::execute_forward() const {
39 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
40 auto dst = reinterpret_cast<data_t*>(this->memory(0));
41 auto ws = pd()->desc()->alg_kind == alg_kind::pooling_max ?
42 reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
44 const memory_desc_wrapper ws_d(pd()->workspace_pd());
45 const memory_desc_wrapper src_d(pd()->src_pd());
46 const memory_desc_wrapper dst_d(pd()->dst_pd());
47 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
49 src += src_d.off_l(0);
50 dst += dst_d.off_l(0);
52 const int MB = pd()->MB();
53 const int C = pd()->C();
54 const int OD = pd()->OD();
55 const int OH = pd()->OH();
56 const int OW = pd()->OW();
57 const int ID = pd()->ID();
58 const int IH = pd()->IH();
59 const int IW = pd()->IW();
60 const int KD = pd()->KD();
61 const int KH = pd()->KH();
62 const int KW = pd()->KW();
63 const int SD = pd()->KSD();
64 const int SH = pd()->KSH();
65 const int SW = pd()->KSW();
66 const int padF = pd()->padFront();
67 const int padT = pd()->padT();
68 const int padL = pd()->padL();
69 const int padB = pd()->padB();
70 const int padR = pd()->padR();
71 const int padBack = pd()->padBack();
73 auto alg = pd()->desc()->alg_kind;
75 auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
76 // value = -1 means that pool window is placed outside of source domain
77 // for current {od, oh, ow} point
79 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
81 = (size_t)OW * OH * OD * C * mb
82 + (size_t)OW * OH * OD * c
83 + (size_t)OW * OH * od
86 if (ws_dt == data_type::u8) {
87 const int u8_max = numeric_limits<
88 typename prec_traits<data_type::u8>::type>::max();
91 assert(0 <= value && value <= u8_max);
92 ws[ws_offset] = value;
94 reinterpret_cast<int *>(ws)[ws_offset] = value;
98 auto ker_max = [=](data_t *d, const data_t *src_, int mb, int c, int od, int oh, int ow) {
99 bool is_initialized = false;
100 int current_pool_size = 0;
101 for (int kd = 0; kd < KD; ++kd) {
102 for (int kh = 0; kh < KH; ++kh) {
103 for (int kw = 0; kw < KW; ++kw) {
104 const int id = od * SD - padF + kd;
105 const int ih = oh * SH - padT + kh;
106 const int iw = ow * SW - padL + kw;
108 if (id < 0 || id >= ID) continue;
109 if (ih < 0 || ih >= IH) continue;
110 if (iw < 0 || iw >= IW) continue;
113 + (size_t)IW * IH * kd
116 auto s = src_[src_offset];
117 if (!is_initialized) {
119 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
120 is_initialized = true;
124 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
132 // corner case: pool window is outside of real input domain
134 if (current_pool_size == 0)
135 set_ws(mb, c, od, oh, ow, -1);
138 auto ker_avg = [=](data_t *d, const data_t *src_,
139 int mb, int c, int od, int oh, int ow) {
140 auto id_start = od*SD - padF;
141 auto ih_start = oh*SH - padT;
142 auto iw_start = ow*SW - padL;
143 auto id_end = nstl::min(od*SD - padF + KD, ID + padBack);
144 auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB);
145 auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR);
147 auto num_summands = (alg == pooling_avg_include_padding) ? KD*KW*KH
148 : (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
150 id_start = nstl::max(id_start, 0);
151 ih_start = nstl::max(ih_start, 0);
152 iw_start = nstl::max(iw_start, 0);
154 id_end = nstl::min(id_end, ID);
155 ih_end = nstl::min(ih_end, IH);
156 iw_end = nstl::min(iw_end, IW);
158 if (alg == pooling_avg_exclude_padding)
159 num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
160 if (num_summands == 0) return;
162 for (int id = id_start; id < id_end; ++id) {
163 for (int ih = ih_start; ih < ih_end; ++ih) {
164 for (int iw = iw_start; iw < iw_end; ++iw) {
166 = (size_t)IW * IH * id
169 d[0] += src_[src_offset];
174 d[0] = math::out_round<data_t>((data_t)d[0] / num_summands);
177 if (pd()->desc()->alg_kind == pooling_max) {
178 parallel_nd(MB, C, OD, OH, OW,
179 [&](int mb, int c, int od, int oh, int ow) {
182 = (size_t)OW * OH * OD * C * mb
183 + (size_t)OW * OH * OD * c
184 + (size_t)OW * OH * od
188 = (size_t)IW * IH * ID * C * mb
189 + (size_t)IW * IH * ID * c
190 + (size_t)IW * IH * (od * SD - padF)
191 + (size_t)IW * (oh * SH - padT)
192 + (size_t)(ow * SW - padL);
194 set_ws(mb, c, od, oh, ow, 0);
196 data_t *d = reinterpret_cast<data_t*>(&dst[dst_offset]);
199 reinterpret_cast<const data_t*>(&src[src_offset]);
201 ker_max(d, src_, mb, c, od, oh, ow);
204 parallel_nd(MB, C, OD, OH, OW,
205 [&](int mb, int c, int od, int oh, int ow) {
207 = (size_t)OW * OH * OD * C * mb
208 + (size_t)OW * OH * OD * c
209 + (size_t)OW * OH * od
213 = (size_t)IW * IH * ID * C * mb
214 + (size_t)IW * IH * ID * c;
216 data_t *d = reinterpret_cast<data_t*>(&dst[dst_offset]);
219 reinterpret_cast<const data_t*>(&src[src_offset]);
220 ker_avg(d, src_, mb, c, od, oh, ow);
226 void nchw_pooling_fwd_t<data_type::bf16>::execute_forward() const {
227 auto src = reinterpret_cast<const mkldnn_bfloat16_t *>(this->input_memory(0));
228 auto dst = reinterpret_cast<mkldnn_bfloat16_t*>(this->memory(0));
229 auto ws = pd()->desc()->alg_kind == alg_kind::pooling_max ?
230 reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
232 auto scratchpad = this->scratchpad();
233 float *bf16cvt_wsp_ = scratchpad.template get<float>(
234 memory_tracking::names::key_pool_src_bf16cvt);
236 const memory_desc_wrapper ws_d(pd()->workspace_pd());
237 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
239 const int MB = pd()->MB();
240 const int C = pd()->C();
241 const int OD = pd()->OD();
242 const int OH = pd()->OH();
243 const int OW = pd()->OW();
244 const int ID = pd()->ID();
245 const int IH = pd()->IH();
246 const int IW = pd()->IW();
247 const int KD = pd()->KD();
248 const int KH = pd()->KH();
249 const int KW = pd()->KW();
250 const int SD = pd()->KSD();
251 const int SH = pd()->KSH();
252 const int SW = pd()->KSW();
253 const int padF = pd()->padFront();
254 const int padT = pd()->padT();
255 const int padL = pd()->padL();
256 const int padB = pd()->padB();
257 const int padR = pd()->padR();
258 const int padBack = pd()->padBack();
260 const size_t simd_w_ = 16;
261 const size_t src_size_ = MB * C * ID * IH * IW;
262 const size_t blocked_size_ = src_size_ / simd_w_;
263 const size_t tail_size_ = src_size_ % simd_w_;
265 auto alg = pd()->desc()->alg_kind;
267 auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
268 // value = -1 means that pool window is placed outside of source domain
269 // for current {od, oh, ow} point
271 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
273 = (size_t)OW * OH * OD * C * mb
274 + (size_t)OW * OH * OD * c
275 + (size_t)OW * OH * od
278 if (ws_dt == data_type::u8) {
279 const int u8_max = numeric_limits<
280 typename prec_traits<data_type::u8>::type>::max();
283 assert(0 <= value && value <= u8_max);
284 ws[ws_offset] = value;
286 reinterpret_cast<int *>(ws)[ws_offset] = value;
290 auto ker_max = [=](float *d, const float *src_,
291 int mb, int c, int od, int oh, int ow) {
292 bool is_initialized = false;
293 int current_pool_size = 0;
294 for (int kd = 0; kd < KD; ++kd) {
295 for (int kh = 0; kh < KH; ++kh) {
296 for (int kw = 0; kw < KW; ++kw) {
297 const int id = od * SD - padF + kd;
298 const int ih = oh * SH - padT + kh;
299 const int iw = ow * SW - padL + kw;
301 if (id < 0 || id >= ID) continue;
302 if (ih < 0 || ih >= IH) continue;
303 if (iw < 0 || iw >= IW) continue;
306 + (size_t)IW * IH * kd
309 auto s = src_[src_offset];
310 if (!is_initialized) {
312 set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw);
313 is_initialized = true;
317 set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw);
325 // corner case: pool window is outside of real input domain
327 if (current_pool_size == 0)
328 set_ws(mb, c, od, oh, ow, -1);
331 auto ker_avg = [=](float *d, const float *src_,
332 int mb, int c, int od, int oh, int ow) {
333 auto id_start = od*SD - padF;
334 auto ih_start = oh*SH - padT;
335 auto iw_start = ow*SW - padL;
336 auto id_end = nstl::min(od*SD - padF + KD, ID + padBack);
337 auto ih_end = nstl::min(oh*SH - padT + KH, IH + padB);
338 auto iw_end = nstl::min(ow*SW - padL + KW, IW + padR);
340 // case alg == pooling_avg_include_padding
341 auto num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
343 id_start = nstl::max(id_start, 0);
344 ih_start = nstl::max(ih_start, 0);
345 iw_start = nstl::max(iw_start, 0);
347 id_end = nstl::min(id_end, ID);
348 ih_end = nstl::min(ih_end, IH);
349 iw_end = nstl::min(iw_end, IW);
351 if (alg == pooling_avg_exclude_padding)
352 num_summands = (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
353 if (num_summands == 0) return;
355 for (int id = id_start; id < id_end; ++id) {
356 for (int ih = ih_start; ih < ih_end; ++ih) {
357 for (int iw = iw_start; iw < iw_end; ++iw) {
359 = (size_t)IW * IH * id
362 d[0] += src_[src_offset];
367 d[0] = math::out_round<float>((float)d[0] / num_summands);
370 parallel_nd(blocked_size_, [&](size_t i) {
371 cvt_bfloat16_to_float(&bf16cvt_wsp_[i * simd_w_],
372 &src[i * simd_w_], simd_w_);});
374 cvt_bfloat16_to_float(&bf16cvt_wsp_[blocked_size_ * simd_w_],
375 &src[blocked_size_ * simd_w_], tail_size_);
377 if (pd()->desc()->alg_kind == pooling_max) {
378 parallel_nd(MB, C, OD, OH, OW,
379 [&](int mb, int c, int od, int oh, int ow) {
382 = (size_t)OW * OH * OD * C * mb
383 + (size_t)OW * OH * OD * c
384 + (size_t)OW * OH * od
388 = (size_t)IW * IH * ID * C * mb
389 + (size_t)IW * IH * ID * c
390 + (size_t)IW * IH * (od * SD - padF)
391 + (size_t)IW * (oh * SH - padT)
392 + (size_t)(ow * SW - padL);
394 set_ws(mb, c, od, oh, ow, 0);
396 const float *src_ = &bf16cvt_wsp_[src_offset];
398 float d_fp32 = cvt_bfloat16_to_float(approx_bfloat16_lowest());
399 ker_max(&d_fp32, src_, mb, c, od, oh, ow);
400 dst[dst_offset] = cvt_float_to_bfloat16(d_fp32);
403 parallel_nd(MB, C, OD, OH, OW,
404 [&](int mb, int c, int od, int oh, int ow) {
406 = (size_t)OW * OH * OD * C * mb
407 + (size_t)OW * OH * OD * c
408 + (size_t)OW * OH * od
412 = (size_t)IW * IH * ID * C * mb
413 + (size_t)IW * IH * ID * c;
415 const float *src_ = &bf16cvt_wsp_[src_offset];
418 ker_avg(&d_fp32, src_, mb, c, od, oh, ow);
419 dst[dst_offset] = cvt_float_to_bfloat16(d_fp32);
424 template <data_type_t d_type>
425 void nchw_pooling_bwd_t<d_type>::execute_backward() const {
426 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
427 auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr :
428 reinterpret_cast<const unsigned char *>(this->input_memory(1));
429 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
431 const memory_desc_wrapper ws_d(pd()->workspace_pd());
433 const int MB = pd()->MB();
434 const int C = pd()->C();
435 const int OD = pd()->OD();
436 const int OH = pd()->OH();
437 const int OW = pd()->OW();
438 const int ID = pd()->ID();
439 const int IH = pd()->IH();
440 const int IW = pd()->IW();
441 const int KD = pd()->KD();
442 const int KH = pd()->KH();
443 const int KW = pd()->KW();
444 const int SD = pd()->KSD();
445 const int SH = pd()->KSH();
446 const int SW = pd()->KSW();
447 const int padF = pd()->padFront();
448 const int padT = pd()->padT();
449 const int padL = pd()->padL();
451 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
453 auto alg = pd()->desc()->alg_kind;
455 auto apply_offset = [=](int index, int offset) {
456 return (index > offset) ? index - offset : 0;
459 auto ker_zero = [=](data_t *diff_src) {
460 size_t diff_src_offset = 0;
461 for (int id = 0; id < ID; ++id) {
462 for (int ih = 0; ih < IH; ++ih) {
463 for (int iw = 0; iw < IW; ++iw) {
464 diff_src[diff_src_offset++] = 0;
470 auto ker_max = [=](const data_t *d, data_t *diff_src_,
471 int mb, int c, int od, int oh, int ow) {
472 auto b_c = ws_d.blocking_desc().block_dims[1];
473 auto ws_offset = is_3d
474 ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c
475 : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c;
477 const int index = ws_d.data_type() == data_type::u8
478 ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset];
479 const int invalid_index_value = ws_d.data_type() == data_type::u8
480 ? numeric_limits<typename prec_traits<data_type::u8>::type>::max()
482 if (index == invalid_index_value)
483 return; // corner case: pool window is outside of real input domain
484 // for this point, do nothing
486 const int kw = index % KW;
487 const int kh = (index / KW) % KH;
488 const int kd = (index / KW) / KH;
490 const int id = od * SD - padF + kd;
491 const int ih = oh * SH - padT + kh;
492 const int iw = ow * SW - padL + kw;
494 // If padding area could fit the kernel,
495 // then input displacement would be out of bounds.
496 // No need to back propagate there as padding is
497 // virtual in pooling_max case.
498 if (id < 0 || id >= ID) return;
499 if (ih < 0 || ih >= IH) return;
500 if (iw < 0 || iw >= IW) return;
502 size_t diff_src_offset
503 = (size_t)IH * IW * id
506 diff_src_[diff_src_offset] += d[0];
509 auto ker_avg = [=](const data_t *d, data_t *diff_src_,
510 int mb, int c, int od, int oh, int ow) {
511 auto id_start = apply_offset(od*SD, padF);
512 auto ih_start = apply_offset(oh*SH, padT);
513 auto iw_start = apply_offset(ow*SW, padL);
514 auto id_end = nstl::min(od*SD - padF + KD, ID);
515 auto ih_end = nstl::min(oh*SH - padT + KH, IH);
516 auto iw_end = nstl::min(ow*SW - padL + KW, IW);
518 size_t num_summands = (alg == pooling_avg_include_padding)
520 : (size_t)(id_end - id_start)*(ih_end - ih_start)
521 *(iw_end - iw_start);
523 for (int id = id_start; id < id_end; ++id) {
524 for (int ih = ih_start; ih < ih_end; ++ih) {
525 for (int iw = iw_start; iw < iw_end; ++iw) {
526 size_t diff_src_offset
530 diff_src_[diff_src_offset] += d[0] / num_summands;
536 if (pd()->desc()->alg_kind == pooling_max) {
537 parallel_nd(MB, C, [&](int mb, int c) {
538 size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
539 size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + (size_t)c*OD*OH*OW;
541 reinterpret_cast<const data_t*>(&diff_dst[diff_dst_offset]);
543 reinterpret_cast<data_t*>(&diff_src[diff_src_offset]);
546 for (int od = 0; od < OD; ++od) {
547 for (int oh = 0; oh < OH; ++oh) {
548 for (int ow = 0; ow < OW; ++ow) {
549 const data_t* local_d = &d[count++];
550 ker_max(local_d, diff_src_, mb, c, od, oh, ow);
556 parallel_nd(MB, C, [&](int mb, int c) {
557 size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
558 size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + (size_t)c*OD*OH*OW;
560 reinterpret_cast<const data_t*>(&diff_dst[diff_dst_offset]);
562 reinterpret_cast<data_t*>(&diff_src[diff_src_offset]);
565 for (int od = 0; od < OD; ++od) {
566 for (int oh = 0; oh < OH; ++oh) {
567 for (int ow = 0; ow < OW; ++ow) {
568 const data_t* local_d = &d[count++];
569 ker_avg(local_d, diff_src_, mb, c, od, oh, ow);
578 void nchw_pooling_bwd_t<data_type::bf16>::execute_backward() const {
579 auto diff_dst = reinterpret_cast<const mkldnn_bfloat16_t *>(this->input_memory(0));
580 auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr :
581 reinterpret_cast<const unsigned char *>(this->input_memory(1));
582 auto diff_src = reinterpret_cast<mkldnn_bfloat16_t*>(this->memory(0));
584 auto scratchpad = this->scratchpad();
585 float *bf16cvt_src_ = scratchpad.template get<float>(
586 memory_tracking::names::key_pool_src_bf16cvt);
587 float *bf16cvt_dst_ = scratchpad.template get<float>(
588 memory_tracking::names::key_pool_dst_bf16cvt);
590 const memory_desc_wrapper ws_d(pd()->workspace_pd());
592 const int MB = pd()->MB();
593 const int C = pd()->C();
594 const int OD = pd()->OD();
595 const int OH = pd()->OH();
596 const int OW = pd()->OW();
597 const int ID = pd()->ID();
598 const int IH = pd()->IH();
599 const int IW = pd()->IW();
600 const int KD = pd()->KD();
601 const int KH = pd()->KH();
602 const int KW = pd()->KW();
603 const int SD = pd()->KSD();
604 const int SH = pd()->KSH();
605 const int SW = pd()->KSW();
606 const int padF = pd()->padFront();
607 const int padT = pd()->padT();
608 const int padL = pd()->padL();
609 const size_t dst_sp_sz = pd()->OD() * pd()->OH() * pd()->OW();
610 const size_t src_sp_sz = pd()->ID() * pd()->IH() * pd()->IW();
612 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
614 auto alg = pd()->desc()->alg_kind;
616 auto apply_offset = [=](int index, int offset) {
617 return (index > offset) ? index - offset : 0;
620 auto ker_zero = [=](float *diff_src) {
621 size_t diff_src_offset = 0;
622 for (int id = 0; id < ID; ++id) {
623 for (int ih = 0; ih < IH; ++ih) {
624 for (int iw = 0; iw < IW; ++iw) {
625 diff_src[diff_src_offset++] = 0.0f;
631 auto ker_max = [=](const float *d, float *diff_src_,
632 int mb, int c, int od, int oh, int ow) {
633 auto b_c = ws_d.blocking_desc().block_dims[1];
634 auto ws_offset = is_3d
635 ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c
636 : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c;
638 const int index = ws_d.data_type() == data_type::u8
639 ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset];
640 const int invalid_index_value = ws_d.data_type() == data_type::u8
641 ? numeric_limits<typename prec_traits<data_type::u8>::type>::max()
643 if (index == invalid_index_value)
644 return; // corner case: pool window is outside of real input domain
645 // for this point, do nothing
647 const int kw = index % KW;
648 const int kh = (index / KW) % KH;
649 const int kd = (index / KW) / KH;
651 const int id = od * SD - padF + kd;
652 const int ih = oh * SH - padT + kh;
653 const int iw = ow * SW - padL + kw;
655 // If padding area could fit the kernel,
656 // then input displacement would be out of bounds.
657 // No need to back propagate there as padding is
658 // virtual in pooling_max case.
659 if (id < 0 || id >= ID) return;
660 if (ih < 0 || ih >= IH) return;
661 if (iw < 0 || iw >= IW) return;
663 size_t diff_src_offset
664 = (size_t)IH * IW * id
667 diff_src_[diff_src_offset] += d[0];
670 auto ker_avg = [=](const float *d, float *diff_src_,
671 int mb, int c, int od, int oh, int ow) {
672 auto id_start = apply_offset(od*SD, padF);
673 auto ih_start = apply_offset(oh*SH, padT);
674 auto iw_start = apply_offset(ow*SW, padL);
675 auto id_end = nstl::min(od*SD - padF + KD, ID);
676 auto ih_end = nstl::min(oh*SH - padT + KH, IH);
677 auto iw_end = nstl::min(ow*SW - padL + KW, IW);
679 size_t num_summands = (alg == pooling_avg_include_padding)
681 : (size_t)(id_end - id_start)*(ih_end - ih_start)
682 *(iw_end - iw_start);
684 for (int id = id_start; id < id_end; ++id) {
685 for (int ih = ih_start; ih < ih_end; ++ih) {
686 for (int iw = iw_start; iw < iw_end; ++iw) {
687 size_t diff_src_offset
691 diff_src_[diff_src_offset] += d[0] / num_summands;
697 if (pd()->desc()->alg_kind == pooling_max) {
698 parallel_nd(MB, C, [&](int mb, int c) {
699 size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
700 size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + (size_t)c*OD*OH*OW;
701 float *src_fp32_ = &bf16cvt_src_[mkldnn_get_thread_num()
703 float *dst_fp32_ = &bf16cvt_dst_[mkldnn_get_thread_num()
707 cvt_bfloat16_to_float(dst_fp32_, &diff_dst[diff_dst_offset],
711 for (int od = 0; od < OD; ++od) {
712 for (int oh = 0; oh < OH; ++oh) {
713 for (int ow = 0; ow < OW; ++ow) {
714 ker_max(&dst_fp32_[idx++], src_fp32_, mb, c, od, oh, ow);
718 cvt_float_to_bfloat16(&diff_src[diff_src_offset], src_fp32_,
722 parallel_nd(MB, C, [&](int mb, int c) {
723 size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
724 size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + (size_t)c*OD*OH*OW;
725 float *src_fp32_ = &bf16cvt_src_[mkldnn_get_thread_num()
727 float *dst_fp32_ = &bf16cvt_dst_[mkldnn_get_thread_num()
731 cvt_bfloat16_to_float(dst_fp32_, &diff_dst[diff_dst_offset],
735 for (int od = 0; od < OD; ++od) {
736 for (int oh = 0; oh < OH; ++oh) {
737 for (int ow = 0; ow < OW; ++ow) {
738 ker_avg(&dst_fp32_[idx++], src_fp32_, mb, c, od, oh, ow);
742 cvt_float_to_bfloat16(&diff_src[diff_src_offset], src_fp32_,
747 template struct nchw_pooling_fwd_t<data_type::f32>;
748 template struct nchw_pooling_bwd_t<data_type::f32>;
749 template struct nchw_pooling_fwd_t<data_type::bf16>;
750 template struct nchw_pooling_bwd_t<data_type::bf16>;
756 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s