updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_batch_normalization.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "simple_q10n.hpp"
24
25 #include "bfloat16_utils.hpp"
26 #include "ref_batch_normalization.hpp"
27
28 #define DECLARE_DATA_OFFSET \
29     auto data_offset = [&](const memory_desc_wrapper &data_d, int n, int c, \
30             int d, int h, int w) { \
31         if (has_spatial) { \
32             if (is_3d) \
33                 return data_d.off(n, c, d, h, w); \
34             else \
35                 return data_d.off(n, c, h, w); \
36         } else { \
37             return data_d.off(n, c); \
38         } \
39     }
40
41 namespace mkldnn {
42 namespace impl {
43 namespace cpu {
44
45 namespace {
46
47 typedef float acc_data_t;
48
49 template <typename T>
50 inline float maybe_up_convert(T x) {
51     return x;
52 }
53
54 template <>
55 inline float maybe_up_convert<mkldnn_bfloat16_t>(mkldnn_bfloat16_t x) {
56     return bf16_cvt_utils::cvt_bfloat16_to_float(x);
57 }
58
59 }
60
61 template <data_type_t data_type>
62 void ref_batch_normalization_fwd_t<data_type>::execute_forward()
63 const {
64     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
65     /* FIXME: check this */
66     acc_data_t *mean = pd()->stats_is_src() ?
67         const_cast<acc_data_t *>(reinterpret_cast<const acc_data_t *>(
68                this->input_memory(1))) :
69         reinterpret_cast<acc_data_t *>(this->memory(1));
70
71     acc_data_t *variance = pd()->stats_is_src() ?
72         const_cast<acc_data_t *>(reinterpret_cast<const acc_data_t *>(
73                 this->input_memory(2))) :
74         reinterpret_cast<acc_data_t *>(this->memory(2));
75
76     auto idx_scaleshift = 1 + 2 * pd()->stats_is_src();
77     auto scaleshift =
78         reinterpret_cast<const acc_data_t *>(this->input_memory(idx_scaleshift));
79
80     auto dst = reinterpret_cast<data_t *>(this->memory(0));
81     auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
82
83     /* fast return */
84     if (this->pd()->has_zero_dim_memory()) return;
85
86     const memory_desc_wrapper data_d(pd()->src_pd());
87     const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
88
89     const dim_t N = pd()->MB();
90     const dim_t C = pd()->C();
91     dim_t H = 1, W = 1, D = 1;
92     const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5);
93     if (has_spatial) {
94         D = pd()->D();
95         H = pd()->H();
96         W = pd()->W();
97     }
98
99     const float eps = pd()->desc()->batch_norm_epsilon;
100     const bool use_scaleshift = pd()->use_scaleshift();;
101     const bool save_stats = pd()->is_training();
102     const bool is_training = pd()->is_training();
103     const bool fuse_bn_relu = pd()->fuse_bn_relu();
104     const bool calculate_stats = !pd()->stats_is_src();
105
106     const bool with_relu = pd()->with_relu_post_op();
107     auto maybe_post_op = [&](acc_data_t res) {
108         return (with_relu && res < 0.0f) ? 0.0f : res;
109     };
110     const bool is_3d = data_d.ndims() == 5;
111
112     //auto data_offset(const memory_desc_wrapper &, int, int, int, int, int)
113     DECLARE_DATA_OFFSET;
114
115     parallel_nd(C, [&](int c) {
116         acc_data_t v_mean = calculate_stats ? 0 : mean[c];
117         acc_data_t v_variance = calculate_stats ? 0 : variance[c];
118
119         if (calculate_stats) {
120             for (int n = 0; n < N; ++n)
121             for (int d = 0; d < D; ++d)
122             for (int h = 0; h < H; ++h)
123             for (int w = 0; w < W; ++w) {
124                 v_mean += maybe_up_convert(src[data_offset(data_d, n, c, d, h, w)]);
125             }
126             v_mean /= W * N * H * D;
127
128             for (int n = 0; n < N; ++n)
129             for (int d = 0; d < D; ++d)
130             for (int h = 0; h < H; ++h)
131             for (int w = 0; w < W; ++w) {
132                 acc_data_t m = maybe_up_convert(src[data_offset(data_d, n, c, d, h, w)]) - v_mean;
133                 v_variance += m * m;
134             }
135             v_variance /= W * H * N * D;
136         }
137
138         acc_data_t sqrt_variance = sqrtf(v_variance + eps);
139         acc_data_t sm = (use_scaleshift
140             ? scaleshift[scaleshift_d.off(0, c)]
141             : 1.0f) / sqrt_variance;
142         acc_data_t sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0;
143
144         for (dim_t n = 0; n < N; ++n)
145         for (dim_t d = 0; d < D; ++d)
146         for (dim_t h = 0; h < H; ++h)
147         for (dim_t w = 0; w < W; ++w) {
148             auto d_off = data_offset(data_d, n, c, d, h, w);
149             acc_data_t bn_res = sm * (maybe_up_convert(src[d_off]) - v_mean) + sv;
150             if (fuse_bn_relu) {
151                 if (bn_res <= 0) {
152                     bn_res = 0;
153                     if (is_training)
154                         ws[d_off] = 0;
155                 } else {
156                     if (is_training)
157                         ws[d_off] = 1;
158                 }
159             }
160             if (data_type == data_type::s8) {
161                 dst[d_off] = qz_a1b0<float, data_t>()(
162                         maybe_post_op(bn_res), round_mode::nearest);
163             } else if (data_type == data_type::bf16) {
164                 const float bn_res_p = maybe_post_op(bn_res);
165                 bf16_cvt_utils::cvt_float_to_bfloat16(
166                         (mkldnn_bfloat16_t *)&dst[d_off], &bn_res_p);
167             } else {
168                 dst[d_off] = static_cast<data_t>(maybe_post_op(bn_res));
169             }
170         }
171
172         if (calculate_stats) {
173             if (save_stats) {
174                 mean[c] = v_mean;
175                 variance[c] = v_variance;
176             }
177         }
178     });
179 }
180
181 template struct ref_batch_normalization_fwd_t<data_type::s8>;
182 template struct ref_batch_normalization_fwd_t<data_type::f32>;
183 template struct ref_batch_normalization_fwd_t<data_type::bf16>;
184
185 template <data_type_t data_type>
186 void ref_batch_normalization_bwd_t<data_type>::execute_backward()
187 const {
188     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
189     auto mean = reinterpret_cast<const acc_data_t *>(this->input_memory(1));
190     auto variance = reinterpret_cast<const acc_data_t *>(this->input_memory(2));
191     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
192     auto scaleshift = reinterpret_cast<const acc_data_t *>(this->input_memory(4));
193     auto ws = reinterpret_cast<const uint8_t *>(
194             this->input_memory(pd()->ws_idx()));
195
196     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
197     auto diff_scaleshift = reinterpret_cast<acc_data_t *>(this->memory(1));
198
199     const memory_desc_wrapper data_d(pd()->src_pd());
200     const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
201     const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
202     const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_pd());
203     const memory_desc_wrapper mean_d(pd()->mean_pd());
204     const memory_desc_wrapper variance_d(pd()->variance_pd());
205
206     const dim_t C = pd()->C();
207
208     /* fast return */
209     if (this->pd()->has_zero_dim_memory()) {
210         if (diff_scaleshift) {
211             for (dim_t c = 0; c < C; ++c) {
212                 diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0;
213                 diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0;
214             }
215         }
216         return;
217     }
218
219     const dim_t N = pd()->MB();
220     dim_t H = 1, W = 1, D = 1;
221     const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5);
222     if (has_spatial) {
223         D = pd()->D();
224         H = pd()->H();
225         W = pd()->W();
226     }
227
228     const float eps = pd()->desc()->batch_norm_epsilon;
229     const bool use_scaleshift = pd()->use_scaleshift();
230     const bool calculate_diff_stats = !pd()->use_global_stats();
231     const bool fuse_bn_relu = pd()->fuse_bn_relu();
232
233     const bool is_3d = data_d.ndims() == 5;
234
235     //auto data_offset(const memory_desc_wrapper &, int, int, int, int, int)
236     DECLARE_DATA_OFFSET;
237
238     parallel_nd(C, [&](int c) {
239         acc_data_t v_mean = mean[mean_d.off(c)];
240         acc_data_t v_variance = variance[variance_d.off(c)];
241         acc_data_t sqrt_variance = static_cast<acc_data_t>(1.0f / sqrtf(v_variance + eps));
242         acc_data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
243         acc_data_t diff_gamma = acc_data_t(0);
244         acc_data_t diff_beta = acc_data_t(0);
245
246         for (dim_t n = 0; n < N; ++n)
247         for (dim_t d = 0; d < D; ++d)
248         for (dim_t h = 0; h < H; ++h)
249         for (dim_t w = 0; w < W; ++w) {
250             const size_t s_off = data_offset(data_d, n, c, d, h, w);
251             acc_data_t dd;
252             if (fuse_bn_relu && !ws[s_off])
253                 dd = 0;
254             else
255                 dd = maybe_up_convert(
256                     diff_dst[data_offset(diff_data_d, n, c, d, h, w)]);
257             diff_gamma += (maybe_up_convert(src[s_off]) - v_mean) * dd;
258             diff_beta += dd;
259         }
260         diff_gamma *= sqrt_variance;
261
262         if (diff_scaleshift) {
263             diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma;
264             diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta;
265         }
266
267         for (dim_t n = 0; n < N; ++n)
268         for (dim_t d = 0; d < D; ++d)
269         for (dim_t h = 0; h < H; ++h)
270         for (dim_t w = 0; w < W; ++w) {
271             const size_t s_off = data_offset(data_d, n, c, d, h, w);
272             const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w);
273             acc_data_t dd;
274             if (fuse_bn_relu && !ws[s_off])
275                 dd = 0;
276             else
277                 dd = maybe_up_convert(diff_dst[dd_off]);
278             acc_data_t v_diff_src = dd;
279             if (calculate_diff_stats) {
280                 v_diff_src -= diff_beta / (D * W * H * N) +
281                     (maybe_up_convert(src[s_off]) - v_mean) * diff_gamma * sqrt_variance / (D * W * H * N);
282             }
283             v_diff_src *= gamma * sqrt_variance;
284             if (data_type == data_type::bf16) {
285                 bf16_cvt_utils::cvt_float_to_bfloat16(
286                         (mkldnn_bfloat16_t *)&diff_src[dd_off], &v_diff_src);
287             } else {
288                 diff_src[dd_off] = static_cast<data_t>(v_diff_src);
289             }
290         }
291     });
292 }
293
294 template struct ref_batch_normalization_bwd_t<data_type::f32>;
295 template struct ref_batch_normalization_bwd_t<data_type::bf16>;
296
297 }
298 }
299 }
300
301 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s