4 #include "caffe/layers/batch_norm_layer.hpp"
5 #include "caffe/util/math_functions.hpp"
9 template <typename Dtype>
10 void BatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
11 const vector<Blob<Dtype>*>& top) {
12 BatchNormParameter param = this->layer_param_.batch_norm_param();
13 moving_average_fraction_ = param.moving_average_fraction();
14 use_global_stats_ = this->phase_ == TEST;
15 if (param.has_use_global_stats())
16 use_global_stats_ = param.use_global_stats();
17 if (bottom[0]->num_axes() == 1)
20 channels_ = bottom[0]->shape(1);
22 if (this->blobs_.size() > 0) {
23 LOG(INFO) << "Skipping parameter initialization";
25 this->blobs_.resize(3);
27 sz.push_back(channels_);
28 this->blobs_[0].reset(new Blob<Dtype>(sz));
29 this->blobs_[1].reset(new Blob<Dtype>(sz));
31 this->blobs_[2].reset(new Blob<Dtype>(sz));
32 for (int i = 0; i < 3; ++i) {
33 caffe_set(this->blobs_[i]->count(), Dtype(0),
34 this->blobs_[i]->mutable_cpu_data());
37 // Mask statistics from optimization by setting local learning rates
38 // for mean, variance, and the bias correction to zero.
39 for (int i = 0; i < this->blobs_.size(); ++i) {
40 if (this->layer_param_.param_size() == i) {
41 ParamSpec* fixed_param_spec = this->layer_param_.add_param();
42 fixed_param_spec->set_lr_mult(0.f);
44 CHECK_EQ(this->layer_param_.param(i).lr_mult(), 0.f)
45 << "Cannot configure batch normalization statistics as layer "
51 template <typename Dtype>
52 void BatchNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
53 const vector<Blob<Dtype>*>& top) {
54 if (bottom[0]->num_axes() >= 1)
55 CHECK_EQ(bottom[0]->shape(1), channels_);
56 top[0]->ReshapeLike(*bottom[0]);
59 sz.push_back(channels_);
61 variance_.Reshape(sz);
62 temp_.ReshapeLike(*bottom[0]);
63 x_norm_.ReshapeLike(*bottom[0]);
64 sz[0] = bottom[0]->shape(0);
65 batch_sum_multiplier_.Reshape(sz);
67 int spatial_dim = bottom[0]->count()/(channels_*bottom[0]->shape(0));
68 if (spatial_sum_multiplier_.num_axes() == 0 ||
69 spatial_sum_multiplier_.shape(0) != spatial_dim) {
71 spatial_sum_multiplier_.Reshape(sz);
72 Dtype* multiplier_data = spatial_sum_multiplier_.mutable_cpu_data();
73 caffe_set(spatial_sum_multiplier_.count(), Dtype(1), multiplier_data);
76 int numbychans = channels_*bottom[0]->shape(0);
77 if (num_by_chans_.num_axes() == 0 ||
78 num_by_chans_.shape(0) != numbychans) {
80 num_by_chans_.Reshape(sz);
81 caffe_set(batch_sum_multiplier_.count(), Dtype(1),
82 batch_sum_multiplier_.mutable_cpu_data());
86 template <typename Dtype>
87 void BatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
88 const vector<Blob<Dtype>*>& top) {
89 const Dtype* bottom_data = bottom[0]->cpu_data();
90 Dtype* top_data = top[0]->mutable_cpu_data();
91 int num = bottom[0]->shape(0);
92 int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_);
94 if (bottom[0] != top[0]) {
95 caffe_copy(bottom[0]->count(), bottom_data, top_data);
98 if (use_global_stats_) {
99 // use the stored mean/variance estimates.
100 const Dtype scale_factor = this->blobs_[2]->cpu_data()[0] == 0 ?
101 0 : 1 / this->blobs_[2]->cpu_data()[0];
102 caffe_cpu_scale(variance_.count(), scale_factor,
103 this->blobs_[0]->cpu_data(), mean_.mutable_cpu_data());
104 caffe_cpu_scale(variance_.count(), scale_factor,
105 this->blobs_[1]->cpu_data(), variance_.mutable_cpu_data());
108 caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
109 1. / (num * spatial_dim), bottom_data,
110 spatial_sum_multiplier_.cpu_data(), 0.,
111 num_by_chans_.mutable_cpu_data());
112 caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
113 num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
114 mean_.mutable_cpu_data());
118 caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
119 batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
120 num_by_chans_.mutable_cpu_data());
121 caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
122 spatial_dim, 1, -1, num_by_chans_.cpu_data(),
123 spatial_sum_multiplier_.cpu_data(), 1., top_data);
125 if (!use_global_stats_) {
126 // compute variance using var(X) = E((X-EX)^2)
127 caffe_powx(top[0]->count(), top_data, Dtype(2),
128 temp_.mutable_cpu_data()); // (X-EX)^2
129 caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
130 1. / (num * spatial_dim), temp_.cpu_data(),
131 spatial_sum_multiplier_.cpu_data(), 0.,
132 num_by_chans_.mutable_cpu_data());
133 caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
134 num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
135 variance_.mutable_cpu_data()); // E((X_EX)^2)
137 // compute and save moving average
138 this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_;
139 this->blobs_[2]->mutable_cpu_data()[0] += 1;
140 caffe_cpu_axpby(mean_.count(), Dtype(1), mean_.cpu_data(),
141 moving_average_fraction_, this->blobs_[0]->mutable_cpu_data());
142 int m = bottom[0]->count()/channels_;
143 Dtype bias_correction_factor = m > 1 ? Dtype(m)/(m-1) : 1;
144 caffe_cpu_axpby(variance_.count(), bias_correction_factor,
145 variance_.cpu_data(), moving_average_fraction_,
146 this->blobs_[1]->mutable_cpu_data());
149 // normalize variance
150 caffe_add_scalar(variance_.count(), eps_, variance_.mutable_cpu_data());
151 caffe_powx(variance_.count(), variance_.cpu_data(), Dtype(0.5),
152 variance_.mutable_cpu_data());
154 // replicate variance to input size
155 caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
156 batch_sum_multiplier_.cpu_data(), variance_.cpu_data(), 0.,
157 num_by_chans_.mutable_cpu_data());
158 caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
159 spatial_dim, 1, 1., num_by_chans_.cpu_data(),
160 spatial_sum_multiplier_.cpu_data(), 0., temp_.mutable_cpu_data());
161 caffe_div(temp_.count(), top_data, temp_.cpu_data(), top_data);
162 // TODO(cdoersch): The caching is only needed because later in-place layers
163 // might clobber the data. Can we skip this if they won't?
164 caffe_copy(x_norm_.count(), top_data,
165 x_norm_.mutable_cpu_data());
168 template <typename Dtype>
169 void BatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
170 const vector<bool>& propagate_down,
171 const vector<Blob<Dtype>*>& bottom) {
172 const Dtype* top_diff;
173 if (bottom[0] != top[0]) {
174 top_diff = top[0]->cpu_diff();
176 caffe_copy(x_norm_.count(), top[0]->cpu_diff(), x_norm_.mutable_cpu_diff());
177 top_diff = x_norm_.cpu_diff();
179 Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
180 if (use_global_stats_) {
181 caffe_div(temp_.count(), top_diff, temp_.cpu_data(), bottom_diff);
184 const Dtype* top_data = x_norm_.cpu_data();
185 int num = bottom[0]->shape()[0];
186 int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_);
187 // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then
190 // (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y)
191 // ./ sqrt(var(X) + eps)
193 // where \cdot and ./ are hadamard product and elementwise division,
194 // respectively, dE/dY is the top diff, and mean/var/sum are all computed
195 // along all dimensions except the channels dimension. In the above
196 // equation, the operations allow for expansion (i.e. broadcast) along all
197 // dimensions except the channels dimension where required.
199 // sum(dE/dY \cdot Y)
200 caffe_mul(temp_.count(), top_data, top_diff, bottom_diff);
201 caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
202 bottom_diff, spatial_sum_multiplier_.cpu_data(), 0.,
203 num_by_chans_.mutable_cpu_data());
204 caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
205 num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
206 mean_.mutable_cpu_data());
208 // reshape (broadcast) the above
209 caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
210 batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
211 num_by_chans_.mutable_cpu_data());
212 caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
213 spatial_dim, 1, 1., num_by_chans_.cpu_data(),
214 spatial_sum_multiplier_.cpu_data(), 0., bottom_diff);
216 // sum(dE/dY \cdot Y) \cdot Y
217 caffe_mul(temp_.count(), top_data, bottom_diff, bottom_diff);
219 // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
220 caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
221 top_diff, spatial_sum_multiplier_.cpu_data(), 0.,
222 num_by_chans_.mutable_cpu_data());
223 caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
224 num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
225 mean_.mutable_cpu_data());
226 // reshape (broadcast) the above to make
227 // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
228 caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
229 batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
230 num_by_chans_.mutable_cpu_data());
231 caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num * channels_,
232 spatial_dim, 1, 1., num_by_chans_.cpu_data(),
233 spatial_sum_multiplier_.cpu_data(), 1., bottom_diff);
235 // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y
236 caffe_cpu_axpby(temp_.count(), Dtype(1), top_diff,
237 Dtype(-1. / (num * spatial_dim)), bottom_diff);
239 // note: temp_ still contains sqrt(var(X)+eps), computed during the forward
241 caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff);
246 STUB_GPU(BatchNormLayer);
249 INSTANTIATE_CLASS(BatchNormLayer);
250 REGISTER_LAYER_CLASS(BatchNorm);