add padding for average pooling
authorJeff Donahue <jeff.donahue@gmail.com>
Sat, 29 Mar 2014 03:18:06 +0000 (20:18 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 8 Apr 2014 18:36:18 +0000 (11:36 -0700)
src/caffe/layers/pooling_layer.cpp
src/caffe/layers/pooling_layer.cu
src/caffe/proto/caffe.proto
src/caffe/test/test_pooling_layer.cpp

index a186741232febe1a444047b90a278037aa62a073..7e880a27b692d79d26ed19ce3fd4eb1cd3a48b5c 100644 (file)
@@ -20,13 +20,19 @@ void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK_EQ(top->size(), 1) << "PoolingLayer takes a single blob as output.";
   kernel_size_ = this->layer_param_.pooling_param().kernel_size();
   stride_ = this->layer_param_.pooling_param().stride();
+  pad_ = this->layer_param_.pooling_param().pad();
+  if (pad_ != 0) {
+    CHECK_EQ(this->layer_param_.pooling_param().pool(),
+             PoolingParameter_PoolMethod_AVE)
+        << "Padding implemented only for average pooling.";
+  }
   channels_ = bottom[0]->channels();
   height_ = bottom[0]->height();
   width_ = bottom[0]->width();
-  pooled_height_ = static_cast<int>(
-      ceil(static_cast<float>(height_ - kernel_size_) / stride_)) + 1;
-  pooled_width_ = static_cast<int>(
-      ceil(static_cast<float>(width_ - kernel_size_) / stride_)) + 1;
+  pooled_height_ = static_cast<int>(ceil(static_cast<float>(
+      height_ + 2 * pad_ - kernel_size_) / stride_)) + 1;
+  pooled_width_ = static_cast<int>(ceil(static_cast<float>(
+      width_ + 2 * pad_ - kernel_size_) / stride_)) + 1;
   (*top)[0]->Reshape(bottom[0]->num(), channels_, pooled_height_,
       pooled_width_);
   // If stochastic pooling, we will initialize the random index part.
@@ -86,18 +92,22 @@ Dtype PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       for (int c = 0; c < channels_; ++c) {
         for (int ph = 0; ph < pooled_height_; ++ph) {
           for (int pw = 0; pw < pooled_width_; ++pw) {
-            int hstart = ph * stride_;
-            int wstart = pw * stride_;
-            int hend = min(hstart + kernel_size_, height_);
-            int wend = min(wstart + kernel_size_, width_);
+            int hstart = ph * stride_ - pad_;
+            int wstart = pw * stride_ - pad_;
+            int hend = min(hstart + kernel_size_, height_ + pad_);
+            int wend = min(wstart + kernel_size_, width_ + pad_);
+            int pool_size = (hend - hstart) * (wend - wstart);
+            hstart = max(hstart, 0);
+            wstart = max(wstart, 0);
+            hend = min(hend, height_);
+            wend = min(wend, width_);
             for (int h = hstart; h < hend; ++h) {
               for (int w = wstart; w < wend; ++w) {
                 top_data[ph * pooled_width_ + pw] +=
                     bottom_data[h * width_ + w];
               }
             }
-            top_data[ph * pooled_width_ + pw] /=
-                (hend - hstart) * (wend - wstart);
+            top_data[ph * pooled_width_ + pw] /= pool_size;
           }
         }
         // compute offset
@@ -163,15 +173,19 @@ void PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       for (int c = 0; c < channels_; ++c) {
         for (int ph = 0; ph < pooled_height_; ++ph) {
           for (int pw = 0; pw < pooled_width_; ++pw) {
-            int hstart = ph * stride_;
-            int wstart = pw * stride_;
-            int hend = min(hstart + kernel_size_, height_);
-            int wend = min(wstart + kernel_size_, width_);
-            int poolsize = (hend - hstart) * (wend - wstart);
+            int hstart = ph * stride_ - pad_;
+            int wstart = pw * stride_ - pad_;
+            int hend = min(hstart + kernel_size_, height_ + pad_);
+            int wend = min(wstart + kernel_size_, width_ + pad_);
+            int pool_size = (hend - hstart) * (wend - wstart);
+            hstart = max(hstart, 0);
+            wstart = max(wstart, 0);
+            hend = min(hend, height_);
+            wend = min(wend, width_);
             for (int h = hstart; h < hend; ++h) {
               for (int w = wstart; w < wend; ++w) {
                 bottom_diff[h * width_ + w] +=
-                  top_diff[ph * pooled_width_ + pw] / poolsize;
+                  top_diff[ph * pooled_width_ + pw] / pool_size;
               }
             }
           }
index 7adf348be34e992e8a6805f18f52d17130eb69ac..74150df249337bd3975e00195e9729509cd2dc8f 100644 (file)
@@ -16,13 +16,13 @@ namespace caffe {
 template <typename Dtype>
 __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
     const int num, const int channels, const int height,
-    const int width, const int pooled_height_, const int pooled_width,
+    const int width, const int pooled_height, const int pooled_width,
     const int kernel_size, const int stride, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height_;
-    int c = (index / pooled_width / pooled_height_) % channels;
-    int n = index / pooled_width / pooled_height_ / channels;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
     int hstart = ph * stride;
     int hend = min(hstart + kernel_size, height);
     int wstart = pw * stride;
@@ -41,17 +41,22 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
 template <typename Dtype>
 __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
     const int num, const int channels, const int height,
-    const int width, const int pooled_height_, const int pooled_width,
-    const int kernel_size, const int stride, Dtype* top_data) {
+    const int width, const int pooled_height, const int pooled_width,
+    const int kernel_size, const int stride, const int pad, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height_;
-    int c = (index / pooled_width / pooled_height_) % channels;
-    int n = index / pooled_width / pooled_height_ / channels;
-    int hstart = ph * stride;
-    int hend = min(hstart + kernel_size, height);
-    int wstart = pw * stride;
-    int wend = min(wstart + kernel_size, width);
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+    int hstart = ph * stride - pad;
+    int wstart = pw * stride - pad;
+    int hend = min(hstart + kernel_size, height + pad);
+    int wend = min(wstart + kernel_size, width + pad);
+    int pool_size = (hend - hstart) * (wend - wstart);
+    hstart = max(hstart, 0);
+    wstart = max(wstart, 0);
+    hend = min(hend, height);
+    wend = min(wend, width);
     Dtype aveval = 0;
     bottom_data += (n * channels + c) * height * width;
     for (int h = hstart; h < hend; ++h) {
@@ -59,7 +64,7 @@ __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
         aveval += bottom_data[h * width + w];
       }
     }
-    top_data[index] = aveval / (hend - hstart) / (wend - wstart);
+    top_data[index] = aveval / pool_size;
   }
 }
 
@@ -67,13 +72,13 @@ template <typename Dtype>
 __global__ void StoPoolForwardTrain(const int nthreads,
     const Dtype* bottom_data,
     const int num, const int channels, const int height,
-    const int width, const int pooled_height_, const int pooled_width,
+    const int width, const int pooled_height, const int pooled_width,
     const int kernel_size, const int stride, float* rand_idx, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height_;
-    int c = (index / pooled_width / pooled_height_) % channels;
-    int n = index / pooled_width / pooled_height_ / channels;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
     int hstart = ph * stride;
     int hend = min(hstart + kernel_size, height);
     int wstart = pw * stride;
@@ -107,13 +112,13 @@ template <typename Dtype>
 __global__ void StoPoolForwardTest(const int nthreads,
     const Dtype* bottom_data,
     const int num, const int channels, const int height,
-    const int width, const int pooled_height_, const int pooled_width,
+    const int width, const int pooled_height, const int pooled_width,
     const int kernel_size, const int stride, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height_;
-    int c = (index / pooled_width / pooled_height_) % channels;
-    int n = index / pooled_width / pooled_height_ / channels;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
     int hstart = ph * stride;
     int hend = min(hstart + kernel_size, height);
     int wstart = pw * stride;
@@ -153,7 +158,7 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, bottom_data, bottom[0]->num(), channels_,
         height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
-        top_data);
+        pad_, top_data);
     break;
   case PoolingParameter_PoolMethod_STOCHASTIC:
     if (Caffe::phase() == Caffe::TRAIN) {
@@ -186,7 +191,7 @@ template <typename Dtype>
 __global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
     const Dtype* top_data, const Dtype* top_diff,
     const int num, const int channels, const int height,
-    const int width, const int pooled_height_, const int pooled_width,
+    const int width, const int pooled_height, const int pooled_width,
     const int kernel_size, const int stride, Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
@@ -196,14 +201,14 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
     int c = (index / width / height) % channels;
     int n = index / width / height / channels;
     int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
-    int phend = min(h / stride + 1, pooled_height_);
+    int phend = min(h / stride + 1, pooled_height);
     int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
     int pwend = min(w / stride + 1, pooled_width);
     Dtype gradient = 0;
     Dtype bottom_datum =
         bottom_data[((n * channels + c) * height + h) * width + w];
-    top_data += (n * channels + c) * pooled_height_ * pooled_width;
-    top_diff += (n * channels + c) * pooled_height_ * pooled_width;
+    top_data += (n * channels + c) * pooled_height * pooled_width;
+    top_diff += (n * channels + c) * pooled_height * pooled_width;
     for (int ph = phstart; ph < phend; ++ph) {
       for (int pw = pwstart; pw < pwend; ++pw) {
         gradient += top_diff[ph * pooled_width + pw] *
@@ -218,27 +223,31 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
 template <typename Dtype>
 __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
     const int num, const int channels, const int height,
-    const int width, const int pooled_height_, const int pooled_width,
-    const int kernel_size, const int stride, Dtype* bottom_diff) {
+    const int width, const int pooled_height, const int pooled_width,
+    const int kernel_size, const int stride, const int pad,
+    Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
-    int w = index % width;
-    int h = (index / width) % height;
+    int w = index % width + pad;
+    int h = (index / width) % height + pad;
     int c = (index / width / height) % channels;
     int n = index / width / height / channels;
     int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
-    int phend = min(h / stride + 1, pooled_height_);
+    int phend = min(h / stride + 1, pooled_height);
     int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
     int pwend = min(w / stride + 1, pooled_width);
     Dtype gradient = 0;
-    top_diff += (n * channels + c) * pooled_height_ * pooled_width;
+    top_diff += (n * channels + c) * pooled_height * pooled_width;
     for (int ph = phstart; ph < phend; ++ph) {
       for (int pw = pwstart; pw < pwend; ++pw) {
         // figure out the pooling size
-        int poolsize = (min(ph * stride + kernel_size, height) - ph * stride) *
-            (min(pw * stride + kernel_size, width) - pw * stride);
-        gradient += top_diff[ph * pooled_width + pw] / poolsize;
+        int hstart = ph * stride - pad;
+        int wstart = pw * stride - pad;
+        int hend = min(hstart + kernel_size, height + pad);
+        int wend = min(wstart + kernel_size, width + pad);
+        int pool_size = (hend - hstart) * (wend - wstart);
+        gradient += top_diff[ph * pooled_width + pw] / pool_size;
       }
     }
     bottom_diff[index] = gradient;
@@ -250,7 +259,7 @@ template <typename Dtype>
 __global__ void StoPoolBackward(const int nthreads,
     const float* rand_idx, const Dtype* top_diff,
     const int num, const int channels, const int height,
-    const int width, const int pooled_height_, const int pooled_width,
+    const int width, const int pooled_height, const int pooled_width,
     const int kernel_size, const int stride, Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
@@ -260,12 +269,12 @@ __global__ void StoPoolBackward(const int nthreads,
     int c = (index / width / height) % channels;
     int n = index / width / height / channels;
     int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
-    int phend = min(h / stride + 1, pooled_height_);
+    int phend = min(h / stride + 1, pooled_height);
     int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
     int pwend = min(w / stride + 1, pooled_width);
     Dtype gradient = 0;
-    rand_idx += (n * channels + c) * pooled_height_ * pooled_width;
-    top_diff += (n * channels + c) * pooled_height_ * pooled_width;
+    rand_idx += (n * channels + c) * pooled_height * pooled_width;
+    top_diff += (n * channels + c) * pooled_height * pooled_width;
     for (int ph = phstart; ph < phend; ++ph) {
       for (int pw = pwstart; pw < pwend; ++pw) {
         gradient += top_diff[ph * pooled_width + pw] *
@@ -299,7 +308,7 @@ void PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, top_diff, top[0]->num(), channels_,
         height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
-        bottom_diff);
+        pad_, bottom_diff);
     break;
   case PoolingParameter_PoolMethod_STOCHASTIC:
     // NOLINT_NEXT_LINE(whitespace/operators)
index e0011483f6bac7bcd44752fbd72676635cb3b999..54f2743ee4f2a84876f5c38141ffda35dcf72013 100644 (file)
@@ -290,7 +290,8 @@ message PoolingParameter {
   optional PoolMethod pool = 1 [default = MAX]; // The pooling method
   optional uint32 kernel_size = 2; // The kernel size
   optional uint32 stride = 3 [default = 1]; // The stride
-  optional uint32 pad = 4 [default = 0]; // The padding size
+  // The padding size -- currently implemented only for average pooling.
+  optional uint32 pad = 4 [default = 0];
 }
 
 // Message that stores parameters used by PowerLayer
index d1246a098c82ac5e49ed23e79c289bf860eb4b53..a57110491ec9be1467f2b55b4532dd6a55bf02d8 100644 (file)
@@ -56,6 +56,21 @@ TYPED_TEST(PoolingLayerTest, TestSetup) {
   EXPECT_EQ(this->blob_top_->width(), 2);
 }
 
+TYPED_TEST(PoolingLayerTest, TestSetupPadded) {
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(3);
+  pooling_param->set_stride(2);
+  pooling_param->set_pad(1);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+  PoolingLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
+  EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
+  EXPECT_EQ(this->blob_top_->height(), 4);
+  EXPECT_EQ(this->blob_top_->width(), 3);
+}
+
 /*
 TYPED_TEST(PoolingLayerTest, PrintGPUBackward) {
   LayerParameter layer_param;
@@ -111,6 +126,72 @@ TYPED_TEST(PoolingLayerTest, TestGPUGradientMax) {
 }
 
 
+TYPED_TEST(PoolingLayerTest, TestCPUForwardAve) {
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(3);
+  pooling_param->set_stride(1);
+  pooling_param->set_pad(1);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+  Caffe::set_mode(Caffe::CPU);
+  this->blob_bottom_->Reshape(1, 1, 3, 3);
+  FillerParameter filler_param;
+  filler_param.set_value(TypeParam(2));
+  ConstantFiller<TypeParam> filler(filler_param);
+  filler.Fill(this->blob_bottom_);
+  PoolingLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), 1);
+  EXPECT_EQ(this->blob_top_->channels(), 1);
+  EXPECT_EQ(this->blob_top_->height(), 3);
+  EXPECT_EQ(this->blob_top_->width(), 3);
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  TypeParam epsilon = 1e-5;
+  EXPECT_NEAR(this->blob_top_->cpu_data()[0], 8.0 / 9, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[2], 8.0 / 9, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[4], 2.0, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[6], 8.0 / 9, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[8], 8.0 / 9, epsilon);
+}
+
+
+TYPED_TEST(PoolingLayerTest, TestGPUForwardAve) {
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(3);
+  pooling_param->set_stride(1);
+  pooling_param->set_pad(1);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+  Caffe::set_mode(Caffe::GPU);
+  this->blob_bottom_->Reshape(1, 1, 3, 3);
+  FillerParameter filler_param;
+  filler_param.set_value(TypeParam(2));
+  ConstantFiller<TypeParam> filler(filler_param);
+  filler.Fill(this->blob_bottom_);
+  PoolingLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), 1);
+  EXPECT_EQ(this->blob_top_->channels(), 1);
+  EXPECT_EQ(this->blob_top_->height(), 3);
+  EXPECT_EQ(this->blob_top_->width(), 3);
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  TypeParam epsilon = 1e-5;
+  EXPECT_NEAR(this->blob_top_->cpu_data()[0], 8.0 / 9, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[2], 8.0 / 9, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[4], 2.0, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[6], 8.0 / 9, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[8], 8.0 / 9, epsilon);
+}
+
+
 TYPED_TEST(PoolingLayerTest, TestCPUGradientAve) {
   LayerParameter layer_param;
   PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
@@ -139,4 +220,34 @@ TYPED_TEST(PoolingLayerTest, TestGPUGradientAve) {
 }
 
 
+TYPED_TEST(PoolingLayerTest, TestCPUGradientAvePadded) {
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(3);
+  pooling_param->set_stride(2);
+  pooling_param->set_pad(2);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+  Caffe::set_mode(Caffe::CPU);
+  PoolingLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
+
+TYPED_TEST(PoolingLayerTest, TestGPUGradientAvePadded) {
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(3);
+  pooling_param->set_stride(2);
+  pooling_param->set_pad(2);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+  Caffe::set_mode(Caffe::GPU);
+  PoolingLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
+
 }  // namespace caffe