From: Yangqing Jia Date: Thu, 19 Sep 2013 01:17:09 +0000 (-0700) Subject: padding layer cuda code, need debug X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=aef20cdaa27087b4fbd20336cd0318ca92a4d260;p=platform%2Fupstream%2Fcaffe.git padding layer cuda code, need debug --- diff --git a/src/caffeine/blob.hpp b/src/caffeine/blob.hpp index 42b3d96..8a729a5 100644 --- a/src/caffeine/blob.hpp +++ b/src/caffeine/blob.hpp @@ -24,6 +24,16 @@ class Blob { inline int height() const { return height_; } inline int width() const { return width_; } inline int count() const {return count_; } + + inline Dtype data_at(const int n, const int c, const int h, + const int w) const { + return cpu_data()[((n * channels_ + c) * height_ + h) * width_ + w]; + } + + inline Dtype diff_at(const int n, const int c, const int h, + const int w) const { + return cpu_diff()[((n * channels_ + c) * height_ + h) * width_ + w]; + } const Dtype* cpu_data() const; const Dtype* gpu_data() const; diff --git a/src/caffeine/common.hpp b/src/caffeine/common.hpp index 6721e26..2da0df1 100644 --- a/src/caffeine/common.hpp +++ b/src/caffeine/common.hpp @@ -15,6 +15,11 @@ #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS) #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK) +#define CUDA_POST_KERNEL_CHECK \ + if (cudaSuccess != cudaPeekAtLastError()) {\ + LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \ + } + #define INSTANTIATE_CLASS(classname) \ template class classname; \ template class classname diff --git a/src/caffeine/filler.hpp b/src/caffeine/filler.hpp index 07f31da..b15b38e 100644 --- a/src/caffeine/filler.hpp +++ b/src/caffeine/filler.hpp @@ -61,7 +61,8 @@ class UniformFiller : public Filler { break; case sizeof(double): VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffeine::vsl_stream(), - count, (double*)data, this->filler_param_.min(), this->filler_param_.max())); + count, (double*)data, this->filler_param_.min(), + this->filler_param_.max())); break; default: CHECK(false) << "Unknown dtype."; diff --git a/src/caffeine/layers/padding_layer.cu b/src/caffeine/layers/padding_layer.cu new file mode 100644 index 0000000..2171467 --- /dev/null +++ b/src/caffeine/layers/padding_layer.cu @@ -0,0 +1,136 @@ +#include "caffeine/layer.hpp" +#include "caffeine/vision_layers.hpp" + +#include + +namespace caffeine { + +template +void PaddingLayer::SetUp(const vector*>& bottom, + vector*>* top) { + PAD_ = this->layer_param_.pad(); + CHECK_EQ(bottom.size(), 1) << "Padding Layer takes a single blob as input."; + CHECK_EQ(top->size(), 1) << "Padding Layer takes a single blob as output."; + NUM_ = bottom[0]->num(); + CHANNEL_ = bottom[0]->channels(); + HEIGHT_IN_ = bottom[0]->height(); + WIDTH_IN_ = bottom[0]->width(); + HEIGHT_OUT_ = HEIGHT_IN_ + PAD_ * 2; + WIDTH_OUT_ = WIDTH_IN_ + PAD_ * 2; + (*top)[0]->Reshape(NUM_, CHANNEL_, HEIGHT_OUT_, WIDTH_OUT_); + +}; + +template +void PaddingLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + Dtype* top_data = (*top)[0]->mutable_cpu_data(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + memset(top_data, 0, sizeof(Dtype) * (*top)[0]->count()); + // In short, top[n, c, h, w] = bottom[n, c, h-pad, w-pad] if in range + for (int n = 0; n < NUM_; ++n) { + for (int c = 0; c < CHANNEL_; ++c) { + for (int h = 0; h < HEIGHT_IN_; ++h) { + // copy the width part + memcpy( + top_data + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_) + * WIDTH_OUT_ + PAD_, + bottom_data + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_, + sizeof(Dtype) * WIDTH_IN_); + } + } + } +} + +template +Dtype PaddingLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + //memset(bottom_data, 0, sizeof(Dtype) * (*bottom)[0]->count()); + for (int n = 0; n < NUM_; ++n) { + for (int c = 0; c < CHANNEL_; ++c) { + for (int h = 0; h < HEIGHT_IN_; ++h) { + // copy the width part + memcpy( + bottom_diff + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_, + top_diff + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_) + * WIDTH_OUT_ + PAD_, + sizeof(Dtype) * WIDTH_IN_); + } + } + } + return Dtype(0.); +} + +template +__global__ void PaddingForward(const int count, const Dtype* in, Dtype* out, + const int num, const int channel, const int height_in, const int width_in, + const int pad) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < count) { + int height_out = height_in + pad + pad; + int width_out = width_in + pad + pad; + int w = index % width_in; + index /= width_in; + int h = index % height_in; + index /= height_in; + int c = index % channel; + index /= channel; + out[((index * channel + c) * height_out + h + pad) * width_out + pad + w] = + in[((index * channel + c) * height_in + h) * width_in + w]; + } +} + +template +void PaddingLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = (*top)[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + // First, set all data to be zero for the boundary pixels + CUDA_CHECK(cudaMemset(top_data, 0, sizeof(Dtype) * (*top)[0]->count())); + PaddingForward<<>>( + count, bottom_data, top_data, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_, + PAD_); + CUDA_POST_KERNEL_CHECK; +} + +template +__global__ void PaddingBackward(const int count, const Dtype* in, Dtype* out, + const int num, const int channel, const int height_in, const int width_in, + const int pad) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < count) { + int height_out = height_in + pad + pad; + int width_out = width_in + pad + pad; + int w = index % width_in; + index /= width_in; + int h = index % height_in; + index /= height_in; + int c = index % channel; + index /= channel; + out[((index * channel + c) * height_in + h) * width_in + w] = + in[((index * channel + c) * height_out + h + pad) * width_out + pad + w]; + } +} + +template +Dtype PaddingLayer::Backward_gpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + if (propagate_down) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + const int count = (*bottom)[0]->count(); + PaddingBackward<<>>( + count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_, + PAD_); + } + return Dtype(0); +} + +INSTANTIATE_CLASS(PaddingLayer); + + +} // namespace caffeine diff --git a/src/caffeine/proto/layer_param.proto b/src/caffeine/proto/layer_param.proto index 27246ff..a20342f 100644 --- a/src/caffeine/proto/layer_param.proto +++ b/src/caffeine/proto/layer_param.proto @@ -31,13 +31,12 @@ message LayerParameter { optional FillerParameter bias_filler = 6; // The filler for the bias optional uint32 pad = 7 [default = 0]; // The padding size - optional float pad_value = 8 [default = 0]; // The padding value - optional uint32 kernelsize = 9; // The kernel size - optional uint32 group = 10 [default = 1]; // The group size for group conv - optional uint32 stride = 11 [default = 1]; // The stride - optional string pool = 12 [default = 'max']; // The pooling method - optional float dropout_ratio = 13 [default = 0.5]; // dropout ratio + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + optional string pool = 11 [default = 'max']; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio - optional float alpha = 14 [default = 1.]; // for local response norm - optional float beta = 15 [default = 0.75]; // for local response norm + optional float alpha = 13 [default = 1.]; // for local response norm + optional float beta = 14 [default = 0.75]; // for local response norm } diff --git a/src/caffeine/test_padding_layer.cpp b/src/caffeine/test_padding_layer.cpp new file mode 100644 index 0000000..8f28040 --- /dev/null +++ b/src/caffeine/test_padding_layer.cpp @@ -0,0 +1,103 @@ +#include +#include + +#include "gtest/gtest.h" +#include "caffeine/blob.hpp" +#include "caffeine/common.hpp" +#include "caffeine/filler.hpp" +#include "caffeine/vision_layers.hpp" +#include "caffeine/test/test_gradient_check_util.hpp" + + +namespace caffeine { + +extern cudaDeviceProp CAFFEINE_TEST_CUDA_PROP; + +template +class PaddingLayerTest : public ::testing::Test { + protected: + PaddingLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + }; + virtual ~PaddingLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(PaddingLayerTest, Dtypes); + +TYPED_TEST(PaddingLayerTest, TestCPU) { + LayerParameter layer_param; + layer_param.set_pad(1); + Caffeine::set_mode(Caffeine::CPU); + PaddingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 3); + EXPECT_EQ(this->blob_top_->height(), 6); + EXPECT_EQ(this->blob_top_->width(), 7); + for (int n = 0; n < 2; ++n) { + for (int c = 0; c < 3; ++c) { + for (int h = 0; h < 4; ++h) { + for (int w = 0; w < 5; ++w) { + EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w), + this->blob_top_->data_at(n, c, h + 1, w + 1)); + } + } + } + } +} + +TYPED_TEST(PaddingLayerTest, TestCPUGrad) { + LayerParameter layer_param; + layer_param.set_pad(1); + Caffeine::set_mode(Caffeine::CPU); + PaddingLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_); +} + +TYPED_TEST(PaddingLayerTest, TestGPU) { + LayerParameter layer_param; + layer_param.set_pad(1); + Caffeine::set_mode(Caffeine::GPU); + PaddingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 3); + EXPECT_EQ(this->blob_top_->height(), 6); + EXPECT_EQ(this->blob_top_->width(), 7); + for (int n = 0; n < 2; ++n) { + for (int c = 0; c < 3; ++c) { + for (int h = 0; h < 4; ++h) { + for (int w = 0; w < 5; ++w) { + EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w), + this->blob_top_->data_at(n, c, h + 1, w + 1)); + } + } + } + } +} + +TYPED_TEST(PaddingLayerTest, TestGPUGrad) { + LayerParameter layer_param; + layer_param.set_pad(1); + Caffeine::set_mode(Caffeine::GPU); + PaddingLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_); +} + +} diff --git a/src/caffeine/vision_layers.hpp b/src/caffeine/vision_layers.hpp index e324c8e..19fd00e 100644 --- a/src/caffeine/vision_layers.hpp +++ b/src/caffeine/vision_layers.hpp @@ -83,6 +83,31 @@ class InnerProductLayer : public Layer { shared_ptr bias_multiplier_; }; +template +class PaddingLayer : public Layer { + public: + explicit PaddingLayer(const LayerParameter& param) + : Layer(param) {}; + virtual void SetUp(const vector*>& bottom, + vector*>* top); + protected: + virtual void Forward_cpu(const vector*>& bottom, + vector*>* top); + virtual void Forward_gpu(const vector*>& bottom, + vector*>* top); + virtual Dtype Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); + virtual Dtype Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); + unsigned int PAD_; + int NUM_; + int CHANNEL_; + int HEIGHT_IN_; + int WIDTH_IN_; + int HEIGHT_OUT_; + int WIDTH_OUT_; +}; + } // namespace caffeine #endif // CAFFEINE_VISION_LAYERS_HPP_