__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size, const int stride, const int pad, Dtype* top_data,
+ const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int stride_w, const int pad_h, const int pad_w, Dtype* top_data,
int* mask, Dtype* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
- int hstart = ph * stride - pad;
- int wstart = pw * stride - pad;
- int hend = min(hstart + kernel_size, height);
- int wend = min(wstart + kernel_size, width);
+ int hstart = ph * stride_h - pad_h;
+ int wstart = pw * stride_w - pad_w;
+ int hend = min(hstart + kernel_size_h, height);
+ int wend = min(wstart + kernel_size_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
Dtype maxval = -FLT_MAX;
__global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size, const int stride, const int pad, Dtype* top_data) {
+ const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int stride_w, const int pad_h, const int pad_w, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
- int hstart = ph * stride - pad;
- int wstart = pw * stride - pad;
- int hend = min(hstart + kernel_size, height + pad);
- int wend = min(wstart + kernel_size, width + pad);
+ int hstart = ph * stride_h - pad_h;
+ int wstart = pw * stride_w - pad_w;
+ int hend = min(hstart + kernel_size_h, height + pad_h);
+ int wend = min(wstart + kernel_size_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size, const int stride, Dtype* rand_idx, Dtype* top_data) {
+ const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int stride_w, Dtype* rand_idx, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
- int hstart = ph * stride;
- int hend = min(hstart + kernel_size, height);
- int wstart = pw * stride;
- int wend = min(wstart + kernel_size, width);
+ int hstart = ph * stride_h;
+ int hend = min(hstart + kernel_size_h, height);
+ int wstart = pw * stride_w;
+ int wend = min(wstart + kernel_size_w, width);
Dtype cumsum = 0.;
bottom_data += (n * channels + c) * height * width;
// First pass: get sum
const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size, const int stride, Dtype* top_data) {
+ const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int stride_w, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
- int hstart = ph * stride;
- int hend = min(hstart + kernel_size, height);
- int wstart = pw * stride;
- int wend = min(wstart + kernel_size, width);
+ int hstart = ph * stride_h;
+ int hend = min(hstart + kernel_size_h, height);
+ int wstart = pw * stride_w;
+ int wend = min(wstart + kernel_size_w, width);
// We set cumsum to be 0 to avoid divide-by-zero problems
Dtype cumsum = FLT_MIN;
Dtype cumvalues = 0.;
// NOLINT_NEXT_LINE(whitespace/operators)
MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
- pad_, top_data, mask, top_mask);
+ height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
+ kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data,
+ mask, top_mask);
break;
case PoolingParameter_PoolMethod_AVE:
// NOLINT_NEXT_LINE(whitespace/operators)
AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
- pad_, top_data);
+ height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
+ kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
if (Caffe::phase() == Caffe::TRAIN) {
StoPoolForwardTrain<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
+ height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
+ kernel_size_w_, stride_h_, stride_w_,
rand_idx_.mutable_gpu_data(), top_data);
} else {
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolForwardTest<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
- top_data);
+ height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
+ kernel_size_w_, stride_h_, stride_w_, top_data);
}
break;
default:
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const int* mask, const Dtype* top_mask, const int num, const int channels,
const int height, const int width, const int pooled_height,
- const int pooled_width, const int kernel_size, const int stride,
- const int pad, Dtype* bottom_diff) {
+ const int pooled_width, const int kernel_size_h, const int kernel_size_w,
+ const int stride_h, const int stride_w, const int pad_h, const int pad_w,
+ Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart =
- (h + pad < kernel_size) ? 0 : (h + pad - kernel_size) / stride + 1;
- int phend = min((h + pad) / stride + 1, pooled_height);
+ (h + pad_h < kernel_size_h) ? 0 : (h + pad_h - kernel_size_h) / stride_h + 1;
+ int phend = min((h + pad_h) / stride_h + 1, pooled_height);
int pwstart =
- (w + pad < kernel_size) ? 0 : (w + pad - kernel_size) / stride + 1;
- int pwend = min((w + pad) / stride + 1, pooled_width);
+ (w + pad_w < kernel_size_w) ? 0 : (w + pad_w - kernel_size_w) / stride_w + 1;
+ int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
Dtype gradient = 0;
int offset = (n * channels + c) * pooled_height * pooled_width;
top_diff += offset;
__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size, const int stride, const int pad,
+ const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int stride_w, const int pad_h, const int pad_w,
Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
- int w = index % width + pad;
- int h = (index / width) % height + pad;
+ int w = index % width + pad_w;
+ int h = (index / width) % height + pad_h;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
- int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
- int phend = min(h / stride + 1, pooled_height);
- int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
- int pwend = min(w / stride + 1, pooled_width);
+ int phstart = (h < kernel_size_h) ? 0 : (h - kernel_size_h) / stride_h + 1;
+ int phend = min(h / stride_h + 1, pooled_height);
+ int pwstart = (w < kernel_size_w) ? 0 : (w - kernel_size_w) / stride_w + 1;
+ int pwend = min(w / stride_w + 1, pooled_width);
Dtype gradient = 0;
top_diff += (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
- int hstart = ph * stride - pad;
- int wstart = pw * stride - pad;
- int hend = min(hstart + kernel_size, height + pad);
- int wend = min(wstart + kernel_size, width + pad);
+ int hstart = ph * stride_h - pad_h;
+ int wstart = pw * stride_w - pad_w;
+ int hend = min(hstart + kernel_size_h, height + pad_h);
+ int wend = min(wstart + kernel_size_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
gradient += top_diff[ph * pooled_width + pw] / pool_size;
}
const Dtype* rand_idx, const Dtype* top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size, const int stride, Dtype* bottom_diff) {
+ const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int stride_w, Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
- int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
- int phend = min(h / stride + 1, pooled_height);
- int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
- int pwend = min(w / stride + 1, pooled_width);
+ int phstart = (h < kernel_size_h) ? 0 : (h - kernel_size_h) / stride_h + 1;
+ int phend = min(h / stride_h + 1, pooled_height);
+ int pwstart = (w < kernel_size_w) ? 0 : (w - kernel_size_w) / stride_w + 1;
+ int pwend = min(w / stride_w + 1, pooled_width);
Dtype gradient = 0;
rand_idx += (n * channels + c) * pooled_height * pooled_width;
top_diff += (n * channels + c) * pooled_height * pooled_width;
MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, mask, top_mask, top[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_,
- kernel_size_, stride_, pad_, bottom_diff);
+ kernel_size_h_, kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_,
+ bottom_diff);
break;
case PoolingParameter_PoolMethod_AVE:
// NOLINT_NEXT_LINE(whitespace/operators)
AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, top[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
- pad_, bottom_diff);
+ height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
+ kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, bottom_diff);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, rand_idx_.gpu_data(), top_diff,
top[0]->num(), channels_, height_, width_, pooled_height_,
- pooled_width_, kernel_size_, stride_, bottom_diff);
+ pooled_width_, kernel_size_h_, kernel_size_w_, stride_h_, stride_w_,
+ bottom_diff);
break;
default:
LOG(FATAL) << "Unknown pooling method.";