pooling layer cpu
authorYangqing Jia <jiayq84@gmail.com>
Mon, 23 Sep 2013 18:03:39 +0000 (11:03 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 23 Sep 2013 18:03:39 +0000 (11:03 -0700)
src/caffe/layers/pooling_layer.cpp [new file with mode: 0644]
src/caffe/proto/layer_param.proto
src/caffe/test/test_gradient_check_util.hpp
src/caffe/test/test_pooling_layer.cpp [new file with mode: 0644]
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.hpp
src/caffe/vision_layers.hpp

diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp
new file mode 100644 (file)
index 0000000..d7e6b4f
--- /dev/null
@@ -0,0 +1,193 @@
+#include <algorithm>
+#include <cfloat>
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+using std::max;
+using std::min;
+
+namespace caffe {
+
+const float CAFFE_MAX_POOLING_THRESHOLD = 1e-8;
+
+template <typename Dtype>
+void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 1) << "PoolingLayer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << "PoolingLayer takes a single blob as output.";
+  KSIZE_ = this->layer_param_.kernelsize();
+  STRIDE_ = this->layer_param_.stride();
+  CHANNELS_ = bottom[0]->channels();
+  HEIGHT_ = bottom[0]->height();
+  WIDTH_ = bottom[0]->width();
+  POOLED_HEIGHT_ = int(ceil(float(HEIGHT_ - KSIZE_) / STRIDE_)) + 1;
+  POOLED_WIDTH_ = int(ceil(float(WIDTH_ - KSIZE_) / STRIDE_)) + 1;
+  (*top)[0]->Reshape(bottom[0]->num(), CHANNELS_, POOLED_HEIGHT_,
+      POOLED_WIDTH_);
+};
+
+
+// TODO: Is there a faster way to do pooling in the channel-first case?
+template <typename Dtype>
+void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  // Different pooling methods. We explicitly do the switch outside the for
+  // loop to save time, although this results in more codes.
+  int top_count = (*top)[0]->count();
+  switch (this->layer_param_.pool()) {
+  case LayerParameter_PoolMethod_MAX:
+    // Initialize
+    for (int i = 0; i < top_count; ++i) {
+      top_data[i] = -FLT_MAX;
+    }
+    // The main loop
+    for (int n = 0; n < bottom[0]->num(); ++n) {
+      for (int c = 0; c < CHANNELS_; ++c) {
+        for (int ph = 0; ph < POOLED_HEIGHT_; ++ph) {
+          for (int pw = 0; pw < POOLED_WIDTH_; ++pw) {
+            int hstart = ph * STRIDE_;
+            int wstart = pw * STRIDE_;
+            int hend = min(hstart + KSIZE_, HEIGHT_);
+            int wend = min(wstart + KSIZE_, WIDTH_);
+            for (int h = hstart; h < hend; ++h) {
+              for (int w = wstart; w < wend; ++w) {
+                top_data[ph * POOLED_WIDTH_ + pw] =
+                  max(top_data[ph * POOLED_WIDTH_ + pw],
+                      bottom_data[h * WIDTH_ + w]);
+              }
+            }
+          }
+        }
+        // compute offset
+        bottom_data += bottom[0]->offset(0, 1);
+        top_data += (*top)[0]->offset(0, 1);
+      }
+    }
+    break;
+  case LayerParameter_PoolMethod_AVE:
+    for (int i = 0; i < top_count; ++i) {
+      top_data[i] = 0;
+    }
+    // The main loop
+    for (int n = 0; n < bottom[0]->num(); ++n) {
+      for (int c = 0; c < CHANNELS_; ++c) {
+        for (int ph = 0; ph < POOLED_HEIGHT_; ++ph) {
+          for (int pw = 0; pw < POOLED_WIDTH_; ++pw) {
+            int hstart = ph * STRIDE_;
+            int wstart = pw * STRIDE_;
+            int hend = min(hstart + KSIZE_, HEIGHT_);
+            int wend = min(wstart + KSIZE_, WIDTH_);
+            for (int h = hstart; h < hend; ++h) {
+              for (int w = wstart; w < wend; ++w) {
+                top_data[ph * POOLED_WIDTH_ + pw] +=
+                    bottom_data[h * WIDTH_ + w];
+              }
+            }
+          }
+        }
+        // compute offset
+        bottom_data += bottom[0]->offset(0, 1);
+        top_data += (*top)[0]->offset(0, 1);
+      }
+    }
+    // Our implementation simply divides the pooled values by KSIZE^2,
+    // regardless of the actual pooling region. This would allow one to not 
+    // trust too much on the border pooling regions, but I am not sure what
+    // benefit / harm it would bring to the actual code.
+    caffe_scal<Dtype>(top_count, Dtype(1.) / KSIZE_ / KSIZE_,
+        (*top)[0]->mutable_cpu_data());
+    break;
+  default:
+    LOG(FATAL) << "Unknown pooling method.";
+  }
+}
+
+template <typename Dtype>
+Dtype PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  if (!propagate_down) {
+    return Dtype(0.);
+  }
+  const Dtype* top_diff = top[0]->cpu_diff();
+  const Dtype* top_data = top[0]->cpu_data();
+  const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  // Different pooling methods. We explicitly do the switch outside the for
+  // loop to save time, although this results in more codes.
+  memset(bottom_diff, 0, (*bottom)[0]->count() * sizeof(Dtype));
+  switch (this->layer_param_.pool()) {
+  case LayerParameter_PoolMethod_MAX:
+    // The main loop
+    for (int n = 0; n < top[0]->num(); ++n) {
+      for (int c = 0; c < CHANNELS_; ++c) {
+        for (int ph = 0; ph < POOLED_HEIGHT_; ++ph) {
+          for (int pw = 0; pw < POOLED_WIDTH_; ++pw) {
+            int hstart = ph * STRIDE_;
+            int wstart = pw * STRIDE_;
+            int hend = min(hstart + KSIZE_, HEIGHT_);
+            int wend = min(wstart + KSIZE_, WIDTH_);
+            for (int h = hstart; h < hend; ++h) {
+              for (int w = wstart; w < wend; ++w) {
+                bottom_diff[h * WIDTH_ + w] +=
+                    top_diff[ph * POOLED_WIDTH_ + pw] *
+                    (bottom_data[h * WIDTH_ + w] >=
+                        top_data[ph * POOLED_WIDTH_ + pw] -
+                        CAFFE_MAX_POOLING_THRESHOLD);
+              }
+            }
+          }
+        }
+        // offset
+        bottom_data += (*bottom)[0]->offset(0, 1);
+        top_data += top[0]->offset(0, 1);
+        bottom_diff += (*bottom)[0]->offset(0, 1);
+        top_diff += top[0]->offset(0, 1);
+      }
+    }
+    break;
+  case LayerParameter_PoolMethod_AVE:
+    // The main loop
+    for (int n = 0; n < top[0]->num(); ++n) {
+      for (int c = 0; c < CHANNELS_; ++c) {
+        for (int ph = 0; ph < POOLED_HEIGHT_; ++ph) {
+          for (int pw = 0; pw < POOLED_WIDTH_; ++pw) {
+            int hstart = ph * STRIDE_;
+            int wstart = pw * STRIDE_;
+            int hend = min(hstart + KSIZE_, HEIGHT_);
+            int wend = min(wstart + KSIZE_, WIDTH_);
+            for (int h = hstart; h < hend; ++h) {
+              for (int w = wstart; w < wend; ++w) {
+                bottom_diff[h * WIDTH_ + w] +=
+                  top_diff[ph * POOLED_WIDTH_ + pw];
+              }
+            }
+          }
+        }
+        // offset
+        bottom_data += (*bottom)[0]->offset(0, 1);
+        top_data += top[0]->offset(0, 1);
+        bottom_diff += (*bottom)[0]->offset(0, 1);
+        top_diff += top[0]->offset(0, 1);
+      }
+    }
+    // Our implementation simply divides the pooled values by KSIZE^2,
+    // regardless of the actual pooling region. This would allow one to not 
+    // trust too much on the border pooling regions, but I am not sure what
+    // benefit / harm it would bring to the actual code.
+    caffe_scal<Dtype>((*bottom)[0]->count(), Dtype(1.) / KSIZE_ / KSIZE_,
+        (*bottom)[0]->mutable_cpu_diff());
+    break;
+  default:
+    LOG(FATAL) << "Unknown pooling method.";
+  }
+  return Dtype(0.);
+}
+
+
+INSTANTIATE_CLASS(PoolingLayer);
+
+
+}  // namespace caffe
index e3cd0b4..08a094c 100644 (file)
@@ -34,7 +34,11 @@ message LayerParameter {
   optional uint32 kernelsize = 8; // The kernel size
   optional uint32 group = 9 [default = 1]; // The group size for group conv
   optional uint32 stride = 10 [default = 1]; // The stride
-  optional string pool = 11 [default = 'max']; // The pooling method
+  enum PoolMethod {
+    MAX = 0;
+    AVE = 1;
+  }
+  optional PoolMethod pool = 11 [default = MAX]; // The pooling method
   optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio
 
   optional uint32 local_size = 13 [default = 5]; // for local response norm
index 475fd43..8cf7851 100644 (file)
@@ -109,8 +109,10 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>& layer,
         // the scale factor by 1.
         Dtype scale = max(max(fabs(computed_gradient), fabs(estimated_gradient)),
             1.);
-        EXPECT_GT(computed_gradient, estimated_gradient - threshold_ * scale);
-        EXPECT_LT(computed_gradient, estimated_gradient + threshold_ * scale);
+        EXPECT_GT(computed_gradient, estimated_gradient - threshold_ * scale)
+          << "debug: (blob_id, feat_id)=" << blobid << "," << feat_id;
+        EXPECT_LT(computed_gradient, estimated_gradient + threshold_ * scale)
+          << "debug: (blob_id, feat_id)=" << blobid << "," << feat_id;
       }
       //LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id];
       //LOG(ERROR) << "computed gradient: " << computed_gradient
diff --git a/src/caffe/test/test_pooling_layer.cpp b/src/caffe/test/test_pooling_layer.cpp
new file mode 100644 (file)
index 0000000..d9d76fe
--- /dev/null
@@ -0,0 +1,91 @@
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class PoolingLayerTest : public ::testing::Test {
+ protected:
+  PoolingLayerTest()
+      : blob_bottom_(new Blob<Dtype>()),
+        blob_top_(new Blob<Dtype>()) {};
+  virtual void SetUp() {
+    blob_bottom_->Reshape(2, 3, 6, 5);
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  };
+  virtual ~PoolingLayerTest() { delete blob_bottom_; delete blob_top_; }
+  void ReferenceLRNForward(const Blob<Dtype>& blob_bottom,
+      const LayerParameter& layer_param, Blob<Dtype>* blob_top);
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(PoolingLayerTest, Dtypes);
+
+TYPED_TEST(PoolingLayerTest, TestSetup) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  PoolingLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
+  EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
+  EXPECT_EQ(this->blob_top_->height(), 3);
+  EXPECT_EQ(this->blob_top_->width(), 2);
+}
+
+TYPED_TEST(PoolingLayerTest, TestCPUGradientMax) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  layer_param.set_pool(LayerParameter_PoolMethod_MAX);
+  Caffe::set_mode(Caffe::CPU);
+  PoolingLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-4, 1e-2);
+  checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+
+TYPED_TEST(PoolingLayerTest, TestCPUGradientAve) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  layer_param.set_pool(LayerParameter_PoolMethod_AVE);
+  Caffe::set_mode(Caffe::CPU);
+  PoolingLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+/*
+TYPED_TEST(PoolingLayerTest, TestGPUGradient) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  Caffe::set_mode(Caffe::GPU);
+  PoolingLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+*/
+
+}
index f503b83..c70dc29 100644 (file)
@@ -112,6 +112,16 @@ void caffe_copy<double>(const int N, const double* X, double* Y) {
 }
 
 template <>
+void caffe_scal<float>(const int N, const float alpha, float *X) {
+  cblas_sscal(N, alpha, X, 1);
+}
+
+template <>
+void caffe_scal<double>(const int N, const double alpha, double *X) {
+  cblas_dscal(N, alpha, X, 1);
+}
+
+template <>
 void caffe_sqr<float>(const int n, const float* a, float* y){
   vsSqr(n, a, y);
 }
index 60a4310..12887c3 100644 (file)
@@ -41,6 +41,9 @@ template <typename Dtype>
 void caffe_copy(const int N, const Dtype *X, Dtype *Y);
 
 template <typename Dtype>
+void caffe_scal (const int N, const Dtype alpha, Dtype *X);
+
+template <typename Dtype>
 void caffe_sqr(const int N, const Dtype* a, Dtype* y);
 
 template <typename Dtype>
index 35d1991..31c6b1d 100644 (file)
@@ -160,6 +160,31 @@ class Im2colLayer : public Layer<Dtype> {
 };
 
 template <typename Dtype>
+class PoolingLayer : public Layer<Dtype> {
+ public:
+  explicit PoolingLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {};
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+  //    vector<Blob<Dtype>*>* top);
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+  //    const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  int KSIZE_;
+  int STRIDE_;
+  int CHANNELS_;
+  int HEIGHT_;
+  int WIDTH_;
+  int POOLED_HEIGHT_;
+  int POOLED_WIDTH_;
+};
+
+template <typename Dtype>
 class ConvolutionLayer : public Layer<Dtype> {
  public:
   explicit ConvolutionLayer(const LayerParameter& param)