1 // Fillers are random number generators that fills a blob using the specified
2 // algorithm. The expectation is that they are only going to be used during
3 // initialization time and will not involve any GPUs.
5 #ifndef CAFFE_FILLER_HPP
6 #define CAFFE_FILLER_HPP
10 #include "caffe/blob.hpp"
11 #include "caffe/common.hpp"
12 #include "caffe/proto/caffe.pb.h"
13 #include "caffe/syncedmem.hpp"
14 #include "caffe/util/math_functions.hpp"
18 template <typename Dtype>
21 explicit Filler(const FillerParameter& param) : filler_param_(param) {}
23 virtual void Fill(Blob<Dtype>* blob) = 0;
25 FillerParameter filler_param_;
29 template <typename Dtype>
30 class ConstantFiller : public Filler<Dtype> {
32 explicit ConstantFiller(const FillerParameter& param)
33 : Filler<Dtype>(param) {}
34 virtual void Fill(Blob<Dtype>* blob) {
35 Dtype* data = blob->mutable_cpu_data();
36 const int count = blob->count();
37 const Dtype value = this->filler_param_.value();
39 for (int i = 0; i < count; ++i) {
42 CHECK_EQ(this->filler_param_.sparse(), -1)
43 << "Sparsity not supported by this Filler.";
47 template <typename Dtype>
48 class UniformFiller : public Filler<Dtype> {
50 explicit UniformFiller(const FillerParameter& param)
51 : Filler<Dtype>(param) {}
52 virtual void Fill(Blob<Dtype>* blob) {
54 caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
55 Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
56 CHECK_EQ(this->filler_param_.sparse(), -1)
57 << "Sparsity not supported by this Filler.";
61 template <typename Dtype>
62 class GaussianFiller : public Filler<Dtype> {
64 explicit GaussianFiller(const FillerParameter& param)
65 : Filler<Dtype>(param) {}
66 virtual void Fill(Blob<Dtype>* blob) {
67 Dtype* data = blob->mutable_cpu_data();
69 caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
70 Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
71 int sparse = this->filler_param_.sparse();
74 // Sparse initialization is implemented for "weight" blobs; i.e. matrices.
75 // These have num == channels == 1; height is number of inputs; width is
76 // number of outputs. The 'sparse' variable specifies the mean number
77 // of non-zero input weights for a given output.
78 CHECK_EQ(blob->num(), 1);
79 CHECK_EQ(blob->channels(), 1);
80 int num_inputs = blob->height();
81 Dtype non_zero_probability = Dtype(sparse) / Dtype(num_inputs);
82 rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));
83 int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());
84 caffe_rng_bernoulli(blob->count(), non_zero_probability, mask);
85 for (int i = 0; i < blob->count(); ++i) {
92 shared_ptr<SyncedMemory> rand_vec_;
95 template <typename Dtype>
96 class PositiveUnitballFiller : public Filler<Dtype> {
98 explicit PositiveUnitballFiller(const FillerParameter& param)
99 : Filler<Dtype>(param) {}
100 virtual void Fill(Blob<Dtype>* blob) {
101 Dtype* data = blob->mutable_cpu_data();
102 DCHECK(blob->count());
103 caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
104 // We expect the filler to not be called very frequently, so we will
105 // just use a simple implementation
106 int dim = blob->count() / blob->num();
108 for (int i = 0; i < blob->num(); ++i) {
110 for (int j = 0; j < dim; ++j) {
111 sum += data[i * dim + j];
113 for (int j = 0; j < dim; ++j) {
114 data[i * dim + j] /= sum;
117 CHECK_EQ(this->filler_param_.sparse(), -1)
118 << "Sparsity not supported by this Filler.";
122 // A filler based on the paper [Bengio and Glorot 2010]: Understanding
123 // the difficulty of training deep feedforward neuralnetworks, but does not
124 // use the fan_out value.
126 // It fills the incoming matrix by randomly sampling uniform data from
127 // [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
128 // of input nodes. You should make sure the input blob has shape (num, a, b, c)
129 // where a * b * c = fan_in.
130 template <typename Dtype>
131 class XavierFiller : public Filler<Dtype> {
133 explicit XavierFiller(const FillerParameter& param)
134 : Filler<Dtype>(param) {}
135 virtual void Fill(Blob<Dtype>* blob) {
136 CHECK(blob->count());
137 int fan_in = blob->count() / blob->num();
138 Dtype scale = sqrt(Dtype(3) / fan_in);
139 caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
140 blob->mutable_cpu_data());
141 CHECK_EQ(this->filler_param_.sparse(), -1)
142 << "Sparsity not supported by this Filler.";
147 // A function to get a specific filler from the specification given in
148 // FillerParameter. Ideally this would be replaced by a factory pattern,
149 // but we will leave it this way for now.
150 template <typename Dtype>
151 Filler<Dtype>* GetFiller(const FillerParameter& param) {
152 const std::string& type = param.type();
153 if (type == "constant") {
154 return new ConstantFiller<Dtype>(param);
155 } else if (type == "gaussian") {
156 return new GaussianFiller<Dtype>(param);
157 } else if (type == "positive_unitball") {
158 return new PositiveUnitballFiller<Dtype>(param);
159 } else if (type == "uniform") {
160 return new UniformFiller<Dtype>(param);
161 } else if (type == "xavier") {
162 return new XavierFiller<Dtype>(param);
164 CHECK(false) << "Unknown filler name: " << param.type();
166 return (Filler<Dtype>*)(NULL);
171 #endif // CAFFE_FILLER_HPP_