Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_deconvolution.hpp
index 6890c1c..e185172 100644 (file)
 #include "utils.hpp"
 #include "primitive_iterator.hpp"
 
-#define DECLARE_DECONVOLUTION_PD_t(impl_name, ...) \
-    virtual pd_t *clone() const override { return new pd_t(*this); } \
-    virtual status_t create_primitive(primitive_t **primitive, \
-                const primitive_at_t *inputs, \
-                const primitive_t **outputs) const override { \
-        double ms = get_msec(); \
-        using namespace prop_kind;\
-        primitive_t::input_vector ins(inputs, inputs + this->n_inputs()); \
-        primitive_t::output_vector outs(outputs, outputs + this->n_outputs()); \
-        auto ret = safe_ptr_assign<primitive_t>(*primitive, \
-                                new (__VA_ARGS__)(this, ins, outs)); \
-        primitive_t *conv_primitive; \
-        if (this->desc()->prop_kind == backward_weights) {\
-            primitive_at_t conv_inputs[2];\
-            conv_inputs[0] = inputs[1];\
-            conv_inputs[1] = inputs[0];\
-            conv_pd_->create_primitive((&conv_primitive), conv_inputs, outputs);\
-        } \
-        else conv_pd_->create_primitive((&conv_primitive), inputs, outputs);\
-        ((__VA_ARGS__ *)(*primitive))->conv_p_ = conv_primitive;\
-        ms = get_msec() - ms; \
-        if (mkldnn_verbose()->level >= 2) { \
-                        printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
-                        fflush(0); \
-                    } \
-        return ret; \
-    } \
-virtual const char *name() const override { return impl_name; }
-
-#define DECLARE_DECONVOLUTION_PD_T(impl_name, ...) \
-        DECLARE_DECONVOLUTION_PD_t(impl_name,  __VA_ARGS__)
-
-
 namespace mkldnn {
 namespace impl {
 namespace cpu {
@@ -146,7 +113,7 @@ struct ref_deconvolution_fwd_t: public cpu_primitive_t {
 
         ~pd_t() { delete conv_pd_; }
 
-        DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_fwd_t);
+        DECLARE_DECONVOLUTION_PD_T(ref_deconvolution_fwd_t);
 
         status_t init_convolution(){
             using namespace memory_format;
@@ -154,7 +121,7 @@ struct ref_deconvolution_fwd_t: public cpu_primitive_t {
             convolution_desc_t cd;
             status_t status;
 
-            status = conv_descr_create(this->cdesc(), &cd);
+            status = conv_descr_create(this->desc(), &cd);
             if (status != status::success) return status;
 
             mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
@@ -216,19 +183,19 @@ struct ref_deconvolution_fwd_t: public cpu_primitive_t {
         bool conv_supports_bias_;
     };
 
-    ref_deconvolution_fwd_t(const pd_t *pd, const input_vector &inputs,
+    ref_deconvolution_fwd_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), conv_p_(nullptr) {}
+        : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
 
     ~ref_deconvolution_fwd_t() { delete this->conv_p_; }
 
-    virtual void execute(event_t *e) {
-        switch (conf_.desc()->prop_kind) {
+    virtual void execute(event_t *e) const {
+        switch (pd()->desc()->prop_kind) {
         case prop_kind::forward_training:
         case prop_kind::forward_inference:
             (conv_p_)->execute(e);
-            if (conf_.with_bias() && !conf_.conv_supports_bias_) {
-                switch (conf_.dst_pd()->desc()->format) {
+            if (pd()->with_bias() && !pd()->conv_supports_bias_) {
+                switch (pd()->dst_pd()->desc()->format) {
                     case memory_format::nchw :
                     case memory_format::ncdhw :
                         compute_fwd_bias_ncdhw();
@@ -254,10 +221,10 @@ struct ref_deconvolution_fwd_t: public cpu_primitive_t {
     }
 
 private:
-    void compute_fwd_bias();
-    void compute_fwd_bias_ncdhw();
-    template <int blksize> void compute_fwd_bias_nCdhwXc();
-    pd_t conf_;
+    void compute_fwd_bias() const;
+    void compute_fwd_bias_ncdhw() const;
+    template <int blksize> void compute_fwd_bias_nCdhwXc() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
     primitive_t *conv_p_;
 };
 
@@ -277,7 +244,7 @@ struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
 
         ~pd_t() { delete conv_pd_; }
 
-        DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_bwd_data_t);
+        DECLARE_DECONVOLUTION_PD_T(ref_deconvolution_bwd_data_t);
 
         status_t init_convolution(){
             using namespace memory_format;
@@ -285,7 +252,7 @@ struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
             convolution_desc_t cd;
             status_t status;
 
-            status = conv_descr_create(this->cdesc(), &cd);
+            status = conv_descr_create(this->desc(), &cd);
             if (status != status::success) return status;
 
              mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
@@ -336,13 +303,13 @@ struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
         }
         primitive_desc_t *conv_pd_;
     };
-    ref_deconvolution_bwd_data_t(const pd_t *pd, const input_vector &inputs,
+    ref_deconvolution_bwd_data_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), conv_p_(nullptr) {}
+        : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
     ~ref_deconvolution_bwd_data_t() { delete this->conv_p_; }
 
-    virtual void execute(event_t *e) {
-        switch (conf_.desc()->prop_kind) {
+    virtual void execute(event_t *e) const {
+        switch (pd()->desc()->prop_kind) {
         case prop_kind::backward_data:
             (conv_p_)->execute(e);
             break;
@@ -353,7 +320,7 @@ struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
     }
 
 private:
-    pd_t conf_;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
     primitive_t *conv_p_;
 };
 
@@ -373,7 +340,7 @@ struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
 
         ~pd_t() { delete conv_pd_; }
 
-        DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_bwd_weights_t);
+        DECLARE_DECONVOLUTION_PD_T(ref_deconvolution_bwd_weights_t);
 
         status_t init_convolution(){
             using namespace memory_format;
@@ -381,7 +348,7 @@ struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
             convolution_desc_t cd;
             status_t status;
 
-            status = conv_descr_create(this->cdesc(), &cd);
+            status = conv_descr_create(this->desc(), &cd);
             if (status != status::success) return status;
 
              mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
@@ -434,20 +401,20 @@ struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
         primitive_desc_t *conv_pd_;
     };
 
-    ref_deconvolution_bwd_weights_t(const pd_t *pd, const input_vector &inputs,
+    ref_deconvolution_bwd_weights_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), conv_p_(nullptr) {}
+        : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
 
     ~ref_deconvolution_bwd_weights_t() { delete this->conv_p_; }
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    virtual void execute(event_t *e) {
-        switch (conf_.desc()->prop_kind) {
+    virtual void execute(event_t *e) const {
+        switch (pd()->desc()->prop_kind) {
         case prop_kind::backward_weights:
             (conv_p_)->execute(e);
-            if (conf_.with_bias()) {
-                switch (conf_.diff_dst_pd()->desc()->format) {
+            if (pd()->with_bias()) {
+                switch (pd()->diff_dst_pd()->desc()->format) {
                     case memory_format::nchw :
                     case memory_format::ncdhw :
                         compute_bwd_bias_ncdhw();
@@ -472,11 +439,11 @@ struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
     }
 
 private:
-    pd_t conf_;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
     primitive_t *conv_p_;
-    void compute_bwd_bias();
-    void compute_bwd_bias_ncdhw();
-    template <int blksize> void compute_bwd_bias_nCdhwXc();
+    void compute_bwd_bias() const;
+    void compute_bwd_bias_ncdhw() const;
+    template <int blksize> void compute_bwd_bias_nCdhwXc() const;
 };
 
 }