From 7a8fcc763dc89e717a58319a77da9d9813a829c0 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Tue, 19 May 2015 22:41:50 -0700 Subject: [PATCH] avoid dangerous state in pooling layer CUDA kernels Previously, pointers were modified with the assumption that they would only be modified once. While this is true so far in practice, the introduction of CUDA_KERNEL_LOOP makes this a dangerous assumption. --- src/caffe/layers/pooling_layer.cu | 57 ++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/src/caffe/layers/pooling_layer.cu b/src/caffe/layers/pooling_layer.cu index d1d4850..a1080eb 100644 --- a/src/caffe/layers/pooling_layer.cu +++ b/src/caffe/layers/pooling_layer.cu @@ -28,12 +28,13 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data, wstart = max(wstart, 0); Dtype maxval = -FLT_MAX; int maxidx = -1; - bottom_data += (n * channels + c) * height * width; + const Dtype* const bottom_slice = + bottom_data + (n * channels + c) * height * width; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - if (bottom_data[h * width + w] > maxval) { + if (bottom_slice[h * width + w] > maxval) { maxidx = h * width + w; - maxval = bottom_data[maxidx]; + maxval = bottom_slice[maxidx]; } } } @@ -67,10 +68,11 @@ __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data, hend = min(hend, height); wend = min(wend, width); Dtype aveval = 0; - bottom_data += (n * channels + c) * height * width; + const Dtype* const bottom_slice = + bottom_data + (n * channels + c) * height * width; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - aveval += bottom_data[h * width + w]; + aveval += bottom_slice[h * width + w]; } } top_data[index] = aveval / pool_size; @@ -94,11 +96,12 @@ __global__ void StoPoolForwardTrain(const int nthreads, int wstart = pw * stride_w; int wend = min(wstart + kernel_w, width); Dtype cumsum = 0.; - bottom_data += (n * channels + c) * height * width; + const Dtype* const bottom_slice = + bottom_data + (n * channels + c) * height * width; // First pass: get sum for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - cumsum += bottom_data[h * width + w]; + cumsum += bottom_slice[h * width + w]; } } float thres = rand_idx[index] * cumsum; @@ -106,10 +109,10 @@ __global__ void StoPoolForwardTrain(const int nthreads, cumsum = 0; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - cumsum += bottom_data[h * width + w]; + cumsum += bottom_slice[h * width + w]; if (cumsum >= thres) { rand_idx[index] = ((n * channels + c) * height + h) * width + w; - top_data[index] = bottom_data[h * width + w]; + top_data[index] = bottom_slice[h * width + w]; return; } } @@ -137,12 +140,13 @@ __global__ void StoPoolForwardTest(const int nthreads, // We set cumsum to be 0 to avoid divide-by-zero problems Dtype cumsum = FLT_MIN; Dtype cumvalues = 0.; - bottom_data += (n * channels + c) * height * width; + const Dtype* const bottom_slice = + bottom_data + (n * channels + c) * height * width; // First pass: get sum for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - cumsum += bottom_data[h * width + w]; - cumvalues += bottom_data[h * width + w] * bottom_data[h * width + w]; + cumsum += bottom_slice[h * width + w]; + cumvalues += bottom_slice[h * width + w] * bottom_slice[h * width + w]; } } top_data[index] = cumvalues / cumsum; @@ -231,22 +235,22 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff, int pwend = min((w + pad_w) / stride_w + 1, pooled_width); Dtype gradient = 0; int offset = (n * channels + c) * pooled_height * pooled_width; - top_diff += offset; + const Dtype* const top_diff_slice = top_diff + offset; if (mask) { - mask += offset; + const int* const mask_slice = mask + offset; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { - if (mask[ph * pooled_width + pw] == h * width + w) { - gradient += top_diff[ph * pooled_width + pw]; + if (mask_slice[ph * pooled_width + pw] == h * width + w) { + gradient += top_diff_slice[ph * pooled_width + pw]; } } } } else { - top_mask += offset; + const Dtype* const top_mask_slice = top_mask + offset; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { - if (top_mask[ph * pooled_width + pw] == h * width + w) { - gradient += top_diff[ph * pooled_width + pw]; + if (top_mask_slice[ph * pooled_width + pw] == h * width + w) { + gradient += top_diff_slice[ph * pooled_width + pw]; } } } @@ -274,7 +278,8 @@ __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff, int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; int pwend = min(w / stride_w + 1, pooled_width); Dtype gradient = 0; - top_diff += (n * channels + c) * pooled_height * pooled_width; + const Dtype* const top_diff_slice = + top_diff + (n * channels + c) * pooled_height * pooled_width; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size @@ -283,7 +288,7 @@ __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff, int hend = min(hstart + kernel_h, height + pad_h); int wend = min(wstart + kernel_w, width + pad_w); int pool_size = (hend - hstart) * (wend - wstart); - gradient += top_diff[ph * pooled_width + pw] / pool_size; + gradient += top_diff_slice[ph * pooled_width + pw] / pool_size; } } bottom_diff[index] = gradient; @@ -310,12 +315,14 @@ __global__ void StoPoolBackward(const int nthreads, int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; int pwend = min(w / stride_w + 1, pooled_width); Dtype gradient = 0; - rand_idx += (n * channels + c) * pooled_height * pooled_width; - top_diff += (n * channels + c) * pooled_height * pooled_width; + const Dtype* const rand_idx_slice = + rand_idx + (n * channels + c) * pooled_height * pooled_width; + const Dtype* const top_diff_slice = + top_diff + (n * channels + c) * pooled_height * pooled_width; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { - gradient += top_diff[ph * pooled_width + pw] * - (index == static_cast(rand_idx[ph * pooled_width + pw])); + gradient += top_diff_slice[ph * pooled_width + pw] * + (index == static_cast(rand_idx_slice[ph * pooled_width + pw])); } } bottom_diff[index] = gradient; -- 2.7.4