1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
18 #include "mkldnn_thread.hpp"
24 namespace gemm_utils {
25 #define BM_NOCOPY_AVX2 64
26 #define BN_NOCOPY_AVX2 48
27 #define BK_NOCOPY_AVX2 384
28 #define BN_LARGE_NOCOPY_AVX2 192
29 #define BM_SMALL_NOCOPY_AVX2 16
30 #define BN_SMALL_NOCOPY_AVX2 1
31 #define BK_SMALL_NOCOPY_AVX2 4
32 // Determine number of threads for each dimension of a 3-D partitioning
33 // algorithm based on input parameters
34 // m/n/k - First/second/third parameter for GEMM
35 // nthrs - total available number of threads
36 // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
37 // BM/BN/BK - blocking values
38 void calc_nthr_nocopy_avx2(int m, int n, int k,
39 int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
42 int nthr, nthr_m, nthr_n, nthr_k, nthr_other;
46 nthr_m = (m + BM_NOCOPY_AVX2 - 1) / BM_NOCOPY_AVX2;
47 nthr_n = (n + BN_NOCOPY_AVX2 - 1) / BN_NOCOPY_AVX2;
50 // Partition along K dimension if there is not enough parallelism along M or
52 nthr_other = nthr_k = 1;
53 while ((nthr_m * nthr_n * nthr_other < nthr)
54 && (k / (nthr_other + 1) > BK_NOCOPY_AVX2)) {
56 if ((nthr / nthr_other) * nthr_other > 0.9 * nthr)
66 // Simple partition reduction
67 while (nthr_m * nthr_n > nthr)
72 while (nthr_m * nthr_n < nthr)
78 if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) {
80 if (nthr_m <= nthr_n) {
81 nthr_m = (int)sqrt((double)nthr);
82 if (nthr_m > (m + BM_SMALL_NOCOPY_AVX2 - 1) / BM_SMALL_NOCOPY_AVX2)
83 nthr_m = (m + BM_SMALL_NOCOPY_AVX2 - 1) / BM_SMALL_NOCOPY_AVX2;
84 nthr_n = nthr / nthr_m;
86 while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
88 nthr_n = nthr / nthr_m;
91 nthr_n = (int)sqrt((double)nthr);
92 if (nthr_n > (n + BN_SMALL_NOCOPY_AVX2 - 1) / BN_SMALL_NOCOPY_AVX2)
93 nthr_n = (n + BN_SMALL_NOCOPY_AVX2 - 1) / BN_SMALL_NOCOPY_AVX2;
94 nthr_m = nthr / nthr_n;
96 while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
98 nthr_m = nthr / nthr_n;
103 MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX2 - 1;
104 MB -= MB % BM_SMALL_NOCOPY_AVX2;
105 NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX2 - 1;
106 NB -= NB % BN_SMALL_NOCOPY_AVX2;
107 KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX2 - 1;
108 KB -= KB % BK_SMALL_NOCOPY_AVX2;
111 nthr_m = (m + MB - 1) / MB;
113 nthr_n = (n + NB - 1) / NB;
115 nthr_k = (k + KB - 1) / KB;
125 #undef BM_NOCOPY_AVX2
126 #undef BN_NOCOPY_AVX2
127 #undef BK_NOCOPY_AVX2
128 #undef BN_LARGE_NOCOPY_AVX2
129 #undef BM_SMALL_NOCOPY_AVX2
130 #undef BN_SMALL_NOCOPY_AVX2
131 #undef BK_SMALL_NOCOPY_AVX2
133 #define BM_NOCOPY_AVX512_COMMON 32
134 #define BN_NOCOPY_AVX512_COMMON 64
135 #define BK_NOCOPY_AVX512_COMMON 192
136 #define BN_LARGE_NOCOPY_AVX512_COMMON 192
137 #define BM_SMALL_NOCOPY_AVX512_COMMON 16
138 #define BN_SMALL_NOCOPY_AVX512_COMMON 1
139 #define BK_SMALL_NOCOPY_AVX512_COMMON 4
140 // Determine number of threads for each dimension of a 3-D partitioning
141 // algorithm based on input parameters
142 // m/n/k - First/second/third parameter for GEMM
143 // nthrs - total available number of threads
144 // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
145 // BM/BN/BK - blocking values
146 void calc_nthr_nocopy_avx512_common(int m,
147 int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
148 int *BM, int *BN, int *BK)
150 int nthr, nthr_m, nthr_n, nthr_k;
155 float ratio_float = 1.;
160 /* Partition along K dimension if there is enough K and there is not enough
162 if (n <= 2 * BN_NOCOPY_AVX512_COMMON &&
163 m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) {
164 nthr_k = k / BK_NOCOPY_AVX512_COMMON;
165 if (nthr_k > nthr / 4)
170 while ((nthr_k > 1) && (nthr % nthr_k)) {
177 nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON;
178 nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON;
185 nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0;
186 ratio_float = (float)nthr_m / nthr_n;
189 ratio = (int)ratio_float;
191 ratio = (int)(1. / ratio_float);
193 // scale down nthr_m and nthr_n if they are too large
194 while (nthr_m * nthr_n > 4 * nthr) {
204 // Simple partition reduction
206 while (nthr_m * nthr_n > nthr) {
207 if (nthr_m > nthr_n) {
225 // Simple partition increment
227 while (nthr_m * nthr_n < 0.95 * nthr) {
228 if (nthr_m > nthr_n) {
246 // if nothing works out, then this should work
247 if ((nthr_m * nthr_n > nthr)) {
249 if (nthr_m <= nthr_n) {
250 nthr_m = (int)sqrt((double)nthr);
251 if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
252 / BM_SMALL_NOCOPY_AVX512_COMMON)
253 nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
254 / BM_SMALL_NOCOPY_AVX512_COMMON;
255 nthr_n = nthr / nthr_m;
257 while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
259 nthr_n = nthr / nthr_m;
262 nthr_n = (int)sqrt((double)nthr);
263 if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
264 / BN_SMALL_NOCOPY_AVX512_COMMON)
265 nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
266 / BN_SMALL_NOCOPY_AVX512_COMMON;
267 nthr_m = nthr / nthr_n;
269 while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
271 nthr_m = nthr / nthr_n;
276 MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1;
277 MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON;
278 NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1;
279 NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON;
280 KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1;
281 KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON;
284 nthr_m = (m + MB - 1) / MB;
286 nthr_n = (n + NB - 1) / NB;
288 nthr_k = (k + KB - 1) / KB;
298 #undef BM_NOCOPY_AVX512_COMMON
299 #undef BN_NOCOPY_AVX512_COMMON
300 #undef BK_NOCOPY_AVX512_COMMON
301 #undef BN_LARGE_NOCOPY_AVX512_COMMON
302 #undef BM_SMALL_NOCOPY_AVX512_COMMON
303 #undef BN_SMALL_NOCOPY_AVX512_COMMON
304 #undef BK_SMALL_NOCOPY_AVX512_COMMON
306 // Partition n values as equally as possible among nthr threads
307 // and set the offset (t_offset) and number of values (t_block) for ithr
308 // Assumption: 0 <= ithr < nthr
309 void partition_unit_diff(
310 int ithr, int nthr, int n, int *t_offset, int *t_block)
315 int tail = n - band * nthr;
321 *t_offset = band * ithr;
324 *t_offset = band * ithr + tail;
328 if (*t_offset >= n) {
333 if (*t_offset + *t_block > n) {
334 *t_block = n - *t_offset;
338 // Sum the m*n values from p_src into p_dst, assuming the two-dimensional
339 // arrays have leading dimensions ld_src and ld_dst, respectively
340 void sum_two_matrices(
341 int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst)
344 for (j = 0; j < n; j++) {
345 for (i = 0; i < m; i++) {
346 p_dst[i + j * ld_dst] += p_src[i + j * ld_src];