#include <assert.h>
#include "c_types_map.hpp"
-#include "cpu_batch_normalization_pd.hpp"
-#include "cpu_engine.hpp"
+#include "memory_tracking.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
+#include "cpu_batch_normalization_pd.hpp"
+
namespace mkldnn {
namespace impl {
namespace cpu {
DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_fwd_t);
virtual status_t init() override {
- using namespace prop_kind;
using namespace data_type;
+ using namespace prop_kind;
+
assert(engine()->kind() == engine_kind::cpu);
+
bool ok = true
&& is_fwd()
&& !has_zero_dim_memory()
&& utils::one_of(data_pd_.desc()->format, memory_format::nchw,
memory_format::ncdhw, memory_format::nc)
&& (attr()->has_default_values() || this->with_relu_post_op());
- if (!ok)
- return status::unimplemented;
+ if (!ok) return status::unimplemented;
- if (is_training() && fuse_bn_relu()) {
+ if (is_training() && fuse_bn_relu())
bn_init_default_ws(this, this->workspace_pd_, 8);
- }
if (stats_is_src() || is_training()) {
memory_desc_t stats_d;
dims_t stats_dims = { C() };
- mkldnn_memory_desc_init(&stats_d, 1, stats_dims, data_type::f32,
- memory_format::x);
+ mkldnn_memory_desc_init(&stats_d, 1, stats_dims,
+ data_type::f32, memory_format::x);
mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
}
+ init_scratchpad();
+
return success;
}
+
+ private:
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ if (!stats_is_src()) {
+ scratchpad.book(key_bnorm_reduction,
+ sizeof(data_t) * C() * mkldnn_get_max_threads());
+
+ if (!is_training()) {
+ scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * C());
+ scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * C());
+ }
+ }
+ }
};
typedef typename prec_traits<data_type::f32>::type data_t;
- ncsp_batch_normalization_fwd_t(const pd_t *pd, const input_vector &inputs,
- const output_vector &outputs);
- ~ncsp_batch_normalization_fwd_t();
+ ncsp_batch_normalization_fwd_t(const pd_t *apd, const input_vector &inputs,
+ const output_vector &outputs)
+ : cpu_primitive_t(apd, inputs, outputs) {}
+ ~ncsp_batch_normalization_fwd_t() {}
- virtual void execute(event_t *e) {
+ virtual void execute(event_t *e) const {
execute_forward();
e->set_state(event_t::ready);
}
private:
- data_t *stats_reduction_, *tmp_mean_, *tmp_variance_;
- void execute_forward();
- pd_t conf_;
+ void execute_forward() const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};
struct ncsp_batch_normalization_bwd_t : public cpu_primitive_t {
const primitive_attr_t *attr,
const batch_normalization_fwd_pd_t *hint_fwd_pd)
: cpu_batch_normalization_bwd_pd_t(
- engine, adesc, attr, hint_fwd_pd) {}
+ engine, adesc, attr, hint_fwd_pd) {}
DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_bwd_t);
virtual status_t init() override {
- using namespace prop_kind;
using namespace data_type;
assert(engine()->kind() == engine_kind::cpu);
+
bool ok = true
&& is_bwd()
&& !has_zero_dim_memory()
&& utils::one_of(data_pd_.desc()->format, memory_format::nchw,
memory_format::ncdhw, memory_format::nc)
&& attr()->has_default_values();
- if (!ok)
- return status::unimplemented;
+ if (!ok) return status::unimplemented;
if (fuse_bn_relu()) {
bn_init_default_ws(this, this->workspace_pd_, 8);
const size_t this_ws_sz
- = memory_desc_wrapper(this->workspace_pd()).size();
-
- bool ws_ok = true && hint_fwd_pd_->workspace_pd()
- && memory_desc_wrapper(hint_fwd_pd_->workspace_pd())
- .size()
- == this_ws_sz;
- if (!ws_ok)
- return status::unimplemented;
+ = memory_desc_wrapper(this->workspace_pd()).size();
+
+ bool ws_ok = true
+ && hint_fwd_pd_->workspace_pd()
+ && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
+ == this_ws_sz;
+ if (!ws_ok) return status::unimplemented;
}
+ init_scratchpad();
+
return success;
}
+
+ private:
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ scratchpad.book(key_bnorm_reduction,
+ sizeof(data_t) * 2 * C() * mkldnn_get_max_threads());
+ if (!(use_scaleshift() && desc()->prop_kind == prop_kind::backward))
+ scratchpad.book(key_bnorm_tmp_diff_ss,
+ sizeof(data_t) * 2 * C());
+ }
};
typedef typename prec_traits<data_type::f32>::type data_t;
- ncsp_batch_normalization_bwd_t(const pd_t *pd, const input_vector &inputs,
- const output_vector &outputs);
- ~ncsp_batch_normalization_bwd_t();
- virtual void execute(event_t *e) {
+ ncsp_batch_normalization_bwd_t(const pd_t *apd, const input_vector &inputs,
+ const output_vector &outputs)
+ : cpu_primitive_t(apd, inputs, outputs) {}
+ ~ncsp_batch_normalization_bwd_t() {}
+
+ virtual void execute(event_t *e) const {
execute_backward();
e->set_state(event_t::ready);
}
private:
- void execute_backward();
- pd_t conf_;
-
- data_t *stats_reduction_, *tmp_diff_scaleshift_;
+ void execute_backward() const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};
+
}
}
}