Fix sparse GaussianFiller for new IPLayer weight axes
authorJeff Donahue <jeff.donahue@gmail.com>
Mon, 16 Feb 2015 09:29:17 +0000 (01:29 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 3 Mar 2015 23:55:13 +0000 (15:55 -0800)
include/caffe/filler.hpp

index eebf565..bb18e8e 100644 (file)
@@ -79,9 +79,8 @@ class GaussianFiller : public Filler<Dtype> {
       // These have num == channels == 1; width is number of inputs; height is
       // number of outputs.  The 'sparse' variable specifies the mean number
       // of non-zero input weights for a given output.
-      CHECK_EQ(blob->num(), 1);
-      CHECK_EQ(blob->channels(), 1);
-      int num_outputs = blob->height();
+      CHECK_GE(blob->num_axes(), 1);
+      const int num_outputs = blob->shape(0);
       Dtype non_zero_probability = Dtype(sparse) / Dtype(num_outputs);
       rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));
       int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());