Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / simple_concat.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_thread.hpp"
18
19 #include "simple_concat.hpp"
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24
25 using namespace memory_tracking::names;
26
27 template <data_type_t data_type>
28 void simple_concat_t<data_type>::execute() const {
29     auto scratchpad = this->scratchpad();
30     auto iptrs = scratchpad.template get<const data_t *>(key_concat_iptrs);
31     auto optrs = scratchpad.template get<data_t *>(key_concat_optrs);
32     auto nelems_to_copy = scratchpad.template get<size_t>(key_concat_nelems);
33     auto is = scratchpad.template get<strides_t>(key_concat_istrides);
34
35     const int num_arrs = pd()->n_inputs();
36     const ptrdiff_t *perm = pd()->perm_, *iperm = pd()->iperm_;
37     const int concat_dim = pd()->concat_dim();
38     auto o_base_ptr = reinterpret_cast<data_t *>(this->memory());
39
40     for (int a = 0; a < num_arrs; ++a) {
41         const memory_desc_wrapper i_d(pd()->src_pd(a));
42         const memory_desc_wrapper o_d(pd()->src_image_pd(a));
43
44         iptrs[a] = reinterpret_cast<const data_t *>(
45                 this->input_memory(a)) + i_d.blk_off(0);
46         optrs[a] = o_base_ptr + o_d.blk_off(0);
47         nelems_to_copy[a] = pd()->nelems_to_concat(i_d);
48         for (int i = 0; i < TENSOR_MAX_DIMS; i++) {
49             if (i < perm[concat_dim])
50                 is[a][i] = size_t(i_d.blocking_desc().strides[0][iperm[i]]);
51             else
52                 is[a][i] = 0;
53         }
54     }
55
56     const memory_desc_wrapper o_d(pd()->src_image_pd());
57     auto &blk = o_d.blocking_desc();
58
59     strides_t os = { 0 };
60     for (int i = 0; i < perm[concat_dim]; i++)
61         os[i] = o_d.blocking_desc().strides[0][iperm[i]];
62
63     dims_t phys_dims;
64     for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++)
65         phys_dims[i] = (i < (size_t)perm[concat_dim])
66             ?  o_d.dims()[iperm[i]] / blk.block_dims[iperm[i]] : 1;
67
68     if (perm[concat_dim] == 0) {
69         for (int a = 0; a < num_arrs; ++a) {
70             const data_t *i = &iptrs[a][0];
71             data_t *o = &optrs[a][0];
72             parallel_nd((ptrdiff_t)nelems_to_copy[a],
73                     [&](ptrdiff_t e) { o[e] = i[e]; });
74         }
75     } else {
76         parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
77             phys_dims[4], num_arrs,
78             [&](int n0, int n1, int n2, int n3, int n4, int a) {
79             // XXX: this code may access uninitialized values in is[*][0-4] --
80             // that's why we have to set them to zero although this is
81             // probably benign
82             size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2
83                     + is[a][3] * n3 + is[a][4] * n4;
84             size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2
85                     + os[3] * n3 + os[4] * n4;
86             const data_t *i = &iptrs[a][in_off];
87             data_t *o = &optrs[a][out_off];
88 #if defined(__GNUC__) && !defined(__INTEL_COMPILER)
89             // The code below performs data copying: o[e] = i[e]
90             // and uses a workaround to make GNU compilers optimize it
91             uint8_t *ptro = reinterpret_cast<uint8_t *>(o);
92             const uint8_t *ptri = reinterpret_cast<const uint8_t *>(i);
93             const size_t main_part =
94                 nelems_to_copy[a] * sizeof(data_t) / sizeof(uint32_t);
95             const size_t tail_part =
96                 nelems_to_copy[a] * sizeof(data_t) % sizeof(uint32_t);
97
98             PRAGMA_OMP_SIMD()
99             for (size_t e = 0; e < main_part; ++e) {
100                 *(reinterpret_cast<uint32_t *>(ptro))
101                     = *(reinterpret_cast<const uint32_t *>(ptri));
102                 ptro += sizeof(uint32_t);
103                 ptri += sizeof(uint32_t);
104             }
105             for (size_t e = 0; e < tail_part; ++e) {
106                 *ptro = *ptri;
107                 ++ptro;
108                 ++ptri;
109             }
110 #else
111             PRAGMA_OMP_SIMD()
112             for (size_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e];
113 #endif
114         });
115     }
116 }
117
118 template struct simple_concat_t<data_type::f32>;
119 template struct simple_concat_t<data_type::u8>;
120 template struct simple_concat_t<data_type::s8>;
121 template struct simple_concat_t<data_type::s32>;
122
123 }
124 }
125 }