Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_concat.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef REF_CONCAT_HPP
18 #define REF_CONCAT_HPP
19
20 #include "cpu_concat.hpp"
21 #include "reorder_pd.hpp"
22
23 namespace mkldnn {
24 namespace impl {
25 namespace cpu {
26
27 struct ref_concat_t: public cpu_primitive_t {
28     using cpu_memory_pd_t = cpu_memory_t::pd_t;
29
30     struct pd_t: public cpu_concat_pd_t {
31         pd_t(const memory_desc_t *output_d, int n, int concat_dim,
32                 const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
33             : cpu_concat_pd_t(output_d, n, concat_dim, input_pds, attr) {}
34         pd_t(const pd_t &rhs)
35             : cpu_concat_pd_t(rhs)
36         {
37             for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) {
38                 reorder_pds_.push_back(
39                         (const reorder_pd_t *)rhs.reorder_pds_[i]->clone());
40             }
41         }
42
43         ~pd_t() {
44             for (size_t i = 0; i < reorder_pds_.size(); ++i) {
45                 delete reorder_pds_[i];
46             }
47         }
48
49         static status_t create(concat_pd_t **concat_pd,
50                 const memory_desc_t *output_d, int n, int concat_dim,
51                 const memory_pd_t **input_pds, const primitive_attr_t *attr) {
52             auto _pd = new pd_t(output_d, n, concat_dim,
53                     (const cpu_memory_pd_t **)input_pds, attr);
54             if (_pd == nullptr) return out_of_memory;
55             if (_pd->init() != success) { delete _pd; return unimplemented; }
56             return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd);
57         }
58         virtual status_t create_primitive(primitive_t **primitive,
59                 const primitive_at_t *inputs,
60                 const primitive_t **outputs) const override {
61             double ms = get_msec();
62             auto n = n_inputs();
63             nstl::vector<primitive_t *> reorders;
64             reorders.resize(n);
65             for (int i = 0; i < n; ++i) {
66                 CHECK(reorder_pds_[i]->create_primitive(&reorders[i],
67                         &inputs[i], outputs));
68             }
69             primitive_t::input_vector ins(inputs, inputs + n_);
70             primitive_t::output_vector outs(outputs, outputs + 1);
71             auto ret = safe_ptr_assign<primitive_t>(*primitive,
72                     new ref_concat_t(this, ins, outs, reorders));
73             ms = get_msec() - ms;
74             if (mkldnn_verbose()->level >= 2) {
75                 printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms);
76                 fflush(0);
77             }
78             return ret;
79         }
80         virtual pd_t *clone() const override { return  new pd_t(*this); }
81         virtual const char *name() const override { return "ref:any"; }
82
83         virtual status_t init() override {
84             assert(engine()->kind() == engine_kind::cpu);
85
86             bool ok = cpu_concat_pd_t::init() == success;
87             if (!ok) return unimplemented;
88
89             for (int i = 0; i < n_; ++i) {
90                 auto r_impls = engine_->get_reorder_implementation_list();
91                 for (auto r = r_impls; *r; ++r) {
92                     const primitive_attr_t dummy_attr; /* alpha == 1. */
93                     reorder_pd_t *r_pd;
94                     if ((*r)(&r_pd, &src_pds_[i], &src_image_pds_[i],
95                                 &dummy_attr) == status::success) {
96                         r_pd->init_info();
97                         reorder_pds_.push_back(r_pd);
98                         break;
99                     }
100                 }
101             }
102             return (size_t)n_ == reorder_pds_.size() ? success : unimplemented;
103         }
104
105         nstl::vector<const reorder_pd_t *> reorder_pds_;
106     };
107
108     ref_concat_t(const pd_t *apd, const input_vector &inputs,
109             const output_vector &outputs, nstl::vector<primitive_t *> reorders)
110         : cpu_primitive_t(apd, inputs, outputs),
111         reorders_(reorders) {}
112
113     ~ref_concat_t() {
114         const auto n = reorders_.size();
115         for (size_t i = 0; i < n; ++i)
116             delete reorders_[i];
117     }
118
119     virtual void execute(event_t *e) const {
120         for (size_t i = 0; i < reorders_.size(); ++i) {
121             event_t ei;
122             reorders_[i]->execute(&ei);
123         }
124         e->set_state(event_t::ready);
125     }
126
127 private:
128     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
129     nstl::vector<primitive_t *> reorders_;
130 };
131
132 }
133 }
134 }
135
136 #endif