xavier filler
authorYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 17:25:23 +0000 (10:25 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 17:25:23 +0000 (10:25 -0700)
src/caffe/filler.hpp
src/caffe/test/lenet.hpp
src/caffe/test/test_net_proto.cpp

index f455460..ffe7a50 100644 (file)
@@ -51,7 +51,7 @@ class UniformFiller : public Filler<Dtype> {
   explicit UniformFiller(const FillerParameter& param)
       : Filler<Dtype>(param) {}
   virtual void Fill(Blob<Dtype>* blob) {
-    DCHECK(blob->count());
+    CHECK(blob->count());
     caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
         Dtype(this->filler_param_.min()),
         Dtype(this->filler_param_.max()));
@@ -65,7 +65,7 @@ class GaussianFiller : public Filler<Dtype> {
       : Filler<Dtype>(param) {}
   virtual void Fill(Blob<Dtype>* blob) {
     Dtype* data = blob->mutable_cpu_data();
-    DCHECK(blob->count());
+    CHECK(blob->count());
     caffe_vRngGaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
         Dtype(this->filler_param_.mean()),
         Dtype(this->filler_param_.std()));
@@ -84,7 +84,7 @@ class PositiveUnitballFiller : public Filler<Dtype> {
     // We expect the filler to not be called very frequently, so we will
     // just use a simple implementation
     int dim = blob->count() / blob->num();
-    DCHECK(dim);
+    CHECK(dim);
     for (int i = 0; i < blob->num(); ++i) {
       Dtype sum = 0;
       for (int j = 0; j < dim; ++j) {
@@ -97,13 +97,28 @@ class PositiveUnitballFiller : public Filler<Dtype> {
   }
 };
 
+// A filler based on the paper [Bengio and Glorot 2010]: Understanding
+// the difficulty of training deep feedforward neuralnetworks, but does not
+// use the fan_out value.
+//
+// It fills the incoming matrix by randomly sampling uniform data from
+// [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
+// of input nodes, and in our case we consider the blob width as the scale.
+// You should make sure the input blob has shape (1, 1, height, width).
 template <typename Dtype>
 class XavierFiller : public Filler<Dtype> {
  public:
   explicit XavierFiller(const FillerParameter& param)
       : Filler<Dtype>(param) {}
   virtual void Fill(Blob<Dtype>* blob) {
-    
+    CHECK(blob->count());
+    CHECK_EQ(blob->num(), 1) << "XavierFiller requires blob.num() = 1.";
+    CHECK_EQ(blob->channels(), 1)
+        << "XavierFiller requires blob.channels() = 1.";
+    int fan_in = blob->width();
+    Dtype scale = sqrt(Dtype(3) / fan_in);
+    caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
+        -scale, scale);    
   }
 };
 
@@ -116,12 +131,14 @@ Filler<Dtype>* GetFiller(const FillerParameter& param) {
   const std::string& type = param.type();
   if (type == "constant") {
     return new ConstantFiller<Dtype>(param);
-  } else if (type == "uniform") {
-    return new UniformFiller<Dtype>(param);
   } else if (type == "gaussian") {
     return new GaussianFiller<Dtype>(param);
   } else if (type == "positive_unitball") {
     return new PositiveUnitballFiller<Dtype>(param);
+  } else if (type == "uniform") {
+    return new UniformFiller<Dtype>(param);
+  } else if (type == "xavier") {
+    return new XavierFiller<Dtype>(param);
   } else {
     CHECK(false) << "Unknown filler name: " << param.type();
   }
index 29ec3c4..266f0b2 100644 (file)
@@ -15,6 +15,12 @@ layers {\n\
     num_output: 20\n\
     kernelsize: 5\n\
     stride: 1\n\
+    weight_filler {\n\
+      type: \"xavier\"\n\
+    }\n\
+    bias_filler {\n\
+      type: \"constant\"\n\
+    }\n\
   }\n\
   bottom: \"data\"\n\
   top: \"conv1\"\n\
@@ -37,6 +43,12 @@ layers {\n\
     num_output: 50\n\
     kernelsize: 5\n\
     stride: 1\n\
+    weight_filler {\n\
+      type: \"xavier\"\n\
+    }\n\
+    bias_filler {\n\
+      type: \"constant\"\n\
+    }\n\
   }\n\
   bottom: \"pool1\"\n\
   top: \"conv2\"\n\
@@ -57,6 +69,12 @@ layers {\n\
     name: \"ip1\"\n\
     type: \"innerproduct\"\n\
     num_output: 500\n\
+    weight_filler {\n\
+      type: \"xavier\"\n\
+    }\n\
+    bias_filler {\n\
+      type: \"constant\"\n\
+    }\n\
   }\n\
   bottom: \"pool2\"\n\
   top: \"ip1\"\n\
@@ -74,6 +92,12 @@ layers {\n\
     name: \"ip2\"\n\
     type: \"innerproduct\"\n\
     num_output: 10\n\
+    weight_filler {\n\
+      type: \"xavier\"\n\
+    }\n\
+    bias_filler {\n\
+      type: \"constant\"\n\
+    }\n\
   }\n\
   bottom: \"relu1\"\n\
   top: \"ip2\"\n\
index f0b0e7d..013bd67 100644 (file)
@@ -40,7 +40,7 @@ TYPED_TEST(NetProtoTest, TestSetup) {
   shared_ptr<Filler<TypeParam> > filler;
   filler.reset(new ConstantFiller<TypeParam>(filler_param));
   filler->Fill(label.get());
-  filler.reset(new GaussianFiller<TypeParam>(filler_param));
+  filler.reset(new UniformFiller<TypeParam>(filler_param));
   filler->Fill(data.get());
 
   vector<Blob<TypeParam>*> bottom_vec;