PoolingLayer customizable output shape rounding mode
[platform/upstream/caffe.git] / src / caffe / layers / pooling_layer.cpp
index 90897db..f2a0885 100644 (file)
@@ -35,6 +35,7 @@ void PoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       || (!pool_param.has_stride_h() && !pool_param.has_stride_w()))
       << "Stride is stride OR stride_h and stride_w are required.";
   global_pooling_ = pool_param.global_pooling();
+  round_mode_ = pool_param.round_mode();
   if (global_pooling_) {
     kernel_h_ = bottom[0]->height();
     kernel_w_ = bottom[0]->width();
@@ -87,10 +88,22 @@ void PoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
     kernel_h_ = bottom[0]->height();
     kernel_w_ = bottom[0]->width();
   }
-  pooled_height_ = static_cast<int>(ceil(static_cast<float>(
-      height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
-  pooled_width_ = static_cast<int>(ceil(static_cast<float>(
-      width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
+  switch (round_mode_) {
+  case PoolingParameter_RoundMode_CEIL:
+    pooled_height_ = static_cast<int>(ceil(static_cast<float>(
+        height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
+    pooled_width_ = static_cast<int>(ceil(static_cast<float>(
+        width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
+    break;
+  case PoolingParameter_RoundMode_FLOOR:
+    pooled_height_ = static_cast<int>(floor(static_cast<float>(
+        height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
+    pooled_width_ = static_cast<int>(floor(static_cast<float>(
+        width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
+    break;
+  default:
+    LOG(FATAL) << "Unknown rounding mode.";
+  }
   if (pad_h_ || pad_w_) {
     // If we have padding, ensure that the last pooling starts strictly
     // inside the image (instead of at the padding); otherwise clip the last.