namespace cpu {
template <impl::data_type_t data_type>
-void ref_batch_normalization_fwd_t<data_type>::execute_forward() {
+void ref_batch_normalization_fwd_t<data_type>::execute_forward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
/* FIXME: check this */
- data_t* mean = conf_.stats_is_src() ?
+ data_t* mean = pd()->stats_is_src() ?
const_cast<data_t*>(reinterpret_cast<const data_t*>(
this->input_memory(1))) :
reinterpret_cast<data_t*>(this->memory(1));
- data_t* variance = conf_.stats_is_src() ?
+ data_t* variance = pd()->stats_is_src() ?
const_cast<data_t*>(reinterpret_cast<const data_t*>(
this->input_memory(2))) :
reinterpret_cast<data_t*>(this->memory(2));
- auto idx_scaleshift = 1 + 2*conf_.stats_is_src();
+ auto idx_scaleshift = 1 + 2*pd()->stats_is_src();
auto scaleshift =
reinterpret_cast<const data_t *>(this->input_memory(idx_scaleshift));
auto dst = reinterpret_cast<data_t*>(this->memory(0));
- auto ws = reinterpret_cast<uint8_t *>(this->memory(conf_.ws_idx()));
+ auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
/* fast return */
- if (this->conf_.has_zero_dim_memory()) return;
+ if (this->pd()->has_zero_dim_memory()) return;
- const memory_desc_wrapper data_d(conf_.src_pd());
- const memory_desc_wrapper scaleshift_d(conf_.weights_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
+ const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
- const int N = conf_.MB();
- const int C = conf_.C();
+ const int N = pd()->MB();
+ const int C = pd()->C();
int H = 1, W = 1, D = 1;
const bool has_spatial = utils::one_of(data_d.ndims(), 4 ,5);
if (has_spatial)
{
- D = conf_.D();
- H = conf_.H();
- W = conf_.W();
+ D = pd()->D();
+ H = pd()->H();
+ W = pd()->W();
}
- const float eps = conf_.desc()->batch_norm_epsilon;
- const bool use_scaleshift = conf_.use_scaleshift();;
- const bool save_stats = conf_.is_training();
- const bool is_training = conf_.is_training();
- const bool fuse_bn_relu = conf_.fuse_bn_relu();
- const bool calculate_stats = !conf_.stats_is_src();
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();;
+ const bool save_stats = pd()->is_training();
+ const bool is_training = pd()->is_training();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+ const bool calculate_stats = !pd()->stats_is_src();
- const bool with_relu = conf_.with_relu_post_op();
+ const bool with_relu = pd()->with_relu_post_op();
auto maybe_post_op = [&](data_t res) {
return (with_relu && res < 0) ? 0 : res;
};
template struct ref_batch_normalization_fwd_t<data_type::f32>;
template <impl::data_type_t data_type>
-void ref_batch_normalization_bwd_t<data_type>::execute_backward() {
+void ref_batch_normalization_bwd_t<data_type>::execute_backward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
auto variance = reinterpret_cast<const data_t *>(this->input_memory(2));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
auto scaleshift = reinterpret_cast<const data_t *>(this->input_memory(4));
auto ws = reinterpret_cast<const uint8_t *>(
- this->input_memory(conf_.ws_idx()));
+ this->input_memory(pd()->ws_idx()));
auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
auto diff_scaleshift = reinterpret_cast<data_t *>(this->memory(1));
- const memory_desc_wrapper data_d(conf_.src_pd());
- const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
- const memory_desc_wrapper scaleshift_d(conf_.weights_pd());
- const memory_desc_wrapper diff_scaleshift_d(conf_.diff_weights_pd());
- const memory_desc_wrapper mean_d(conf_.mean_pd());
- const memory_desc_wrapper variance_d(conf_.variance_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
+ const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
+ const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
+ const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_pd());
+ const memory_desc_wrapper mean_d(pd()->mean_pd());
+ const memory_desc_wrapper variance_d(pd()->variance_pd());
- const int C = conf_.C();
+ const int C = pd()->C();
/* fast return */
- if (this->conf_.has_zero_dim_memory()) {
+ if (this->pd()->has_zero_dim_memory()) {
if (diff_scaleshift) {
for (int c = 0; c < C; ++c) {
diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0;
return;
}
- const int N = conf_.MB();
+ const int N = pd()->MB();
int H = 1, W = 1, D = 1;
const bool has_spatial = utils::one_of(data_d.ndims(), 4 ,5);
if (has_spatial)
{
- D = conf_.D();
- H = conf_.H();
- W = conf_.W();
+ D = pd()->D();
+ H = pd()->H();
+ W = pd()->W();
}
- const float eps = conf_.desc()->batch_norm_epsilon;
- const bool use_scaleshift = conf_.use_scaleshift();
- const bool calculate_diff_stats = !conf_.omit_stats();
- const bool fuse_bn_relu = conf_.fuse_bn_relu();
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();
+ const bool calculate_diff_stats = !pd()->use_global_stats();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
const bool is_3d = data_d.ndims() == 5;