Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ncsp_batch_normalization.cpp
index d755538..66523a6 100644 (file)
 #include <assert.h>
 #include <math.h>
 
-#include "cpu_batch_normalization_utils.hpp"
 #include "c_types_map.hpp"
+#include "type_helpers.hpp"
+
+#include "cpu_batch_normalization_utils.hpp"
 #include "jit_generator.hpp"
+
 #include "ncsp_batch_normalization.hpp"
-#include "type_helpers.hpp"
 
 // clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
 #if (defined __clang_major__) && (__clang_major__ >= 6)
@@ -34,38 +36,17 @@ namespace mkldnn {
 namespace impl {
 namespace cpu {
 
-typedef float data_t;
-ncsp_batch_normalization_fwd_t::ncsp_batch_normalization_fwd_t(const pd_t *pd,
-        const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), stats_reduction_(nullptr),
-    tmp_mean_(nullptr), tmp_variance_(nullptr), conf_(*pd) {
-    if (!conf_.stats_is_src()) {
-        this->stats_reduction_ = (data_t *)malloc(
-                conf_.C() * mkldnn_get_max_threads() * sizeof(data_t), 64);
-        if (!conf_.is_training()) {
-            this->tmp_mean_ = (data_t *)malloc(conf_.C() * sizeof(data_t), 64);
-            this->tmp_variance_
-                    = (data_t *)malloc(conf_.C() * sizeof(data_t), 64);
-        }
-    }
-}
-ncsp_batch_normalization_fwd_t::~ncsp_batch_normalization_fwd_t() {
-    if (!conf_.stats_is_src()) {
-        free(this->stats_reduction_);
-        if (!conf_.is_training()) {
-            free(this->tmp_mean_);
-            free(this->tmp_variance_);
-        }
-    }
-}
+using namespace memory_tracking::names;
 
-void ncsp_batch_normalization_fwd_t::execute_forward() {
+void ncsp_batch_normalization_fwd_t::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto dst = reinterpret_cast<data_t *>(this->memory(0));
-    const bool calculate_stats = !conf_.stats_is_src();
-    const bool save_stats = conf_.is_training();
-    const bool is_training = conf_.is_training();
-    const bool fuse_bn_relu = conf_.fuse_bn_relu();
+    auto scratchpad = this->scratchpad();
+
+    const bool calculate_stats = !pd()->stats_is_src();
+    const bool save_stats = pd()->is_training();
+    const bool is_training = pd()->is_training();
+    const bool fuse_bn_relu = pd()->fuse_bn_relu();
 
     data_t *mean, *variance;
     if (!calculate_stats) {
@@ -78,25 +59,25 @@ void ncsp_batch_normalization_fwd_t::execute_forward() {
             mean = reinterpret_cast<data_t *>(this->memory(1));
             variance = reinterpret_cast<data_t *>(this->memory(2));
         } else {
-            mean = this->tmp_mean_;
-            variance = this->tmp_variance_;
+            mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
+            variance = scratchpad.get<data_t>(key_bnorm_tmp_var);
         }
     }
-    auto idx_scale_shift = 1 + 2 * conf_.stats_is_src();
+    auto idx_scale_shift = 1 + 2 * pd()->stats_is_src();
     auto scaleshift = reinterpret_cast<const data_t *>(
             this->input_memory(idx_scale_shift));
-    auto ws = reinterpret_cast<uint8_t *>(this->memory(conf_.ws_idx()));
-    data_t *ws_reduce = this->stats_reduction_;
+    auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
+    auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
 
-    const float eps = conf_.desc()->batch_norm_epsilon;
-    const bool use_scaleshift = conf_.use_scaleshift();
-    const bool with_relu = conf_.with_relu_post_op();
+    const float eps = pd()->desc()->batch_norm_epsilon;
+    const bool use_scaleshift = pd()->use_scaleshift();
+    const bool with_relu = pd()->with_relu_post_op();
     auto maybe_post_op
             = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
-    const bool has_spatial = utils::one_of(conf_.ndims(), 4, 5);
-    int SP = (has_spatial) ? conf_.H() * conf_.W() * conf_.D() : 1;
-    size_t N = conf_.MB();
-    size_t C = conf_.C();
+    const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
+    int SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
+    size_t N = pd()->MB();
+    size_t C = pd()->C();
 
     int nthr = mkldnn_get_max_threads();
     size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
@@ -232,44 +213,30 @@ void ncsp_batch_normalization_fwd_t::execute_forward() {
     });
 }
 
-ncsp_batch_normalization_bwd_t::ncsp_batch_normalization_bwd_t(const pd_t *pd,
-        const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-    , stats_reduction_(nullptr), tmp_diff_scaleshift_(nullptr) {
-    this->stats_reduction_ = (data_t *)malloc(
-            conf_.C() * 2 * mkldnn_get_max_threads() * sizeof(data_t), 64);
-    if (!(conf_.use_scaleshift()
-                && conf_.desc()->prop_kind == prop_kind::backward))
-        this->tmp_diff_scaleshift_
-                = (data_t *)malloc(conf_.C() * 2 * sizeof(data_t), 64);
-}
-
-ncsp_batch_normalization_bwd_t::~ncsp_batch_normalization_bwd_t() {
-    free(this->stats_reduction_);
-    free(this->tmp_diff_scaleshift_);
-}
-
-void ncsp_batch_normalization_bwd_t::execute_backward() {
+void ncsp_batch_normalization_bwd_t::execute_backward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto variance = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
     auto scaleshift = reinterpret_cast<const data_t *>(this->input_memory(4));
     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
-    auto diff_scaleshift = (this->memory(1)) ?
-            reinterpret_cast<data_t *>(this->memory(1)) :
-            this->tmp_diff_scaleshift_;
+
+    auto scratchpad = this->scratchpad();
+
+    auto diff_scaleshift = this->memory(1)
+        ? reinterpret_cast<data_t *>(this->memory(1))
+        : scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
     auto ws = reinterpret_cast<const uint8_t *>(
-            this->input_memory(conf_.ws_idx()));
-    data_t *ws_reduce = this->stats_reduction_;
-
-    const bool has_spatial = utils::one_of(conf_.ndims(), 4, 5);
-    int SP = (has_spatial) ? conf_.H() * conf_.W() * conf_.D() : 1;
-    size_t C = conf_.C(), N = conf_.MB();
-    const bool use_scaleshift = conf_.use_scaleshift();
-    const float eps = conf_.desc()->batch_norm_epsilon;
-    const bool calculate_diff_stats = !conf_.omit_stats();
-    const bool fuse_bn_relu = conf_.fuse_bn_relu();
+            this->input_memory(pd()->ws_idx()));
+    auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
+
+    const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
+    int SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
+    size_t C = pd()->C(), N = pd()->MB();
+    const bool use_scaleshift = pd()->use_scaleshift();
+    const float eps = pd()->desc()->batch_norm_epsilon;
+    const bool calculate_diff_stats = !pd()->use_global_stats();
+    const bool fuse_bn_relu = pd()->fuse_bn_relu();
 
     int nthr = mkldnn_get_max_threads();
     size_t l3_size_ = get_cache_size(3, true) * nthr / 2;