Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / deconvolution_pd.hpp
index ba699c5..a98a749 100644 (file)
@@ -39,7 +39,6 @@ struct deconvolution_fwd_pd_t : public primitive_desc_t {
     virtual ~deconvolution_fwd_pd_t() {}
 
     const deconvolution_desc_t *desc() const { return &desc_; }
-    inline const deconvolution_desc_t *cdesc() const { return &desc_; }
     virtual const op_desc_t *op_desc() const override {
         return reinterpret_cast<const op_desc_t *>(this->desc());
     }
@@ -118,6 +117,12 @@ struct deconvolution_fwd_pd_t : public primitive_desc_t {
     }
     inline int ndims() const { return desc_.src_desc.ndims; }
 
+    bool has_zero_dim_memory() const {
+        return false
+            || memory_desc_wrapper(desc_.src_desc).has_zero_dim()
+            || memory_desc_wrapper(desc_.dst_desc).has_zero_dim();
+    }
+
 protected:
     deconvolution_desc_t desc_;
     const deconvolution_fwd_pd_t *hint_fwd_pd_;
@@ -138,7 +143,6 @@ struct deconvolution_bwd_data_pd_t : public primitive_desc_t {
     virtual ~deconvolution_bwd_data_pd_t() {}
 
     const deconvolution_desc_t *desc() const { return &desc_; }
-    const deconvolution_desc_t *cdesc() const { return desc(); }
     virtual const op_desc_t *op_desc() const override {
         return reinterpret_cast<const op_desc_t *>(this->desc());
     }
@@ -214,7 +218,7 @@ struct deconvolution_bwd_data_pd_t : public primitive_desc_t {
     inline bool with_groups() const {
         return desc_.weights_desc.ndims == desc_.diff_src_desc.ndims + 1;
     }
-    inline int ndims() const { return desc_.src_desc.ndims; }
+    inline int ndims() const { return desc_.diff_src_desc.ndims; }
 
 protected:
     deconvolution_desc_t desc_;
@@ -236,7 +240,6 @@ struct deconvolution_bwd_weights_pd_t : public primitive_desc_t {
     virtual ~deconvolution_bwd_weights_pd_t() {}
 
     const deconvolution_desc_t *desc() const { return &desc_; }
-    const deconvolution_desc_t *cdesc() const { return desc(); }
     virtual const op_desc_t *op_desc() const override {
         return reinterpret_cast<const op_desc_t *>(this->desc());
     }