706ee156807f5d95be181fa27067b49485e57fee
[platform/upstream/caffeonacl.git] / src / caffe / layers / pooling_layer.cu
1 // Copyright 2013 Yangqing Jia
2
3 #include <algorithm>
4 #include <cfloat>
5 #include "caffe/layer.hpp"
6 #include "caffe/vision_layers.hpp"
7 #include "caffe/util/math_functions.hpp"
8
9 #define CAFFE_MAX_POOLING_THRESHOLD 1e-8f
10
11 using std::max;
12 using std::min;
13
14 namespace caffe {
15
16 template <typename Dtype>
17 __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
18     const int num, const int channels, const int height,
19     const int width, const int pooled_height, const int pooled_width,
20     const int ksize, const int stride, Dtype* top_data) {
21   int index = threadIdx.x + blockIdx.x * blockDim.x;
22   if (index < nthreads) {
23     int pw = index % pooled_width;
24     int ph = (index / pooled_width) % pooled_height;
25     int c = (index / pooled_width / pooled_height) % channels;
26     int n = index / pooled_width / pooled_height / channels;
27     int hstart = ph * stride;
28     int hend = min(hstart + ksize, height);
29     int wstart = pw * stride;
30     int wend = min(wstart + ksize, width);
31     Dtype maxval = -FLT_MAX;
32     bottom_data += (n * channels + c) * height * width;
33     for (int h = hstart; h < hend; ++h) {
34       for (int w = wstart; w < wend; ++w) {
35         maxval = max(maxval, bottom_data[h * width + w]);
36       }
37     }
38     top_data[index] = maxval;
39   }  // (if index < nthreads)
40 }
41
42 template <typename Dtype>
43 __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
44     const int num, const int channels, const int height,
45     const int width, const int pooled_height, const int pooled_width,
46     const int ksize, const int stride, Dtype* top_data) {
47   int index = threadIdx.x + blockIdx.x * blockDim.x;
48   if (index < nthreads) {
49     int pw = index % pooled_width;
50     int ph = (index / pooled_width) % pooled_height;
51     int c = (index / pooled_width / pooled_height) % channels;
52     int n = index / pooled_width / pooled_height / channels;
53     int hstart = ph * stride;
54     int hend = min(hstart + ksize, height);
55     int wstart = pw * stride;
56     int wend = min(wstart + ksize, width);
57     Dtype aveval = 0;
58     bottom_data += (n * channels + c) * height * width;
59     for (int h = hstart; h < hend; ++h) {
60       for (int w = wstart; w < wend; ++w) {
61         aveval += bottom_data[h * width + w];
62       }
63     }
64     top_data[index] = aveval / (hend - hstart) / (wend - wstart);
65   }  // (if index < nthreads)
66 }
67
68 template <typename Dtype>
69 void PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
70       vector<Blob<Dtype>*>* top) {
71   const Dtype* bottom_data = bottom[0]->gpu_data();
72   Dtype* top_data = (*top)[0]->mutable_gpu_data();
73   int count = (*top)[0]->count();
74   switch (this->layer_param_.pool()) {
75   case LayerParameter_PoolMethod_MAX:
76     MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
77         count, bottom_data, bottom[0]->num(), CHANNELS_,
78         HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
79         top_data);
80     break;
81   case LayerParameter_PoolMethod_AVE:
82     AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
83         count, bottom_data, bottom[0]->num(), CHANNELS_,
84         HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
85         top_data);
86     break;
87   default:
88     LOG(FATAL) << "Unknown pooling method.";
89   }
90   CUDA_POST_KERNEL_CHECK;
91 }
92
93 template <typename Dtype>
94 __global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
95     const Dtype* top_data, const Dtype* top_diff,
96     const int num, const int channels, const int height,
97     const int width, const int pooled_height, const int pooled_width,
98     const int ksize, const int stride, Dtype* bottom_diff) {
99   int index = threadIdx.x + blockIdx.x * blockDim.x;
100   if (index < nthreads) {
101     // find out the local index
102     // find out the local offset
103     int w = index % width;
104     int h = (index / width) % height;
105     int c = (index / width / height) % channels;
106     int n = index / width / height / channels;
107     int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
108     int phend = min(h / stride + 1, pooled_height);
109     int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
110     int pwend = min(w / stride + 1, pooled_width);
111     Dtype gradient = 0;
112     Dtype bottom_datum =
113         bottom_data[((n * channels + c) * height + h) * width + w];
114     top_data += (n * channels + c) * pooled_height * pooled_width;
115     top_diff += (n * channels + c) * pooled_height * pooled_width;
116     for (int ph = phstart; ph < phend; ++ph) {
117       for (int pw = pwstart; pw < pwend; ++pw) {
118         gradient += top_diff[ph * pooled_width + pw] *
119             (bottom_datum >= top_data[ph * pooled_width + pw] -
120                 CAFFE_MAX_POOLING_THRESHOLD);
121       }
122     }
123     bottom_diff[index] = gradient;
124   }  // (if index < nthreads)
125 }
126
127
128 template <typename Dtype>
129 __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
130     const int num, const int channels, const int height,
131     const int width, const int pooled_height, const int pooled_width,
132     const int ksize, const int stride, Dtype* bottom_diff) {
133   int index = threadIdx.x + blockIdx.x * blockDim.x;
134   if (index < nthreads) {
135     // find out the local index
136     // find out the local offset
137     int w = index % width;
138     int h = (index / width) % height;
139     int c = (index / width / height) % channels;
140     int n = index / width / height / channels;
141     int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
142     int phend = min(h / stride + 1, pooled_height);
143     int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
144     int pwend = min(w / stride + 1, pooled_width);
145     Dtype gradient = 0;
146     top_diff += (n * channels + c) * pooled_height * pooled_width;
147     for (int ph = phstart; ph < phend; ++ph) {
148       for (int pw = pwstart; pw < pwend; ++pw) {
149         // figure out the pooling size
150         int poolsize = (min(ph * stride + ksize, height) - ph * stride) *
151             (min(pw * stride + ksize, width) - pw * stride);
152         gradient += top_diff[ph * pooled_width + pw] / poolsize;
153       }
154     }
155     bottom_diff[index] = gradient;
156   }  // (if index < nthreads)
157 }
158
159 template <typename Dtype>
160 Dtype PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
161       const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
162   if (!propagate_down) {
163     return Dtype(0.);
164   }
165   const Dtype* top_diff = top[0]->gpu_diff();
166   Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
167   int count = (*bottom)[0]->count();
168   switch (this->layer_param_.pool()) {
169   case LayerParameter_PoolMethod_MAX:
170     MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
171         count, (*bottom)[0]->gpu_data(), top[0]->gpu_data(), top_diff,
172         top[0]->num(), CHANNELS_, HEIGHT_, WIDTH_, POOLED_HEIGHT_,
173         POOLED_WIDTH_, KSIZE_, STRIDE_, bottom_diff);
174     break;
175   case LayerParameter_PoolMethod_AVE:
176     AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
177         count, top_diff, top[0]->num(), CHANNELS_,
178         HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
179         bottom_diff);
180     break;
181   default:
182     LOG(FATAL) << "Unknown pooling method.";
183   }
184   CUDA_POST_KERNEL_CHECK;
185   return Dtype(0.);
186 }
187
188
189 INSTANTIATE_CLASS(PoolingLayer);
190
191
192 }  // namespace caffe