inline int height() const { return height_; }
inline int width() const { return width_; }
inline int count() const {return count_; }
+
+ inline Dtype data_at(const int n, const int c, const int h,
+ const int w) const {
+ return cpu_data()[((n * channels_ + c) * height_ + h) * width_ + w];
+ }
+
+ inline Dtype diff_at(const int n, const int c, const int h,
+ const int w) const {
+ return cpu_diff()[((n * channels_ + c) * height_ + h) * width_ + w];
+ }
const Dtype* cpu_data() const;
const Dtype* gpu_data() const;
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
+#define CUDA_POST_KERNEL_CHECK \
+ if (cudaSuccess != cudaPeekAtLastError()) {\
+ LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \
+ }
+
#define INSTANTIATE_CLASS(classname) \
template class classname<float>; \
template class classname<double>
break;
case sizeof(double):
VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffeine::vsl_stream(),
- count, (double*)data, this->filler_param_.min(), this->filler_param_.max()));
+ count, (double*)data, this->filler_param_.min(),
+ this->filler_param_.max()));
break;
default:
CHECK(false) << "Unknown dtype.";
--- /dev/null
+#include "caffeine/layer.hpp"
+#include "caffeine/vision_layers.hpp"
+
+#include <iostream>
+
+namespace caffeine {
+
+template <typename Dtype>
+void PaddingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ PAD_ = this->layer_param_.pad();
+ CHECK_EQ(bottom.size(), 1) << "Padding Layer takes a single blob as input.";
+ CHECK_EQ(top->size(), 1) << "Padding Layer takes a single blob as output.";
+ NUM_ = bottom[0]->num();
+ CHANNEL_ = bottom[0]->channels();
+ HEIGHT_IN_ = bottom[0]->height();
+ WIDTH_IN_ = bottom[0]->width();
+ HEIGHT_OUT_ = HEIGHT_IN_ + PAD_ * 2;
+ WIDTH_OUT_ = WIDTH_IN_ + PAD_ * 2;
+ (*top)[0]->Reshape(NUM_, CHANNEL_, HEIGHT_OUT_, WIDTH_OUT_);
+
+};
+
+template <typename Dtype>
+void PaddingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ Dtype* top_data = (*top)[0]->mutable_cpu_data();
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ memset(top_data, 0, sizeof(Dtype) * (*top)[0]->count());
+ // In short, top[n, c, h, w] = bottom[n, c, h-pad, w-pad] if in range
+ for (int n = 0; n < NUM_; ++n) {
+ for (int c = 0; c < CHANNEL_; ++c) {
+ for (int h = 0; h < HEIGHT_IN_; ++h) {
+ // copy the width part
+ memcpy(
+ top_data + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
+ * WIDTH_OUT_ + PAD_,
+ bottom_data + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
+ sizeof(Dtype) * WIDTH_IN_);
+ }
+ }
+ }
+}
+
+template <typename Dtype>
+Dtype PaddingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+ //memset(bottom_data, 0, sizeof(Dtype) * (*bottom)[0]->count());
+ for (int n = 0; n < NUM_; ++n) {
+ for (int c = 0; c < CHANNEL_; ++c) {
+ for (int h = 0; h < HEIGHT_IN_; ++h) {
+ // copy the width part
+ memcpy(
+ bottom_diff + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
+ top_diff + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
+ * WIDTH_OUT_ + PAD_,
+ sizeof(Dtype) * WIDTH_IN_);
+ }
+ }
+ }
+ return Dtype(0.);
+}
+
+template <typename Dtype>
+__global__ void PaddingForward(const int count, const Dtype* in, Dtype* out,
+ const int num, const int channel, const int height_in, const int width_in,
+ const int pad) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < count) {
+ int height_out = height_in + pad + pad;
+ int width_out = width_in + pad + pad;
+ int w = index % width_in;
+ index /= width_in;
+ int h = index % height_in;
+ index /= height_in;
+ int c = index % channel;
+ index /= channel;
+ out[((index * channel + c) * height_out + h + pad) * width_out + pad + w] =
+ in[((index * channel + c) * height_in + h) * width_in + w];
+ }
+}
+
+template <typename Dtype>
+void PaddingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = (*top)[0]->mutable_gpu_data();
+ const int count = bottom[0]->count();
+ // First, set all data to be zero for the boundary pixels
+ CUDA_CHECK(cudaMemset(top_data, 0, sizeof(Dtype) * (*top)[0]->count()));
+ PaddingForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+ count, bottom_data, top_data, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
+ PAD_);
+ CUDA_POST_KERNEL_CHECK;
+}
+
+template <typename Dtype>
+__global__ void PaddingBackward(const int count, const Dtype* in, Dtype* out,
+ const int num, const int channel, const int height_in, const int width_in,
+ const int pad) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < count) {
+ int height_out = height_in + pad + pad;
+ int width_out = width_in + pad + pad;
+ int w = index % width_in;
+ index /= width_in;
+ int h = index % height_in;
+ index /= height_in;
+ int c = index % channel;
+ index /= channel;
+ out[((index * channel + c) * height_in + h) * width_in + w] =
+ in[((index * channel + c) * height_out + h + pad) * width_out + pad + w];
+ }
+}
+
+template <typename Dtype>
+Dtype PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ if (propagate_down) {
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+ const int count = (*bottom)[0]->count();
+ PaddingBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+ count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
+ PAD_);
+ }
+ return Dtype(0);
+}
+
+INSTANTIATE_CLASS(PaddingLayer);
+
+
+} // namespace caffeine
optional FillerParameter bias_filler = 6; // The filler for the bias
optional uint32 pad = 7 [default = 0]; // The padding size
- optional float pad_value = 8 [default = 0]; // The padding value
- optional uint32 kernelsize = 9; // The kernel size
- optional uint32 group = 10 [default = 1]; // The group size for group conv
- optional uint32 stride = 11 [default = 1]; // The stride
- optional string pool = 12 [default = 'max']; // The pooling method
- optional float dropout_ratio = 13 [default = 0.5]; // dropout ratio
+ optional uint32 kernelsize = 8; // The kernel size
+ optional uint32 group = 9 [default = 1]; // The group size for group conv
+ optional uint32 stride = 10 [default = 1]; // The stride
+ optional string pool = 11 [default = 'max']; // The pooling method
+ optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio
- optional float alpha = 14 [default = 1.]; // for local response norm
- optional float beta = 15 [default = 0.75]; // for local response norm
+ optional float alpha = 13 [default = 1.]; // for local response norm
+ optional float beta = 14 [default = 0.75]; // for local response norm
}
--- /dev/null
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffeine/blob.hpp"
+#include "caffeine/common.hpp"
+#include "caffeine/filler.hpp"
+#include "caffeine/vision_layers.hpp"
+#include "caffeine/test/test_gradient_check_util.hpp"
+
+
+namespace caffeine {
+
+extern cudaDeviceProp CAFFEINE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class PaddingLayerTest : public ::testing::Test {
+ protected:
+ PaddingLayerTest()
+ : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
+ blob_top_(new Blob<Dtype>()) {
+ // fill the values
+ FillerParameter filler_param;
+ GaussianFiller<Dtype> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ blob_bottom_vec_.push_back(blob_bottom_);
+ blob_top_vec_.push_back(blob_top_);
+ };
+ virtual ~PaddingLayerTest() { delete blob_bottom_; delete blob_top_; }
+ Blob<Dtype>* const blob_bottom_;
+ Blob<Dtype>* const blob_top_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(PaddingLayerTest, Dtypes);
+
+TYPED_TEST(PaddingLayerTest, TestCPU) {
+ LayerParameter layer_param;
+ layer_param.set_pad(1);
+ Caffeine::set_mode(Caffeine::CPU);
+ PaddingLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ EXPECT_EQ(this->blob_top_->num(), 2);
+ EXPECT_EQ(this->blob_top_->channels(), 3);
+ EXPECT_EQ(this->blob_top_->height(), 6);
+ EXPECT_EQ(this->blob_top_->width(), 7);
+ for (int n = 0; n < 2; ++n) {
+ for (int c = 0; c < 3; ++c) {
+ for (int h = 0; h < 4; ++h) {
+ for (int w = 0; w < 5; ++w) {
+ EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w),
+ this->blob_top_->data_at(n, c, h + 1, w + 1));
+ }
+ }
+ }
+ }
+}
+
+TYPED_TEST(PaddingLayerTest, TestCPUGrad) {
+ LayerParameter layer_param;
+ layer_param.set_pad(1);
+ Caffeine::set_mode(Caffeine::CPU);
+ PaddingLayer<TypeParam> layer(layer_param);
+ GradientChecker<TypeParam> checker(1e-2, 1e-3);
+ checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+TYPED_TEST(PaddingLayerTest, TestGPU) {
+ LayerParameter layer_param;
+ layer_param.set_pad(1);
+ Caffeine::set_mode(Caffeine::GPU);
+ PaddingLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ EXPECT_EQ(this->blob_top_->num(), 2);
+ EXPECT_EQ(this->blob_top_->channels(), 3);
+ EXPECT_EQ(this->blob_top_->height(), 6);
+ EXPECT_EQ(this->blob_top_->width(), 7);
+ for (int n = 0; n < 2; ++n) {
+ for (int c = 0; c < 3; ++c) {
+ for (int h = 0; h < 4; ++h) {
+ for (int w = 0; w < 5; ++w) {
+ EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w),
+ this->blob_top_->data_at(n, c, h + 1, w + 1));
+ }
+ }
+ }
+ }
+}
+
+TYPED_TEST(PaddingLayerTest, TestGPUGrad) {
+ LayerParameter layer_param;
+ layer_param.set_pad(1);
+ Caffeine::set_mode(Caffeine::GPU);
+ PaddingLayer<TypeParam> layer(layer_param);
+ GradientChecker<TypeParam> checker(1e-2, 1e-3);
+ checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+}
shared_ptr<SyncedMemory> bias_multiplier_;
};
+template <typename Dtype>
+class PaddingLayer : public Layer<Dtype> {
+ public:
+ explicit PaddingLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {};
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ unsigned int PAD_;
+ int NUM_;
+ int CHANNEL_;
+ int HEIGHT_IN_;
+ int WIDTH_IN_;
+ int HEIGHT_OUT_;
+ int WIDTH_OUT_;
+};
+
} // namespace caffeine
#endif // CAFFEINE_VISION_LAYERS_HPP_