1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
20 #include "jit_generator.hpp"
21 #include "cpu_batch_normalization_utils.hpp"
26 namespace bnorm_utils {
28 void cache_balance(size_t working_set_size, int C_blks, int &C_blks_per_iter,
30 int nthrs = mkldnn_get_max_threads();
31 int l3_size = get_cache_size(3, true) * nthrs / 2;
33 C_blks_per_iter = l3_size / working_set_size;
35 if (C_blks_per_iter == 0)
37 if (C_blks_per_iter > C_blks)
38 C_blks_per_iter = C_blks;
40 iters = (C_blks + C_blks_per_iter - 1) / C_blks_per_iter;
43 bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr,
44 int nthr, int N, int C_blks, int SP, int &C_ithr, int &C_nthr,
45 int &C_blk_s, int &C_blk_e, int &N_ithr, int &N_nthr, int &N_s,
46 int &N_e, int &S_ithr, int &S_nthr, int &S_s, int &S_e) {
47 if (nthr <= C_blks || !mkldnn_thr_syncable()) {
48 C_ithr = ithr; C_nthr = nthr;
49 N_ithr = 0; N_nthr = 1;
50 S_ithr = 0; S_nthr = 1;
51 N_s = 0; N_e = N; S_s = 0; S_e = SP;
52 balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e);
55 N_nthr = nstl::min(N, nthr);
56 C_nthr = nstl::min(C_blks, nthr / N_nthr);
57 S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));
59 C_nthr = math::gcd(nthr, C_blks);
60 N_nthr = nstl::min(N, nthr / C_nthr);
61 S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));
64 if (!spatial_thr_allowed)
67 if (S_nthr < 1) S_nthr = 1;
68 if (ithr < C_nthr * N_nthr * S_nthr) {
69 N_ithr = (ithr / S_nthr) % N_nthr ;
70 C_ithr = ithr / (N_nthr * S_nthr);
71 S_ithr = ithr % S_nthr;
72 balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e);
73 balance211(N, N_nthr, N_ithr, N_s, N_e);
74 balance211(SP, S_nthr, S_ithr, S_s, S_e);
76 S_ithr = N_ithr = C_ithr = -ithr;
77 S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1;
81 // spatial_thr_allowed is meant to help maintain
82 // consistent decisions about spatial threading
83 // between mutiple invocations of this routine.
84 // It is caller's responsibility to check the
85 // return value and pass it as a flag to the
86 // next call if needed.
88 spatial_thr_allowed = false;
90 return spatial_thr_allowed;
93 bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w,
95 if (!mkldnn_thr_syncable()) return false;
97 int nthr = mkldnn_get_max_threads();
98 int SP = bdesc->W() * bdesc->D() * bdesc->H();
99 int C_PADDED = memory_desc_wrapper(bdesc->src_pd())
100 .blocking_desc().padding_dims[1];
101 assert(C_PADDED % simd_w == 0);
103 size_t data = bdesc->MB() * C_PADDED * SP * data_size;
104 size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
105 bool do_blocking = (data >= l3_size_ / 2 && l3_size_ > 0);
106 int C_blks_per_iter{ 1 }, iters{ 1 };
107 int C_blks = C_PADDED / simd_w;
110 int num_tensors = bdesc->is_fwd() ? 1 : 2;
111 size_t working_set_size
112 = (bdesc->MB() * SP * simd_w * data_size) * num_tensors;
113 cache_balance(working_set_size, C_blks, C_blks_per_iter, iters);
116 // Spatial threading decision made in this function shall be consistent
117 // with thread_balance() behavior.
118 C_blks = do_blocking ? C_blks_per_iter : C_blks;
120 if (nthr <= C_blks) return false;
124 int N_nthr = nstl::min(bdesc->MB(), nthr);
125 int C_nthr = nstl::min(C_blks, nthr / N_nthr);
126 S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));
128 int C_nthr = math::gcd(nthr, C_blks);
129 int N_nthr = nstl::min(bdesc->MB(), nthr / C_nthr);
130 S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));