Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ncsp_batch_normalization.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22
23 #include "cpu_batch_normalization_utils.hpp"
24 #include "jit_generator.hpp"
25
26 #include "ncsp_batch_normalization.hpp"
27
28 // clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
29 #if (defined __clang_major__) && (__clang_major__ >= 6)
30 #define SAFE_TO_USE_OMP_SIMD 0
31 #else
32 #define SAFE_TO_USE_OMP_SIMD 1
33 #endif
34
35 namespace mkldnn {
36 namespace impl {
37 namespace cpu {
38
39 using namespace memory_tracking::names;
40
41 void ncsp_batch_normalization_fwd_t::execute_forward() const {
42     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
43     auto dst = reinterpret_cast<data_t *>(this->memory(0));
44     auto scratchpad = this->scratchpad();
45
46     const bool calculate_stats = !pd()->stats_is_src();
47     const bool save_stats = pd()->is_training();
48     const bool is_training = pd()->is_training();
49     const bool fuse_bn_relu = pd()->fuse_bn_relu();
50
51     data_t *mean, *variance;
52     if (!calculate_stats) {
53         mean = reinterpret_cast<data_t *>(
54                 const_cast<char *>(this->input_memory(1)));
55         variance = reinterpret_cast<data_t *>(
56                 const_cast<char *>(this->input_memory(2)));
57     } else {
58         if (save_stats) {
59             mean = reinterpret_cast<data_t *>(this->memory(1));
60             variance = reinterpret_cast<data_t *>(this->memory(2));
61         } else {
62             mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
63             variance = scratchpad.get<data_t>(key_bnorm_tmp_var);
64         }
65     }
66     auto idx_scale_shift = 1 + 2 * pd()->stats_is_src();
67     auto scaleshift = reinterpret_cast<const data_t *>(
68             this->input_memory(idx_scale_shift));
69     auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
70     auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
71
72     const float eps = pd()->desc()->batch_norm_epsilon;
73     const bool use_scaleshift = pd()->use_scaleshift();
74     const bool with_relu = pd()->with_relu_post_op();
75     auto maybe_post_op
76             = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
77     const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
78     int SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
79     size_t N = pd()->MB();
80     size_t C = pd()->C();
81
82     int nthr = mkldnn_get_max_threads();
83     size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
84     size_t data_size = N * C * SP * sizeof(data_t);
85     bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
86
87     parallel(0, [&](const int ithr, const int nthr) {
88         int C_blks_per_iter = 1, iters = 1;
89         int C_ithr = 0, C_nthr = 0, N_ithr = 0, N_nthr = 0, N_s = 0, N_e = 0;
90         int S_ithr = 0, S_nthr = 0, S_s = 0, S_e = 0;
91         int C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
92         if (do_blocking) {
93             size_t working_set_size = N * SP * sizeof(data_t);
94             bnorm_utils::cache_balance(
95                     working_set_size, C, C_blks_per_iter, iters);
96         } else
97             C_blks_per_iter = C;
98         int last_iter_blks = C - (iters - 1) * C_blks_per_iter;
99         bool spatial_thr_allowed
100                 = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N,
101                         C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
102                         N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
103         balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
104         int SP_N_ithr = N_ithr * S_nthr + S_ithr;
105         int SP_N_nthr = N_nthr * S_nthr;
106         for (int it = 0; it < iters; ++it) {
107             if (it == iters - 1 && iters > 1) {
108                 // On the last iteration the access pattern to ws_reduce
109                 // might change (due to re-balance on C). So sync the
110                 // threads if they are not synced by the algorithm.
111                 if (SP_N_nthr == 1 && mkldnn_thr_syncable())
112                     mkldnn_thr_barrier();
113
114                 S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0;
115                 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
116                         spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
117                         C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
118                         N_e, S_ithr, S_nthr, S_s, S_e);
119                 balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
120                 SP_N_ithr = N_ithr * S_nthr + S_ithr;
121                 SP_N_nthr = N_nthr * S_nthr;
122             }
123             size_t C_off = it * C_blks_per_iter;
124             // On the last iteration the access pattern to ws_reduce
125             // might change (due to re-balance on C). Since sync is not always
126             // possible (in case of TBB) use different parts of ws for each
127             // iteration if threads are not synced by the algorithm.
128             size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * C_off;
129
130             if (calculate_stats) {
131                 data_t *mean_blk = mean + C_off;
132                 data_t *variance_blk = variance + C_off;
133                 for (int c = C_blk_s; c < C_blk_e; c++) {
134                     size_t off = (c + C_off) * SP;
135                     data_t sum = 0;
136                     for (int n = N_s; n < N_e; ++n)
137                         PRAGMA_OMP_SIMD(reduction(+ : sum))
138                         for (int sp = S_s; sp < S_e; ++sp) {
139                             sum += src[off + n * C * SP + sp];
140                         }
141                     ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
142                         = sum;
143                 }
144
145                 if (SP_N_nthr > 1) mkldnn_thr_barrier();
146
147                 for (int c = C_blk_gl_s; c < C_blk_gl_e; c++) {
148                     mean_blk[c] = 0.;
149                     for (int n = 0; n < SP_N_nthr; n++)
150                         mean_blk[c] += ws_reduce[ws_iter_off
151                                 + n * C_blks_per_iter + c];
152                     mean_blk[c] /= (N * SP);
153                 }
154
155                 if (SP_N_nthr > 1) mkldnn_thr_barrier();
156
157                 for (int c = C_blk_s; c < C_blk_e; c++) {
158                     size_t off = c + C_off;
159                     data_t sum = 0.;
160                     for (int n = N_s; n < N_e; ++n)
161                         PRAGMA_OMP_SIMD(reduction(+ : sum))
162                         for (int sp = S_s; sp < S_e; ++sp) {
163                             data_t m = src[off * SP + n * C * SP + sp]
164                                     - mean[off];
165                             sum += m * m;
166                         }
167                     ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
168                         = sum;
169                 }
170
171                 if (SP_N_nthr > 1) mkldnn_thr_barrier();
172
173                 for (int c = C_blk_gl_s; c < C_blk_gl_e; c++) {
174                     variance_blk[c] = 0.;
175                     for (int n = 0; n < SP_N_nthr; n++)
176                         variance_blk[c] += ws_reduce[ws_iter_off
177                                 + n * C_blks_per_iter + c];
178                     variance_blk[c] /= (N * SP);
179                 }
180
181                 if (SP_N_nthr > 1) mkldnn_thr_barrier();
182             }
183
184             for (int c = C_blk_s; c < C_blk_e; c++) {
185                 size_t off = c + C_off;
186                 data_t sm = use_scaleshift ? scaleshift[off] : 1;
187                 data_t sv = use_scaleshift ? scaleshift[C + off] : 0;
188                 data_t sqrt_variance
189                         = static_cast<data_t>(1.0f / sqrtf(variance[off] + eps));
190                 for (int n = N_s; n < N_e; ++n)
191 #if SAFE_TO_USE_OMP_SIMD
192                     PRAGMA_OMP_SIMD()
193 #endif
194                     for (int sp = S_s; sp < S_e; ++sp) {
195                         size_t d_off = off * SP + n * C * SP + sp;
196                         data_t bn_res
197                                 = sm * (src[d_off] - mean[off]) * sqrt_variance
198                                 + sv;
199                         if (fuse_bn_relu) {
200                             if (bn_res <= 0) {
201                                 bn_res = 0;
202                                 if (is_training)
203                                     ws[d_off] = 0;
204                             } else {
205                                 if (is_training)
206                                     ws[d_off] = 1;
207                             }
208                         }
209                         dst[d_off] = maybe_post_op(bn_res);
210                     }
211             }
212         }
213     });
214 }
215
216 void ncsp_batch_normalization_bwd_t::execute_backward() const {
217     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
218     auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
219     auto variance = reinterpret_cast<const data_t *>(this->input_memory(2));
220     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
221     auto scaleshift = reinterpret_cast<const data_t *>(this->input_memory(4));
222     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
223
224     auto scratchpad = this->scratchpad();
225
226     auto diff_scaleshift = this->memory(1)
227         ? reinterpret_cast<data_t *>(this->memory(1))
228         : scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
229     auto ws = reinterpret_cast<const uint8_t *>(
230             this->input_memory(pd()->ws_idx()));
231     auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
232
233     const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
234     int SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
235     size_t C = pd()->C(), N = pd()->MB();
236     const bool use_scaleshift = pd()->use_scaleshift();
237     const float eps = pd()->desc()->batch_norm_epsilon;
238     const bool calculate_diff_stats = !pd()->use_global_stats();
239     const bool fuse_bn_relu = pd()->fuse_bn_relu();
240
241     int nthr = mkldnn_get_max_threads();
242     size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
243     size_t data_size = N * C * SP * sizeof(data_t);
244     bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
245
246     parallel(0, [&](const int ithr, const int nthr) {
247         int C_blks_per_iter = 1, iters = 1;
248         int C_ithr = 0, C_nthr = 0, N_ithr = 0, N_nthr = 0, N_s = 0, N_e = 0;
249         int S_ithr = 0, S_nthr = 0, S_s = 0, S_e = 0;
250         int C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
251         if (do_blocking) {
252             size_t working_set_size = 2 * N * SP * sizeof(data_t);
253             bnorm_utils::cache_balance(
254                     working_set_size, C, C_blks_per_iter, iters);
255         } else
256             C_blks_per_iter = C;
257         int last_iter_blks = C - (iters - 1) * C_blks_per_iter;
258         bool spatial_thr_allowed
259                 = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N,
260                         C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
261                         N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
262         balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
263         int SP_N_ithr = N_ithr * S_nthr + S_ithr;
264         int SP_N_nthr = N_nthr * S_nthr;
265
266         for (int it = 0; it < iters; ++it) {
267             if (it == iters - 1 && iters > 1) {
268                 // On the last iteration the access pattern to ws_reduce
269                 // might change (due to re-balance on C). So sync the
270                 // threads if they are not synced by the algorithm.
271                 if (SP_N_nthr == 1 && mkldnn_thr_syncable())
272                     mkldnn_thr_barrier();
273
274                 C_blk_s = C_blk_e = N_s = N_e = 0;
275                 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
276                         spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
277                         C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
278                         N_e, S_ithr, S_nthr, S_s, S_e);
279                 balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
280                 SP_N_ithr = N_ithr * S_nthr + S_ithr;
281                 SP_N_nthr = N_nthr * S_nthr;
282             }
283             size_t C_off = it * C_blks_per_iter;
284             // On the last iteration the access pattern to ws_reduce
285             // might change (due to re-balance on C). Since sync is not always
286             // possible (in case of TBB) use different parts of ws for each
287             // iteration if threads are not synced by the algorithm.
288             size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * 2 * C_off;
289
290             data_t *diff_gamma_blk = diff_scaleshift + C_off;
291             data_t *diff_beta_blk = diff_scaleshift + C + C_off;
292             for (int c = C_blk_s; c < C_blk_e; c++) {
293                 size_t off = c + C_off;
294                 data_t diff_gamma = 0.0, diff_beta = 0.0;
295                 data_t v_mean = mean[off];
296                 for (int n = N_s; n < N_e; ++n)
297                     PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta))
298                     for (int sp = S_s; sp < S_e; ++sp) {
299                         const size_t d_off = off * SP + n * C * SP + sp;
300                         data_t dd;
301                         if (fuse_bn_relu)
302                             dd = (!ws[d_off]) ? 0 : diff_dst[d_off];
303                         else
304                             dd = diff_dst[d_off];
305                         diff_gamma += (src[d_off] - v_mean) * dd;
306                         diff_beta += dd;
307                     }
308                 ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
309                     = diff_gamma;
310                 ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter
311                         + SP_N_ithr * C_blks_per_iter + c] = diff_beta;
312             }
313
314             if (SP_N_nthr > 1) mkldnn_thr_barrier();
315
316             for (int c = C_blk_gl_s; c < C_blk_gl_e; c++) {
317                 data_t sqrt_variance = static_cast<data_t>(
318                         1.0f / sqrtf(variance[c + C_off] + eps));
319                 diff_gamma_blk[c] = 0.;
320                 diff_beta_blk[c] = 0.;
321                 for (int n = 0; n < SP_N_nthr; n++) {
322                     diff_gamma_blk[c] += ws_reduce[ws_iter_off
323                             + n * C_blks_per_iter + c];
324                     diff_beta_blk[c] += ws_reduce[ws_iter_off
325                             + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter
326                             + c];
327                 }
328                 diff_gamma_blk[c] *= sqrt_variance;
329             }
330
331             if (SP_N_nthr > 1) mkldnn_thr_barrier();
332
333             for (int c = C_blk_s; c < C_blk_e; c++) {
334                 size_t off = c + C_off;
335                 data_t gamma = use_scaleshift ? scaleshift[off] : 1;
336                 data_t sqrt_variance
337                         = static_cast<data_t>(1.0f / sqrtf(variance[off] + eps));
338                 data_t v_mean = mean[off];
339                 for (int n = N_s; n < N_e; ++n)
340 #if SAFE_TO_USE_OMP_SIMD
341                     PRAGMA_OMP_SIMD()
342 #endif
343                     for (int sp = S_s; sp < S_e; ++sp) {
344                         const size_t d_off = off * SP + n * C * SP + sp;
345
346                         data_t v_diff_src;
347                         if (fuse_bn_relu)
348                             v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off];
349                         else
350                             v_diff_src = diff_dst[d_off];
351                         if (calculate_diff_stats) {
352                             v_diff_src -= diff_beta_blk[c] / (SP * N)
353                                     + (src[d_off] - v_mean) * diff_gamma_blk[c]
354                                             * sqrt_variance / (SP * N);
355                         }
356                         v_diff_src *= gamma * sqrt_variance;
357                         diff_src[d_off] = v_diff_src;
358                     }
359             }
360         }
361     });
362 }
363 }
364 }
365 }
366
367 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s