Fix compatibility for ND convolution
authorknsong <sunskn@163.com>
Sat, 17 Feb 2018 07:56:32 +0000 (15:56 +0800)
committerknsong <sunskn@163.com>
Sat, 17 Feb 2018 07:56:32 +0000 (15:56 +0800)
include/caffe/filler.hpp

index bb92ded..e3e86a5 100644 (file)
@@ -108,9 +108,9 @@ class PositiveUnitballFiller : public Filler<Dtype> {
     caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
     // We expect the filler to not be called very frequently, so we will
     // just use a simple implementation
-    int dim = blob->count() / blob->num();
+    int dim = blob->count() / blob->shape(0);
     CHECK(dim);
-    for (int i = 0; i < blob->num(); ++i) {
+    for (int i = 0; i < blob->shape(0); ++i) {
       Dtype sum = 0;
       for (int j = 0; j < dim; ++j) {
         sum += data[i * dim + j];
@@ -147,8 +147,9 @@ class XavierFiller : public Filler<Dtype> {
       : Filler<Dtype>(param) {}
   virtual void Fill(Blob<Dtype>* blob) {
     CHECK(blob->count());
-    int fan_in = blob->count() / blob->num();
-    int fan_out = blob->count() / blob->channels();
+    int fan_in = blob->count() / blob->shape(0);
+    // Compatible for ND Convolution
+    int fan_out = blob->count() / blob->shape(1);
     Dtype n = fan_in;  // default to fan_in
     if (this->filler_param_.variance_norm() ==
         FillerParameter_VarianceNorm_AVERAGE) {
@@ -189,8 +190,9 @@ class MSRAFiller : public Filler<Dtype> {
       : Filler<Dtype>(param) {}
   virtual void Fill(Blob<Dtype>* blob) {
     CHECK(blob->count());
-    int fan_in = blob->count() / blob->num();
-    int fan_out = blob->count() / blob->channels();
+    int fan_in = blob->count() / blob->shape(0);
+    // Compatible for ND Convolution
+    int fan_out = blob->count() / blob->shape(1);
     Dtype n = fan_in;  // default to fan_in
     if (this->filler_param_.variance_norm() ==
         FillerParameter_VarianceNorm_AVERAGE) {