Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / convolution_pd.hpp
index 90b6629..99e6e32 100644 (file)
@@ -35,25 +35,28 @@ status_t conv_desc_init(convolution_desc_t *conv_desc,
         const dims_t padding_l, const dims_t padding_r,
         padding_kind_t padding_kind);
 
-template <bool with_relu>
-struct _convolution_fwd_pd_t: public primitive_desc_t {
-    typedef _convolution_fwd_pd_t base_class;
-    typedef _convolution_fwd_pd_t hint_class;
-    typedef typename utils::conditional<with_relu,
-            convolution_relu_desc_t, convolution_desc_t>::type base_desc_t;
-    static constexpr auto base_pkind =
-        utils::conditional_v<with_relu, primitive_kind_t,
-        primitive_kind::convolution_relu, primitive_kind::convolution>::value;
-
-    _convolution_fwd_pd_t(mkldnn::impl::engine_t *engine,
-            const base_desc_t *adesc, const primitive_attr_t *attr,
-            const _convolution_fwd_pd_t *hint_fwd_pd)
+memory_desc_t *conv_prop_agnostic_src_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_agnostic_wei_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_agnostic_bia_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_agnostic_dst_d(convolution_desc_t *desc);
+const memory_desc_t *conv_prop_agnostic_src_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_agnostic_wei_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_agnostic_bia_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_agnostic_dst_d(const convolution_desc_t *desc);
+
+struct convolution_fwd_pd_t: public primitive_desc_t {
+    typedef convolution_fwd_pd_t base_class;
+    typedef convolution_fwd_pd_t hint_class;
+    static constexpr auto base_pkind = primitive_kind::convolution;
+
+    convolution_fwd_pd_t(mkldnn::impl::engine_t *engine,
+            const convolution_desc_t *adesc, const primitive_attr_t *attr,
+            const convolution_fwd_pd_t *hint_fwd_pd)
         : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
         , hint_fwd_pd_(hint_fwd_pd) {}
-    virtual ~_convolution_fwd_pd_t() {}
+    virtual ~convolution_fwd_pd_t() {}
 
-    const base_desc_t *desc() const { return &desc_; }
-    inline const convolution_desc_t *cdesc() const { return &cdesc_(); }
+    const convolution_desc_t *desc() const { return &desc_; }
     virtual const op_desc_t *op_desc() const override
     { return reinterpret_cast<const op_desc_t *>(this->desc()); }
     virtual void init_info() override { init_info_conv(this, this->info_); }
@@ -75,7 +78,7 @@ struct _convolution_fwd_pd_t: public primitive_desc_t {
     {
         switch (what) {
         case pkind_traits<base_pkind>::query_d:
-            *(const base_desc_t**)result = desc(); break;
+            *(const convolution_desc_t**)result = desc(); break;
         default: return primitive_desc_t::query(what, idx, result);
         }
         return status::success;
@@ -88,7 +91,7 @@ struct _convolution_fwd_pd_t: public primitive_desc_t {
     inline int IC() const { return input_pd()->desc()->dims[1]; }
     inline int OC() const { return output_pd()->desc()->dims[1]; }
     inline int G() const
-    { return with_groups() ? cdesc_().weights_desc.dims[0] : 1; }
+    { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
 
     inline int ID() const { return (ndims() == 5) ? input_pd()->desc()->dims[2] : 1; }
     inline int IH() const { return (ndims() == 3) ? 1 : input_pd()->desc()->dims[ndims()-2]; }
@@ -97,73 +100,61 @@ struct _convolution_fwd_pd_t: public primitive_desc_t {
     inline int OH() const { return (ndims() == 3) ? 1 : output_pd()->desc()->dims[ndims()-2]; }
     inline int OW() const { return output_pd()->desc()->dims[ndims()-1]; }
     inline int KD() const { return (ndims() == 5)
-        ? cdesc_().weights_desc.dims[2 + with_groups()] : 1; }
+        ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
     inline int KH() const
     { return (ndims() == 3)
-        ? 1 : cdesc_().weights_desc.dims[ndims() - (2 - with_groups())]; }
+        ? 1 : desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
     inline int KW() const
-    { return cdesc_().weights_desc.dims[ndims() - (1 - with_groups())]; }
+    { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
 
-    inline int KSD() const { return (ndims() == 5) ? cdesc_().strides[0] : 1; }
+    inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
     inline int KSH() const { return (ndims() == 3)
-        ? 1 : cdesc_().strides[ndims()-4]; }
-    inline int KSW() const { return cdesc_().strides[ndims()-3]; }
+        ? 1 : desc_.strides[ndims()-4]; }
+    inline int KSW() const { return desc_.strides[ndims()-3]; }
 
-    inline int KDD() const { return (ndims() == 5) ? cdesc_().dilates[0] : 0; }
+    inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
     inline int KDH() const { return (ndims() == 3)
-        ? 0 : cdesc_().dilates[ndims()-4]; }
-    inline int KDW() const { return cdesc_().dilates[ndims()-3]; }
+        ? 0 : desc_.dilates[ndims()-4]; }
+    inline int KDW() const { return desc_.dilates[ndims()-3]; }
 
     inline int padFront() const
-        { return (ndims() == 5) ? cdesc_().padding[0][0] : 0; }
+        { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
     inline int padBack() const
-        { return (ndims() == 5) ? cdesc_().padding[1][0] : 0; }
+        { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
     inline int padT() const { return (ndims() == 3)
-        ? 0 : cdesc_().padding[0][ndims()-4]; }
+        ? 0 : desc_.padding[0][ndims()-4]; }
     inline int padB() const { return (ndims() == 3)
-        ? 0 : cdesc_().padding[1][ndims()-4]; }
-    inline int padL() const { return cdesc_().padding[0][ndims()-3]; }
-    inline int padR() const { return cdesc_().padding[1][ndims()-3]; }
-
-    inline float negative_slope() const;
+        ? 0 : desc_.padding[1][ndims()-4]; }
+    inline int padL() const { return desc_.padding[0][ndims()-3]; }
+    inline int padR() const { return desc_.padding[1][ndims()-3]; }
 
     inline bool with_bias() const
-    { return !memory_desc_wrapper(cdesc_().bias_desc).is_zero(); }
+    { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
     inline bool with_groups() const
-    { return cdesc_().weights_desc.ndims == cdesc_().src_desc.ndims + 1; }
+    { return desc_.weights_desc.ndims == desc_.src_desc.ndims + 1; }
 
-    inline int ndims() const { return cdesc_().src_desc.ndims; }
+    inline int ndims() const { return desc_.src_desc.ndims; }
+
+    virtual status_t set_alg_kind(alg_kind_t alg) {
+        if (alg == alg_kind::undef) return status::invalid_arguments;
+        desc_.alg_kind = alg;
+        return status::success;
+    }
 
     bool has_zero_dim_memory() const {
         return false
-            || memory_desc_wrapper(cdesc_().src_desc).has_zero_dim()
-            || memory_desc_wrapper(cdesc_().dst_desc).has_zero_dim();
+            || memory_desc_wrapper(desc_.src_desc).has_zero_dim()
+            || memory_desc_wrapper(desc_.dst_desc).has_zero_dim();
     }
 
-protected:
-    base_desc_t desc_;
-    const _convolution_fwd_pd_t *hint_fwd_pd_;
 
-    inline const convolution_desc_t &cdesc_() const;
+protected:
+    convolution_desc_t desc_;
+    const convolution_fwd_pd_t *hint_fwd_pd_;
 
     virtual status_t init() = 0;
 };
 
-using convolution_fwd_pd_t = mkldnn::impl::_convolution_fwd_pd_t<false>;
-using convolution_relu_fwd_pd_t = mkldnn::impl::_convolution_fwd_pd_t<true>;
-
-template<> inline float convolution_fwd_pd_t::negative_slope() const
-{ return 0.; }
-template<> inline float convolution_relu_fwd_pd_t::negative_slope() const
-{ return desc()->negative_slope; }
-
-template<bool with_relu> inline const
-convolution_desc_t &_convolution_fwd_pd_t<with_relu>::cdesc_() const
-{ return desc_; }
-template<>
-inline const convolution_desc_t &convolution_relu_fwd_pd_t::cdesc_() const
-{ return desc_.convolution_desc; }
-
 struct convolution_bwd_data_pd_t: public primitive_desc_t {
     typedef convolution_bwd_data_pd_t base_class;
     typedef convolution_fwd_pd_t hint_class;
@@ -178,7 +169,6 @@ struct convolution_bwd_data_pd_t: public primitive_desc_t {
     virtual ~convolution_bwd_data_pd_t() {}
 
     const convolution_desc_t *desc() const { return &desc_; }
-    const convolution_desc_t *cdesc() const { return desc(); }
     virtual const op_desc_t *op_desc() const override
     { return reinterpret_cast<const op_desc_t *>(this->desc()); }
     virtual void init_info() override { init_info_conv(this, this->info_); }
@@ -257,6 +247,12 @@ struct convolution_bwd_data_pd_t: public primitive_desc_t {
     inline int ndims() const { return desc_.diff_src_desc.ndims; }
     virtual bool support_bias() const { return false; }
 
+    virtual status_t set_alg_kind(alg_kind_t alg) {
+        if (alg == alg_kind::undef) return status::invalid_arguments;
+        desc_.alg_kind = alg;
+        return status::success;
+    }
+
     bool has_zero_dim_memory() const {
         return false
             || memory_desc_wrapper(desc_.diff_src_desc).has_zero_dim()
@@ -284,7 +280,6 @@ struct convolution_bwd_weights_pd_t: public primitive_desc_t {
     virtual ~convolution_bwd_weights_pd_t() {}
 
     const convolution_desc_t *desc() const { return &desc_; }
-    const convolution_desc_t *cdesc() const { return desc(); }
     virtual const op_desc_t *op_desc() const override
     { return reinterpret_cast<const op_desc_t *>(this->desc()); }
     virtual void init_info() override { init_info_conv(this, this->info_); }
@@ -372,6 +367,12 @@ struct convolution_bwd_weights_pd_t: public primitive_desc_t {
 
     inline int ndims() const { return desc_.src_desc.ndims; }
 
+    virtual status_t set_alg_kind(alg_kind_t alg) {
+        if (alg == alg_kind::undef) return status::invalid_arguments;
+        desc_.alg_kind = alg;
+        return status::success;
+    }
+
     bool has_zero_dim_memory() const {
         return false
             || memory_desc_wrapper(desc_.src_desc).has_zero_dim()