Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_lrn.cpp
index 38b81dd..de9a1d9 100644 (file)
@@ -47,7 +47,7 @@ static inline float fast_negative_powf(float omega, float beta) {
 
 template <impl::data_type_t data_type>
 template <mkldnn_memory_format_t fmt>
-void ref_lrn_fwd_t<data_type>::execute_forward() {
+void ref_lrn_fwd_t<data_type>::execute_forward() const {
     using namespace alg_kind;
     using namespace memory_format;
 
@@ -55,15 +55,15 @@ void ref_lrn_fwd_t<data_type>::execute_forward() {
     auto dst = reinterpret_cast<data_t*>(this->memory(0));
     auto ws = reinterpret_cast<data_t*>(this->memory(1));
 
-    const memory_desc_wrapper data_d(conf_.src_pd());
-    const memory_desc_wrapper ws_d(conf_.workspace_pd());
+    const memory_desc_wrapper data_d(pd()->src_pd());
+    const memory_desc_wrapper ws_d(pd()->workspace_pd());
     MAYBE_UNUSED(ws_d);
 
-    const int C = conf_.C();
-    const int H = conf_.H();
-    const int W = conf_.W();
+    const int C = pd()->C();
+    const int H = pd()->H();
+    const int W = pd()->W();
     const size_t stride_mb = data_d.blocking_desc().strides[0][0];
-    const bool across_channels = conf_.desc()->alg_kind == lrn_across_channels;
+    const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels;
     constexpr int blksize = fmt == nChw16c ? 16 : 8;
 
     auto data_off = [&](int mb, int c, int h, int w) -> size_t {
@@ -78,11 +78,11 @@ void ref_lrn_fwd_t<data_type>::execute_forward() {
     };
 
     auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
-        const float alpha = static_cast<float>(conf_.desc()->lrn_alpha);
-        const float beta = static_cast<float>(conf_.desc()->lrn_beta);
-        const float k = static_cast<float>(conf_.desc()->lrn_k);
+        const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
+        const float beta = static_cast<float>(pd()->desc()->lrn_beta);
+        const float k = static_cast<float>(pd()->desc()->lrn_k);
 
-        const int size = conf_.desc()->local_size;
+        const int size = pd()->desc()->local_size;
         const int half_size = (size - 1) / 2;
 
         float sum = 0;
@@ -114,7 +114,7 @@ void ref_lrn_fwd_t<data_type>::execute_forward() {
         d[0] = static_cast<data_t>(src[off] * fast_negative_powf(sum, beta));
     };
 
-    const int MB = conf_.MB();
+    const int MB = pd()->MB();
     if (fmt == nChw16c || fmt == nChw8c) {
         parallel_nd(MB, utils::div_up(C, blksize), H, W,
             [&](int mb, int c_blk, int h, int w) {
@@ -142,7 +142,7 @@ void ref_lrn_fwd_t<data_type>::execute_forward() {
 
 template <impl::data_type_t data_type>
 template <mkldnn_memory_format_t fmt>
-void ref_lrn_bwd_t<data_type>::execute_backward() {
+void ref_lrn_bwd_t<data_type>::execute_backward() const {
     using namespace alg_kind;
     using namespace memory_format;
 
@@ -150,21 +150,21 @@ void ref_lrn_bwd_t<data_type>::execute_backward() {
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
 
-    const memory_desc_wrapper data_d(conf_.src_pd());
-    const memory_desc_wrapper diff_data_d(conf_.diff_dst_pd());
+    const memory_desc_wrapper data_d(pd()->src_pd());
+    const memory_desc_wrapper diff_data_d(pd()->diff_dst_pd());
     MAYBE_UNUSED(diff_data_d);
 
-    const int MB = conf_.MB();
-    const int C = conf_.C();
-    const int H = conf_.H();
-    const int W = conf_.W();
+    const int MB = pd()->MB();
+    const int C = pd()->C();
+    const int H = pd()->H();
+    const int W = pd()->W();
     const size_t stride_mb = data_d.blocking_desc().strides[0][0];
     constexpr int blksize = fmt == nChw16c ? 16 : 8;
 
-    const float alpha = static_cast<float>(conf_.desc()->lrn_alpha);
-    const float beta = static_cast<float>(conf_.desc()->lrn_beta);
-    const float k = static_cast<float>(conf_.desc()->lrn_k);
-    const int kernel_size = conf_.desc()->local_size;
+    const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
+    const float beta = static_cast<float>(pd()->desc()->lrn_beta);
+    const float k = static_cast<float>(pd()->desc()->lrn_k);
+    const int kernel_size = pd()->desc()->local_size;
     const int half_ksize = (kernel_size - 1) / 2;
 
     auto data_off = [&](int mb, int c, int h, int w) -> size_t {
@@ -231,16 +231,16 @@ void ref_lrn_bwd_t<data_type>::execute_backward() {
     }
 }
 
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw16c>();
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw8c>();
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nchw>();
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nhwc>();
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::any>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw16c>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw8c>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nchw>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nhwc>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::any>();
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw16c>() const;
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw8c>() const;
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nchw>() const;
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nhwc>() const;
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::any>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw16c>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw8c>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nchw>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nhwc>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::any>() const;
 
 }
 }