1D blob handling in MSRA/Xavier fillers
authorNoiredd <snowball91b@gmail.com>
Tue, 6 Mar 2018 12:39:49 +0000 (13:39 +0100)
committerWook Song <wook16.song@samsung.com>
Thu, 23 Jan 2020 13:50:35 +0000 (22:50 +0900)
include/caffe/filler.hpp

index e3e86a5..a447736 100644 (file)
@@ -148,8 +148,10 @@ class XavierFiller : public Filler<Dtype> {
   virtual void Fill(Blob<Dtype>* blob) {
     CHECK(blob->count());
     int fan_in = blob->count() / blob->shape(0);
-    // Compatible for ND Convolution
-    int fan_out = blob->count() / blob->shape(1);
+    // Compatibility with ND blobs
+    int fan_out = blob->num_axes() > 1 ?
+                  blob->count() / blob->shape(1) :
+                  blob->count();
     Dtype n = fan_in;  // default to fan_in
     if (this->filler_param_.variance_norm() ==
         FillerParameter_VarianceNorm_AVERAGE) {
@@ -191,8 +193,10 @@ class MSRAFiller : public Filler<Dtype> {
   virtual void Fill(Blob<Dtype>* blob) {
     CHECK(blob->count());
     int fan_in = blob->count() / blob->shape(0);
-    // Compatible for ND Convolution
-    int fan_out = blob->count() / blob->shape(1);
+    // Compatibility with ND blobs
+    int fan_out = blob->num_axes() > 1 ?
+                  blob->count() / blob->shape(1) :
+                  blob->count();
     Dtype n = fan_in;  // default to fan_in
     if (this->filler_param_.variance_norm() ==
         FillerParameter_VarianceNorm_AVERAGE) {