Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / gemm_utils.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include <math.h>
17
18 #include "mkldnn_thread.hpp"
19 #include "utils.hpp"
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24 namespace gemm_utils {
25 #define BM_NOCOPY_AVX 64
26 #define BN_NOCOPY_AVX 48
27 #define BK_NOCOPY_AVX 384
28 #define BN_LARGE_NOCOPY_AVX 192
29 #define BM_SMALL_NOCOPY_AVX 16
30 #define BN_SMALL_NOCOPY_AVX 1
31 #define BK_SMALL_NOCOPY_AVX 4
32 // Determine number of threads for each dimension of a 3-D partitioning
33 // algorithm based on input parameters
34 // m/n/k - First/second/third parameter for GEMM
35 // nthrs - total available number of threads
36 // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
37 // BM/BN/BK - blocking values
38 void calc_nthr_nocopy_avx(int m, int n, int k,
39         int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
40         int *BK)
41 {
42     int nthr, nthr_m, nthr_n, nthr_k;
43     int MB, NB, KB;
44
45     nthr = nthrs;
46     nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX;
47     nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX;
48     nthr_k = 1;
49
50     // Partition along K dimension
51     //  - if threading allows having barriers (e.g. OMP)
52     //  - if there is not enough parallelism along M or N
53     if (mkldnn_thr_syncable()) {
54         int nthr_other = nthr_k = 1;
55         while ((nthr_m * nthr_n * nthr_other < nthr)
56                 && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) {
57             nthr_other++;
58             if ((nthr / nthr_other) * nthr_other > 0.9 * nthr)
59                 nthr_k = nthr_other;
60         }
61     }
62     nthr /= nthr_k;
63
64     if (nthr_m == 1)
65         nthr_n = nthr;
66     if (nthr_n == 1)
67         nthr_m = nthr;
68
69     // Simple partition reduction
70     while (nthr_m * nthr_n > nthr)
71         if (nthr_m > nthr_n)
72             nthr_m--;
73         else
74             nthr_n--;
75     while (nthr_m * nthr_n < nthr)
76         if (nthr_m < nthr_n)
77             nthr_m++;
78         else
79             nthr_n++;
80
81     if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) {
82
83         if (nthr_m <= nthr_n) {
84             nthr_m = (int)sqrt((double)nthr);
85             if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX)
86                 nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX;
87             nthr_n = nthr / nthr_m;
88
89             while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
90                 nthr_m--;
91                 nthr_n = nthr / nthr_m;
92             }
93         } else {
94             nthr_n = (int)sqrt((double)nthr);
95             if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX)
96                 nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX;
97             nthr_m = nthr / nthr_n;
98
99             while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
100                 nthr_n--;
101                 nthr_m = nthr / nthr_n;
102             }
103         }
104     }
105
106     MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1;
107     MB -= MB % BM_SMALL_NOCOPY_AVX;
108     NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1;
109     NB -= NB % BN_SMALL_NOCOPY_AVX;
110     KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1;
111     KB -= KB % BK_SMALL_NOCOPY_AVX;
112
113     if (MB * nthr_m > m)
114         nthr_m = (m + MB - 1) / MB;
115     if (NB * nthr_n > n)
116         nthr_n = (n + NB - 1) / NB;
117     if (KB * nthr_k > k)
118         nthr_k = (k + KB - 1) / KB;
119
120     *nthrs_m = nthr_m;
121     *nthrs_n = nthr_n;
122     *nthrs_k = nthr_k;
123
124     *BM = MB;
125     *BN = NB;
126     *BK = KB;
127 }
128 #undef BM_NOCOPY_AVX
129 #undef BN_NOCOPY_AVX
130 #undef BK_NOCOPY_AVX
131 #undef BN_LARGE_NOCOPY_AVX
132 #undef BM_SMALL_NOCOPY_AVX
133 #undef BN_SMALL_NOCOPY_AVX
134 #undef BK_SMALL_NOCOPY_AVX
135
136 #define BM_NOCOPY_AVX512_COMMON 32
137 #define BN_NOCOPY_AVX512_COMMON 64
138 #define BK_NOCOPY_AVX512_COMMON 192
139 #define BN_LARGE_NOCOPY_AVX512_COMMON 192
140 #define BM_SMALL_NOCOPY_AVX512_COMMON 16
141 #define BN_SMALL_NOCOPY_AVX512_COMMON 1
142 #define BK_SMALL_NOCOPY_AVX512_COMMON 4
143 // Determine number of threads for each dimension of a 3-D partitioning
144 // algorithm based on input parameters
145 // m/n/k - First/second/third parameter for GEMM
146 // nthrs - total available number of threads
147 // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
148 // BM/BN/BK - blocking values
149 void calc_nthr_nocopy_avx512_common(int m,
150         int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
151         int *BM, int *BN, int *BK)
152 {
153     int nthr, nthr_m, nthr_n, nthr_k = 1;
154     int MB, NB, KB;
155     nthr = nthrs;
156
157     int counter = 0;
158     float ratio_float = 1.;
159     int ratio = 1;
160     nthr = nthrs;
161     int nthr_m_gt_n;
162
163     // Partition along K dimension
164     //  - if threading allows having barriers (e.g. OMP)
165     //  - if there is not enough parallelism along M or N
166     if (mkldnn_thr_syncable()) {
167         if (n <= 2 * BN_NOCOPY_AVX512_COMMON &&
168                 m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) {
169             nthr_k = k / BK_NOCOPY_AVX512_COMMON;
170             if (nthr_k > nthr / 4)
171                 nthr_k = nthr / 4;
172             if (nthr_k < 1)
173                 nthr_k = 1;
174
175             while ((nthr_k > 1) && (nthr % nthr_k)) {
176                 nthr_k--;
177             }
178             nthr /= nthr_k;
179         } else {
180             nthr_k = 1;
181         }
182     }
183     nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON;
184     nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON;
185
186     if (nthr_m < 1)
187         nthr_m = 1;
188     if (nthr_n < 1)
189         nthr_n = 1;
190
191     nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0;
192     ratio_float = (float)nthr_m / nthr_n;
193
194     if (nthr_m_gt_n)
195         ratio = (int)ratio_float;
196     else
197         ratio = (int)(1. / ratio_float);
198
199     // scale down nthr_m and nthr_n if they are too large
200     while (nthr_m * nthr_n > 4 * nthr) {
201         nthr_m /= 2;
202         nthr_n /= 2;
203     }
204
205     if (nthr_m < 1)
206         nthr_m = 1;
207     if (nthr_n < 1)
208         nthr_n = 1;
209
210     // Simple partition reduction
211     counter = 0;
212     while (nthr_m * nthr_n > nthr) {
213         if (nthr_m > nthr_n) {
214             if (counter < ratio)
215                 nthr_m--;
216             else {
217                 nthr_n--;
218                 counter = -1;
219             }
220         } else {
221             if (counter < ratio)
222                 nthr_n--;
223             else {
224                 nthr_m--;
225                 counter = -1;
226             }
227         }
228         counter++;
229     }
230
231     // Simple partition increment
232     counter = 0;
233     while (nthr_m * nthr_n < 0.95 * nthr) {
234         if (nthr_m > nthr_n) {
235             if (counter < ratio)
236                 nthr_m++;
237             else {
238                 nthr_n++;
239                 counter = -1;
240             }
241         } else {
242             if (counter < ratio)
243                 nthr_n++;
244             else {
245                 nthr_m++;
246                 counter = -1;
247             }
248         }
249         counter++;
250     }
251
252     // if nothing works out, then this should work
253     if ((nthr_m * nthr_n > nthr)) {
254
255         if (nthr_m <= nthr_n) {
256             nthr_m = (int)sqrt((double)nthr);
257             if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
258                             / BM_SMALL_NOCOPY_AVX512_COMMON)
259                 nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
260                         / BM_SMALL_NOCOPY_AVX512_COMMON;
261             nthr_n = nthr / nthr_m;
262
263             while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
264                 nthr_m--;
265                 nthr_n = nthr / nthr_m;
266             }
267         } else {
268             nthr_n = (int)sqrt((double)nthr);
269             if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
270                             / BN_SMALL_NOCOPY_AVX512_COMMON)
271                 nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
272                         / BN_SMALL_NOCOPY_AVX512_COMMON;
273             nthr_m = nthr / nthr_n;
274
275             while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
276                 nthr_n--;
277                 nthr_m = nthr / nthr_n;
278             }
279         }
280     }
281
282     MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1;
283     MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON;
284     NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1;
285     NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON;
286     KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1;
287     KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON;
288
289     if (MB * nthr_m > m)
290         nthr_m = (m + MB - 1) / MB;
291     if (NB * nthr_n > n)
292         nthr_n = (n + NB - 1) / NB;
293     if (KB * nthr_k > k)
294         nthr_k = (k + KB - 1) / KB;
295
296     *nthrs_m = nthr_m;
297     *nthrs_n = nthr_n;
298     *nthrs_k = nthr_k;
299
300     *BM = MB;
301     *BN = NB;
302     *BK = KB;
303 }
304 #undef BM_NOCOPY_AVX512_COMMON
305 #undef BN_NOCOPY_AVX512_COMMON
306 #undef BK_NOCOPY_AVX512_COMMON
307 #undef BN_LARGE_NOCOPY_AVX512_COMMON
308 #undef BM_SMALL_NOCOPY_AVX512_COMMON
309 #undef BN_SMALL_NOCOPY_AVX512_COMMON
310 #undef BK_SMALL_NOCOPY_AVX512_COMMON
311
312 // Partition n values as equally as possible among nthr threads
313 // and set the offset (t_offset) and number of values (t_block) for ithr
314 // Assumption: 0 <= ithr < nthr
315 void partition_unit_diff(
316         int ithr, int nthr, int n, int *t_offset, int *t_block)
317 {
318     int band = n / nthr;
319     if (band == 0)
320         band = 1;
321     int tail = n - band * nthr;
322     if (tail < 0)
323         tail = 0;
324
325     if (ithr < tail) {
326         band++;
327         *t_offset = band * ithr;
328         *t_block = band;
329     } else {
330         *t_offset = band * ithr + tail;
331         *t_block = band;
332     }
333
334     if (*t_offset >= n) {
335         *t_offset = 0;
336         *t_block = 0;
337     }
338
339     if (*t_offset + *t_block > n) {
340         *t_block = n - *t_offset;
341     }
342 }
343
344 // Sum the m*n values from p_src into p_dst, assuming the two-dimensional
345 // arrays have leading dimensions ld_src and ld_dst, respectively
346 template<typename data_t>
347 void sum_two_matrices(
348         int m, int n, data_t *p_src, int ld_src, data_t *p_dst, int ld_dst)
349 {
350     int i, j;
351     for (j = 0; j < n; j++) {
352         for (i = 0; i < m; i++) {
353             p_dst[i + j * ld_dst] += p_src[i + j * ld_src];
354         }
355     }
356 }
357
358 template void sum_two_matrices<float>(
359         int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst);
360
361 template void sum_two_matrices<double>(
362         int m, int n, double *p_src, int ld_src, double *p_dst, int ld_dst);
363 }
364 }
365 }
366 }