Merge pull request #231 from BVLC/next
[platform/upstream/caffeonacl.git] / include / caffe / filler.hpp
1 // Copyright 2013 Yangqing Jia
2
3 // Fillers are random number generators that fills a blob using the specified
4 // algorithm. The expectation is that they are only going to be used during
5 // initialization time and will not involve any GPUs.
6
7 #ifndef CAFFE_FILLER_HPP
8 #define CAFFE_FILLER_HPP
9
10 #include <mkl.h>
11 #include <string>
12
13 #include "caffe/common.hpp"
14 #include "caffe/blob.hpp"
15 #include "caffe/syncedmem.hpp"
16 #include "caffe/util/math_functions.hpp"
17 #include "caffe/proto/caffe.pb.h"
18
19 namespace caffe {
20
21 template <typename Dtype>
22 class Filler {
23  public:
24   explicit Filler(const FillerParameter& param) : filler_param_(param) {}
25   virtual ~Filler() {}
26   virtual void Fill(Blob<Dtype>* blob) = 0;
27  protected:
28   FillerParameter filler_param_;
29 };  // class Filler
30
31
32 template <typename Dtype>
33 class ConstantFiller : public Filler<Dtype> {
34  public:
35   explicit ConstantFiller(const FillerParameter& param)
36       : Filler<Dtype>(param) {}
37   virtual void Fill(Blob<Dtype>* blob) {
38     Dtype* data = blob->mutable_cpu_data();
39     const int count = blob->count();
40     const Dtype value = this->filler_param_.value();
41     CHECK(count);
42     for (int i = 0; i < count; ++i) {
43       data[i] = value;
44     }
45   }
46 };
47
48 template <typename Dtype>
49 class UniformFiller : public Filler<Dtype> {
50  public:
51   explicit UniformFiller(const FillerParameter& param)
52       : Filler<Dtype>(param) {}
53   virtual void Fill(Blob<Dtype>* blob) {
54     CHECK(blob->count());
55     caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
56         Dtype(this->filler_param_.min()),
57         Dtype(this->filler_param_.max()));
58   }
59 };
60
61 template <typename Dtype>
62 class GaussianFiller : public Filler<Dtype> {
63  public:
64   explicit GaussianFiller(const FillerParameter& param)
65       : Filler<Dtype>(param) {}
66   virtual void Fill(Blob<Dtype>* blob) {
67     Dtype* data = blob->mutable_cpu_data();
68     CHECK(blob->count());
69     caffe_vRngGaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
70         Dtype(this->filler_param_.mean()),
71         Dtype(this->filler_param_.std()));
72   }
73 };
74
75 template <typename Dtype>
76 class PositiveUnitballFiller : public Filler<Dtype> {
77  public:
78   explicit PositiveUnitballFiller(const FillerParameter& param)
79       : Filler<Dtype>(param) {}
80   virtual void Fill(Blob<Dtype>* blob) {
81     Dtype* data = blob->mutable_cpu_data();
82     DCHECK(blob->count());
83     caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(), 0, 1);
84     // We expect the filler to not be called very frequently, so we will
85     // just use a simple implementation
86     int dim = blob->count() / blob->num();
87     CHECK(dim);
88     for (int i = 0; i < blob->num(); ++i) {
89       Dtype sum = 0;
90       for (int j = 0; j < dim; ++j) {
91         sum += data[i * dim + j];
92       }
93       for (int j = 0; j < dim; ++j) {
94         data[i * dim + j] /= sum;
95       }
96     }
97   }
98 };
99
100 // A filler based on the paper [Bengio and Glorot 2010]: Understanding
101 // the difficulty of training deep feedforward neuralnetworks, but does not
102 // use the fan_out value.
103 //
104 // It fills the incoming matrix by randomly sampling uniform data from
105 // [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
106 // of input nodes. You should make sure the input blob has shape (num, a, b, c)
107 // where a * b * c = fan_in.
108 template <typename Dtype>
109 class XavierFiller : public Filler<Dtype> {
110  public:
111   explicit XavierFiller(const FillerParameter& param)
112       : Filler<Dtype>(param) {}
113   virtual void Fill(Blob<Dtype>* blob) {
114     CHECK(blob->count());
115     int fan_in = blob->count() / blob->num();
116     Dtype scale = sqrt(Dtype(3) / fan_in);
117     caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
118         -scale, scale);
119   }
120 };
121
122
123 // A function to get a specific filler from the specification given in
124 // FillerParameter. Ideally this would be replaced by a factory pattern,
125 // but we will leave it this way for now.
126 template <typename Dtype>
127 Filler<Dtype>* GetFiller(const FillerParameter& param) {
128   const std::string& type = param.type();
129   if (type == "constant") {
130     return new ConstantFiller<Dtype>(param);
131   } else if (type == "gaussian") {
132     return new GaussianFiller<Dtype>(param);
133   } else if (type == "positive_unitball") {
134     return new PositiveUnitballFiller<Dtype>(param);
135   } else if (type == "uniform") {
136     return new UniformFiller<Dtype>(param);
137   } else if (type == "xavier") {
138     return new XavierFiller<Dtype>(param);
139   } else {
140     CHECK(false) << "Unknown filler name: " << param.type();
141   }
142   return (Filler<Dtype>*)(NULL);
143 }
144
145 }  // namespace caffe
146
147 #endif  // CAFFE_FILLER_HPP_