im2col
authorYangqing Jia <jiayq84@gmail.com>
Thu, 19 Sep 2013 23:26:01 +0000 (16:26 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 19 Sep 2013 23:26:01 +0000 (16:26 -0700)
.gitignore
src/Makefile
src/caffeine/layers/dropout_layer.cu
src/caffeine/layers/im2col_layer.cpp
src/caffeine/layers/padding_layer.cu
src/caffeine/test/test_im2col_layer.cpp
src/caffeine/util/im2col.cu [new file with mode: 0644]
src/caffeine/util/im2col.hpp
src/caffeine/vision_layers.hpp

index f4489a6..ca27edc 100644 (file)
@@ -2,6 +2,7 @@
 *.slo
 *.lo
 *.o
+*.cuo
 
 # Compiled Dynamic libraries
 *.so
index cd89c7d..2b53501 100644 (file)
@@ -17,7 +17,7 @@ PROTO_GEN_HEADER := ${PROTO_SRCS:.proto=.pb.h}
 PROTO_GEN_CC := ${PROTO_SRCS:.proto=.pb.cc}
 PROTO_GEN_PY := ${PROTO_SRCS:.proto=_pb2.py}
 CXX_OBJS := ${CXX_SRCS:.cpp=.o}
-CU_OBJS := ${CU_SRCS:.cu=.o}
+CU_OBJS := ${CU_SRCS:.cu=.cuo}
 PROTO_OBJS := ${PROTO_SRCS:.proto=.pb.o}
 OBJS := $(PROTO_OBJS) $(CXX_OBJS) $(CU_OBJS)
 TEST_OBJS := ${TEST_SRCS:.cpp=.o}
@@ -63,7 +63,7 @@ $(TEST_BINS): %.testbin : %.o
 $(NAME): $(PROTO_GEN_CC) $(OBJS)
        $(LINK) -shared $(OBJS) -o $(NAME)
 
-$(CU_OBJS): %.o: %.cu
+$(CU_OBJS): %.cuo: %.cu
        $(NVCC) -c $< -o $@
 
 $(PROTO_GEN_CC): $(PROTO_SRCS)
index 9818907..2d9e260 100644 (file)
@@ -82,6 +82,7 @@ void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     DropoutForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
         count, bottom_data, (unsigned int*)rand_vec_->gpu_data(), uint_thres_, scale_,
         top_data);
+    CUDA_POST_KERNEL_CHECK;
   } else {
     CUDA_CHECK(cudaMemcpy(top_data, bottom_data,
         count * sizeof(Dtype), cudaMemcpyDeviceToDevice));
@@ -112,6 +113,7 @@ Dtype DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const int count = (*bottom)[0]->count();
     DropoutBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
         count, top_diff, mask, uint_thres_, scale_, bottom_diff);
+    CUDA_POST_KERNEL_CHECK;
   }
   return Dtype(0);
 }
index c6020e3..ee87493 100644 (file)
@@ -1,6 +1,7 @@
 #include "caffeine/layer.hpp"
 #include "caffeine/util/im2col.hpp"
 #include "caffeine/vision_layers.hpp"
+#include "caffeine/common.hpp"
 
 namespace caffeine {
 
@@ -30,6 +31,17 @@ void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
+void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  for (int n = 0; n < bottom[0]->num(); ++n) {
+    im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+        WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n));
+  }
+}
+
+template <typename Dtype>
 Dtype Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
index 2171467..53bd1f5 100644 (file)
@@ -126,6 +126,7 @@ Dtype PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     PaddingBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
         count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
         PAD_);
+    CUDA_POST_KERNEL_CHECK;
   }
   return Dtype(0);
 }
index 7456dfb..e2b28e6 100644 (file)
@@ -64,6 +64,21 @@ TYPED_TEST(Im2colLayerTest, TestCPU) {
   }
 }
 
+TYPED_TEST(Im2colLayerTest, TestGPU) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  Im2colLayer<TypeParam> layer(layer_param);
+  Caffeine::set_mode(Caffeine::GPU);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // We are lazy and will only check the top left block
+  for (int c = 0; c < 27; ++c) {
+    EXPECT_EQ(this->blob_bottom_->data_at(0, (c / 9), (c / 3) % 3, c % 3),
+        this->blob_top_->data_at(0, c, 0, 0));
+  }
+}
+
 TYPED_TEST(Im2colLayerTest, TestCPUGradient) {
   LayerParameter layer_param;
   layer_param.set_kernelsize(3);
diff --git a/src/caffeine/util/im2col.cu b/src/caffeine/util/im2col.cu
new file mode 100644 (file)
index 0000000..ac9e845
--- /dev/null
@@ -0,0 +1,87 @@
+#include <cmath>
+#include <cstdlib>
+#include <cstring>
+
+#include "caffeine/common.hpp"
+#include "caffeine/util/im2col.hpp"
+
+namespace caffeine {
+
+template <typename Dtype>
+__global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
+  const int height, const int width, const int ksize,
+  const int stride, const int height_col, const int width_col, Dtype* data_col) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < n) {
+    int w_out = index % width_col;
+    index /= width_col;
+    int h_out = index % height_col;
+    int channel_in = index / height_col;
+    int channel_out = channel_in * ksize * ksize;
+    int h_in = h_out * stride;
+    int w_in = w_out * stride;
+    data_col += (channel_out * height_col + h_out) * width_col + w_out;
+    data_im += (channel_in * height + h_in) * width + w_in;
+    for (int i = 0; i < ksize; ++i) {
+      for (int j = 0; j < ksize; ++j) {
+        *data_col = data_im[i * width + j];
+        data_col += height_col * width_col;
+      }
+    }
+  }
+}
+
+template <typename Dtype>
+void im2col_gpu(const Dtype* data_im, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    Dtype* data_col) {
+  // We are going to launch channels * height_col * width_col kernels, each
+  // kernel responsible for copying a single-channel grid.
+  int height_col = (height - ksize) / stride + 1;
+  int width_col = (width - ksize) / stride + 1;
+  int num_kernels = channels * height_col * width_col;
+  im2col_gpu_kernel<<<CAFFEINE_GET_BLOCKS(num_kernels), CAFFEINE_CUDA_NUM_THREADS>>>(
+      num_kernels, data_im, height, width, ksize, stride, height_col, width_col,
+      data_col);
+  CUDA_POST_KERNEL_CHECK;
+}
+
+// Explicit instantiation
+template void im2col_gpu<float>(const float* data_im, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    float* data_col);
+template void im2col_gpu<double>(const double* data_im, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    double* data_col);
+
+/*
+template <typename Dtype>
+void col2im_gpu(const Dtype* data_col, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    Dtype* data_im) {
+  memset(data_im, 0, sizeof(Dtype) * height * width * channels);
+  int height_col = (height - ksize) / stride + 1;
+  int width_col = (width - ksize) / stride + 1;
+  int channels_col = channels * ksize * ksize;
+  for (int c = 0; c < channels_col; ++c) {
+    int w_offset = c % ksize;
+    int h_offset = (c / ksize) % ksize;
+    int c_im = c / ksize / ksize;
+    for (int h = 0; h < height_col; ++h) {
+      for (int w = 0; w < width_col; ++w) {
+        data_im[(c_im * height + h * stride + h_offset) * width + w * stride 
+            + w_offset] += data_col[(c * height_col + h) * width_col + w];
+      }
+    }
+  }
+}
+
+// Explicit instantiation
+template void col2im_gpu<float>(const float* data_col, const int channels,
+    const int height, const int width, const int psize, const int stride,
+    float* data_im);
+template void col2im_gpu<double>(const double* data_col, const int channels,
+    const int height, const int width, const int psize, const int stride,
+    double* data_im);
+*/
+}  // namespace caffeine
index 76f401d..f634990 100644 (file)
@@ -13,7 +13,15 @@ void col2im_cpu(const Dtype* data_col, const int channels,
     const int height, const int width, const int psize, const int stride,
     Dtype* data_im);
 
+template <typename Dtype>
+void im2col_gpu(const Dtype* data_im, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    Dtype* data_col);
 
+template <typename Dtype>
+void col2im_gpu(const Dtype* data_col, const int channels,
+    const int height, const int width, const int psize, const int stride,
+    Dtype* data_im);
 
 }  // namespace caffeine
 
index d931bc2..2d24bf8 100644 (file)
@@ -146,8 +146,8 @@ class Im2colLayer : public Layer<Dtype> {
  protected:
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-  //    vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
   //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,