virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->shape(0);
- // Compatible for ND Convolution
- int fan_out = blob->count() / blob->shape(1);
+ // Compatibility with ND blobs
+ int fan_out = blob->num_axes() > 1 ?
+ blob->count() / blob->shape(1) :
+ blob->count();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->shape(0);
- // Compatible for ND Convolution
- int fan_out = blob->count() / blob->shape(1);
+ // Compatibility with ND blobs
+ int fan_out = blob->num_axes() > 1 ?
+ blob->count() / blob->shape(1) :
+ blob->count();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {