Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_shuffle.hpp
index 763bbaa..cd653dc 100644 (file)
@@ -53,15 +53,15 @@ struct ref_shuffle_t : public cpu_primitive_t {
         }
     };
 
-    ref_shuffle_t(const pd_t *pd, const input_vector &inputs,
+    ref_shuffle_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
+        : cpu_primitive_t(apd, inputs, outputs)
     {
-        const int axis_size = conf_.axis_size();
-        const int group_size = conf_.group_size();
-        const int transpose_row = conf_.is_fwd() ? group_size
+        const int axis_size = pd()->axis_size();
+        const int group_size = pd()->group_size();
+        const int transpose_row = pd()->is_fwd() ? group_size
                                                  : axis_size / group_size;
-        const int transpose_col = conf_.is_fwd() ? axis_size / group_size
+        const int transpose_col = pd()->is_fwd() ? axis_size / group_size
                                                  : group_size;
         rev_transposed_ = (int *)malloc(axis_size * sizeof(int), 64);
         parallel_nd(transpose_col, transpose_row, [&](int i, int j) {
@@ -73,9 +73,9 @@ struct ref_shuffle_t : public cpu_primitive_t {
 
     typedef typename typesize_traits<data_type_size>::type data_t;
 
-    virtual void execute(event_t *e) {
+    virtual void execute(event_t *e) const {
         using namespace memory_format;
-        switch (conf_.data_pd()->desc()->format) {
+        switch (pd()->data_pd()->desc()->format) {
         case nCdhw16c: execute_<nCdhw16c>(); break;
         case nChw16c:  execute_<nChw16c>(); break;
         case nCdhw8c:  execute_<nCdhw8c>(); break;
@@ -91,8 +91,8 @@ struct ref_shuffle_t : public cpu_primitive_t {
     }
 
 private:
-    template<memory_format_t fmt>void execute_();
-    pd_t conf_;
+    template<memory_format_t fmt>void execute_() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
     int *rev_transposed_;
 };