add tests for rectangular pooling regions
authorRonghang Hu <huronghang@hotmail.com>
Sat, 5 Jul 2014 15:21:50 +0000 (08:21 -0700)
committerRonghang Hu <huronghang@hotmail.com>
Sat, 5 Jul 2014 15:21:50 +0000 (08:21 -0700)
include/caffe/vision_layers.hpp
src/caffe/layers/pooling_layer.cpp
src/caffe/layers/pooling_layer.cu
src/caffe/proto/caffe.proto
src/caffe/test/test_pooling_layer.cpp

index 62b65bc..450524e 100644 (file)
@@ -361,7 +361,7 @@ class PoolingLayer : public Layer<Dtype> {
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   int max_top_blobs_;
-  int kernel_size_h_, kernel_size_w_;
+  int kernel_h_, kernel_w_;
   int stride_h_, stride_w_;
   int pad_h_, pad_w_;
   int channels_;
index d4feaad..d8cd2e2 100644 (file)
@@ -43,14 +43,14 @@ void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       && pool_param.has_stride_w())
       || (!pool_param.has_stride_h() && !pool_param.has_stride_w()))
       << "Stride is stride OR stride_h and stride_w are required.";
-
   if (pool_param.has_kernel_size()) {
     kernel_h_ = kernel_w_ = pool_param.kernel_size();
   } else {
     kernel_h_ = pool_param.kernel_h();
     kernel_w_ = pool_param.kernel_w();
   }
-  CHECK_GT(kernel_h_ * kernel_w_, 0) << "Filter dimensions cannot be zero.";
+  CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
+  CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
   if (!pool_param.has_pad_h()) {
     pad_h_ = pad_w_ = pool_param.pad();
   } else {
@@ -69,16 +69,16 @@ void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
         || this->layer_param_.pooling_param().pool()
         == PoolingParameter_PoolMethod_MAX)
         << "Padding implemented only for average and max pooling.";
-    CHECK_LT(pad_h_, kernel_size_h_);
-    CHECK_LT(pad_w_, kernel_size_w_);
+    CHECK_LT(pad_h_, kernel_h_);
+    CHECK_LT(pad_w_, kernel_w_);
   }
   channels_ = bottom[0]->channels();
   height_ = bottom[0]->height();
   width_ = bottom[0]->width();
   pooled_height_ = static_cast<int>(ceil(static_cast<float>(
-      height_ + 2 * pad_h_ - kernel_size_h_) / stride_h_)) + 1;
+      height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
   pooled_width_ = static_cast<int>(ceil(static_cast<float>(
-      width_ + 2 * pad_w_ - kernel_size_w_) / stride_w_)) + 1;
+      width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
   if (pad_h_ || pad_w_) {
     // If we have padding, ensure that the last pooling starts strictly
     // inside the image (instead of at the padding); otherwise clip the last.
@@ -142,8 +142,8 @@ Dtype PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
           for (int pw = 0; pw < pooled_width_; ++pw) {
             int hstart = ph * stride_h_ - pad_h_;
             int wstart = pw * stride_w_ - pad_w_;
-            int hend = min(hstart + kernel_size_h_, height_);
-            int wend = min(wstart + kernel_size_w_, width_);
+            int hend = min(hstart + kernel_h_, height_);
+            int wend = min(wstart + kernel_w_, width_);
             hstart = max(hstart, 0);
             wstart = max(wstart, 0);
             const int pool_index = ph * pooled_width_ + pw;
@@ -184,8 +184,8 @@ Dtype PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
           for (int pw = 0; pw < pooled_width_; ++pw) {
             int hstart = ph * stride_h_ - pad_h_;
             int wstart = pw * stride_w_ - pad_w_;
-            int hend = min(hstart + kernel_size_h_, height_ + pad_h_);
-            int wend = min(wstart + kernel_size_w_, width_ + pad_w_);
+            int hend = min(hstart + kernel_h_, height_ + pad_h_);
+            int wend = min(wstart + kernel_w_, width_ + pad_w_);
             int pool_size = (hend - hstart) * (wend - wstart);
             hstart = max(hstart, 0);
             wstart = max(wstart, 0);
@@ -266,8 +266,8 @@ void PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
           for (int pw = 0; pw < pooled_width_; ++pw) {
             int hstart = ph * stride_h_ - pad_h_;
             int wstart = pw * stride_w_ - pad_w_;
-            int hend = min(hstart + kernel_size_h_, height_ + pad_h_);
-            int wend = min(wstart + kernel_size_w_, width_ + pad_w_);
+            int hend = min(hstart + kernel_h_, height_ + pad_h_);
+            int wend = min(wstart + kernel_w_, width_ + pad_w_);
             int pool_size = (hend - hstart) * (wend - wstart);
             hstart = max(hstart, 0);
             wstart = max(wstart, 0);
index 3ac74f8..e38028d 100644 (file)
@@ -17,7 +17,7 @@ template <typename Dtype>
 __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int kernel_h, const int kernel_w, const int stride_h, 
     const int stride_w, const int pad_h, const int pad_w, Dtype* top_data,
     int* mask, Dtype* top_mask) {
   CUDA_KERNEL_LOOP(index, nthreads) {
@@ -27,8 +27,8 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
     int n = index / pooled_width / pooled_height / channels;
     int hstart = ph * stride_h - pad_h;
     int wstart = pw * stride_w - pad_w;
-    int hend = min(hstart + kernel_size_h, height);
-    int wend = min(wstart + kernel_size_w, width);
+    int hend = min(hstart + kernel_h, height);
+    int wend = min(wstart + kernel_w, width);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
     Dtype maxval = -FLT_MAX;
@@ -55,7 +55,7 @@ template <typename Dtype>
 __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int kernel_h, const int kernel_w, const int stride_h, 
     const int stride_w, const int pad_h, const int pad_w, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
@@ -64,8 +64,8 @@ __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
     int n = index / pooled_width / pooled_height / channels;
     int hstart = ph * stride_h - pad_h;
     int wstart = pw * stride_w - pad_w;
-    int hend = min(hstart + kernel_size_h, height + pad_h);
-    int wend = min(wstart + kernel_size_w, width + pad_w);
+    int hend = min(hstart + kernel_h, height + pad_h);
+    int wend = min(wstart + kernel_w, width + pad_w);
     int pool_size = (hend - hstart) * (wend - wstart);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
@@ -87,7 +87,7 @@ __global__ void StoPoolForwardTrain(const int nthreads,
     const Dtype* bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int kernel_h, const int kernel_w, const int stride_h, 
     const int stride_w, Dtype* rand_idx, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
@@ -95,9 +95,9 @@ __global__ void StoPoolForwardTrain(const int nthreads,
     int c = (index / pooled_width / pooled_height) % channels;
     int n = index / pooled_width / pooled_height / channels;
     int hstart = ph * stride_h;
-    int hend = min(hstart + kernel_size_h, height);
+    int hend = min(hstart + kernel_h, height);
     int wstart = pw * stride_w;
-    int wend = min(wstart + kernel_size_w, width);
+    int wend = min(wstart + kernel_w, width);
     Dtype cumsum = 0.;
     bottom_data += (n * channels + c) * height * width;
     // First pass: get sum
@@ -128,7 +128,7 @@ __global__ void StoPoolForwardTest(const int nthreads,
     const Dtype* bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int kernel_h, const int kernel_w, const int stride_h, 
     const int stride_w, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
@@ -136,9 +136,9 @@ __global__ void StoPoolForwardTest(const int nthreads,
     int c = (index / pooled_width / pooled_height) % channels;
     int n = index / pooled_width / pooled_height / channels;
     int hstart = ph * stride_h;
-    int hend = min(hstart + kernel_size_h, height);
+    int hend = min(hstart + kernel_h, height);
     int wstart = pw * stride_w;
-    int wend = min(wstart + kernel_size_w, width);
+    int wend = min(wstart + kernel_w, width);
     // We set cumsum to be 0 to avoid divide-by-zero problems
     Dtype cumsum = FLT_MIN;
     Dtype cumvalues = 0.;
@@ -175,16 +175,16 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     // NOLINT_NEXT_LINE(whitespace/operators)
     MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, bottom_data, bottom[0]->num(), channels_,
-        height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
-        kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data, 
+        height_, width_, pooled_height_, pooled_width_, kernel_h_, 
+        kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data, 
         mask, top_mask);
     break;
   case PoolingParameter_PoolMethod_AVE:
     // NOLINT_NEXT_LINE(whitespace/operators)
     AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, bottom_data, bottom[0]->num(), channels_,
-        height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
-        kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data);
+        height_, width_, pooled_height_, pooled_width_, kernel_h_, 
+        kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data);
     break;
   case PoolingParameter_PoolMethod_STOCHASTIC:
     if (Caffe::phase() == Caffe::TRAIN) {
@@ -195,16 +195,16 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       StoPoolForwardTrain<Dtype><<<CAFFE_GET_BLOCKS(count),
                                    CAFFE_CUDA_NUM_THREADS>>>(
           count, bottom_data, bottom[0]->num(), channels_,
-          height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
-          kernel_size_w_, stride_h_, stride_w_,
+          height_, width_, pooled_height_, pooled_width_, kernel_h_, 
+          kernel_w_, stride_h_, stride_w_,
           rand_idx_.mutable_gpu_data(), top_data);
     } else {
       // NOLINT_NEXT_LINE(whitespace/operators)
       StoPoolForwardTest<Dtype><<<CAFFE_GET_BLOCKS(count),
                                   CAFFE_CUDA_NUM_THREADS>>>(
           count, bottom_data, bottom[0]->num(), channels_,
-          height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
-          kernel_size_w_, stride_h_, stride_w_, top_data);
+          height_, width_, pooled_height_, pooled_width_, kernel_h_, 
+          kernel_w_, stride_h_, stride_w_, top_data);
     }
     break;
   default:
@@ -219,7 +219,7 @@ template <typename Dtype>
 __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
     const int* mask, const Dtype* top_mask, const int num, const int channels,
     const int height, const int width, const int pooled_height,
-    const int pooled_width, const int kernel_size_h, const int kernel_size_w, 
+    const int pooled_width, const int kernel_h, const int kernel_w, 
     const int stride_h, const int stride_w, const int pad_h, const int pad_w, 
     Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
@@ -230,10 +230,10 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
     int c = (index / width / height) % channels;
     int n = index / width / height / channels;
     int phstart =
-        (h + pad_h < kernel_size_h) ? 0 : (h + pad_h - kernel_size_h) / stride_h + 1;
+        (h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
     int phend = min((h + pad_h) / stride_h + 1, pooled_height);
     int pwstart =
-        (w + pad_w < kernel_size_w) ? 0 : (w + pad_w - kernel_size_w) / stride_w + 1;
+        (w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
     int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
     Dtype gradient = 0;
     int offset = (n * channels + c) * pooled_height * pooled_width;
@@ -265,7 +265,7 @@ template <typename Dtype>
 __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int kernel_h, const int kernel_w, const int stride_h, 
     const int stride_w, const int pad_h, const int pad_w, 
     Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
@@ -275,9 +275,9 @@ __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
     int h = (index / width) % height + pad_h;
     int c = (index / width / height) % channels;
     int n = index / width / height / channels;
-    int phstart = (h < kernel_size_h) ? 0 : (h - kernel_size_h) / stride_h + 1;
+    int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
     int phend = min(h / stride_h + 1, pooled_height);
-    int pwstart = (w < kernel_size_w) ? 0 : (w - kernel_size_w) / stride_w + 1;
+    int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
     int pwend = min(w / stride_w + 1, pooled_width);
     Dtype gradient = 0;
     top_diff += (n * channels + c) * pooled_height * pooled_width;
@@ -286,8 +286,8 @@ __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
         // figure out the pooling size
         int hstart = ph * stride_h - pad_h;
         int wstart = pw * stride_w - pad_w;
-        int hend = min(hstart + kernel_size_h, height + pad_h);
-        int wend = min(wstart + kernel_size_w, width + pad_w);
+        int hend = min(hstart + kernel_h, height + pad_h);
+        int wend = min(wstart + kernel_w, width + pad_w);
         int pool_size = (hend - hstart) * (wend - wstart);
         gradient += top_diff[ph * pooled_width + pw] / pool_size;
       }
@@ -302,7 +302,7 @@ __global__ void StoPoolBackward(const int nthreads,
     const Dtype* rand_idx, const Dtype* top_diff,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int kernel_h, const int kernel_w, const int stride_h, 
     const int stride_w, Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
@@ -311,9 +311,9 @@ __global__ void StoPoolBackward(const int nthreads,
     int h = (index / width) % height;
     int c = (index / width / height) % channels;
     int n = index / width / height / channels;
-    int phstart = (h < kernel_size_h) ? 0 : (h - kernel_size_h) / stride_h + 1;
+    int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
     int phend = min(h / stride_h + 1, pooled_height);
-    int pwstart = (w < kernel_size_w) ? 0 : (w - kernel_size_w) / stride_w + 1;
+    int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
     int pwend = min(w / stride_w + 1, pooled_width);
     Dtype gradient = 0;
     rand_idx += (n * channels + c) * pooled_height * pooled_width;
@@ -354,22 +354,22 @@ void PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, top_diff, mask, top_mask, top[0]->num(), channels_,
         height_, width_, pooled_height_, pooled_width_,
-        kernel_size_h_, kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, 
+        kernel_h_, kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, 
         bottom_diff);
     break;
   case PoolingParameter_PoolMethod_AVE:
     // NOLINT_NEXT_LINE(whitespace/operators)
     AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, top_diff, top[0]->num(), channels_,
-        height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
-        kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, bottom_diff);
+        height_, width_, pooled_height_, pooled_width_, kernel_h_, 
+        kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, bottom_diff);
     break;
   case PoolingParameter_PoolMethod_STOCHASTIC:
     // NOLINT_NEXT_LINE(whitespace/operators)
     StoPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, rand_idx_.gpu_data(), top_diff,
         top[0]->num(), channels_, height_, width_, pooled_height_,
-        pooled_width_, kernel_size_h_, kernel_size_w_, stride_h_, stride_w_, 
+        pooled_width_, kernel_h_, kernel_w_, stride_h_, stride_w_, 
         bottom_diff);
     break;
   default:
index effebd0..70a4ab2 100644 (file)
@@ -411,8 +411,8 @@ message PoolingParameter {
   optional uint32 pad_h = 9 [default = 0]; // The padding height
   optional uint32 pad_w = 10 [default = 0]; // The padding width
   optional uint32 kernel_size = 2; // The kernel size (square)
-  optional uint32 kernel_size_h = 5; // The kernel height
-  optional uint32 kernel_size_w = 6; // The kernel width
+  optional uint32 kernel_h = 5; // The kernel height
+  optional uint32 kernel_w = 6; // The kernel width
   optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X)
   optional uint32 stride_h = 7; // The stride height
   optional uint32 stride_w = 8; // The stride width
index b13d11f..651d203 100644 (file)
@@ -44,8 +44,8 @@ class PoolingLayerTest : public ::testing::Test {
   Blob<Dtype>* const blob_top_mask_;
   vector<Blob<Dtype>*> blob_bottom_vec_;
   vector<Blob<Dtype>*> blob_top_vec_;
-
-  void TestForward() {
+  // Test for 2x 2 square pooling layer
+  void TestForwardSquare() {
     LayerParameter layer_param;
     PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
     pooling_param->set_kernel_size(2);
@@ -116,6 +116,256 @@ class PoolingLayerTest : public ::testing::Test {
       }
     }
   }
+  // Test for 3x 2 rectangular pooling layer with kernel_h > kernel_w
+  void TestForwardRectHigh() {
+    LayerParameter layer_param;
+    PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+    pooling_param->set_kernel_h(3);
+    pooling_param->set_kernel_w(2);
+    pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
+    const int num = 2;
+    const int channels = 2;
+    blob_bottom_->Reshape(num, channels, 6, 6);
+    // Input: 2x 2 channels of:
+    // [35     1     6    26    19    24]
+    // [ 3    32     7    21    23    25]
+    // [31     9     2    22    27    20]
+    // [ 8    28    33    17    10    15]
+    // [30     5    34    12    14    16]
+    // [ 4    36    29    13    18    11] (this is generated by magic(6) in MATLAB)
+    for (int i = 0; i < 36 * num * channels; i += 36) {
+      blob_bottom_->mutable_cpu_data()[i +  0] = 35;
+      blob_bottom_->mutable_cpu_data()[i +  1] = 1;
+      blob_bottom_->mutable_cpu_data()[i +  2] = 6;
+      blob_bottom_->mutable_cpu_data()[i +  3] = 26;
+      blob_bottom_->mutable_cpu_data()[i +  4] = 19;
+      blob_bottom_->mutable_cpu_data()[i +  5] = 24;
+      blob_bottom_->mutable_cpu_data()[i +  6] = 3;
+      blob_bottom_->mutable_cpu_data()[i +  7] = 32;
+      blob_bottom_->mutable_cpu_data()[i +  8] = 7;
+      blob_bottom_->mutable_cpu_data()[i +  9] = 21;
+      blob_bottom_->mutable_cpu_data()[i + 10] = 23;
+      blob_bottom_->mutable_cpu_data()[i + 11] = 25;
+      blob_bottom_->mutable_cpu_data()[i + 12] = 31;
+      blob_bottom_->mutable_cpu_data()[i + 13] = 9;
+      blob_bottom_->mutable_cpu_data()[i + 14] = 2;
+      blob_bottom_->mutable_cpu_data()[i + 15] = 22;
+      blob_bottom_->mutable_cpu_data()[i + 16] = 27;
+      blob_bottom_->mutable_cpu_data()[i + 17] = 20;
+      blob_bottom_->mutable_cpu_data()[i + 18] = 8;
+      blob_bottom_->mutable_cpu_data()[i + 19] = 28;
+      blob_bottom_->mutable_cpu_data()[i + 20] = 33;
+      blob_bottom_->mutable_cpu_data()[i + 21] = 17;
+      blob_bottom_->mutable_cpu_data()[i + 22] = 10;
+      blob_bottom_->mutable_cpu_data()[i + 23] = 15;
+      blob_bottom_->mutable_cpu_data()[i + 24] = 30;
+      blob_bottom_->mutable_cpu_data()[i + 25] = 5;
+      blob_bottom_->mutable_cpu_data()[i + 26] = 34;
+      blob_bottom_->mutable_cpu_data()[i + 27] = 12;
+      blob_bottom_->mutable_cpu_data()[i + 28] = 14;
+      blob_bottom_->mutable_cpu_data()[i + 29] = 16;
+      blob_bottom_->mutable_cpu_data()[i + 30] = 4;
+      blob_bottom_->mutable_cpu_data()[i + 31] = 36;
+      blob_bottom_->mutable_cpu_data()[i + 32] = 29;
+      blob_bottom_->mutable_cpu_data()[i + 33] = 13;
+      blob_bottom_->mutable_cpu_data()[i + 34] = 18;
+      blob_bottom_->mutable_cpu_data()[i + 35] = 11;
+    }
+    PoolingLayer<Dtype> layer(layer_param);
+    layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
+    EXPECT_EQ(blob_top_->num(), num);
+    EXPECT_EQ(blob_top_->channels(), channels);
+    EXPECT_EQ(blob_top_->height(), 4);
+    EXPECT_EQ(blob_top_->width(), 5);
+    if (blob_top_vec_.size() > 1) {
+      EXPECT_EQ(blob_top_mask_->num(), num);
+      EXPECT_EQ(blob_top_mask_->channels(), channels);
+      EXPECT_EQ(blob_top_mask_->height(), 4);
+      EXPECT_EQ(blob_top_mask_->width(), 5);
+    }
+    layer.Forward(blob_bottom_vec_, &blob_top_vec_);
+    // Expected output: 2x 2 channels of:
+    // [35    32    26    27    27]
+    // [32    33    33    27    27]
+    // [31    34    34    27    27]
+    // [36    36    34    18    18]
+    for (int i = 0; i < 20 * num * channels; i += 20) {
+      EXPECT_EQ(blob_top_->cpu_data()[i +  0], 35);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  1], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  2], 26);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  3], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  4], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  5], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  6], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  7], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  8], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  9], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 10], 31);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 11], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 13], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 14], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 15], 36);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 17], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 18], 18);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18);
+    }
+    if (blob_top_vec_.size() > 1) {
+        // [ 1     8     4    17    17]
+        // [ 8    21    21    17    17]
+        // [13    27    27    17    17]
+        // [32    32    27    35    35]
+      for (int i = 0; i < 20 * num * channels; i += 20) {
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  0],  0);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  1],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  2],  3);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  3], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  4], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  5],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  6], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  7], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  8], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  9], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 12);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 31);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 34);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34);
+      }
+    }
+  }
+  // Test for rectangular pooling layer with kernel_w > kernel_h
+  void TestForwardRectWide() {
+    LayerParameter layer_param;
+    PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+    pooling_param->set_kernel_h(2);
+    pooling_param->set_kernel_w(3);
+    pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
+    const int num = 2;
+    const int channels = 2;
+    blob_bottom_->Reshape(num, channels, 6, 6);
+    // Input: 2x 2 channels of:
+    // [35     1     6    26    19    24]
+    // [ 3    32     7    21    23    25]
+    // [31     9     2    22    27    20]
+    // [ 8    28    33    17    10    15]
+    // [30     5    34    12    14    16]
+    // [ 4    36    29    13    18    11] (this is generated by magic(6) in MATLAB)
+    for (int i = 0; i < 36 * num * channels; i += 36) {
+      blob_bottom_->mutable_cpu_data()[i +  0] = 35;
+      blob_bottom_->mutable_cpu_data()[i +  1] = 1;
+      blob_bottom_->mutable_cpu_data()[i +  2] = 6;
+      blob_bottom_->mutable_cpu_data()[i +  3] = 26;
+      blob_bottom_->mutable_cpu_data()[i +  4] = 19;
+      blob_bottom_->mutable_cpu_data()[i +  5] = 24;
+      blob_bottom_->mutable_cpu_data()[i +  6] = 3;
+      blob_bottom_->mutable_cpu_data()[i +  7] = 32;
+      blob_bottom_->mutable_cpu_data()[i +  8] = 7;
+      blob_bottom_->mutable_cpu_data()[i +  9] = 21;
+      blob_bottom_->mutable_cpu_data()[i + 10] = 23;
+      blob_bottom_->mutable_cpu_data()[i + 11] = 25;
+      blob_bottom_->mutable_cpu_data()[i + 12] = 31;
+      blob_bottom_->mutable_cpu_data()[i + 13] = 9;
+      blob_bottom_->mutable_cpu_data()[i + 14] = 2;
+      blob_bottom_->mutable_cpu_data()[i + 15] = 22;
+      blob_bottom_->mutable_cpu_data()[i + 16] = 27;
+      blob_bottom_->mutable_cpu_data()[i + 17] = 20;
+      blob_bottom_->mutable_cpu_data()[i + 18] = 8;
+      blob_bottom_->mutable_cpu_data()[i + 19] = 28;
+      blob_bottom_->mutable_cpu_data()[i + 20] = 33;
+      blob_bottom_->mutable_cpu_data()[i + 21] = 17;
+      blob_bottom_->mutable_cpu_data()[i + 22] = 10;
+      blob_bottom_->mutable_cpu_data()[i + 23] = 15;
+      blob_bottom_->mutable_cpu_data()[i + 24] = 30;
+      blob_bottom_->mutable_cpu_data()[i + 25] = 5;
+      blob_bottom_->mutable_cpu_data()[i + 26] = 34;
+      blob_bottom_->mutable_cpu_data()[i + 27] = 12;
+      blob_bottom_->mutable_cpu_data()[i + 28] = 14;
+      blob_bottom_->mutable_cpu_data()[i + 29] = 16;
+      blob_bottom_->mutable_cpu_data()[i + 30] = 4;
+      blob_bottom_->mutable_cpu_data()[i + 31] = 36;
+      blob_bottom_->mutable_cpu_data()[i + 32] = 29;
+      blob_bottom_->mutable_cpu_data()[i + 33] = 13;
+      blob_bottom_->mutable_cpu_data()[i + 34] = 18;
+      blob_bottom_->mutable_cpu_data()[i + 35] = 11;
+    }
+    PoolingLayer<Dtype> layer(layer_param);
+    layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
+    EXPECT_EQ(blob_top_->num(), num);
+    EXPECT_EQ(blob_top_->channels(), channels);
+    EXPECT_EQ(blob_top_->height(), 5);
+    EXPECT_EQ(blob_top_->width(), 4);
+    if (blob_top_vec_.size() > 1) {
+      EXPECT_EQ(blob_top_mask_->num(), num);
+      EXPECT_EQ(blob_top_mask_->channels(), channels);
+      EXPECT_EQ(blob_top_mask_->height(), 5);
+      EXPECT_EQ(blob_top_mask_->width(), 4);
+    }
+    layer.Forward(blob_bottom_vec_, &blob_top_vec_);
+    // Expected output: 2x 2 channels of:
+    // [35    32    26    26]
+    // [32    32    27    27]
+    // [33    33    33    27]
+    // [34    34    34    17]
+    // [36    36    34    18]
+    for (int i = 0; i < 20 * num * channels; i += 20) {
+      EXPECT_EQ(blob_top_->cpu_data()[i +  0], 35);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  1], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  2], 26);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  3], 26);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  4], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  5], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  6], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  7], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  8], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  9], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 10], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 11], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 13], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 14], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 15], 17);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 17], 36);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 18], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18);
+    }
+    if (blob_top_vec_.size() > 1) {
+        // [ 1     8     4     4]
+        // [ 8     8    17    17]
+        // [21    21    21    17]
+        // [27    27    27    22]
+        // [32    32    27    35]
+      for (int i = 0; i < 20 * num * channels; i += 20) {
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  0],  0);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  1],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  2],  3);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  3],  3);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  4],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  5],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  6], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  7], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  8], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  9], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 21);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 31);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34);
+      }
+    }
+  }
 };
 
 typedef ::testing::Types<float, double> Dtypes;
@@ -207,24 +457,32 @@ TYPED_TEST(PoolingLayerTest, PrintCPUBackward) {
 
 TYPED_TEST(PoolingLayerTest, TestCPUForwardMax) {
   Caffe::set_mode(Caffe::CPU);
-  this->TestForward();
+  this->TestForwardSquare();
+  this->TestForwardRectHigh();
+  this->TestForwardRectWide();
 }
 
 TYPED_TEST(PoolingLayerTest, TestGPUForwardMax) {
   Caffe::set_mode(Caffe::GPU);
-  this->TestForward();
+  this->TestForwardSquare();
+  this->TestForwardRectHigh();
+  this->TestForwardRectWide();
 }
 
 TYPED_TEST(PoolingLayerTest, TestCPUForwardMaxTopMask) {
   Caffe::set_mode(Caffe::CPU);
   this->blob_top_vec_.push_back(this->blob_top_mask_);
-  this->TestForward();
+  this->TestForwardSquare();
+  this->TestForwardRectHigh();
+  this->TestForwardRectWide();
 }
 
 TYPED_TEST(PoolingLayerTest, TestGPUForwardMaxTopMask) {
   Caffe::set_mode(Caffe::GPU);
   this->blob_top_vec_.push_back(this->blob_top_mask_);
-  this->TestForward();
+  this->TestForwardSquare();
+  this->TestForwardRectHigh();
+  this->TestForwardRectWide();
 }
 
 TYPED_TEST(PoolingLayerTest, TestCPUGradientMax) {