pooling layer
authorYangqing Jia <jiayq84@gmail.com>
Mon, 23 Sep 2013 21:42:24 +0000 (14:42 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 23 Sep 2013 21:42:24 +0000 (14:42 -0700)
src/caffe/layers/pooling_layer.cpp
src/caffe/layers/pooling_layer.cu [new file with mode: 0644]
src/caffe/test/test_gradient_check_util.hpp
src/caffe/test/test_pooling_layer.cpp
src/caffe/util/im2col.cu
src/caffe/vision_layers.hpp

index d7e6b4f..88a8f41 100644 (file)
@@ -4,13 +4,13 @@
 #include "caffe/vision_layers.hpp"
 #include "caffe/util/math_functions.hpp"
 
+#define CAFFE_MAX_POOLING_THRESHOLD 1e-8f
+
 using std::max;
 using std::min;
 
 namespace caffe {
 
-const float CAFFE_MAX_POOLING_THRESHOLD = 1e-8;
-
 template <typename Dtype>
 void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
diff --git a/src/caffe/layers/pooling_layer.cu b/src/caffe/layers/pooling_layer.cu
new file mode 100644 (file)
index 0000000..2086869
--- /dev/null
@@ -0,0 +1,187 @@
+#include <algorithm>
+#include <cfloat>
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+#define CAFFE_MAX_POOLING_THRESHOLD 1e-8f
+
+using std::max;
+using std::min;
+
+namespace caffe {
+
+template <typename Dtype>
+__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
+    const int num, const int channels, const int height,
+    const int width, const int pooled_height, const int pooled_width,
+    const int ksize, const int stride, Dtype* top_data) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < nthreads) {
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+    int hstart = ph * stride;
+    int hend = min(hstart + ksize, height);
+    int wstart = pw * stride;
+    int wend = min(wstart + ksize, width);
+    Dtype maxval = -FLT_MAX;
+    bottom_data += (n * channels + c) * height * width;
+    for (int h = hstart; h < hend; ++h) {
+      for (int w = wstart; w < wend; ++w) {
+        maxval = max(maxval, bottom_data[h * width + w]);
+      }
+    }
+    top_data[index] = maxval;
+  }  // (if index < nthreads)
+}
+
+template <typename Dtype>
+__global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
+    const int num, const int channels, const int height,
+    const int width, const int pooled_height, const int pooled_width,
+    const int ksize, const int stride, Dtype* top_data) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < nthreads) {
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+    int hstart = ph * stride;
+    int hend = min(hstart + ksize, height);
+    int wstart = pw * stride;
+    int wend = min(wstart + ksize, width);
+    Dtype aveval = 0;
+    bottom_data += (n * channels + c) * height * width;
+    for (int h = hstart; h < hend; ++h) {
+      for (int w = wstart; w < wend; ++w) {
+        aveval += bottom_data[h * width + w];
+      }
+    }
+    top_data[index] = aveval / ksize / ksize;
+  }  // (if index < nthreads)
+}
+
+template <typename Dtype>
+void PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  int count = (*top)[0]->count();
+  switch (this->layer_param_.pool()) {
+  case LayerParameter_PoolMethod_MAX:
+    MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+        count, bottom_data, bottom[0]->num(), CHANNELS_,
+        HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
+        top_data);
+    break;
+  case LayerParameter_PoolMethod_AVE:
+    AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+        count, bottom_data, bottom[0]->num(), CHANNELS_,
+        HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
+        top_data);
+    break;
+  default:
+    LOG(FATAL) << "Unknown pooling method.";
+  }
+  CUDA_POST_KERNEL_CHECK;
+}
+
+template <typename Dtype>
+__global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
+    const Dtype* top_data, const Dtype* top_diff,
+    const int num, const int channels, const int height,
+    const int width, const int pooled_height, const int pooled_width,
+    const int ksize, const int stride, Dtype* bottom_diff) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < nthreads) {
+    // find out the local index
+    // find out the local offset
+    int w = index % width;
+    int h = (index / width) % height;
+    int c = (index / width / height) % channels;
+    int n = index / width / height / channels;
+    int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
+    int phend = min(h / stride + 1, pooled_height);
+    int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
+    int pwend = min(w / stride + 1, pooled_width);
+    Dtype gradient = 0;
+    Dtype bottom_datum = 
+        bottom_data[((n * channels + c) * height + h) * width + w];
+    top_data += (n * channels + c) * pooled_height * pooled_width;
+    top_diff += (n * channels + c) * pooled_height * pooled_width;
+    for (int ph = phstart; ph < phend; ++ph) {
+      for (int pw = pwstart; pw < pwend; ++pw) {
+        gradient += top_diff[ph * pooled_width + pw] *
+            (bottom_datum >= top_data[ph * pooled_width + pw] -
+                CAFFE_MAX_POOLING_THRESHOLD);
+      }
+    }
+    bottom_diff[index] = gradient;
+  }  // (if index < nthreads)
+}
+
+
+template <typename Dtype>
+__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
+    const int num, const int channels, const int height,
+    const int width, const int pooled_height, const int pooled_width,
+    const int ksize, const int stride, Dtype* bottom_diff) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < nthreads) {
+    // find out the local index
+    // find out the local offset
+    int w = index % width;
+    int h = (index / width) % height;
+    int c = (index / width / height) % channels;
+    int n = index / width / height / channels;
+    int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
+    int phend = min(h / stride + 1, pooled_height);
+    int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
+    int pwend = min(w / stride + 1, pooled_width);
+    Dtype gradient = 0;
+    top_diff += (n * channels + c) * pooled_height * pooled_width;
+    for (int ph = phstart; ph < phend; ++ph) {
+      for (int pw = pwstart; pw < pwend; ++pw) {
+        gradient += top_diff[ph * pooled_width + pw];
+      }
+    }
+    bottom_diff[index] = gradient / ksize / ksize;
+  }  // (if index < nthreads)
+}
+
+template <typename Dtype>
+Dtype PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  if (!propagate_down) {
+    return Dtype(0.);
+  }
+  const Dtype* top_diff = top[0]->gpu_diff();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+  int count = (*bottom)[0]->count();
+  switch (this->layer_param_.pool()) {
+  case LayerParameter_PoolMethod_MAX:
+    MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+        count, (*bottom)[0]->gpu_data(), top[0]->gpu_data(), top_diff,
+        top[0]->num(), CHANNELS_, HEIGHT_, WIDTH_, POOLED_HEIGHT_,
+        POOLED_WIDTH_, KSIZE_, STRIDE_, bottom_diff);
+    break;
+  case LayerParameter_PoolMethod_AVE:
+    AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+        count, top_diff, top[0]->num(), CHANNELS_,
+        HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
+        bottom_diff);
+    break;
+  default:
+    LOG(FATAL) << "Unknown pooling method.";
+  }
+  CUDA_POST_KERNEL_CHECK;
+  return Dtype(0.);
+}
+
+
+INSTANTIATE_CLASS(PoolingLayer);
+
+
+}  // namespace caffe
index 8cf7851..7984aaf 100644 (file)
@@ -110,9 +110,11 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>& layer,
         Dtype scale = max(max(fabs(computed_gradient), fabs(estimated_gradient)),
             1.);
         EXPECT_GT(computed_gradient, estimated_gradient - threshold_ * scale)
-          << "debug: (blob_id, feat_id)=" << blobid << "," << feat_id;
+          << "debug: (top_id, top_data_id, blob_id, feat_id)="
+          << top_id << "," << top_data_id << "," << blobid << "," << feat_id;
         EXPECT_LT(computed_gradient, estimated_gradient + threshold_ * scale)
-          << "debug: (blob_id, feat_id)=" << blobid << "," << feat_id;
+          << "debug: (top_id, top_data_id, blob_id, feat_id)="
+          << top_id << "," << top_data_id << "," << blobid << "," << feat_id;
       }
       //LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id];
       //LOG(ERROR) << "computed gradient: " << computed_gradient
index d9d76fe..2e3c60f 100644 (file)
@@ -21,6 +21,7 @@ class PoolingLayerTest : public ::testing::Test {
       : blob_bottom_(new Blob<Dtype>()),
         blob_top_(new Blob<Dtype>()) {};
   virtual void SetUp() {
+    Caffe::set_random_seed(1701);
     blob_bottom_->Reshape(2, 3, 6, 5);
     // fill the values
     FillerParameter filler_param;
@@ -53,6 +54,71 @@ TYPED_TEST(PoolingLayerTest, TestSetup) {
   EXPECT_EQ(this->blob_top_->width(), 2);
 }
 
+TYPED_TEST(PoolingLayerTest, TestGPUMax) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  layer_param.set_pool(LayerParameter_PoolMethod_MAX);
+  Caffe::set_mode(Caffe::CPU);
+  PoolingLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  Blob<TypeParam> blob_reference(*this->blob_top_);
+  Caffe::set_mode(Caffe::GPU);
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  for (int i = 0; i < blob_reference.count(); ++i) {
+    EXPECT_EQ(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i])
+        << "debug: index " << i;
+  }
+}
+
+TYPED_TEST(PoolingLayerTest, TestGPUAve) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  layer_param.set_pool(LayerParameter_PoolMethod_AVE);
+  Caffe::set_mode(Caffe::CPU);
+  PoolingLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  Blob<TypeParam> blob_reference(*this->blob_top_);
+  Caffe::set_mode(Caffe::GPU);
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  for (int i = 0; i < blob_reference.count(); ++i) {
+    EXPECT_GE(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i] - 1e-4)
+        << "debug: index " << i;
+    EXPECT_LE(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i] + 1e-4)
+        << "debug: index " << i;
+  }
+}
+
+/*
+TYPED_TEST(PoolingLayerTest, PrintGPUBackward) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  layer_param.set_pool(LayerParameter_PoolMethod_MAX);
+  Caffe::set_mode(Caffe::GPU);
+  PoolingLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    cout << "bottom data " << i << " " << this->blob_bottom_->cpu_data()[i] << endl;
+  }
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    cout << "top data " << i << " " << this->blob_top_->cpu_data()[i] << endl;
+  }  
+
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    this->blob_top_->mutable_cpu_diff()[i] = 1.;
+  }
+  layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    cout << "bottom diff " << i << " " << this->blob_bottom_->cpu_diff()[i] << endl;
+  }  
+}
+*/
+
 TYPED_TEST(PoolingLayerTest, TestCPUGradientMax) {
   LayerParameter layer_param;
   layer_param.set_kernelsize(3);
@@ -64,6 +130,17 @@ TYPED_TEST(PoolingLayerTest, TestCPUGradientMax) {
   checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
 }
 
+TYPED_TEST(PoolingLayerTest, TestGPUGradientMax) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  layer_param.set_pool(LayerParameter_PoolMethod_MAX);
+  Caffe::set_mode(Caffe::GPU);
+  PoolingLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-4, 1e-2);
+  checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
 
 TYPED_TEST(PoolingLayerTest, TestCPUGradientAve) {
   LayerParameter layer_param;
@@ -76,16 +153,17 @@ TYPED_TEST(PoolingLayerTest, TestCPUGradientAve) {
   checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
 }
 
-/*
-TYPED_TEST(PoolingLayerTest, TestGPUGradient) {
+
+TYPED_TEST(PoolingLayerTest, TestGPUGradientAve) {
   LayerParameter layer_param;
   layer_param.set_kernelsize(3);
   layer_param.set_stride(2);
+  layer_param.set_pool(LayerParameter_PoolMethod_AVE);
   Caffe::set_mode(Caffe::GPU);
   PoolingLayer<TypeParam> layer(layer_param);
   GradientChecker<TypeParam> checker(1e-2, 1e-2);
   checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
 }
-*/
+
 
 }
index aef6842..161a45d 100644 (file)
@@ -60,6 +60,7 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
   const int stride, const int height_col, const int width_col, Dtype* data_im) {
   int index = threadIdx.x + blockIdx.x * blockDim.x;
   if (index < n) {
+    Dtype val = 0;
     int w = index % width;
     int h = (index / width) % height;
     int c = index / (width * height);
@@ -72,9 +73,10 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
       for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
         // the col location: [c * width * height + h_out, w_out]
         int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride); 
-        data_im[index] += data_col[(c_col * height_col + h_col) * width_col + w_col];
+        val += data_col[(c_col * height_col + h_col) * width_col + w_col];
       }
     }
+    data_im[index] = val;
   }
 }
 
@@ -82,7 +84,7 @@ template <typename Dtype>
 void col2im_gpu(const Dtype* data_col, const int channels,
     const int height, const int width, const int ksize, const int stride,
     Dtype* data_im) {
-  CUDA_CHECK(cudaMemset(data_im, 0, sizeof(Dtype) * height * width * channels));
+  //CUDA_CHECK(cudaMemset(data_im, 0, sizeof(Dtype) * height * width * channels));
   int height_col = (height - ksize) / stride + 1;
   int width_col = (width - ksize) / stride + 1;
   int num_kernels = channels * height * width;
index 31c6b1d..5d99c48 100644 (file)
@@ -169,12 +169,12 @@ class PoolingLayer : public Layer<Dtype> {
  protected:
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-  //    vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
-  //    const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
   int KSIZE_;
   int STRIDE_;
   int CHANNELS_;