Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_convolution_pd.hpp
index 1db3f4a..f50287a 100644 (file)
@@ -31,20 +31,19 @@ namespace mkldnn {
 namespace impl {
 namespace cpu {
 
-template <bool with_relu>
-struct _cpu_convolution_fwd_pd_t: public _convolution_fwd_pd_t<with_relu> {
+struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t {
     using cpu_memory_pd_t = cpu_memory_t::pd_t;
 
-    _cpu_convolution_fwd_pd_t(engine_t *engine,
-            const typename _cpu_convolution_fwd_pd_t::base_desc_t *adesc,
+    cpu_convolution_fwd_pd_t(engine_t *engine,
+            const convolution_desc_t *adesc,
             const primitive_attr_t *attr,
-            const typename _cpu_convolution_fwd_pd_t::base_class *hint_fwd_pd)
-        : _convolution_fwd_pd_t<with_relu>(engine, adesc, attr, hint_fwd_pd)
-        , src_pd_(this->engine_, &this->cdesc_().src_desc)
-        , dst_pd_(this->engine_, &this->cdesc_().dst_desc)
-        , weights_pd_(this->engine_, &this->cdesc_().weights_desc)
-        , bias_pd_(this->engine_, &this->cdesc_().bias_desc) {}
-    virtual ~_cpu_convolution_fwd_pd_t() {}
+            const typename cpu_convolution_fwd_pd_t::base_class *hint_fwd_pd)
+        : convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , src_pd_(this->engine_, &this->desc()->src_desc)
+        , dst_pd_(this->engine_, &this->desc()->dst_desc)
+        , weights_pd_(this->engine_, &this->desc()->weights_desc)
+        , bias_pd_(this->engine_, &this->desc()->bias_desc) {}
+    virtual ~cpu_convolution_fwd_pd_t() {}
 
     virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
     { return index == 0 ? &src_pd_ : nullptr; }
@@ -56,13 +55,26 @@ struct _cpu_convolution_fwd_pd_t: public _convolution_fwd_pd_t<with_relu> {
         return nullptr;
     }
 
-    bool want_padded_bias() const {
-        if (!this->with_bias()) return false;
+    bool has_padded_dst() const {
         memory_desc_wrapper dst_d(&dst_pd_);
         if (!dst_d.is_blocking_desc()) return false;
         return this->OC() != dst_d.blocking_desc().padding_dims[1];
     }
 
+    bool wants_padded_bias() const {
+        if (!this->with_bias()) return false;
+        return has_padded_dst();
+    }
+
+    bool wants_zero_pad_dst(bool jit_impl = true) const {
+        if (!has_padded_dst()) return false;
+        const auto &po = this->attr()->post_ops_;
+        int idx;
+        if ((idx = po.find(primitive_kind::eltwise)) == -1) return false;
+        return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg,
+                jit_impl);
+    }
+
 protected:
     cpu_memory_pd_t src_pd_, dst_pd_;
     cpu_memory_pd_t weights_pd_, bias_pd_;
@@ -70,14 +82,14 @@ protected:
     inline memory_format_t src_format()
     {
         using namespace memory_format;
-        return utils::pick(this->cdesc_().src_desc.ndims - 3, ncw, nchw, ncdhw);
+        return utils::pick(this->desc()->src_desc.ndims - 3, ncw, nchw, ncdhw);
     }
     inline memory_format_t wei_format()
     {
         using namespace memory_format;
         return this->with_groups()
-            ? utils::pick(this->cdesc_().src_desc.ndims - 3, goiw, goihw, goidhw)
-            : utils::pick(this->cdesc_().src_desc.ndims - 3, oiw, oihw, oidhw);
+            ? utils::pick(this->desc()->src_desc.ndims - 3, goiw, goihw, goidhw)
+            : utils::pick(this->desc()->src_desc.ndims - 3, oiw, oihw, oidhw);
     }
 
     virtual status_t set_default_params() {
@@ -90,13 +102,12 @@ protected:
             CHECK(weights_pd_.set_format(wei_format()));
         if (bias_pd_.desc()->format == any)
             CHECK(bias_pd_.set_format(x));
+        if (this->desc()->alg_kind == alg_kind::convolution_auto)
+            CHECK(this->set_alg_kind(alg_kind::convolution_direct));
         return status::success;
     }
 };
 
-using cpu_convolution_fwd_pd_t = _cpu_convolution_fwd_pd_t<false>;
-using cpu_convolution_relu_fwd_pd_t = _cpu_convolution_fwd_pd_t<true>;
-
 struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t {
     using cpu_memory_pd_t = cpu_memory_t::pd_t;
 
@@ -148,6 +159,8 @@ protected:
            CHECK(weights_pd_.set_format(wei_format()));
         if (bias_pd_.desc()->format == any)
             CHECK(bias_pd_.set_format(x));
+        if (this->desc()->alg_kind == alg_kind::convolution_auto)
+            CHECK(this->set_alg_kind(alg_kind::convolution_direct));
         return status::success;
     }
 };
@@ -177,7 +190,7 @@ struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t {
             return  nullptr;
         }
 
-    bool want_padded_bias() const {
+    bool wants_padded_bias() const {
         if (!this->with_bias()) return false;
         memory_desc_wrapper diff_dst_d(&diff_dst_pd_);
         if (!diff_dst_d.is_blocking_desc()) return false;
@@ -212,6 +225,8 @@ protected:
             CHECK(diff_weights_pd_.set_format(wei_format()));
         if (diff_bias_pd_.desc()->format == any)
             CHECK(diff_bias_pd_.set_format(x));
+        if (this->desc()->alg_kind == alg_kind::convolution_auto)
+            CHECK(this->set_alg_kind(alg_kind::convolution_direct));
         return status::success;
     }
 };