1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_thread.hpp"
19 #include "simple_concat.hpp"
25 template <data_type_t data_type>
26 void simple_concat_t<data_type>::execute() {
27 const int num_arrs = conf_.n_inputs();
28 int *perm = conf_.perm_, *iperm = conf_.iperm_;
29 int concat_dim = conf_.concat_dim();
30 auto o_base_ptr = reinterpret_cast<data_t *>(this->memory());
32 for (int a = 0; a < num_arrs; ++a) {
33 const memory_desc_wrapper i_d(conf_.src_pd(a));
34 const memory_desc_wrapper o_d(conf_.src_image_pd(a));
36 input_ptrs_[a] = reinterpret_cast<const data_t *>(
37 this->input_memory(a)) + i_d.blk_off(0);
38 output_ptrs_[a] = o_base_ptr + o_d.blk_off(0);
39 nelems_to_copy_[a] = nelems_to_concat(concat_dim, perm, iperm, i_d);
40 for (int i = 0; i < TENSOR_MAX_DIMS; i++) {
41 if (i < perm[concat_dim])
42 is_[a][i] = size_t(i_d.blocking_desc().strides[0][iperm[i]]);
48 const memory_desc_wrapper o_d(conf_.src_image_pd());
49 auto &blk = o_d.blocking_desc();
51 for (int i = 0; i < perm[concat_dim]; i++)
52 os[i] = o_d.blocking_desc().strides[0][iperm[i]];
54 for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++)
55 phys_dims[i] = (i < (size_t)perm[concat_dim]) ?
56 o_d.dims()[iperm[i]] / blk.block_dims[iperm[i]] :
59 switch (perm[concat_dim]) {
61 for (int a = 0; a < num_arrs; ++a) {
62 const data_t *i = &input_ptrs_[a][0];
63 data_t *o = &output_ptrs_[a][0];
64 parallel_nd((ptrdiff_t)nelems_to_copy_[a],
65 [&](ptrdiff_t e) { o[e] = i[e]; });
70 parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
71 phys_dims[4], num_arrs,
72 [&](int n0, int n1, int n2, int n3, int n4, int a) {
73 // XXX: this code may access unitialized values in is_[*][0-4] --
74 // that's why we have to set them to zero although this is
76 size_t in_off = is_[a][0] * n0 + is_[a][1] * n1
77 + is_[a][2] * n2 + is_[a][3] * n3
79 size_t out_off = os[0] * n0 + os[1] * n1
80 + os[2] * n2 + os[3] * n3 + os[4] * n4;
81 const data_t *i = &input_ptrs_[a][in_off];
82 data_t *o = &output_ptrs_[a][out_off];
85 for (size_t e = 0; e < nelems_to_copy_[a]; ++e)
91 template struct simple_concat_t<data_type::f32>;
92 template struct simple_concat_t<data_type::u8>;
93 template struct simple_concat_t<data_type::s8>;
94 template struct simple_concat_t<data_type::s32>;