1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "math_utils.hpp"
23 #include "mkldnn_thread.hpp"
26 #include "nhwc_pooling.hpp"
32 #define MEM_D(name) name##_d
34 #define DECLARE_READ_STRIDES(name) \
35 const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0][0]; \
36 const size_t name##_d_stride = (!is_3d) \
38 : MEM_D(name).blocking_desc().strides[0][2]; \
39 const size_t name##_h_stride = (!is_3d) \
40 ? MEM_D(name).blocking_desc().strides[0][2] \
41 : MEM_D(name).blocking_desc().strides[0][3]; \
42 const size_t name##_w_stride = (!is_3d) \
43 ? MEM_D(name).blocking_desc().strides[0][3] \
44 : MEM_D(name).blocking_desc().strides[0][4];
47 size_t strided_offset(const int _n, const size_t _sn,
48 const int _d, const size_t _sd,
49 const int _h, const size_t _sh,
50 const int _w, const size_t _sw)
59 using namespace alg_kind;
60 using namespace prop_kind;
61 using namespace bf16_cvt_utils;
63 template <data_type_t d_type>
64 void nhwc_pooling_fwd_t<d_type>::execute_forward() const {
65 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
66 auto dst = reinterpret_cast<data_t *>(this->memory(0));
67 unsigned char * ws = reinterpret_cast<unsigned char *>(
68 pd()->desc()->alg_kind == pooling_max
69 && pd()->desc()->prop_kind == forward_training ?
70 this->memory(1) : nullptr
73 const memory_desc_wrapper MEM_D(dst)(pd()->dst_pd());
74 const memory_desc_wrapper MEM_D(src)(pd()->src_pd());
75 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
76 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
78 const int MB = pd()->MB();
79 const int C = pd()->C();
80 const int OD = pd()->OD();
81 const int OH = pd()->OH();
82 const int OW = pd()->OW();
83 const int ID = pd()->ID();
84 const int IH = pd()->IH();
85 const int IW = pd()->IW();
86 const int KD = pd()->KD();
87 const int KH = pd()->KH();
88 const int KW = pd()->KW();
89 const int SD = pd()->KSD();
90 const int SH = pd()->KSH();
91 const int SW = pd()->KSW();
92 const int padF = pd()->padFront();
93 const int padT = pd()->padT();
94 const int padL = pd()->padL();
96 const bool is_3d = pd()->desc()->src_desc.ndims == 5;
97 auto alg = pd()->desc()->alg_kind;
99 DECLARE_READ_STRIDES(src);
100 DECLARE_READ_STRIDES(dst);
102 auto apply_offset = [=](int index, int offset) {
103 return (index > offset) ? index - offset : 0;
106 auto ker_max = [=](data_t *d, const data_t *s, int mb, int od, int oh, int ow) {
107 size_t ws_offset_init = 0;
110 DECLARE_READ_STRIDES(ws);
111 ws_offset_init = strided_offset(mb, ws_n_stride,
117 /* Note: GCC 4.8.5 won't vectorize below simple loops unless
118 * they are singled out into separate helper routines:
119 * array_initialize, array_max */
120 array_initialize(C, d,
121 ws, ws_offset_init, ws_dt);
123 for (int kd = 0; kd < KD; ++kd)
124 for (int kh = 0; kh < KH; ++kh)
125 for (int kw = 0; kw < KW; ++kw) {
126 const int id = od * SD - padF + kd;
127 const int ih = oh * SH - padT + kh;
128 const int iw = ow * SW - padL + kw;
130 if (id < 0 || id >= ID)
132 if (ih < 0 || ih >= IH)
134 if (iw < 0 || iw >= IW)
137 size_t src_offset_init = strided_offset(mb, src_n_stride,
142 d, &s[src_offset_init],
145 kd * KH * KW + kh * KW + kw
150 auto ker_avg = [=](data_t *d, const data_t *s,
151 int mb, int od, int oh, int ow) {
152 utils::array_set(d, 0, C);
154 auto id_start = apply_offset(od * SD, padF);
155 auto ih_start = apply_offset(oh * SH, padT);
156 auto iw_start = apply_offset(ow * SW, padL);
157 auto id_end = nstl::min(od * SD - padF + KD, ID);
158 auto ih_end = nstl::min(oh * SH - padT + KH, IH);
159 auto iw_end = nstl::min(ow * SW - padL + KW, IW);
161 // it is cheaper to actually count this in a loop
162 // as the typical kernel is small
163 size_t num_summands = 0;
165 /* Note: GCC 4.8.5 won't vectorize below simple loops unless
166 * they are singled out into separate helper routines:
167 * array_add, array_div_by_const */
168 for (int id = id_start; id < id_end; ++id)
169 for (int ih = ih_start; ih < ih_end; ++ih)
170 for (int iw = iw_start; iw < iw_end; ++iw) {
171 size_t src_offset_init = strided_offset(mb, src_n_stride,
175 array_add(C, d, &s[src_offset_init]);
179 num_summands = (alg == pooling_avg_include_padding) ?
180 KW * KH * KD : num_summands;
182 array_div_by_const(C, d, num_summands, d);
185 parallel_nd(MB, OD, OH, OW,
186 [&](int mb, int od, int oh, int ow) {
187 size_t dst_offset_init = strided_offset(mb, dst_n_stride,
191 data_t *d = reinterpret_cast<data_t*>(&dst[dst_offset_init]);
193 if (alg == pooling_max) {
194 ker_max(d, src, mb, od, oh, ow);
197 ker_avg(d, src, mb, od, oh, ow);
203 void nhwc_pooling_fwd_t<data_type::bf16>::execute_forward() const {
205 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
206 auto dst = reinterpret_cast<data_t *>(this->memory(0));
207 unsigned char * ws = reinterpret_cast<unsigned char *>(
208 pd()->desc()->alg_kind == pooling_max
209 && pd()->desc()->prop_kind == forward_training ?
210 this->memory(1) : nullptr
213 auto scratchpad = this->scratchpad();
214 float *bf16cvt_src_wsp = scratchpad.template get<float>(
215 memory_tracking::names::key_pool_src_bf16cvt);
216 float *bf16cvt_dst_wsp = scratchpad.template get<float>(
217 memory_tracking::names::key_pool_dst_bf16cvt);
219 const memory_desc_wrapper MEM_D(dst)(pd()->dst_pd());
220 const memory_desc_wrapper MEM_D(src)(pd()->src_pd());
221 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
222 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
224 const int MB = pd()->MB();
225 const int C = pd()->C();
226 const int OD = pd()->OD();
227 const int OH = pd()->OH();
228 const int OW = pd()->OW();
229 const int ID = pd()->ID();
230 const int IH = pd()->IH();
231 const int IW = pd()->IW();
232 const int KD = pd()->KD();
233 const int KH = pd()->KH();
234 const int KW = pd()->KW();
235 const int SD = pd()->KSD();
236 const int SH = pd()->KSH();
237 const int SW = pd()->KSW();
238 const int padF = pd()->padFront();
239 const int padT = pd()->padT();
240 const int padL = pd()->padL();
242 const bool is_3d = pd()->desc()->src_desc.ndims == 5;
243 auto alg = pd()->desc()->alg_kind;
245 DECLARE_READ_STRIDES(src);
246 DECLARE_READ_STRIDES(dst);
248 auto apply_offset = [=](int index, int offset) {
249 return (index > offset) ? index - offset : 0;
252 auto ker_max = [=](mkldnn_bfloat16_t *d, const mkldnn_bfloat16_t *s,
253 int mb, int od, int oh, int ow) {
254 size_t ws_offset_init = 0;
257 DECLARE_READ_STRIDES(ws);
258 ws_offset_init = strided_offset(mb, ws_n_stride,
263 size_t ithr = mkldnn_get_thread_num();
264 float *dst_f32_ = &bf16cvt_dst_wsp[ithr * C];
265 float *src_f32_ = &bf16cvt_src_wsp[ithr * C];
267 /* Note: GCC 4.8.5 won't vectorize below simple loops unless
268 * they are singled out into separate helper routines:
269 * array_initialize, array_max */
270 array_initialize(C, dst_f32_,
271 ws, ws_offset_init, ws_dt);
273 for (int kd = 0; kd < KD; ++kd)
274 for (int kh = 0; kh < KH; ++kh)
275 for (int kw = 0; kw < KW; ++kw) {
276 const int id = od * SD - padF + kd;
277 const int ih = oh * SH - padT + kh;
278 const int iw = ow * SW - padL + kw;
280 if (id < 0 || id >= ID)
282 if (ih < 0 || ih >= IH)
284 if (iw < 0 || iw >= IW)
287 size_t src_offset_init = strided_offset(mb, src_n_stride,
291 cvt_bfloat16_to_float(src_f32_, &s[src_offset_init], C);
296 kd * KH * KW + kh * KW + kw
299 cvt_float_to_bfloat16(d, dst_f32_, C);
302 auto ker_avg = [=](mkldnn_bfloat16_t *d, const mkldnn_bfloat16_t *s,
303 int mb, int od, int oh, int ow) {
304 size_t ithr = mkldnn_get_thread_num();
305 float *dst_f32_ = &bf16cvt_dst_wsp[ithr * C];
306 float *src_f32_ = &bf16cvt_src_wsp[ithr * C];
307 utils::array_set(dst_f32_, 0, C);
309 auto id_start = apply_offset(od * SD, padF);
310 auto ih_start = apply_offset(oh * SH, padT);
311 auto iw_start = apply_offset(ow * SW, padL);
312 auto id_end = nstl::min(od * SD - padF + KD, ID);
313 auto ih_end = nstl::min(oh * SH - padT + KH, IH);
314 auto iw_end = nstl::min(ow * SW - padL + KW, IW);
316 // it is cheaper to actually count this in a loop
317 // as the typical kernel is small
318 size_t num_summands = 0;
320 /* Note: GCC 4.8.5 won't vectorize below simple loops unless
321 * they are singled out into separate helper routines:
322 * array_add, array_div_by_const */
323 for (int id = id_start; id < id_end; ++id)
324 for (int ih = ih_start; ih < ih_end; ++ih)
325 for (int iw = iw_start; iw < iw_end; ++iw) {
326 size_t src_offset_init = strided_offset(mb, src_n_stride,
330 cvt_bfloat16_to_float(src_f32_, &s[src_offset_init], C);
332 array_add(C, dst_f32_, src_f32_);
336 num_summands = (alg == pooling_avg_include_padding) ?
337 KW * KH * KD : num_summands;
339 array_div_by_const(C, dst_f32_, num_summands, dst_f32_);
340 cvt_float_to_bfloat16(d, dst_f32_, C);
343 parallel_nd(MB, OD, OH, OW,
344 [&](int mb, int od, int oh, int ow) {
345 size_t dst_offset_init = strided_offset(mb, dst_n_stride,
349 data_t *d = reinterpret_cast<data_t*>(&dst[dst_offset_init]);
351 if (alg == pooling_max)
352 ker_max((mkldnn_bfloat16_t *)d, (const mkldnn_bfloat16_t *)src,
355 ker_avg((mkldnn_bfloat16_t *)d, (const mkldnn_bfloat16_t *)src,
360 template <impl::data_type_t d_type>
361 void nhwc_pooling_bwd_t<d_type>::execute_backward() const {
362 using namespace alg_kind;
363 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
364 auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
365 : reinterpret_cast<const unsigned char *>(this->input_memory(1));
366 auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
368 const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_pd());
369 const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_pd());
371 const int MB = pd()->MB();
372 const int C = pd()->C();
373 const int OD = pd()->OD();
374 const int OH = pd()->OH();
375 const int OW = pd()->OW();
376 const int ID = pd()->ID();
377 const int IH = pd()->IH();
378 const int IW = pd()->IW();
379 const int KD = pd()->KD();
380 const int KH = pd()->KH();
381 const int KW = pd()->KW();
382 const int SD = pd()->KSD();
383 const int SH = pd()->KSH();
384 const int SW = pd()->KSW();
385 const int padF = pd()->padFront();
386 const int padT = pd()->padT();
387 const int padL = pd()->padL();
389 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
390 auto alg = pd()->desc()->alg_kind;
392 DECLARE_READ_STRIDES(diff_src);
393 DECLARE_READ_STRIDES(diff_dst);
395 auto apply_offset = [=](int index, int offset) {
396 return (index > offset) ? index - offset : 0;
399 auto ker_max = [=](data_t *ds, const data_t *dd,
400 int mb, int od, int oh, int ow,
401 int kd, int kh, int kw) {
402 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
403 DECLARE_READ_STRIDES(ws);
404 size_t ws_offset_init = strided_offset(mb, ws_n_stride,
408 const int index = kd * KH * KW
412 for (int c = 0; c < C; ++c) {
413 const int index_from_ws =
414 (MEM_D(ws).data_type() == data_type::u8)
415 ? (int)ws[ws_offset_init + c]
416 : ((int *)ws)[ws_offset_init + c];
418 // Check if kernel windows are disjoint, in this case
419 // there's no update needed and we just write there once
420 // otherwise we add value to the contents.
422 KH == SH && KW == SW))
423 ds[c] += (index_from_ws == index)
424 ? dd[c] : data_type_t(0);
426 ds[c] = (index_from_ws == index)
427 ? dd[c] : data_type_t(0);
432 auto ker_avg = [=](data_t *ds, const data_t *dd,
433 int mb, int od, int oh, int ow) {
434 auto id_start = apply_offset(od * SD, padF);
435 auto ih_start = apply_offset(oh * SH, padT);
436 auto iw_start = apply_offset(ow * SW, padL);
437 auto id_end = nstl::min(
438 od * SD - padF + KD, ID);
439 auto ih_end = nstl::min(
440 oh * SH - padT + KH, IH);
441 auto iw_end = nstl::min(
442 ow * SW - padL + KW, IW);
444 auto num_summands = (alg == pooling_avg_include_padding)
446 : (ih_end - ih_start) * (iw_end - iw_start) * (id_end - id_start);
449 for (int c = 0; c < C; ++c) {
450 const data_t d = dd[c];
451 // Check if kernel windows are disjoint, in this case
452 // there's no update needed and we just write there once
453 // otherwise we add value to the contents.
455 KH == SH && KW == SW))
456 ds[c] += d / num_summands;
458 ds[c] = d / num_summands;
462 parallel_nd(MB, ID, IH, IW,
463 [&](int mb, int id, int ih, int iw) {
464 size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
465 id, diff_src_d_stride,
466 ih, diff_src_h_stride,
467 iw, diff_src_w_stride);
469 // check if kernel windows are disjoint, in this case there's no
470 // update needed and we just write there once, no initialization
472 if (!(KD == SD && KH == SH
474 for (int c = 0; c < C; ++c)
475 diff_src[src_offset_init + c] = data_type_t(0);
477 // Find out which output cells may correspond to current
478 // input position. Current input postition divided by
479 // stride, with integer divide rounding down, is the
480 // right-most output.
481 // Left-most output may be computed if we decrement input
482 // by (kernel_size - 1) and then do the same division by
484 int od_left = nstl::max(
485 (id + padF - KD + 1) / SD, 0);
486 int oh_left = nstl::max(
487 (ih + padT - KH + 1) / SH, 0);
488 int ow_left = nstl::max(
489 (iw + padL - KW + 1) / SW, 0);
490 // Notice +1 here to preserve the C loop "less than"
491 // condition for continuing the for loop.
492 int od_right = nstl::min(
493 (id + padF) / SD + 1, OD);
494 int oh_right = nstl::min(
495 (ih + padT) / SH + 1, OH);
496 int ow_right = nstl::min(
497 (iw + padL) / SW + 1, OW);
499 for (int od = od_left; od < od_right; ++od)
500 for (int oh = oh_left; oh < oh_right; ++oh)
501 for (int ow = ow_left; ow < ow_right; ++ow) {
502 const int kd = id - od * SD + padF;
503 const int kh = ih - oh * SH + padT;
504 const int kw = iw - ow * SW + padL;
506 if (kd < 0 || kd >= KD)
508 if (kh < 0 || kh >= KH)
510 if (kw < 0 || kw >= KW)
513 size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride,
514 od, diff_dst_d_stride,
515 oh, diff_dst_h_stride,
516 ow, diff_dst_w_stride);
517 if (alg == pooling_max) {
518 ker_max(&diff_src[src_offset_init], &diff_dst[dst_offset_init],
519 mb, od, oh, ow, kd, kh, kw);
521 ker_avg(&diff_src[src_offset_init], &diff_dst[dst_offset_init],
529 void nhwc_pooling_bwd_t<data_type::bf16>::execute_backward() const {
530 using namespace alg_kind;
531 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
532 auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
533 : reinterpret_cast<const unsigned char *>(this->input_memory(1));
534 auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
536 const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_pd());
537 const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_pd());
539 auto scratchpad = this->scratchpad();
540 float *bf16cvt_dsrc_ = scratchpad.template get<float>(
541 memory_tracking::names::key_pool_src_bf16cvt);
542 float *bf16cvt_ddst_ = scratchpad.template get<float>(
543 memory_tracking::names::key_pool_dst_bf16cvt);
545 const int MB = pd()->MB();
546 const int C = pd()->C();
547 const int OD = pd()->OD();
548 const int OH = pd()->OH();
549 const int OW = pd()->OW();
550 const int ID = pd()->ID();
551 const int IH = pd()->IH();
552 const int IW = pd()->IW();
553 const int KD = pd()->KD();
554 const int KH = pd()->KH();
555 const int KW = pd()->KW();
556 const int SD = pd()->KSD();
557 const int SH = pd()->KSH();
558 const int SW = pd()->KSW();
559 const int padF = pd()->padFront();
560 const int padT = pd()->padT();
561 const int padL = pd()->padL();
563 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
564 auto alg = pd()->desc()->alg_kind;
566 DECLARE_READ_STRIDES(diff_src);
567 DECLARE_READ_STRIDES(diff_dst);
569 auto apply_offset = [=](int index, int offset) {
570 return (index > offset) ? index - offset : 0;
573 auto ker_max = [=](float *ds, const float *dd,
574 int mb, int od, int oh, int ow,
575 int kd, int kh, int kw) {
576 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
577 DECLARE_READ_STRIDES(ws);
578 size_t ws_offset_init = strided_offset(mb, ws_n_stride,
582 const int index = kd * KH * KW
586 for (int c = 0; c < C; ++c) {
587 const int index_from_ws =
588 (MEM_D(ws).data_type() == data_type::u8)
589 ? (int)ws[ws_offset_init + c]
590 : ((int *)ws)[ws_offset_init + c];
592 // Check if kernel windows are disjoint, in this case
593 // there's no update needed and we just write there once
594 // otherwise we add value to the contents.
596 KH == SH && KW == SW))
597 ds[c] += (index_from_ws == index)
600 ds[c] = (index_from_ws == index)
606 auto ker_avg = [=](float *ds, const float *dd,
607 int mb, int od, int oh, int ow) {
608 auto id_start = apply_offset(od * SD, padF);
609 auto ih_start = apply_offset(oh * SH, padT);
610 auto iw_start = apply_offset(ow * SW, padL);
611 auto id_end = nstl::min(
612 od * SD - padF + KD, ID);
613 auto ih_end = nstl::min(
614 oh * SH - padT + KH, IH);
615 auto iw_end = nstl::min(
616 ow * SW - padL + KW, IW);
618 auto num_summands = (alg == pooling_avg_include_padding)
620 : (ih_end - ih_start) * (iw_end - iw_start) * (id_end - id_start);
623 for (int c = 0; c < C; ++c) {
624 // Check if kernel windows are disjoint, in this case
625 // there's no update needed and we just write there once
626 // otherwise we add value to the contents.
628 KH == SH && KW == SW))
629 ds[c] += dd[c] / num_summands;
631 ds[c] = dd[c] / num_summands;
635 parallel_nd(MB, ID, IH, IW,
636 [&](int mb, int id, int ih, int iw) {
637 size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
638 id, diff_src_d_stride,
639 ih, diff_src_h_stride,
640 iw, diff_src_w_stride);
642 float *ddst_fp32_ = &bf16cvt_ddst_[mkldnn_get_thread_num() * C];
643 float *dsrc_fp32_ = &bf16cvt_dsrc_[mkldnn_get_thread_num() * C];
644 // check if kernel windows are disjoint, in this case there's no
645 // update needed and we just write there once, no initialization
647 if (!(KD == SD && KH == SH
649 for (int c = 0; c < C; ++c)
650 dsrc_fp32_[c] = 0.0f;
652 // Find out which output cells may correspond to current
653 // input position. Current input postition divided by
654 // stride, with integer divide rounding down, is the
655 // right-most output.
656 // Left-most output may be computed if we decrement input
657 // by (kernel_size - 1) and then do the same division by
659 int od_left = nstl::max(
660 (id + padF - KD + 1) / SD, 0);
661 int oh_left = nstl::max(
662 (ih + padT - KH + 1) / SH, 0);
663 int ow_left = nstl::max(
664 (iw + padL - KW + 1) / SW, 0);
665 // Notice +1 here to preserve the C loop "less than"
666 // condition for continuing the for loop.
667 int od_right = nstl::min(
668 (id + padF) / SD + 1, OD);
669 int oh_right = nstl::min(
670 (ih + padT) / SH + 1, OH);
671 int ow_right = nstl::min(
672 (iw + padL) / SW + 1, OW);
674 for (int od = od_left; od < od_right; ++od)
675 for (int oh = oh_left; oh < oh_right; ++oh)
676 for (int ow = ow_left; ow < ow_right; ++ow) {
677 const int kd = id - od * SD + padF;
678 const int kh = ih - oh * SH + padT;
679 const int kw = iw - ow * SW + padL;
681 if (kd < 0 || kd >= KD)
683 if (kh < 0 || kh >= KH)
685 if (kw < 0 || kw >= KW)
688 size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride,
689 od, diff_dst_d_stride,
690 oh, diff_dst_h_stride,
691 ow, diff_dst_w_stride);
692 cvt_bfloat16_to_float(ddst_fp32_, &diff_dst[dst_offset_init], C);
693 if (alg == pooling_max) {
694 ker_max(dsrc_fp32_, ddst_fp32_,
695 mb, od, oh, ow, kd, kh, kw);
697 ker_avg(dsrc_fp32_, ddst_fp32_,
701 cvt_float_to_bfloat16(&diff_src[src_offset_init],
705 template struct nhwc_pooling_fwd_t<data_type::f32>;
706 template struct nhwc_pooling_bwd_t<data_type::f32>;
707 template struct nhwc_pooling_fwd_t<data_type::bf16>;
708 template struct nhwc_pooling_bwd_t<data_type::bf16>;
714 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s