Added MSRAFiller, an Xavier-like filler designed for use with ReLUs
authorNick Carlevaris-Bianco <carlevar@umich.edu>
Mon, 16 Feb 2015 05:19:43 +0000 (15:49 +1030)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 26 May 2015 19:23:04 +0000 (12:23 -0700)
...instead of tanh. Based on paper: He et al, "Delving Deep into
Rectifiers: Surpassing Human-Level Performance on ImageNet
Classification," 2015.

- add VarianceNorm option to FillerParameters which allows one to
  normalize by fan_in, fan_out or their average.
- update XavierFiller to use the VarianceNorm option (default behavior
  unchanged).
- add tests for MSRAFiller and XavierFiller.

include/caffe/filler.hpp
src/caffe/proto/caffe.proto
src/caffe/test/test_filler.cpp

index eebf565..0125b30 100644 (file)
@@ -127,17 +127,18 @@ class PositiveUnitballFiller : public Filler<Dtype> {
 };
 
 /**
- * @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$
- *        is set inversely proportional to the number of incoming nodes.
+ * @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$ is
+ *        set inversely proportional to number of incoming nodes, outgoing
+ *        nodes, or their average.
  *
  * A Filler based on the paper [Bengio and Glorot 2010]: Understanding
- * the difficulty of training deep feedforward neuralnetworks, but does not
- * use the fan_out value.
+ * the difficulty of training deep feedforward neuralnetworks.
  *
- * It fills the incoming matrix by randomly sampling uniform data from
- * [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
- * of input nodes. You should make sure the input blob has shape (num, a, b, c)
- * where a * b * c = fan_in.
+ * It fills the incoming matrix by randomly sampling uniform data from [-scale,
+ * scale] where scale = sqrt(3 / n) where n is the fan_in, fan_out, or their
+ * average, depending on the variance_norm option. You should make sure the
+ * input blob has shape (num, a, b, c) where a * b * c = fan_in and num * b * c
+ * = fan_out. Note that this is currently not the case for inner product layers.
  *
  * TODO(dox): make notation in above comment consistent with rest & use LaTeX.
  */
@@ -149,7 +150,16 @@ class XavierFiller : public Filler<Dtype> {
   virtual void Fill(Blob<Dtype>* blob) {
     CHECK(blob->count());
     int fan_in = blob->count() / blob->num();
-    Dtype scale = sqrt(Dtype(3) / fan_in);
+    int fan_out = blob->count() / blob->channels();
+    Dtype n = fan_in;  // default to fan_in
+    if (this->filler_param_.variance_norm() ==
+        FillerParameter_VarianceNorm_AVERAGE) {
+      n = (fan_in + fan_out) / Dtype(2);
+    } else if (this->filler_param_.variance_norm() ==
+        FillerParameter_VarianceNorm_FAN_OUT) {
+      n = fan_out;
+    }
+    Dtype scale = sqrt(Dtype(3) / n);
     caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
         blob->mutable_cpu_data());
     CHECK_EQ(this->filler_param_.sparse(), -1)
@@ -157,6 +167,44 @@ class XavierFiller : public Filler<Dtype> {
   }
 };
 
+/**
+ * @brief Fills a Blob with values @f$ x \sim N(0, \sigma^2) @f$ where
+ *        @f$ \sigma^2 @f$ is set inversely proportional to number of incoming
+ *        nodes, outgoing nodes, or their average.
+ *
+ * A Filler based on the paper [He, Zhang, Ren and Sun 2015]: Specifically
+ * accounts for ReLU nonlinearities.
+ *
+ * It fills the incoming matrix by randomly sampling Gaussian data with std =
+ * sqrt(2 / n) where n is the fan_in, fan_out, or their average, depending on
+ * the variance_norm option. You should make sure the input blob has shape (num,
+ * a, b, c) where a * b * c = fan_in and num * b * c = fan_out. Note that this
+ * is currently not the case for inner product layers.
+ */
+template <typename Dtype>
+class MSRAFiller : public Filler<Dtype> {
+ public:
+  explicit MSRAFiller(const FillerParameter& param)
+      : Filler<Dtype>(param) {}
+  virtual void Fill(Blob<Dtype>* blob) {
+    CHECK(blob->count());
+    int fan_in = blob->count() / blob->num();
+    int fan_out = blob->count() / blob->channels();
+    Dtype n = fan_in;  // default to fan_in
+    if (this->filler_param_.variance_norm() ==
+        FillerParameter_VarianceNorm_AVERAGE) {
+      n = (fan_in + fan_out) / Dtype(2);
+    } else if (this->filler_param_.variance_norm() ==
+        FillerParameter_VarianceNorm_FAN_OUT) {
+      n = fan_out;
+    }
+    Dtype std = sqrt(Dtype(2) / n);
+    caffe_rng_gaussian<Dtype>(blob->count(), Dtype(0), std,
+        blob->mutable_cpu_data());
+    CHECK_EQ(this->filler_param_.sparse(), -1)
+         << "Sparsity not supported by this Filler.";
+  }
+};
 
 /**
  * @brief Get a specific filler from the specification given in FillerParameter.
@@ -177,6 +225,8 @@ Filler<Dtype>* GetFiller(const FillerParameter& param) {
     return new UniformFiller<Dtype>(param);
   } else if (type == "xavier") {
     return new XavierFiller<Dtype>(param);
+  } else if (type == "msra") {
+    return new MSRAFiller<Dtype>(param);
   } else {
     CHECK(false) << "Unknown filler name: " << param.type();
   }
index 84b475c..f9d4d61 100644 (file)
@@ -41,6 +41,14 @@ message FillerParameter {
   // The expected number of non-zero output weights for a given input in
   // Gaussian filler -- the default -1 means don't perform sparsification.
   optional int32 sparse = 7 [default = -1];
+  // Normalize the filler variance by fan_in, fan_out, or their average.
+  // Applies to 'xavier' and 'msra' fillers.
+  enum VarianceNorm {
+    FAN_IN = 0;
+    FAN_OUT = 1;
+    AVERAGE = 2;
+  }
+  optional VarianceNorm variance_norm = 8 [default = FAN_IN];
 }
 
 message NetParameter {
index e04b0fd..728b8dc 100644 (file)
@@ -142,4 +142,102 @@ TYPED_TEST(GaussianFillerTest, TestFill) {
   EXPECT_LE(var, target_var * 5.);
 }
 
+template <typename Dtype>
+class XavierFillerTest : public ::testing::Test {
+ protected:
+  XavierFillerTest()
+      : blob_(new Blob<Dtype>(1000, 2, 4, 5)),
+        filler_param_() {
+  }
+  virtual void test_params(FillerParameter_VarianceNorm variance_norm,
+      Dtype n) {
+    this->filler_param_.set_variance_norm(variance_norm);
+    this->filler_.reset(new XavierFiller<Dtype>(this->filler_param_));
+    this->filler_->Fill(blob_);
+    EXPECT_TRUE(this->blob_);
+    const int count = this->blob_->count();
+    const Dtype* data = this->blob_->cpu_data();
+    Dtype mean = 0.;
+    Dtype ex2 = 0.;
+    for (int i = 0; i < count; ++i) {
+      mean += data[i];
+      ex2 += data[i] * data[i];
+    }
+    mean /= count;
+    ex2 /= count;
+    Dtype std = sqrt(ex2 - mean*mean);
+    Dtype target_std = sqrt(2.0 / n);
+    EXPECT_NEAR(mean, 0.0, 0.1);
+    EXPECT_NEAR(std, target_std, 0.1);
+  }
+  virtual ~XavierFillerTest() { delete blob_; }
+  Blob<Dtype>* const blob_;
+  FillerParameter filler_param_;
+  shared_ptr<XavierFiller<Dtype> > filler_;
+};
+
+TYPED_TEST_CASE(XavierFillerTest, TestDtypes);
+
+TYPED_TEST(XavierFillerTest, TestFillFanIn) {
+  TypeParam n = 2*4*5;
+  this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
+}
+TYPED_TEST(XavierFillerTest, TestFillFanOut) {
+  TypeParam n = 1000*4*5;
+  this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
+}
+TYPED_TEST(XavierFillerTest, TestFillAverage) {
+  TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
+  this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
+}
+
+template <typename Dtype>
+class MSRAFillerTest : public ::testing::Test {
+ protected:
+  MSRAFillerTest()
+      : blob_(new Blob<Dtype>(1000, 2, 4, 5)),
+        filler_param_() {
+  }
+  virtual void test_params(FillerParameter_VarianceNorm variance_norm,
+      Dtype n) {
+    this->filler_param_.set_variance_norm(variance_norm);
+    this->filler_.reset(new MSRAFiller<Dtype>(this->filler_param_));
+    this->filler_->Fill(blob_);
+    EXPECT_TRUE(this->blob_);
+    const int count = this->blob_->count();
+    const Dtype* data = this->blob_->cpu_data();
+    Dtype mean = 0.;
+    Dtype ex2 = 0.;
+    for (int i = 0; i < count; ++i) {
+      mean += data[i];
+      ex2 += data[i] * data[i];
+    }
+    mean /= count;
+    ex2 /= count;
+    Dtype std = sqrt(ex2 - mean*mean);
+    Dtype target_std = sqrt(2.0 / n);
+    EXPECT_NEAR(mean, 0.0, 0.1);
+    EXPECT_NEAR(std, target_std, 0.1);
+  }
+  virtual ~MSRAFillerTest() { delete blob_; }
+  Blob<Dtype>* const blob_;
+  FillerParameter filler_param_;
+  shared_ptr<MSRAFiller<Dtype> > filler_;
+};
+
+TYPED_TEST_CASE(MSRAFillerTest, TestDtypes);
+
+TYPED_TEST(MSRAFillerTest, TestFillFanIn) {
+  TypeParam n = 2*4*5;
+  this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
+}
+TYPED_TEST(MSRAFillerTest, TestFillFanOut) {
+  TypeParam n = 1000*4*5;
+  this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
+}
+TYPED_TEST(MSRAFillerTest, TestFillAverage) {
+  TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
+  this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
+}
+
 }  // namespace caffe