LICENSE governs the whole project so strip file headers
[platform/upstream/caffeonacl.git] / include / caffe / filler.hpp
1 // Fillers are random number generators that fills a blob using the specified
2 // algorithm. The expectation is that they are only going to be used during
3 // initialization time and will not involve any GPUs.
4
5 #ifndef CAFFE_FILLER_HPP
6 #define CAFFE_FILLER_HPP
7
8 #include <string>
9
10 #include "caffe/blob.hpp"
11 #include "caffe/common.hpp"
12 #include "caffe/proto/caffe.pb.h"
13 #include "caffe/syncedmem.hpp"
14 #include "caffe/util/math_functions.hpp"
15
16 namespace caffe {
17
18 template <typename Dtype>
19 class Filler {
20  public:
21   explicit Filler(const FillerParameter& param) : filler_param_(param) {}
22   virtual ~Filler() {}
23   virtual void Fill(Blob<Dtype>* blob) = 0;
24  protected:
25   FillerParameter filler_param_;
26 };  // class Filler
27
28
29 template <typename Dtype>
30 class ConstantFiller : public Filler<Dtype> {
31  public:
32   explicit ConstantFiller(const FillerParameter& param)
33       : Filler<Dtype>(param) {}
34   virtual void Fill(Blob<Dtype>* blob) {
35     Dtype* data = blob->mutable_cpu_data();
36     const int count = blob->count();
37     const Dtype value = this->filler_param_.value();
38     CHECK(count);
39     for (int i = 0; i < count; ++i) {
40       data[i] = value;
41     }
42     CHECK_EQ(this->filler_param_.sparse(), -1)
43          << "Sparsity not supported by this Filler.";
44   }
45 };
46
47 template <typename Dtype>
48 class UniformFiller : public Filler<Dtype> {
49  public:
50   explicit UniformFiller(const FillerParameter& param)
51       : Filler<Dtype>(param) {}
52   virtual void Fill(Blob<Dtype>* blob) {
53     CHECK(blob->count());
54     caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
55         Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
56     CHECK_EQ(this->filler_param_.sparse(), -1)
57          << "Sparsity not supported by this Filler.";
58   }
59 };
60
61 template <typename Dtype>
62 class GaussianFiller : public Filler<Dtype> {
63  public:
64   explicit GaussianFiller(const FillerParameter& param)
65       : Filler<Dtype>(param) {}
66   virtual void Fill(Blob<Dtype>* blob) {
67     Dtype* data = blob->mutable_cpu_data();
68     CHECK(blob->count());
69     caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
70         Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
71     int sparse = this->filler_param_.sparse();
72     CHECK_GE(sparse, -1);
73     if (sparse >= 0) {
74       // Sparse initialization is implemented for "weight" blobs; i.e. matrices.
75       // These have num == channels == 1; height is number of inputs; width is
76       // number of outputs.  The 'sparse' variable specifies the mean number
77       // of non-zero input weights for a given output.
78       CHECK_EQ(blob->num(), 1);
79       CHECK_EQ(blob->channels(), 1);
80       int num_inputs = blob->height();
81       Dtype non_zero_probability = Dtype(sparse) / Dtype(num_inputs);
82       rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));
83       int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());
84       caffe_rng_bernoulli(blob->count(), non_zero_probability, mask);
85       for (int i = 0; i < blob->count(); ++i) {
86         data[i] *= mask[i];
87       }
88     }
89   }
90
91  protected:
92   shared_ptr<SyncedMemory> rand_vec_;
93 };
94
95 template <typename Dtype>
96 class PositiveUnitballFiller : public Filler<Dtype> {
97  public:
98   explicit PositiveUnitballFiller(const FillerParameter& param)
99       : Filler<Dtype>(param) {}
100   virtual void Fill(Blob<Dtype>* blob) {
101     Dtype* data = blob->mutable_cpu_data();
102     DCHECK(blob->count());
103     caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
104     // We expect the filler to not be called very frequently, so we will
105     // just use a simple implementation
106     int dim = blob->count() / blob->num();
107     CHECK(dim);
108     for (int i = 0; i < blob->num(); ++i) {
109       Dtype sum = 0;
110       for (int j = 0; j < dim; ++j) {
111         sum += data[i * dim + j];
112       }
113       for (int j = 0; j < dim; ++j) {
114         data[i * dim + j] /= sum;
115       }
116     }
117     CHECK_EQ(this->filler_param_.sparse(), -1)
118          << "Sparsity not supported by this Filler.";
119   }
120 };
121
122 // A filler based on the paper [Bengio and Glorot 2010]: Understanding
123 // the difficulty of training deep feedforward neuralnetworks, but does not
124 // use the fan_out value.
125 //
126 // It fills the incoming matrix by randomly sampling uniform data from
127 // [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
128 // of input nodes. You should make sure the input blob has shape (num, a, b, c)
129 // where a * b * c = fan_in.
130 template <typename Dtype>
131 class XavierFiller : public Filler<Dtype> {
132  public:
133   explicit XavierFiller(const FillerParameter& param)
134       : Filler<Dtype>(param) {}
135   virtual void Fill(Blob<Dtype>* blob) {
136     CHECK(blob->count());
137     int fan_in = blob->count() / blob->num();
138     Dtype scale = sqrt(Dtype(3) / fan_in);
139     caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
140         blob->mutable_cpu_data());
141     CHECK_EQ(this->filler_param_.sparse(), -1)
142          << "Sparsity not supported by this Filler.";
143   }
144 };
145
146
147 // A function to get a specific filler from the specification given in
148 // FillerParameter. Ideally this would be replaced by a factory pattern,
149 // but we will leave it this way for now.
150 template <typename Dtype>
151 Filler<Dtype>* GetFiller(const FillerParameter& param) {
152   const std::string& type = param.type();
153   if (type == "constant") {
154     return new ConstantFiller<Dtype>(param);
155   } else if (type == "gaussian") {
156     return new GaussianFiller<Dtype>(param);
157   } else if (type == "positive_unitball") {
158     return new PositiveUnitballFiller<Dtype>(param);
159   } else if (type == "uniform") {
160     return new UniformFiller<Dtype>(param);
161   } else if (type == "xavier") {
162     return new XavierFiller<Dtype>(param);
163   } else {
164     CHECK(false) << "Unknown filler name: " << param.type();
165   }
166   return (Filler<Dtype>*)(NULL);
167 }
168
169 }  // namespace caffe
170
171 #endif  // CAFFE_FILLER_HPP_