1 // Copyright 2014 BVLC and contributors.
3 // Fillers are random number generators that fills a blob using the specified
4 // algorithm. The expectation is that they are only going to be used during
5 // initialization time and will not involve any GPUs.
7 #ifndef CAFFE_FILLER_HPP
8 #define CAFFE_FILLER_HPP
12 #include "caffe/common.hpp"
13 #include "caffe/blob.hpp"
14 #include "caffe/syncedmem.hpp"
15 #include "caffe/util/math_functions.hpp"
16 #include "caffe/proto/caffe.pb.h"
20 template <typename Dtype>
23 explicit Filler(const FillerParameter& param) : filler_param_(param) {}
25 virtual void Fill(Blob<Dtype>* blob) = 0;
27 FillerParameter filler_param_;
31 template <typename Dtype>
32 class ConstantFiller : public Filler<Dtype> {
34 explicit ConstantFiller(const FillerParameter& param)
35 : Filler<Dtype>(param) {}
36 virtual void Fill(Blob<Dtype>* blob) {
37 Dtype* data = blob->mutable_cpu_data();
38 const int count = blob->count();
39 const Dtype value = this->filler_param_.value();
41 for (int i = 0; i < count; ++i) {
47 template <typename Dtype>
48 class UniformFiller : public Filler<Dtype> {
50 explicit UniformFiller(const FillerParameter& param)
51 : Filler<Dtype>(param) {}
52 virtual void Fill(Blob<Dtype>* blob) {
54 caffe_rng_uniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
55 Dtype(this->filler_param_.min()),
56 Dtype(this->filler_param_.max()));
60 template <typename Dtype>
61 class GaussianFiller : public Filler<Dtype> {
63 explicit GaussianFiller(const FillerParameter& param)
64 : Filler<Dtype>(param) {}
65 virtual void Fill(Blob<Dtype>* blob) {
66 Dtype* data = blob->mutable_cpu_data();
68 caffe_rng_gaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
69 Dtype(this->filler_param_.mean()),
70 Dtype(this->filler_param_.std()));
74 template <typename Dtype>
75 class PositiveUnitballFiller : public Filler<Dtype> {
77 explicit PositiveUnitballFiller(const FillerParameter& param)
78 : Filler<Dtype>(param) {}
79 virtual void Fill(Blob<Dtype>* blob) {
80 Dtype* data = blob->mutable_cpu_data();
81 DCHECK(blob->count());
82 caffe_rng_uniform<Dtype>(blob->count(), blob->mutable_cpu_data(), 0, 1);
83 // We expect the filler to not be called very frequently, so we will
84 // just use a simple implementation
85 int dim = blob->count() / blob->num();
87 for (int i = 0; i < blob->num(); ++i) {
89 for (int j = 0; j < dim; ++j) {
90 sum += data[i * dim + j];
92 for (int j = 0; j < dim; ++j) {
93 data[i * dim + j] /= sum;
99 // A filler based on the paper [Bengio and Glorot 2010]: Understanding
100 // the difficulty of training deep feedforward neuralnetworks, but does not
101 // use the fan_out value.
103 // It fills the incoming matrix by randomly sampling uniform data from
104 // [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
105 // of input nodes. You should make sure the input blob has shape (num, a, b, c)
106 // where a * b * c = fan_in.
107 template <typename Dtype>
108 class XavierFiller : public Filler<Dtype> {
110 explicit XavierFiller(const FillerParameter& param)
111 : Filler<Dtype>(param) {}
112 virtual void Fill(Blob<Dtype>* blob) {
113 CHECK(blob->count());
114 int fan_in = blob->count() / blob->num();
115 Dtype scale = sqrt(Dtype(3) / fan_in);
116 caffe_rng_uniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
122 // A function to get a specific filler from the specification given in
123 // FillerParameter. Ideally this would be replaced by a factory pattern,
124 // but we will leave it this way for now.
125 template <typename Dtype>
126 Filler<Dtype>* GetFiller(const FillerParameter& param) {
127 const std::string& type = param.type();
128 if (type == "constant") {
129 return new ConstantFiller<Dtype>(param);
130 } else if (type == "gaussian") {
131 return new GaussianFiller<Dtype>(param);
132 } else if (type == "positive_unitball") {
133 return new PositiveUnitballFiller<Dtype>(param);
134 } else if (type == "uniform") {
135 return new UniformFiller<Dtype>(param);
136 } else if (type == "xavier") {
137 return new XavierFiller<Dtype>(param);
139 CHECK(false) << "Unknown filler name: " << param.type();
141 return (Filler<Dtype>*)(NULL);
146 #endif // CAFFE_FILLER_HPP_