Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_batch_normalization.cpp
index 3a667ac..38e4f48 100644 (file)
 #include <assert.h>
 
 #include "c_types_map.hpp"
+#include "math_utils.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
 #include "nstl.hpp"
 #include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-#include "math_utils.hpp"
 #include "utils.hpp"
 
-#include "jit_generator.hpp"
 #include "cpu_barrier.hpp"
+#include "cpu_batch_normalization_utils.hpp"
+#include "jit_generator.hpp"
 
 #include "jit_uni_batch_normalization.hpp"
-#include "cpu_batch_normalization_utils.hpp"
 
 namespace mkldnn {
 namespace impl {
@@ -35,6 +36,8 @@ namespace cpu {
 
 namespace {
 
+using namespace memory_tracking::names;
+
 using namespace Xbyak;
 namespace barrier = simple_barrier;
 
@@ -71,7 +74,7 @@ struct jit_bnorm_t: public jit_generator {
     const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
 
     const batch_normalization_pd_t *bdesc_;
-    int is_spatial_thr_;
+    bool is_spatial_thr_;
 
     void (*ker)(const call_params_t *);
     void operator()(const call_params_t *p) { (*ker)(p); }
@@ -846,7 +849,7 @@ struct jit_bnorm_t: public jit_generator {
                                 else
                                     assert(false);
                             }
-                            if (!bdesc_->omit_stats()) {
+                            if (!bdesc_->use_global_stats()) {
                                 uni_vsubps(v, v, vdiff_beta);
                                 uni_vmovups(t, vmmword[reg_src + reg_soff
                                         + offt]);
@@ -1006,11 +1009,15 @@ struct jit_bnorm_t: public jit_generator {
         }
     }
 
-    jit_bnorm_t(const batch_normalization_pd_t *bdesc, int is_spatial_thr):
-        bdesc_(bdesc), is_spatial_thr_(is_spatial_thr) {
+    jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) {
         static_assert(isa == sse42 || isa == avx2 || isa == avx512_common
                 || isa == avx512_mic, "unsupported isa");
 
+        const int simd_w = isa == sse42 ? 8 :
+            cpu_isa_traits<isa>::vlen / sizeof(data_t);
+        is_spatial_thr_ =
+            bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t));
+
         unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
         unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
 
@@ -1044,52 +1051,51 @@ struct jit_bnorm_t: public jit_generator {
 
 template <cpu_isa_t isa>
 struct uni_bnorm_driver_t: public c_compatible {
-    uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc,
-        int is_spatial_thr) : bdesc_(bdesc), ker_(bdesc_,is_spatial_thr),
-        buf_(nullptr), barriers_(nullptr)
+    uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc)
+        : bdesc_(bdesc), ker_(bdesc_)
     {
-        use_tmp_stats_ = !bdesc_->stats_is_src()
-            && bdesc_->desc()->prop_kind == prop_kind::forward_inference;
-        use_tmp_diff_scale_shift_ = false
-            || (bdesc_->is_bwd() && !bdesc_->use_scaleshift())
-            || bdesc_->desc()->prop_kind == prop_kind::backward_data;
-        int num_sbufs = 2 * use_tmp_stats_;
-        int num_pbufs = 2 * use_tmp_diff_scale_shift_;
-        int num_rbufs = bdesc_->is_fwd() ? 1 : 2;
+        const int nthrs = mkldnn_get_max_threads();
+        const int C_PADDED = get_c_padded(bdesc_);
+
+        size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED
+            * bdesc_->D() * bdesc_->H() * bdesc_->W();
+        l3_size_ = get_cache_size(3, true) * nthrs / 2;
+        do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
+    }
+
+    ~uni_bnorm_driver_t() {}
+
+    static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+            const batch_normalization_pd_t *bdesc) {
         int nthrs = mkldnn_get_max_threads();
-        int C_PADDED = memory_desc_wrapper(bdesc_->src_pd()).blocking_desc()
-            .padding_dims[1];
+        int C_PADDED = get_c_padded(bdesc);
 
-        int buf_size = (num_sbufs + num_pbufs + num_rbufs * nthrs) * C_PADDED;
-        buf_ = (data_t *)malloc(buf_size * sizeof(data_t), 64);
+        int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
+        int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED;
+        int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
 
-        sbuf_ = buf_;
-        pbuf_ = sbuf_ + num_sbufs * C_PADDED;
-        rbuf_ = pbuf_ + num_pbufs * C_PADDED;
+        scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz);
+        scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz);
+        scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz);
 
-        int num_barriers = C_PADDED / simd_w;
         if (mkldnn_thr_syncable()) {
-            barriers_ = (barrier::ctx_t *)malloc(
-                    num_barriers * sizeof(barrier::ctx_t), 64);
-            for (int i = 0; i < num_barriers; ++i)
-                barrier::ctx_init(&barriers_[i]);
+            int n_barriers = C_PADDED / simd_w;
+            scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers);
         }
-
-        size_t data_size = bdesc_->MB() * C_PADDED * bdesc_->H()
-                * bdesc_->W() * bdesc_->D() * sizeof(data_t);
-        l3_size_ = get_cache_size(3, true) * nthrs / 2;
-        do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
     }
-    ~uni_bnorm_driver_t() { free(buf_); free(barriers_); }
 
     void exec(int ithr, int nthr, const data_t *src, data_t *diff_src,
             data_t *dst, const data_t *diff_dst, const data_t *scale_shift,
             data_t *diff_scale_shift, const data_t *mean, const data_t *var,
-            const uint8_t *ws) {
+            const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) {
+        auto sbuf = scratchpad.get<data_t>(key_bnorm_tmp_stats);
+        auto pbuf = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
+        auto rbuf = scratchpad.get<data_t>(key_bnorm_reduction);
+        auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
+
         size_t N = bdesc_->MB();
         size_t C = bdesc_->C();
-        size_t C_PADDED = memory_desc_wrapper(bdesc_->src_pd()).blocking_desc()
-            .padding_dims[1];
+        size_t C_PADDED = get_c_padded(bdesc_);
         size_t D = bdesc_->D();
         size_t H = bdesc_->H();
         size_t W = bdesc_->W();
@@ -1162,12 +1168,11 @@ struct uni_bnorm_driver_t: public c_compatible {
             p.S_s = S_s * vlen;
             p.S_tail = (p.spat_size - S_e) * vlen;
             p.coff_max = C_blks_thr * simd_w;
-            p.mean = (use_tmp_stats_ ? sbuf_ : mean) + coff_base;
-            p.var = (use_tmp_stats_ ? sbuf_ + C_PADDED : var) + coff_base;
+            p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base;
+            p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base;
             p.scale_shift = scale_shift + coff_base;
-            p.diff_scale_shift
-                    = (use_tmp_diff_scale_shift_ ? pbuf_ : diff_scale_shift)
-                    + coff_base;
+            p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_)
+                    ? pbuf : diff_scale_shift) + coff_base;
 
             p.soff_max = N_thr * img_size;
             p.src = src + soff_base;
@@ -1180,10 +1185,8 @@ struct uni_bnorm_driver_t: public c_compatible {
 
             // use SP_N_nthr which is the same as p.N_nthr except maybe for
             // the last iteration.
-            p.rbuf1 = rbuf_
-                    + ((it * C_blks_per_iter) * SP_N_nthr + C_blk_s * p.N_nthr
-                              + p.N_ithr * C_blks_thr)
-                            * simd_w;
+            p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr
+                    + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w;
             // rbuf1 and rbuf2 have to be disjoint
             p.rbuf2 = p.rbuf1 + C_PADDED * nthr;
             p.is_cblk_tail =
@@ -1191,89 +1194,193 @@ struct uni_bnorm_driver_t: public c_compatible {
 
             size_t iter_bariers
                     = do_blocking_ ? it * global_barriers_per_iter : 0;
-            p.barrier = barriers_ + C_ithr + iter_bariers;
+            p.barrier = barriers + C_ithr + iter_bariers;
             if (p.soff_max != 0 && p.coff_max != 0)
                 ker_(&p);
         }
     }
 
+    void init_barriers(const memory_tracking::grantor_t &scratchpad) {
+        auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
+        if (barriers) {
+            const int n_barriers = get_c_padded(bdesc_) / simd_w;
+            for (int i = 0; i < n_barriers; ++i)
+                barrier::ctx_init(&barriers[i]);
+        }
+    }
+
 private:
-    const int simd_w = isa == sse42 ? 8 :
-        cpu_isa_traits<isa>::vlen / sizeof(data_t);
+    enum {
+        simd_w = isa == sse42 ? 8 : cpu_isa_traits<isa>::vlen / sizeof(data_t)
+    };
+
+    static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
+        return true
+            && !bdesc->stats_is_src()
+            && bdesc->desc()->prop_kind == prop_kind::forward_inference;
+    }
+
+    static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc)
+    {
+        return false
+            || (bdesc->is_bwd() && !bdesc->use_scaleshift())
+            || bdesc->desc()->prop_kind == prop_kind::backward_data;
+    }
+
+    static int get_c_padded(const batch_normalization_pd_t *bdesc)
+    { return bdesc->src_pd()->desc()->layout_desc.blocking.padding_dims[1]; }
 
     const batch_normalization_pd_t *bdesc_;
-    jit_bnorm_t<isa> ker_;
-    bool use_tmp_stats_, use_tmp_diff_scale_shift_;
     bool do_blocking_;
     size_t l3_size_;
 
-    data_t *buf_, *sbuf_, *rbuf_, *pbuf_;
-
-    barrier::ctx_t *barriers_;
+    jit_bnorm_t<isa> ker_;
 };
 
 }
 
+using namespace data_type;
+using namespace memory_format;
+using namespace utils;
+
+/* fwd */
+
 template <cpu_isa_t isa>
-jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t(
-        const pd_t *pd, const input_vector &inputs,
-        const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-{
-    int is_spatial_thr = 0;
-    const int simd_w = isa == sse42 ? 8 :
-        cpu_isa_traits<isa>::vlen / sizeof(data_t);
+status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init() {
+    assert(engine()->kind() == engine_kind::cpu);
+    auto desired_fmt = (ndims() == 4)
+        ? isa == avx512_common ? nChw16c : nChw8c
+        : isa == avx512_common ? nCdhw16c : nCdhw8c;
+
+    bool ok = true
+        && mayiuse(isa)
+        && is_fwd()
+        && !has_zero_dim_memory()
+        && one_of(ndims(), 4, 5)
+        && desc()->data_desc.data_type == f32
+        && IMPLICATION(use_scaleshift(),
+                desc()->data_scaleshift_desc.data_type == f32)
+        && desc()->data_desc.format == desired_fmt
+        && (attr()->has_default_values() || this->with_relu_post_op());
+    if (!ok) return status::unimplemented;
+
+    if (is_training() && fuse_bn_relu()) {
+        if (isa < avx2) return status::unimplemented;
+        bn_init_default_ws(this, this->workspace_pd_, 1);
+    }
 
-    bnorm_utils::set_spatial_thr(&conf_,simd_w,sizeof(data_t),is_spatial_thr);
+    if (memory_desc_wrapper(&data_pd_).blocking_desc().padding_dims[1]
+            != this->C() && isa < avx2)
+        return status::unimplemented;
 
-    bnorm_driver_ = new uni_bnorm_driver_t<isa>(&conf_,is_spatial_thr);
+    if (stats_is_src() || is_training()) {
+        memory_desc_t stats_d;
+        dims_t stats_dims = { C() };
+        mkldnn_memory_desc_init(&stats_d, 1, stats_dims, f32, x);
+        mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
+        variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
+    }
+
+    auto scratchpad = scratchpad_registry().registrar();
+    uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
+
+    return status::success;
 }
 
 template <cpu_isa_t isa>
-void jit_uni_batch_normalization_fwd_t<isa>::execute(event_t *e) {
+jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t(
+        const pd_t *apd, const input_vector &inputs,
+        const output_vector &outputs)
+    : cpu_primitive_t(apd, inputs, outputs)
+{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
+
+template <cpu_isa_t isa>
+void jit_uni_batch_normalization_fwd_t<isa>::execute(event_t *e) const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto dst = reinterpret_cast<data_t*>(this->memory(0));
-    auto mean = reinterpret_cast<data_t*>(conf_.stats_is_src()
+    auto mean = reinterpret_cast<data_t*>(pd()->stats_is_src()
             ? const_cast<char*>(this->input_memory(1))
             : this->memory(1));
-    auto var = reinterpret_cast<data_t*>(conf_.stats_is_src()
+    auto var = reinterpret_cast<data_t*>(pd()->stats_is_src()
             ? const_cast<char*>(this->input_memory(2))
             : this->memory(2));
 
-    auto idx_scale_shift = 1 + 2*conf_.stats_is_src();
+    auto idx_scale_shift = 1 + 2*pd()->stats_is_src();
     auto scale_shift =
         reinterpret_cast<const data_t *>(this->input_memory(idx_scale_shift));
-    auto ws = reinterpret_cast<uint8_t *>(this->memory(conf_.ws_idx()));
+    auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
+
+    auto scratchpad = this->scratchpad();
+
+    bnorm_driver_->init_barriers(scratchpad);
 
     parallel(0, [&](const int ithr, const int nthr) {
-        bnorm_driver_->exec(ithr, nthr, src,
-                nullptr, dst, nullptr, scale_shift, nullptr, mean, var, ws);
+        bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr,
+                scale_shift, nullptr, mean, var, ws, scratchpad);
     });
     e->set_state(event_t::ready);
 }
 
 template <cpu_isa_t isa>
-jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t() {
-    delete bnorm_driver_;
-}
+jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t()
+{ delete bnorm_driver_; }
+
+/* bwd */
 
 template <cpu_isa_t isa>
-jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t(
-        const pd_t *pd, const input_vector &inputs,
-        const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-{
-    int is_spatial_thr = 0;
-    const int simd_w = isa == sse42 ? 8 :
-        cpu_isa_traits<isa>::vlen / sizeof(data_t);
+status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init() {
+    assert(engine()->kind() == engine_kind::cpu);
+    auto desired_fmt = (ndims() == 4)
+        ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c
+        : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c;
+
+    bool ok = true
+        && mayiuse(isa)
+        && is_bwd()
+        && !has_zero_dim_memory()
+        && one_of(ndims(), 4, 5)
+        && everyone_is(f32, desc()->data_desc.data_type,
+                desc()->diff_data_desc.data_type)
+        && IMPLICATION(use_scaleshift(),
+                desc()->data_scaleshift_desc.data_type == f32)
+        && everyone_is(desired_fmt, desc()->diff_data_desc.format,
+                desc()->data_desc.format)
+        && attr()->has_default_values();
+    if (!ok) return status::unimplemented;
+
+    if (memory_desc_wrapper(&data_pd_).blocking_desc()
+            .padding_dims[1] != this->C() && isa < avx2)
+        return status::unimplemented;
+
+    if (fuse_bn_relu()) {
+        if (isa < avx2) return status::unimplemented;
+        bn_init_default_ws(this, this->workspace_pd_, 1);
+        size_t this_ws_sz = memory_desc_wrapper(this->workspace_pd()).size();
+
+        bool ws_ok = true
+            && hint_fwd_pd_->workspace_pd()
+            && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
+            == this_ws_sz;
+        if (!ws_ok) return status::unimplemented;
+    }
+
+    /* TODO: extra checks required */
 
-    bnorm_utils::set_spatial_thr(&conf_,simd_w,sizeof(data_t),is_spatial_thr);
+    auto scratchpad = scratchpad_registry().registrar();
+    uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
 
-    bnorm_driver_ = new uni_bnorm_driver_t<isa>(&conf_,is_spatial_thr);
+    return status::success;
 }
 
 template <cpu_isa_t isa>
-void jit_uni_batch_normalization_bwd_t<isa>::execute(event_t *e) {
+jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t(
+        const pd_t *apd, const input_vector &inputs,
+        const output_vector &outputs)
+    : cpu_primitive_t(apd, inputs, outputs)
+{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
+
+template <cpu_isa_t isa>
+void jit_uni_batch_normalization_bwd_t<isa>::execute(event_t *e) const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto var = reinterpret_cast<const data_t *>(this->input_memory(2));
@@ -1282,20 +1389,22 @@ void jit_uni_batch_normalization_bwd_t<isa>::execute(event_t *e) {
     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
     auto diff_scale_shift = reinterpret_cast<data_t *>(this->memory(1));
     auto ws = reinterpret_cast<const uint8_t *>(
-            this->input_memory(conf_.ws_idx()));
+            this->input_memory(pd()->ws_idx()));
+
+    auto scratchpad = this->scratchpad();
+
+    bnorm_driver_->init_barriers(scratchpad);
 
     parallel(0, [&](const int ithr, const int nthr) {
-        bnorm_driver_->exec(ithr, nthr, src,
-                diff_src, nullptr, diff_dst, scale_shift, diff_scale_shift,
-                mean, var, ws);
+        bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst,
+                scale_shift, diff_scale_shift, mean, var, ws, scratchpad);
     });
     e->set_state(event_t::ready);
 }
 
 template <cpu_isa_t isa>
-jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t() {
-    delete bnorm_driver_;
-}
+jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t()
+{ delete bnorm_driver_; }
 
 /* struct instantiation */
 template struct jit_uni_batch_normalization_fwd_t<sse42>;