From 8f5448be12ae4df1a9bf4250183a4756733abfa9 Mon Sep 17 00:00:00 2001 From: Sergio Date: Sun, 13 Apr 2014 20:21:15 -0700 Subject: [PATCH] Use loops in GPU again to avoid over-writting of bottom_diff --- src/caffe/layers/pooling_layer.cu | 74 ++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/src/caffe/layers/pooling_layer.cu b/src/caffe/layers/pooling_layer.cu index 8be1521..e3d7546 100644 --- a/src/caffe/layers/pooling_layer.cu +++ b/src/caffe/layers/pooling_layer.cu @@ -12,7 +12,7 @@ using std::max; using std::min; namespace caffe { - + template __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data, const int num, const int channels, const int height, @@ -195,40 +195,6 @@ Dtype PoolingLayer::Forward_gpu(const vector*>& bottom, return Dtype(0.); } -// template -// __global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data, -// const Dtype* top_data, const Dtype* top_diff, -// const int num, const int channels, const int height, -// const int width, const int pooled_height, const int pooled_width, -// const int ksize, const int stride, Dtype* bottom_diff, int* mask) { -// int index = threadIdx.x + blockIdx.x * blockDim.x; -// if (index < nthreads) { -// // find out the local index -// // find out the local offset -// int w = index % width; -// int h = (index / width) % height; -// int c = (index / width / height) % channels; -// int n = index / width / height / channels; -// int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1; -// int phend = min(h / stride + 1, pooled_height); -// int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1; -// int pwend = min(w / stride + 1, pooled_width); -// Dtype gradient = 0; -// Dtype bottom_datum = -// bottom_data[((n * channels + c) * height + h) * width + w]; -// top_data += (n * channels + c) * pooled_height * pooled_width; -// top_diff += (n * channels + c) * pooled_height * pooled_width; -// //bottom_diff[index] += top_diff[mask[index]]; -// for (int ph = phstart; ph < phend; ++ph) { -// for (int pw = pwstart; pw < pwend; ++pw) { -// gradient += top_diff[ph * pooled_width + pw] * -// (bottom_datum == top_data[ph * pooled_width + pw]); -// } -// } -// bottom_diff[index] = gradient; -// } // (if index < nthreads) -// } - template __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff, const int num, const int channels, const int height, @@ -237,13 +203,43 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff, CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index // find out the local offset - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; - bottom_diff += (n * channels + c) * height * width; - bottom_diff[mask[index]] += top_diff[index]; + int w = index % width; + int h = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1; + int phend = min(h / stride + 1, pooled_height); + int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1; + int pwend = min(w / stride + 1, pooled_width); + Dtype gradient = 0; + top_diff += (n * channels + c) * pooled_height * pooled_width; + mask += (n * channels + c) * pooled_height * pooled_width; + //bottom_diff[index] += top_diff[mask[index]]; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if (mask[ph * pooled_width + pw] == h * width + w) + gradient += top_diff[ph * pooled_width + pw]; + } + } + bottom_diff[index] = gradient; } } +// template +// __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff, +// const int num, const int channels, const int height, +// const int width, const int pooled_height, const int pooled_width, +// const int ksize, const int stride, Dtype* bottom_diff, int* mask) { +// CUDA_KERNEL_LOOP(index, nthreads) { +// // find out the local index +// // find out the local offset +// int c = (index / pooled_width / pooled_height) % channels; +// int n = index / pooled_width / pooled_height / channels; +// bottom_diff += (n * channels + c) * height * width; +// bottom_diff[mask[index]] += top_diff[index]; +// } +// } + template __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff, const int num, const int channels, const int height, -- 2.7.4