Scaffold engine switching for pooling.
The Caffe pooling is instantiated without regard for engine in:
- LRNLayer
- PoolingLayer tests
- StochasticPoolingLayer tests
- MaxPoolingDropout tests
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top);
+ vector<Blob<Dtype>*>* top) = 0;
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top);
+ vector<Blob<Dtype>*>* top) = 0;
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
- const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+ const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
- const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+ const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
int kernel_h_, kernel_w_;
int stride_h_, stride_w_;
Blob<int> max_idx_;
};
+/* PoolingLayer
+*/
+template <typename Dtype>
+class CaffePoolingLayer : public PoolingLayer<Dtype> {
+ public:
+ explicit CaffePoolingLayer(const LayerParameter& param)
+ : PoolingLayer<Dtype>(param) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ virtual inline LayerParameter_LayerType type() const {
+ return LayerParameter_LayerType_POOLING;
+ }
+ virtual inline int ExactNumBottomBlobs() const { return 1; }
+ virtual inline int MinTopBlobs() const { return 1; }
+ // MAX POOL layers can output an extra top blob for the mask;
+ // others can only output the pooled inputs.
+ virtual inline int MaxTopBlobs() const {
+ return (this->layer_param_.pooling_param().pool() ==
+ PoolingParameter_PoolMethod_MAX) ? 2 : 1;
+ }
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+
+ Blob<Dtype> rand_idx_;
+ Blob<int> max_idx_;
+};
+
} // namespace caffe
#endif // CAFFE_VISION_LAYERS_HPP_
template ConvolutionLayer<double>* GetConvolutionLayer(const string& name,
const LayerParameter& param);
+// Get pooling layer according to engine.
+template <typename Dtype>
+PoolingLayer<Dtype>* GetPoolingLayer(const string& name,
+ const LayerParameter& param) {
+ PoolingParameter_Engine engine = param.pooling_param().engine();
+ if (engine == PoolingParameter_Engine_CAFFE) {
+ return new CaffePoolingLayer<Dtype>(param);
+ } else {
+ LOG(FATAL) << "Layer " << name << " has unknown engine.";
+ }
+}
+
+template PoolingLayer<float>* GetPoolingLayer(const string& name,
+ const LayerParameter& param);
+template PoolingLayer<double>* GetPoolingLayer(const string& name,
+ const LayerParameter& param);
// A function to get a specific layer from the specification given in
// LayerParameter. Ideally this would be replaced by a factory pattern,
case LayerParameter_LayerType_MULTINOMIAL_LOGISTIC_LOSS:
return new MultinomialLogisticLossLayer<Dtype>(param);
case LayerParameter_LayerType_POOLING:
- return new PoolingLayer<Dtype>(param);
+ return GetPoolingLayer<Dtype>(name, param);
case LayerParameter_LayerType_POWER:
return new PowerLayer<Dtype>(param);
case LayerParameter_LayerType_RELU:
--- /dev/null
+#include <algorithm>
+#include <cfloat>
+#include <vector>
+
+#include "caffe/common.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/syncedmem.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+using std::min;
+using std::max;
+
+template <typename Dtype>
+void CaffePoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ PoolingLayer<Dtype>::LayerSetUp(bottom, top);
+ PoolingParameter pool_param = this->layer_param_.pooling_param();
+ // If max pooling, we will initialize the vector index part.
+ if (this->layer_param_.pooling_param().pool() ==
+ PoolingParameter_PoolMethod_MAX && top->size() == 1) {
+ max_idx_.Reshape(bottom[0]->num(), this->channels_, this->pooled_height_,
+ this->pooled_width_);
+ }
+ // If stochastic pooling, we will initialize the random index part.
+ if (this->layer_param_.pooling_param().pool() ==
+ PoolingParameter_PoolMethod_STOCHASTIC) {
+ rand_idx_.Reshape(bottom[0]->num(), this->channels_, this->pooled_height_,
+ this->pooled_width_);
+ }
+}
+
+// TODO(Yangqing): Is there a faster way to do pooling in the channel-first
+// case?
+template <typename Dtype>
+void CaffePoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = (*top)[0]->mutable_cpu_data();
+ const int top_count = (*top)[0]->count();
+ // We'll output the mask to top[1] if it's of size >1.
+ const bool use_top_mask = top->size() > 1;
+ int* mask = NULL; // suppress warnings about uninitalized variables
+ Dtype* top_mask = NULL;
+ // Different pooling methods. We explicitly do the switch outside the for
+ // loop to save time, although this results in more code.
+ switch (this->layer_param_.pooling_param().pool()) {
+ case PoolingParameter_PoolMethod_MAX:
+ // Initialize
+ if (use_top_mask) {
+ top_mask = (*top)[1]->mutable_cpu_data();
+ caffe_set(top_count, Dtype(-1), top_mask);
+ } else {
+ mask = max_idx_.mutable_cpu_data();
+ caffe_set(top_count, -1, mask);
+ }
+ caffe_set(top_count, Dtype(-FLT_MAX), top_data);
+ // The main loop
+ for (int n = 0; n < bottom[0]->num(); ++n) {
+ for (int c = 0; c < this->channels_; ++c) {
+ for (int ph = 0; ph < this->pooled_height_; ++ph) {
+ for (int pw = 0; pw < this->pooled_width_; ++pw) {
+ int hstart = ph * this->stride_h_ - this->pad_h_;
+ int wstart = pw * this->stride_w_ - this->pad_w_;
+ int hend = min(hstart + this->kernel_h_, this->height_);
+ int wend = min(wstart + this->kernel_w_, this->width_);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ const int pool_index = ph * this->pooled_width_ + pw;
+ for (int h = hstart; h < hend; ++h) {
+ for (int w = wstart; w < wend; ++w) {
+ const int index = h * this->width_ + w;
+ if (bottom_data[index] > top_data[pool_index]) {
+ top_data[pool_index] = bottom_data[index];
+ if (use_top_mask) {
+ top_mask[pool_index] = static_cast<Dtype>(index);
+ } else {
+ mask[pool_index] = index;
+ }
+ }
+ }
+ }
+ }
+ }
+ // compute offset
+ bottom_data += bottom[0]->offset(0, 1);
+ top_data += (*top)[0]->offset(0, 1);
+ if (use_top_mask) {
+ top_mask += (*top)[0]->offset(0, 1);
+ } else {
+ mask += (*top)[0]->offset(0, 1);
+ }
+ }
+ }
+ break;
+ case PoolingParameter_PoolMethod_AVE:
+ for (int i = 0; i < top_count; ++i) {
+ top_data[i] = 0;
+ }
+ // The main loop
+ for (int n = 0; n < bottom[0]->num(); ++n) {
+ for (int c = 0; c < this->channels_; ++c) {
+ for (int ph = 0; ph < this->pooled_height_; ++ph) {
+ for (int pw = 0; pw < this->pooled_width_; ++pw) {
+ int hstart = ph * this->stride_h_ - this->pad_h_;
+ int wstart = pw * this->stride_w_ - this->pad_w_;
+ int hend = min(hstart + this->kernel_h_,
+ this->height_ + this->pad_h_);
+ int wend = min(wstart + this->kernel_w_,
+ this->width_ + this->pad_w_);
+ int pool_size = (hend - hstart) * (wend - wstart);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ hend = min(hend, this->height_);
+ wend = min(wend, this->width_);
+ for (int h = hstart; h < hend; ++h) {
+ for (int w = wstart; w < wend; ++w) {
+ top_data[ph * this->pooled_width_ + pw] +=
+ bottom_data[h * this->width_ + w];
+ }
+ }
+ top_data[ph * this->pooled_width_ + pw] /= pool_size;
+ }
+ }
+ // compute offset
+ bottom_data += bottom[0]->offset(0, 1);
+ top_data += (*top)[0]->offset(0, 1);
+ }
+ }
+ break;
+ case PoolingParameter_PoolMethod_STOCHASTIC:
+ NOT_IMPLEMENTED;
+ break;
+ default:
+ LOG(FATAL) << "Unknown pooling method.";
+ }
+}
+
+template <typename Dtype>
+void CaffePoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+ if (!propagate_down[0]) {
+ return;
+ }
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+ // Different pooling methods. We explicitly do the switch outside the for
+ // loop to save time, although this results in more codes.
+ caffe_set((*bottom)[0]->count(), Dtype(0), bottom_diff);
+ // We'll output the mask to top[1] if it's of size >1.
+ const bool use_top_mask = top.size() > 1;
+ const int* mask = NULL; // suppress warnings about uninitialized variables
+ const Dtype* top_mask = NULL;
+ switch (this->layer_param_.pooling_param().pool()) {
+ case PoolingParameter_PoolMethod_MAX:
+ // The main loop
+ if (use_top_mask) {
+ top_mask = top[1]->cpu_data();
+ } else {
+ mask = max_idx_.cpu_data();
+ }
+ for (int n = 0; n < top[0]->num(); ++n) {
+ for (int c = 0; c < this->channels_; ++c) {
+ for (int ph = 0; ph < this->pooled_height_; ++ph) {
+ for (int pw = 0; pw < this->pooled_width_; ++pw) {
+ const int index = ph * this->pooled_width_ + pw;
+ const int bottom_index =
+ use_top_mask ? top_mask[index] : mask[index];
+ bottom_diff[bottom_index] += top_diff[index];
+ }
+ }
+ bottom_diff += (*bottom)[0]->offset(0, 1);
+ top_diff += top[0]->offset(0, 1);
+ if (use_top_mask) {
+ top_mask += top[0]->offset(0, 1);
+ } else {
+ mask += top[0]->offset(0, 1);
+ }
+ }
+ }
+ break;
+ case PoolingParameter_PoolMethod_AVE:
+ // The main loop
+ for (int n = 0; n < top[0]->num(); ++n) {
+ for (int c = 0; c < this->channels_; ++c) {
+ for (int ph = 0; ph < this->pooled_height_; ++ph) {
+ for (int pw = 0; pw < this->pooled_width_; ++pw) {
+ int hstart = ph * this->stride_h_ - this->pad_h_;
+ int wstart = pw * this->stride_w_ - this->pad_w_;
+ int hend = min(hstart + this->kernel_h_,
+ this->height_ + this->pad_h_);
+ int wend = min(wstart + this->kernel_w_,
+ this->width_ + this->pad_w_);
+ int pool_size = (hend - hstart) * (wend - wstart);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ hend = min(hend, this->height_);
+ wend = min(wend, this->width_);
+ for (int h = hstart; h < hend; ++h) {
+ for (int w = wstart; w < wend; ++w) {
+ bottom_diff[h * this->width_ + w] +=
+ top_diff[ph * this->pooled_width_ + pw] / pool_size;
+ }
+ }
+ }
+ }
+ // offset
+ bottom_diff += (*bottom)[0]->offset(0, 1);
+ top_diff += top[0]->offset(0, 1);
+ }
+ }
+ break;
+ case PoolingParameter_PoolMethod_STOCHASTIC:
+ NOT_IMPLEMENTED;
+ break;
+ default:
+ LOG(FATAL) << "Unknown pooling method.";
+ }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(CaffePoolingLayer);
+#endif
+
+INSTANTIATE_CLASS(CaffePoolingLayer);
+
+} // namespace caffe
+
template <typename Dtype>
-void PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void CaffePoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
}
// NOLINT_NEXT_LINE(whitespace/operators)
MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
- count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_h_,
- kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data,
- mask, top_mask);
+ count, bottom_data, bottom[0]->num(), this->channels_, this->height_,
+ this->width_, this->pooled_height_, this->pooled_width_,
+ this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_,
+ this->pad_h_, this->pad_w_, top_data, mask, top_mask);
break;
case PoolingParameter_PoolMethod_AVE:
// NOLINT_NEXT_LINE(whitespace/operators)
AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
- count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_h_,
- kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data);
+ count, bottom_data, bottom[0]->num(), this->channels_, this->height_,
+ this->width_, this->pooled_height_, this->pooled_width_,
+ this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_,
+ this->pad_h_, this->pad_w_, top_data);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
if (Caffe::phase() == Caffe::TRAIN) {
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolForwardTrain<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
- count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_h_,
- kernel_w_, stride_h_, stride_w_,
+ count, bottom_data, bottom[0]->num(), this->channels_, this->height_,
+ this->width_, this->pooled_height_, this->pooled_width_,
+ this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_,
rand_idx_.mutable_gpu_data(), top_data);
} else {
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolForwardTest<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
- count, bottom_data, bottom[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_h_,
- kernel_w_, stride_h_, stride_w_, top_data);
+ count, bottom_data, bottom[0]->num(), this->channels_, this->height_,
+ this->width_, this->pooled_height_, this->pooled_width_,
+ this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_,
+ top_data);
}
break;
default:
template <typename Dtype>
-void PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+void CaffePoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
if (!propagate_down[0]) {
return;
}
// NOLINT_NEXT_LINE(whitespace/operators)
MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
- count, top_diff, mask, top_mask, top[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_,
- kernel_h_, kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_,
- bottom_diff);
+ count, top_diff, mask, top_mask, top[0]->num(), this->channels_,
+ this->height_, this->width_, this->pooled_height_, this->pooled_width_,
+ this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_,
+ this->pad_h_, this->pad_w_, bottom_diff);
break;
case PoolingParameter_PoolMethod_AVE:
// NOLINT_NEXT_LINE(whitespace/operators)
AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
- count, top_diff, top[0]->num(), channels_,
- height_, width_, pooled_height_, pooled_width_, kernel_h_,
- kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, bottom_diff);
+ count, top_diff, top[0]->num(), this->channels_, this->height_,
+ this->width_, this->pooled_height_, this->pooled_width_,
+ this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_,
+ this->pad_h_, this->pad_w_, bottom_diff);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
- count, rand_idx_.gpu_data(), top_diff,
- top[0]->num(), channels_, height_, width_, pooled_height_,
- pooled_width_, kernel_h_, kernel_w_, stride_h_, stride_w_,
+ count, rand_idx_.gpu_data(), top_diff, top[0]->num(), this->channels_,
+ this->height_, this->width_, this->pooled_height_, this->pooled_width_,
+ this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_,
bottom_diff);
break;
default:
}
-INSTANTIATE_CLASS(PoolingLayer);
+INSTANTIATE_CLASS(CaffePoolingLayer);
} // namespace caffe
PoolingParameter_PoolMethod_AVE);
pool_param.mutable_pooling_param()->set_pad(pre_pad_);
pool_param.mutable_pooling_param()->set_kernel_size(size_);
- pool_layer_.reset(new PoolingLayer<Dtype>(pool_param));
+ pool_layer_.reset(new CaffePoolingLayer<Dtype>(pool_param));
pool_layer_->SetUp(square_top_vec_, &pool_top_vec_);
CHECK_EQ(pool_output_.num(), num_);
CHECK_EQ(pool_output_.channels(), channels_);
stride_w_ = pool_param.stride_w();
}
if (pad_h_ != 0 || pad_w_ != 0) {
- CHECK(this->layer_param_.pooling_param().pool()
+ CHECK(pool_param.pool()
== PoolingParameter_PoolMethod_AVE
- || this->layer_param_.pooling_param().pool()
+ || pool_param.pool()
== PoolingParameter_PoolMethod_MAX)
<< "Padding implemented only for average and max pooling.";
CHECK_LT(pad_h_, kernel_h_);
if (top->size() > 1) {
(*top)[1]->ReshapeLike(*(*top)[0]);
}
- // If max pooling, we will initialize the vector index part.
- if (this->layer_param_.pooling_param().pool() ==
- PoolingParameter_PoolMethod_MAX && top->size() == 1) {
- max_idx_.Reshape(bottom[0]->num(), channels_, pooled_height_,
- pooled_width_);
- }
- // If stochastic pooling, we will initialize the random index part.
- if (this->layer_param_.pooling_param().pool() ==
- PoolingParameter_PoolMethod_STOCHASTIC) {
- rand_idx_.Reshape(bottom[0]->num(), channels_, pooled_height_,
- pooled_width_);
- }
-}
-
-// TODO(Yangqing): Is there a faster way to do pooling in the channel-first
-// case?
-template <typename Dtype>
-void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- const Dtype* bottom_data = bottom[0]->cpu_data();
- Dtype* top_data = (*top)[0]->mutable_cpu_data();
- const int top_count = (*top)[0]->count();
- // We'll output the mask to top[1] if it's of size >1.
- const bool use_top_mask = top->size() > 1;
- int* mask = NULL; // suppress warnings about uninitalized variables
- Dtype* top_mask = NULL;
- // Different pooling methods. We explicitly do the switch outside the for
- // loop to save time, although this results in more code.
- switch (this->layer_param_.pooling_param().pool()) {
- case PoolingParameter_PoolMethod_MAX:
- // Initialize
- if (use_top_mask) {
- top_mask = (*top)[1]->mutable_cpu_data();
- caffe_set(top_count, Dtype(-1), top_mask);
- } else {
- mask = max_idx_.mutable_cpu_data();
- caffe_set(top_count, -1, mask);
- }
- caffe_set(top_count, Dtype(-FLT_MAX), top_data);
- // The main loop
- for (int n = 0; n < bottom[0]->num(); ++n) {
- for (int c = 0; c < channels_; ++c) {
- for (int ph = 0; ph < pooled_height_; ++ph) {
- for (int pw = 0; pw < pooled_width_; ++pw) {
- int hstart = ph * stride_h_ - pad_h_;
- int wstart = pw * stride_w_ - pad_w_;
- int hend = min(hstart + kernel_h_, height_);
- int wend = min(wstart + kernel_w_, width_);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- const int pool_index = ph * pooled_width_ + pw;
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- const int index = h * width_ + w;
- if (bottom_data[index] > top_data[pool_index]) {
- top_data[pool_index] = bottom_data[index];
- if (use_top_mask) {
- top_mask[pool_index] = static_cast<Dtype>(index);
- } else {
- mask[pool_index] = index;
- }
- }
- }
- }
- }
- }
- // compute offset
- bottom_data += bottom[0]->offset(0, 1);
- top_data += (*top)[0]->offset(0, 1);
- if (use_top_mask) {
- top_mask += (*top)[0]->offset(0, 1);
- } else {
- mask += (*top)[0]->offset(0, 1);
- }
- }
- }
- break;
- case PoolingParameter_PoolMethod_AVE:
- for (int i = 0; i < top_count; ++i) {
- top_data[i] = 0;
- }
- // The main loop
- for (int n = 0; n < bottom[0]->num(); ++n) {
- for (int c = 0; c < channels_; ++c) {
- for (int ph = 0; ph < pooled_height_; ++ph) {
- for (int pw = 0; pw < pooled_width_; ++pw) {
- int hstart = ph * stride_h_ - pad_h_;
- int wstart = pw * stride_w_ - pad_w_;
- int hend = min(hstart + kernel_h_, height_ + pad_h_);
- int wend = min(wstart + kernel_w_, width_ + pad_w_);
- int pool_size = (hend - hstart) * (wend - wstart);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- hend = min(hend, height_);
- wend = min(wend, width_);
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- top_data[ph * pooled_width_ + pw] +=
- bottom_data[h * width_ + w];
- }
- }
- top_data[ph * pooled_width_ + pw] /= pool_size;
- }
- }
- // compute offset
- bottom_data += bottom[0]->offset(0, 1);
- top_data += (*top)[0]->offset(0, 1);
- }
- }
- break;
- case PoolingParameter_PoolMethod_STOCHASTIC:
- NOT_IMPLEMENTED;
- break;
- default:
- LOG(FATAL) << "Unknown pooling method.";
- }
}
-template <typename Dtype>
-void PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
- const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
- if (!propagate_down[0]) {
- return;
- }
- const Dtype* top_diff = top[0]->cpu_diff();
- Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
- // Different pooling methods. We explicitly do the switch outside the for
- // loop to save time, although this results in more codes.
- caffe_set((*bottom)[0]->count(), Dtype(0), bottom_diff);
- // We'll output the mask to top[1] if it's of size >1.
- const bool use_top_mask = top.size() > 1;
- const int* mask = NULL; // suppress warnings about uninitialized variables
- const Dtype* top_mask = NULL;
- switch (this->layer_param_.pooling_param().pool()) {
- case PoolingParameter_PoolMethod_MAX:
- // The main loop
- if (use_top_mask) {
- top_mask = top[1]->cpu_data();
- } else {
- mask = max_idx_.cpu_data();
- }
- for (int n = 0; n < top[0]->num(); ++n) {
- for (int c = 0; c < channels_; ++c) {
- for (int ph = 0; ph < pooled_height_; ++ph) {
- for (int pw = 0; pw < pooled_width_; ++pw) {
- const int index = ph * pooled_width_ + pw;
- const int bottom_index =
- use_top_mask ? top_mask[index] : mask[index];
- bottom_diff[bottom_index] += top_diff[index];
- }
- }
- bottom_diff += (*bottom)[0]->offset(0, 1);
- top_diff += top[0]->offset(0, 1);
- if (use_top_mask) {
- top_mask += top[0]->offset(0, 1);
- } else {
- mask += top[0]->offset(0, 1);
- }
- }
- }
- break;
- case PoolingParameter_PoolMethod_AVE:
- // The main loop
- for (int n = 0; n < top[0]->num(); ++n) {
- for (int c = 0; c < channels_; ++c) {
- for (int ph = 0; ph < pooled_height_; ++ph) {
- for (int pw = 0; pw < pooled_width_; ++pw) {
- int hstart = ph * stride_h_ - pad_h_;
- int wstart = pw * stride_w_ - pad_w_;
- int hend = min(hstart + kernel_h_, height_ + pad_h_);
- int wend = min(wstart + kernel_w_, width_ + pad_w_);
- int pool_size = (hend - hstart) * (wend - wstart);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- hend = min(hend, height_);
- wend = min(wend, width_);
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- bottom_diff[h * width_ + w] +=
- top_diff[ph * pooled_width_ + pw] / pool_size;
- }
- }
- }
- }
- // offset
- bottom_diff += (*bottom)[0]->offset(0, 1);
- top_diff += top[0]->offset(0, 1);
- }
- }
- break;
- case PoolingParameter_PoolMethod_STOCHASTIC:
- NOT_IMPLEMENTED;
- break;
- default:
- LOG(FATAL) << "Unknown pooling method.";
- }
-}
-
-
-#ifdef CPU_ONLY
-STUB_GPU(PoolingLayer);
-#endif
-
INSTANTIATE_CLASS(PoolingLayer);
-
} // namespace caffe
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
- PoolingLayer<Dtype> max_layer(layer_param);
+ CaffePoolingLayer<Dtype> max_layer(layer_param);
max_layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
DropoutLayer<Dtype> dropout_layer(layer_param);
dropout_layer.SetUp(this->blob_top_vec_, &(this->blob_top_vec_));
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
const Dtype* top_data = this->blob_top_->cpu_data();
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
for (int i = 0; i < this->blob_top_->count(); ++i) {
blob_bottom_->mutable_cpu_data()[i + 13] = 2;
blob_bottom_->mutable_cpu_data()[i + 14] = 3;
}
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
EXPECT_EQ(blob_top_->num(), num);
EXPECT_EQ(blob_top_->channels(), channels);
blob_bottom_->mutable_cpu_data()[i + 34] = 18;
blob_bottom_->mutable_cpu_data()[i + 35] = 11;
}
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
EXPECT_EQ(blob_top_->num(), num);
EXPECT_EQ(blob_top_->channels(), channels);
blob_bottom_->mutable_cpu_data()[i + 34] = 18;
blob_bottom_->mutable_cpu_data()[i + 35] = 11;
}
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
EXPECT_EQ(blob_top_->num(), num);
EXPECT_EQ(blob_top_->channels(), channels);
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
pooling_param->set_stride(2);
pooling_param->set_pad(1);
pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
pooling_param->set_stride(2);
pooling_param->set_pad(1);
pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-4, 1e-2);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
this->blob_bottom_->mutable_cpu_data()[6] = 4;
this->blob_bottom_->mutable_cpu_data()[7] = 2;
this->blob_bottom_->mutable_cpu_data()[8] = 1;
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), 1);
EXPECT_EQ(this->blob_top_->channels(), 1);
pooling_param->set_stride(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
this->blob_top_vec_.push_back(this->blob_top_mask_);
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-4, 1e-2);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
filler_param.set_value(Dtype(2));
ConstantFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), 1);
EXPECT_EQ(this->blob_top_->channels(), 1);
pooling_param->set_kernel_w(kernel_w);
pooling_param->set_stride(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-2);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
pooling_param->set_stride(2);
pooling_param->set_pad(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
- PoolingLayer<Dtype> layer(layer_param);
+ CaffePoolingLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-2);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
- PoolingLayer<TypeParam> layer(layer_param);
+ CaffePoolingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_STOCHASTIC);
- PoolingLayer<TypeParam> layer(layer_param);
+ CaffePoolingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_STOCHASTIC);
- PoolingLayer<TypeParam> layer(layer_param);
+ CaffePoolingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_STOCHASTIC);
- PoolingLayer<TypeParam> layer(layer_param);
+ CaffePoolingLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-4, 1e-2);
// it is too expensive to call curand multiple times, so we don't do an
// exhaustive gradient check.