misc update
authorYangqing Jia <jiayq84@gmail.com>
Thu, 19 Sep 2013 06:32:04 +0000 (23:32 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 19 Sep 2013 06:32:04 +0000 (23:32 -0700)
src/caffeine/blob.hpp
src/caffeine/common.hpp
src/caffeine/layer.cpp
src/caffeine/layers/im2col_layer.cpp [new file with mode: 0644]
src/caffeine/layers/lrn_layer.cu [new file with mode: 0644]
src/caffeine/proto/layer_param.proto
src/caffeine/test/test_im2col_layer.cpp [new file with mode: 0644]
src/caffeine/util/im2col.cpp [new file with mode: 0644]
src/caffeine/util/im2col.hpp [new file with mode: 0644]
src/caffeine/vision_layers.hpp

index 8a729a5..acef484 100644 (file)
@@ -24,15 +24,19 @@ class Blob {
   inline int height() const { return height_; }
   inline int width() const { return width_; }
   inline int count() const {return count_; }
+  inline int offset(const int n, const int c = 0, const int h = 0,
+      const int w = 0) const {
+    return ((n * channels_ + c) * height_ + h) * width_ + w;
+  }
 
   inline Dtype data_at(const int n, const int c, const int h,
       const int w) const {
-    return cpu_data()[((n * channels_ + c) * height_ + h) * width_ + w];
+    return *(cpu_data() + offset(n, c, h, w));
   }
 
   inline Dtype diff_at(const int n, const int c, const int h,
       const int w) const {
-    return cpu_diff()[((n * channels_ + c) * height_ + h) * width_ + w];
+    return *(cpu_diff() + offset(n, c, h, w));
   }
   
   const Dtype* cpu_data() const;
index 2da0df1..9fd5084 100644 (file)
@@ -24,7 +24,7 @@
   template class classname<float>; \
   template class classname<double>
 
-#define NOT_IMPLEMENTED CHECK(false) << "Not Implemented"
+#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
 
 namespace caffeine {
 
index c33746c..9d61fe9 100644 (file)
@@ -36,7 +36,6 @@ inline Dtype Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
   }
 };
 
-template class Layer<float>;
-template class Layer<double>;
+INSTANTIATE_CLASS(Layer);
 
 }  // namespace caffeine
diff --git a/src/caffeine/layers/im2col_layer.cpp b/src/caffeine/layers/im2col_layer.cpp
new file mode 100644 (file)
index 0000000..c6020e3
--- /dev/null
@@ -0,0 +1,46 @@
+#include "caffeine/layer.hpp"
+#include "caffeine/util/im2col.hpp"
+#include "caffeine/vision_layers.hpp"
+
+namespace caffeine {
+
+template <typename Dtype>
+void Im2colLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 1) << "Im2col Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << "Im2col Layer takes a single blob as output.";
+  KSIZE_ = this->layer_param_.kernelsize();
+  STRIDE_ = this->layer_param_.stride();
+  CHANNELS_ = bottom[0]->channels();
+  HEIGHT_ = bottom[0]->height();
+  WIDTH_ = bottom[0]->width();
+  (*top)[0]->Reshape(bottom[0]->num(), CHANNELS_ * KSIZE_ * KSIZE_,
+      (HEIGHT_ - KSIZE_) / STRIDE_ + 1, (WIDTH_ - KSIZE_) / STRIDE_ + 1);
+};
+
+template <typename Dtype>
+void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  for (int n = 0; n < bottom[0]->num(); ++n) {
+    im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+        WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n));
+  }
+}
+
+template <typename Dtype>
+Dtype Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  const Dtype* top_diff = top[0]->cpu_diff();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  for (int n = 0; n < top[0]->num(); ++n) {
+    col2im_cpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
+        WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+  }
+  return Dtype(0.);
+}
+
+INSTANTIATE_CLASS(Im2colLayer);
+
+}  // namespace caffeine
diff --git a/src/caffeine/layers/lrn_layer.cu b/src/caffeine/layers/lrn_layer.cu
new file mode 100644 (file)
index 0000000..53e423f
--- /dev/null
@@ -0,0 +1,40 @@
+#include "caffeine/layer.hpp"
+#include "caffeine/vision_layers.hpp"
+
+
+namespace caffeine {
+
+template <typename Dtype>
+void LRNLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 1) <<
+      "Local Response Normalization Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << 
+      "Local Response Normalization Layer takes a single blob as output.";
+  (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
+      bottom[0]->height(), bottom[0]->width());
+};
+
+template <typename Dtype>
+void LRNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  /*
+  const int size = this->layer_param_->local_size();
+  const int pre_pad = (size - 1) / 2;
+  const Dtype alpha = this->layer_param_->alpha();
+  const Dtype beta = this->layer_param_->beta();
+  */
+  NOT_IMPLEMENTED;
+}
+
+template <typename Dtype>
+Dtype LRNLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  NOT_IMPLEMENTED;
+  return Dtype(0.);
+}
+
+INSTANTIATE_CLASS(LRNLayer);
+
+
+}  // namespace caffeine
index a20342f..4c87680 100644 (file)
@@ -25,7 +25,7 @@ message LayerParameter {
   required string type = 2; // the string to specify the layer type
 
   // Parameters to specify layers with inner products.
-  optional int32 num_output = 3; // The number of outputs for the layer
+  optional uint32 num_output = 3; // The number of outputs for the layer
   optional bool biasterm = 4 [default = true]; // whether to have bias terms
   optional FillerParameter weight_filler = 5; // The filler for the weight
   optional FillerParameter bias_filler = 6; // The filler for the bias
@@ -37,6 +37,7 @@ message LayerParameter {
   optional string pool = 11 [default = 'max']; // The pooling method
   optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio
 
-  optional float alpha = 13 [default = 1.]; // for local response norm
-  optional float beta = 14 [default = 0.75]; // for local response norm
+  optional uint32 local_size = 13 [default = 5]; // for local response norm
+  optional float alpha = 14 [default = 1.]; // for local response norm
+  optional float beta = 15 [default = 0.75]; // for local response norm
 }
diff --git a/src/caffeine/test/test_im2col_layer.cpp b/src/caffeine/test/test_im2col_layer.cpp
new file mode 100644 (file)
index 0000000..d192d75
--- /dev/null
@@ -0,0 +1,74 @@
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffeine/blob.hpp"
+#include "caffeine/common.hpp"
+#include "caffeine/filler.hpp"
+#include "caffeine/vision_layers.hpp"
+#include "caffeine/test/test_gradient_check_util.hpp"
+
+
+namespace caffeine {
+
+extern cudaDeviceProp CAFFEINE_TEST_CUDA_PROP;
+  
+template <typename Dtype>
+class Im2colLayerTest : public ::testing::Test {
+ protected:
+  Im2colLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)),
+        blob_top_(new Blob<Dtype>()) {
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  };
+  virtual ~Im2colLayerTest() { delete blob_bottom_; delete blob_top_; }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(Im2colLayerTest, Dtypes);
+
+TYPED_TEST(Im2colLayerTest, TestSetup) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  Im2colLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 27);
+  EXPECT_EQ(this->blob_top_->height(), 2);
+  EXPECT_EQ(this->blob_top_->width(), 2);
+}
+
+TYPED_TEST(Im2colLayerTest, TestCPU) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  Im2colLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // We are lazy and will only check the top left block
+  for (int c = 0; c < 27; ++c) {
+    EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
+        this->blob_bottom_->data_at(0, (c / 9), (c / 3) % 3, c % 3));
+  }
+}
+
+TYPED_TEST(Im2colLayerTest, TestCPUGradient) {
+  LayerParameter layer_param;
+  layer_param.set_kernelsize(3);
+  layer_param.set_stride(2);
+  Im2colLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+}
diff --git a/src/caffeine/util/im2col.cpp b/src/caffeine/util/im2col.cpp
new file mode 100644 (file)
index 0000000..11c3bf6
--- /dev/null
@@ -0,0 +1,67 @@
+#include <cmath>
+#include <cstdlib>
+#include <cstring>
+
+#include "caffeine/util/im2col.hpp"
+
+namespace caffeine {
+
+template <typename Dtype>
+void im2col_cpu(const Dtype* data_im, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    Dtype* data_col) {
+  int height_col = (height - ksize) / stride + 1;
+  int width_col = (width - ksize) / stride + 1;
+  int channels_col = channels * ksize * ksize;
+  for (int c = 0; c < channels_col; ++c) {
+    int w_offset = c % ksize;
+    int h_offset = (c / ksize) % ksize;
+    int c_im = c / ksize / ksize;
+    for (int h = 0; h < height_col; ++h) {
+      for (int w = 0; w < width_col; ++w) {
+        data_col[(c * height_col + h) * width_col + w] =
+            data_im[(c_im * height + h * stride + h_offset) * width
+                + w * stride + w_offset];
+      }
+    }
+  }
+}
+
+// Explicit instantiation
+template void im2col_cpu<float>(const float* data_im, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    float* data_col);
+template void im2col_cpu<double>(const double* data_im, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    double* data_col);
+
+template <typename Dtype>
+void col2im_cpu(const Dtype* data_col, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    Dtype* data_im) {
+  memset(data_im, 0, sizeof(Dtype) * height * width * channels);
+  int height_col = (height - ksize) / stride + 1;
+  int width_col = (width - ksize) / stride + 1;
+  int channels_col = channels * ksize * ksize;
+  for (int c = 0; c < channels_col; ++c) {
+    int w_offset = c % ksize;
+    int h_offset = (c / ksize) % ksize;
+    int c_im = c / ksize / ksize;
+    for (int h = 0; h < height_col; ++h) {
+      for (int w = 0; w < width_col; ++w) {
+        data_im[(c_im * height + h * stride + h_offset) * width + w * stride 
+            + w_offset] += data_col[(c * height_col + h) * width_col + w];
+      }
+    }
+  }
+}
+
+// Explicit instantiation
+template void col2im_cpu<float>(const float* data_col, const int channels,
+    const int height, const int width, const int psize, const int stride,
+    float* data_im);
+template void col2im_cpu<double>(const double* data_col, const int channels,
+    const int height, const int width, const int psize, const int stride,
+    double* data_im);
+
+}  // namespace caffeine
diff --git a/src/caffeine/util/im2col.hpp b/src/caffeine/util/im2col.hpp
new file mode 100644 (file)
index 0000000..76f401d
--- /dev/null
@@ -0,0 +1,20 @@
+#ifndef _CAFFEINE_UTIL__IM2COL_HPP_
+#define _CAFFEINE_UTIL_IM2COL_HPP_
+
+namespace caffeine {
+
+template <typename Dtype>
+void im2col_cpu(const Dtype* data_im, const int channels,
+    const int height, const int width, const int ksize, const int stride,
+    Dtype* data_col);
+
+template <typename Dtype>
+void col2im_cpu(const Dtype* data_col, const int channels,
+    const int height, const int width, const int psize, const int stride,
+    Dtype* data_im);
+
+
+
+}  // namespace caffeine
+
+#endif  // CAFFEINE_UTIL_IM2COL_HPP_
index 19fd00e..2fe6a1b 100644 (file)
@@ -108,6 +108,66 @@ class PaddingLayer : public Layer<Dtype> {
   int WIDTH_OUT_;
 };
 
+template <typename Dtype>
+class LRNLayer : public Layer<Dtype> {
+ public:
+  explicit LRNLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {};
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+  //    vector<Blob<Dtype>*>* top);
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+  //    const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+template <typename Dtype>
+class Im2colLayer : public Layer<Dtype> {
+ public:
+  explicit Im2colLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {};
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+  //    vector<Blob<Dtype>*>* top);
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+  //    const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  int KSIZE_;
+  int STRIDE_;
+  int CHANNELS_;
+  int HEIGHT_;
+  int WIDTH_;
+};
+
+template <typename Dtype>
+class ConvolutionLayer : public Layer<Dtype> {
+ public:
+  explicit ConvolutionLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {};
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+  //    vector<Blob<Dtype>*>* top);
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+  //    const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  Blob<Dtype> col_bob_;
+};
+
 }  // namespace caffeine
 
 #endif  // CAFFEINE_VISION_LAYERS_HPP_