// These have num == channels == 1; width is number of inputs; height is
// number of outputs. The 'sparse' variable specifies the mean number
// of non-zero input weights for a given output.
- CHECK_EQ(blob->num(), 1);
- CHECK_EQ(blob->channels(), 1);
- int num_outputs = blob->height();
+ CHECK_GE(blob->num_axes(), 1);
+ const int num_outputs = blob->shape(0);
Dtype non_zero_probability = Dtype(sparse) / Dtype(num_outputs);
rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));
int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());