Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_convolution.hpp
index d0d65c1..2a0da52 100644 (file)
 #define CPU_JIT_GEMM_CONVOLUTION_HPP
 
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
 #include "cpu_convolution_pd.hpp"
 #include "cpu_engine.hpp"
 #include "gemm_convolution_utils.hpp"
 #include "gemm/gemm.hpp"
-#include "scratchpad.hpp"
 #include "ref_eltwise.hpp"
 #include "ref_depthwise.hpp"
 
@@ -30,34 +31,15 @@ namespace mkldnn {
 namespace impl {
 namespace cpu {
 
-template <bool with_relu>
-struct _gemm_convolution_fwd_t: public cpu_primitive_t {
-    struct pd_t: public _cpu_convolution_fwd_pd_t<with_relu> {
+struct gemm_convolution_fwd_t: public cpu_primitive_t {
+    struct pd_t: public cpu_convolution_fwd_pd_t {
         pd_t(engine_t *engine,
-                const typename pd_t::base_desc_t *adesc,
-                const primitive_attr_t *attr,
+                const convolution_desc_t *adesc, const primitive_attr_t *attr,
                 const typename pd_t::base_class *hint_fwd_pd)
-            : _cpu_convolution_fwd_pd_t<with_relu>(engine, adesc, attr,
-                    hint_fwd_pd)
+            : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
             , jcp_() {}
 
-        DECLARE_COMMON_PD_T(GEMM_IMPL_STR, _gemm_convolution_fwd_t<with_relu>);
-
-        inline memory_format_t src_format()
-        {
-            using namespace memory_format;
-            return (utils::pick(this->cdesc_().src_desc.ndims - 3,
-                ncw, nchw, ncdhw));
-        }
-        inline memory_format_t wei_format()
-        {
-            using namespace memory_format;
-            return (this->with_groups()
-                ? utils::pick(this->cdesc_().src_desc.ndims - 3,
-                    goiw, goihw, goidhw)
-                : utils::pick(this->cdesc_().src_desc.ndims - 3,
-                    oiw, oihw, oidhw));
-        }
+        DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_fwd_t);
 
         virtual status_t init() override {
             using namespace prop_kind;
@@ -67,26 +49,47 @@ struct _gemm_convolution_fwd_t: public cpu_primitive_t {
 
             bool ok = true
                 && this->set_default_params() == status::success
-                && utils::one_of(this->cdesc_().prop_kind, forward_training,
+                && utils::one_of(this->desc()->prop_kind, forward_training,
                            forward_inference)
-                && this->cdesc_().alg_kind == alg_kind::convolution_direct
+                && utils::one_of(this->desc()->alg_kind,
+                        alg_kind::convolution_auto,
+                        alg_kind::convolution_direct)
                 && !this->has_zero_dim_memory()
                 && utils::everyone_is(data_type::f32,
-                           this->cdesc_().src_desc.data_type,
-                           this->cdesc_().weights_desc.data_type,
-                           this->cdesc_().dst_desc.data_type)
+                           this->desc()->src_desc.data_type,
+                           this->desc()->weights_desc.data_type,
+                           this->desc()->dst_desc.data_type)
                 && IMPLICATION(this->with_bias(), data_type::f32
-                                   == this->cdesc_().bias_desc.data_type)
+                                   == this->desc()->bias_desc.data_type)
                 && this->src_pd_.desc()->format == src_format()
                 && this->dst_pd_.desc()->format == src_format()
                 && this->weights_pd_.desc()->format == wei_format()
                 && this->is_gemm_conv_format();
-            return ok ? status::success : status::unimplemented;
+            if (!ok) return status::unimplemented;
+
+            auto scratchpad = scratchpad_registry().registrar();
+            return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+                    *desc(), src_pd(), weights_pd(0), dst_pd(),
+                    mkldnn_get_max_threads());
         }
 
         jit_gemm_conv_conf_t jcp_;
 
     protected:
+        memory_format_t src_format() const {
+            using namespace memory_format;
+            const int ndims_sp = this->desc()->src_desc.ndims - 2;
+            return (utils::pick(ndims_sp - 1, ncw, nchw, ncdhw));
+        }
+
+        memory_format_t wei_format() const {
+            using namespace memory_format;
+            const int ndims_sp = this->desc()->src_desc.ndims - 2;
+            return (this->with_groups()
+                ? utils::pick(ndims_sp - 1, goiw, goihw, goidhw)
+                : utils::pick(ndims_sp - 1, oiw, oihw, oidhw));
+        }
+
         virtual status_t set_default_params() override {
             using namespace memory_format;
             if (this->src_pd_.desc()->format == any)
@@ -97,11 +100,12 @@ struct _gemm_convolution_fwd_t: public cpu_primitive_t {
                 CHECK(this->weights_pd_.set_format(wei_format()));
             if (this->bias_pd_.desc()->format == any)
                 CHECK(this->bias_pd_.set_format(x));
+            if (this->desc()->alg_kind == alg_kind::convolution_auto)
+                CHECK(this->set_alg_kind(alg_kind::convolution_direct));
             return status::success;
         }
 
         virtual bool is_gemm_conv_format() const {
-            bool ok = true;
             auto const &po = this->attr()->post_ops_;
 
             auto is_eltwise = [&](int idx) { return po.entry_[idx].is_eltwise(); };
@@ -110,48 +114,24 @@ struct _gemm_convolution_fwd_t: public cpu_primitive_t {
             auto is_simple = [&](int idx) { return (is_eltwise(idx) || is_depthwise(idx)); };
 
             switch (po.len_) {
-                using namespace mkldnn::impl::primitive_kind;
-            case 0: // no post_ops
-                break;
-            case 1:
-                ok = ok && // sum OR eltwise/depthwise
-                        (is_simple(0) || is_sum(0));
-                break;
-            case 2:
-                ok = ok && // sum->eltwise/depthwise OR eltwise/depthwise->eltwise/depthwise
-                           ((is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1)));
-                break;
-            case 3:
-                ok = ok && // sum->eltwise/depthwise->eltwise/depthwise
-                     (is_sum(0) && is_simple(1) && is_simple(2));
-                break;
-
-            default: ok = false;
+            case 0: return true;
+            case 1: return is_simple(0) || is_sum(0);
+            case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
+            case 3: return is_sum(0) && is_simple(1) && is_simple(2);
+            default: return false;
             }
-            return ok;
+            return false;
         }
     };
 
-    _gemm_convolution_fwd_t(const pd_t *pd, const input_vector &inputs,
+    gemm_convolution_fwd_t(const pd_t *apd, const input_vector &inputs,
            const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-        , scratchpad_(nullptr)
+        : cpu_primitive_t(apd, inputs, outputs, true)
     {
-        using namespace prop_kind;
-
-        const auto &post_ops = conf_.attr()->post_ops_;
+        const auto &post_ops = pd()->attr()->post_ops_;
         const data_t one = 1.0, zero = 0.0;
         beta_ = post_ops.find(primitive_kind::sum) >= 0 ? one : zero;
 
-        jit_gemm_convolution_utils::init_conf(conf_.jcp_,
-            *(conf_.cdesc()), conf_.src_pd(), conf_.weights_pd(0),
-            conf_.dst_pd(), mkldnn_get_max_threads(), with_relu,
-            conf_.negative_slope());
-
-        size_t size = (size_t)conf_.jcp_.im2col_sz * sizeof(data_t);
-        jit_gemm_convolution_utils::prepare_scratchpad(this->conf_.jcp_,
-                &this->scratchpad_, size, this->conf_.jcp_.nthr);
-
         for (int i = 0; i < post_ops.len_; i++) {
             auto &post_op = post_ops.entry_[i];
             if (post_op.is_eltwise()) {
@@ -168,10 +148,7 @@ struct _gemm_convolution_fwd_t: public cpu_primitive_t {
         }
 
         use_fast_relu = false;
-        if (conf_.jcp_.with_relu && post_ops.len_ == 0) {
-            use_fast_relu = true;
-            fast_relu_ns = conf_.jcp_.relu_negative_slope;
-        } else if (post_ops.len_ == 1 && post_ops.entry_[0].is_relu(true, false)) {
+        if (post_ops.len_ == 1 && post_ops.entry_[0].is_relu(true, false)) {
             use_fast_relu = true;
             fast_relu_ns = post_ops.entry_[0].eltwise.alpha;
         } else if (post_ops.len_ == 2 && post_ops.entry_[0].is_sum() && post_ops.entry_[1].is_relu(true, false)) {
@@ -180,9 +157,7 @@ struct _gemm_convolution_fwd_t: public cpu_primitive_t {
         }
     }
 
-    ~_gemm_convolution_fwd_t() {
-        delete this->scratchpad_;
-
+    ~gemm_convolution_fwd_t() {
         for (auto inj : eltwise_injectors)
             delete inj;
         eltwise_injectors.clear();
@@ -190,19 +165,19 @@ struct _gemm_convolution_fwd_t: public cpu_primitive_t {
         for (auto inj : depthwise_injectors)
             delete inj;
         depthwise_injectors.clear();
-    };
+    }
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    virtual void execute(event_t *e) {
+    virtual void execute(event_t *e) const {
         execute_forward();
         e->set_state(event_t::ready);
     }
 
 private:
-    void execute_forward();
-    pd_t conf_;
-    scratchpad_t *scratchpad_;
+    void execute_forward() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
     data_t beta_;
 
     nstl::vector<ref_eltwise_scalar_fwd_t*> eltwise_injectors;
@@ -212,39 +187,16 @@ private:
     float fast_relu_ns;
 };
 
-using gemm_convolution_fwd_t =
-                         _gemm_convolution_fwd_t<false>;
-using gemm_convolution_relu_t =
-                         _gemm_convolution_fwd_t<true>;
-
 struct gemm_convolution_bwd_data_t: public cpu_primitive_t {
     struct pd_t: public cpu_convolution_bwd_data_pd_t {
         pd_t(engine_t *engine,
-                const convolution_desc_t *adesc,
-                const primitive_attr_t *attr,
+                const convolution_desc_t *adesc, const primitive_attr_t *attr,
                 const convolution_fwd_pd_t *hint_fwd_pd)
             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
-            , jcp_()
-        {}
+            , jcp_() {}
 
         DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t);
 
-        inline memory_format_t src_format()
-        {
-            using namespace memory_format;
-            return (utils::pick(this->desc()->diff_src_desc.ndims - 3,
-                ncw, nchw, ncdhw));
-        }
-        inline memory_format_t wei_format()
-        {
-            using namespace memory_format;
-            return (this->with_groups()
-                ? utils::pick(this->desc()->diff_src_desc.ndims - 3,
-                    goiw, goihw, goidhw)
-                : utils::pick(this->desc()->diff_src_desc.ndims - 3,
-                    oiw, oihw, oidhw));
-        }
-
         virtual status_t init() override {
             using namespace prop_kind;
             using namespace memory_format;
@@ -254,7 +206,8 @@ struct gemm_convolution_bwd_data_t: public cpu_primitive_t {
             bool ok = true
                 && this->set_default_params() == status::success
                 && this->desc()->prop_kind == backward_data
-                && this->desc()->alg_kind == alg_kind::convolution_direct
+                && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
+                           alg_kind::convolution_direct)
                 && !this->has_zero_dim_memory()
                 && utils::everyone_is(data_type::f32,
                         this->desc()->diff_src_desc.data_type,
@@ -263,12 +216,31 @@ struct gemm_convolution_bwd_data_t: public cpu_primitive_t {
                 && this->diff_src_pd_.desc()->format == src_format()
                 && this->diff_dst_pd_.desc()->format == src_format()
                 && this->weights_pd_.desc()->format == wei_format();
-            return ok ? status::success : status::unimplemented;
+            if (!ok) return status::unimplemented;
+
+            auto scratchpad = scratchpad_registry().registrar();
+            return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+                    *desc(), diff_src_pd(), weights_pd(0), diff_dst_pd(),
+                    mkldnn_get_max_threads());
         }
 
         jit_gemm_conv_conf_t jcp_;
 
     protected:
+        memory_format_t src_format() const {
+            using namespace memory_format;
+            const int ndims_sp = this->desc()->diff_src_desc.ndims - 2;
+            return (utils::pick(ndims_sp - 1, ncw, nchw, ncdhw));
+        }
+
+        memory_format_t wei_format() const {
+            using namespace memory_format;
+            const int ndims_sp = this->desc()->diff_src_desc.ndims - 2;
+            return (this->with_groups()
+                ? utils::pick(ndims_sp - 1, goiw, goihw, goidhw)
+                : utils::pick(ndims_sp - 1, oiw, oihw, oidhw));
+        }
+
         virtual status_t set_default_params() override {
             using namespace memory_format;
             if (this->diff_src_pd_.desc()->format == any)
@@ -277,34 +249,21 @@ struct gemm_convolution_bwd_data_t: public cpu_primitive_t {
                 CHECK(this->diff_dst_pd_.set_format(src_format()));
             if (this->weights_pd_.desc()->format == any)
                 CHECK(this->weights_pd_.set_format(wei_format()));
+            if (this->desc()->alg_kind == alg_kind::convolution_auto)
+                CHECK(this->set_alg_kind(alg_kind::convolution_direct));
             return status::success;
         }
     };
 
-    gemm_convolution_bwd_data_t(const pd_t *pd, const input_vector &inputs,
+    gemm_convolution_bwd_data_t(const pd_t *apd, const input_vector &inputs,
               const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-        , scratchpad_(nullptr)
-    {
-        using namespace prop_kind;
-
-        jit_gemm_convolution_utils::init_conf(conf_.jcp_,
-            *(conf_.desc()), conf_.diff_src_pd(), conf_.weights_pd(0),
-            conf_.diff_dst_pd(), mkldnn_get_max_threads());
-
-        size_t size = (size_t)conf_.jcp_.im2col_sz * sizeof(data_t);
-        jit_gemm_convolution_utils::prepare_scratchpad(this->conf_.jcp_,
-                &this->scratchpad_, size, this->conf_.jcp_.nthr);
-    }
-
-    ~gemm_convolution_bwd_data_t() {
-        delete this->scratchpad_;
-    };
+        : cpu_primitive_t(apd, inputs, outputs, true) {}
+    ~gemm_convolution_bwd_data_t() {}
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    virtual void execute(event_t *e) {
-        switch (conf_.desc()->prop_kind) {
+    virtual void execute(event_t *e) const {
+        switch (pd()->desc()->prop_kind) {
         case prop_kind::backward_data:
             execute_backward_data();
             break;
@@ -315,9 +274,8 @@ struct gemm_convolution_bwd_data_t: public cpu_primitive_t {
     }
 
 private:
-    void execute_backward_data();
-    pd_t conf_;
-    scratchpad_t *scratchpad_;
+    void execute_backward_data() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
 
 struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
@@ -327,27 +285,10 @@ struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
                 const primitive_attr_t *attr,
                 const convolution_fwd_pd_t *hint_fwd_pd)
             : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
-            , jcp_()
-        {}
+            , jcp_() {}
 
         DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t);
 
-        inline memory_format_t src_format()
-        {
-            using namespace memory_format;
-            return (utils::pick(this->desc()->src_desc.ndims - 3,
-                ncw, nchw, ncdhw));
-        }
-        inline memory_format_t wei_format()
-        {
-            using namespace memory_format;
-            return (this->with_groups()
-                ? utils::pick(this->desc()->src_desc.ndims - 3,
-                    goiw, goihw, goidhw)
-                : utils::pick(this->desc()->src_desc.ndims - 3,
-                    oiw, oihw, oidhw));
-        }
-
         virtual status_t init() override {
             using namespace prop_kind;
             using namespace memory_format;
@@ -357,7 +298,8 @@ struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
             bool ok = true
             && this->set_default_params() == status::success
             && this->desc()->prop_kind == backward_weights
-            && this->desc()->alg_kind == alg_kind::convolution_direct
+            && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
+                       alg_kind::convolution_direct)
             && !this->has_zero_dim_memory()
             && utils::everyone_is(data_type::f32,
                     this->desc()->src_desc.data_type,
@@ -368,12 +310,31 @@ struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
             && this->src_pd_.desc()->format == src_format()
             && this->diff_dst_pd_.desc()->format == src_format()
             && this->diff_weights_pd_.desc()->format == wei_format();
-            return ok ? status::success : status::unimplemented;
+            if (!ok) return status::unimplemented;
+
+            auto scratchpad = scratchpad_registry().registrar();
+            return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+                    *desc(), src_pd(), diff_weights_pd(0), diff_dst_pd(),
+                    mkldnn_get_max_threads());
         }
 
         jit_gemm_conv_conf_t jcp_;
 
     protected:
+        memory_format_t src_format() const {
+            using namespace memory_format;
+            const int ndims_sp = this->desc()->src_desc.ndims - 2;
+            return (utils::pick(ndims_sp - 1, ncw, nchw, ncdhw));
+        }
+
+        memory_format_t wei_format() const {
+            using namespace memory_format;
+            const int ndims_sp = this->desc()->src_desc.ndims - 2;
+            return (this->with_groups()
+                ? utils::pick(ndims_sp - 1, goiw, goihw, goidhw)
+                : utils::pick(ndims_sp - 1, oiw, oihw, oidhw));
+        }
+
         virtual status_t set_default_params() override {
             using namespace memory_format;
             if (this->src_pd_.desc()->format == any)
@@ -384,38 +345,21 @@ struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
                 CHECK(this->diff_weights_pd_.set_format(wei_format()));
             if (this->diff_bias_pd_.desc()->format == any)
                 CHECK(this->diff_bias_pd_.set_format(x));
+            if (this->desc()->alg_kind == alg_kind::convolution_auto)
+                CHECK(this->set_alg_kind(alg_kind::convolution_direct));
             return status::success;
         }
     };
 
-    gemm_convolution_bwd_weights_t(const pd_t *pd, const input_vector &inputs,
+    gemm_convolution_bwd_weights_t(const pd_t *apd, const input_vector &inputs,
               const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-        , scratchpad_(nullptr)
-    {
-        using namespace prop_kind;
-
-        jit_gemm_convolution_utils::init_conf(conf_.jcp_,
-            *(conf_.desc()), conf_.src_pd(), conf_.diff_weights_pd(0),
-            conf_.diff_dst_pd(), mkldnn_get_max_threads());
-        const memory_desc_wrapper weights_d(conf_.diff_weights_pd(0));
-
-        size_t size = (size_t)conf_.jcp_.im2col_sz  * sizeof(data_t);
-        if (conf_.jcp_.need_wei_reduction)
-            size += (size_t)conf_.jcp_.ngroups * weights_d.size();
-
-        jit_gemm_convolution_utils::prepare_scratchpad(this->conf_.jcp_,
-                &this->scratchpad_, size, conf_.jcp_.nthr);
-    }
-
-    ~gemm_convolution_bwd_weights_t() {
-        delete this->scratchpad_;
-     };
+        : cpu_primitive_t(apd, inputs, outputs, true) {}
+    ~gemm_convolution_bwd_weights_t() {}
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    virtual void execute(event_t *e) {
-        switch (conf_.desc()->prop_kind) {
+    virtual void execute(event_t *e) const {
+        switch (pd()->desc()->prop_kind) {
         case prop_kind::backward_weights:
             execute_backward_weights();
             break;
@@ -426,9 +370,8 @@ struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
     }
 
 private:
-    void execute_backward_weights();
-    pd_t conf_;
-    scratchpad_t *scratchpad_;
+    void execute_backward_weights() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
 
 }