#include <assert.h>
#include "c_types_map.hpp"
-#include "cpu_batch_normalization_pd.hpp"
-#include "cpu_engine.hpp"
+#include "memory_tracking.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
+#include "cpu_batch_normalization_pd.hpp"
+
namespace mkldnn {
namespace impl {
namespace cpu {
DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t);
virtual status_t init() override {
- using namespace prop_kind;
using namespace data_type;
+ using namespace prop_kind;
+
assert(engine()->kind() == engine_kind::cpu);
+
bool ok = true
/* the algorithm requires barriers while switching
* between parallelization over N and C dimensions */
desc()->data_scaleshift_desc.data_type == f32)
&& utils::one_of(data_pd_.desc()->format, memory_format::nhwc)
&& (attr()->has_default_values() || this->with_relu_post_op());
- if (!ok)
- return status::unimplemented;
+ if (!ok) return status::unimplemented;
if (is_training() && fuse_bn_relu())
bn_init_default_ws(this, this->workspace_pd_, 8);
if (stats_is_src() || is_training()) {
memory_desc_t stats_d;
dims_t stats_dims = { C() };
- mkldnn_memory_desc_init(&stats_d, 1, stats_dims, data_type::f32,
- memory_format::x);
+ mkldnn_memory_desc_init(&stats_d, 1, stats_dims,
+ data_type::f32, memory_format::x);
mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
}
+ init_scratchpad();
+
return status::success;
}
+
+ private:
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ if (!stats_is_src()) {
+ int sz = nstl::max(C(), 16) * mkldnn_get_max_threads();
+ scratchpad.book(key_bnorm_reduction, sizeof(data_t) * sz);
+ scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * sz);
+ scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * sz);
+ }
+ }
};
typedef typename prec_traits<data_type::f32>::type data_t;
- nspc_batch_normalization_fwd_t(const pd_t *pd, const input_vector &inputs,
- const output_vector &outputs);
- ~nspc_batch_normalization_fwd_t();
- virtual void execute(event_t *e) {
+ nspc_batch_normalization_fwd_t(const pd_t *apd, const input_vector &inputs,
+ const output_vector &outputs)
+ : cpu_primitive_t(apd, inputs, outputs) {}
+ ~nspc_batch_normalization_fwd_t() {}
+
+ virtual void execute(event_t *e) const {
execute_forward();
e->set_state(event_t::ready);
}
private:
- data_t *stats_reduction_;
- data_t *tmp_mean_, *tmp_variance_;
- void execute_forward();
- pd_t conf_;
+ void execute_forward() const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};
struct nspc_batch_normalization_bwd_t : public cpu_primitive_t {
DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t);
virtual status_t init() override {
- using namespace prop_kind;
using namespace data_type;
+ using namespace prop_kind;
+
assert(engine()->kind() == engine_kind::cpu);
+
bool ok = true
/* the algorithm requires barriers while switching
* between parallelization over N and C dimensions */
desc()->data_scaleshift_desc.data_type == f32)
&& utils::one_of(data_pd_.desc()->format, memory_format::nhwc)
&& (attr()->has_default_values() || this->with_relu_post_op());
- if (!ok)
- return status::unimplemented;
+ if (!ok) return status::unimplemented;
if (fuse_bn_relu()) {
bn_init_default_ws(this, this->workspace_pd_, 8);
const size_t this_ws_sz
- = memory_desc_wrapper(this->workspace_pd()).size();
-
- bool ws_ok = true && hint_fwd_pd_->workspace_pd()
- && memory_desc_wrapper(hint_fwd_pd_->workspace_pd())
- .size()
- == this_ws_sz;
- if (!ws_ok)
- return status::unimplemented;
+ = memory_desc_wrapper(this->workspace_pd()).size();
+
+ bool ws_ok = true
+ && hint_fwd_pd_->workspace_pd()
+ && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
+ == this_ws_sz;
+ if (!ws_ok) return status::unimplemented;
}
+ init_scratchpad();
+
return status::success;
}
+
+ private:
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ scratchpad.book(key_bnorm_reduction,
+ sizeof(data_t) * 2 * C() * mkldnn_get_max_threads());
+ scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * 2 * C()
+ * (mkldnn_get_max_threads() + 1));
+ }
};
typedef typename prec_traits<data_type::f32>::type data_t;
- nspc_batch_normalization_bwd_t(const pd_t *pd, const input_vector &inputs,
- const output_vector &outputs);
- ~nspc_batch_normalization_bwd_t();
- virtual void execute(event_t *e) {
+ nspc_batch_normalization_bwd_t(const pd_t *apd, const input_vector &inputs,
+ const output_vector &outputs)
+ : cpu_primitive_t(apd, inputs, outputs) {}
+ ~nspc_batch_normalization_bwd_t() {}
+
+ virtual void execute(event_t *e) const {
execute_backward();
e->set_state(event_t::ready);
}
private:
- data_t *stats_reduction_;
- data_t *tmp_diff_scaleshift_;
- void execute_backward();
- pd_t conf_;
+ void execute_backward() const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};
+
}
}
}