From: Yangqing Jia Date: Mon, 23 Sep 2013 21:42:24 +0000 (-0700) Subject: pooling layer X-Git-Tag: submit/tizen/20180823.020014~1018 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=65b44b62d206b2fa04c31b1958c46b59cffacdf7;p=platform%2Fupstream%2Fcaffeonacl.git pooling layer --- diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp index d7e6b4f..88a8f41 100644 --- a/src/caffe/layers/pooling_layer.cpp +++ b/src/caffe/layers/pooling_layer.cpp @@ -4,13 +4,13 @@ #include "caffe/vision_layers.hpp" #include "caffe/util/math_functions.hpp" +#define CAFFE_MAX_POOLING_THRESHOLD 1e-8f + using std::max; using std::min; namespace caffe { -const float CAFFE_MAX_POOLING_THRESHOLD = 1e-8; - template void PoolingLayer::SetUp(const vector*>& bottom, vector*>* top) { diff --git a/src/caffe/layers/pooling_layer.cu b/src/caffe/layers/pooling_layer.cu new file mode 100644 index 0000000..2086869 --- /dev/null +++ b/src/caffe/layers/pooling_layer.cu @@ -0,0 +1,187 @@ +#include +#include +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" + +#define CAFFE_MAX_POOLING_THRESHOLD 1e-8f + +using std::max; +using std::min; + +namespace caffe { + +template +__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int ksize, const int stride, Dtype* top_data) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride; + int hend = min(hstart + ksize, height); + int wstart = pw * stride; + int wend = min(wstart + ksize, width); + Dtype maxval = -FLT_MAX; + bottom_data += (n * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + maxval = max(maxval, bottom_data[h * width + w]); + } + } + top_data[index] = maxval; + } // (if index < nthreads) +} + +template +__global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int ksize, const int stride, Dtype* top_data) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride; + int hend = min(hstart + ksize, height); + int wstart = pw * stride; + int wend = min(wstart + ksize, width); + Dtype aveval = 0; + bottom_data += (n * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + aveval += bottom_data[h * width + w]; + } + } + top_data[index] = aveval / ksize / ksize; + } // (if index < nthreads) +} + +template +void PoolingLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = (*top)[0]->mutable_gpu_data(); + int count = (*top)[0]->count(); + switch (this->layer_param_.pool()) { + case LayerParameter_PoolMethod_MAX: + MaxPoolForward<<>>( + count, bottom_data, bottom[0]->num(), CHANNELS_, + HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_, + top_data); + break; + case LayerParameter_PoolMethod_AVE: + AvePoolForward<<>>( + count, bottom_data, bottom[0]->num(), CHANNELS_, + HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_, + top_data); + break; + default: + LOG(FATAL) << "Unknown pooling method."; + } + CUDA_POST_KERNEL_CHECK; +} + +template +__global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data, + const Dtype* top_data, const Dtype* top_diff, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int ksize, const int stride, Dtype* bottom_diff) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < nthreads) { + // find out the local index + // find out the local offset + int w = index % width; + int h = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1; + int phend = min(h / stride + 1, pooled_height); + int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1; + int pwend = min(w / stride + 1, pooled_width); + Dtype gradient = 0; + Dtype bottom_datum = + bottom_data[((n * channels + c) * height + h) * width + w]; + top_data += (n * channels + c) * pooled_height * pooled_width; + top_diff += (n * channels + c) * pooled_height * pooled_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + gradient += top_diff[ph * pooled_width + pw] * + (bottom_datum >= top_data[ph * pooled_width + pw] - + CAFFE_MAX_POOLING_THRESHOLD); + } + } + bottom_diff[index] = gradient; + } // (if index < nthreads) +} + + +template +__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int ksize, const int stride, Dtype* bottom_diff) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < nthreads) { + // find out the local index + // find out the local offset + int w = index % width; + int h = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1; + int phend = min(h / stride + 1, pooled_height); + int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1; + int pwend = min(w / stride + 1, pooled_width); + Dtype gradient = 0; + top_diff += (n * channels + c) * pooled_height * pooled_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + gradient += top_diff[ph * pooled_width + pw]; + } + } + bottom_diff[index] = gradient / ksize / ksize; + } // (if index < nthreads) +} + +template +Dtype PoolingLayer::Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom) { + if (!propagate_down) { + return Dtype(0.); + } + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + int count = (*bottom)[0]->count(); + switch (this->layer_param_.pool()) { + case LayerParameter_PoolMethod_MAX: + MaxPoolBackward<<>>( + count, (*bottom)[0]->gpu_data(), top[0]->gpu_data(), top_diff, + top[0]->num(), CHANNELS_, HEIGHT_, WIDTH_, POOLED_HEIGHT_, + POOLED_WIDTH_, KSIZE_, STRIDE_, bottom_diff); + break; + case LayerParameter_PoolMethod_AVE: + AvePoolBackward<<>>( + count, top_diff, top[0]->num(), CHANNELS_, + HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_, + bottom_diff); + break; + default: + LOG(FATAL) << "Unknown pooling method."; + } + CUDA_POST_KERNEL_CHECK; + return Dtype(0.); +} + + +INSTANTIATE_CLASS(PoolingLayer); + + +} // namespace caffe diff --git a/src/caffe/test/test_gradient_check_util.hpp b/src/caffe/test/test_gradient_check_util.hpp index 8cf7851..7984aaf 100644 --- a/src/caffe/test/test_gradient_check_util.hpp +++ b/src/caffe/test/test_gradient_check_util.hpp @@ -110,9 +110,11 @@ void GradientChecker::CheckGradientSingle(Layer& layer, Dtype scale = max(max(fabs(computed_gradient), fabs(estimated_gradient)), 1.); EXPECT_GT(computed_gradient, estimated_gradient - threshold_ * scale) - << "debug: (blob_id, feat_id)=" << blobid << "," << feat_id; + << "debug: (top_id, top_data_id, blob_id, feat_id)=" + << top_id << "," << top_data_id << "," << blobid << "," << feat_id; EXPECT_LT(computed_gradient, estimated_gradient + threshold_ * scale) - << "debug: (blob_id, feat_id)=" << blobid << "," << feat_id; + << "debug: (top_id, top_data_id, blob_id, feat_id)=" + << top_id << "," << top_data_id << "," << blobid << "," << feat_id; } //LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id]; //LOG(ERROR) << "computed gradient: " << computed_gradient diff --git a/src/caffe/test/test_pooling_layer.cpp b/src/caffe/test/test_pooling_layer.cpp index d9d76fe..2e3c60f 100644 --- a/src/caffe/test/test_pooling_layer.cpp +++ b/src/caffe/test/test_pooling_layer.cpp @@ -21,6 +21,7 @@ class PoolingLayerTest : public ::testing::Test { : blob_bottom_(new Blob()), blob_top_(new Blob()) {}; virtual void SetUp() { + Caffe::set_random_seed(1701); blob_bottom_->Reshape(2, 3, 6, 5); // fill the values FillerParameter filler_param; @@ -53,6 +54,71 @@ TYPED_TEST(PoolingLayerTest, TestSetup) { EXPECT_EQ(this->blob_top_->width(), 2); } +TYPED_TEST(PoolingLayerTest, TestGPUMax) { + LayerParameter layer_param; + layer_param.set_kernelsize(3); + layer_param.set_stride(2); + layer_param.set_pool(LayerParameter_PoolMethod_MAX); + Caffe::set_mode(Caffe::CPU); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + Blob blob_reference(*this->blob_top_); + Caffe::set_mode(Caffe::GPU); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + for (int i = 0; i < blob_reference.count(); ++i) { + EXPECT_EQ(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i]) + << "debug: index " << i; + } +} + +TYPED_TEST(PoolingLayerTest, TestGPUAve) { + LayerParameter layer_param; + layer_param.set_kernelsize(3); + layer_param.set_stride(2); + layer_param.set_pool(LayerParameter_PoolMethod_AVE); + Caffe::set_mode(Caffe::CPU); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + Blob blob_reference(*this->blob_top_); + Caffe::set_mode(Caffe::GPU); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + for (int i = 0; i < blob_reference.count(); ++i) { + EXPECT_GE(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i] - 1e-4) + << "debug: index " << i; + EXPECT_LE(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i] + 1e-4) + << "debug: index " << i; + } +} + +/* +TYPED_TEST(PoolingLayerTest, PrintGPUBackward) { + LayerParameter layer_param; + layer_param.set_kernelsize(3); + layer_param.set_stride(2); + layer_param.set_pool(LayerParameter_PoolMethod_MAX); + Caffe::set_mode(Caffe::GPU); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + cout << "bottom data " << i << " " << this->blob_bottom_->cpu_data()[i] << endl; + } + for (int i = 0; i < this->blob_top_->count(); ++i) { + cout << "top data " << i << " " << this->blob_top_->cpu_data()[i] << endl; + } + + for (int i = 0; i < this->blob_top_->count(); ++i) { + this->blob_top_->mutable_cpu_diff()[i] = 1.; + } + layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_)); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + cout << "bottom diff " << i << " " << this->blob_bottom_->cpu_diff()[i] << endl; + } +} +*/ + TYPED_TEST(PoolingLayerTest, TestCPUGradientMax) { LayerParameter layer_param; layer_param.set_kernelsize(3); @@ -64,6 +130,17 @@ TYPED_TEST(PoolingLayerTest, TestCPUGradientMax) { checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); } +TYPED_TEST(PoolingLayerTest, TestGPUGradientMax) { + LayerParameter layer_param; + layer_param.set_kernelsize(3); + layer_param.set_stride(2); + layer_param.set_pool(LayerParameter_PoolMethod_MAX); + Caffe::set_mode(Caffe::GPU); + PoolingLayer layer(layer_param); + GradientChecker checker(1e-4, 1e-2); + checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); +} + TYPED_TEST(PoolingLayerTest, TestCPUGradientAve) { LayerParameter layer_param; @@ -76,16 +153,17 @@ TYPED_TEST(PoolingLayerTest, TestCPUGradientAve) { checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); } -/* -TYPED_TEST(PoolingLayerTest, TestGPUGradient) { + +TYPED_TEST(PoolingLayerTest, TestGPUGradientAve) { LayerParameter layer_param; layer_param.set_kernelsize(3); layer_param.set_stride(2); + layer_param.set_pool(LayerParameter_PoolMethod_AVE); Caffe::set_mode(Caffe::GPU); PoolingLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); } -*/ + } diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index aef6842..161a45d 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -60,6 +60,7 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, const int stride, const int height_col, const int width_col, Dtype* data_im) { int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < n) { + Dtype val = 0; int w = index % width; int h = (index / width) % height; int c = index / (width * height); @@ -72,9 +73,10 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { // the col location: [c * width * height + h_out, w_out] int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride); - data_im[index] += data_col[(c_col * height_col + h_col) * width_col + w_col]; + val += data_col[(c_col * height_col + h_col) * width_col + w_col]; } } + data_im[index] = val; } } @@ -82,7 +84,7 @@ template void col2im_gpu(const Dtype* data_col, const int channels, const int height, const int width, const int ksize, const int stride, Dtype* data_im) { - CUDA_CHECK(cudaMemset(data_im, 0, sizeof(Dtype) * height * width * channels)); + //CUDA_CHECK(cudaMemset(data_im, 0, sizeof(Dtype) * height * width * channels)); int height_col = (height - ksize) / stride + 1; int width_col = (width - ksize) / stride + 1; int num_kernels = channels * height * width; diff --git a/src/caffe/vision_layers.hpp b/src/caffe/vision_layers.hpp index 31c6b1d..5d99c48 100644 --- a/src/caffe/vision_layers.hpp +++ b/src/caffe/vision_layers.hpp @@ -169,12 +169,12 @@ class PoolingLayer : public Layer { protected: virtual void Forward_cpu(const vector*>& bottom, vector*>* top); - //virtual void Forward_gpu(const vector*>& bottom, - // vector*>* top); + virtual void Forward_gpu(const vector*>& bottom, + vector*>* top); virtual Dtype Backward_cpu(const vector*>& top, const bool propagate_down, vector*>* bottom); - //virtual Dtype Backward_gpu(const vector*>& top, - // const bool propagate_down, vector*>* bottom); + virtual Dtype Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); int KSIZE_; int STRIDE_; int CHANNELS_;