template <impl::data_type_t data_type>
void nhwc_pooling_fwd_t<data_type>::array_div_by_const(const int n,
- const data_t *src, const size_t num, data_t *dst)
+ const data_t *src, const size_t num, data_t *dst) const
{
for (int i = 0; i < n; ++i)
{
}
template <impl::data_type_t data_type>
-void nhwc_pooling_fwd_t<data_type>::array_add(const int n,
- const data_t *src, data_t *dst)
+void nhwc_pooling_fwd_t<data_type>::array_add(const int n, const data_t *src,
+ data_t *dst) const
{
for (int i = 0; i < n; ++i)
{
}
template <impl::data_type_t data_type>
-void nhwc_pooling_fwd_t<data_type>::execute_forward() {
+void nhwc_pooling_fwd_t<data_type>::execute_forward() const {
using namespace alg_kind;
using namespace prop_kind;
using namespace nhwc_pooling;
- auto alg = conf_.desc()->alg_kind;
+ auto alg = pd()->desc()->alg_kind;
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto dst = reinterpret_cast<data_t *>(this->memory(0));
unsigned char * ws = reinterpret_cast<unsigned char *>(
alg == pooling_max
- && conf_.desc()->prop_kind == forward_training ?
+ && pd()->desc()->prop_kind == forward_training ?
this->memory(1) : nullptr
);
- const memory_desc_wrapper MEM_D(dst)(conf_.dst_pd());
- const memory_desc_wrapper MEM_D(ws)(conf_.workspace_pd());
- const memory_desc_wrapper MEM_D(src)(conf_.src_pd());
-
- const int ID = conf_.ID();
- const int IH = conf_.IH();
- const int IW = conf_.IW();
- const int KD = conf_.KD();
- const int KH = conf_.KH();
- const int KW = conf_.KW();
- const int SD = conf_.KSD();
- const int SH = conf_.KSH();
- const int SW = conf_.KSW();
- const int padF = conf_.padFront();
- const int padT = conf_.padT();
- const int padL = conf_.padL();
- const int MB = conf_.MB();
- const int OC = conf_.C();
- const int OD = conf_.OD();
- const int OH = conf_.OH();
- const int OW = conf_.OW();
-
- const bool is_3d = conf_.desc()->src_desc.ndims == 5;
+ const memory_desc_wrapper MEM_D(dst)(pd()->dst_pd());
+ const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
+ const memory_desc_wrapper MEM_D(src)(pd()->src_pd());
+
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ const int SD = pd()->KSD();
+ const int SH = pd()->KSH();
+ const int SW = pd()->KSW();
+ const int padF = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+ const int MB = pd()->MB();
+ const int OC = pd()->C();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+
+ const bool is_3d = pd()->desc()->src_desc.ndims == 5;
const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
DECLARE_READ_STRIDES(src);
}
template <impl::data_type_t data_type>
-void nhwc_pooling_bwd_t<data_type>::execute_backward() {
+void nhwc_pooling_bwd_t<data_type>::execute_backward() const {
using namespace alg_kind;
using namespace nhwc_pooling;
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
- auto ws = conf_.desc()->alg_kind != alg_kind::pooling_max ? nullptr
+ auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
: reinterpret_cast<const unsigned char *>(this->input_memory(1));
auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
- const memory_desc_wrapper MEM_D(diff_dst)(conf_.diff_dst_pd());
- const memory_desc_wrapper MEM_D(ws)(conf_.workspace_pd());
- const memory_desc_wrapper MEM_D(diff_src)(conf_.diff_src_pd());
-
- const int ID = conf_.ID();
- const int IH = conf_.IH();
- const int IW = conf_.IW();
- const int KD = conf_.KD();
- const int KH = conf_.KH();
- const int KW = conf_.KW();
- const int SD = conf_.KSD();
- const int SH = conf_.KSH();
- const int SW = conf_.KSW();
- const int OC = conf_.C();
- const int padF = conf_.padFront();
- const int padT = conf_.padT();
- const int padL = conf_.padL();
- const int OD = conf_.OD();
- const int OH = conf_.OH();
- const int OW = conf_.OW();
-
- const bool is_3d = conf_.desc()->diff_src_desc.ndims == 5;
- auto alg = conf_.desc()->alg_kind;
+ const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_pd());
+ const memory_desc_wrapper MEM_D(ws)(pd()->workspace_pd());
+ const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_pd());
+
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ const int SD = pd()->KSD();
+ const int SH = pd()->KSH();
+ const int SW = pd()->KSW();
+ const int OC = pd()->C();
+ const int padF = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+
+ const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
+ auto alg = pd()->desc()->alg_kind;
DECLARE_READ_STRIDES(diff_src);
DECLARE_READ_STRIDES(diff_dst);
return (index > offset) ? index - offset : 0;
};
- const int MB = conf_.MB();
+ const int MB = pd()->MB();
parallel_nd(MB, ID, IH, IW,
[&](int mb, int id, int ih, int iw) {