#include <assert.h>
#include "c_types_map.hpp"
+#include "math_utils.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
#include "nstl.hpp"
#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-#include "math_utils.hpp"
#include "utils.hpp"
-#include "jit_generator.hpp"
#include "cpu_barrier.hpp"
+#include "cpu_batch_normalization_utils.hpp"
+#include "jit_generator.hpp"
#include "jit_uni_batch_normalization.hpp"
-#include "cpu_batch_normalization_utils.hpp"
namespace mkldnn {
namespace impl {
namespace {
+using namespace memory_tracking::names;
+
using namespace Xbyak;
namespace barrier = simple_barrier;
const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
const batch_normalization_pd_t *bdesc_;
- int is_spatial_thr_;
+ bool is_spatial_thr_;
void (*ker)(const call_params_t *);
void operator()(const call_params_t *p) { (*ker)(p); }
else
assert(false);
}
- if (!bdesc_->omit_stats()) {
+ if (!bdesc_->use_global_stats()) {
uni_vsubps(v, v, vdiff_beta);
uni_vmovups(t, vmmword[reg_src + reg_soff
+ offt]);
}
}
- jit_bnorm_t(const batch_normalization_pd_t *bdesc, int is_spatial_thr):
- bdesc_(bdesc), is_spatial_thr_(is_spatial_thr) {
+ jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) {
static_assert(isa == sse42 || isa == avx2 || isa == avx512_common
|| isa == avx512_mic, "unsupported isa");
+ const int simd_w = isa == sse42 ? 8 :
+ cpu_isa_traits<isa>::vlen / sizeof(data_t);
+ is_spatial_thr_ =
+ bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t));
+
unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
template <cpu_isa_t isa>
struct uni_bnorm_driver_t: public c_compatible {
- uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc,
- int is_spatial_thr) : bdesc_(bdesc), ker_(bdesc_,is_spatial_thr),
- buf_(nullptr), barriers_(nullptr)
+ uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc)
+ : bdesc_(bdesc), ker_(bdesc_)
{
- use_tmp_stats_ = !bdesc_->stats_is_src()
- && bdesc_->desc()->prop_kind == prop_kind::forward_inference;
- use_tmp_diff_scale_shift_ = false
- || (bdesc_->is_bwd() && !bdesc_->use_scaleshift())
- || bdesc_->desc()->prop_kind == prop_kind::backward_data;
- int num_sbufs = 2 * use_tmp_stats_;
- int num_pbufs = 2 * use_tmp_diff_scale_shift_;
- int num_rbufs = bdesc_->is_fwd() ? 1 : 2;
+ const int nthrs = mkldnn_get_max_threads();
+ const int C_PADDED = get_c_padded(bdesc_);
+
+ size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED
+ * bdesc_->D() * bdesc_->H() * bdesc_->W();
+ l3_size_ = get_cache_size(3, true) * nthrs / 2;
+ do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
+ }
+
+ ~uni_bnorm_driver_t() {}
+
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const batch_normalization_pd_t *bdesc) {
int nthrs = mkldnn_get_max_threads();
- int C_PADDED = memory_desc_wrapper(bdesc_->src_pd()).blocking_desc()
- .padding_dims[1];
+ int C_PADDED = get_c_padded(bdesc);
- int buf_size = (num_sbufs + num_pbufs + num_rbufs * nthrs) * C_PADDED;
- buf_ = (data_t *)malloc(buf_size * sizeof(data_t), 64);
+ int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
+ int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED;
+ int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
- sbuf_ = buf_;
- pbuf_ = sbuf_ + num_sbufs * C_PADDED;
- rbuf_ = pbuf_ + num_pbufs * C_PADDED;
+ scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz);
+ scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz);
+ scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz);
- int num_barriers = C_PADDED / simd_w;
if (mkldnn_thr_syncable()) {
- barriers_ = (barrier::ctx_t *)malloc(
- num_barriers * sizeof(barrier::ctx_t), 64);
- for (int i = 0; i < num_barriers; ++i)
- barrier::ctx_init(&barriers_[i]);
+ int n_barriers = C_PADDED / simd_w;
+ scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers);
}
-
- size_t data_size = bdesc_->MB() * C_PADDED * bdesc_->H()
- * bdesc_->W() * bdesc_->D() * sizeof(data_t);
- l3_size_ = get_cache_size(3, true) * nthrs / 2;
- do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
}
- ~uni_bnorm_driver_t() { free(buf_); free(barriers_); }
void exec(int ithr, int nthr, const data_t *src, data_t *diff_src,
data_t *dst, const data_t *diff_dst, const data_t *scale_shift,
data_t *diff_scale_shift, const data_t *mean, const data_t *var,
- const uint8_t *ws) {
+ const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) {
+ auto sbuf = scratchpad.get<data_t>(key_bnorm_tmp_stats);
+ auto pbuf = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
+ auto rbuf = scratchpad.get<data_t>(key_bnorm_reduction);
+ auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
+
size_t N = bdesc_->MB();
size_t C = bdesc_->C();
- size_t C_PADDED = memory_desc_wrapper(bdesc_->src_pd()).blocking_desc()
- .padding_dims[1];
+ size_t C_PADDED = get_c_padded(bdesc_);
size_t D = bdesc_->D();
size_t H = bdesc_->H();
size_t W = bdesc_->W();
p.S_s = S_s * vlen;
p.S_tail = (p.spat_size - S_e) * vlen;
p.coff_max = C_blks_thr * simd_w;
- p.mean = (use_tmp_stats_ ? sbuf_ : mean) + coff_base;
- p.var = (use_tmp_stats_ ? sbuf_ + C_PADDED : var) + coff_base;
+ p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base;
+ p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base;
p.scale_shift = scale_shift + coff_base;
- p.diff_scale_shift
- = (use_tmp_diff_scale_shift_ ? pbuf_ : diff_scale_shift)
- + coff_base;
+ p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_)
+ ? pbuf : diff_scale_shift) + coff_base;
p.soff_max = N_thr * img_size;
p.src = src + soff_base;
// use SP_N_nthr which is the same as p.N_nthr except maybe for
// the last iteration.
- p.rbuf1 = rbuf_
- + ((it * C_blks_per_iter) * SP_N_nthr + C_blk_s * p.N_nthr
- + p.N_ithr * C_blks_thr)
- * simd_w;
+ p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr
+ + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w;
// rbuf1 and rbuf2 have to be disjoint
p.rbuf2 = p.rbuf1 + C_PADDED * nthr;
p.is_cblk_tail =
size_t iter_bariers
= do_blocking_ ? it * global_barriers_per_iter : 0;
- p.barrier = barriers_ + C_ithr + iter_bariers;
+ p.barrier = barriers + C_ithr + iter_bariers;
if (p.soff_max != 0 && p.coff_max != 0)
ker_(&p);
}
}
+ void init_barriers(const memory_tracking::grantor_t &scratchpad) {
+ auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
+ if (barriers) {
+ const int n_barriers = get_c_padded(bdesc_) / simd_w;
+ for (int i = 0; i < n_barriers; ++i)
+ barrier::ctx_init(&barriers[i]);
+ }
+ }
+
private:
- const int simd_w = isa == sse42 ? 8 :
- cpu_isa_traits<isa>::vlen / sizeof(data_t);
+ enum {
+ simd_w = isa == sse42 ? 8 : cpu_isa_traits<isa>::vlen / sizeof(data_t)
+ };
+
+ static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
+ return true
+ && !bdesc->stats_is_src()
+ && bdesc->desc()->prop_kind == prop_kind::forward_inference;
+ }
+
+ static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc)
+ {
+ return false
+ || (bdesc->is_bwd() && !bdesc->use_scaleshift())
+ || bdesc->desc()->prop_kind == prop_kind::backward_data;
+ }
+
+ static int get_c_padded(const batch_normalization_pd_t *bdesc)
+ { return bdesc->src_pd()->desc()->layout_desc.blocking.padding_dims[1]; }
const batch_normalization_pd_t *bdesc_;
- jit_bnorm_t<isa> ker_;
- bool use_tmp_stats_, use_tmp_diff_scale_shift_;
bool do_blocking_;
size_t l3_size_;
- data_t *buf_, *sbuf_, *rbuf_, *pbuf_;
-
- barrier::ctx_t *barriers_;
+ jit_bnorm_t<isa> ker_;
};
}
+using namespace data_type;
+using namespace memory_format;
+using namespace utils;
+
+/* fwd */
+
template <cpu_isa_t isa>
-jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t(
- const pd_t *pd, const input_vector &inputs,
- const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-{
- int is_spatial_thr = 0;
- const int simd_w = isa == sse42 ? 8 :
- cpu_isa_traits<isa>::vlen / sizeof(data_t);
+status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init() {
+ assert(engine()->kind() == engine_kind::cpu);
+ auto desired_fmt = (ndims() == 4)
+ ? isa == avx512_common ? nChw16c : nChw8c
+ : isa == avx512_common ? nCdhw16c : nCdhw8c;
+
+ bool ok = true
+ && mayiuse(isa)
+ && is_fwd()
+ && !has_zero_dim_memory()
+ && one_of(ndims(), 4, 5)
+ && desc()->data_desc.data_type == f32
+ && IMPLICATION(use_scaleshift(),
+ desc()->data_scaleshift_desc.data_type == f32)
+ && desc()->data_desc.format == desired_fmt
+ && (attr()->has_default_values() || this->with_relu_post_op());
+ if (!ok) return status::unimplemented;
+
+ if (is_training() && fuse_bn_relu()) {
+ if (isa < avx2) return status::unimplemented;
+ bn_init_default_ws(this, this->workspace_pd_, 1);
+ }
- bnorm_utils::set_spatial_thr(&conf_,simd_w,sizeof(data_t),is_spatial_thr);
+ if (memory_desc_wrapper(&data_pd_).blocking_desc().padding_dims[1]
+ != this->C() && isa < avx2)
+ return status::unimplemented;
- bnorm_driver_ = new uni_bnorm_driver_t<isa>(&conf_,is_spatial_thr);
+ if (stats_is_src() || is_training()) {
+ memory_desc_t stats_d;
+ dims_t stats_dims = { C() };
+ mkldnn_memory_desc_init(&stats_d, 1, stats_dims, f32, x);
+ mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
+ variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
+ }
+
+ auto scratchpad = scratchpad_registry().registrar();
+ uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
+
+ return status::success;
}
template <cpu_isa_t isa>
-void jit_uni_batch_normalization_fwd_t<isa>::execute(event_t *e) {
+jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t(
+ const pd_t *apd, const input_vector &inputs,
+ const output_vector &outputs)
+ : cpu_primitive_t(apd, inputs, outputs)
+{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
+
+template <cpu_isa_t isa>
+void jit_uni_batch_normalization_fwd_t<isa>::execute(event_t *e) const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto dst = reinterpret_cast<data_t*>(this->memory(0));
- auto mean = reinterpret_cast<data_t*>(conf_.stats_is_src()
+ auto mean = reinterpret_cast<data_t*>(pd()->stats_is_src()
? const_cast<char*>(this->input_memory(1))
: this->memory(1));
- auto var = reinterpret_cast<data_t*>(conf_.stats_is_src()
+ auto var = reinterpret_cast<data_t*>(pd()->stats_is_src()
? const_cast<char*>(this->input_memory(2))
: this->memory(2));
- auto idx_scale_shift = 1 + 2*conf_.stats_is_src();
+ auto idx_scale_shift = 1 + 2*pd()->stats_is_src();
auto scale_shift =
reinterpret_cast<const data_t *>(this->input_memory(idx_scale_shift));
- auto ws = reinterpret_cast<uint8_t *>(this->memory(conf_.ws_idx()));
+ auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
+
+ auto scratchpad = this->scratchpad();
+
+ bnorm_driver_->init_barriers(scratchpad);
parallel(0, [&](const int ithr, const int nthr) {
- bnorm_driver_->exec(ithr, nthr, src,
- nullptr, dst, nullptr, scale_shift, nullptr, mean, var, ws);
+ bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr,
+ scale_shift, nullptr, mean, var, ws, scratchpad);
});
e->set_state(event_t::ready);
}
template <cpu_isa_t isa>
-jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t() {
- delete bnorm_driver_;
-}
+jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t()
+{ delete bnorm_driver_; }
+
+/* bwd */
template <cpu_isa_t isa>
-jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t(
- const pd_t *pd, const input_vector &inputs,
- const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-{
- int is_spatial_thr = 0;
- const int simd_w = isa == sse42 ? 8 :
- cpu_isa_traits<isa>::vlen / sizeof(data_t);
+status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init() {
+ assert(engine()->kind() == engine_kind::cpu);
+ auto desired_fmt = (ndims() == 4)
+ ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c
+ : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c;
+
+ bool ok = true
+ && mayiuse(isa)
+ && is_bwd()
+ && !has_zero_dim_memory()
+ && one_of(ndims(), 4, 5)
+ && everyone_is(f32, desc()->data_desc.data_type,
+ desc()->diff_data_desc.data_type)
+ && IMPLICATION(use_scaleshift(),
+ desc()->data_scaleshift_desc.data_type == f32)
+ && everyone_is(desired_fmt, desc()->diff_data_desc.format,
+ desc()->data_desc.format)
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ if (memory_desc_wrapper(&data_pd_).blocking_desc()
+ .padding_dims[1] != this->C() && isa < avx2)
+ return status::unimplemented;
+
+ if (fuse_bn_relu()) {
+ if (isa < avx2) return status::unimplemented;
+ bn_init_default_ws(this, this->workspace_pd_, 1);
+ size_t this_ws_sz = memory_desc_wrapper(this->workspace_pd()).size();
+
+ bool ws_ok = true
+ && hint_fwd_pd_->workspace_pd()
+ && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
+ == this_ws_sz;
+ if (!ws_ok) return status::unimplemented;
+ }
+
+ /* TODO: extra checks required */
- bnorm_utils::set_spatial_thr(&conf_,simd_w,sizeof(data_t),is_spatial_thr);
+ auto scratchpad = scratchpad_registry().registrar();
+ uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
- bnorm_driver_ = new uni_bnorm_driver_t<isa>(&conf_,is_spatial_thr);
+ return status::success;
}
template <cpu_isa_t isa>
-void jit_uni_batch_normalization_bwd_t<isa>::execute(event_t *e) {
+jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t(
+ const pd_t *apd, const input_vector &inputs,
+ const output_vector &outputs)
+ : cpu_primitive_t(apd, inputs, outputs)
+{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
+
+template <cpu_isa_t isa>
+void jit_uni_batch_normalization_bwd_t<isa>::execute(event_t *e) const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
auto var = reinterpret_cast<const data_t *>(this->input_memory(2));
auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
auto diff_scale_shift = reinterpret_cast<data_t *>(this->memory(1));
auto ws = reinterpret_cast<const uint8_t *>(
- this->input_memory(conf_.ws_idx()));
+ this->input_memory(pd()->ws_idx()));
+
+ auto scratchpad = this->scratchpad();
+
+ bnorm_driver_->init_barriers(scratchpad);
parallel(0, [&](const int ithr, const int nthr) {
- bnorm_driver_->exec(ithr, nthr, src,
- diff_src, nullptr, diff_dst, scale_shift, diff_scale_shift,
- mean, var, ws);
+ bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst,
+ scale_shift, diff_scale_shift, mean, var, ws, scratchpad);
});
e->set_state(event_t::ready);
}
template <cpu_isa_t isa>
-jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t() {
- delete bnorm_driver_;
-}
+jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t()
+{ delete bnorm_driver_; }
/* struct instantiation */
template struct jit_uni_batch_normalization_fwd_t<sse42>;