Update pooling_layer.cu
authorRonghang Hu <huronghang@hotmail.com>
Fri, 4 Jul 2014 04:05:23 +0000 (21:05 -0700)
committerRonghang Hu <huronghang@hotmail.com>
Fri, 4 Jul 2014 04:05:23 +0000 (21:05 -0700)
Replace pad_, kernel_size_, stride_ with pad_h_, pad_w_, kernel_size_h_, kernel_size_w_, stride_h_, stride_w_ to support pooling on rectangle regions.

src/caffe/layers/pooling_layer.cu

index f07fe3c..3ac74f8 100644 (file)
@@ -17,17 +17,18 @@ template <typename Dtype>
 __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size, const int stride, const int pad, Dtype* top_data,
+    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int stride_w, const int pad_h, const int pad_w, Dtype* top_data,
     int* mask, Dtype* top_mask) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
     int ph = (index / pooled_width) % pooled_height;
     int c = (index / pooled_width / pooled_height) % channels;
     int n = index / pooled_width / pooled_height / channels;
-    int hstart = ph * stride - pad;
-    int wstart = pw * stride - pad;
-    int hend = min(hstart + kernel_size, height);
-    int wend = min(wstart + kernel_size, width);
+    int hstart = ph * stride_h - pad_h;
+    int wstart = pw * stride_w - pad_w;
+    int hend = min(hstart + kernel_size_h, height);
+    int wend = min(wstart + kernel_size_w, width);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
     Dtype maxval = -FLT_MAX;
@@ -54,16 +55,17 @@ template <typename Dtype>
 __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size, const int stride, const int pad, Dtype* top_data) {
+    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int stride_w, const int pad_h, const int pad_w, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
     int ph = (index / pooled_width) % pooled_height;
     int c = (index / pooled_width / pooled_height) % channels;
     int n = index / pooled_width / pooled_height / channels;
-    int hstart = ph * stride - pad;
-    int wstart = pw * stride - pad;
-    int hend = min(hstart + kernel_size, height + pad);
-    int wend = min(wstart + kernel_size, width + pad);
+    int hstart = ph * stride_h - pad_h;
+    int wstart = pw * stride_w - pad_w;
+    int hend = min(hstart + kernel_size_h, height + pad_h);
+    int wend = min(wstart + kernel_size_w, width + pad_w);
     int pool_size = (hend - hstart) * (wend - wstart);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
@@ -85,16 +87,17 @@ __global__ void StoPoolForwardTrain(const int nthreads,
     const Dtype* bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size, const int stride, Dtype* rand_idx, Dtype* top_data) {
+    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int stride_w, Dtype* rand_idx, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
     int ph = (index / pooled_width) % pooled_height;
     int c = (index / pooled_width / pooled_height) % channels;
     int n = index / pooled_width / pooled_height / channels;
-    int hstart = ph * stride;
-    int hend = min(hstart + kernel_size, height);
-    int wstart = pw * stride;
-    int wend = min(wstart + kernel_size, width);
+    int hstart = ph * stride_h;
+    int hend = min(hstart + kernel_size_h, height);
+    int wstart = pw * stride_w;
+    int wend = min(wstart + kernel_size_w, width);
     Dtype cumsum = 0.;
     bottom_data += (n * channels + c) * height * width;
     // First pass: get sum
@@ -125,16 +128,17 @@ __global__ void StoPoolForwardTest(const int nthreads,
     const Dtype* bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size, const int stride, Dtype* top_data) {
+    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int stride_w, Dtype* top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
     int ph = (index / pooled_width) % pooled_height;
     int c = (index / pooled_width / pooled_height) % channels;
     int n = index / pooled_width / pooled_height / channels;
-    int hstart = ph * stride;
-    int hend = min(hstart + kernel_size, height);
-    int wstart = pw * stride;
-    int wend = min(wstart + kernel_size, width);
+    int hstart = ph * stride_h;
+    int hend = min(hstart + kernel_size_h, height);
+    int wstart = pw * stride_w;
+    int wend = min(wstart + kernel_size_w, width);
     // We set cumsum to be 0 to avoid divide-by-zero problems
     Dtype cumsum = FLT_MIN;
     Dtype cumvalues = 0.;
@@ -171,15 +175,16 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     // NOLINT_NEXT_LINE(whitespace/operators)
     MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, bottom_data, bottom[0]->num(), channels_,
-        height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
-        pad_, top_data, mask, top_mask);
+        height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
+        kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data, 
+        mask, top_mask);
     break;
   case PoolingParameter_PoolMethod_AVE:
     // NOLINT_NEXT_LINE(whitespace/operators)
     AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, bottom_data, bottom[0]->num(), channels_,
-        height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
-        pad_, top_data);
+        height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
+        kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data);
     break;
   case PoolingParameter_PoolMethod_STOCHASTIC:
     if (Caffe::phase() == Caffe::TRAIN) {
@@ -190,15 +195,16 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       StoPoolForwardTrain<Dtype><<<CAFFE_GET_BLOCKS(count),
                                    CAFFE_CUDA_NUM_THREADS>>>(
           count, bottom_data, bottom[0]->num(), channels_,
-          height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
+          height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
+          kernel_size_w_, stride_h_, stride_w_,
           rand_idx_.mutable_gpu_data(), top_data);
     } else {
       // NOLINT_NEXT_LINE(whitespace/operators)
       StoPoolForwardTest<Dtype><<<CAFFE_GET_BLOCKS(count),
                                   CAFFE_CUDA_NUM_THREADS>>>(
           count, bottom_data, bottom[0]->num(), channels_,
-          height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
-          top_data);
+          height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
+          kernel_size_w_, stride_h_, stride_w_, top_data);
     }
     break;
   default:
@@ -213,8 +219,9 @@ template <typename Dtype>
 __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
     const int* mask, const Dtype* top_mask, const int num, const int channels,
     const int height, const int width, const int pooled_height,
-    const int pooled_width, const int kernel_size, const int stride,
-    const int pad, Dtype* bottom_diff) {
+    const int pooled_width, const int kernel_size_h, const int kernel_size_w, 
+    const int stride_h, const int stride_w, const int pad_h, const int pad_w, 
+    Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
@@ -223,11 +230,11 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
     int c = (index / width / height) % channels;
     int n = index / width / height / channels;
     int phstart =
-        (h + pad < kernel_size) ? 0 : (h + pad - kernel_size) / stride + 1;
-    int phend = min((h + pad) / stride + 1, pooled_height);
+        (h + pad_h < kernel_size_h) ? 0 : (h + pad_h - kernel_size_h) / stride_h + 1;
+    int phend = min((h + pad_h) / stride_h + 1, pooled_height);
     int pwstart =
-        (w + pad < kernel_size) ? 0 : (w + pad - kernel_size) / stride + 1;
-    int pwend = min((w + pad) / stride + 1, pooled_width);
+        (w + pad_w < kernel_size_w) ? 0 : (w + pad_w - kernel_size_w) / stride_w + 1;
+    int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
     Dtype gradient = 0;
     int offset = (n * channels + c) * pooled_height * pooled_width;
     top_diff += offset;
@@ -258,28 +265,29 @@ template <typename Dtype>
 __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size, const int stride, const int pad,
+    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int stride_w, const int pad_h, const int pad_w, 
     Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
-    int w = index % width + pad;
-    int h = (index / width) % height + pad;
+    int w = index % width + pad_w;
+    int h = (index / width) % height + pad_h;
     int c = (index / width / height) % channels;
     int n = index / width / height / channels;
-    int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
-    int phend = min(h / stride + 1, pooled_height);
-    int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
-    int pwend = min(w / stride + 1, pooled_width);
+    int phstart = (h < kernel_size_h) ? 0 : (h - kernel_size_h) / stride_h + 1;
+    int phend = min(h / stride_h + 1, pooled_height);
+    int pwstart = (w < kernel_size_w) ? 0 : (w - kernel_size_w) / stride_w + 1;
+    int pwend = min(w / stride_w + 1, pooled_width);
     Dtype gradient = 0;
     top_diff += (n * channels + c) * pooled_height * pooled_width;
     for (int ph = phstart; ph < phend; ++ph) {
       for (int pw = pwstart; pw < pwend; ++pw) {
         // figure out the pooling size
-        int hstart = ph * stride - pad;
-        int wstart = pw * stride - pad;
-        int hend = min(hstart + kernel_size, height + pad);
-        int wend = min(wstart + kernel_size, width + pad);
+        int hstart = ph * stride_h - pad_h;
+        int wstart = pw * stride_w - pad_w;
+        int hend = min(hstart + kernel_size_h, height + pad_h);
+        int wend = min(wstart + kernel_size_w, width + pad_w);
         int pool_size = (hend - hstart) * (wend - wstart);
         gradient += top_diff[ph * pooled_width + pw] / pool_size;
       }
@@ -294,7 +302,8 @@ __global__ void StoPoolBackward(const int nthreads,
     const Dtype* rand_idx, const Dtype* top_diff,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size, const int stride, Dtype* bottom_diff) {
+    const int kernel_size_h, const int kernel_size_w, const int stride_h, 
+    const int stride_w, Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
@@ -302,10 +311,10 @@ __global__ void StoPoolBackward(const int nthreads,
     int h = (index / width) % height;
     int c = (index / width / height) % channels;
     int n = index / width / height / channels;
-    int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
-    int phend = min(h / stride + 1, pooled_height);
-    int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
-    int pwend = min(w / stride + 1, pooled_width);
+    int phstart = (h < kernel_size_h) ? 0 : (h - kernel_size_h) / stride_h + 1;
+    int phend = min(h / stride_h + 1, pooled_height);
+    int pwstart = (w < kernel_size_w) ? 0 : (w - kernel_size_w) / stride_w + 1;
+    int pwend = min(w / stride_w + 1, pooled_width);
     Dtype gradient = 0;
     rand_idx += (n * channels + c) * pooled_height * pooled_width;
     top_diff += (n * channels + c) * pooled_height * pooled_width;
@@ -345,21 +354,23 @@ void PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, top_diff, mask, top_mask, top[0]->num(), channels_,
         height_, width_, pooled_height_, pooled_width_,
-        kernel_size_, stride_, pad_, bottom_diff);
+        kernel_size_h_, kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, 
+        bottom_diff);
     break;
   case PoolingParameter_PoolMethod_AVE:
     // NOLINT_NEXT_LINE(whitespace/operators)
     AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, top_diff, top[0]->num(), channels_,
-        height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
-        pad_, bottom_diff);
+        height_, width_, pooled_height_, pooled_width_, kernel_size_h_, 
+        kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, bottom_diff);
     break;
   case PoolingParameter_PoolMethod_STOCHASTIC:
     // NOLINT_NEXT_LINE(whitespace/operators)
     StoPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, rand_idx_.gpu_data(), top_diff,
         top[0]->num(), channels_, height_, width_, pooled_height_,
-        pooled_width_, kernel_size_, stride_, bottom_diff);
+        pooled_width_, kernel_size_h_, kernel_size_w_, stride_h_, stride_w_, 
+        bottom_diff);
     break;
   default:
     LOG(FATAL) << "Unknown pooling method.";