1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef SIMPLE_CONCAT_HPP
18 #define SIMPLE_CONCAT_HPP
20 #include "memory_tracking.hpp"
22 #include "cpu_concat.hpp"
28 template <data_type_t data_type>
29 struct simple_concat_t: public cpu_primitive_t {
30 using cpu_memory_pd_t = cpu_memory_t::pd_t;
32 struct pd_t: public cpu_concat_pd_t {
33 pd_t(const memory_desc_t *output_d, int n, int concat_dim,
34 const cpu_memory_pd_t **input_pds,
35 const primitive_attr_t *attr)
36 : cpu_concat_pd_t(output_d, n, concat_dim, input_pds, attr) {}
38 pd_t(const pd_t &rhs) : cpu_concat_pd_t(rhs) {
39 for (size_t i = 0; i < sizeof(perm_)/sizeof(perm_[0]); i++) {
40 perm_[i] = rhs.perm_[i];
41 iperm_[i] = rhs.iperm_[i];
45 DECLARE_CPU_CONCAT_PD_T("simple:any", simple_concat_t);
47 virtual status_t init() override {
48 const memory_desc_wrapper dst_d(&dst_pd_);
50 && cpu_concat_pd_t::init() == success
51 && dst_d.ndims() <= 6;
52 if (!ok) return unimplemented;
54 for (size_t i = 0; i < src_pds_.size(); ++i) {
55 const memory_desc_wrapper i_d(&src_pds_[i]);
56 const memory_desc_wrapper o_d(&src_image_pds_[i]);
58 && utils::everyone_is(data_type, i_d.data_type(),
60 && i_d.format() == o_d.format()
61 && !utils::one_of(i_d.format(), memory_format::blocked,
62 memory_format::wino_fmt)
63 && !i_d.is_additional_buffer();
64 if (!ok) return unimplemented;
70 for (size_t i = 0; i < src_pds_.size(); ++i) {
71 const memory_desc_wrapper i_d(&src_pds_[i]);
72 const memory_desc_wrapper o_d(&src_image_pds_[i]);
74 && nelems_to_concat(i_d) == size_to_concat(i_d)
75 && nelems_to_concat(o_d) == size_to_concat(o_d);
76 if (!ok) return unimplemented;
87 size_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
88 const int ndims = data_d.ndims();
89 auto &blk = data_d.blocking_desc();
92 for (int i = perm_[concat_dim()]; i < ndims; i++)
93 nelems *= data_d.dims()[iperm_[i]] / blk.block_dims[iperm_[i]];
94 for (int i = 0; i < ndims; i++)
95 nelems *= blk.block_dims[i];
102 const memory_desc_wrapper dst_d(&dst_pd_);
103 const int ndims = dst_d.ndims();
106 utils::array_copy(strides, dst_d.blocking_desc().strides[0], ndims);
108 for (int i = 0; i < ndims; i++) iperm_[i] = i;
110 for (int i = 0; i < ndims - 1; i++) {
111 bool swapped = false;
112 for (int j = 0; j < ndims - i - 1; j++) {
113 if (strides[j] < strides[j + 1]) {
114 nstl::swap(strides[j], strides[j + 1]);
115 nstl::swap(iperm_[j], iperm_[j + 1]);
119 if (swapped == false)
123 for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i;
126 size_t size_to_concat(const memory_desc_wrapper &data_d) const {
128 auto &blk = data_d.blocking_desc();
129 for (int d = perm_[concat_dim()]; d < data_d.ndims(); ++d) {
130 auto block = blk.block_dims[iperm_[d]];
131 max_size = nstl::max(max_size,
132 size_t(blk.padding_dims[iperm_[d]] / block)
133 * blk.strides[0][iperm_[d]]);
134 if (block > 1) max_size = nstl::max(max_size,
135 size_t(block * blk.strides[1][iperm_[d]]));
140 void init_scratchpad() {
141 using namespace memory_tracking::names;
142 auto scratchpad = scratchpad_registry().registrar();
143 scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs());
144 scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs());
145 scratchpad.book(key_concat_nelems, sizeof(size_t) * n_inputs());
146 scratchpad.book(key_concat_istrides,
147 sizeof(strides_t) * n_inputs());
151 simple_concat_t(const pd_t *apd, const input_vector &inputs,
152 const output_vector &outputs)
153 : cpu_primitive_t(apd, inputs, outputs) {}
154 ~simple_concat_t() {}
156 virtual void execute(event_t *e) const {
158 e->set_state(event_t::ready);
161 typedef typename prec_traits<data_type>::type data_t;
164 void execute() const;
165 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }