Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / nspc_batch_normalization.hpp
index 168caf9..6c1ec25 100644 (file)
 #include <assert.h>
 
 #include "c_types_map.hpp"
-#include "cpu_batch_normalization_pd.hpp"
-#include "cpu_engine.hpp"
+#include "memory_tracking.hpp"
 #include "type_helpers.hpp"
 #include "utils.hpp"
 
+#include "cpu_batch_normalization_pd.hpp"
+
 namespace mkldnn {
 namespace impl {
 namespace cpu {
@@ -40,9 +41,11 @@ struct nspc_batch_normalization_fwd_t : public cpu_primitive_t {
         DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t);
 
         virtual status_t init() override {
-            using namespace prop_kind;
             using namespace data_type;
+            using namespace prop_kind;
+
             assert(engine()->kind() == engine_kind::cpu);
+
             bool ok = true
                 /* the algorithm requires barriers while switching
                  * between parallelization over N and C dimensions */
@@ -54,8 +57,7 @@ struct nspc_batch_normalization_fwd_t : public cpu_primitive_t {
                         desc()->data_scaleshift_desc.data_type == f32)
                 && utils::one_of(data_pd_.desc()->format, memory_format::nhwc)
                 && (attr()->has_default_values() || this->with_relu_post_op());
-            if (!ok)
-                return status::unimplemented;
+            if (!ok) return status::unimplemented;
 
             if (is_training() && fuse_bn_relu())
                 bn_init_default_ws(this, this->workspace_pd_, 8);
@@ -63,31 +65,45 @@ struct nspc_batch_normalization_fwd_t : public cpu_primitive_t {
             if (stats_is_src() || is_training()) {
                 memory_desc_t stats_d;
                 dims_t stats_dims = { C() };
-                mkldnn_memory_desc_init(&stats_d, 1, stats_dims, data_type::f32,
-                        memory_format::x);
+                mkldnn_memory_desc_init(&stats_d, 1, stats_dims,
+                        data_type::f32, memory_format::x);
                 mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
                 variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
             }
 
+            init_scratchpad();
+
             return status::success;
         }
+
+    private:
+        void init_scratchpad() {
+            using namespace memory_tracking::names;
+            auto scratchpad = scratchpad_registry().registrar();
+            if (!stats_is_src()) {
+                int sz = nstl::max(C(), 16) * mkldnn_get_max_threads();
+                scratchpad.book(key_bnorm_reduction, sizeof(data_t) * sz);
+                scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * sz);
+                scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * sz);
+            }
+        }
     };
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    nspc_batch_normalization_fwd_t(const pd_t *pd, const input_vector &inputs,
-            const output_vector &outputs);
-    ~nspc_batch_normalization_fwd_t();
-    virtual void execute(event_t *e) {
+    nspc_batch_normalization_fwd_t(const pd_t *apd, const input_vector &inputs,
+            const output_vector &outputs)
+        : cpu_primitive_t(apd, inputs, outputs) {}
+    ~nspc_batch_normalization_fwd_t() {}
+
+    virtual void execute(event_t *e) const {
         execute_forward();
         e->set_state(event_t::ready);
     }
 
 private:
-    data_t *stats_reduction_;
-    data_t *tmp_mean_, *tmp_variance_;
-    void execute_forward();
-    pd_t conf_;
+    void execute_forward() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
 
 struct nspc_batch_normalization_bwd_t : public cpu_primitive_t {
@@ -101,9 +117,11 @@ struct nspc_batch_normalization_bwd_t : public cpu_primitive_t {
         DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t);
 
         virtual status_t init() override {
-            using namespace prop_kind;
             using namespace data_type;
+            using namespace prop_kind;
+
             assert(engine()->kind() == engine_kind::cpu);
+
             bool ok = true
                 /* the algorithm requires barriers while switching
                  * between parallelization over N and C dimensions */
@@ -115,42 +133,53 @@ struct nspc_batch_normalization_bwd_t : public cpu_primitive_t {
                         desc()->data_scaleshift_desc.data_type == f32)
                 && utils::one_of(data_pd_.desc()->format, memory_format::nhwc)
                 && (attr()->has_default_values() || this->with_relu_post_op());
-            if (!ok)
-                return status::unimplemented;
+            if (!ok) return status::unimplemented;
 
             if (fuse_bn_relu()) {
                 bn_init_default_ws(this, this->workspace_pd_, 8);
                 const size_t this_ws_sz
-                        = memory_desc_wrapper(this->workspace_pd()).size();
-
-                bool ws_ok = true && hint_fwd_pd_->workspace_pd()
-                        && memory_desc_wrapper(hint_fwd_pd_->workspace_pd())
-                                        .size()
-                                == this_ws_sz;
-                if (!ws_ok)
-                    return status::unimplemented;
+                    = memory_desc_wrapper(this->workspace_pd()).size();
+
+                bool ws_ok = true
+                    && hint_fwd_pd_->workspace_pd()
+                    && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
+                    == this_ws_sz;
+                if (!ws_ok) return status::unimplemented;
             }
 
+            init_scratchpad();
+
             return status::success;
         }
+
+    private:
+        void init_scratchpad() {
+            using namespace memory_tracking::names;
+            auto scratchpad = scratchpad_registry().registrar();
+            scratchpad.book(key_bnorm_reduction,
+                    sizeof(data_t) * 2 * C() * mkldnn_get_max_threads());
+            scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * 2 * C()
+                    * (mkldnn_get_max_threads() + 1));
+        }
     };
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    nspc_batch_normalization_bwd_t(const pd_t *pd, const input_vector &inputs,
-            const output_vector &outputs);
-    ~nspc_batch_normalization_bwd_t();
-    virtual void execute(event_t *e) {
+    nspc_batch_normalization_bwd_t(const pd_t *apd, const input_vector &inputs,
+            const output_vector &outputs)
+        : cpu_primitive_t(apd, inputs, outputs) {}
+    ~nspc_batch_normalization_bwd_t() {}
+
+    virtual void execute(event_t *e) const {
         execute_backward();
         e->set_state(event_t::ready);
     }
 
 private:
-    data_t *stats_reduction_;
-    data_t *tmp_diff_scaleshift_;
-    void execute_backward();
-    pd_t conf_;
+    void execute_backward() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
+
 }
 }
 }