const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
int max_top_blobs_;
- int kernel_size_h_, kernel_size_w_;
+ int kernel_h_, kernel_w_;
int stride_h_, stride_w_;
int pad_h_, pad_w_;
int channels_;
&& pool_param.has_stride_w())
|| (!pool_param.has_stride_h() && !pool_param.has_stride_w()))
<< "Stride is stride OR stride_h and stride_w are required.";
-
if (pool_param.has_kernel_size()) {
kernel_h_ = kernel_w_ = pool_param.kernel_size();
} else {
kernel_h_ = pool_param.kernel_h();
kernel_w_ = pool_param.kernel_w();
}
- CHECK_GT(kernel_h_ * kernel_w_, 0) << "Filter dimensions cannot be zero.";
+ CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
+ CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
if (!pool_param.has_pad_h()) {
pad_h_ = pad_w_ = pool_param.pad();
} else {
|| this->layer_param_.pooling_param().pool()
== PoolingParameter_PoolMethod_MAX)
<< "Padding implemented only for average and max pooling.";
- CHECK_LT(pad_h_, kernel_size_h_);
- CHECK_LT(pad_w_, kernel_size_w_);
+ CHECK_LT(pad_h_, kernel_h_);
+ CHECK_LT(pad_w_, kernel_w_);
}
channels_ = bottom[0]->channels();
height_ = bottom[0]->height();
width_ = bottom[0]->width();
pooled_height_ = static_cast<int>(ceil(static_cast<float>(
- height_ + 2 * pad_h_ - kernel_size_h_) / stride_h_)) + 1;
+ height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
pooled_width_ = static_cast<int>(ceil(static_cast<float>(
- width_ + 2 * pad_w_ - kernel_size_w_) / stride_w_)) + 1;
+ width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
if (pad_h_ || pad_w_) {
// If we have padding, ensure that the last pooling starts strictly
// inside the image (instead of at the padding); otherwise clip the last.
for (int pw = 0; pw < pooled_width_; ++pw) {
int hstart = ph * stride_h_ - pad_h_;
int wstart = pw * stride_w_ - pad_w_;
- int hend = min(hstart + kernel_size_h_, height_);
- int wend = min(wstart + kernel_size_w_, width_);
+ int hend = min(hstart + kernel_h_, height_);
+ int wend = min(wstart + kernel_w_, width_);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
const int pool_index = ph * pooled_width_ + pw;
for (int pw = 0; pw < pooled_width_; ++pw) {
int hstart = ph * stride_h_ - pad_h_;
int wstart = pw * stride_w_ - pad_w_;
- int hend = min(hstart + kernel_size_h_, height_ + pad_h_);
- int wend = min(wstart + kernel_size_w_, width_ + pad_w_);
+ int hend = min(hstart + kernel_h_, height_ + pad_h_);
+ int wend = min(wstart + kernel_w_, width_ + pad_w_);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
for (int pw = 0; pw < pooled_width_; ++pw) {
int hstart = ph * stride_h_ - pad_h_;
int wstart = pw * stride_w_ - pad_w_;
- int hend = min(hstart + kernel_size_h_, height_ + pad_h_);
- int wend = min(wstart + kernel_size_w_, width_ + pad_w_);
+ int hend = min(hstart + kernel_h_, height_ + pad_h_);
+ int wend = min(wstart + kernel_w_, width_ + pad_w_);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, Dtype* top_data,
int* mask, Dtype* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
- int hend = min(hstart + kernel_size_h, height);
- int wend = min(wstart + kernel_size_w, width);
+ int hend = min(hstart + kernel_h, height);
+ int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
Dtype maxval = -FLT_MAX;
__global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
- int hend = min(hstart + kernel_size_h, height + pad_h);
- int wend = min(wstart + kernel_size_w, width + pad_w);
+ int hend = min(hstart + kernel_h, height + pad_h);
+ int wend = min(wstart + kernel_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, Dtype* rand_idx, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h;
- int hend = min(hstart + kernel_size_h, height);
+ int hend = min(hstart + kernel_h, height);
int wstart = pw * stride_w;
- int wend = min(wstart + kernel_size_w, width);
+ int wend = min(wstart + kernel_w, width);
Dtype cumsum = 0.;
bottom_data += (n * channels + c) * height * width;
// First pass: get sum
const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h;
- int hend = min(hstart + kernel_size_h, height);
+ int hend = min(hstart + kernel_h, height);
int wstart = pw * stride_w;
- int wend = min(wstart + kernel_size_w, width);
+ int wend = min(wstart + kernel_w, width);
// We set cumsum to be 0 to avoid divide-by-zero problems
Dtype cumsum = FLT_MIN;
Dtype cumvalues = 0.;
// NOLINT_NEXT_LINE(whitespace/operators)
MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
- kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data,
+ height_, width_, pooled_height_, pooled_width_, kernel_h_,
+ kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data,
mask, top_mask);
break;
case PoolingParameter_PoolMethod_AVE:
// NOLINT_NEXT_LINE(whitespace/operators)
AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
- kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data);
+ height_, width_, pooled_height_, pooled_width_, kernel_h_,
+ kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
if (Caffe::phase() == Caffe::TRAIN) {
StoPoolForwardTrain<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
- kernel_size_w_, stride_h_, stride_w_,
+ height_, width_, pooled_height_, pooled_width_, kernel_h_,
+ kernel_w_, stride_h_, stride_w_,
rand_idx_.mutable_gpu_data(), top_data);
} else {
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolForwardTest<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
- kernel_size_w_, stride_h_, stride_w_, top_data);
+ height_, width_, pooled_height_, pooled_width_, kernel_h_,
+ kernel_w_, stride_h_, stride_w_, top_data);
}
break;
default:
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const int* mask, const Dtype* top_mask, const int num, const int channels,
const int height, const int width, const int pooled_height,
- const int pooled_width, const int kernel_size_h, const int kernel_size_w,
+ const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart =
- (h + pad_h < kernel_size_h) ? 0 : (h + pad_h - kernel_size_h) / stride_h + 1;
+ (h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
int phend = min((h + pad_h) / stride_h + 1, pooled_height);
int pwstart =
- (w + pad_w < kernel_size_w) ? 0 : (w + pad_w - kernel_size_w) / stride_w + 1;
+ (w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
Dtype gradient = 0;
int offset = (n * channels + c) * pooled_height * pooled_width;
__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
int h = (index / width) % height + pad_h;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
- int phstart = (h < kernel_size_h) ? 0 : (h - kernel_size_h) / stride_h + 1;
+ int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
int phend = min(h / stride_h + 1, pooled_height);
- int pwstart = (w < kernel_size_w) ? 0 : (w - kernel_size_w) / stride_w + 1;
+ int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
int pwend = min(w / stride_w + 1, pooled_width);
Dtype gradient = 0;
top_diff += (n * channels + c) * pooled_height * pooled_width;
// figure out the pooling size
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
- int hend = min(hstart + kernel_size_h, height + pad_h);
- int wend = min(wstart + kernel_size_w, width + pad_w);
+ int hend = min(hstart + kernel_h, height + pad_h);
+ int wend = min(wstart + kernel_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
gradient += top_diff[ph * pooled_width + pw] / pool_size;
}
const Dtype* rand_idx, const Dtype* top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
- const int kernel_size_h, const int kernel_size_w, const int stride_h,
+ const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
- int phstart = (h < kernel_size_h) ? 0 : (h - kernel_size_h) / stride_h + 1;
+ int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
int phend = min(h / stride_h + 1, pooled_height);
- int pwstart = (w < kernel_size_w) ? 0 : (w - kernel_size_w) / stride_w + 1;
+ int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
int pwend = min(w / stride_w + 1, pooled_width);
Dtype gradient = 0;
rand_idx += (n * channels + c) * pooled_height * pooled_width;
MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, mask, top_mask, top[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_,
- kernel_size_h_, kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_,
+ kernel_h_, kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_,
bottom_diff);
break;
case PoolingParameter_PoolMethod_AVE:
// NOLINT_NEXT_LINE(whitespace/operators)
AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, top[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_size_h_,
- kernel_size_w_, stride_h_, stride_w_, pad_h_, pad_w_, bottom_diff);
+ height_, width_, pooled_height_, pooled_width_, kernel_h_,
+ kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, bottom_diff);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, rand_idx_.gpu_data(), top_diff,
top[0]->num(), channels_, height_, width_, pooled_height_,
- pooled_width_, kernel_size_h_, kernel_size_w_, stride_h_, stride_w_,
+ pooled_width_, kernel_h_, kernel_w_, stride_h_, stride_w_,
bottom_diff);
break;
default:
optional uint32 pad_h = 9 [default = 0]; // The padding height
optional uint32 pad_w = 10 [default = 0]; // The padding width
optional uint32 kernel_size = 2; // The kernel size (square)
- optional uint32 kernel_size_h = 5; // The kernel height
- optional uint32 kernel_size_w = 6; // The kernel width
+ optional uint32 kernel_h = 5; // The kernel height
+ optional uint32 kernel_w = 6; // The kernel width
optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X)
optional uint32 stride_h = 7; // The stride height
optional uint32 stride_w = 8; // The stride width
Blob<Dtype>* const blob_top_mask_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
-
- void TestForward() {
+ // Test for 2x 2 square pooling layer
+ void TestForwardSquare() {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(2);
}
}
}
+ // Test for 3x 2 rectangular pooling layer with kernel_h > kernel_w
+ void TestForwardRectHigh() {
+ LayerParameter layer_param;
+ PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+ pooling_param->set_kernel_h(3);
+ pooling_param->set_kernel_w(2);
+ pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
+ const int num = 2;
+ const int channels = 2;
+ blob_bottom_->Reshape(num, channels, 6, 6);
+ // Input: 2x 2 channels of:
+ // [35 1 6 26 19 24]
+ // [ 3 32 7 21 23 25]
+ // [31 9 2 22 27 20]
+ // [ 8 28 33 17 10 15]
+ // [30 5 34 12 14 16]
+ // [ 4 36 29 13 18 11] (this is generated by magic(6) in MATLAB)
+ for (int i = 0; i < 36 * num * channels; i += 36) {
+ blob_bottom_->mutable_cpu_data()[i + 0] = 35;
+ blob_bottom_->mutable_cpu_data()[i + 1] = 1;
+ blob_bottom_->mutable_cpu_data()[i + 2] = 6;
+ blob_bottom_->mutable_cpu_data()[i + 3] = 26;
+ blob_bottom_->mutable_cpu_data()[i + 4] = 19;
+ blob_bottom_->mutable_cpu_data()[i + 5] = 24;
+ blob_bottom_->mutable_cpu_data()[i + 6] = 3;
+ blob_bottom_->mutable_cpu_data()[i + 7] = 32;
+ blob_bottom_->mutable_cpu_data()[i + 8] = 7;
+ blob_bottom_->mutable_cpu_data()[i + 9] = 21;
+ blob_bottom_->mutable_cpu_data()[i + 10] = 23;
+ blob_bottom_->mutable_cpu_data()[i + 11] = 25;
+ blob_bottom_->mutable_cpu_data()[i + 12] = 31;
+ blob_bottom_->mutable_cpu_data()[i + 13] = 9;
+ blob_bottom_->mutable_cpu_data()[i + 14] = 2;
+ blob_bottom_->mutable_cpu_data()[i + 15] = 22;
+ blob_bottom_->mutable_cpu_data()[i + 16] = 27;
+ blob_bottom_->mutable_cpu_data()[i + 17] = 20;
+ blob_bottom_->mutable_cpu_data()[i + 18] = 8;
+ blob_bottom_->mutable_cpu_data()[i + 19] = 28;
+ blob_bottom_->mutable_cpu_data()[i + 20] = 33;
+ blob_bottom_->mutable_cpu_data()[i + 21] = 17;
+ blob_bottom_->mutable_cpu_data()[i + 22] = 10;
+ blob_bottom_->mutable_cpu_data()[i + 23] = 15;
+ blob_bottom_->mutable_cpu_data()[i + 24] = 30;
+ blob_bottom_->mutable_cpu_data()[i + 25] = 5;
+ blob_bottom_->mutable_cpu_data()[i + 26] = 34;
+ blob_bottom_->mutable_cpu_data()[i + 27] = 12;
+ blob_bottom_->mutable_cpu_data()[i + 28] = 14;
+ blob_bottom_->mutable_cpu_data()[i + 29] = 16;
+ blob_bottom_->mutable_cpu_data()[i + 30] = 4;
+ blob_bottom_->mutable_cpu_data()[i + 31] = 36;
+ blob_bottom_->mutable_cpu_data()[i + 32] = 29;
+ blob_bottom_->mutable_cpu_data()[i + 33] = 13;
+ blob_bottom_->mutable_cpu_data()[i + 34] = 18;
+ blob_bottom_->mutable_cpu_data()[i + 35] = 11;
+ }
+ PoolingLayer<Dtype> layer(layer_param);
+ layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
+ EXPECT_EQ(blob_top_->num(), num);
+ EXPECT_EQ(blob_top_->channels(), channels);
+ EXPECT_EQ(blob_top_->height(), 4);
+ EXPECT_EQ(blob_top_->width(), 5);
+ if (blob_top_vec_.size() > 1) {
+ EXPECT_EQ(blob_top_mask_->num(), num);
+ EXPECT_EQ(blob_top_mask_->channels(), channels);
+ EXPECT_EQ(blob_top_mask_->height(), 4);
+ EXPECT_EQ(blob_top_mask_->width(), 5);
+ }
+ layer.Forward(blob_bottom_vec_, &blob_top_vec_);
+ // Expected output: 2x 2 channels of:
+ // [35 32 26 27 27]
+ // [32 33 33 27 27]
+ // [31 34 34 27 27]
+ // [36 36 34 18 18]
+ for (int i = 0; i < 20 * num * channels; i += 20) {
+ EXPECT_EQ(blob_top_->cpu_data()[i + 0], 35);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 1], 32);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 2], 26);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 3], 27);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 4], 27);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 5], 32);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 6], 33);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 7], 33);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 8], 27);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 9], 27);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 10], 31);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 11], 34);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 13], 27);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 14], 27);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 15], 36);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 17], 34);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 18], 18);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18);
+ }
+ if (blob_top_vec_.size() > 1) {
+ // [ 1 8 4 17 17]
+ // [ 8 21 21 17 17]
+ // [13 27 27 17 17]
+ // [32 32 27 35 35]
+ for (int i = 0; i < 20 * num * channels; i += 20) {
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 0);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 7);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 3);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 16);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 16);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 7);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 20);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 20);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 8], 16);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 9], 16);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 12);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 26);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 16);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 16);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 31);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 26);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 34);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34);
+ }
+ }
+ }
+ // Test for rectangular pooling layer with kernel_w > kernel_h
+ void TestForwardRectWide() {
+ LayerParameter layer_param;
+ PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+ pooling_param->set_kernel_h(2);
+ pooling_param->set_kernel_w(3);
+ pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
+ const int num = 2;
+ const int channels = 2;
+ blob_bottom_->Reshape(num, channels, 6, 6);
+ // Input: 2x 2 channels of:
+ // [35 1 6 26 19 24]
+ // [ 3 32 7 21 23 25]
+ // [31 9 2 22 27 20]
+ // [ 8 28 33 17 10 15]
+ // [30 5 34 12 14 16]
+ // [ 4 36 29 13 18 11] (this is generated by magic(6) in MATLAB)
+ for (int i = 0; i < 36 * num * channels; i += 36) {
+ blob_bottom_->mutable_cpu_data()[i + 0] = 35;
+ blob_bottom_->mutable_cpu_data()[i + 1] = 1;
+ blob_bottom_->mutable_cpu_data()[i + 2] = 6;
+ blob_bottom_->mutable_cpu_data()[i + 3] = 26;
+ blob_bottom_->mutable_cpu_data()[i + 4] = 19;
+ blob_bottom_->mutable_cpu_data()[i + 5] = 24;
+ blob_bottom_->mutable_cpu_data()[i + 6] = 3;
+ blob_bottom_->mutable_cpu_data()[i + 7] = 32;
+ blob_bottom_->mutable_cpu_data()[i + 8] = 7;
+ blob_bottom_->mutable_cpu_data()[i + 9] = 21;
+ blob_bottom_->mutable_cpu_data()[i + 10] = 23;
+ blob_bottom_->mutable_cpu_data()[i + 11] = 25;
+ blob_bottom_->mutable_cpu_data()[i + 12] = 31;
+ blob_bottom_->mutable_cpu_data()[i + 13] = 9;
+ blob_bottom_->mutable_cpu_data()[i + 14] = 2;
+ blob_bottom_->mutable_cpu_data()[i + 15] = 22;
+ blob_bottom_->mutable_cpu_data()[i + 16] = 27;
+ blob_bottom_->mutable_cpu_data()[i + 17] = 20;
+ blob_bottom_->mutable_cpu_data()[i + 18] = 8;
+ blob_bottom_->mutable_cpu_data()[i + 19] = 28;
+ blob_bottom_->mutable_cpu_data()[i + 20] = 33;
+ blob_bottom_->mutable_cpu_data()[i + 21] = 17;
+ blob_bottom_->mutable_cpu_data()[i + 22] = 10;
+ blob_bottom_->mutable_cpu_data()[i + 23] = 15;
+ blob_bottom_->mutable_cpu_data()[i + 24] = 30;
+ blob_bottom_->mutable_cpu_data()[i + 25] = 5;
+ blob_bottom_->mutable_cpu_data()[i + 26] = 34;
+ blob_bottom_->mutable_cpu_data()[i + 27] = 12;
+ blob_bottom_->mutable_cpu_data()[i + 28] = 14;
+ blob_bottom_->mutable_cpu_data()[i + 29] = 16;
+ blob_bottom_->mutable_cpu_data()[i + 30] = 4;
+ blob_bottom_->mutable_cpu_data()[i + 31] = 36;
+ blob_bottom_->mutable_cpu_data()[i + 32] = 29;
+ blob_bottom_->mutable_cpu_data()[i + 33] = 13;
+ blob_bottom_->mutable_cpu_data()[i + 34] = 18;
+ blob_bottom_->mutable_cpu_data()[i + 35] = 11;
+ }
+ PoolingLayer<Dtype> layer(layer_param);
+ layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
+ EXPECT_EQ(blob_top_->num(), num);
+ EXPECT_EQ(blob_top_->channels(), channels);
+ EXPECT_EQ(blob_top_->height(), 5);
+ EXPECT_EQ(blob_top_->width(), 4);
+ if (blob_top_vec_.size() > 1) {
+ EXPECT_EQ(blob_top_mask_->num(), num);
+ EXPECT_EQ(blob_top_mask_->channels(), channels);
+ EXPECT_EQ(blob_top_mask_->height(), 5);
+ EXPECT_EQ(blob_top_mask_->width(), 4);
+ }
+ layer.Forward(blob_bottom_vec_, &blob_top_vec_);
+ // Expected output: 2x 2 channels of:
+ // [35 32 26 26]
+ // [32 32 27 27]
+ // [33 33 33 27]
+ // [34 34 34 17]
+ // [36 36 34 18]
+ for (int i = 0; i < 20 * num * channels; i += 20) {
+ EXPECT_EQ(blob_top_->cpu_data()[i + 0], 35);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 1], 32);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 2], 26);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 3], 26);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 4], 32);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 5], 32);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 6], 27);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 7], 27);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 8], 33);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 9], 33);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 10], 33);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 11], 27);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 13], 34);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 14], 34);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 15], 17);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 17], 36);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 18], 34);
+ EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18);
+ }
+ if (blob_top_vec_.size() > 1) {
+ // [ 1 8 4 4]
+ // [ 8 8 17 17]
+ // [21 21 21 17]
+ // [27 27 27 22]
+ // [32 32 27 35]
+ for (int i = 0; i < 20 * num * channels; i += 20) {
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 0);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 7);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 3);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 3);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 7);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 7);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 16);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 16);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 8], 20);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 9], 20);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 20);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 16);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 26);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 26);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 21);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 31);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 26);
+ EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34);
+ }
+ }
+ }
};
typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST(PoolingLayerTest, TestCPUForwardMax) {
Caffe::set_mode(Caffe::CPU);
- this->TestForward();
+ this->TestForwardSquare();
+ this->TestForwardRectHigh();
+ this->TestForwardRectWide();
}
TYPED_TEST(PoolingLayerTest, TestGPUForwardMax) {
Caffe::set_mode(Caffe::GPU);
- this->TestForward();
+ this->TestForwardSquare();
+ this->TestForwardRectHigh();
+ this->TestForwardRectWide();
}
TYPED_TEST(PoolingLayerTest, TestCPUForwardMaxTopMask) {
Caffe::set_mode(Caffe::CPU);
this->blob_top_vec_.push_back(this->blob_top_mask_);
- this->TestForward();
+ this->TestForwardSquare();
+ this->TestForwardRectHigh();
+ this->TestForwardRectWide();
}
TYPED_TEST(PoolingLayerTest, TestGPUForwardMaxTopMask) {
Caffe::set_mode(Caffe::GPU);
this->blob_top_vec_.push_back(this->blob_top_mask_);
- this->TestForward();
+ this->TestForwardSquare();
+ this->TestForwardRectHigh();
+ this->TestForwardRectWide();
}
TYPED_TEST(PoolingLayerTest, TestCPUGradientMax) {