Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / mkldnn_thread_parallel_nd.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef MKLDNN_THREAD_PARALLEL_ND_HPP
18 #define MKLDNN_THREAD_PARALLEL_ND_HPP
19
20 /* This header must be included by mkldnn_thread.hpp only */
21
22 /* Functions:
23  *  - parallel(nthr, f)              - executes f in parallel using at most
24  *                                     nthr threads. If nthr equals 0
25  *                                     mkldnn_get_max_threads() threads is
26  *                                     used
27  *  - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already
28  *                                     created threads
29  *  - parallel_nd(dims..., f)        - creates a parallel section and then
30  *                                     calls for_nd
31  *  - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then
32  *                                     calls for_nd (mostly for convenience)
33  */
34
35 namespace mkldnn {
36 namespace impl {
37
38 /* general parallelization */
39 template <typename F>
40 void parallel(int nthr, F f) {
41     if (nthr == 0) nthr = mkldnn_get_max_threads();
42 #if MKLDNN_THR == MKLDNN_THR_SEQ
43     assert(nthr == 1);
44     f(0, 1);
45 #elif MKLDNN_THR == MKLDNN_THR_OMP
46     if (nthr == 1) { f(0, 1); return; }
47 #   pragma omp parallel num_threads(nthr)
48     f(mkldnn_get_thread_num(), mkldnn_get_num_threads());
49 #elif MKLDNN_THR == MKLDNN_THR_TBB
50     if (nthr == 1) { f(0, 1); return; }
51     tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); });
52 #endif
53 }
54
55 /* for_nd section */
56
57 template <typename T0, typename F>
58 void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
59     T0 start{0}, end{0};
60     balance211(D0, nthr, ithr, start, end);
61     for (T0 d0 = start; d0 < end; ++d0) f(d0);
62 }
63
64 template <typename T0, typename T1, typename F>
65 void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) {
66     const size_t work_amount = (size_t)D0 * D1;
67     if (work_amount == 0) return;
68     size_t start{0}, end{0};
69     balance211(work_amount, nthr, ithr, start, end);
70
71     T0 d0{0}; T1 d1{0};
72     utils::nd_iterator_init(start, d0, D0, d1, D1);
73     for (size_t iwork = start; iwork < end; ++iwork) {
74         f(d0, d1);
75         utils::nd_iterator_step(d0, D0, d1, D1);
76     }
77 }
78
79 template <typename T0, typename T1, typename T2, typename F>
80 void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
81         const T2 &D2, F f) {
82     const size_t work_amount = (size_t)D0 * D1 * D2;
83     if (work_amount == 0) return;
84     size_t start{0}, end{0};
85     balance211(work_amount, nthr, ithr, start, end);
86
87     T0 d0{0}; T1 d1{0}; T2 d2{0};
88     utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
89     for (size_t iwork = start; iwork < end; ++iwork) {
90         f(d0, d1, d2);
91         utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
92     }
93 }
94
95 template <typename T0, typename T1, typename T2, typename T3, typename F>
96 void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
97         const T2 &D2, const T3 &D3, F f) {
98     const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
99     if (work_amount == 0) return;
100     size_t start{0}, end{0};
101     balance211(work_amount, nthr, ithr, start, end);
102
103     T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
104     utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
105     for (size_t iwork = start; iwork < end; ++iwork) {
106         f(d0, d1, d2, d3);
107         utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
108     }
109 }
110
111 template <typename T0, typename T1, typename T2, typename T3, typename T4,
112          typename F>
113 void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
114         const T2 &D2, const T3 &D3, const T4 &D4, F f) {
115     const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
116     if (work_amount == 0) return;
117     size_t start{0}, end{0};
118     balance211(work_amount, nthr, ithr, start, end);
119
120     T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
121     utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
122     for (size_t iwork = start; iwork < end; ++iwork) {
123         f(d0, d1, d2, d3, d4);
124         utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
125     }
126 }
127
128 template <typename T0, typename T1, typename T2, typename T3, typename T4,
129          typename T5, typename F>
130 void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
131         const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
132     const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
133     if (work_amount == 0) return;
134     size_t start{0}, end{0};
135     balance211(work_amount, nthr, ithr, start, end);
136
137     T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
138     utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
139             d5, D5);
140     for (size_t iwork = start; iwork < end; ++iwork) {
141         f(d0, d1, d2, d3, d4, d5);
142         utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
143     }
144 }
145
146 // Skip a lambda function in the parameter pack.
147 template <typename T>
148 constexpr size_t get_work_amount(const T &v) { return 1; }
149 template <typename T, typename ...Args>
150 constexpr size_t get_work_amount(const T &v, Args &&...args)
151 { return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
152
153 /* parallel_nd and parallel_nd_in_omp section */
154
155 #if MKLDNN_THR != MKLDNN_THR_TBB
156 template <typename ...Args>
157 void parallel_nd(Args &&...args) {
158 #if MKLDNN_THR == MKLDNN_THR_SEQ
159     for_nd(0, 1, utils::forward<Args>(args)...);
160 #elif MKLDNN_THR == MKLDNN_THR_OMP
161     const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
162 #   pragma omp parallel if (do_parallel)
163     {
164         const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
165         const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
166         for_nd(ithr, nthr, utils::forward<Args>(args)...);
167     }
168 #endif
169 }
170 #else // MKLDNN_THR != MKLDNN_THR_TBB
171
172 // gcc 4.8 has a bug with passing parameter pack to lambdas.
173 // So have to explicitly instantiate all the cases.
174
175 template <typename T0, typename F>
176 void parallel_nd(const T0 &D0, F f) {
177     const int nthr = mkldnn_get_max_threads();
178     tbb::parallel_for(0, nthr, [&](int ithr) {
179         for_nd(ithr, nthr, D0, f);
180     });
181 }
182
183 template <typename T0, typename T1, typename F>
184 void parallel_nd(const T0 &D0, const T1 &D1, F f) {
185     const int nthr = mkldnn_get_max_threads();
186     tbb::parallel_for(0, nthr, [&](int ithr) {
187         for_nd(ithr, nthr, D0, D1, f);
188     });
189 }
190
191 template <typename T0, typename T1, typename T2, typename F>
192 void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) {
193     const int nthr = mkldnn_get_max_threads();
194     tbb::parallel_for(0, nthr, [&](int ithr) {
195         for_nd(ithr, nthr, D0, D1, D2, f);
196     });
197 }
198
199 template <typename T0, typename T1, typename T2, typename T3, typename F>
200 void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
201     const int nthr = mkldnn_get_max_threads();
202     tbb::parallel_for(0, nthr, [&](int ithr) {
203         for_nd(ithr, nthr, D0, D1, D2, D3, f);
204     });
205 }
206
207 template <typename T0, typename T1, typename T2, typename T3, typename T4,
208          typename F>
209 void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
210         const T4 &D4, F f) {
211     const int nthr = mkldnn_get_max_threads();
212     tbb::parallel_for(0, nthr, [&](int ithr) {
213         for_nd(ithr, nthr, D0, D1, D2, D3, D4, f);
214     });
215 }
216
217 template <typename T0, typename T1, typename T2, typename T3, typename T4,
218          typename T5, typename F>
219 void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
220         const T4 &D4, const T5 &D5, F f) {
221     const int nthr = mkldnn_get_max_threads();
222     tbb::parallel_for(0, nthr, [&](int ithr) {
223         for_nd(ithr, nthr, D0, D1, D2, D3, D4, D5, f);
224     });
225 }
226 #endif
227
228 template <typename ...Args>
229 void parallel_nd_in_omp(Args &&...args) {
230 #if MKLDNN_THR == MKLDNN_THR_SEQ
231     for_nd(0, 1, utils::forward<Args>(args)...);
232 #elif MKLDNN_THR == MKLDNN_THR_OMP
233     for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
234             utils::forward<Args>(args)...);
235 #elif MKLDNN_THR == MKLDNN_THR_TBB
236     assert(!"unsupported parallel_nd_in_omp()");
237 #endif
238 }
239
240 } // namespace impl
241 } // namespace mkldnn
242
243 #endif