Publishing R5 content (#72)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / simple_concat.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_thread.hpp"
18
19 #include "simple_concat.hpp"
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24
25 template <data_type_t data_type>
26 void simple_concat_t<data_type>::execute() {
27     const int num_arrs = conf_.n_inputs();
28     int *perm = conf_.perm_, *iperm = conf_.iperm_;
29     int concat_dim = conf_.concat_dim();
30     auto o_base_ptr = reinterpret_cast<data_t *>(this->memory());
31
32     for (int a = 0; a < num_arrs; ++a) {
33         const memory_desc_wrapper i_d(conf_.src_pd(a));
34         const memory_desc_wrapper o_d(conf_.src_image_pd(a));
35
36         input_ptrs_[a] = reinterpret_cast<const data_t *>(
37                 this->input_memory(a)) + i_d.blk_off(0);
38         output_ptrs_[a] = o_base_ptr + o_d.blk_off(0);
39         nelems_to_copy_[a] = nelems_to_concat(concat_dim, perm, iperm, i_d);
40         for (int i = 0; i < TENSOR_MAX_DIMS; i++) {
41             if (i < perm[concat_dim])
42                 is_[a][i] = size_t(i_d.blocking_desc().strides[0][iperm[i]]);
43             else
44                 is_[a][i] = 0;
45         }
46     }
47
48     const memory_desc_wrapper o_d(conf_.src_image_pd());
49     auto &blk = o_d.blocking_desc();
50     strides_t os = { 0 };
51     for (int i = 0; i < perm[concat_dim]; i++)
52         os[i] = o_d.blocking_desc().strides[0][iperm[i]];
53     dims_t phys_dims;
54     for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++)
55         phys_dims[i] = (i < (size_t)perm[concat_dim]) ?
56                 o_d.dims()[iperm[i]] / blk.block_dims[iperm[i]] :
57                 1;
58
59     switch (perm[concat_dim]) {
60     case (0): {
61         for (int a = 0; a < num_arrs; ++a) {
62             const data_t *i = &input_ptrs_[a][0];
63             data_t *o = &output_ptrs_[a][0];
64             parallel_nd((ptrdiff_t)nelems_to_copy_[a],
65                     [&](ptrdiff_t e) { o[e] = i[e]; });
66         }
67         break;
68     }
69     default:
70         parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
71             phys_dims[4], num_arrs,
72             [&](int n0, int n1, int n2, int n3, int n4, int a) {
73             // XXX: this code may access unitialized values in is_[*][0-4] --
74             // that's why we have to set them to zero although this is
75             // probably benign
76             size_t in_off = is_[a][0] * n0 + is_[a][1] * n1
77                     + is_[a][2] * n2 + is_[a][3] * n3
78                     + is_[a][4] * n4;
79             size_t out_off = os[0] * n0 + os[1] * n1
80                     + os[2] * n2 + os[3] * n3 + os[4] * n4;
81             const data_t *i = &input_ptrs_[a][in_off];
82             data_t *o = &output_ptrs_[a][out_off];
83
84             PRAGMA_OMP_SIMD()
85             for (size_t e = 0; e < nelems_to_copy_[a]; ++e)
86                 o[e] = i[e];
87         });
88     }
89 }
90
91 template struct simple_concat_t<data_type::f32>;
92 template struct simple_concat_t<data_type::u8>;
93 template struct simple_concat_t<data_type::s8>;
94 template struct simple_concat_t<data_type::s32>;
95
96 }
97 }
98 }