Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / f32 / gemm_utils_f32.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include <cmath>
17
18 #include "mkldnn_thread.hpp"
19 #include "utils.hpp"
20 #include "gemm_utils_f32.hpp"
21
22 namespace mkldnn {
23 namespace impl {
24 namespace cpu {
25 namespace gemm_utils {
26 #define BM_NOCOPY_AVX 64
27 #define BN_NOCOPY_AVX 48
28 #define BK_NOCOPY_AVX 384
29 #define BN_LARGE_NOCOPY_AVX 192
30 #define BM_SMALL_NOCOPY_AVX 16
31 #define BN_SMALL_NOCOPY_AVX 1
32 #define BK_SMALL_NOCOPY_AVX 4
33 // Determine number of threads for each dimension of a 3-D partitioning
34 // algorithm based on input parameters
35 // m/n/k - First/second/third parameter for GEMM
36 // nthrs - total available number of threads
37 // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
38 // BM/BN/BK - blocking values
39 void calc_nthr_nocopy_avx(int m, int n, int k,
40         int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
41         int *BK)
42 {
43     int nthr, nthr_m, nthr_n, nthr_k;
44     int MB, NB, KB;
45
46     nthr = nthrs;
47     nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX;
48     nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX;
49     nthr_k = 1;
50
51     // Partition along K dimension
52     //  - if threading allows having barriers (e.g. OMP)
53     //  - if there is not enough parallelism along M or N
54     if (mkldnn_thr_syncable()) {
55         int nthr_other = nthr_k = 1;
56         while ((nthr_m * nthr_n * nthr_other < nthr)
57                 && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) {
58             nthr_other++;
59             if ((nthr / nthr_other) * nthr_other > 0.9 * nthr)
60                 nthr_k = nthr_other;
61         }
62     }
63     nthr /= nthr_k;
64
65     if (nthr_m == 1)
66         nthr_n = nthr;
67     if (nthr_n == 1)
68         nthr_m = nthr;
69
70     // Simple partition reduction
71     while (nthr_m * nthr_n > nthr)
72         if (nthr_m > nthr_n)
73             nthr_m--;
74         else
75             nthr_n--;
76     while (nthr_m * nthr_n < nthr)
77         if (nthr_m < nthr_n)
78             nthr_m++;
79         else
80             nthr_n++;
81
82     if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) {
83
84         if (nthr_m <= nthr_n) {
85             nthr_m = (int)sqrt((double)nthr);
86             if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX)
87                 nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX;
88             nthr_n = nthr / nthr_m;
89
90             while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
91                 nthr_m--;
92                 nthr_n = nthr / nthr_m;
93             }
94         } else {
95             nthr_n = (int)sqrt((double)nthr);
96             if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX)
97                 nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX;
98             nthr_m = nthr / nthr_n;
99
100             while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
101                 nthr_n--;
102                 nthr_m = nthr / nthr_n;
103             }
104         }
105     }
106
107     MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1;
108     MB -= MB % BM_SMALL_NOCOPY_AVX;
109     NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1;
110     NB -= NB % BN_SMALL_NOCOPY_AVX;
111     KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1;
112     KB -= KB % BK_SMALL_NOCOPY_AVX;
113
114     if (MB * nthr_m > m)
115         nthr_m = (m + MB - 1) / MB;
116     if (NB * nthr_n > n)
117         nthr_n = (n + NB - 1) / NB;
118     if (KB * nthr_k > k)
119         nthr_k = (k + KB - 1) / KB;
120
121     *nthrs_m = nthr_m;
122     *nthrs_n = nthr_n;
123     *nthrs_k = nthr_k;
124
125     *BM = MB;
126     *BN = NB;
127     *BK = KB;
128 }
129 #undef BM_NOCOPY_AVX
130 #undef BN_NOCOPY_AVX
131 #undef BK_NOCOPY_AVX
132 #undef BN_LARGE_NOCOPY_AVX
133 #undef BM_SMALL_NOCOPY_AVX
134 #undef BN_SMALL_NOCOPY_AVX
135 #undef BK_SMALL_NOCOPY_AVX
136
137 #define BM_NOCOPY_AVX512_COMMON 32
138 #define BN_NOCOPY_AVX512_COMMON 64
139 #define BK_NOCOPY_AVX512_COMMON 192
140 #define BN_LARGE_NOCOPY_AVX512_COMMON 192
141 #define BM_SMALL_NOCOPY_AVX512_COMMON 16
142 #define BN_SMALL_NOCOPY_AVX512_COMMON 1
143 #define BK_SMALL_NOCOPY_AVX512_COMMON 4
144 // Determine number of threads for each dimension of a 3-D partitioning
145 // algorithm based on input parameters
146 // m/n/k - First/second/third parameter for GEMM
147 // nthrs - total available number of threads
148 // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
149 // BM/BN/BK - blocking values
150 void calc_nthr_nocopy_avx512_common(int m,
151         int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
152         int *BM, int *BN, int *BK)
153 {
154     int nthr, nthr_m, nthr_n, nthr_k = 1;
155     int MB, NB, KB;
156     nthr = nthrs;
157
158     int counter = 0;
159     float ratio_float = 1.;
160     int ratio = 1;
161     nthr = nthrs;
162     int nthr_m_gt_n;
163
164     // Partition along K dimension
165     //  - if threading allows having barriers (e.g. OMP)
166     //  - if there is not enough parallelism along M or N
167     if (mkldnn_thr_syncable()) {
168         if (n <= 2 * BN_NOCOPY_AVX512_COMMON &&
169                 m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) {
170             nthr_k = k / BK_NOCOPY_AVX512_COMMON;
171             if (nthr_k > nthr / 4)
172                 nthr_k = nthr / 4;
173             if (nthr_k < 1)
174                 nthr_k = 1;
175
176             while ((nthr_k > 1) && (nthr % nthr_k)) {
177                 nthr_k--;
178             }
179             nthr /= nthr_k;
180         } else {
181             nthr_k = 1;
182         }
183     }
184     nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON;
185     nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON;
186
187     if (nthr_m < 1)
188         nthr_m = 1;
189     if (nthr_n < 1)
190         nthr_n = 1;
191
192     nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0;
193     ratio_float = (float)nthr_m / nthr_n;
194
195     if (nthr_m_gt_n)
196         ratio = (int)ratio_float;
197     else
198         ratio = (int)(1. / ratio_float);
199
200     // scale down nthr_m and nthr_n if they are too large
201     while (nthr_m * nthr_n > 4 * nthr) {
202         nthr_m /= 2;
203         nthr_n /= 2;
204     }
205
206     if (nthr_m < 1)
207         nthr_m = 1;
208     if (nthr_n < 1)
209         nthr_n = 1;
210
211     // Simple partition reduction
212     counter = 0;
213     while (nthr_m * nthr_n > nthr) {
214         if (nthr_m > nthr_n) {
215             if (counter < ratio)
216                 nthr_m--;
217             else {
218                 nthr_n--;
219                 counter = -1;
220             }
221         } else {
222             if (counter < ratio)
223                 nthr_n--;
224             else {
225                 nthr_m--;
226                 counter = -1;
227             }
228         }
229         counter++;
230     }
231
232     // Simple partition increment
233     counter = 0;
234     while (nthr_m * nthr_n < 0.95 * nthr) {
235         if (nthr_m > nthr_n) {
236             if (counter < ratio)
237                 nthr_m++;
238             else {
239                 nthr_n++;
240                 counter = -1;
241             }
242         } else {
243             if (counter < ratio)
244                 nthr_n++;
245             else {
246                 nthr_m++;
247                 counter = -1;
248             }
249         }
250         counter++;
251     }
252
253     // if nothing works out, then this should work
254     if ((nthr_m * nthr_n > nthr)) {
255
256         if (nthr_m <= nthr_n) {
257             nthr_m = (int)sqrt((double)nthr);
258             if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
259                             / BM_SMALL_NOCOPY_AVX512_COMMON)
260                 nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
261                         / BM_SMALL_NOCOPY_AVX512_COMMON;
262             nthr_n = nthr / nthr_m;
263
264             while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
265                 nthr_m--;
266                 nthr_n = nthr / nthr_m;
267             }
268         } else {
269             nthr_n = (int)sqrt((double)nthr);
270             if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
271                             / BN_SMALL_NOCOPY_AVX512_COMMON)
272                 nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
273                         / BN_SMALL_NOCOPY_AVX512_COMMON;
274             nthr_m = nthr / nthr_n;
275
276             while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
277                 nthr_n--;
278                 nthr_m = nthr / nthr_n;
279             }
280         }
281     }
282
283     MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1;
284     MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON;
285     NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1;
286     NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON;
287     KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1;
288     KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON;
289
290     if (MB * nthr_m > m)
291         nthr_m = (m + MB - 1) / MB;
292     if (NB * nthr_n > n)
293         nthr_n = (n + NB - 1) / NB;
294     if (KB * nthr_k > k)
295         nthr_k = (k + KB - 1) / KB;
296
297     *nthrs_m = nthr_m;
298     *nthrs_n = nthr_n;
299     *nthrs_k = nthr_k;
300
301     *BM = MB;
302     *BN = NB;
303     *BK = KB;
304 }
305 #undef BM_NOCOPY_AVX512_COMMON
306 #undef BN_NOCOPY_AVX512_COMMON
307 #undef BK_NOCOPY_AVX512_COMMON
308 #undef BN_LARGE_NOCOPY_AVX512_COMMON
309 #undef BM_SMALL_NOCOPY_AVX512_COMMON
310 #undef BN_SMALL_NOCOPY_AVX512_COMMON
311 #undef BK_SMALL_NOCOPY_AVX512_COMMON
312
313 // Partition n values as equally as possible among nthr threads
314 // and set the offset (t_offset) and number of values (t_block) for ithr
315 // Assumption: 0 <= ithr < nthr
316 void partition_unit_diff(
317         int ithr, int nthr, int n, int *t_offset, int *t_block)
318 {
319     int band = n / nthr;
320     if (band == 0)
321         band = 1;
322     int tail = n - band * nthr;
323     if (tail < 0)
324         tail = 0;
325
326     if (ithr < tail) {
327         band++;
328         *t_offset = band * ithr;
329         *t_block = band;
330     } else {
331         *t_offset = band * ithr + tail;
332         *t_block = band;
333     }
334
335     if (*t_offset >= n) {
336         *t_offset = 0;
337         *t_block = 0;
338     }
339
340     if (*t_offset + *t_block > n) {
341         *t_block = n - *t_offset;
342     }
343 }
344
345 // Sum the m*n values from p_src into p_dst, assuming the two-dimensional
346 // arrays have leading dimensions ld_src and ld_dst, respectively
347 template<typename data_t>
348 void sum_two_matrices(int m, int n,
349         data_t * __restrict p_src, dim_t ld_src,
350         data_t * __restrict p_dst, dim_t ld_dst)
351 {
352     int i, j;
353     for (j = 0; j < n; j++) {
354         for (i = 0; i < m; i++) {
355             p_dst[i + j * ld_dst] += p_src[i + j * ld_src];
356         }
357     }
358 }
359
360 template
361 void sum_two_matrices<float>(int m, int n,
362         float * __restrict p_src, dim_t ld_src,
363         float * __restrict p_dst, dim_t ld_dst);
364
365 template
366 void sum_two_matrices<double>(int m, int n,
367         double * __restrict p_src, dim_t ld_src,
368         double * __restrict p_dst, dim_t ld_dst);
369 }
370 }
371 }
372 }