Merge pull request #1987 from tnarihi/fix-siam-example
[platform/upstream/caffeonacl.git] / include / caffe / filler.hpp
1 // Fillers are random number generators that fills a blob using the specified
2 // algorithm. The expectation is that they are only going to be used during
3 // initialization time and will not involve any GPUs.
4
5 #ifndef CAFFE_FILLER_HPP
6 #define CAFFE_FILLER_HPP
7
8 #include <string>
9
10 #include "caffe/blob.hpp"
11 #include "caffe/common.hpp"
12 #include "caffe/proto/caffe.pb.h"
13 #include "caffe/syncedmem.hpp"
14 #include "caffe/util/math_functions.hpp"
15
16 namespace caffe {
17
18 /// @brief Fills a Blob with constant or randomly-generated data.
19 template <typename Dtype>
20 class Filler {
21  public:
22   explicit Filler(const FillerParameter& param) : filler_param_(param) {}
23   virtual ~Filler() {}
24   virtual void Fill(Blob<Dtype>* blob) = 0;
25  protected:
26   FillerParameter filler_param_;
27 };  // class Filler
28
29
30 /// @brief Fills a Blob with constant values @f$ x = 0 @f$.
31 template <typename Dtype>
32 class ConstantFiller : public Filler<Dtype> {
33  public:
34   explicit ConstantFiller(const FillerParameter& param)
35       : Filler<Dtype>(param) {}
36   virtual void Fill(Blob<Dtype>* blob) {
37     Dtype* data = blob->mutable_cpu_data();
38     const int count = blob->count();
39     const Dtype value = this->filler_param_.value();
40     CHECK(count);
41     for (int i = 0; i < count; ++i) {
42       data[i] = value;
43     }
44     CHECK_EQ(this->filler_param_.sparse(), -1)
45          << "Sparsity not supported by this Filler.";
46   }
47 };
48
49 /// @brief Fills a Blob with uniformly distributed values @f$ x\sim U(a, b) @f$.
50 template <typename Dtype>
51 class UniformFiller : public Filler<Dtype> {
52  public:
53   explicit UniformFiller(const FillerParameter& param)
54       : Filler<Dtype>(param) {}
55   virtual void Fill(Blob<Dtype>* blob) {
56     CHECK(blob->count());
57     caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
58         Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
59     CHECK_EQ(this->filler_param_.sparse(), -1)
60          << "Sparsity not supported by this Filler.";
61   }
62 };
63
64 /// @brief Fills a Blob with Gaussian-distributed values @f$ x = a @f$.
65 template <typename Dtype>
66 class GaussianFiller : public Filler<Dtype> {
67  public:
68   explicit GaussianFiller(const FillerParameter& param)
69       : Filler<Dtype>(param) {}
70   virtual void Fill(Blob<Dtype>* blob) {
71     Dtype* data = blob->mutable_cpu_data();
72     CHECK(blob->count());
73     caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
74         Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
75     int sparse = this->filler_param_.sparse();
76     CHECK_GE(sparse, -1);
77     if (sparse >= 0) {
78       // Sparse initialization is implemented for "weight" blobs; i.e. matrices.
79       // These have num == channels == 1; width is number of inputs; height is
80       // number of outputs.  The 'sparse' variable specifies the mean number
81       // of non-zero input weights for a given output.
82       CHECK_GE(blob->num_axes(), 1);
83       const int num_outputs = blob->shape(0);
84       Dtype non_zero_probability = Dtype(sparse) / Dtype(num_outputs);
85       rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));
86       int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());
87       caffe_rng_bernoulli(blob->count(), non_zero_probability, mask);
88       for (int i = 0; i < blob->count(); ++i) {
89         data[i] *= mask[i];
90       }
91     }
92   }
93
94  protected:
95   shared_ptr<SyncedMemory> rand_vec_;
96 };
97
98 /** @brief Fills a Blob with values @f$ x \in [0, 1] @f$
99  *         such that @f$ \forall i \sum_j x_{ij} = 1 @f$.
100  */
101 template <typename Dtype>
102 class PositiveUnitballFiller : public Filler<Dtype> {
103  public:
104   explicit PositiveUnitballFiller(const FillerParameter& param)
105       : Filler<Dtype>(param) {}
106   virtual void Fill(Blob<Dtype>* blob) {
107     Dtype* data = blob->mutable_cpu_data();
108     DCHECK(blob->count());
109     caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
110     // We expect the filler to not be called very frequently, so we will
111     // just use a simple implementation
112     int dim = blob->count() / blob->num();
113     CHECK(dim);
114     for (int i = 0; i < blob->num(); ++i) {
115       Dtype sum = 0;
116       for (int j = 0; j < dim; ++j) {
117         sum += data[i * dim + j];
118       }
119       for (int j = 0; j < dim; ++j) {
120         data[i * dim + j] /= sum;
121       }
122     }
123     CHECK_EQ(this->filler_param_.sparse(), -1)
124          << "Sparsity not supported by this Filler.";
125   }
126 };
127
128 /**
129  * @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$
130  *        is set inversely proportional to the number of incoming nodes.
131  *
132  * A Filler based on the paper [Bengio and Glorot 2010]: Understanding
133  * the difficulty of training deep feedforward neuralnetworks, but does not
134  * use the fan_out value.
135  *
136  * It fills the incoming matrix by randomly sampling uniform data from
137  * [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
138  * of input nodes. You should make sure the input blob has shape (num, a, b, c)
139  * where a * b * c = fan_in.
140  *
141  * TODO(dox): make notation in above comment consistent with rest & use LaTeX.
142  */
143 template <typename Dtype>
144 class XavierFiller : public Filler<Dtype> {
145  public:
146   explicit XavierFiller(const FillerParameter& param)
147       : Filler<Dtype>(param) {}
148   virtual void Fill(Blob<Dtype>* blob) {
149     CHECK(blob->count());
150     int fan_in = blob->count() / blob->num();
151     Dtype scale = sqrt(Dtype(3) / fan_in);
152     caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
153         blob->mutable_cpu_data());
154     CHECK_EQ(this->filler_param_.sparse(), -1)
155          << "Sparsity not supported by this Filler.";
156   }
157 };
158
159
160 /**
161  * @brief Get a specific filler from the specification given in FillerParameter.
162  *
163  * Ideally this would be replaced by a factory pattern, but we will leave it
164  * this way for now.
165  */
166 template <typename Dtype>
167 Filler<Dtype>* GetFiller(const FillerParameter& param) {
168   const std::string& type = param.type();
169   if (type == "constant") {
170     return new ConstantFiller<Dtype>(param);
171   } else if (type == "gaussian") {
172     return new GaussianFiller<Dtype>(param);
173   } else if (type == "positive_unitball") {
174     return new PositiveUnitballFiller<Dtype>(param);
175   } else if (type == "uniform") {
176     return new UniformFiller<Dtype>(param);
177   } else if (type == "xavier") {
178     return new XavierFiller<Dtype>(param);
179   } else {
180     CHECK(false) << "Unknown filler name: " << param.type();
181   }
182   return (Filler<Dtype>*)(NULL);
183 }
184
185 }  // namespace caffe
186
187 #endif  // CAFFE_FILLER_HPP_