Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / simple_concat.hpp
index 45193b2..84946da 100644 (file)
@@ -17,6 +17,8 @@
 #ifndef SIMPLE_CONCAT_HPP
 #define SIMPLE_CONCAT_HPP
 
+#include "memory_tracking.hpp"
+
 #include "cpu_concat.hpp"
 
 namespace mkldnn {
@@ -28,29 +30,25 @@ struct simple_concat_t: public cpu_primitive_t {
     using cpu_memory_pd_t = cpu_memory_t::pd_t;
 
     struct pd_t: public cpu_concat_pd_t {
-        pd_t(const memory_desc_t *output_d, int n,
-                int concat_dim, const cpu_memory_pd_t **input_pds,
+        pd_t(const memory_desc_t *output_d, int n, int concat_dim,
+                const cpu_memory_pd_t **input_pds,
                 const primitive_attr_t *attr)
-            : cpu_concat_pd_t(output_d, n, concat_dim, input_pds, attr)
-        {}
+            : cpu_concat_pd_t(output_d, n, concat_dim, input_pds, attr) {}
+
         pd_t(const pd_t &rhs) : cpu_concat_pd_t(rhs) {
             for (size_t i = 0; i < sizeof(perm_)/sizeof(perm_[0]); i++) {
                 perm_[i] = rhs.perm_[i];
                 iperm_[i] = rhs.iperm_[i];
             }
         }
+
         DECLARE_CPU_CONCAT_PD_T("simple:any", simple_concat_t);
 
         virtual status_t init() override {
-            auto is_dense = [&](const memory_desc_wrapper &data_d) {
-                return nelems_to_concat(concat_dim_, perm_, iperm_, data_d)
-                        == _size_to_concat(concat_dim_, perm_, iperm_, data_d);
-            };
             const memory_desc_wrapper dst_d(&dst_pd_);
             bool ok = true
                 && cpu_concat_pd_t::init() == success
                 && dst_d.ndims() <= 6;
-
             if (!ok) return unimplemented;
 
             for (size_t i = 0; i < src_pds_.size(); ++i) {
@@ -61,118 +59,110 @@ struct simple_concat_t: public cpu_primitive_t {
                             o_d.data_type())
                     && i_d.format() == o_d.format()
                     && !utils::one_of(i_d.format(), memory_format::blocked,
-                        memory_format::wino_fmt)
+                            memory_format::wino_fmt)
                     && !i_d.is_additional_buffer();
+                if (!ok) return unimplemented;
             }
 
-            if (!ok)
-                return unimplemented;
-
-            format_perm(dst_d.ndims(), dst_d.blocking_desc().strides[0], perm_,
-                    iperm_);
+            format_perm();
 
+            // density check
             for (size_t i = 0; i < src_pds_.size(); ++i) {
                 const memory_desc_wrapper i_d(&src_pds_[i]);
                 const memory_desc_wrapper o_d(&src_image_pds_[i]);
-                ok = ok && is_dense(i_d) && is_dense(o_d);
+                ok = ok
+                    && nelems_to_concat(i_d) == size_to_concat(i_d)
+                    && nelems_to_concat(o_d) == size_to_concat(o_d);
+                if (!ok) return unimplemented;
             }
 
-            return ok ? success : unimplemented;
+            init_scratchpad();
+
+            return success;
         }
+
         dims_t perm_;
         dims_t iperm_;
-    };
 
-    simple_concat_t(const pd_t *conf, const input_vector &inputs,
-            const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*conf)
-    {
-        const int n = conf_.n_inputs();
-        input_ptrs_ = (decltype(input_ptrs_))malloc(
-                sizeof(*input_ptrs_) * n, 64);
-        output_ptrs_ = (decltype(output_ptrs_))malloc(
-                sizeof(*output_ptrs_) * n, 64);
-        nelems_to_copy_ = (decltype(nelems_to_copy_))malloc(
-                sizeof(*nelems_to_copy_) * n, 64);
-        is_ = (decltype(is_))malloc(sizeof(*is_) * n, 64);
-    }
+        size_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
+            const int ndims = data_d.ndims();
+            auto &blk = data_d.blocking_desc();
 
-    ~simple_concat_t() {
-        free(input_ptrs_);
-        free(output_ptrs_);
-        free(nelems_to_copy_);
-        free(is_);
-    }
+            size_t nelems = 1;
+            for (int i = perm_[concat_dim()]; i < ndims; i++)
+                nelems *= data_d.dims()[iperm_[i]] / blk.block_dims[iperm_[i]];
+            for (int i = 0; i < ndims; i++)
+                nelems *= blk.block_dims[i];
 
-    virtual void execute(event_t *e) {
-        execute();
-        e->set_state(event_t::ready);
-    }
+            return nelems;
+        }
 
-    typedef typename prec_traits<data_type>::type data_t;
+    private:
+        void format_perm() {
+            const memory_desc_wrapper dst_d(&dst_pd_);
+            const int ndims = dst_d.ndims();
 
-private:
-    static void format_perm(
-            const int ndims, const stride_t *strides, int *perm, int *iperm) {
-        assert(ndims >= 0);
-        bool swapped;
-        strides_t strides_tmp;
-        utils::array_copy(strides_tmp, strides, ndims);
-        for (int i = 0; i < ndims; i++)
-            iperm[i] = i;
-        for (int i = 0; i < ndims - 1; i++) {
-            swapped = false;
-            for (int j = 0; j < ndims - i - 1; j++) {
-                if (strides_tmp[j] < strides_tmp[j + 1]) {
-                    nstl::swap(strides_tmp[j], strides_tmp[j + 1]);
-                    nstl::swap(iperm[j], iperm[j + 1]);
-                    swapped = true;
+            strides_t strides;
+            utils::array_copy(strides, dst_d.blocking_desc().strides[0], ndims);
+
+            for (int i = 0; i < ndims; i++) iperm_[i] = i;
+
+            for (int i = 0; i < ndims - 1; i++) {
+                bool swapped = false;
+                for (int j = 0; j < ndims - i - 1; j++) {
+                    if (strides[j] < strides[j + 1]) {
+                        nstl::swap(strides[j], strides[j + 1]);
+                        nstl::swap(iperm_[j], iperm_[j + 1]);
+                        swapped = true;
+                    }
                 }
+                if (swapped == false)
+                    break;
             }
-            if (swapped == false)
-                break;
-        }
-        for (int i = 0; i < ndims; i++)
-            perm[iperm[i]] = i;
-    }
 
-    static size_t nelems_to_concat(const int concat_dim, int *perm, int *iperm,
-            const memory_desc_wrapper &data_d) {
-        const int ndims = data_d.ndims();
-        auto &blk = data_d.blocking_desc();
-        int nelems = 1;
-        for (int i = perm[concat_dim]; i < ndims; i++) {
-            nelems *= data_d.dims()[iperm[i]] / blk.block_dims[iperm[i]];
+            for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i;
         }
-        for (int i = 0; i < ndims; i++) {
-            nelems *= blk.block_dims[i];
-        }
-        return nelems;
-    }
 
-    static size_t _size_to_concat(const int concat_dim, int *perm, int *iperm,
-            const memory_desc_wrapper &data_d) {
-        size_t max_size = 0;
-        auto &blk = data_d.blocking_desc();
-        for (int d = perm[concat_dim]; d < data_d.ndims(); ++d) {
-            auto block = blk.block_dims[iperm[d]];
-            max_size = nstl::max(max_size,
-                    size_t(blk.padding_dims[iperm[d]] / block)
-                            * blk.strides[0][iperm[d]]);
-            if (block > 1)
+        size_t size_to_concat(const memory_desc_wrapper &data_d) const {
+            size_t max_size = 0;
+            auto &blk = data_d.blocking_desc();
+            for (int d = perm_[concat_dim()]; d < data_d.ndims(); ++d) {
+                auto block = blk.block_dims[iperm_[d]];
                 max_size = nstl::max(max_size,
-                        size_t(block * blk.strides[1][iperm[d]]));
+                        size_t(blk.padding_dims[iperm_[d]] / block)
+                        * blk.strides[0][iperm_[d]]);
+                if (block > 1) max_size = nstl::max(max_size,
+                        size_t(block * blk.strides[1][iperm_[d]]));
+            }
+            return max_size;
+        }
+
+        void init_scratchpad() {
+            using namespace memory_tracking::names;
+            auto scratchpad = scratchpad_registry().registrar();
+            scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs());
+            scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs());
+            scratchpad.book(key_concat_nelems, sizeof(size_t) * n_inputs());
+            scratchpad.book(key_concat_istrides,
+                    sizeof(strides_t) * n_inputs());
         }
-        return max_size;
+    };
+
+    simple_concat_t(const pd_t *apd, const input_vector &inputs,
+            const output_vector &outputs)
+        : cpu_primitive_t(apd, inputs, outputs) {}
+    ~simple_concat_t() {}
+
+    virtual void execute(event_t *e) const {
+        execute();
+        e->set_state(event_t::ready);
     }
 
-    void execute();
-    pd_t conf_;
+    typedef typename prec_traits<data_type>::type data_t;
 
-    const data_t **input_ptrs_ = nullptr;
-    data_t **output_ptrs_ = nullptr;
-    size_t *nelems_to_copy_ = nullptr;
-    strides_t *is_ = nullptr;
+private:
+    void execute() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
 
 }