namespace cpu {
template <cpu_isa_t isa>
-void jit_uni_pooling_fwd_t<isa>::execute_forward() {
+void jit_uni_pooling_fwd_t<isa>::execute_forward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto dst = reinterpret_cast<data_t*>(this->memory(0));
- auto indices = conf_.desc()->alg_kind == alg_kind::pooling_max ?
+ auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper indices_d(conf_.workspace_pd());
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper indices_d(pd()->workspace_pd());
const size_t ind_dt_size = indices
? types::data_type_size(indices_d.data_type()) : 0;
- const auto &jpp = conf_.jpp_;
- int mb = conf_.MB();
+ const auto &jpp = pd()->jpp_;
+ int mb = pd()->MB();
auto ker = [&](int n, int b_c, int oh) {
auto arg = jit_pool_call_s();
arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
arg.kh_padding_shift = i_t_overflow*jpp.kw;
arg.kw_padding = 0;
- arg.ker_area_h = conf_.desc()->alg_kind == alg_kind::pooling_avg_exclude_padding
+ arg.ker_area_h = pd()->desc()->alg_kind == alg_kind::pooling_avg_exclude_padding
? (float)(jpp.kh - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
nstl::max(0, jpp.t_pad - oh*jpp.stride_h))
: (float)(jpp.kh - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih - jpp.b_pad));
}
template <cpu_isa_t isa>
-void jit_uni_pooling_fwd_t<isa>::execute_forward_3d() {
+void jit_uni_pooling_fwd_t<isa>::execute_forward_3d() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto dst = reinterpret_cast<data_t*>(this->memory(0));
- auto indices = conf_.desc()->alg_kind == alg_kind::pooling_max ?
+ auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper indices_d(conf_.workspace_pd());
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper indices_d(pd()->workspace_pd());
const size_t ind_dt_size = indices
? types::data_type_size(indices_d.data_type()) : 0;
- const auto &jpp = conf_.jpp_;
- int mb = conf_.MB();
+ const auto &jpp = pd()->jpp_;
+ int mb = pd()->MB();
auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
int d_b_overflow) {
template <cpu_isa_t isa>
-void jit_uni_pooling_bwd_t<isa>::execute_backward() {
+void jit_uni_pooling_bwd_t<isa>::execute_backward() const {
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
- auto indices = conf_.desc()->alg_kind == alg_kind::pooling_max ?
+ auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
reinterpret_cast<const char*>(this->input_memory(1)) : nullptr;
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper indices_d(conf_.workspace_pd());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper indices_d(pd()->workspace_pd());
const size_t ind_dt_size = indices
? types::data_type_size(indices_d.data_type()) : 0;
- const auto &jpp = conf_.jpp_;
- int mb = conf_.MB();
+ const auto &jpp = pd()->jpp_;
+ int mb = pd()->MB();
auto ker = [&](int n, int b_c, int oh) {
auto arg = jit_pool_call_s();
}
template <cpu_isa_t isa>
-void jit_uni_pooling_bwd_t<isa>::execute_backward_3d() {
+void jit_uni_pooling_bwd_t<isa>::execute_backward_3d() const {
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
- auto indices = conf_.desc()->alg_kind == alg_kind::pooling_max ?
+ auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
reinterpret_cast<const char*>(this->input_memory(1)) : nullptr;
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper indices_d(conf_.workspace_pd());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper indices_d(pd()->workspace_pd());
const size_t ind_dt_size = indices
? types::data_type_size(indices_d.data_type()) : 0;
- const auto &jpp = conf_.jpp_;
- int mb = conf_.MB();
+ const auto &jpp = pd()->jpp_;
+ int mb = pd()->MB();
auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
int d_b_overflow, int zero_size, int kd) {