1 // Copyright 2013 Yangqing Jia
5 #include "caffe/layer.hpp"
6 #include "caffe/vision_layers.hpp"
7 #include "caffe/util/math_functions.hpp"
9 #define CAFFE_MAX_POOLING_THRESHOLD 1e-8f
16 template <typename Dtype>
17 __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
18 const int num, const int channels, const int height,
19 const int width, const int pooled_height, const int pooled_width,
20 const int ksize, const int stride, Dtype* top_data) {
21 int index = threadIdx.x + blockIdx.x * blockDim.x;
22 if (index < nthreads) {
23 int pw = index % pooled_width;
24 int ph = (index / pooled_width) % pooled_height;
25 int c = (index / pooled_width / pooled_height) % channels;
26 int n = index / pooled_width / pooled_height / channels;
27 int hstart = ph * stride;
28 int hend = min(hstart + ksize, height);
29 int wstart = pw * stride;
30 int wend = min(wstart + ksize, width);
31 Dtype maxval = -FLT_MAX;
32 bottom_data += (n * channels + c) * height * width;
33 for (int h = hstart; h < hend; ++h) {
34 for (int w = wstart; w < wend; ++w) {
35 maxval = max(maxval, bottom_data[h * width + w]);
38 top_data[index] = maxval;
39 } // (if index < nthreads)
42 template <typename Dtype>
43 __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
44 const int num, const int channels, const int height,
45 const int width, const int pooled_height, const int pooled_width,
46 const int ksize, const int stride, Dtype* top_data) {
47 int index = threadIdx.x + blockIdx.x * blockDim.x;
48 if (index < nthreads) {
49 int pw = index % pooled_width;
50 int ph = (index / pooled_width) % pooled_height;
51 int c = (index / pooled_width / pooled_height) % channels;
52 int n = index / pooled_width / pooled_height / channels;
53 int hstart = ph * stride;
54 int hend = min(hstart + ksize, height);
55 int wstart = pw * stride;
56 int wend = min(wstart + ksize, width);
58 bottom_data += (n * channels + c) * height * width;
59 for (int h = hstart; h < hend; ++h) {
60 for (int w = wstart; w < wend; ++w) {
61 aveval += bottom_data[h * width + w];
64 top_data[index] = aveval / (hend - hstart) / (wend - wstart);
65 } // (if index < nthreads)
68 template <typename Dtype>
69 void PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
70 vector<Blob<Dtype>*>* top) {
71 const Dtype* bottom_data = bottom[0]->gpu_data();
72 Dtype* top_data = (*top)[0]->mutable_gpu_data();
73 int count = (*top)[0]->count();
74 switch (this->layer_param_.pool()) {
75 case LayerParameter_PoolMethod_MAX:
76 MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
77 count, bottom_data, bottom[0]->num(), CHANNELS_,
78 HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
81 case LayerParameter_PoolMethod_AVE:
82 AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
83 count, bottom_data, bottom[0]->num(), CHANNELS_,
84 HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
88 LOG(FATAL) << "Unknown pooling method.";
90 CUDA_POST_KERNEL_CHECK;
93 template <typename Dtype>
94 __global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
95 const Dtype* top_data, const Dtype* top_diff,
96 const int num, const int channels, const int height,
97 const int width, const int pooled_height, const int pooled_width,
98 const int ksize, const int stride, Dtype* bottom_diff) {
99 int index = threadIdx.x + blockIdx.x * blockDim.x;
100 if (index < nthreads) {
101 // find out the local index
102 // find out the local offset
103 int w = index % width;
104 int h = (index / width) % height;
105 int c = (index / width / height) % channels;
106 int n = index / width / height / channels;
107 int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
108 int phend = min(h / stride + 1, pooled_height);
109 int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
110 int pwend = min(w / stride + 1, pooled_width);
113 bottom_data[((n * channels + c) * height + h) * width + w];
114 top_data += (n * channels + c) * pooled_height * pooled_width;
115 top_diff += (n * channels + c) * pooled_height * pooled_width;
116 for (int ph = phstart; ph < phend; ++ph) {
117 for (int pw = pwstart; pw < pwend; ++pw) {
118 gradient += top_diff[ph * pooled_width + pw] *
119 (bottom_datum >= top_data[ph * pooled_width + pw] -
120 CAFFE_MAX_POOLING_THRESHOLD);
123 bottom_diff[index] = gradient;
124 } // (if index < nthreads)
128 template <typename Dtype>
129 __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
130 const int num, const int channels, const int height,
131 const int width, const int pooled_height, const int pooled_width,
132 const int ksize, const int stride, Dtype* bottom_diff) {
133 int index = threadIdx.x + blockIdx.x * blockDim.x;
134 if (index < nthreads) {
135 // find out the local index
136 // find out the local offset
137 int w = index % width;
138 int h = (index / width) % height;
139 int c = (index / width / height) % channels;
140 int n = index / width / height / channels;
141 int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
142 int phend = min(h / stride + 1, pooled_height);
143 int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
144 int pwend = min(w / stride + 1, pooled_width);
146 top_diff += (n * channels + c) * pooled_height * pooled_width;
147 for (int ph = phstart; ph < phend; ++ph) {
148 for (int pw = pwstart; pw < pwend; ++pw) {
149 // figure out the pooling size
150 int poolsize = (min(ph * stride + ksize, height) - ph * stride) *
151 (min(pw * stride + ksize, width) - pw * stride);
152 gradient += top_diff[ph * pooled_width + pw] / poolsize;
155 bottom_diff[index] = gradient;
156 } // (if index < nthreads)
159 template <typename Dtype>
160 Dtype PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
161 const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
162 if (!propagate_down) {
165 const Dtype* top_diff = top[0]->gpu_diff();
166 Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
167 int count = (*bottom)[0]->count();
168 switch (this->layer_param_.pool()) {
169 case LayerParameter_PoolMethod_MAX:
170 MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
171 count, (*bottom)[0]->gpu_data(), top[0]->gpu_data(), top_diff,
172 top[0]->num(), CHANNELS_, HEIGHT_, WIDTH_, POOLED_HEIGHT_,
173 POOLED_WIDTH_, KSIZE_, STRIDE_, bottom_diff);
175 case LayerParameter_PoolMethod_AVE:
176 AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
177 count, top_diff, top[0]->num(), CHANNELS_,
178 HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
182 LOG(FATAL) << "Unknown pooling method.";
184 CUDA_POST_KERNEL_CHECK;
189 INSTANTIATE_CLASS(PoolingLayer);