padding layer cuda code, need debug
authorYangqing Jia <jiayq84@gmail.com>
Thu, 19 Sep 2013 01:17:09 +0000 (18:17 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 19 Sep 2013 01:17:09 +0000 (18:17 -0700)
src/caffeine/blob.hpp
src/caffeine/common.hpp
src/caffeine/filler.hpp
src/caffeine/layers/padding_layer.cu [new file with mode: 0644]
src/caffeine/proto/layer_param.proto
src/caffeine/test_padding_layer.cpp [new file with mode: 0644]
src/caffeine/vision_layers.hpp

index 42b3d96..8a729a5 100644 (file)
@@ -24,6 +24,16 @@ class Blob {
   inline int height() const { return height_; }
   inline int width() const { return width_; }
   inline int count() const {return count_; }
+
+  inline Dtype data_at(const int n, const int c, const int h,
+      const int w) const {
+    return cpu_data()[((n * channels_ + c) * height_ + h) * width_ + w];
+  }
+
+  inline Dtype diff_at(const int n, const int c, const int h,
+      const int w) const {
+    return cpu_diff()[((n * channels_ + c) * height_ + h) * width_ + w];
+  }
   
   const Dtype* cpu_data() const;
   const Dtype* gpu_data() const;
index 6721e26..2da0df1 100644 (file)
 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
 
+#define CUDA_POST_KERNEL_CHECK \
+  if (cudaSuccess != cudaPeekAtLastError()) {\
+    LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \
+  }
+
 #define INSTANTIATE_CLASS(classname) \
   template class classname<float>; \
   template class classname<double>
index 07f31da..b15b38e 100644 (file)
@@ -61,7 +61,8 @@ class UniformFiller : public Filler<Dtype> {
       break;
     case sizeof(double):
       VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffeine::vsl_stream(),
-          count, (double*)data, this->filler_param_.min(), this->filler_param_.max()));
+          count, (double*)data, this->filler_param_.min(),
+          this->filler_param_.max()));
       break;
     default:
       CHECK(false) << "Unknown dtype.";
diff --git a/src/caffeine/layers/padding_layer.cu b/src/caffeine/layers/padding_layer.cu
new file mode 100644 (file)
index 0000000..2171467
--- /dev/null
@@ -0,0 +1,136 @@
+#include "caffeine/layer.hpp"
+#include "caffeine/vision_layers.hpp"
+
+#include <iostream>
+
+namespace caffeine {
+
+template <typename Dtype>
+void PaddingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  PAD_ = this->layer_param_.pad();
+  CHECK_EQ(bottom.size(), 1) << "Padding Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << "Padding Layer takes a single blob as output.";
+  NUM_ = bottom[0]->num();
+  CHANNEL_ = bottom[0]->channels();
+  HEIGHT_IN_ = bottom[0]->height();
+  WIDTH_IN_ = bottom[0]->width();
+  HEIGHT_OUT_ = HEIGHT_IN_ + PAD_ * 2;
+  WIDTH_OUT_ = WIDTH_IN_ + PAD_ * 2;
+  (*top)[0]->Reshape(NUM_, CHANNEL_, HEIGHT_OUT_, WIDTH_OUT_);
+
+};
+
+template <typename Dtype>
+void PaddingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  memset(top_data, 0, sizeof(Dtype) * (*top)[0]->count());
+  // In short, top[n, c, h, w] = bottom[n, c, h-pad, w-pad] if in range
+  for (int n = 0; n < NUM_; ++n) {
+    for (int c = 0; c < CHANNEL_; ++c) {
+      for (int h = 0; h < HEIGHT_IN_; ++h) {
+        // copy the width part
+        memcpy(
+            top_data + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
+                * WIDTH_OUT_ + PAD_,
+            bottom_data + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
+            sizeof(Dtype) * WIDTH_IN_);
+      }
+    }
+  }
+}
+
+template <typename Dtype>
+Dtype PaddingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  const Dtype* top_diff = top[0]->cpu_diff();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  //memset(bottom_data, 0, sizeof(Dtype) * (*bottom)[0]->count());
+  for (int n = 0; n < NUM_; ++n) {
+    for (int c = 0; c < CHANNEL_; ++c) {
+      for (int h = 0; h < HEIGHT_IN_; ++h) {
+        // copy the width part
+        memcpy(
+            bottom_diff + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
+            top_diff + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
+                * WIDTH_OUT_ + PAD_,
+            sizeof(Dtype) * WIDTH_IN_);
+      }
+    }
+  }
+  return Dtype(0.);
+}
+
+template <typename Dtype>
+__global__ void PaddingForward(const int count, const Dtype* in, Dtype* out,
+    const int num, const int channel, const int height_in, const int width_in,
+    const int pad) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < count) {
+    int height_out = height_in + pad + pad;
+    int width_out = width_in + pad + pad;
+    int w = index % width_in;
+    index /= width_in;
+    int h = index % height_in;
+    index /= height_in;
+    int c = index % channel;
+    index /= channel;
+    out[((index * channel + c) * height_out + h + pad) * width_out + pad + w] =
+        in[((index * channel + c) * height_in + h) * width_in + w];
+  }
+}
+
+template <typename Dtype>
+void PaddingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  const int count = bottom[0]->count();
+  // First, set all data to be zero for the boundary pixels
+  CUDA_CHECK(cudaMemset(top_data, 0, sizeof(Dtype) * (*top)[0]->count()));
+  PaddingForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+      count, bottom_data, top_data, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
+      PAD_);
+  CUDA_POST_KERNEL_CHECK;
+}
+
+template <typename Dtype>
+__global__ void PaddingBackward(const int count, const Dtype* in, Dtype* out,
+    const int num, const int channel, const int height_in, const int width_in,
+    const int pad) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < count) {
+    int height_out = height_in + pad + pad;
+    int width_out = width_in + pad + pad;
+    int w = index % width_in;
+    index /= width_in;
+    int h = index % height_in;
+    index /= height_in;
+    int c = index % channel;
+    index /= channel;
+    out[((index * channel + c) * height_in + h) * width_in + w] =
+        in[((index * channel + c) * height_out + h + pad) * width_out + pad + w];
+  }
+}
+
+template <typename Dtype>
+Dtype PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  if (propagate_down) {
+    const Dtype* top_diff = top[0]->gpu_diff();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+    const int count = (*bottom)[0]->count();
+    PaddingBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+        count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
+        PAD_);
+  }
+  return Dtype(0);
+}
+
+INSTANTIATE_CLASS(PaddingLayer);
+
+
+}  // namespace caffeine
index 27246ff..a20342f 100644 (file)
@@ -31,13 +31,12 @@ message LayerParameter {
   optional FillerParameter bias_filler = 6; // The filler for the bias
 
   optional uint32 pad = 7 [default = 0]; // The padding size
-  optional float pad_value = 8 [default = 0]; // The padding value
-  optional uint32 kernelsize = 9; // The kernel size
-  optional uint32 group = 10 [default = 1]; // The group size for group conv
-  optional uint32 stride = 11 [default = 1]; // The stride
-  optional string pool = 12 [default = 'max']; // The pooling method
-  optional float dropout_ratio = 13 [default = 0.5]; // dropout ratio
+  optional uint32 kernelsize = 8; // The kernel size
+  optional uint32 group = 9 [default = 1]; // The group size for group conv
+  optional uint32 stride = 10 [default = 1]; // The stride
+  optional string pool = 11 [default = 'max']; // The pooling method
+  optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio
 
-  optional float alpha = 14 [default = 1.]; // for local response norm
-  optional float beta = 15 [default = 0.75]; // for local response norm
+  optional float alpha = 13 [default = 1.]; // for local response norm
+  optional float beta = 14 [default = 0.75]; // for local response norm
 }
diff --git a/src/caffeine/test_padding_layer.cpp b/src/caffeine/test_padding_layer.cpp
new file mode 100644 (file)
index 0000000..8f28040
--- /dev/null
@@ -0,0 +1,103 @@
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffeine/blob.hpp"
+#include "caffeine/common.hpp"
+#include "caffeine/filler.hpp"
+#include "caffeine/vision_layers.hpp"
+#include "caffeine/test/test_gradient_check_util.hpp"
+
+
+namespace caffeine {
+
+extern cudaDeviceProp CAFFEINE_TEST_CUDA_PROP;
+  
+template <typename Dtype>
+class PaddingLayerTest : public ::testing::Test {
+ protected:
+  PaddingLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
+        blob_top_(new Blob<Dtype>()) {
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  };
+  virtual ~PaddingLayerTest() { delete blob_bottom_; delete blob_top_; }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(PaddingLayerTest, Dtypes);
+
+TYPED_TEST(PaddingLayerTest, TestCPU) {
+  LayerParameter layer_param;
+  layer_param.set_pad(1);
+  Caffeine::set_mode(Caffeine::CPU);
+  PaddingLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 3);
+  EXPECT_EQ(this->blob_top_->height(), 6);
+  EXPECT_EQ(this->blob_top_->width(), 7);
+  for (int n = 0; n < 2; ++n) {
+    for (int c = 0; c < 3; ++c) {
+      for (int h = 0; h < 4; ++h) {
+        for (int w = 0; w < 5; ++w) {
+          EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w),
+              this->blob_top_->data_at(n, c, h + 1, w + 1));
+        }
+      }
+    }
+  }
+}
+
+TYPED_TEST(PaddingLayerTest, TestCPUGrad) {
+  LayerParameter layer_param;
+  layer_param.set_pad(1);
+  Caffeine::set_mode(Caffeine::CPU);
+  PaddingLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
+  checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+TYPED_TEST(PaddingLayerTest, TestGPU) {
+  LayerParameter layer_param;
+  layer_param.set_pad(1);
+  Caffeine::set_mode(Caffeine::GPU);
+  PaddingLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 3);
+  EXPECT_EQ(this->blob_top_->height(), 6);
+  EXPECT_EQ(this->blob_top_->width(), 7);
+  for (int n = 0; n < 2; ++n) {
+    for (int c = 0; c < 3; ++c) {
+      for (int h = 0; h < 4; ++h) {
+        for (int w = 0; w < 5; ++w) {
+          EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w),
+              this->blob_top_->data_at(n, c, h + 1, w + 1));
+        }
+      }
+    }
+  }
+}
+
+TYPED_TEST(PaddingLayerTest, TestGPUGrad) {
+  LayerParameter layer_param;
+  layer_param.set_pad(1);
+  Caffeine::set_mode(Caffeine::GPU);
+  PaddingLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
+  checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+}
index e324c8e..19fd00e 100644 (file)
@@ -83,6 +83,31 @@ class InnerProductLayer : public Layer<Dtype> {
   shared_ptr<SyncedMemory> bias_multiplier_;
 };
 
+template <typename Dtype>
+class PaddingLayer : public Layer<Dtype> {
+ public:
+  explicit PaddingLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {};
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  unsigned int PAD_;
+  int NUM_;
+  int CHANNEL_;
+  int HEIGHT_IN_;
+  int WIDTH_IN_;
+  int HEIGHT_OUT_;
+  int WIDTH_OUT_;
+};
+
 }  // namespace caffeine
 
 #endif  // CAFFEINE_VISION_LAYERS_HPP_