#include <assert.h>
#include <math.h>
-#include "cpu_batch_normalization_utils.hpp"
#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+
+#include "cpu_batch_normalization_utils.hpp"
#include "jit_generator.hpp"
+
#include "ncsp_batch_normalization.hpp"
-#include "type_helpers.hpp"
// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
#if (defined __clang_major__) && (__clang_major__ >= 6)
namespace impl {
namespace cpu {
-typedef float data_t;
-ncsp_batch_normalization_fwd_t::ncsp_batch_normalization_fwd_t(const pd_t *pd,
- const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), stats_reduction_(nullptr),
- tmp_mean_(nullptr), tmp_variance_(nullptr), conf_(*pd) {
- if (!conf_.stats_is_src()) {
- this->stats_reduction_ = (data_t *)malloc(
- conf_.C() * mkldnn_get_max_threads() * sizeof(data_t), 64);
- if (!conf_.is_training()) {
- this->tmp_mean_ = (data_t *)malloc(conf_.C() * sizeof(data_t), 64);
- this->tmp_variance_
- = (data_t *)malloc(conf_.C() * sizeof(data_t), 64);
- }
- }
-}
-ncsp_batch_normalization_fwd_t::~ncsp_batch_normalization_fwd_t() {
- if (!conf_.stats_is_src()) {
- free(this->stats_reduction_);
- if (!conf_.is_training()) {
- free(this->tmp_mean_);
- free(this->tmp_variance_);
- }
- }
-}
+using namespace memory_tracking::names;
-void ncsp_batch_normalization_fwd_t::execute_forward() {
+void ncsp_batch_normalization_fwd_t::execute_forward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto dst = reinterpret_cast<data_t *>(this->memory(0));
- const bool calculate_stats = !conf_.stats_is_src();
- const bool save_stats = conf_.is_training();
- const bool is_training = conf_.is_training();
- const bool fuse_bn_relu = conf_.fuse_bn_relu();
+ auto scratchpad = this->scratchpad();
+
+ const bool calculate_stats = !pd()->stats_is_src();
+ const bool save_stats = pd()->is_training();
+ const bool is_training = pd()->is_training();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
data_t *mean, *variance;
if (!calculate_stats) {
mean = reinterpret_cast<data_t *>(this->memory(1));
variance = reinterpret_cast<data_t *>(this->memory(2));
} else {
- mean = this->tmp_mean_;
- variance = this->tmp_variance_;
+ mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
+ variance = scratchpad.get<data_t>(key_bnorm_tmp_var);
}
}
- auto idx_scale_shift = 1 + 2 * conf_.stats_is_src();
+ auto idx_scale_shift = 1 + 2 * pd()->stats_is_src();
auto scaleshift = reinterpret_cast<const data_t *>(
this->input_memory(idx_scale_shift));
- auto ws = reinterpret_cast<uint8_t *>(this->memory(conf_.ws_idx()));
- data_t *ws_reduce = this->stats_reduction_;
+ auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
+ auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
- const float eps = conf_.desc()->batch_norm_epsilon;
- const bool use_scaleshift = conf_.use_scaleshift();
- const bool with_relu = conf_.with_relu_post_op();
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();
+ const bool with_relu = pd()->with_relu_post_op();
auto maybe_post_op
= [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
- const bool has_spatial = utils::one_of(conf_.ndims(), 4, 5);
- int SP = (has_spatial) ? conf_.H() * conf_.W() * conf_.D() : 1;
- size_t N = conf_.MB();
- size_t C = conf_.C();
+ const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
+ int SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
+ size_t N = pd()->MB();
+ size_t C = pd()->C();
int nthr = mkldnn_get_max_threads();
size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
});
}
-ncsp_batch_normalization_bwd_t::ncsp_batch_normalization_bwd_t(const pd_t *pd,
- const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
- , stats_reduction_(nullptr), tmp_diff_scaleshift_(nullptr) {
- this->stats_reduction_ = (data_t *)malloc(
- conf_.C() * 2 * mkldnn_get_max_threads() * sizeof(data_t), 64);
- if (!(conf_.use_scaleshift()
- && conf_.desc()->prop_kind == prop_kind::backward))
- this->tmp_diff_scaleshift_
- = (data_t *)malloc(conf_.C() * 2 * sizeof(data_t), 64);
-}
-
-ncsp_batch_normalization_bwd_t::~ncsp_batch_normalization_bwd_t() {
- free(this->stats_reduction_);
- free(this->tmp_diff_scaleshift_);
-}
-
-void ncsp_batch_normalization_bwd_t::execute_backward() {
+void ncsp_batch_normalization_bwd_t::execute_backward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
auto variance = reinterpret_cast<const data_t *>(this->input_memory(2));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
auto scaleshift = reinterpret_cast<const data_t *>(this->input_memory(4));
auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
- auto diff_scaleshift = (this->memory(1)) ?
- reinterpret_cast<data_t *>(this->memory(1)) :
- this->tmp_diff_scaleshift_;
+
+ auto scratchpad = this->scratchpad();
+
+ auto diff_scaleshift = this->memory(1)
+ ? reinterpret_cast<data_t *>(this->memory(1))
+ : scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
auto ws = reinterpret_cast<const uint8_t *>(
- this->input_memory(conf_.ws_idx()));
- data_t *ws_reduce = this->stats_reduction_;
-
- const bool has_spatial = utils::one_of(conf_.ndims(), 4, 5);
- int SP = (has_spatial) ? conf_.H() * conf_.W() * conf_.D() : 1;
- size_t C = conf_.C(), N = conf_.MB();
- const bool use_scaleshift = conf_.use_scaleshift();
- const float eps = conf_.desc()->batch_norm_epsilon;
- const bool calculate_diff_stats = !conf_.omit_stats();
- const bool fuse_bn_relu = conf_.fuse_bn_relu();
+ this->input_memory(pd()->ws_idx()));
+ auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
+
+ const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
+ int SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
+ size_t C = pd()->C(), N = pd()->MB();
+ const bool use_scaleshift = pd()->use_scaleshift();
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool calculate_diff_stats = !pd()->use_global_stats();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
int nthr = mkldnn_get_max_threads();
size_t l3_size_ = get_cache_size(3, true) * nthr / 2;