Cleanup batch norm layer, include global stats computation
authorCarl Doersch <cdoersch@cs.cmu.edu>
Tue, 6 Oct 2015 21:19:59 +0000 (14:19 -0700)
committerCarl Doersch <cdoersch@cs.cmu.edu>
Thu, 22 Oct 2015 15:00:13 +0000 (08:00 -0700)
examples/cifar10/cifar10_full_sigmoid_train_test.prototxt
examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt
include/caffe/common_layers.hpp
src/caffe/layers/batch_norm_layer.cpp
src/caffe/layers/batch_norm_layer.cu
src/caffe/proto/caffe.proto
src/caffe/test/test_batch_norm_layer.cpp

index 6f5bf26..fba69b8 100644 (file)
@@ -176,10 +176,10 @@ layer {
   top: "ip1"
   param {
     lr_mult: 1
-    decay_mult: 250
+    decay_mult: 0
   }
   param {
-    lr_mult: 0.2
+    lr_mult: 2
     decay_mult: 0
   }
   inner_product_param {
index 85c2dff..1a81075 100644 (file)
@@ -12,7 +12,7 @@ layer {
   }
   data_param {
     source: "examples/cifar10/cifar10_train_lmdb"
-    batch_size: 111
+    batch_size: 100
     backend: LMDB
   }
 }
@@ -41,21 +41,16 @@ layer {
   param {
     lr_mult: 1
   }
-  param {
-    lr_mult: 2
-  }
   convolution_param {
     num_output: 32
     pad: 2
     kernel_size: 5
     stride: 1
+    bias_term: false
     weight_filler {
       type: "gaussian"
       std: 0.0001
     }
-    bias_filler {
-      type: "constant"
-    }
   }
 }
 layer {
@@ -75,23 +70,14 @@ layer {
   type: "BatchNorm"
   bottom: "pool1"
   top: "bn1"
-  bn_param {
-    scale_filler {
-      type: "constant"
-      value: 1
-    }
-    shift_filler {
-      type: "constant"
-      value: 0.001
-    } 
+  param {
+    lr_mult: 0
   }
   param {
-    lr_mult: 1.00001
-    decay_mult: 0
+    lr_mult: 0
   }
   param {
-    lr_mult: 1.00001
-    decay_mult: 0
+    lr_mult: 0
   }
 }
 
@@ -110,50 +96,35 @@ layer {
   param {
     lr_mult: 1
   }
-  param {
-    lr_mult: 2
-  }
   convolution_param {
     num_output: 32
     pad: 2
     kernel_size: 5
     stride: 1
+    bias_term: false
     weight_filler {
       type: "gaussian"
       std: 0.01
     }
-    bias_filler {
-      type: "constant"
-    }
   }
 }
 
-
-
 layer {
   name: "bn2"
   type: "BatchNorm"
   bottom: "conv2"
   top: "bn2"
-  bn_param {
-    scale_filler {
-      type: "constant"
-      value: 1
-    }
-    shift_filler {
-      type: "constant"
-      value: 0.001
-    } 
+  param {
+    lr_mult: 0
   }
   param {
-    lr_mult: 1.00001
-    decay_mult: 0
+    lr_mult: 0
   }
   param {
-    lr_mult: 1.00001
-    decay_mult: 0
+    lr_mult: 0
   }
 }
+
 layer {
   name: "Sigmoid2"
   type: "Sigmoid"
@@ -176,53 +147,38 @@ layer {
   type: "Convolution"
   bottom: "pool2"
   top: "conv3"
+  param {
+    lr_mult: 1
+  }
   convolution_param {
     num_output: 64
     pad: 2
     kernel_size: 5
     stride: 1
+    bias_term: false
     weight_filler {
       type: "gaussian"
       std: 0.01
     }
-    bias_filler {
-      type: "constant"
-    }
-  }
-  param {
-    lr_mult: 1
   }
-  param {
-    lr_mult: 1
-  }
-
 }
 
-
 layer {
   name: "bn3"
   type: "BatchNorm"
   bottom: "conv3"
   top: "bn3"
-  bn_param {
-    scale_filler {
-      type: "constant"
-      value: 1
-    }
-    shift_filler {
-      type: "constant"
-      value: 0.001
-    } 
+  param {
+    lr_mult: 0
   }
   param {
-    lr_mult: 1.00001
-    decay_mult: 0
+    lr_mult: 0
   }
   param {
-    lr_mult: 1.00001
-    decay_mult: 0
+    lr_mult: 0
   }
 }
+
 layer {
   name: "Sigmoid3"
   type: "Sigmoid"
@@ -248,10 +204,10 @@ layer {
   top: "ip1"
   param {
     lr_mult: 1
-    decay_mult: 250
+    decay_mult: 1
   }
   param {
-    lr_mult: 0.2
+    lr_mult: 1
     decay_mult: 0
   }
   inner_product_param {
index 09605db..da38f12 100644 (file)
@@ -79,9 +79,35 @@ class ArgMaxLayer : public Layer<Dtype> {
 };
 
 /**
-* @brief Batch Normalization per-channel with scale & shift linear transform.
-*
-*/
+ * @brief Normalizes the input to have 0-mean and/or unit (1) variance across
+ *        the batch.
+ *
+ * This layer computes Batch Normalization described in [1].  For
+ * each channel in the data (i.e. axis 1), it subtracts the mean and divides
+ * by the variance, where both statistics are computed across both spatial
+ * dimensions and across the different examples in the batch.
+ * 
+ * By default, during training time, the network is computing global mean/
+ * variance statistics via a running average, which is then used at test
+ * time to allow deterministic outputs for each input.  You can manually
+ * toggle whether the network is accumulating or using the statistics via the
+ * use_global_stats option.  IMPORTANT: for this feature to work, you MUST
+ * set the learning rate to zero for all three parameter blobs, i.e., 
+ * param {lr_mult: 0} three times in the layer definition.
+ *
+ * Note that the original paper also included a per-channel learned bias and
+ * scaling factor.  It is possible (though a bit cumbersome) to implement
+ * this in caffe using a single-channel DummyDataLayer filled with zeros,
+ * followed by a Convolution layer with output the same size as the current.
+ * This produces a channel-specific value that can be added or multiplied by
+ * the BatchNorm layer's output.
+ * 
+ * [1] S. Ioffe and C. Szegedy, "Batch Normalization: Accelerating Deep Network
+ *     Training by Reducing Internal Covariate Shift." arXiv preprint 
+ *     arXiv:1502.03167 (2015).  
+ *
+ * TODO(dox): thorough documentation for Forward, Backward, and proto params.
+ */
 template <typename Dtype>
 class BatchNormLayer : public Layer<Dtype> {
  public:
@@ -89,11 +115,10 @@ class BatchNormLayer : public Layer<Dtype> {
       : Layer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
-
   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
 
-  virtual inline const char* type() const { return "BN"; }
+  virtual inline const char* type() const { return "BatchNorm"; }
   virtual inline int ExactNumBottomBlobs() const { return 1; }
   virtual inline int ExactNumTopBlobs() const { return 1; }
 
@@ -105,26 +130,19 @@ class BatchNormLayer : public Layer<Dtype> {
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
-
-  // spatial mean & variance
-  Blob<Dtype> spatial_mean_, spatial_variance_;
-  // batch mean & variance
-  Blob<Dtype> batch_mean_, batch_variance_;
-  // buffer blob
-  Blob<Dtype> buffer_blob_;
+     const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
-  Blob<Dtype> x_norm_;
-  // x_sum_multiplier is used to carry out sum using BLAS
-  Blob<Dtype> spatial_sum_multiplier_, batch_sum_multiplier_;
+  Blob<Dtype> mean_, variance_, temp_, x_norm_;
+  bool use_global_stats_;
+  Dtype moving_average_fraction_;
+  int channels_;
+  Dtype eps_;
 
-  // dimension
-  int N_;
-  int C_;
-  int H_;
-  int W_;
-  // eps
-  Dtype var_eps_;
+  // extra temporarary variables is used to carry out sums/broadcasting
+  // using BLAS
+  Blob<Dtype> batch_sum_multiplier_;
+  Blob<Dtype> num_by_chans_;
+  Blob<Dtype> spatial_sum_multiplier_;
 };
 
 /**
index 8dea349..94c2b96 100644 (file)
 #include <vector>
 
 #include "caffe/common_layers.hpp"
-#include "caffe/filler.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/util/math_functions.hpp"
 
 namespace caffe {
-  template <typename Dtype>
-  void BatchNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
-      const vector<Blob<Dtype>*>& top) {
-    top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
-        bottom[0]->height(), bottom[0]->width());
-
-    x_norm_.Reshape(bottom[0]->num(), bottom[0]->channels(),
-        bottom[0]->height(), bottom[0]->width());
-
-    // Figure out the dimensions
-    N_ = bottom[0]->num();
-    C_ = bottom[0]->channels();
-    H_ = bottom[0]->height();
-    W_ = bottom[0]->width();
 
-    // mean
-    spatial_mean_.Reshape(N_, C_, 1, 1);
-    batch_mean_.Reshape(1, C_, 1, 1);
-    // variance
-    spatial_variance_.Reshape(N_, C_, 1, 1);
-    batch_variance_.Reshape(1, C_, 1, 1);
-    // buffer blod
-    buffer_blob_.Reshape(N_, C_, H_, W_);
-
-    // fill spatial multiplier
-    spatial_sum_multiplier_.Reshape(1, 1, H_, W_);
-    Dtype* spatial_multipl_data = spatial_sum_multiplier_.mutable_cpu_data();
-    caffe_set(spatial_sum_multiplier_.count(), Dtype(1),
-        spatial_multipl_data);
-    caffe_set(spatial_sum_multiplier_.count(), Dtype(0),
-        spatial_sum_multiplier_.mutable_cpu_diff());
-    // fill batch multiplier
-    batch_sum_multiplier_.Reshape(N_, 1, 1, 1);
-    Dtype* batch_multiplier_data = batch_sum_multiplier_.mutable_cpu_data();
-    caffe_set(batch_sum_multiplier_.count(), Dtype(1),
-        batch_multiplier_data);
-    caffe_set(batch_sum_multiplier_.count(), Dtype(0),
-        batch_sum_multiplier_.mutable_cpu_diff());
-    this->param_propagate_down_.resize(this->blobs_.size(), true);
-  }
-  template <typename Dtype>
-  void BatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+template <typename Dtype>
+void BatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-    CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not "
-    "allow in-place computation.";
-
-    top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
-        bottom[0]->height(), bottom[0]->width());
-
-    x_norm_.Reshape(bottom[0]->num(), bottom[0]->channels(),
-        bottom[0]->height(), bottom[0]->width());
-    // Figure out the dimensions
-    N_ = bottom[0]->num();
-    C_ = bottom[0]->channels();
-    H_ = bottom[0]->height();
-    W_ = bottom[0]->width();
-    var_eps_ = 1e-9;
-
-    // mean
-    spatial_mean_.Reshape(N_, C_, 1, 1);
-    batch_mean_.Reshape(1, C_, 1, 1);
-    // variance
-    spatial_variance_.Reshape(N_, C_, 1, 1);
-    batch_variance_.Reshape(1, C_, 1, 1);
-    // buffer blod
-    buffer_blob_.Reshape(N_, C_, H_, W_);
-
-    // fill spatial multiplier
-    spatial_sum_multiplier_.Reshape(1, 1, H_, W_);
-    Dtype* spatial_multipl_data = spatial_sum_multiplier_.mutable_cpu_data();
-    caffe_set(spatial_sum_multiplier_.count(), Dtype(1),
-        spatial_multipl_data);
-    caffe_set(spatial_sum_multiplier_.count(), Dtype(0),
-        spatial_sum_multiplier_.mutable_cpu_diff());
-
-    // fill batch multiplier
-    batch_sum_multiplier_.Reshape(N_, 1, 1, 1);
-    Dtype* batch_multiplier_data = batch_sum_multiplier_.mutable_cpu_data();
-    caffe_set(batch_sum_multiplier_.count(), Dtype(1),
-        batch_multiplier_data);
-    caffe_set(batch_sum_multiplier_.count(), Dtype(0),
-        batch_sum_multiplier_.mutable_cpu_diff());
-
-    // Check if we need to set up the weights
-    if (this->blobs_.size() > 0) {
-      LOG(INFO) << "Skipping parameter initialization";
-    } else {
-      this->blobs_.resize(2);
-
-      // fill scale with scale_filler
-      this->blobs_[0].reset(new Blob<Dtype>(1, C_, 1, 1));
-      caffe_set(this->blobs_[0]->count(), Dtype(1),
-          this->blobs_[0]->mutable_cpu_data());
-
-      // fill shift with shift_filler
-      this->blobs_[1].reset(new Blob<Dtype>(1, C_, 1, 1));
-      caffe_set(this->blobs_[1]->count(), Dtype(0),
-          this->blobs_[1]->mutable_cpu_data());
-    }  // parameter initialization
-    this->param_propagate_down_.resize(this->blobs_.size(), true);
+  BatchNormParameter param = this->layer_param_.batch_norm_param();
+  moving_average_fraction_ = param.moving_average_fraction();
+  use_global_stats_ = this->phase_ == TEST;
+  if (param.has_use_global_stats())
+    use_global_stats_ = param.use_global_stats();
+  if (bottom[0]->num_axes() == 1)
+    channels_ = 1;
+  else
+    channels_ = bottom[0]->shape(1);
+  eps_ = param.eps();
+  if (this->blobs_.size() > 0) {
+    LOG(INFO) << "Skipping parameter initialization";
+  } else {
+    this->blobs_.resize(3);
+    vector<int> sz;
+    sz.push_back(channels_);
+    this->blobs_[0].reset(new Blob<Dtype>(sz));
+    this->blobs_[1].reset(new Blob<Dtype>(sz));
+    sz[0]=1;
+    this->blobs_[2].reset(new Blob<Dtype>(sz));
+    for (int i = 0; i < 3; ++i) {
+      caffe_set(this->blobs_[i]->count(), Dtype(0),
+                this->blobs_[i]->mutable_cpu_data());
+    }
   }
+}
 
-  template <typename Dtype>
-  void BatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+template <typename Dtype>
+void BatchNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-    const Dtype* bottom_data = bottom[0]->cpu_data();
-    Dtype* top_data = top[0]->mutable_cpu_data();
-    const Dtype* const_top_data = top[0]->cpu_data();
-
-    const Dtype* scale_data = this->blobs_[0]->cpu_data();
-    const Dtype* shift_data = this->blobs_[1]->cpu_data();
-
-    // put the squares of bottom into buffer_blob_
-    caffe_powx(bottom[0]->count(), bottom_data, Dtype(2),
-        buffer_blob_.mutable_cpu_data());
+  if (bottom[0]->num_axes() >= 1)
+    CHECK_EQ(bottom[0]->shape(1), channels_);
+  top[0]->ReshapeLike(*bottom[0]);
+
+  vector<int> sz;
+  sz.push_back(channels_);
+  mean_.Reshape(sz);
+  variance_.Reshape(sz);
+  temp_.ReshapeLike(*bottom[0]);
+  x_norm_.ReshapeLike(*bottom[0]);
+  sz[0]=bottom[0]->shape(0);
+  batch_sum_multiplier_.Reshape(sz);
+
+  int spatial_dim = bottom[0]->count()/(channels_*bottom[0]->shape(0));
+  if (spatial_sum_multiplier_.num_axes() == 0 ||
+      spatial_sum_multiplier_.shape(0) != spatial_dim) {
+    sz[0] = spatial_dim;
+    spatial_sum_multiplier_.Reshape(sz);
+    Dtype* multiplier_data = spatial_sum_multiplier_.mutable_cpu_data();
+    caffe_set(spatial_sum_multiplier_.count(), Dtype(1), multiplier_data);
+  }
 
+  int numbychans = channels_*bottom[0]->shape(0);
+  if (num_by_chans_.num_axes() == 0 ||
+      num_by_chans_.shape(0) != numbychans) {
+    sz[0] = numbychans;
+    num_by_chans_.Reshape(sz);
+    caffe_set(batch_sum_multiplier_.count(), Dtype(1),
+        batch_sum_multiplier_.mutable_cpu_data());
+  }
+}
+
+template <typename Dtype>
+void BatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  int num = bottom[0]->shape(0);
+  int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_);
+
+  // elementwise square
+  caffe_powx(bottom[0]->count(), bottom_data, Dtype(2),
+             temp_.mutable_cpu_data());
+
+  if (use_global_stats_) {
+    // use the stored mean/variance estimates.  TODO(cdoersch): allow an option
+    // to use an unbiased variance estimate, like the paper does.
+    const Dtype scale_factor = 1 / this->blobs_[2]->cpu_data()[0];
+    caffe_cpu_scale(variance_.count(), scale_factor,
+        this->blobs_[0]->cpu_data(), mean_.mutable_cpu_data());
+    caffe_cpu_scale(variance_.count(), scale_factor,
+        this->blobs_[1]->cpu_data(), variance_.mutable_cpu_data());
+  } else {
     // computes variance using var(X) = E(X^2) - (EX)^2
-    // EX across spatial
-    caffe_cpu_gemv<Dtype>(CblasNoTrans, N_ * C_, H_ * W_,
-        Dtype(1. / (H_ * W_)), bottom_data,
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        spatial_mean_.mutable_cpu_data());
-    // EX across batch
-    caffe_cpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1. / N_),
-        spatial_mean_.cpu_data(),
-        batch_sum_multiplier_.cpu_data(), Dtype(0),
-        batch_mean_.mutable_cpu_data());
-
-    // E(X^2) across spatial
-    caffe_cpu_gemv<Dtype>(CblasNoTrans, N_ * C_, H_ * W_,
-        Dtype(1. / (H_ * W_)), buffer_blob_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        spatial_variance_.mutable_cpu_data());
-    // E(X^2) across batch
-    caffe_cpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1. / N_),
-        spatial_variance_.cpu_data(),
-        batch_sum_multiplier_.cpu_data(), Dtype(0),
-        batch_variance_.mutable_cpu_data());
-
-    caffe_powx(batch_mean_.count(), batch_mean_.cpu_data(), Dtype(2),
-        buffer_blob_.mutable_cpu_data());  // (EX)^2
-    caffe_sub(batch_mean_.count(), batch_variance_.cpu_data(),
-        buffer_blob_.cpu_data(),
-        batch_variance_.mutable_cpu_data());  // variance
-
-    // do mean and variance normalization
-    // subtract mean
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_,
-        C_, 1, Dtype(1),
-        batch_sum_multiplier_.cpu_data(),
-        batch_mean_.cpu_data(), Dtype(0),
-        spatial_mean_.mutable_cpu_data());
-
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(-1),
-        spatial_mean_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        buffer_blob_.mutable_cpu_data());
-
-    caffe_add(buffer_blob_.count(), bottom_data,
-        buffer_blob_.cpu_data(), top_data);
-
-    // normalize variance
-    caffe_add_scalar(batch_variance_.count(), var_eps_,
-        batch_variance_.mutable_cpu_data());
-    caffe_powx(batch_variance_.count(),
-        batch_variance_.cpu_data(), Dtype(0.5),
-        batch_variance_.mutable_cpu_data());
-
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_,
-        C_, 1, Dtype(1),
-        batch_sum_multiplier_.cpu_data(),
-        batch_variance_.cpu_data(), Dtype(0),
-        spatial_variance_.mutable_cpu_data());
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
-        N_ * C_, H_ * W_, 1, Dtype(1),
-        spatial_variance_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        buffer_blob_.mutable_cpu_data());
-
-    caffe_div(buffer_blob_.count(), const_top_data,
-        buffer_blob_.cpu_data(), top_data);
-
-    // Saving x_norm
-    caffe_copy(buffer_blob_.count(), const_top_data,
-        x_norm_.mutable_cpu_data());
-    // scale
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.cpu_data(), scale_data, Dtype(0),
-        spatial_variance_.mutable_cpu_data());
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(1),
-        spatial_variance_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        buffer_blob_.mutable_cpu_data());
-    caffe_mul(buffer_blob_.count(), top_data,
-        buffer_blob_.cpu_data(), top_data);
-
-    // shift
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.cpu_data(), shift_data, Dtype(0),
-        spatial_mean_.mutable_cpu_data());
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
-        N_ * C_, H_ * W_, 1, Dtype(1),
-        spatial_mean_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        buffer_blob_.mutable_cpu_data());
-    caffe_add(buffer_blob_.count(), const_top_data,
-        buffer_blob_.cpu_data(), top_data);
+    caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
+        1. / (num * spatial_dim), bottom_data,
+        spatial_sum_multiplier_.cpu_data(), 0.,
+        num_by_chans_.mutable_cpu_data());
+    caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+        num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
+        mean_.mutable_cpu_data());
+    caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
+        1. / (num * spatial_dim), temp_.cpu_data(),
+        spatial_sum_multiplier_.cpu_data(), 0.,
+        num_by_chans_.mutable_cpu_data());
+    caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+        num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
+        variance_.mutable_cpu_data());
+    this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_;
+    this->blobs_[2]->mutable_cpu_data()[0] += 1;
+    caffe_cpu_axpby(mean_.count(), Dtype(1), mean_.cpu_data(),
+        moving_average_fraction_, this->blobs_[0]->mutable_cpu_data());
+    Dtype m = Dtype(bottom[0]->count()/channels_);
+    caffe_cpu_axpby(variance_.count(), m/(m-1), variance_.cpu_data(),
+        moving_average_fraction_, this->blobs_[1]->mutable_cpu_data());
   }
+  // elementwise square of mean
+  caffe_powx(mean_.count(), mean_.cpu_data(), Dtype(2),
+             temp_.mutable_cpu_data());
 
-  template <typename Dtype>
-  void BatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down,
-      const vector<Blob<Dtype>*>& bottom) {
-    const Dtype* top_diff = top[0]->cpu_diff();
-    const Dtype* bottom_data = bottom[0]->cpu_data();
-    Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
-
-    Dtype* scale_diff = this->blobs_[0]->mutable_cpu_diff();
-    Dtype* shift_diff = this->blobs_[1]->mutable_cpu_diff();
-    const Dtype* scale_data = this->blobs_[0]->cpu_data();
-
-// Propagate layer to parameters
-    // gradient w.r.t. scale
-    caffe_mul(buffer_blob_.count(), x_norm_.cpu_data(),
-        top_diff, buffer_blob_.mutable_cpu_data());
-    // EX across spatial
-    caffe_cpu_gemv<Dtype>(CblasNoTrans, N_ * C_,
-        H_ * W_, Dtype(1), buffer_blob_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        spatial_variance_.mutable_cpu_diff());
-    // EX across batch
-    caffe_cpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1),
-        spatial_variance_.cpu_diff(),
-        batch_sum_multiplier_.cpu_data(), Dtype(0), scale_diff);
-
-    // gradient w.r.t. shift
-    // EX across spatial
-    caffe_cpu_gemv<Dtype>(CblasNoTrans, N_ * C_,
-        H_ * W_, Dtype(1), top_diff,
-        spatial_sum_multiplier_.cpu_data(),
-        Dtype(0), spatial_mean_.mutable_cpu_diff());
-    // EX across batch
-    caffe_cpu_gemv<Dtype>(CblasTrans, N_, C_,
-        Dtype(1), spatial_mean_.cpu_diff(),
-        batch_sum_multiplier_.cpu_data(),
-        Dtype(0), shift_diff);
+  caffe_sub(mean_.count(), variance_.cpu_data(), temp_.cpu_data(),
+            variance_.mutable_cpu_data());  // variance
 
-// Propagate down
+  // normalize variance
+  caffe_add_scalar(variance_.count(), eps_, variance_.mutable_cpu_data());
+  caffe_powx(variance_.count(), variance_.cpu_data(), Dtype(0.5),
+             variance_.mutable_cpu_data());
 
-    // put scale * top_diff to buffer_blob_
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.cpu_data(), scale_data, Dtype(0),
-        spatial_variance_.mutable_cpu_data());
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(1),
-        spatial_variance_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        buffer_blob_.mutable_cpu_data());
-    caffe_mul(buffer_blob_.count(), top_diff, buffer_blob_.cpu_data(),
-        buffer_blob_.mutable_cpu_data());
-
-    // use new top diff for computation
-    caffe_mul(buffer_blob_.count(),  x_norm_.cpu_data(),
-        buffer_blob_.cpu_data(), bottom_diff);
-    // EX across spatial
-    caffe_cpu_gemv<Dtype>(CblasNoTrans, N_ * C_, H_ * W_,
-        Dtype(1), bottom_diff,
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        spatial_mean_.mutable_cpu_data());
-    // EX across batch
-    caffe_cpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1),
-        spatial_mean_.cpu_data(),
-        batch_sum_multiplier_.cpu_data(), Dtype(0),
-        batch_mean_.mutable_cpu_data());
-
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
-        N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.cpu_data(),
-        batch_mean_.cpu_data(), Dtype(0),
-        spatial_mean_.mutable_cpu_data());
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(1),
-        spatial_mean_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        bottom_diff);
-
-    caffe_mul(buffer_blob_.count(),
-        x_norm_.cpu_data(), bottom_diff, bottom_diff);
-
-    // EX across spatial
-    caffe_cpu_gemv<Dtype>(CblasNoTrans, N_ * C_,
-        H_ * W_, Dtype(1), buffer_blob_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        spatial_mean_.mutable_cpu_data());
-    // EX across batch
-    caffe_cpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1),
-        spatial_mean_.cpu_data(),
-        batch_sum_multiplier_.cpu_data(), Dtype(0),
-        batch_mean_.mutable_cpu_data());
-
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
-        N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.cpu_data(),
-        batch_mean_.cpu_data(), Dtype(0),
-        spatial_mean_.mutable_cpu_data());
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
-        N_ * C_, H_ * W_, 1, Dtype(1),
-        spatial_mean_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(1), bottom_diff);
-
-    caffe_cpu_axpby(buffer_blob_.count(), Dtype(1),
-        buffer_blob_.cpu_data(), Dtype(-1. / (N_ * H_ * W_)),
-        bottom_diff);
-
-    // put the squares of bottom into buffer_blob_
-//    caffe_powx(buffer_blob_.count(), bottom_data, Dtype(2),
-//        buffer_blob_.mutable_cpu_data());
+  // do mean and variance normalization
+  if (bottom[0] != top[0]) {
+    caffe_copy(bottom[0]->count(), bottom_data, top_data);
+  }
+  // subtract mean
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
+      batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
+      num_by_chans_.mutable_cpu_data());
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
+      spatial_dim, 1, -1, num_by_chans_.cpu_data(),
+      spatial_sum_multiplier_.cpu_data(), 1., top_data);
+  // replicate variance to input size
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
+      batch_sum_multiplier_.cpu_data(), variance_.cpu_data(), 0.,
+      num_by_chans_.mutable_cpu_data());
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
+      spatial_dim, 1, 1., num_by_chans_.cpu_data(),
+      spatial_sum_multiplier_.cpu_data(), 0., temp_.mutable_cpu_data());
+  caffe_div(temp_.count(), top_data, temp_.cpu_data(), top_data);
+  // TODO(cdoersch): The caching is only needed because later in-place layers
+  //                 might clobber the data.  Can we skip this if they won't?
+  caffe_copy(x_norm_.count(), top_data,
+      x_norm_.mutable_cpu_data());
+}
+
+template <typename Dtype>
+void BatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  CHECK(!use_global_stats_);
+  const Dtype* top_diff;
+  if (bottom[0] != top[0]) {
+    top_diff = top[0]->cpu_diff();
+  } else {
+    caffe_copy(x_norm_.count(), top[0]->cpu_diff(), x_norm_.mutable_cpu_diff());
+    top_diff = x_norm_.cpu_diff();
+  }
+  const Dtype* top_data = x_norm_.cpu_data();
+  Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+  int num = bottom[0]->shape()[0];
+  int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_);
+  // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then
+  //
+  // dE(Y)/dX =
+  //   (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y)
+  //     ./ sqrt(var(X) + eps)
+  //
+  // where \cdot and ./ are hadamard product and elementwise division,
+  // respectively, dE/dY is the top diff, and mean/var/sum are all computed
+  // along all dimensions except the channels dimension.  In the above
+  // equation, the operations allow for expansion (i.e. broadcast) along all
+  // dimensions except the channels dimension where required.
+
+  // sum(dE/dY \cdot Y)
+  caffe_mul(temp_.count(), top_data, top_diff, bottom_diff);
+  caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
+      bottom_diff, spatial_sum_multiplier_.cpu_data(), 0.,
+      num_by_chans_.mutable_cpu_data());
+  caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+      num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
+      mean_.mutable_cpu_data());
+
+  // reshape (broadcast) the above
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
+      batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
+      num_by_chans_.mutable_cpu_data());
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
+      spatial_dim, 1, 1., num_by_chans_.cpu_data(),
+      spatial_sum_multiplier_.cpu_data(), 0., bottom_diff);
+
+  // sum(dE/dY \cdot Y) \cdot Y
+  caffe_mul(temp_.count(), top_data, bottom_diff, bottom_diff);
+
+  // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
+  caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
+      top_diff, spatial_sum_multiplier_.cpu_data(), 0.,
+      num_by_chans_.mutable_cpu_data());
+  caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+      num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
+      mean_.mutable_cpu_data());
+  // reshape (broadcast) the above to make
+  // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
+      batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
+      num_by_chans_.mutable_cpu_data());
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num * channels_,
+      spatial_dim, 1, 1., num_by_chans_.cpu_data(),
+      spatial_sum_multiplier_.cpu_data(), 1., bottom_diff);
+
+  // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y
+  caffe_cpu_axpby(temp_.count(), Dtype(1), top_diff,
+      Dtype(-1. / (num * spatial_dim)), bottom_diff);
+
+  // note: temp_ still contains sqrt(var(X)+eps), computed during the forward
+  // pass.
+  caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff);
+}
 
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
-        N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.cpu_data(),
-        batch_variance_.cpu_data(), Dtype(0),
-        spatial_variance_.mutable_cpu_data());
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
-        N_ * C_, H_ * W_, 1, Dtype(1),
-        spatial_variance_.cpu_data(),
-        spatial_sum_multiplier_.cpu_data(), Dtype(0),
-        buffer_blob_.mutable_cpu_data());
 
-    caffe_div(buffer_blob_.count(), bottom_diff,
-        buffer_blob_.cpu_data(), bottom_diff);
-  }
 #ifdef CPU_ONLY
 STUB_GPU(BatchNormLayer);
 #endif
 
-  INSTANTIATE_CLASS(BatchNormLayer);
-  REGISTER_LAYER_CLASS(BatchNorm);
+INSTANTIATE_CLASS(BatchNormLayer);
+REGISTER_LAYER_CLASS(BatchNorm);
 }  // namespace caffe
-
index e87f8c6..cd8924a 100644 (file)
 #include <vector>
 
 #include "caffe/common_layers.hpp"
-#include "caffe/filler.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/util/math_functions.hpp"
 
 namespace caffe {
-  template <typename Dtype>
-  void BatchNormLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      const vector<Blob<Dtype>*>& top) {
-    const Dtype* bottom_data = bottom[0]->gpu_data();
-    const Dtype* const_top_data = top[0]->gpu_data();
-    Dtype* top_data = top[0]->mutable_gpu_data();
-    Dtype* spatial_mean_data = spatial_mean_.mutable_gpu_data();
-    Dtype* buffer_data = buffer_blob_.mutable_gpu_data();
-    const Dtype* const_buffer_data = buffer_blob_.gpu_data();
-
-
-  // put the squares of bottom into buffer_blob_
-    caffe_gpu_powx(bottom[0]->count(), bottom_data, Dtype(2),
-        buffer_blob_.mutable_gpu_data());
 
+template <typename Dtype>
+void BatchNormLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = top[0]->mutable_gpu_data();
+  int num = bottom[0]->shape(0);
+  int spatial_dim = bottom[0]->count()/(channels_*bottom[0]->shape(0));
+
+  // elementwise square
+  caffe_gpu_powx(bottom[0]->count(), bottom_data, Dtype(2),
+      temp_.mutable_gpu_data());
+
+  if (use_global_stats_) {
+    // use the stored mean/variance estimates.  TODO(cdoersch): allow an option
+    // to use an unbiased variance estimate, like the paper does.
+    const Dtype scale_factor = 1 / this->blobs_[2]->cpu_data()[0];
+    caffe_gpu_scale(variance_.count(), scale_factor,
+        this->blobs_[0]->gpu_data(), mean_.mutable_gpu_data());
+    caffe_gpu_scale(variance_.count(), scale_factor,
+        this->blobs_[1]->gpu_data(), variance_.mutable_gpu_data());
+  } else {
     // computes variance using var(X) = E(X^2) - (EX)^2
-    // EX across spatial
-    caffe_gpu_gemv<Dtype>(CblasNoTrans, N_ * C_, H_ * W_,
-      Dtype(1. / (H_ * W_)),
-      bottom_data, spatial_sum_multiplier_.gpu_data(),
-      Dtype(0), spatial_mean_data);
-    // EX across batch
-    caffe_gpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1. / N_),
-        spatial_mean_.gpu_data(),
-        batch_sum_multiplier_.gpu_data(), Dtype(0),
-      batch_mean_.mutable_gpu_data());
-
-    // E(X^2) across spatial
-    caffe_gpu_gemv<Dtype>(CblasNoTrans, N_ * C_, H_ * W_,
-      Dtype(1. / (H_ * W_)), buffer_data,
-        spatial_sum_multiplier_.gpu_data(), Dtype(0),
-      spatial_variance_.mutable_gpu_data());
-    // E(X^2) across batch
-    caffe_gpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1. / N_),
-        spatial_variance_.gpu_data(),
-        batch_sum_multiplier_.gpu_data(), Dtype(0),
-      batch_variance_.mutable_gpu_data());
-
-    caffe_gpu_powx(batch_mean_.count(), batch_mean_.gpu_data(),
-      Dtype(2), buffer_blob_.mutable_gpu_data());  // (EX)^2
-    caffe_gpu_sub(batch_mean_.count(), batch_variance_.gpu_data(),
-      buffer_data, batch_variance_.mutable_gpu_data());  // variance
-
-    // do mean and variance normalization
-    // subtract mean
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.gpu_data(), batch_mean_.gpu_data(), Dtype(0),
-        spatial_mean_data);
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_, H_ * W_,
-        1, -Dtype(1),
-        spatial_mean_.gpu_data(), spatial_sum_multiplier_.gpu_data(), Dtype(0),
-        buffer_blob_.mutable_gpu_data());
-
-    caffe_gpu_add(buffer_blob_.count(), bottom_data, buffer_data, top_data);
-
-    // normalize variance
-    caffe_gpu_add_scalar(batch_variance_.count(), var_eps_,
-      batch_variance_.mutable_gpu_data());
-    caffe_gpu_powx(batch_variance_.count(), batch_variance_.gpu_data(),
-      Dtype(0.5), batch_variance_.mutable_gpu_data());
-
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.gpu_data(), batch_variance_.gpu_data(), Dtype(0),
-        spatial_variance_.mutable_gpu_data());
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(1),
-        spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(),
-        Dtype(0), buffer_blob_.mutable_gpu_data());
-
-    caffe_gpu_div(buffer_blob_.count(), top_data, buffer_data, top_data);
-
-    // Saving x_norm
-    caffe_copy(top[0]->count(), const_top_data, x_norm_.mutable_gpu_data());
-
-    // scale
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.gpu_data(), this->blobs_[0]->gpu_data(),
-        Dtype(0), spatial_variance_.mutable_gpu_data());
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(1),
-        spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(),
-        Dtype(0), buffer_blob_.mutable_gpu_data());
-
-    caffe_gpu_mul(buffer_blob_.count(), top_data, buffer_data, top_data);
-
-    // shift
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.gpu_data(),
-        this->blobs_[1]->gpu_data(), Dtype(0),
-        spatial_mean_data);
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_, H_ * W_, 1,
-        Dtype(1),
-        spatial_mean_.gpu_data(), spatial_sum_multiplier_.gpu_data(), Dtype(0),
-        buffer_blob_.mutable_gpu_data());
-    caffe_gpu_add(buffer_blob_.count(), top_data, buffer_data, top_data);
+    caffe_gpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
+        1. / (num * spatial_dim), bottom_data,
+        spatial_sum_multiplier_.gpu_data(), 0.,
+        num_by_chans_.mutable_gpu_data());
+    caffe_gpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+        num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0.,
+        mean_.mutable_gpu_data());
+    caffe_gpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
+        1. / (num * spatial_dim), temp_.gpu_data(),
+        spatial_sum_multiplier_.gpu_data(), 0.,
+        num_by_chans_.mutable_gpu_data());
+    caffe_gpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+        num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0.,
+        variance_.mutable_gpu_data());
+    this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_;
+    this->blobs_[2]->mutable_cpu_data()[0] += 1;
+    caffe_gpu_axpby(mean_.count(), Dtype(1), mean_.gpu_data(),
+        moving_average_fraction_, this->blobs_[0]->mutable_gpu_data());
+    Dtype m = Dtype(bottom[0]->count()/channels_);
+    caffe_gpu_axpby(variance_.count(), m/(m-1), variance_.gpu_data(),
+        moving_average_fraction_, this->blobs_[1]->mutable_gpu_data());
   }
+  // elementwise square of mean
+  caffe_gpu_powx(mean_.count(), mean_.gpu_data(), Dtype(2),
+                 temp_.mutable_gpu_data());
+
+  caffe_gpu_sub(mean_.count(), variance_.gpu_data(), temp_.gpu_data(),
+                variance_.mutable_gpu_data());  // variance
+
+  // normalize variance
+  caffe_gpu_add_scalar(variance_.count(), eps_, variance_.mutable_gpu_data());
+  caffe_gpu_powx(variance_.count(), variance_.gpu_data(), Dtype(0.5),
+      variance_.mutable_gpu_data());
 
-  template <typename Dtype>
-  void BatchNormLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down,
-      const vector<Blob<Dtype>*>& bottom) {
-    const Dtype* top_diff = top[0]->gpu_diff();
-    const Dtype* top_data = top[0]->gpu_data();
-    const Dtype* bottom_data = bottom[0]->gpu_data();
-    Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
-    const Dtype* const_bottom_diff = bottom[0]->gpu_diff();
-    Dtype* spatial_mean_data = spatial_mean_.mutable_gpu_data();
-    Dtype* buffer_data = buffer_blob_.mutable_gpu_data();
-    const Dtype* const_buffer_data = buffer_blob_.gpu_data();
-
-    // Propage to layer params
-    // gradient w.r.t. scale
-    caffe_gpu_mul(buffer_blob_.count(), x_norm_.gpu_data(),
-    top_diff, buffer_blob_.mutable_gpu_data());
-    // EX across spatial
-    caffe_gpu_gemv<Dtype>(CblasNoTrans, N_ * C_, H_ * W_, Dtype(1),
-        buffer_data, spatial_sum_multiplier_.gpu_data(), Dtype(0),
-    spatial_variance_.mutable_gpu_data());
-    // EX across batch
-    caffe_gpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1),
-        spatial_variance_.gpu_data(),
-        batch_sum_multiplier_.gpu_data(), Dtype(0),
-        this->blobs_[0]->mutable_gpu_diff());
-
-    // gradient w.r.t. shift
-    // EX across spatial
-    caffe_gpu_gemv<Dtype>(CblasNoTrans, N_ * C_, H_ * W_, Dtype(1),
-      top_diff, spatial_sum_multiplier_.gpu_data(),
-      Dtype(0), spatial_mean_data);
-    // EX across batch
-    caffe_gpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1),
-        spatial_mean_.gpu_data(),
-        batch_sum_multiplier_.gpu_data(), Dtype(0),
-        this->blobs_[1]->mutable_gpu_diff());
-
-    // Propagate down
-    // scale top diff
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.gpu_data(), this->blobs_[0]->gpu_data(),
-        Dtype(0), spatial_variance_.mutable_gpu_data());
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(1),
-        spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(),
-        Dtype(0),
-        buffer_blob_.mutable_gpu_data());
-    caffe_gpu_mul(buffer_blob_.count(), top_diff, buffer_data,
-        buffer_blob_.mutable_gpu_data());
-
-    // use new top diff for computation
-    caffe_gpu_mul(buffer_blob_.count(), x_norm_.gpu_data(),
-       buffer_data, bottom_diff);
-    // EX across spatial
-    caffe_gpu_gemv<Dtype>(CblasNoTrans, N_ * C_, H_ * W_,
-        Dtype(1), bottom_diff,
-        spatial_sum_multiplier_.gpu_data(), Dtype(0), spatial_mean_data);
-    // EX across batch
-    caffe_gpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1),
-        spatial_mean_.gpu_data(),
-        batch_sum_multiplier_.gpu_data(), Dtype(0),
-        batch_mean_.mutable_gpu_data());
-
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.gpu_data(),
-        batch_mean_.gpu_data(), Dtype(0),
-        spatial_mean_data);
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(1), spatial_mean_.gpu_data(),
-        spatial_sum_multiplier_.gpu_data(), Dtype(0),
-        bottom_diff);
-
-    caffe_gpu_mul(buffer_blob_.count(), x_norm_.gpu_data(),
-        bottom_diff, bottom_diff);
-
-    // EX across spatial
-    caffe_gpu_gemv<Dtype>(CblasNoTrans, N_ * C_, H_ * W_, Dtype(1),
-        buffer_data, spatial_sum_multiplier_.gpu_data(),
-        Dtype(0), spatial_mean_data);
-
-    // EX across batch
-    caffe_gpu_gemv<Dtype>(CblasTrans, N_, C_, Dtype(1),
-        spatial_mean_.gpu_data(),
-        batch_sum_multiplier_.gpu_data(), Dtype(0),
-        batch_mean_.mutable_gpu_data());
-
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_,
-        C_, 1, Dtype(1),
-        batch_sum_multiplier_.gpu_data(),
-        batch_mean_.gpu_data(), Dtype(0),
-        spatial_mean_data);
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(1),
-        spatial_mean_.gpu_data(), spatial_sum_multiplier_.gpu_data(),
-        Dtype(1),
-        bottom_diff);
-
-    caffe_gpu_axpby(buffer_blob_.count(), Dtype(1), buffer_data,
-        Dtype(-1. / (N_ * H_ * W_)),
-        bottom_diff);
-
-    // put the squares of bottom into buffer_blob_
-//    caffe_gpu_powx(buffer_blob_.count(), bottom_data, Dtype(2),
-//        buffer_blob_.mutable_gpu_data());
-
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1),
-        batch_sum_multiplier_.gpu_data(), batch_variance_.gpu_data(), Dtype(0),
-        spatial_variance_.mutable_gpu_data());
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N_ * C_,
-        H_ * W_, 1, Dtype(1),
-        spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(),
-        Dtype(0),
-        buffer_blob_.mutable_gpu_data());
-
-    caffe_gpu_div(buffer_blob_.count(), const_bottom_diff,
-    const_buffer_data, bottom_diff);
+  // do mean and variance normalization
+  if (bottom[0] != top[0]) {
+    caffe_copy(bottom[0]->count(), bottom_data, top_data);
   }
+  // subtract mean
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
+      batch_sum_multiplier_.gpu_data(), mean_.gpu_data(), 0.,
+      num_by_chans_.mutable_gpu_data());
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
+      spatial_dim, 1, -1, num_by_chans_.gpu_data(),
+      spatial_sum_multiplier_.gpu_data(), 1., top_data);
+  // replicate variance to input size
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
+      batch_sum_multiplier_.gpu_data(), variance_.gpu_data(), 0.,
+      num_by_chans_.mutable_gpu_data());
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
+      spatial_dim, 1, 1., num_by_chans_.gpu_data(),
+      spatial_sum_multiplier_.gpu_data(), 0., temp_.mutable_gpu_data());
+  caffe_gpu_div(temp_.count(), top_data, temp_.gpu_data(), top_data);
+  // TODO(cdoersch): The caching is only needed because later in-place layers
+  //                 might clobber the data.  Can we skip this if they won't?
+  caffe_copy(x_norm_.count(), top_data,
+      x_norm_.mutable_gpu_data());
+}
+
+template <typename Dtype>
+void BatchNormLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  CHECK(!use_global_stats_);
+  const Dtype* top_diff;
+  if (bottom[0] != top[0]) {
+    top_diff = top[0]->gpu_diff();
+  } else {
+    caffe_copy(x_norm_.count(), top[0]->gpu_diff(), x_norm_.mutable_gpu_diff());
+    top_diff = x_norm_.gpu_diff();
+  }
+  const Dtype* top_data = x_norm_.gpu_data();
+  Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+  int num = bottom[0]->shape()[0];
+  int spatial_dim = bottom[0]->count()/(channels_*bottom[0]->shape(0));
+  // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then
+  //
+  // dE(Y)/dX =
+  //   (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y)
+  //     ./ sqrt(var(X) + eps)
+  //
+  // where \cdot and ./ are hadamard product and elementwise division,
+  // respectively, dE/dY is the top diff, and mean/var/sum are all computed
+  // along all dimensions except the channels dimension.  In the above
+  // equation, the operations allow for expansion (i.e. broadcast) along all
+  // dimensions except the channels dimension where required.
+
+  // sum(dE/dY \cdot Y)
+  caffe_gpu_mul(temp_.count(), top_data, top_diff, bottom_diff);
+  caffe_gpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
+      bottom_diff, spatial_sum_multiplier_.gpu_data(), 0.,
+      num_by_chans_.mutable_gpu_data());
+  caffe_gpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+      num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0.,
+      mean_.mutable_gpu_data());
+
+  // reshape (broadcast) the above
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
+      batch_sum_multiplier_.gpu_data(), mean_.gpu_data(), 0.,
+      num_by_chans_.mutable_gpu_data());
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
+      spatial_dim, 1, 1., num_by_chans_.gpu_data(),
+      spatial_sum_multiplier_.gpu_data(), 0., bottom_diff);
+
+  // sum(dE/dY \cdot Y) \cdot Y
+  caffe_gpu_mul(temp_.count(), top_data, bottom_diff, bottom_diff);
+
+  // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
+  caffe_gpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
+      top_diff, spatial_sum_multiplier_.gpu_data(), 0.,
+      num_by_chans_.mutable_gpu_data());
+  caffe_gpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
+      num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0.,
+      mean_.mutable_gpu_data());
+  // reshape (broadcast) the above to make
+  // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
+      batch_sum_multiplier_.gpu_data(), mean_.gpu_data(), 0.,
+      num_by_chans_.mutable_gpu_data());
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num * channels_,
+      spatial_dim, 1, 1., num_by_chans_.gpu_data(),
+      spatial_sum_multiplier_.gpu_data(), 1., bottom_diff);
+
+  // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y
+  caffe_gpu_axpby(temp_.count(), Dtype(1), top_diff,
+      Dtype(-1. / (num * spatial_dim)), bottom_diff);
+
+  // note: temp_ still contains sqrt(var(X)+eps), computed during the forward
+  // pass.
+  caffe_gpu_div(temp_.count(), bottom_diff, temp_.gpu_data(), bottom_diff);
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(BatchNormLayer);
 
-  INSTANTIATE_LAYER_GPU_FUNCS(BatchNormLayer);
-}  // namespace caffe
 
+}  // namespace caffe
index a8747c1..99dd3c9 100644 (file)
@@ -301,7 +301,7 @@ message ParamSpec {
 // NOTE
 // Update the next available ID when you add a new LayerParameter field.
 //
-// LayerParameter next available layer-specific ID: 139 (last added: tile_param)
+// LayerParameter next available layer-specific ID: 140 (last added: batch_norm_param)
 message LayerParameter {
   optional string name = 1; // the layer name
   optional string type = 2; // the layer type
@@ -350,6 +350,7 @@ message LayerParameter {
   // The default for the engine is set by the ENGINE switch at compile-time.
   optional AccuracyParameter accuracy_param = 102;
   optional ArgMaxParameter argmax_param = 103;
+  optional BatchNormParameter batch_norm_param = 139;
   optional ConcatParameter concat_param = 104;
   optional ContrastiveLossParameter contrastive_loss_param = 105;
   optional ConvolutionParameter convolution_param = 106;
@@ -461,6 +462,18 @@ message ConcatParameter {
   optional uint32 concat_dim = 1 [default = 1];
 }
 
+message BatchNormParameter {
+  // If false, accumulate global mean/variance values via a moving average. If
+  // true, use those accumulated values instead of computing mean/variance
+  // across the batch.
+  optional bool use_global_stats = 1;
+  // How much does the moving average decay each iteration?
+  optional float moving_average_fraction = 2 [default = .999];
+  // Small value to add to the variance estimate so that we don't divide by
+  // zero.
+  optional float eps = 3 [default = 1e-5];
+}
+
 message ContrastiveLossParameter {
   // margin for dissimilar pair
   optional float margin = 1 [default = 1.0];
index 704efd5..22b9667 100644 (file)
@@ -60,7 +60,50 @@ namespace caffe {
         for ( int k = 0; k < height; ++k ) {
           for ( int l = 0; l < width; ++l ) {
             Dtype data = this->blob_top_->data_at(i, j, k, l);
-            Dtype bottom_data = this->blob_bottom_->data_at(i, j, k, l);
+            sum += data;
+            var += data * data;
+          }
+        }
+      }
+      sum /= height * width * num;
+      var /= height * width * num;
+
+      const Dtype kErrorBound = 0.001;
+      // expect zero mean
+      EXPECT_NEAR(0, sum, kErrorBound);
+      // expect unit variance
+      EXPECT_NEAR(1, var, kErrorBound);
+    }
+  }
+
+  TYPED_TEST(BatchNormLayerTest, TestForwardInplace) {
+    typedef typename TypeParam::Dtype Dtype;
+    Blob<Dtype> blob_inplace(5, 2, 3, 4);
+    vector<Blob<Dtype>*> blob_bottom_vec;
+    vector<Blob<Dtype>*> blob_top_vec;
+    LayerParameter layer_param;
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(&blob_inplace);
+    blob_bottom_vec.push_back(&blob_inplace);
+    blob_top_vec.push_back(&blob_inplace);
+
+    BatchNormLayer<Dtype> layer(layer_param);
+    layer.SetUp(blob_bottom_vec, blob_top_vec);
+    layer.Forward(blob_bottom_vec, blob_top_vec);
+
+    // Test mean
+    int num = blob_inplace.num();
+    int channels = blob_inplace.channels();
+    int height = blob_inplace.height();
+    int width = blob_inplace.width();
+
+    for (int j = 0; j < channels; ++j) {
+      Dtype sum = 0, var = 0;
+      for (int i = 0; i < num; ++i) {
+        for ( int k = 0; k < height; ++k ) {
+          for ( int l = 0; l < width; ++l ) {
+            Dtype data = blob_inplace.data_at(i, j, k, l);
             sum += data;
             var += data * data;
           }