Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_lrn.cpp
index f774d44..00bea07 100644 (file)
@@ -26,23 +26,23 @@ namespace cpu {
 
 template <cpu_isa_t isa>
 jit_uni_lrn_fwd_t<isa>::jit_uni_lrn_fwd_t(
-    const pd_t *pd,
+    const pd_t *apd,
     const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), ker_(nullptr)
+    : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
     , ker_first_(nullptr), ker_last_(nullptr)
 {
     using namespace alg_kind;
 
-    const int C = conf_.C();
-    const int H = conf_.H();
-    const int W = conf_.W();
-    const int ls = conf_.desc()->local_size;
-    float A = conf_.desc()->lrn_alpha / ls;
-    float K = conf_.desc()->lrn_k;
+    const int C = pd()->C();
+    const int H = pd()->H();
+    const int W = pd()->W();
+    const int ls = pd()->desc()->local_size;
+    float A = pd()->desc()->lrn_alpha / ls;
+    float K = pd()->desc()->lrn_k;
 
-    auto pk = conf_.desc()->prop_kind;
-    auto ak = conf_.desc()->alg_kind;
-    auto dfmt = conf_.src_pd()->desc()->format;
+    auto pk = pd()->desc()->prop_kind;
+    auto ak = pd()->desc()->alg_kind;
+    auto dfmt = pd()->src_pd()->desc()->format;
 
     if (dfmt == nChw8c && ls == 5 && ak == lrn_across_channels) {
         ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
@@ -74,20 +74,20 @@ jit_uni_lrn_fwd_t<isa>::~jit_uni_lrn_fwd_t()
 { delete ker_; delete ker_first_; delete ker_last_; }
 
 template <cpu_isa_t isa>
-void jit_uni_lrn_fwd_t<isa>::execute_forward() {
+void jit_uni_lrn_fwd_t<isa>::execute_forward() const {
     using namespace alg_kind;
 
     auto src = reinterpret_cast<const data_t*>(this->input_memory(0));
     auto dst = reinterpret_cast<data_t*>(this->memory(0));
     auto ws = reinterpret_cast<data_t*>(this->memory(1));
 
-    const int N = conf_.MB();
-    const int C = conf_.C();
-    const int HW = conf_.H() * conf_.W();
-    const int ls = conf_.desc()->local_size;
+    const int N = pd()->MB();
+    const int C = pd()->C();
+    const int HW = pd()->H() * pd()->W();
+    const int ls = pd()->desc()->local_size;
 
-    auto ak = conf_.desc()->alg_kind;
-    auto dfmt = conf_.src_pd()->desc()->format;
+    auto ak = pd()->desc()->alg_kind;
+    auto dfmt = pd()->src_pd()->desc()->format;
 
     if (dfmt == nChw8c && ls == 5 && ak == lrn_across_channels) {
         parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
@@ -177,18 +177,18 @@ status_t jit_uni_lrn_fwd_t<isa>::pd_t::init() {
 }
 
 template <cpu_isa_t isa>
-jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *pd,
+jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *apd,
     const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
+    : cpu_primitive_t(apd, inputs, outputs)
     , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr)
 {
     using namespace alg_kind;
-    const int C = conf_.C();
-    const int H = conf_.H();
-    const int W = conf_.W();
-    const int ls = conf_.desc()->local_size;
-    float A = conf_.desc()->lrn_alpha / ls;
-    float B = conf_.desc()->lrn_beta;
+    const int C = pd()->C();
+    const int H = pd()->H();
+    const int W = pd()->W();
+    const int ls = pd()->desc()->local_size;
+    float A = pd()->desc()->lrn_alpha / ls;
+    float B = pd()->desc()->lrn_beta;
 
     int use_h_parallelizm = 0;// XXX
     if (C / VECTOR_LENGTH == 1) {
@@ -212,16 +212,16 @@ jit_uni_lrn_bwd_t<isa>::~jit_uni_lrn_bwd_t()
 }
 
 template <cpu_isa_t isa>
-void jit_uni_lrn_bwd_t<isa>::execute_backward() {
+void jit_uni_lrn_bwd_t<isa>::execute_backward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto ws = reinterpret_cast<const data_t*>(this->input_memory(2));
     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
 
-    const int N = conf_.MB();
-    const int C = conf_.C();
-    const int H = conf_.H();
-    const int W = conf_.W();
+    const int N = pd()->MB();
+    const int C = pd()->C();
+    const int H = pd()->H();
+    const int W = pd()->W();
 
     int use_h_parallelizm = 0; // XXX
     if (use_h_parallelizm) {