make RNG function names more similar to other caffe math function names
[platform/upstream/caffeonacl.git] / include / caffe / filler.hpp
1 // Copyright 2014 BVLC and contributors.
2
3 // Fillers are random number generators that fills a blob using the specified
4 // algorithm. The expectation is that they are only going to be used during
5 // initialization time and will not involve any GPUs.
6
7 #ifndef CAFFE_FILLER_HPP
8 #define CAFFE_FILLER_HPP
9
10 #include <string>
11
12 #include "caffe/common.hpp"
13 #include "caffe/blob.hpp"
14 #include "caffe/syncedmem.hpp"
15 #include "caffe/util/math_functions.hpp"
16 #include "caffe/proto/caffe.pb.h"
17
18 namespace caffe {
19
20 template <typename Dtype>
21 class Filler {
22  public:
23   explicit Filler(const FillerParameter& param) : filler_param_(param) {}
24   virtual ~Filler() {}
25   virtual void Fill(Blob<Dtype>* blob) = 0;
26  protected:
27   FillerParameter filler_param_;
28 };  // class Filler
29
30
31 template <typename Dtype>
32 class ConstantFiller : public Filler<Dtype> {
33  public:
34   explicit ConstantFiller(const FillerParameter& param)
35       : Filler<Dtype>(param) {}
36   virtual void Fill(Blob<Dtype>* blob) {
37     Dtype* data = blob->mutable_cpu_data();
38     const int count = blob->count();
39     const Dtype value = this->filler_param_.value();
40     CHECK(count);
41     for (int i = 0; i < count; ++i) {
42       data[i] = value;
43     }
44   }
45 };
46
47 template <typename Dtype>
48 class UniformFiller : public Filler<Dtype> {
49  public:
50   explicit UniformFiller(const FillerParameter& param)
51       : Filler<Dtype>(param) {}
52   virtual void Fill(Blob<Dtype>* blob) {
53     CHECK(blob->count());
54     caffe_rng_uniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
55         Dtype(this->filler_param_.min()),
56         Dtype(this->filler_param_.max()));
57   }
58 };
59
60 template <typename Dtype>
61 class GaussianFiller : public Filler<Dtype> {
62  public:
63   explicit GaussianFiller(const FillerParameter& param)
64       : Filler<Dtype>(param) {}
65   virtual void Fill(Blob<Dtype>* blob) {
66     Dtype* data = blob->mutable_cpu_data();
67     CHECK(blob->count());
68     caffe_rng_gaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
69         Dtype(this->filler_param_.mean()),
70         Dtype(this->filler_param_.std()));
71   }
72 };
73
74 template <typename Dtype>
75 class PositiveUnitballFiller : public Filler<Dtype> {
76  public:
77   explicit PositiveUnitballFiller(const FillerParameter& param)
78       : Filler<Dtype>(param) {}
79   virtual void Fill(Blob<Dtype>* blob) {
80     Dtype* data = blob->mutable_cpu_data();
81     DCHECK(blob->count());
82     caffe_rng_uniform<Dtype>(blob->count(), blob->mutable_cpu_data(), 0, 1);
83     // We expect the filler to not be called very frequently, so we will
84     // just use a simple implementation
85     int dim = blob->count() / blob->num();
86     CHECK(dim);
87     for (int i = 0; i < blob->num(); ++i) {
88       Dtype sum = 0;
89       for (int j = 0; j < dim; ++j) {
90         sum += data[i * dim + j];
91       }
92       for (int j = 0; j < dim; ++j) {
93         data[i * dim + j] /= sum;
94       }
95     }
96   }
97 };
98
99 // A filler based on the paper [Bengio and Glorot 2010]: Understanding
100 // the difficulty of training deep feedforward neuralnetworks, but does not
101 // use the fan_out value.
102 //
103 // It fills the incoming matrix by randomly sampling uniform data from
104 // [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
105 // of input nodes. You should make sure the input blob has shape (num, a, b, c)
106 // where a * b * c = fan_in.
107 template <typename Dtype>
108 class XavierFiller : public Filler<Dtype> {
109  public:
110   explicit XavierFiller(const FillerParameter& param)
111       : Filler<Dtype>(param) {}
112   virtual void Fill(Blob<Dtype>* blob) {
113     CHECK(blob->count());
114     int fan_in = blob->count() / blob->num();
115     Dtype scale = sqrt(Dtype(3) / fan_in);
116     caffe_rng_uniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
117         -scale, scale);
118   }
119 };
120
121
122 // A function to get a specific filler from the specification given in
123 // FillerParameter. Ideally this would be replaced by a factory pattern,
124 // but we will leave it this way for now.
125 template <typename Dtype>
126 Filler<Dtype>* GetFiller(const FillerParameter& param) {
127   const std::string& type = param.type();
128   if (type == "constant") {
129     return new ConstantFiller<Dtype>(param);
130   } else if (type == "gaussian") {
131     return new GaussianFiller<Dtype>(param);
132   } else if (type == "positive_unitball") {
133     return new PositiveUnitballFiller<Dtype>(param);
134   } else if (type == "uniform") {
135     return new UniformFiller<Dtype>(param);
136   } else if (type == "xavier") {
137     return new XavierFiller<Dtype>(param);
138   } else {
139     CHECK(false) << "Unknown filler name: " << param.type();
140   }
141   return (Filler<Dtype>*)(NULL);
142 }
143
144 }  // namespace caffe
145
146 #endif  // CAFFE_FILLER_HPP_