CHECK_EQ(top->size(), 1) << "PoolingLayer takes a single blob as output.";
kernel_size_ = this->layer_param_.pooling_param().kernel_size();
stride_ = this->layer_param_.pooling_param().stride();
+ pad_ = this->layer_param_.pooling_param().pad();
+ if (pad_ != 0) {
+ CHECK_EQ(this->layer_param_.pooling_param().pool(),
+ PoolingParameter_PoolMethod_AVE)
+ << "Padding implemented only for average pooling.";
+ }
channels_ = bottom[0]->channels();
height_ = bottom[0]->height();
width_ = bottom[0]->width();
- pooled_height_ = static_cast<int>(
- ceil(static_cast<float>(height_ - kernel_size_) / stride_)) + 1;
- pooled_width_ = static_cast<int>(
- ceil(static_cast<float>(width_ - kernel_size_) / stride_)) + 1;
+ pooled_height_ = static_cast<int>(ceil(static_cast<float>(
+ height_ + 2 * pad_ - kernel_size_) / stride_)) + 1;
+ pooled_width_ = static_cast<int>(ceil(static_cast<float>(
+ width_ + 2 * pad_ - kernel_size_) / stride_)) + 1;
(*top)[0]->Reshape(bottom[0]->num(), channels_, pooled_height_,
pooled_width_);
// If stochastic pooling, we will initialize the random index part.
for (int c = 0; c < channels_; ++c) {
for (int ph = 0; ph < pooled_height_; ++ph) {
for (int pw = 0; pw < pooled_width_; ++pw) {
- int hstart = ph * stride_;
- int wstart = pw * stride_;
- int hend = min(hstart + kernel_size_, height_);
- int wend = min(wstart + kernel_size_, width_);
+ int hstart = ph * stride_ - pad_;
+ int wstart = pw * stride_ - pad_;
+ int hend = min(hstart + kernel_size_, height_ + pad_);
+ int wend = min(wstart + kernel_size_, width_ + pad_);
+ int pool_size = (hend - hstart) * (wend - wstart);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ hend = min(hend, height_);
+ wend = min(wend, width_);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
top_data[ph * pooled_width_ + pw] +=
bottom_data[h * width_ + w];
}
}
- top_data[ph * pooled_width_ + pw] /=
- (hend - hstart) * (wend - wstart);
+ top_data[ph * pooled_width_ + pw] /= pool_size;
}
}
// compute offset
for (int c = 0; c < channels_; ++c) {
for (int ph = 0; ph < pooled_height_; ++ph) {
for (int pw = 0; pw < pooled_width_; ++pw) {
- int hstart = ph * stride_;
- int wstart = pw * stride_;
- int hend = min(hstart + kernel_size_, height_);
- int wend = min(wstart + kernel_size_, width_);
- int poolsize = (hend - hstart) * (wend - wstart);
+ int hstart = ph * stride_ - pad_;
+ int wstart = pw * stride_ - pad_;
+ int hend = min(hstart + kernel_size_, height_ + pad_);
+ int wend = min(wstart + kernel_size_, width_ + pad_);
+ int pool_size = (hend - hstart) * (wend - wstart);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ hend = min(hend, height_);
+ wend = min(wend, width_);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
bottom_diff[h * width_ + w] +=
- top_diff[ph * pooled_width_ + pw] / poolsize;
+ top_diff[ph * pooled_width_ + pw] / pool_size;
}
}
}
template <typename Dtype>
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
- const int width, const int pooled_height_, const int pooled_width,
+ const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
- int ph = (index / pooled_width) % pooled_height_;
- int c = (index / pooled_width / pooled_height_) % channels;
- int n = index / pooled_width / pooled_height_ / channels;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride;
int hend = min(hstart + kernel_size, height);
int wstart = pw * stride;
template <typename Dtype>
__global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
- const int width, const int pooled_height_, const int pooled_width,
- const int kernel_size, const int stride, Dtype* top_data) {
+ const int width, const int pooled_height, const int pooled_width,
+ const int kernel_size, const int stride, const int pad, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
- int ph = (index / pooled_width) % pooled_height_;
- int c = (index / pooled_width / pooled_height_) % channels;
- int n = index / pooled_width / pooled_height_ / channels;
- int hstart = ph * stride;
- int hend = min(hstart + kernel_size, height);
- int wstart = pw * stride;
- int wend = min(wstart + kernel_size, width);
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+ int hstart = ph * stride - pad;
+ int wstart = pw * stride - pad;
+ int hend = min(hstart + kernel_size, height + pad);
+ int wend = min(wstart + kernel_size, width + pad);
+ int pool_size = (hend - hstart) * (wend - wstart);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ hend = min(hend, height);
+ wend = min(wend, width);
Dtype aveval = 0;
bottom_data += (n * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
aveval += bottom_data[h * width + w];
}
}
- top_data[index] = aveval / (hend - hstart) / (wend - wstart);
+ top_data[index] = aveval / pool_size;
}
}
__global__ void StoPoolForwardTrain(const int nthreads,
const Dtype* bottom_data,
const int num, const int channels, const int height,
- const int width, const int pooled_height_, const int pooled_width,
+ const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, float* rand_idx, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
- int ph = (index / pooled_width) % pooled_height_;
- int c = (index / pooled_width / pooled_height_) % channels;
- int n = index / pooled_width / pooled_height_ / channels;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride;
int hend = min(hstart + kernel_size, height);
int wstart = pw * stride;
__global__ void StoPoolForwardTest(const int nthreads,
const Dtype* bottom_data,
const int num, const int channels, const int height,
- const int width, const int pooled_height_, const int pooled_width,
+ const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
- int ph = (index / pooled_width) % pooled_height_;
- int c = (index / pooled_width / pooled_height_) % channels;
- int n = index / pooled_width / pooled_height_ / channels;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride;
int hend = min(hstart + kernel_size, height);
int wstart = pw * stride;
AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
- top_data);
+ pad_, top_data);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
if (Caffe::phase() == Caffe::TRAIN) {
__global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
const Dtype* top_data, const Dtype* top_diff,
const int num, const int channels, const int height,
- const int width, const int pooled_height_, const int pooled_width,
+ const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
- int phend = min(h / stride + 1, pooled_height_);
+ int phend = min(h / stride + 1, pooled_height);
int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
int pwend = min(w / stride + 1, pooled_width);
Dtype gradient = 0;
Dtype bottom_datum =
bottom_data[((n * channels + c) * height + h) * width + w];
- top_data += (n * channels + c) * pooled_height_ * pooled_width;
- top_diff += (n * channels + c) * pooled_height_ * pooled_width;
+ top_data += (n * channels + c) * pooled_height * pooled_width;
+ top_diff += (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
gradient += top_diff[ph * pooled_width + pw] *
template <typename Dtype>
__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
const int num, const int channels, const int height,
- const int width, const int pooled_height_, const int pooled_width,
- const int kernel_size, const int stride, Dtype* bottom_diff) {
+ const int width, const int pooled_height, const int pooled_width,
+ const int kernel_size, const int stride, const int pad,
+ Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
- int w = index % width;
- int h = (index / width) % height;
+ int w = index % width + pad;
+ int h = (index / width) % height + pad;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
- int phend = min(h / stride + 1, pooled_height_);
+ int phend = min(h / stride + 1, pooled_height);
int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
int pwend = min(w / stride + 1, pooled_width);
Dtype gradient = 0;
- top_diff += (n * channels + c) * pooled_height_ * pooled_width;
+ top_diff += (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
- int poolsize = (min(ph * stride + kernel_size, height) - ph * stride) *
- (min(pw * stride + kernel_size, width) - pw * stride);
- gradient += top_diff[ph * pooled_width + pw] / poolsize;
+ int hstart = ph * stride - pad;
+ int wstart = pw * stride - pad;
+ int hend = min(hstart + kernel_size, height + pad);
+ int wend = min(wstart + kernel_size, width + pad);
+ int pool_size = (hend - hstart) * (wend - wstart);
+ gradient += top_diff[ph * pooled_width + pw] / pool_size;
}
}
bottom_diff[index] = gradient;
__global__ void StoPoolBackward(const int nthreads,
const float* rand_idx, const Dtype* top_diff,
const int num, const int channels, const int height,
- const int width, const int pooled_height_, const int pooled_width,
+ const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
- int phend = min(h / stride + 1, pooled_height_);
+ int phend = min(h / stride + 1, pooled_height);
int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
int pwend = min(w / stride + 1, pooled_width);
Dtype gradient = 0;
- rand_idx += (n * channels + c) * pooled_height_ * pooled_width;
- top_diff += (n * channels + c) * pooled_height_ * pooled_width;
+ rand_idx += (n * channels + c) * pooled_height * pooled_width;
+ top_diff += (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
gradient += top_diff[ph * pooled_width + pw] *
AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, top[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
- bottom_diff);
+ pad_, bottom_diff);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
// NOLINT_NEXT_LINE(whitespace/operators)
EXPECT_EQ(this->blob_top_->width(), 2);
}
+TYPED_TEST(PoolingLayerTest, TestSetupPadded) {
+ LayerParameter layer_param;
+ PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+ pooling_param->set_kernel_size(3);
+ pooling_param->set_stride(2);
+ pooling_param->set_pad(1);
+ pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+ PoolingLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
+ EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
+ EXPECT_EQ(this->blob_top_->height(), 4);
+ EXPECT_EQ(this->blob_top_->width(), 3);
+}
+
/*
TYPED_TEST(PoolingLayerTest, PrintGPUBackward) {
LayerParameter layer_param;
}
+TYPED_TEST(PoolingLayerTest, TestCPUForwardAve) {
+ LayerParameter layer_param;
+ PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+ pooling_param->set_kernel_size(3);
+ pooling_param->set_stride(1);
+ pooling_param->set_pad(1);
+ pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+ Caffe::set_mode(Caffe::CPU);
+ this->blob_bottom_->Reshape(1, 1, 3, 3);
+ FillerParameter filler_param;
+ filler_param.set_value(TypeParam(2));
+ ConstantFiller<TypeParam> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ PoolingLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ EXPECT_EQ(this->blob_top_->num(), 1);
+ EXPECT_EQ(this->blob_top_->channels(), 1);
+ EXPECT_EQ(this->blob_top_->height(), 3);
+ EXPECT_EQ(this->blob_top_->width(), 3);
+ layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ TypeParam epsilon = 1e-5;
+ EXPECT_NEAR(this->blob_top_->cpu_data()[0], 8.0 / 9, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4.0 / 3, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[2], 8.0 / 9, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4.0 / 3, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[4], 2.0, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4.0 / 3, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[6], 8.0 / 9, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4.0 / 3, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[8], 8.0 / 9, epsilon);
+}
+
+
+TYPED_TEST(PoolingLayerTest, TestGPUForwardAve) {
+ LayerParameter layer_param;
+ PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+ pooling_param->set_kernel_size(3);
+ pooling_param->set_stride(1);
+ pooling_param->set_pad(1);
+ pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+ Caffe::set_mode(Caffe::GPU);
+ this->blob_bottom_->Reshape(1, 1, 3, 3);
+ FillerParameter filler_param;
+ filler_param.set_value(TypeParam(2));
+ ConstantFiller<TypeParam> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ PoolingLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ EXPECT_EQ(this->blob_top_->num(), 1);
+ EXPECT_EQ(this->blob_top_->channels(), 1);
+ EXPECT_EQ(this->blob_top_->height(), 3);
+ EXPECT_EQ(this->blob_top_->width(), 3);
+ layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ TypeParam epsilon = 1e-5;
+ EXPECT_NEAR(this->blob_top_->cpu_data()[0], 8.0 / 9, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4.0 / 3, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[2], 8.0 / 9, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4.0 / 3, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[4], 2.0, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4.0 / 3, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[6], 8.0 / 9, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4.0 / 3, epsilon);
+ EXPECT_NEAR(this->blob_top_->cpu_data()[8], 8.0 / 9, epsilon);
+}
+
+
TYPED_TEST(PoolingLayerTest, TestCPUGradientAve) {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
}
+TYPED_TEST(PoolingLayerTest, TestCPUGradientAvePadded) {
+ LayerParameter layer_param;
+ PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+ pooling_param->set_kernel_size(3);
+ pooling_param->set_stride(2);
+ pooling_param->set_pad(2);
+ pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+ Caffe::set_mode(Caffe::CPU);
+ PoolingLayer<TypeParam> layer(layer_param);
+ GradientChecker<TypeParam> checker(1e-2, 1e-2);
+ checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+ &(this->blob_top_vec_));
+}
+
+
+TYPED_TEST(PoolingLayerTest, TestGPUGradientAvePadded) {
+ LayerParameter layer_param;
+ PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+ pooling_param->set_kernel_size(3);
+ pooling_param->set_stride(2);
+ pooling_param->set_pad(2);
+ pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+ Caffe::set_mode(Caffe::GPU);
+ PoolingLayer<TypeParam> layer(layer_param);
+ GradientChecker<TypeParam> checker(1e-2, 1e-2);
+ checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+ &(this->blob_top_vec_));
+}
+
+
} // namespace caffe