explicit UniformFiller(const FillerParameter& param)
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
- DCHECK(blob->count());
+ CHECK(blob->count());
caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
Dtype(this->filler_param_.min()),
Dtype(this->filler_param_.max()));
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
Dtype* data = blob->mutable_cpu_data();
- DCHECK(blob->count());
+ CHECK(blob->count());
caffe_vRngGaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
Dtype(this->filler_param_.mean()),
Dtype(this->filler_param_.std()));
// We expect the filler to not be called very frequently, so we will
// just use a simple implementation
int dim = blob->count() / blob->num();
- DCHECK(dim);
+ CHECK(dim);
for (int i = 0; i < blob->num(); ++i) {
Dtype sum = 0;
for (int j = 0; j < dim; ++j) {
}
};
+// A filler based on the paper [Bengio and Glorot 2010]: Understanding
+// the difficulty of training deep feedforward neuralnetworks, but does not
+// use the fan_out value.
+//
+// It fills the incoming matrix by randomly sampling uniform data from
+// [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
+// of input nodes, and in our case we consider the blob width as the scale.
+// You should make sure the input blob has shape (1, 1, height, width).
template <typename Dtype>
class XavierFiller : public Filler<Dtype> {
public:
explicit XavierFiller(const FillerParameter& param)
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
-
+ CHECK(blob->count());
+ CHECK_EQ(blob->num(), 1) << "XavierFiller requires blob.num() = 1.";
+ CHECK_EQ(blob->channels(), 1)
+ << "XavierFiller requires blob.channels() = 1.";
+ int fan_in = blob->width();
+ Dtype scale = sqrt(Dtype(3) / fan_in);
+ caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
+ -scale, scale);
}
};
const std::string& type = param.type();
if (type == "constant") {
return new ConstantFiller<Dtype>(param);
- } else if (type == "uniform") {
- return new UniformFiller<Dtype>(param);
} else if (type == "gaussian") {
return new GaussianFiller<Dtype>(param);
} else if (type == "positive_unitball") {
return new PositiveUnitballFiller<Dtype>(param);
+ } else if (type == "uniform") {
+ return new UniformFiller<Dtype>(param);
+ } else if (type == "xavier") {
+ return new XavierFiller<Dtype>(param);
} else {
CHECK(false) << "Unknown filler name: " << param.type();
}