more const in pooling layer CUDA kernels
authorJonathan L Long <jonlong@cs.berkeley.edu>
Wed, 20 May 2015 05:59:23 +0000 (22:59 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Wed, 20 May 2015 06:41:06 +0000 (23:41 -0700)
This treats pointer arguments in the same way as non-pointer arguments,
and should help to avoid issues like the previous dangerous state issue.

src/caffe/layers/pooling_layer.cu

index a1080eb..ca4b13f 100644 (file)
@@ -9,21 +9,21 @@
 namespace caffe {
 
 template <typename Dtype>
-__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
-    const int num, const int channels, const int height,
-    const int width, const int pooled_height, const int pooled_width,
-    const int kernel_h, const int kernel_w, const int stride_h,
-    const int stride_w, const int pad_h, const int pad_w, Dtype* top_data,
-    int* mask, Dtype* top_mask) {
+__global__ void MaxPoolForward(const int nthreads,
+    const Dtype* const bottom_data, const int num, const int channels,
+    const int height, const int width, const int pooled_height,
+    const int pooled_width, const int kernel_h, const int kernel_w,
+    const int stride_h, const int stride_w, const int pad_h, const int pad_w,
+    Dtype* const top_data, int* mask, Dtype* top_mask) {
   CUDA_KERNEL_LOOP(index, nthreads) {
-    int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height;
-    int c = (index / pooled_width / pooled_height) % channels;
-    int n = index / pooled_width / pooled_height / channels;
+    const int pw = index % pooled_width;
+    const int ph = (index / pooled_width) % pooled_height;
+    const int c = (index / pooled_width / pooled_height) % channels;
+    const int n = index / pooled_width / pooled_height / channels;
     int hstart = ph * stride_h - pad_h;
     int wstart = pw * stride_w - pad_w;
-    int hend = min(hstart + kernel_h, height);
-    int wend = min(wstart + kernel_w, width);
+    const int hend = min(hstart + kernel_h, height);
+    const int wend = min(wstart + kernel_w, width);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
     Dtype maxval = -FLT_MAX;
@@ -48,21 +48,22 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
 }
 
 template <typename Dtype>
-__global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
-    const int num, const int channels, const int height,
-    const int width, const int pooled_height, const int pooled_width,
-    const int kernel_h, const int kernel_w, const int stride_h,
-    const int stride_w, const int pad_h, const int pad_w, Dtype* top_data) {
+__global__ void AvePoolForward(const int nthreads,
+    const Dtype* const bottom_data, const int num, const int channels,
+    const int height, const int width, const int pooled_height,
+    const int pooled_width, const int kernel_h, const int kernel_w,
+    const int stride_h, const int stride_w, const int pad_h, const int pad_w,
+    Dtype* const top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
-    int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height;
-    int c = (index / pooled_width / pooled_height) % channels;
-    int n = index / pooled_width / pooled_height / channels;
+    const int pw = index % pooled_width;
+    const int ph = (index / pooled_width) % pooled_height;
+    const int c = (index / pooled_width / pooled_height) % channels;
+    const int n = index / pooled_width / pooled_height / channels;
     int hstart = ph * stride_h - pad_h;
     int wstart = pw * stride_w - pad_w;
     int hend = min(hstart + kernel_h, height + pad_h);
     int wend = min(wstart + kernel_w, width + pad_w);
-    int pool_size = (hend - hstart) * (wend - wstart);
+    const int pool_size = (hend - hstart) * (wend - wstart);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
     hend = min(hend, height);
@@ -81,20 +82,20 @@ __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
 
 template <typename Dtype>
 __global__ void StoPoolForwardTrain(const int nthreads,
-    const Dtype* bottom_data,
+    const Dtype* const bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
     const int kernel_h, const int kernel_w, const int stride_h,
-    const int stride_w, Dtype* rand_idx, Dtype* top_data) {
+    const int stride_w, Dtype* const rand_idx, Dtype* const top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
-    int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height;
-    int c = (index / pooled_width / pooled_height) % channels;
-    int n = index / pooled_width / pooled_height / channels;
-    int hstart = ph * stride_h;
-    int hend = min(hstart + kernel_h, height);
-    int wstart = pw * stride_w;
-    int wend = min(wstart + kernel_w, width);
+    const int pw = index % pooled_width;
+    const int ph = (index / pooled_width) % pooled_height;
+    const int c = (index / pooled_width / pooled_height) % channels;
+    const int n = index / pooled_width / pooled_height / channels;
+    const int hstart = ph * stride_h;
+    const int hend = min(hstart + kernel_h, height);
+    const int wstart = pw * stride_w;
+    const int wend = min(wstart + kernel_w, width);
     Dtype cumsum = 0.;
     const Dtype* const bottom_slice =
         bottom_data + (n * channels + c) * height * width;
@@ -104,7 +105,7 @@ __global__ void StoPoolForwardTrain(const int nthreads,
         cumsum += bottom_slice[h * width + w];
       }
     }
-    float thres = rand_idx[index] * cumsum;
+    const float thres = rand_idx[index] * cumsum;
     // Second pass: get value, and set index.
     cumsum = 0;
     for (int h = hstart; h < hend; ++h) {
@@ -123,20 +124,20 @@ __global__ void StoPoolForwardTrain(const int nthreads,
 
 template <typename Dtype>
 __global__ void StoPoolForwardTest(const int nthreads,
-    const Dtype* bottom_data,
+    const Dtype* const bottom_data,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
     const int kernel_h, const int kernel_w, const int stride_h,
-    const int stride_w, Dtype* top_data) {
+    const int stride_w, Dtype* const top_data) {
   CUDA_KERNEL_LOOP(index, nthreads) {
-    int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height;
-    int c = (index / pooled_width / pooled_height) % channels;
-    int n = index / pooled_width / pooled_height / channels;
-    int hstart = ph * stride_h;
-    int hend = min(hstart + kernel_h, height);
-    int wstart = pw * stride_w;
-    int wend = min(wstart + kernel_w, width);
+    const int pw = index % pooled_width;
+    const int ph = (index / pooled_width) % pooled_height;
+    const int c = (index / pooled_width / pooled_height) % channels;
+    const int n = index / pooled_width / pooled_height / channels;
+    const int hstart = ph * stride_h;
+    const int hend = min(hstart + kernel_h, height);
+    const int wstart = pw * stride_w;
+    const int wend = min(wstart + kernel_w, width);
     // We set cumsum to be 0 to avoid divide-by-zero problems
     Dtype cumsum = FLT_MIN;
     Dtype cumvalues = 0.;
@@ -214,27 +215,27 @@ void PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 
 template <typename Dtype>
-__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
-    const int* mask, const Dtype* top_mask, const int num, const int channels,
-    const int height, const int width, const int pooled_height,
-    const int pooled_width, const int kernel_h, const int kernel_w,
-    const int stride_h, const int stride_w, const int pad_h, const int pad_w,
-    Dtype* bottom_diff) {
+__global__ void MaxPoolBackward(const int nthreads, const Dtype* const top_diff,
+    const int* const mask, const Dtype* const top_mask, const int num,
+    const int channels, const int height, const int width,
+    const int pooled_height, const int pooled_width, const int kernel_h,
+    const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
+    const int pad_w, Dtype* const bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
-    int w = index % width;
-    int h = (index / width) % height;
-    int c = (index / width / height) % channels;
-    int n = index / width / height / channels;
-    int phstart =
-        (h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
-    int phend = min((h + pad_h) / stride_h + 1, pooled_height);
-    int pwstart =
-        (w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
-    int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
+    const int w = index % width;
+    const int h = (index / width) % height;
+    const int c = (index / width / height) % channels;
+    const int n = index / width / height / channels;
+    const int phstart =
+         (h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
+    const int phend = min((h + pad_h) / stride_h + 1, pooled_height);
+    const int pwstart =
+         (w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
+    const int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
     Dtype gradient = 0;
-    int offset = (n * channels + c) * pooled_height * pooled_width;
+    const int offset = (n * channels + c) * pooled_height * pooled_width;
     const Dtype* const top_diff_slice = top_diff + offset;
     if (mask) {
       const int* const mask_slice = mask + offset;
@@ -260,23 +261,23 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
 }
 
 template <typename Dtype>
-__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
+__global__ void AvePoolBackward(const int nthreads, const Dtype* const top_diff,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
     const int kernel_h, const int kernel_w, const int stride_h,
     const int stride_w, const int pad_h, const int pad_w,
-    Dtype* bottom_diff) {
+    Dtype* const bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
-    int w = index % width + pad_w;
-    int h = (index / width) % height + pad_h;
-    int c = (index / width / height) % channels;
-    int n = index / width / height / channels;
-    int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    int phend = min(h / stride_h + 1, pooled_height);
-    int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    int pwend = min(w / stride_w + 1, pooled_width);
+    const int w = index % width + pad_w;
+    const int h = (index / width) % height + pad_h;
+    const int c = (index / width / height) % channels;
+    const int n = index / width / height / channels;
+    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
+    const int phend = min(h / stride_h + 1, pooled_height);
+    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
+    const int pwend = min(w / stride_w + 1, pooled_width);
     Dtype gradient = 0;
     const Dtype* const top_diff_slice =
         top_diff + (n * channels + c) * pooled_height * pooled_width;
@@ -298,22 +299,22 @@ __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
 
 template <typename Dtype>
 __global__ void StoPoolBackward(const int nthreads,
-    const Dtype* rand_idx, const Dtype* top_diff,
+    const Dtype* const rand_idx, const Dtype* const top_diff,
     const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
     const int kernel_h, const int kernel_w, const int stride_h,
-    const int stride_w, Dtype* bottom_diff) {
+    const int stride_w, Dtype* const bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
-    int w = index % width;
-    int h = (index / width) % height;
-    int c = (index / width / height) % channels;
-    int n = index / width / height / channels;
-    int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    int phend = min(h / stride_h + 1, pooled_height);
-    int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    int pwend = min(w / stride_w + 1, pooled_width);
+    const int w = index % width;
+    const int h = (index / width) % height;
+    const int c = (index / width / height) % channels;
+    const int n = index / width / height / channels;
+    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
+    const int phend = min(h / stride_h + 1, pooled_height);
+    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
+    const int pwend = min(w / stride_w + 1, pooled_width);
     Dtype gradient = 0;
     const Dtype* const rand_idx_slice =
         rand_idx + (n * channels + c) * pooled_height * pooled_width;