Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_batch_normalization.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "mkldnn_thread.hpp"
23
24 #include "ref_batch_normalization.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 template <impl::data_type_t data_type>
31 void ref_batch_normalization_fwd_t<data_type>::execute_forward() const {
32     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
33     /* FIXME: check this */
34     data_t* mean = pd()->stats_is_src() ?
35         const_cast<data_t*>(reinterpret_cast<const data_t*>(
36                this->input_memory(1))) :
37         reinterpret_cast<data_t*>(this->memory(1));
38
39     data_t* variance = pd()->stats_is_src() ?
40         const_cast<data_t*>(reinterpret_cast<const data_t*>(
41                 this->input_memory(2))) :
42         reinterpret_cast<data_t*>(this->memory(2));
43
44     auto idx_scaleshift = 1 + 2*pd()->stats_is_src();
45     auto scaleshift =
46         reinterpret_cast<const data_t *>(this->input_memory(idx_scaleshift));
47
48     auto dst = reinterpret_cast<data_t*>(this->memory(0));
49     auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
50
51     /* fast return */
52     if (this->pd()->has_zero_dim_memory()) return;
53
54     const memory_desc_wrapper data_d(pd()->src_pd());
55     const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
56
57     const int N = pd()->MB();
58     const int C = pd()->C();
59     int H = 1, W = 1, D = 1;
60     const bool has_spatial = utils::one_of(data_d.ndims(), 4 ,5);
61     if (has_spatial)
62     {
63         D = pd()->D();
64         H = pd()->H();
65         W = pd()->W();
66     }
67
68     const float eps = pd()->desc()->batch_norm_epsilon;
69     const bool use_scaleshift = pd()->use_scaleshift();;
70     const bool save_stats = pd()->is_training();
71     const bool is_training = pd()->is_training();
72     const bool fuse_bn_relu = pd()->fuse_bn_relu();
73     const bool calculate_stats = !pd()->stats_is_src();
74
75     const bool with_relu = pd()->with_relu_post_op();
76     auto maybe_post_op = [&](data_t res) {
77         return (with_relu && res < 0) ? 0 : res;
78     };
79     const bool is_3d = data_d.ndims() == 5;
80
81     auto data_offset = [&] (const memory_desc_wrapper &data_d, int n, int c, int d,
82             int h, int w) {
83         if (has_spatial)
84         {
85             if (is_3d) return data_d.off(n, c, d, h, w);
86             else return data_d.off(n, c, h, w);
87         }
88         else return data_d.off(n, c);
89     };
90
91     parallel_nd(C, [&](int c) {
92         data_t v_mean = calculate_stats ? 0 : mean[c];
93         data_t v_variance = calculate_stats ? 0 : variance[c];
94
95         data_t sm = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
96         data_t sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0;
97         if (calculate_stats) {
98             for (int n = 0; n < N; ++n)
99             for (int d = 0; d < D; ++d)
100             for (int h = 0; h < H; ++h)
101             for (int w = 0; w < W; ++w)
102                 v_mean += src[data_offset(data_d, n, c, d, h, w)];
103             v_mean /= W*N*H*D;
104
105             for (int n = 0; n < N; ++n)
106             for (int d = 0; d < D; ++d)
107             for (int h = 0; h < H; ++h)
108             for (int w = 0; w < W; ++w) {
109                 data_t m = src[data_offset(data_d,n,c,d,h,w)] - v_mean;
110                 v_variance += m*m;
111             }
112             v_variance /= W*H*N*D;
113         }
114
115         data_t sqrt_variance =
116             static_cast<data_t>(1.0f / sqrtf(v_variance + eps));
117
118         for (int n = 0; n < N; ++n)
119         for (int d = 0; d < D; ++d)
120         for (int h = 0; h < H; ++h)
121         for (int w = 0; w < W; ++w) {
122             auto d_off = data_offset(data_d,n,c,d,h,w);
123             data_t bn_res = sm * (src[d_off] - v_mean) * sqrt_variance + sv;
124             if (fuse_bn_relu) {
125                 if (bn_res <= 0) {
126                     bn_res = 0;
127                     if (is_training)
128                         ws[d_off] = 0;
129                 } else {
130                     if (is_training)
131                         ws[d_off] = 1;
132                 }
133             }
134             dst[d_off] = maybe_post_op(bn_res);
135         }
136
137         if (calculate_stats) {
138             if (save_stats) {
139                 mean[c] = v_mean;
140                 variance[c] = v_variance;
141             }
142         }
143     });
144 }
145
146 template struct ref_batch_normalization_fwd_t<data_type::f32>;
147
148 template <impl::data_type_t data_type>
149 void ref_batch_normalization_bwd_t<data_type>::execute_backward() const {
150     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
151     auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
152     auto variance = reinterpret_cast<const data_t *>(this->input_memory(2));
153     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
154     auto scaleshift = reinterpret_cast<const data_t *>(this->input_memory(4));
155     auto ws = reinterpret_cast<const uint8_t *>(
156             this->input_memory(pd()->ws_idx()));
157
158     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
159     auto diff_scaleshift = reinterpret_cast<data_t *>(this->memory(1));
160
161     const memory_desc_wrapper data_d(pd()->src_pd());
162     const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
163     const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
164     const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_pd());
165     const memory_desc_wrapper mean_d(pd()->mean_pd());
166     const memory_desc_wrapper variance_d(pd()->variance_pd());
167
168     const int C = pd()->C();
169
170     /* fast return */
171     if (this->pd()->has_zero_dim_memory()) {
172         if (diff_scaleshift) {
173             for (int c = 0; c < C; ++c) {
174                 diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0;
175                 diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0;
176             }
177         }
178         return;
179     }
180
181     const int N = pd()->MB();
182     int H = 1, W = 1, D = 1;
183     const bool has_spatial = utils::one_of(data_d.ndims(), 4 ,5);
184     if (has_spatial)
185     {
186         D = pd()->D();
187         H = pd()->H();
188         W = pd()->W();
189     }
190
191     const float eps = pd()->desc()->batch_norm_epsilon;
192     const bool use_scaleshift = pd()->use_scaleshift();
193     const bool calculate_diff_stats = !pd()->use_global_stats();
194     const bool fuse_bn_relu = pd()->fuse_bn_relu();
195
196     const bool is_3d = data_d.ndims() == 5;
197
198     auto data_offset = [&] (const memory_desc_wrapper &data_d, int n, int c, int d,
199             int h, int w) {
200         if (has_spatial)
201         {
202             if (is_3d) return data_d.off(n, c, d, h, w);
203             else return data_d.off(n, c, h, w);
204         }
205         else return data_d.off(n, c);
206     };
207
208     parallel_nd(C, [&](int c) {
209         data_t v_mean = mean[mean_d.off(c)];
210         data_t v_variance = variance[variance_d.off(c)];
211         data_t sqrt_variance = static_cast<data_t>(1.0f / sqrtf(v_variance + eps));
212         data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
213         data_t diff_gamma = data_t(0);
214         data_t diff_beta = data_t(0);
215         diff_gamma = 0.0;
216         diff_beta = 0.0;
217
218         for (int n = 0; n < N; ++n)
219         for (int d = 0; d < D; ++d)
220         for (int h = 0; h < H; ++h)
221         for (int w = 0; w < W; ++w) {
222             const size_t s_off = data_offset(data_d, n, c, d, h, w);
223             data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)];
224             if (fuse_bn_relu && !ws[s_off])
225                 dd = 0;
226
227             diff_gamma += (src[s_off] - v_mean) * dd;
228             diff_beta += dd;
229         }
230         diff_gamma *= sqrt_variance;
231
232         if (diff_scaleshift) {
233             diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma;
234             diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta;
235         }
236
237         for (int n = 0; n < N; ++n)
238         for (int d = 0; d < D; ++d)
239         for (int h = 0; h < H; ++h)
240         for (int w = 0; w < W; ++w) {
241             const size_t s_off = data_offset(data_d, n, c, d, h, w);
242             const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w);
243             data_t dd = diff_dst[dd_off];
244             if (fuse_bn_relu && !ws[s_off])
245                 dd = 0;
246
247             data_t v_diff_src = dd;
248             if (calculate_diff_stats) {
249                 v_diff_src -= diff_beta/(D*W*H*N) +
250                     (src[s_off] - v_mean) *
251                     diff_gamma*sqrt_variance/(D*W*H*N);
252             }
253             v_diff_src *= gamma*sqrt_variance;
254             diff_src[dd_off] = v_diff_src;
255         }
256     });
257 }
258
259 template struct ref_batch_normalization_bwd_t<data_type::f32>;
260
261 }
262 }
263 }
264
265 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s