Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / gemm_utils.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include <math.h>
17
18 #include "mkldnn_thread.hpp"
19 #include "utils.hpp"
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24 namespace gemm_utils {
25 #define BM_NOCOPY_AVX2 64
26 #define BN_NOCOPY_AVX2 48
27 #define BK_NOCOPY_AVX2 384
28 #define BN_LARGE_NOCOPY_AVX2 192
29 #define BM_SMALL_NOCOPY_AVX2 16
30 #define BN_SMALL_NOCOPY_AVX2 1
31 #define BK_SMALL_NOCOPY_AVX2 4
32 // Determine number of threads for each dimension of a 3-D partitioning
33 // algorithm based on input parameters
34 // m/n/k - First/second/third parameter for GEMM
35 // nthrs - total available number of threads
36 // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
37 // BM/BN/BK - blocking values
38 void calc_nthr_nocopy_avx2(int m, int n, int k,
39         int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
40         int *BK)
41 {
42     int nthr, nthr_m, nthr_n, nthr_k, nthr_other;
43     int MB, NB, KB;
44
45     nthr = nthrs;
46     nthr_m = (m + BM_NOCOPY_AVX2 - 1) / BM_NOCOPY_AVX2;
47     nthr_n = (n + BN_NOCOPY_AVX2 - 1) / BN_NOCOPY_AVX2;
48     nthr_k = 1;
49
50     // Partition along K dimension if there is not enough parallelism along M or
51     // N.
52     nthr_other = nthr_k = 1;
53     while ((nthr_m * nthr_n * nthr_other < nthr)
54             && (k / (nthr_other + 1) > BK_NOCOPY_AVX2)) {
55         nthr_other++;
56         if ((nthr / nthr_other) * nthr_other > 0.9 * nthr)
57             nthr_k = nthr_other;
58     }
59     nthr /= nthr_k;
60
61     if (nthr_m == 1)
62         nthr_n = nthr;
63     if (nthr_n == 1)
64         nthr_m = nthr;
65
66     // Simple partition reduction
67     while (nthr_m * nthr_n > nthr)
68         if (nthr_m > nthr_n)
69             nthr_m--;
70         else
71             nthr_n--;
72     while (nthr_m * nthr_n < nthr)
73         if (nthr_m < nthr_n)
74             nthr_m++;
75         else
76             nthr_n++;
77
78     if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) {
79
80         if (nthr_m <= nthr_n) {
81             nthr_m = (int)sqrt((double)nthr);
82             if (nthr_m > (m + BM_SMALL_NOCOPY_AVX2 - 1) / BM_SMALL_NOCOPY_AVX2)
83                 nthr_m = (m + BM_SMALL_NOCOPY_AVX2 - 1) / BM_SMALL_NOCOPY_AVX2;
84             nthr_n = nthr / nthr_m;
85
86             while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
87                 nthr_m--;
88                 nthr_n = nthr / nthr_m;
89             }
90         } else {
91             nthr_n = (int)sqrt((double)nthr);
92             if (nthr_n > (n + BN_SMALL_NOCOPY_AVX2 - 1) / BN_SMALL_NOCOPY_AVX2)
93                 nthr_n = (n + BN_SMALL_NOCOPY_AVX2 - 1) / BN_SMALL_NOCOPY_AVX2;
94             nthr_m = nthr / nthr_n;
95
96             while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
97                 nthr_n--;
98                 nthr_m = nthr / nthr_n;
99             }
100         }
101     }
102
103     MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX2 - 1;
104     MB -= MB % BM_SMALL_NOCOPY_AVX2;
105     NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX2 - 1;
106     NB -= NB % BN_SMALL_NOCOPY_AVX2;
107     KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX2 - 1;
108     KB -= KB % BK_SMALL_NOCOPY_AVX2;
109
110     if (MB * nthr_m > m)
111         nthr_m = (m + MB - 1) / MB;
112     if (NB * nthr_n > n)
113         nthr_n = (n + NB - 1) / NB;
114     if (KB * nthr_k > k)
115         nthr_k = (k + KB - 1) / KB;
116
117     *nthrs_m = nthr_m;
118     *nthrs_n = nthr_n;
119     *nthrs_k = nthr_k;
120
121     *BM = MB;
122     *BN = NB;
123     *BK = KB;
124 }
125 #undef BM_NOCOPY_AVX2
126 #undef BN_NOCOPY_AVX2
127 #undef BK_NOCOPY_AVX2
128 #undef BN_LARGE_NOCOPY_AVX2
129 #undef BM_SMALL_NOCOPY_AVX2
130 #undef BN_SMALL_NOCOPY_AVX2
131 #undef BK_SMALL_NOCOPY_AVX2
132
133 #define BM_NOCOPY_AVX512_COMMON 32
134 #define BN_NOCOPY_AVX512_COMMON 64
135 #define BK_NOCOPY_AVX512_COMMON 192
136 #define BN_LARGE_NOCOPY_AVX512_COMMON 192
137 #define BM_SMALL_NOCOPY_AVX512_COMMON 16
138 #define BN_SMALL_NOCOPY_AVX512_COMMON 1
139 #define BK_SMALL_NOCOPY_AVX512_COMMON 4
140 // Determine number of threads for each dimension of a 3-D partitioning
141 // algorithm based on input parameters
142 // m/n/k - First/second/third parameter for GEMM
143 // nthrs - total available number of threads
144 // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
145 // BM/BN/BK - blocking values
146 void calc_nthr_nocopy_avx512_common(int m,
147         int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
148         int *BM, int *BN, int *BK)
149 {
150     int nthr, nthr_m, nthr_n, nthr_k;
151     int MB, NB, KB;
152     nthr = nthrs;
153
154     int counter = 0;
155     float ratio_float = 1.;
156     int ratio = 1;
157     nthr = nthrs;
158     int nthr_m_gt_n;
159
160     /* Partition along K dimension if there is enough K and there is not enough
161      * M/N */
162     if (n <= 2 * BN_NOCOPY_AVX512_COMMON &&
163             m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) {
164         nthr_k = k / BK_NOCOPY_AVX512_COMMON;
165         if (nthr_k > nthr / 4)
166             nthr_k = nthr / 4;
167         if (nthr_k < 1)
168             nthr_k = 1;
169
170         while ((nthr_k > 1) && (nthr % nthr_k)) {
171             nthr_k--;
172         }
173         nthr /= nthr_k;
174     } else {
175         nthr_k = 1;
176     }
177     nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON;
178     nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON;
179
180     if (nthr_m < 1)
181         nthr_m = 1;
182     if (nthr_n < 1)
183         nthr_n = 1;
184
185     nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0;
186     ratio_float = (float)nthr_m / nthr_n;
187
188     if (nthr_m_gt_n)
189         ratio = (int)ratio_float;
190     else
191         ratio = (int)(1. / ratio_float);
192
193     // scale down nthr_m and nthr_n if they are too large
194     while (nthr_m * nthr_n > 4 * nthr) {
195         nthr_m /= 2;
196         nthr_n /= 2;
197     }
198
199     if (nthr_m < 1)
200         nthr_m = 1;
201     if (nthr_n < 1)
202         nthr_n = 1;
203
204     // Simple partition reduction
205     counter = 0;
206     while (nthr_m * nthr_n > nthr) {
207         if (nthr_m > nthr_n) {
208             if (counter < ratio)
209                 nthr_m--;
210             else {
211                 nthr_n--;
212                 counter = -1;
213             }
214         } else {
215             if (counter < ratio)
216                 nthr_n--;
217             else {
218                 nthr_m--;
219                 counter = -1;
220             }
221         }
222         counter++;
223     }
224
225     // Simple partition increment
226     counter = 0;
227     while (nthr_m * nthr_n < 0.95 * nthr) {
228         if (nthr_m > nthr_n) {
229             if (counter < ratio)
230                 nthr_m++;
231             else {
232                 nthr_n++;
233                 counter = -1;
234             }
235         } else {
236             if (counter < ratio)
237                 nthr_n++;
238             else {
239                 nthr_m++;
240                 counter = -1;
241             }
242         }
243         counter++;
244     }
245
246     // if nothing works out, then this should work
247     if ((nthr_m * nthr_n > nthr)) {
248
249         if (nthr_m <= nthr_n) {
250             nthr_m = (int)sqrt((double)nthr);
251             if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
252                             / BM_SMALL_NOCOPY_AVX512_COMMON)
253                 nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
254                         / BM_SMALL_NOCOPY_AVX512_COMMON;
255             nthr_n = nthr / nthr_m;
256
257             while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
258                 nthr_m--;
259                 nthr_n = nthr / nthr_m;
260             }
261         } else {
262             nthr_n = (int)sqrt((double)nthr);
263             if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
264                             / BN_SMALL_NOCOPY_AVX512_COMMON)
265                 nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
266                         / BN_SMALL_NOCOPY_AVX512_COMMON;
267             nthr_m = nthr / nthr_n;
268
269             while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
270                 nthr_n--;
271                 nthr_m = nthr / nthr_n;
272             }
273         }
274     }
275
276     MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1;
277     MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON;
278     NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1;
279     NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON;
280     KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1;
281     KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON;
282
283     if (MB * nthr_m > m)
284         nthr_m = (m + MB - 1) / MB;
285     if (NB * nthr_n > n)
286         nthr_n = (n + NB - 1) / NB;
287     if (KB * nthr_k > k)
288         nthr_k = (k + KB - 1) / KB;
289
290     *nthrs_m = nthr_m;
291     *nthrs_n = nthr_n;
292     *nthrs_k = nthr_k;
293
294     *BM = MB;
295     *BN = NB;
296     *BK = KB;
297 }
298 #undef BM_NOCOPY_AVX512_COMMON
299 #undef BN_NOCOPY_AVX512_COMMON
300 #undef BK_NOCOPY_AVX512_COMMON
301 #undef BN_LARGE_NOCOPY_AVX512_COMMON
302 #undef BM_SMALL_NOCOPY_AVX512_COMMON
303 #undef BN_SMALL_NOCOPY_AVX512_COMMON
304 #undef BK_SMALL_NOCOPY_AVX512_COMMON
305
306 // Partition n values as equally as possible among nthr threads
307 // and set the offset (t_offset) and number of values (t_block) for ithr
308 // Assumption: 0 <= ithr < nthr
309 void partition_unit_diff(
310         int ithr, int nthr, int n, int *t_offset, int *t_block)
311 {
312     int band = n / nthr;
313     if (band == 0)
314         band = 1;
315     int tail = n - band * nthr;
316     if (tail < 0)
317         tail = 0;
318
319     if (ithr < tail) {
320         band++;
321         *t_offset = band * ithr;
322         *t_block = band;
323     } else {
324         *t_offset = band * ithr + tail;
325         *t_block = band;
326     }
327
328     if (*t_offset >= n) {
329         *t_offset = 0;
330         *t_block = 0;
331     }
332
333     if (*t_offset + *t_block > n) {
334         *t_block = n - *t_offset;
335     }
336 }
337
338 // Sum the m*n values from p_src into p_dst, assuming the two-dimensional
339 // arrays have leading dimensions ld_src and ld_dst, respectively
340 void sum_two_matrices(
341         int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst)
342 {
343     int i, j;
344     for (j = 0; j < n; j++) {
345         for (i = 0; i < m; i++) {
346             p_dst[i + j * ld_dst] += p_src[i + j * ld_src];
347         }
348     }
349 }
350 }
351 }
352 }
353 }