fixing pooling SetUp() to allow default values for stride and pad
authorRonghang Hu <huronghang@hotmail.com>
Sat, 5 Jul 2014 00:29:56 +0000 (17:29 -0700)
committerRonghang Hu <huronghang@hotmail.com>
Sat, 5 Jul 2014 15:11:03 +0000 (08:11 -0700)
src/caffe/layers/pooling_layer.cpp

index 3b64741..d4feaad 100644 (file)
@@ -35,15 +35,15 @@ void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK(pool_param.has_kernel_size() ||
       (pool_param.has_kernel_h() && pool_param.has_kernel_w()))
       << "For non-square filters both kernel_h and kernel_w are required.";
-  CHECK((pool_param.has_pad()
-      && !(pool_param.has_pad_h() || pool_param.has_pad_w()))
-      || (!pool_param.has_pad()
-      && (pool_param.has_pad_h() && pool_param.has_pad_w())
-      || (!pool_param.has_pad_h() && !pool_param.has_pad_w())))
-      << "Padding size is pad OR pad_h and pad_w; not both";
-  CHECK(!pool_param.has_stride() !=
-      !(pool_param.has_stride_h() && pool_param.has_stride_w()))
+  CHECK((!pool_param.has_pad() && pool_param.has_pad_h() 
+      && pool_param.has_pad_w())
+      || (!pool_param.has_pad_h() && !pool_param.has_pad_w()))
+      << "pad is pad OR pad_h and pad_w are required.";
+  CHECK((!pool_param.has_stride() && pool_param.has_stride_h() 
+      && pool_param.has_stride_w())
+      || (!pool_param.has_stride_h() && !pool_param.has_stride_w()))
       << "Stride is stride OR stride_h and stride_w are required.";
+
   if (pool_param.has_kernel_size()) {
     kernel_h_ = kernel_w_ = pool_param.kernel_size();
   } else {
@@ -51,13 +51,13 @@ void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
     kernel_w_ = pool_param.kernel_w();
   }
   CHECK_GT(kernel_h_ * kernel_w_, 0) << "Filter dimensions cannot be zero.";
-  if (pool_param.has_pad()) {
+  if (!pool_param.has_pad_h()) {
     pad_h_ = pad_w_ = pool_param.pad();
   } else {
     pad_h_ = pool_param.pad_h();
     pad_w_ = pool_param.pad_w();
   }
-  if (pool_param.has_stride()) {
+  if (!pool_param.has_stride_h()) {
     stride_h_ = stride_w_ = pool_param.stride();
   } else {
     stride_h_ = pool_param.stride_h();