Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_lrn.hpp
index ad89ed7..e2750f9 100644 (file)
@@ -57,14 +57,14 @@ struct ref_lrn_fwd_t: public cpu_primitive_t {
         }
     };
 
-    ref_lrn_fwd_t(const pd_t *pd, const input_vector &inputs,
+    ref_lrn_fwd_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
+        : cpu_primitive_t(apd, inputs, outputs) {}
     typedef typename prec_traits<data_type>::type data_t;
 
-    virtual void execute(event_t *e) {
+    virtual void execute(event_t *e) const {
         using namespace memory_format;
-        switch (conf_.src_pd()->desc()->format) {
+        switch (pd()->src_pd()->desc()->format) {
         case nChw16c: execute_forward<nChw16c>(); break;
         case nChw8c: execute_forward<nChw8c>(); break;
         case nchw: execute_forward<nchw>(); break;
@@ -77,8 +77,8 @@ struct ref_lrn_fwd_t: public cpu_primitive_t {
     }
 
 private:
-    template<memory_format_t fmt>void execute_forward();
-    pd_t conf_;
+    template<memory_format_t fmt>void execute_forward() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
 
 template <impl::data_type_t data_type>
@@ -106,14 +106,14 @@ struct ref_lrn_bwd_t: public cpu_primitive_t {
         }
     };
 
-    ref_lrn_bwd_t(const pd_t *pd, const input_vector &inputs,
+    ref_lrn_bwd_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
+        : cpu_primitive_t(apd, inputs, outputs) {}
     typedef typename prec_traits<data_type>::type data_t;
 
-    virtual void execute(event_t *e) {
+    virtual void execute(event_t *e) const {
         using namespace memory_format;
-        switch (conf_.src_pd()->desc()->format) {
+        switch (pd()->src_pd()->desc()->format) {
         case nChw16c: execute_backward<nChw16c>(); break;
         case nChw8c: execute_backward<nChw8c>(); break;
         case nchw: execute_backward<nchw>(); break;
@@ -126,8 +126,8 @@ struct ref_lrn_bwd_t: public cpu_primitive_t {
     }
 
 private:
-    template<memory_format_t fmt>void execute_backward();
-    pd_t conf_;
+    template<memory_format_t fmt>void execute_backward() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
 
 }