Use loops in GPU again to avoid over-writting of bottom_diff
authorSergio <sguada@gmail.com>
Mon, 14 Apr 2014 03:21:15 +0000 (20:21 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Sat, 24 May 2014 22:15:11 +0000 (15:15 -0700)
src/caffe/layers/pooling_layer.cu

index 8be1521..e3d7546 100644 (file)
@@ -12,7 +12,7 @@ using std::max;
 using std::min;
 
 namespace caffe {
-
 template <typename Dtype>
 __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
     const int num, const int channels, const int height,
@@ -195,40 +195,6 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   return Dtype(0.);
 }
 
-// template <typename Dtype>
-// __global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
-//     const Dtype* top_data, const Dtype* top_diff,
-//     const int num, const int channels, const int height,
-//     const int width, const int pooled_height, const int pooled_width,
-//     const int ksize, const int stride, Dtype* bottom_diff, int* mask) {
-//   int index = threadIdx.x + blockIdx.x * blockDim.x;
-//   if (index < nthreads) {
-//     // find out the local index
-//     // find out the local offset
-//     int w = index % width;
-//     int h = (index / width) % height;
-//     int c = (index / width / height) % channels;
-//     int n = index / width / height / channels;
-//     int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
-//     int phend = min(h / stride + 1, pooled_height);
-//     int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
-//     int pwend = min(w / stride + 1, pooled_width);
-//     Dtype gradient = 0;
-//     Dtype bottom_datum =
-//         bottom_data[((n * channels + c) * height + h) * width + w];
-//     top_data += (n * channels + c) * pooled_height * pooled_width;
-//     top_diff += (n * channels + c) * pooled_height * pooled_width;
-//     //bottom_diff[index] += top_diff[mask[index]];
-//     for (int ph = phstart; ph < phend; ++ph) {
-//       for (int pw = pwstart; pw < pwend; ++pw) {
-//         gradient += top_diff[ph * pooled_width + pw] *
-//             (bottom_datum == top_data[ph * pooled_width + pw]);
-//       }
-//     }
-//     bottom_diff[index] = gradient;
-//   }  // (if index < nthreads)
-// }
-
 template <typename Dtype>
 __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
     const int num, const int channels, const int height,
@@ -237,13 +203,43 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
-    int c = (index / pooled_width / pooled_height) % channels;
-    int n = index / pooled_width / pooled_height / channels;
-    bottom_diff += (n * channels + c) * height * width;
-    bottom_diff[mask[index]] += top_diff[index];
+    int w = index % width;
+    int h = (index / width) % height;
+    int c = (index / width / height) % channels;
+    int n = index / width / height / channels;
+    int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
+    int phend = min(h / stride + 1, pooled_height);
+    int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
+    int pwend = min(w / stride + 1, pooled_width);
+    Dtype gradient = 0;
+    top_diff += (n * channels + c) * pooled_height * pooled_width;
+    mask += (n * channels + c) * pooled_height * pooled_width;
+    //bottom_diff[index] += top_diff[mask[index]];
+    for (int ph = phstart; ph < phend; ++ph) {
+      for (int pw = pwstart; pw < pwend; ++pw) {
+        if (mask[ph * pooled_width + pw] == h * width + w)
+          gradient += top_diff[ph * pooled_width + pw];
+      }
+    }
+    bottom_diff[index] = gradient;
   }  
 }
 
+// template <typename Dtype>
+// __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
+//     const int num, const int channels, const int height,
+//     const int width, const int pooled_height, const int pooled_width,
+//     const int ksize, const int stride, Dtype* bottom_diff, int* mask) {
+//   CUDA_KERNEL_LOOP(index, nthreads) {
+//     // find out the local index
+//     // find out the local offset
+//     int c = (index / pooled_width / pooled_height) % channels;
+//     int n = index / pooled_width / pooled_height / channels;
+//     bottom_diff += (n * channels + c) * height * width;
+//     bottom_diff[mask[index]] += top_diff[index];
+//   } 
+// }
+
 template <typename Dtype>
 __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
     const int num, const int channels, const int height,