From ced55b009ae4fd6c0685543a013b1439da5879ba Mon Sep 17 00:00:00 2001 From: knsong Date: Sat, 17 Feb 2018 15:56:32 +0800 Subject: [PATCH] Fix compatibility for ND convolution --- include/caffe/filler.hpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp index bb92ded..e3e86a5 100644 --- a/include/caffe/filler.hpp +++ b/include/caffe/filler.hpp @@ -108,9 +108,9 @@ class PositiveUnitballFiller : public Filler { caffe_rng_uniform(blob->count(), 0, 1, blob->mutable_cpu_data()); // We expect the filler to not be called very frequently, so we will // just use a simple implementation - int dim = blob->count() / blob->num(); + int dim = blob->count() / blob->shape(0); CHECK(dim); - for (int i = 0; i < blob->num(); ++i) { + for (int i = 0; i < blob->shape(0); ++i) { Dtype sum = 0; for (int j = 0; j < dim; ++j) { sum += data[i * dim + j]; @@ -147,8 +147,9 @@ class XavierFiller : public Filler { : Filler(param) {} virtual void Fill(Blob* blob) { CHECK(blob->count()); - int fan_in = blob->count() / blob->num(); - int fan_out = blob->count() / blob->channels(); + int fan_in = blob->count() / blob->shape(0); + // Compatible for ND Convolution + int fan_out = blob->count() / blob->shape(1); Dtype n = fan_in; // default to fan_in if (this->filler_param_.variance_norm() == FillerParameter_VarianceNorm_AVERAGE) { @@ -189,8 +190,9 @@ class MSRAFiller : public Filler { : Filler(param) {} virtual void Fill(Blob* blob) { CHECK(blob->count()); - int fan_in = blob->count() / blob->num(); - int fan_out = blob->count() / blob->channels(); + int fan_in = blob->count() / blob->shape(0); + // Compatible for ND Convolution + int fan_out = blob->count() / blob->shape(1); Dtype n = fan_in; // default to fan_in if (this->filler_param_.variance_norm() == FillerParameter_VarianceNorm_AVERAGE) { -- 2.7.4