Merge pull request #429 from shelhamer/next
[platform/upstream/caffeonacl.git] / include / caffe / filler.hpp
1 // Copyright 2014 BVLC and contributors.
2
3 // Fillers are random number generators that fills a blob using the specified
4 // algorithm. The expectation is that they are only going to be used during
5 // initialization time and will not involve any GPUs.
6
7 #ifndef CAFFE_FILLER_HPP
8 #define CAFFE_FILLER_HPP
9
10 #include <string>
11
12 #include "caffe/common.hpp"
13 #include "caffe/blob.hpp"
14 #include "caffe/syncedmem.hpp"
15 #include "caffe/util/math_functions.hpp"
16 #include "caffe/proto/caffe.pb.h"
17
18 namespace caffe {
19
20 template <typename Dtype>
21 class Filler {
22  public:
23   explicit Filler(const FillerParameter& param) : filler_param_(param) {}
24   virtual ~Filler() {}
25   virtual void Fill(Blob<Dtype>* blob) = 0;
26  protected:
27   FillerParameter filler_param_;
28 };  // class Filler
29
30
31 template <typename Dtype>
32 class ConstantFiller : public Filler<Dtype> {
33  public:
34   explicit ConstantFiller(const FillerParameter& param)
35       : Filler<Dtype>(param) {}
36   virtual void Fill(Blob<Dtype>* blob) {
37     Dtype* data = blob->mutable_cpu_data();
38     const int count = blob->count();
39     const Dtype value = this->filler_param_.value();
40     CHECK(count);
41     for (int i = 0; i < count; ++i) {
42       data[i] = value;
43     }
44     CHECK_EQ(this->filler_param_.sparse(), -1)
45          << "Sparsity not supported by this Filler.";
46   }
47 };
48
49 template <typename Dtype>
50 class UniformFiller : public Filler<Dtype> {
51  public:
52   explicit UniformFiller(const FillerParameter& param)
53       : Filler<Dtype>(param) {}
54   virtual void Fill(Blob<Dtype>* blob) {
55     CHECK(blob->count());
56     caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
57         Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
58     CHECK_EQ(this->filler_param_.sparse(), -1)
59          << "Sparsity not supported by this Filler.";
60   }
61 };
62
63 template <typename Dtype>
64 class GaussianFiller : public Filler<Dtype> {
65  public:
66   explicit GaussianFiller(const FillerParameter& param)
67       : Filler<Dtype>(param) {}
68   virtual void Fill(Blob<Dtype>* blob) {
69     Dtype* data = blob->mutable_cpu_data();
70     CHECK(blob->count());
71     caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
72         Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
73     int sparse = this->filler_param_.sparse();
74     CHECK_GE(sparse, -1);
75     if (sparse >= 0) {
76       // Sparse initialization is implemented for "weight" blobs; i.e. matrices.
77       // These have num == channels == 1; height is number of inputs; width is
78       // number of outputs.  The 'sparse' variable specifies the mean number
79       // of non-zero input weights for a given output.
80       CHECK_EQ(blob->num(), 1);
81       CHECK_EQ(blob->channels(), 1);
82       int num_inputs = blob->height();
83       Dtype non_zero_probability = Dtype(sparse) / Dtype(num_inputs);
84       rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));
85       int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());
86       caffe_rng_bernoulli(blob->count(), non_zero_probability, mask);
87       for (int i = 0; i < blob->count(); ++i) {
88         data[i] *= mask[i];
89       }
90     }
91   }
92
93  protected:
94   shared_ptr<SyncedMemory> rand_vec_;
95 };
96
97 template <typename Dtype>
98 class PositiveUnitballFiller : public Filler<Dtype> {
99  public:
100   explicit PositiveUnitballFiller(const FillerParameter& param)
101       : Filler<Dtype>(param) {}
102   virtual void Fill(Blob<Dtype>* blob) {
103     Dtype* data = blob->mutable_cpu_data();
104     DCHECK(blob->count());
105     caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
106     // We expect the filler to not be called very frequently, so we will
107     // just use a simple implementation
108     int dim = blob->count() / blob->num();
109     CHECK(dim);
110     for (int i = 0; i < blob->num(); ++i) {
111       Dtype sum = 0;
112       for (int j = 0; j < dim; ++j) {
113         sum += data[i * dim + j];
114       }
115       for (int j = 0; j < dim; ++j) {
116         data[i * dim + j] /= sum;
117       }
118     }
119     CHECK_EQ(this->filler_param_.sparse(), -1)
120          << "Sparsity not supported by this Filler.";
121   }
122 };
123
124 // A filler based on the paper [Bengio and Glorot 2010]: Understanding
125 // the difficulty of training deep feedforward neuralnetworks, but does not
126 // use the fan_out value.
127 //
128 // It fills the incoming matrix by randomly sampling uniform data from
129 // [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
130 // of input nodes. You should make sure the input blob has shape (num, a, b, c)
131 // where a * b * c = fan_in.
132 template <typename Dtype>
133 class XavierFiller : public Filler<Dtype> {
134  public:
135   explicit XavierFiller(const FillerParameter& param)
136       : Filler<Dtype>(param) {}
137   virtual void Fill(Blob<Dtype>* blob) {
138     CHECK(blob->count());
139     int fan_in = blob->count() / blob->num();
140     Dtype scale = sqrt(Dtype(3) / fan_in);
141     caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
142         blob->mutable_cpu_data());
143     CHECK_EQ(this->filler_param_.sparse(), -1)
144          << "Sparsity not supported by this Filler.";
145   }
146 };
147
148
149 // A function to get a specific filler from the specification given in
150 // FillerParameter. Ideally this would be replaced by a factory pattern,
151 // but we will leave it this way for now.
152 template <typename Dtype>
153 Filler<Dtype>* GetFiller(const FillerParameter& param) {
154   const std::string& type = param.type();
155   if (type == "constant") {
156     return new ConstantFiller<Dtype>(param);
157   } else if (type == "gaussian") {
158     return new GaussianFiller<Dtype>(param);
159   } else if (type == "positive_unitball") {
160     return new PositiveUnitballFiller<Dtype>(param);
161   } else if (type == "uniform") {
162     return new UniformFiller<Dtype>(param);
163   } else if (type == "xavier") {
164     return new XavierFiller<Dtype>(param);
165   } else {
166     CHECK(false) << "Unknown filler name: " << param.type();
167   }
168   return (Filler<Dtype>*)(NULL);
169 }
170
171 }  // namespace caffe
172
173 #endif  // CAFFE_FILLER_HPP_