avoid dangerous state in LRN layer CUDA kernels
authorJonathan L Long <jonlong@cs.berkeley.edu>
Wed, 20 May 2015 05:52:12 +0000 (22:52 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Wed, 20 May 2015 06:41:06 +0000 (23:41 -0700)
src/caffe/layers/lrn_layer.cu

index 24aa6a3..e50ae8d 100644 (file)
@@ -18,8 +18,8 @@ __global__ void LRNFillScale(const int nthreads, const Dtype* in,
     int n = index / width / height;
     int offset = (n * channels * height + h) * width + w;
     int step = height * width;
-    in += offset;
-    scale += offset;
+    const Dtype* const in_off = in + offset;
+    Dtype* const scale_off = scale + offset;
     int head = 0;
     int pre_pad = (size - 1) / 2;
     int post_pad = size - pre_pad - 1;
@@ -27,24 +27,26 @@ __global__ void LRNFillScale(const int nthreads, const Dtype* in,
     // fill the scale at [n, :, h, w]
     // accumulate values
     while (head < post_pad && head < channels) {
-      accum_scale += in[head * step] * in[head * step];
+      accum_scale += in_off[head * step] * in_off[head * step];
       ++head;
     }
     // both add and subtract
     while (head < channels) {
-      accum_scale += in[head * step] * in[head * step];
+      accum_scale += in_off[head * step] * in_off[head * step];
       if (head - size >= 0) {
-        accum_scale -= in[(head - size) * step] * in[(head - size) * step];
+        accum_scale -= in_off[(head - size) * step]
+                       * in_off[(head - size) * step];
       }
-      scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
+      scale_off[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
       ++head;
     }
     // subtract only
     while (head < channels + post_pad) {
       if (head - size >= 0) {
-        accum_scale -= in[(head - size) * step] * in[(head - size) * step];
+        accum_scale -= in_off[(head - size) * step]
+                       * in_off[(head - size) * step];
       }
-      scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
+      scale_off[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
       ++head;
     }
   }
@@ -131,43 +133,45 @@ __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
     int n = index / width / height;
     int offset = (n * channels * height + h) * width + w;
     int step = height * width;
-    bottom_data += offset;
-    top_data += offset;
-    scale += offset;
-    top_diff += offset;
-    bottom_diff += offset;
+    const Dtype* const bottom_off = bottom_data + offset;
+    const Dtype* const top_off = top_data + offset;
+    const Dtype* const scale_off = scale + offset;
+    const Dtype* const top_diff_off = top_diff + offset;
+    Dtype* const bottom_diff_off = bottom_diff + offset;
     int head = 0;
     int pre_pad = size - (size + 1) / 2;
     int post_pad = size - pre_pad - 1;
     Dtype accum_ratio = 0;
     // accumulate values
     while (head < post_pad && head < channels) {
-      accum_ratio += top_diff[head * step] * top_data[head * step] /
-          scale[head * step];
+      accum_ratio += top_diff_off[head * step] * top_off[head * step] /
+          scale_off[head * step];
       ++head;
     }
     // both add and subtract
     while (head < channels) {
-      accum_ratio += top_diff[head * step] * top_data[head * step] /
-          scale[head * step];
+      accum_ratio += top_diff_off[head * step] * top_off[head * step] /
+          scale_off[head * step];
       if (head - size >= 0) {
-        accum_ratio -= top_diff[(head - size) * step] *
-            top_data[(head - size) * step] / scale[(head - size) * step];
+        accum_ratio -= top_diff_off[(head - size) * step] *
+            top_off[(head - size) * step] / scale_off[(head - size) * step];
       }
-      bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
-          * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
-          bottom_data[(head - post_pad) * step] * accum_ratio;
+      bottom_diff_off[(head - post_pad) * step] =
+          top_diff_off[(head - post_pad) * step]
+            * pow(scale_off[(head - post_pad) * step], negative_beta)
+          - cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio;
       ++head;
     }
     // subtract only
     while (head < channels + post_pad) {
       if (head - size >= 0) {
-        accum_ratio -= top_diff[(head - size) * step] *
-            top_data[(head - size) * step] / scale[(head - size) * step];
+        accum_ratio -= top_diff_off[(head - size) * step] *
+            top_off[(head - size) * step] / scale_off[(head - size) * step];
       }
-      bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
-          * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
-          bottom_data[(head - post_pad) * step] * accum_ratio;
+      bottom_diff_off[(head - post_pad) * step] =
+          top_diff_off[(head - post_pad) * step]
+            * pow(scale_off[(head - post_pad) * step], negative_beta)
+          - cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio;
       ++head;
     }
   }