updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_shuffle.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "c_types_map.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
23 #include "format_traits.hpp"
24
25 #include "ref_shuffle.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace memory_format;
32
33 template <int data_type_size>
34 template <mkldnn_memory_format_t fmt>
35 void ref_shuffle_t<data_type_size>::execute_() const {
36     using namespace prop_kind;
37     using namespace utils;
38
39     const memory_desc_wrapper data_d(pd()->data_pd());
40
41     auto input = reinterpret_cast<const data_t*>(this->input_memory(0));
42     auto output = reinterpret_cast<data_t*>(this->memory(0));
43
44     const int axis = pd()->axis();
45     const int axis_size = pd()->axis_size();
46
47     const int MB = pd()->MB();
48     const int C = pd()->C();
49     int H = 1, W = 1, D = 1, HW = 1, SP = 1;
50     const bool has_spatial = utils::one_of(data_d.ndims(), 3, 4, 5);
51     if (has_spatial)
52     {
53         D = pd()->D();
54         H = pd()->H();
55         W = pd()->W();
56         HW = H * W;
57         SP = D * HW;
58     }
59     const size_t stride_mb = data_d.blocking_desc().strides[0][0];
60     constexpr int blksize = format_traits<fmt>::blk_size;
61
62     if (axis == 1 && one_of(fmt, nChw16c, nChw8c, nChw4c, nCdhw16c, nCdhw8c,
63             nCdhw4c)) {
64 #if MKLDNN_THR == MKLDNN_THR_OMP
65 #       pragma omp parallel for collapse(3) schedule(static)
66         for (int mb = 0; mb < MB; ++mb)
67         for (int cb = 0; cb < C; cb += blksize)
68         for (int sp = 0; sp < SP; ++sp) {
69             const size_t off = mb * stride_mb + sp * blksize;
70             const size_t output_off = off + cb * SP;
71             PRAGMA_OMP_SIMD()
72             for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc)
73             {
74                 int input_c = rev_transposed_[cb + cc];
75                 const size_t input_off = off + input_c / blksize * SP * blksize
76                                            + input_c % blksize;
77                 output[output_off + cc] = input[input_off];
78             }
79         }
80 #else
81         parallel_nd(MB, utils::div_up(C, blksize), SP, [&](int mb, int c,
82                   int sp) {
83             const size_t off = mb * stride_mb + sp * blksize;
84             const int cb = c * blksize;
85             const size_t output_off = off + cb * SP;
86             for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc)
87             {
88                 int input_c = rev_transposed_[cb + cc];
89                 const size_t input_off = off + input_c / blksize * SP * blksize
90                                            + input_c % blksize;
91                 output[output_off + cc] = input[input_off];
92             }
93         });
94 #endif
95     } else if (axis == 1 && one_of(fmt, nhwc, ndhwc)) {
96         parallel_nd(MB, SP, [&](int mb, int sp) {
97             const size_t off = mb * stride_mb + sp * C;
98             PRAGMA_OMP_SIMD()
99             for (int c = 0; c < C; ++c)
100                 output[off + c] = input[off + rev_transposed_[c]];
101         });
102     } else if (axis == 1 && one_of(fmt, nchw, ncdhw)) {
103         parallel_nd(MB, C, [&](int mb, int c) {
104             const size_t output_off = mb * stride_mb + c * SP;
105             const size_t input_off = mb * stride_mb + rev_transposed_[c] * SP;
106             PRAGMA_OMP_SIMD()
107             for (int sp = 0; sp < SP; ++sp) {
108                 output[output_off + sp] = input[input_off + sp];
109             }
110         });
111     } else {
112         auto dims = pd()->desc()->data_desc.dims;
113         auto ndims = pd()->desc()->data_desc.ndims;
114         const size_t outer_size = utils::array_product(dims, axis);
115         const size_t inner_size = utils::array_product(dims + axis + 1,
116                                          ndims - axis - 1);
117         const size_t dim = axis_size * inner_size;
118
119         parallel_nd(outer_size, axis_size, inner_size, [&](size_t ou, int a,
120                size_t in)
121         {
122             const size_t off = ou * dim + in;
123             auto &o = output[data_d.off_l(off + a * inner_size)];
124             o = input[data_d.off_l(off + rev_transposed_[a] * inner_size)];
125         });
126     }
127 }
128
129 template void ref_shuffle_t<4>::execute_<nCdhw16c>() const;
130 template void ref_shuffle_t<4>::execute_<nChw16c>() const;
131 template void ref_shuffle_t<4>::execute_<nCdhw8c>() const;
132 template void ref_shuffle_t<4>::execute_<nChw8c>() const;
133 template void ref_shuffle_t<4>::execute_<nCdhw4c>() const;
134 template void ref_shuffle_t<4>::execute_<nChw4c>() const;
135 template void ref_shuffle_t<4>::execute_<ncdhw>() const;
136 template void ref_shuffle_t<4>::execute_<nchw>() const;
137 template void ref_shuffle_t<4>::execute_<ndhwc>() const;
138 template void ref_shuffle_t<4>::execute_<nhwc>() const;
139 template void ref_shuffle_t<4>::execute_<any>() const;
140
141 template void ref_shuffle_t<1>::execute_<nCdhw16c>() const;
142 template void ref_shuffle_t<1>::execute_<nChw16c>() const;
143 template void ref_shuffle_t<1>::execute_<nCdhw8c>() const;
144 template void ref_shuffle_t<1>::execute_<nChw8c>() const;
145 template void ref_shuffle_t<1>::execute_<nCdhw4c>() const;
146 template void ref_shuffle_t<1>::execute_<nChw4c>() const;
147 template void ref_shuffle_t<1>::execute_<ncdhw>() const;
148 template void ref_shuffle_t<1>::execute_<nchw>() const;
149 template void ref_shuffle_t<1>::execute_<ndhwc>() const;
150 template void ref_shuffle_t<1>::execute_<nhwc>() const;
151 template void ref_shuffle_t<1>::execute_<any>() const;
152
153 }
154 }
155 }
156
157 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s