1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "mkldnn_thread.hpp"
24 #include "ref_batch_normalization.hpp"
30 template <impl::data_type_t data_type>
31 void ref_batch_normalization_fwd_t<data_type>::execute_forward() const {
32 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
33 /* FIXME: check this */
34 data_t* mean = pd()->stats_is_src() ?
35 const_cast<data_t*>(reinterpret_cast<const data_t*>(
36 this->input_memory(1))) :
37 reinterpret_cast<data_t*>(this->memory(1));
39 data_t* variance = pd()->stats_is_src() ?
40 const_cast<data_t*>(reinterpret_cast<const data_t*>(
41 this->input_memory(2))) :
42 reinterpret_cast<data_t*>(this->memory(2));
44 auto idx_scaleshift = 1 + 2*pd()->stats_is_src();
46 reinterpret_cast<const data_t *>(this->input_memory(idx_scaleshift));
48 auto dst = reinterpret_cast<data_t*>(this->memory(0));
49 auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
52 if (this->pd()->has_zero_dim_memory()) return;
54 const memory_desc_wrapper data_d(pd()->src_pd());
55 const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
57 const int N = pd()->MB();
58 const int C = pd()->C();
59 int H = 1, W = 1, D = 1;
60 const bool has_spatial = utils::one_of(data_d.ndims(), 4 ,5);
68 const float eps = pd()->desc()->batch_norm_epsilon;
69 const bool use_scaleshift = pd()->use_scaleshift();;
70 const bool save_stats = pd()->is_training();
71 const bool is_training = pd()->is_training();
72 const bool fuse_bn_relu = pd()->fuse_bn_relu();
73 const bool calculate_stats = !pd()->stats_is_src();
75 const bool with_relu = pd()->with_relu_post_op();
76 auto maybe_post_op = [&](data_t res) {
77 return (with_relu && res < 0) ? 0 : res;
79 const bool is_3d = data_d.ndims() == 5;
81 auto data_offset = [&] (const memory_desc_wrapper &data_d, int n, int c, int d,
85 if (is_3d) return data_d.off(n, c, d, h, w);
86 else return data_d.off(n, c, h, w);
88 else return data_d.off(n, c);
91 parallel_nd(C, [&](int c) {
92 data_t v_mean = calculate_stats ? 0 : mean[c];
93 data_t v_variance = calculate_stats ? 0 : variance[c];
95 data_t sm = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
96 data_t sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0;
97 if (calculate_stats) {
98 for (int n = 0; n < N; ++n)
99 for (int d = 0; d < D; ++d)
100 for (int h = 0; h < H; ++h)
101 for (int w = 0; w < W; ++w)
102 v_mean += src[data_offset(data_d, n, c, d, h, w)];
105 for (int n = 0; n < N; ++n)
106 for (int d = 0; d < D; ++d)
107 for (int h = 0; h < H; ++h)
108 for (int w = 0; w < W; ++w) {
109 data_t m = src[data_offset(data_d,n,c,d,h,w)] - v_mean;
112 v_variance /= W*H*N*D;
115 data_t sqrt_variance =
116 static_cast<data_t>(1.0f / sqrtf(v_variance + eps));
118 for (int n = 0; n < N; ++n)
119 for (int d = 0; d < D; ++d)
120 for (int h = 0; h < H; ++h)
121 for (int w = 0; w < W; ++w) {
122 auto d_off = data_offset(data_d,n,c,d,h,w);
123 data_t bn_res = sm * (src[d_off] - v_mean) * sqrt_variance + sv;
134 dst[d_off] = maybe_post_op(bn_res);
137 if (calculate_stats) {
140 variance[c] = v_variance;
146 template struct ref_batch_normalization_fwd_t<data_type::f32>;
148 template <impl::data_type_t data_type>
149 void ref_batch_normalization_bwd_t<data_type>::execute_backward() const {
150 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
151 auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
152 auto variance = reinterpret_cast<const data_t *>(this->input_memory(2));
153 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
154 auto scaleshift = reinterpret_cast<const data_t *>(this->input_memory(4));
155 auto ws = reinterpret_cast<const uint8_t *>(
156 this->input_memory(pd()->ws_idx()));
158 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
159 auto diff_scaleshift = reinterpret_cast<data_t *>(this->memory(1));
161 const memory_desc_wrapper data_d(pd()->src_pd());
162 const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
163 const memory_desc_wrapper scaleshift_d(pd()->weights_pd());
164 const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_pd());
165 const memory_desc_wrapper mean_d(pd()->mean_pd());
166 const memory_desc_wrapper variance_d(pd()->variance_pd());
168 const int C = pd()->C();
171 if (this->pd()->has_zero_dim_memory()) {
172 if (diff_scaleshift) {
173 for (int c = 0; c < C; ++c) {
174 diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0;
175 diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0;
181 const int N = pd()->MB();
182 int H = 1, W = 1, D = 1;
183 const bool has_spatial = utils::one_of(data_d.ndims(), 4 ,5);
191 const float eps = pd()->desc()->batch_norm_epsilon;
192 const bool use_scaleshift = pd()->use_scaleshift();
193 const bool calculate_diff_stats = !pd()->use_global_stats();
194 const bool fuse_bn_relu = pd()->fuse_bn_relu();
196 const bool is_3d = data_d.ndims() == 5;
198 auto data_offset = [&] (const memory_desc_wrapper &data_d, int n, int c, int d,
202 if (is_3d) return data_d.off(n, c, d, h, w);
203 else return data_d.off(n, c, h, w);
205 else return data_d.off(n, c);
208 parallel_nd(C, [&](int c) {
209 data_t v_mean = mean[mean_d.off(c)];
210 data_t v_variance = variance[variance_d.off(c)];
211 data_t sqrt_variance = static_cast<data_t>(1.0f / sqrtf(v_variance + eps));
212 data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
213 data_t diff_gamma = data_t(0);
214 data_t diff_beta = data_t(0);
218 for (int n = 0; n < N; ++n)
219 for (int d = 0; d < D; ++d)
220 for (int h = 0; h < H; ++h)
221 for (int w = 0; w < W; ++w) {
222 const size_t s_off = data_offset(data_d, n, c, d, h, w);
223 data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)];
224 if (fuse_bn_relu && !ws[s_off])
227 diff_gamma += (src[s_off] - v_mean) * dd;
230 diff_gamma *= sqrt_variance;
232 if (diff_scaleshift) {
233 diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma;
234 diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta;
237 for (int n = 0; n < N; ++n)
238 for (int d = 0; d < D; ++d)
239 for (int h = 0; h < H; ++h)
240 for (int w = 0; w < W; ++w) {
241 const size_t s_off = data_offset(data_d, n, c, d, h, w);
242 const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w);
243 data_t dd = diff_dst[dd_off];
244 if (fuse_bn_relu && !ws[s_off])
247 data_t v_diff_src = dd;
248 if (calculate_diff_stats) {
249 v_diff_src -= diff_beta/(D*W*H*N) +
250 (src[s_off] - v_mean) *
251 diff_gamma*sqrt_variance/(D*W*H*N);
253 v_diff_src *= gamma*sqrt_variance;
254 diff_src[dd_off] = v_diff_src;
259 template struct ref_batch_normalization_bwd_t<data_type::f32>;
265 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s