1 // Copyright 2013 Yangqing Jia
3 // Fillers are random number generators that fills a blob using the specified
4 // algorithm. The expectation is that they are only going to be used during
5 // initialization time and will not involve any GPUs.
7 #ifndef CAFFE_FILLER_HPP
8 #define CAFFE_FILLER_HPP
13 #include "caffe/common.hpp"
14 #include "caffe/blob.hpp"
15 #include "caffe/syncedmem.hpp"
16 #include "caffe/util/math_functions.hpp"
17 #include "caffe/proto/caffe.pb.h"
21 template <typename Dtype>
24 explicit Filler(const FillerParameter& param) : filler_param_(param) {}
26 virtual void Fill(Blob<Dtype>* blob) = 0;
28 FillerParameter filler_param_;
32 template <typename Dtype>
33 class ConstantFiller : public Filler<Dtype> {
35 explicit ConstantFiller(const FillerParameter& param)
36 : Filler<Dtype>(param) {}
37 virtual void Fill(Blob<Dtype>* blob) {
38 Dtype* data = blob->mutable_cpu_data();
39 const int count = blob->count();
40 const Dtype value = this->filler_param_.value();
42 for (int i = 0; i < count; ++i) {
48 template <typename Dtype>
49 class UniformFiller : public Filler<Dtype> {
51 explicit UniformFiller(const FillerParameter& param)
52 : Filler<Dtype>(param) {}
53 virtual void Fill(Blob<Dtype>* blob) {
55 caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
56 Dtype(this->filler_param_.min()),
57 Dtype(this->filler_param_.max()));
61 template <typename Dtype>
62 class GaussianFiller : public Filler<Dtype> {
64 explicit GaussianFiller(const FillerParameter& param)
65 : Filler<Dtype>(param) {}
66 virtual void Fill(Blob<Dtype>* blob) {
67 Dtype* data = blob->mutable_cpu_data();
69 caffe_vRngGaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
70 Dtype(this->filler_param_.mean()),
71 Dtype(this->filler_param_.std()));
75 template <typename Dtype>
76 class PositiveUnitballFiller : public Filler<Dtype> {
78 explicit PositiveUnitballFiller(const FillerParameter& param)
79 : Filler<Dtype>(param) {}
80 virtual void Fill(Blob<Dtype>* blob) {
81 Dtype* data = blob->mutable_cpu_data();
82 DCHECK(blob->count());
83 caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(), 0, 1);
84 // We expect the filler to not be called very frequently, so we will
85 // just use a simple implementation
86 int dim = blob->count() / blob->num();
88 for (int i = 0; i < blob->num(); ++i) {
90 for (int j = 0; j < dim; ++j) {
91 sum += data[i * dim + j];
93 for (int j = 0; j < dim; ++j) {
94 data[i * dim + j] /= sum;
100 // A filler based on the paper [Bengio and Glorot 2010]: Understanding
101 // the difficulty of training deep feedforward neuralnetworks, but does not
102 // use the fan_out value.
104 // It fills the incoming matrix by randomly sampling uniform data from
105 // [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
106 // of input nodes. You should make sure the input blob has shape (num, a, b, c)
107 // where a * b * c = fan_in.
108 template <typename Dtype>
109 class XavierFiller : public Filler<Dtype> {
111 explicit XavierFiller(const FillerParameter& param)
112 : Filler<Dtype>(param) {}
113 virtual void Fill(Blob<Dtype>* blob) {
114 CHECK(blob->count());
115 int fan_in = blob->count() / blob->num();
116 Dtype scale = sqrt(Dtype(3) / fan_in);
117 caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
123 // A function to get a specific filler from the specification given in
124 // FillerParameter. Ideally this would be replaced by a factory pattern,
125 // but we will leave it this way for now.
126 template <typename Dtype>
127 Filler<Dtype>* GetFiller(const FillerParameter& param) {
128 const std::string& type = param.type();
129 if (type == "constant") {
130 return new ConstantFiller<Dtype>(param);
131 } else if (type == "gaussian") {
132 return new GaussianFiller<Dtype>(param);
133 } else if (type == "positive_unitball") {
134 return new PositiveUnitballFiller<Dtype>(param);
135 } else if (type == "uniform") {
136 return new UniformFiller<Dtype>(param);
137 } else if (type == "xavier") {
138 return new XavierFiller<Dtype>(param);
140 CHECK(false) << "Unknown filler name: " << param.type();
142 return (Filler<Dtype>*)(NULL);
147 #endif // CAFFE_FILLER_HPP_