Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / simple_concat.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef SIMPLE_CONCAT_HPP
18 #define SIMPLE_CONCAT_HPP
19
20 #include "memory_tracking.hpp"
21
22 #include "cpu_concat.hpp"
23
24 namespace mkldnn {
25 namespace impl {
26 namespace cpu {
27
28 template <data_type_t data_type>
29 struct simple_concat_t: public cpu_primitive_t {
30     using cpu_memory_pd_t = cpu_memory_t::pd_t;
31
32     struct pd_t: public cpu_concat_pd_t {
33         pd_t(const memory_desc_t *output_d, int n, int concat_dim,
34                 const cpu_memory_pd_t **input_pds,
35                 const primitive_attr_t *attr)
36             : cpu_concat_pd_t(output_d, n, concat_dim, input_pds, attr) {}
37
38         pd_t(const pd_t &rhs) : cpu_concat_pd_t(rhs) {
39             for (size_t i = 0; i < sizeof(perm_)/sizeof(perm_[0]); i++) {
40                 perm_[i] = rhs.perm_[i];
41                 iperm_[i] = rhs.iperm_[i];
42             }
43         }
44
45         DECLARE_CPU_CONCAT_PD_T("simple:any", simple_concat_t);
46
47         virtual status_t init() override {
48             const memory_desc_wrapper dst_d(&dst_pd_);
49             bool ok = true
50                 && cpu_concat_pd_t::init() == success
51                 && dst_d.ndims() <= 6;
52             if (!ok) return unimplemented;
53
54             for (size_t i = 0; i < src_pds_.size(); ++i) {
55                 const memory_desc_wrapper i_d(&src_pds_[i]);
56                 const memory_desc_wrapper o_d(&src_image_pds_[i]);
57                 ok = ok
58                     && utils::everyone_is(data_type, i_d.data_type(),
59                             o_d.data_type())
60                     && i_d.format() == o_d.format()
61                     && !utils::one_of(i_d.format(), memory_format::blocked,
62                             memory_format::wino_fmt)
63                     && !i_d.is_additional_buffer();
64                 if (!ok) return unimplemented;
65             }
66
67             format_perm();
68
69             // density check
70             for (size_t i = 0; i < src_pds_.size(); ++i) {
71                 const memory_desc_wrapper i_d(&src_pds_[i]);
72                 const memory_desc_wrapper o_d(&src_image_pds_[i]);
73                 ok = ok
74                     && nelems_to_concat(i_d) == size_to_concat(i_d)
75                     && nelems_to_concat(o_d) == size_to_concat(o_d);
76                 if (!ok) return unimplemented;
77             }
78
79             init_scratchpad();
80
81             return success;
82         }
83
84         dims_t perm_;
85         dims_t iperm_;
86
87         size_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
88             const int ndims = data_d.ndims();
89             auto &blk = data_d.blocking_desc();
90
91             size_t nelems = 1;
92             for (int i = perm_[concat_dim()]; i < ndims; i++)
93                 nelems *= data_d.dims()[iperm_[i]] / blk.block_dims[iperm_[i]];
94             for (int i = 0; i < ndims; i++)
95                 nelems *= blk.block_dims[i];
96
97             return nelems;
98         }
99
100     private:
101         void format_perm() {
102             const memory_desc_wrapper dst_d(&dst_pd_);
103             const int ndims = dst_d.ndims();
104
105             strides_t strides;
106             utils::array_copy(strides, dst_d.blocking_desc().strides[0], ndims);
107
108             for (int i = 0; i < ndims; i++) iperm_[i] = i;
109
110             for (int i = 0; i < ndims - 1; i++) {
111                 bool swapped = false;
112                 for (int j = 0; j < ndims - i - 1; j++) {
113                     if (strides[j] < strides[j + 1]) {
114                         nstl::swap(strides[j], strides[j + 1]);
115                         nstl::swap(iperm_[j], iperm_[j + 1]);
116                         swapped = true;
117                     }
118                 }
119                 if (swapped == false)
120                     break;
121             }
122
123             for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i;
124         }
125
126         size_t size_to_concat(const memory_desc_wrapper &data_d) const {
127             size_t max_size = 0;
128             auto &blk = data_d.blocking_desc();
129             for (int d = perm_[concat_dim()]; d < data_d.ndims(); ++d) {
130                 auto block = blk.block_dims[iperm_[d]];
131                 max_size = nstl::max(max_size,
132                         size_t(blk.padding_dims[iperm_[d]] / block)
133                         * blk.strides[0][iperm_[d]]);
134                 if (block > 1) max_size = nstl::max(max_size,
135                         size_t(block * blk.strides[1][iperm_[d]]));
136             }
137             return max_size;
138         }
139
140         void init_scratchpad() {
141             using namespace memory_tracking::names;
142             auto scratchpad = scratchpad_registry().registrar();
143             scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs());
144             scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs());
145             scratchpad.book(key_concat_nelems, sizeof(size_t) * n_inputs());
146             scratchpad.book(key_concat_istrides,
147                     sizeof(strides_t) * n_inputs());
148         }
149     };
150
151     simple_concat_t(const pd_t *apd, const input_vector &inputs,
152             const output_vector &outputs)
153         : cpu_primitive_t(apd, inputs, outputs) {}
154     ~simple_concat_t() {}
155
156     virtual void execute(event_t *e) const {
157         execute();
158         e->set_state(event_t::ready);
159     }
160
161     typedef typename prec_traits<data_type>::type data_t;
162
163 private:
164     void execute() const;
165     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
166 };
167
168 }
169 }
170 }
171
172 #endif