Merge pull request #286 from longjon/pycaffe-solver
[platform/upstream/caffeonacl.git] / include / caffe / filler.hpp
1 // Copyright 2014 BVLC and contributors.
2
3 // Fillers are random number generators that fills a blob using the specified
4 // algorithm. The expectation is that they are only going to be used during
5 // initialization time and will not involve any GPUs.
6
7 #ifndef CAFFE_FILLER_HPP
8 #define CAFFE_FILLER_HPP
9
10 #include <string>
11
12 #include "caffe/common.hpp"
13 #include "caffe/blob.hpp"
14 #include "caffe/syncedmem.hpp"
15 #include "caffe/util/math_functions.hpp"
16 #include "caffe/proto/caffe.pb.h"
17
18 namespace caffe {
19
20 template <typename Dtype>
21 class Filler {
22  public:
23   explicit Filler(const FillerParameter& param) : filler_param_(param) {}
24   virtual ~Filler() {}
25   virtual void Fill(Blob<Dtype>* blob) = 0;
26  protected:
27   FillerParameter filler_param_;
28 };  // class Filler
29
30
31 template <typename Dtype>
32 class ConstantFiller : public Filler<Dtype> {
33  public:
34   explicit ConstantFiller(const FillerParameter& param)
35       : Filler<Dtype>(param) {}
36   virtual void Fill(Blob<Dtype>* blob) {
37     Dtype* data = blob->mutable_cpu_data();
38     const int count = blob->count();
39     const Dtype value = this->filler_param_.value();
40     CHECK(count);
41     for (int i = 0; i < count; ++i) {
42       data[i] = value;
43     }
44   }
45 };
46
47 template <typename Dtype>
48 class UniformFiller : public Filler<Dtype> {
49  public:
50   explicit UniformFiller(const FillerParameter& param)
51       : Filler<Dtype>(param) {}
52   virtual void Fill(Blob<Dtype>* blob) {
53     CHECK(blob->count());
54     caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
55         Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
56   }
57 };
58
59 template <typename Dtype>
60 class GaussianFiller : public Filler<Dtype> {
61  public:
62   explicit GaussianFiller(const FillerParameter& param)
63       : Filler<Dtype>(param) {}
64   virtual void Fill(Blob<Dtype>* blob) {
65     Dtype* data = blob->mutable_cpu_data();
66     CHECK(blob->count());
67     caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
68         Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
69   }
70 };
71
72 template <typename Dtype>
73 class PositiveUnitballFiller : public Filler<Dtype> {
74  public:
75   explicit PositiveUnitballFiller(const FillerParameter& param)
76       : Filler<Dtype>(param) {}
77   virtual void Fill(Blob<Dtype>* blob) {
78     Dtype* data = blob->mutable_cpu_data();
79     DCHECK(blob->count());
80     caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
81     // We expect the filler to not be called very frequently, so we will
82     // just use a simple implementation
83     int dim = blob->count() / blob->num();
84     CHECK(dim);
85     for (int i = 0; i < blob->num(); ++i) {
86       Dtype sum = 0;
87       for (int j = 0; j < dim; ++j) {
88         sum += data[i * dim + j];
89       }
90       for (int j = 0; j < dim; ++j) {
91         data[i * dim + j] /= sum;
92       }
93     }
94   }
95 };
96
97 // A filler based on the paper [Bengio and Glorot 2010]: Understanding
98 // the difficulty of training deep feedforward neuralnetworks, but does not
99 // use the fan_out value.
100 //
101 // It fills the incoming matrix by randomly sampling uniform data from
102 // [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
103 // of input nodes. You should make sure the input blob has shape (num, a, b, c)
104 // where a * b * c = fan_in.
105 template <typename Dtype>
106 class XavierFiller : public Filler<Dtype> {
107  public:
108   explicit XavierFiller(const FillerParameter& param)
109       : Filler<Dtype>(param) {}
110   virtual void Fill(Blob<Dtype>* blob) {
111     CHECK(blob->count());
112     int fan_in = blob->count() / blob->num();
113     Dtype scale = sqrt(Dtype(3) / fan_in);
114     caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
115         blob->mutable_cpu_data());
116   }
117 };
118
119
120 // A function to get a specific filler from the specification given in
121 // FillerParameter. Ideally this would be replaced by a factory pattern,
122 // but we will leave it this way for now.
123 template <typename Dtype>
124 Filler<Dtype>* GetFiller(const FillerParameter& param) {
125   const std::string& type = param.type();
126   if (type == "constant") {
127     return new ConstantFiller<Dtype>(param);
128   } else if (type == "gaussian") {
129     return new GaussianFiller<Dtype>(param);
130   } else if (type == "positive_unitball") {
131     return new PositiveUnitballFiller<Dtype>(param);
132   } else if (type == "uniform") {
133     return new UniformFiller<Dtype>(param);
134   } else if (type == "xavier") {
135     return new XavierFiller<Dtype>(param);
136   } else {
137     CHECK(false) << "Unknown filler name: " << param.type();
138   }
139   return (Filler<Dtype>*)(NULL);
140 }
141
142 }  // namespace caffe
143
144 #endif  // CAFFE_FILLER_HPP_