Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_batch_normalization.cpp
index 65570f1..f009d85 100644 (file)
@@ -28,51 +28,51 @@ namespace impl {
 namespace cpu {
 
 template <impl::data_type_t data_type>
-void ref_batch_normalization_fwd_t<data_type>::execute_forward() {
+void ref_batch_normalization_fwd_t<data_type>::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     /* FIXME: check this */
-    data_t* mean = conf_.stats_is_src() ?
+    data_t* mean = pd()->stats_is_src() ?
         const_cast<data_t*>(reinterpret_cast<const data_t*>(
                this->input_memory(1))) :
         reinterpret_cast<data_t*>(this->memory(1));
 
-    data_t* variance = conf_.stats_is_src() ?
+    data_t* variance = pd()->stats_is_src() ?
         const_cast<data_t*>(reinterpret_cast<const data_t*>(
                 this->input_memory(2))) :
         reinterpret_cast<data_t*>(this->memory(2));
 
-    auto idx_scaleshift = 1 + 2*conf_.stats_is_src();
+    auto idx_scaleshift = 1 + 2*pd()->stats_is_src();
     auto scaleshift =
         reinterpret_cast<const data_t *>(this->input_memory(idx_scaleshift));
 
     auto dst = reinterpret_cast<data_t*>(this->memory(0));
-    auto ws = reinterpret_cast<uint8_t *>(this->memory(conf_.ws_idx()));
+    auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
 
     /* fast return */
-    if (this->conf_.has_zero_dim_memory()) return;
+    if (this->pd()->has_zero_dim_memory()) return;
 
-    const memory_desc_wrapper data_d(conf_.src_pd());
-    const memory_desc_wrapper scaleshift_d(conf_.weights_pd());
+    const memory_desc_wrapper data_d(pd()->src_pd());
+    const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
 
-    const int N = conf_.MB();
-    const int C = conf_.C();
+    const int N = pd()->MB();
+    const int C = pd()->C();
     int H = 1, W = 1, D = 1;
     const bool has_spatial = utils::one_of(data_d.ndims(), 4 ,5);
     if (has_spatial)
     {
-        D = conf_.D();
-        H = conf_.H();
-        W = conf_.W();
+        D = pd()->D();
+        H = pd()->H();
+        W = pd()->W();
     }
 
-    const float eps = conf_.desc()->batch_norm_epsilon;
-    const bool use_scaleshift = conf_.use_scaleshift();;
-    const bool save_stats = conf_.is_training();
-    const bool is_training = conf_.is_training();
-    const bool fuse_bn_relu = conf_.fuse_bn_relu();
-    const bool calculate_stats = !conf_.stats_is_src();
+    const float eps = pd()->desc()->batch_norm_epsilon;
+    const bool use_scaleshift = pd()->use_scaleshift();;
+    const bool save_stats = pd()->is_training();
+    const bool is_training = pd()->is_training();
+    const bool fuse_bn_relu = pd()->fuse_bn_relu();
+    const bool calculate_stats = !pd()->stats_is_src();
 
-    const bool with_relu = conf_.with_relu_post_op();
+    const bool with_relu = pd()->with_relu_post_op();
     auto maybe_post_op = [&](data_t res) {
         return (with_relu && res < 0) ? 0 : res;
     };
@@ -146,29 +146,29 @@ void ref_batch_normalization_fwd_t<data_type>::execute_forward() {
 template struct ref_batch_normalization_fwd_t<data_type::f32>;
 
 template <impl::data_type_t data_type>
-void ref_batch_normalization_bwd_t<data_type>::execute_backward() {
+void ref_batch_normalization_bwd_t<data_type>::execute_backward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto variance = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
     auto scaleshift = reinterpret_cast<const data_t *>(this->input_memory(4));
     auto ws = reinterpret_cast<const uint8_t *>(
-            this->input_memory(conf_.ws_idx()));
+            this->input_memory(pd()->ws_idx()));
 
     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
     auto diff_scaleshift = reinterpret_cast<data_t *>(this->memory(1));
 
-    const memory_desc_wrapper data_d(conf_.src_pd());
-    const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
-    const memory_desc_wrapper scaleshift_d(conf_.weights_pd());
-    const memory_desc_wrapper diff_scaleshift_d(conf_.diff_weights_pd());
-    const memory_desc_wrapper mean_d(conf_.mean_pd());
-    const memory_desc_wrapper variance_d(conf_.variance_pd());
+    const memory_desc_wrapper data_d(pd()->src_pd());
+    const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
+    const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
+    const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_pd());
+    const memory_desc_wrapper mean_d(pd()->mean_pd());
+    const memory_desc_wrapper variance_d(pd()->variance_pd());
 
-    const int C = conf_.C();
+    const int C = pd()->C();
 
     /* fast return */
-    if (this->conf_.has_zero_dim_memory()) {
+    if (this->pd()->has_zero_dim_memory()) {
         if (diff_scaleshift) {
             for (int c = 0; c < C; ++c) {
                 diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0;
@@ -178,20 +178,20 @@ void ref_batch_normalization_bwd_t<data_type>::execute_backward() {
         return;
     }
 
-    const int N = conf_.MB();
+    const int N = pd()->MB();
     int H = 1, W = 1, D = 1;
     const bool has_spatial = utils::one_of(data_d.ndims(), 4 ,5);
     if (has_spatial)
     {
-        D = conf_.D();
-        H = conf_.H();
-        W = conf_.W();
+        D = pd()->D();
+        H = pd()->H();
+        W = pd()->W();
     }
 
-    const float eps = conf_.desc()->batch_norm_epsilon;
-    const bool use_scaleshift = conf_.use_scaleshift();
-    const bool calculate_diff_stats = !conf_.omit_stats();
-    const bool fuse_bn_relu = conf_.fuse_bn_relu();
+    const float eps = pd()->desc()->batch_norm_epsilon;
+    const bool use_scaleshift = pd()->use_scaleshift();
+    const bool calculate_diff_stats = !pd()->use_global_stats();
+    const bool fuse_bn_relu = pd()->fuse_bn_relu();
 
     const bool is_3d = data_d.ndims() == 5;