1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
23 #include "cpu_batch_normalization_utils.hpp"
24 #include "jit_generator.hpp"
26 #include "ncsp_batch_normalization.hpp"
28 // clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
29 #if (defined __clang_major__) && (__clang_major__ >= 6)
30 #define SAFE_TO_USE_OMP_SIMD 0
32 #define SAFE_TO_USE_OMP_SIMD 1
39 using namespace memory_tracking::names;
41 void ncsp_batch_normalization_fwd_t::execute_forward() const {
42 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
43 auto dst = reinterpret_cast<data_t *>(this->memory(0));
44 auto scratchpad = this->scratchpad();
46 const bool calculate_stats = !pd()->stats_is_src();
47 const bool save_stats = pd()->is_training();
48 const bool is_training = pd()->is_training();
49 const bool fuse_bn_relu = pd()->fuse_bn_relu();
51 data_t *mean, *variance;
52 if (!calculate_stats) {
53 mean = reinterpret_cast<data_t *>(
54 const_cast<char *>(this->input_memory(1)));
55 variance = reinterpret_cast<data_t *>(
56 const_cast<char *>(this->input_memory(2)));
59 mean = reinterpret_cast<data_t *>(this->memory(1));
60 variance = reinterpret_cast<data_t *>(this->memory(2));
62 mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
63 variance = scratchpad.get<data_t>(key_bnorm_tmp_var);
66 auto idx_scale_shift = 1 + 2 * pd()->stats_is_src();
67 auto scaleshift = reinterpret_cast<const data_t *>(
68 this->input_memory(idx_scale_shift));
69 auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
70 auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
72 const float eps = pd()->desc()->batch_norm_epsilon;
73 const bool use_scaleshift = pd()->use_scaleshift();
74 const bool with_relu = pd()->with_relu_post_op();
76 = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
77 const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
78 int SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
79 size_t N = pd()->MB();
82 int nthr = mkldnn_get_max_threads();
83 size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
84 size_t data_size = N * C * SP * sizeof(data_t);
85 bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
87 parallel(0, [&](const int ithr, const int nthr) {
88 int C_blks_per_iter = 1, iters = 1;
89 int C_ithr = 0, C_nthr = 0, N_ithr = 0, N_nthr = 0, N_s = 0, N_e = 0;
90 int S_ithr = 0, S_nthr = 0, S_s = 0, S_e = 0;
91 int C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
93 size_t working_set_size = N * SP * sizeof(data_t);
94 bnorm_utils::cache_balance(
95 working_set_size, C, C_blks_per_iter, iters);
98 int last_iter_blks = C - (iters - 1) * C_blks_per_iter;
99 bool spatial_thr_allowed
100 = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N,
101 C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
102 N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
103 balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
104 int SP_N_ithr = N_ithr * S_nthr + S_ithr;
105 int SP_N_nthr = N_nthr * S_nthr;
106 for (int it = 0; it < iters; ++it) {
107 if (it == iters - 1 && iters > 1) {
108 // On the last iteration the access pattern to ws_reduce
109 // might change (due to re-balance on C). So sync the
110 // threads if they are not synced by the algorithm.
111 if (SP_N_nthr == 1 && mkldnn_thr_syncable())
112 mkldnn_thr_barrier();
114 S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0;
115 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
116 spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
117 C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
118 N_e, S_ithr, S_nthr, S_s, S_e);
119 balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
120 SP_N_ithr = N_ithr * S_nthr + S_ithr;
121 SP_N_nthr = N_nthr * S_nthr;
123 size_t C_off = it * C_blks_per_iter;
124 // On the last iteration the access pattern to ws_reduce
125 // might change (due to re-balance on C). Since sync is not always
126 // possible (in case of TBB) use different parts of ws for each
127 // iteration if threads are not synced by the algorithm.
128 size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * C_off;
130 if (calculate_stats) {
131 data_t *mean_blk = mean + C_off;
132 data_t *variance_blk = variance + C_off;
133 for (int c = C_blk_s; c < C_blk_e; c++) {
134 size_t off = (c + C_off) * SP;
136 for (int n = N_s; n < N_e; ++n)
137 PRAGMA_OMP_SIMD(reduction(+ : sum))
138 for (int sp = S_s; sp < S_e; ++sp) {
139 sum += src[off + n * C * SP + sp];
141 ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
145 if (SP_N_nthr > 1) mkldnn_thr_barrier();
147 for (int c = C_blk_gl_s; c < C_blk_gl_e; c++) {
149 for (int n = 0; n < SP_N_nthr; n++)
150 mean_blk[c] += ws_reduce[ws_iter_off
151 + n * C_blks_per_iter + c];
152 mean_blk[c] /= (N * SP);
155 if (SP_N_nthr > 1) mkldnn_thr_barrier();
157 for (int c = C_blk_s; c < C_blk_e; c++) {
158 size_t off = c + C_off;
160 for (int n = N_s; n < N_e; ++n)
161 PRAGMA_OMP_SIMD(reduction(+ : sum))
162 for (int sp = S_s; sp < S_e; ++sp) {
163 data_t m = src[off * SP + n * C * SP + sp]
167 ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
171 if (SP_N_nthr > 1) mkldnn_thr_barrier();
173 for (int c = C_blk_gl_s; c < C_blk_gl_e; c++) {
174 variance_blk[c] = 0.;
175 for (int n = 0; n < SP_N_nthr; n++)
176 variance_blk[c] += ws_reduce[ws_iter_off
177 + n * C_blks_per_iter + c];
178 variance_blk[c] /= (N * SP);
181 if (SP_N_nthr > 1) mkldnn_thr_barrier();
184 for (int c = C_blk_s; c < C_blk_e; c++) {
185 size_t off = c + C_off;
186 data_t sm = use_scaleshift ? scaleshift[off] : 1;
187 data_t sv = use_scaleshift ? scaleshift[C + off] : 0;
189 = static_cast<data_t>(1.0f / sqrtf(variance[off] + eps));
190 for (int n = N_s; n < N_e; ++n)
191 #if SAFE_TO_USE_OMP_SIMD
194 for (int sp = S_s; sp < S_e; ++sp) {
195 size_t d_off = off * SP + n * C * SP + sp;
197 = sm * (src[d_off] - mean[off]) * sqrt_variance
209 dst[d_off] = maybe_post_op(bn_res);
216 void ncsp_batch_normalization_bwd_t::execute_backward() const {
217 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
218 auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
219 auto variance = reinterpret_cast<const data_t *>(this->input_memory(2));
220 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
221 auto scaleshift = reinterpret_cast<const data_t *>(this->input_memory(4));
222 auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
224 auto scratchpad = this->scratchpad();
226 auto diff_scaleshift = this->memory(1)
227 ? reinterpret_cast<data_t *>(this->memory(1))
228 : scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
229 auto ws = reinterpret_cast<const uint8_t *>(
230 this->input_memory(pd()->ws_idx()));
231 auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
233 const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
234 int SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
235 size_t C = pd()->C(), N = pd()->MB();
236 const bool use_scaleshift = pd()->use_scaleshift();
237 const float eps = pd()->desc()->batch_norm_epsilon;
238 const bool calculate_diff_stats = !pd()->use_global_stats();
239 const bool fuse_bn_relu = pd()->fuse_bn_relu();
241 int nthr = mkldnn_get_max_threads();
242 size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
243 size_t data_size = N * C * SP * sizeof(data_t);
244 bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
246 parallel(0, [&](const int ithr, const int nthr) {
247 int C_blks_per_iter = 1, iters = 1;
248 int C_ithr = 0, C_nthr = 0, N_ithr = 0, N_nthr = 0, N_s = 0, N_e = 0;
249 int S_ithr = 0, S_nthr = 0, S_s = 0, S_e = 0;
250 int C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
252 size_t working_set_size = 2 * N * SP * sizeof(data_t);
253 bnorm_utils::cache_balance(
254 working_set_size, C, C_blks_per_iter, iters);
257 int last_iter_blks = C - (iters - 1) * C_blks_per_iter;
258 bool spatial_thr_allowed
259 = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N,
260 C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
261 N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
262 balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
263 int SP_N_ithr = N_ithr * S_nthr + S_ithr;
264 int SP_N_nthr = N_nthr * S_nthr;
266 for (int it = 0; it < iters; ++it) {
267 if (it == iters - 1 && iters > 1) {
268 // On the last iteration the access pattern to ws_reduce
269 // might change (due to re-balance on C). So sync the
270 // threads if they are not synced by the algorithm.
271 if (SP_N_nthr == 1 && mkldnn_thr_syncable())
272 mkldnn_thr_barrier();
274 C_blk_s = C_blk_e = N_s = N_e = 0;
275 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
276 spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
277 C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
278 N_e, S_ithr, S_nthr, S_s, S_e);
279 balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
280 SP_N_ithr = N_ithr * S_nthr + S_ithr;
281 SP_N_nthr = N_nthr * S_nthr;
283 size_t C_off = it * C_blks_per_iter;
284 // On the last iteration the access pattern to ws_reduce
285 // might change (due to re-balance on C). Since sync is not always
286 // possible (in case of TBB) use different parts of ws for each
287 // iteration if threads are not synced by the algorithm.
288 size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * 2 * C_off;
290 data_t *diff_gamma_blk = diff_scaleshift + C_off;
291 data_t *diff_beta_blk = diff_scaleshift + C + C_off;
292 for (int c = C_blk_s; c < C_blk_e; c++) {
293 size_t off = c + C_off;
294 data_t diff_gamma = 0.0, diff_beta = 0.0;
295 data_t v_mean = mean[off];
296 for (int n = N_s; n < N_e; ++n)
297 PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta))
298 for (int sp = S_s; sp < S_e; ++sp) {
299 const size_t d_off = off * SP + n * C * SP + sp;
302 dd = (!ws[d_off]) ? 0 : diff_dst[d_off];
304 dd = diff_dst[d_off];
305 diff_gamma += (src[d_off] - v_mean) * dd;
308 ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
310 ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter
311 + SP_N_ithr * C_blks_per_iter + c] = diff_beta;
314 if (SP_N_nthr > 1) mkldnn_thr_barrier();
316 for (int c = C_blk_gl_s; c < C_blk_gl_e; c++) {
317 data_t sqrt_variance = static_cast<data_t>(
318 1.0f / sqrtf(variance[c + C_off] + eps));
319 diff_gamma_blk[c] = 0.;
320 diff_beta_blk[c] = 0.;
321 for (int n = 0; n < SP_N_nthr; n++) {
322 diff_gamma_blk[c] += ws_reduce[ws_iter_off
323 + n * C_blks_per_iter + c];
324 diff_beta_blk[c] += ws_reduce[ws_iter_off
325 + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter
328 diff_gamma_blk[c] *= sqrt_variance;
331 if (SP_N_nthr > 1) mkldnn_thr_barrier();
333 for (int c = C_blk_s; c < C_blk_e; c++) {
334 size_t off = c + C_off;
335 data_t gamma = use_scaleshift ? scaleshift[off] : 1;
337 = static_cast<data_t>(1.0f / sqrtf(variance[off] + eps));
338 data_t v_mean = mean[off];
339 for (int n = N_s; n < N_e; ++n)
340 #if SAFE_TO_USE_OMP_SIMD
343 for (int sp = S_s; sp < S_e; ++sp) {
344 const size_t d_off = off * SP + n * C * SP + sp;
348 v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off];
350 v_diff_src = diff_dst[d_off];
351 if (calculate_diff_stats) {
352 v_diff_src -= diff_beta_blk[c] / (SP * N)
353 + (src[d_off] - v_mean) * diff_gamma_blk[c]
354 * sqrt_variance / (SP * N);
356 v_diff_src *= gamma * sqrt_variance;
357 diff_src[d_off] = v_diff_src;
367 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s