lrn forward
authorYangqing Jia <jiayq84@gmail.com>
Thu, 19 Sep 2013 20:29:32 +0000 (13:29 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 19 Sep 2013 20:29:32 +0000 (13:29 -0700)
src/caffeine/layers/lrn_layer.cu
src/caffeine/test/test_im2col_layer.cpp
src/caffeine/test/test_lrn_layer.cpp [new file with mode: 0644]
src/caffeine/util/blas.cpp
src/caffeine/util/blas.hpp
src/caffeine/vision_layers.hpp

index 53e423f..9103ce9 100644 (file)
@@ -1,6 +1,6 @@
 #include "caffeine/layer.hpp"
 #include "caffeine/vision_layers.hpp"
-
+#include "caffeine/util/blas.hpp"
 
 namespace caffeine {
 
@@ -11,20 +11,63 @@ void LRNLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       "Local Response Normalization Layer takes a single blob as input.";
   CHECK_EQ(top->size(), 1) << 
       "Local Response Normalization Layer takes a single blob as output.";
-  (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
-      bottom[0]->height(), bottom[0]->width());
+  num_ = bottom[0]->num();
+  channels_ = bottom[0]->channels();
+  height_ = bottom[0]->height();
+  width_ = bottom[0]->width();
+  (*top)[0]->Reshape(num_, channels_, height_, width_);
+  scale_.Reshape(num_, channels_, height_, width_);
+  size_ = this->layer_param_.local_size();
+  pre_pad_ = (size_ - 1) / 2;
+  alpha_ = this->layer_param_.alpha();
+  beta_ = this->layer_param_.beta();
 };
 
 template <typename Dtype>
 void LRNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
-  /*
-  const int size = this->layer_param_->local_size();
-  const int pre_pad = (size - 1) / 2;
-  const Dtype alpha = this->layer_param_->alpha();
-  const Dtype beta = this->layer_param_->beta();
-  */
-  NOT_IMPLEMENTED;
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  Dtype* scale_data = scale_.mutable_cpu_data();
+  // start with the constant value
+  for (int i = 0; i < scale_.count(); ++i) {
+    scale_data[i] = 1.;
+  }
+  Blob<Dtype> padded_square(1, channels_ + size_ - 1, height_, width_);
+  Dtype* padded_square_data = padded_square.mutable_cpu_data();
+  memset(padded_square_data, 0, sizeof(Dtype) * padded_square.count());
+  Dtype alpha_over_size = alpha_ / size_;
+  // go through the images
+  for (int n = 0; n < num_; ++n) {
+    // compute the padded square
+    caffeine_sqr(channels_ * height_ * width_,
+        bottom_data + bottom[0]->offset(n),
+        padded_square_data + padded_square.offset(0, pre_pad_));
+    // Create the first channel scale
+    for (int c = 0; c < size_; ++c) {
+      caffeine_axpy<Dtype>(height_ * width_, alpha_over_size,
+          padded_square_data + padded_square.offset(0, c),
+          scale_data + scale_.offset(n, 0));
+    }
+    for (int c = 1; c < channels_; ++c) {
+      // copy previous scale
+      caffeine_copy<Dtype>(height_ * width_,
+          scale_data + scale_.offset(n, c - 1),
+          scale_data + scale_.offset(n, c));
+      // add head
+      caffeine_axpy<Dtype>(height_ * width_, alpha_over_size,
+          padded_square_data + padded_square.offset(0, c + size_ - 1),
+          scale_data + scale_.offset(n, c));
+      // subtract tail
+      caffeine_axpy<Dtype>(height_ * width_, -alpha_over_size,
+          padded_square_data + padded_square.offset(0, c - 1),
+          scale_data + scale_.offset(n, c));
+    }
+  }
+
+  // In the end, compute output
+  caffeine_powx<Dtype>(scale_.count(), scale_data, -beta_, top_data);
+  caffeine_mul<Dtype>(scale_.count(), top_data, bottom_data, top_data);
 }
 
 template <typename Dtype>
index d192d75..da5632d 100644 (file)
@@ -53,6 +53,7 @@ TYPED_TEST(Im2colLayerTest, TestCPU) {
   layer_param.set_kernelsize(3);
   layer_param.set_stride(2);
   Im2colLayer<TypeParam> layer(layer_param);
+  Caffeine::set_mode(Caffeine::CPU);
   layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
   layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
   // We are lazy and will only check the top left block
@@ -66,6 +67,7 @@ TYPED_TEST(Im2colLayerTest, TestCPUGradient) {
   LayerParameter layer_param;
   layer_param.set_kernelsize(3);
   layer_param.set_stride(2);
+  Caffeine::set_mode(Caffeine::CPU);
   Im2colLayer<TypeParam> layer(layer_param);
   GradientChecker<TypeParam> checker(1e-2, 1e-2);
   checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_);
diff --git a/src/caffeine/test/test_lrn_layer.cpp b/src/caffeine/test/test_lrn_layer.cpp
new file mode 100644 (file)
index 0000000..af76cf3
--- /dev/null
@@ -0,0 +1,103 @@
+#include <algorithm>
+#include <cstring>
+#include <cuda_runtime.h>
+#include <iostream>
+
+#include "gtest/gtest.h"
+#include "caffeine/blob.hpp"
+#include "caffeine/common.hpp"
+#include "caffeine/filler.hpp"
+#include "caffeine/vision_layers.hpp"
+#include "caffeine/test/test_gradient_check_util.hpp"
+
+using std::min;
+using std::max;
+
+namespace caffeine {
+
+extern cudaDeviceProp CAFFEINE_TEST_CUDA_PROP;
+  
+template <typename Dtype>
+class LRNLayerTest : public ::testing::Test {
+ protected:
+  LRNLayerTest()
+      : blob_bottom_(new Blob<Dtype>(1, 10, 1, 1)),
+        blob_top_(new Blob<Dtype>()) {
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  };
+  virtual ~LRNLayerTest() { delete blob_bottom_; delete blob_top_; }
+  void ReferenceLRNForward(const Blob<Dtype>& blob_bottom,
+      const LayerParameter& layer_param, Blob<Dtype>* blob_top);
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+template <typename Dtype>
+void LRNLayerTest<Dtype>::ReferenceLRNForward(
+    const Blob<Dtype>& blob_bottom, const LayerParameter& layer_param, 
+    Blob<Dtype>* blob_top) {
+  blob_top->Reshape(blob_bottom.num(), blob_bottom.channels(),
+      blob_bottom.height(), blob_bottom.width());
+  const Dtype* bottom_data = blob_bottom.cpu_data();
+  Dtype* top_data = blob_top->mutable_cpu_data();
+  Dtype alpha = layer_param.alpha();
+  Dtype beta = layer_param.beta();
+  int size = layer_param.local_size();
+  for (int n = 0; n < blob_bottom.num(); ++n) {
+    for (int c = 0; c < blob_bottom.channels(); ++c) {
+      for (int h = 0; h < blob_bottom.height(); ++h) {
+        for (int w = 0; w < blob_bottom.width(); ++w) {
+          int c_start = c - (size - 1) / 2;
+          int c_end = min(c_start + size, blob_bottom.channels());
+          c_start = max(c_start, 0);
+          Dtype scale = 1.;
+          for (int i = c_start; i < c_end; ++i) {
+            Dtype value = blob_bottom.data_at(n, i, h, w);
+            scale += value * value * alpha / size;
+          }
+          *(top_data + blob_top->offset(n, c, h, w)) =
+            blob_bottom.data_at(n, c, h, w) / pow(scale, beta);
+        }
+      }
+    }
+  }
+}
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(LRNLayerTest, Dtypes);
+
+TYPED_TEST(LRNLayerTest, TestSetup) {
+  LayerParameter layer_param;
+  LRNLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), 1);
+  EXPECT_EQ(this->blob_top_->channels(), 10);
+  EXPECT_EQ(this->blob_top_->height(), 1);
+  EXPECT_EQ(this->blob_top_->width(), 1);
+}
+
+TYPED_TEST(LRNLayerTest, TestCPU) {
+  LayerParameter layer_param;
+  LRNLayer<TypeParam> layer(layer_param);
+  Caffeine::set_mode(Caffeine::CPU);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  Blob<TypeParam> top_reference;
+  this->ReferenceLRNForward(*(this->blob_bottom_), layer_param,
+      &top_reference);
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_GE(this->blob_top_->cpu_data()[i],
+        top_reference.cpu_data()[i] - 1e-5);
+    EXPECT_LE(this->blob_top_->cpu_data()[i],
+        top_reference.cpu_data()[i] + 1e-5);
+  }
+}
+
+}
index e03bac2..c542cbc 100644 (file)
@@ -93,4 +93,48 @@ void caffeine_gpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
       A, N, x, 1, &beta, y, 1));
 }
 
+template <>
+void caffeine_axpy<float>(const int N, const float alpha, const float* X,
+    float* Y) { cblas_saxpy(N, alpha, X, 1, Y, 1); }
+
+template <>
+void caffeine_axpy<double>(const int N, const double alpha, const double* X,
+    double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
+
+template <>
+void caffeine_copy<float>(const int N, const float* X, float* Y) {
+  cblas_scopy(N, X, 1, Y, 1);
+}
+
+template <>
+void caffeine_copy<double>(const int N, const double* X, double* Y) {
+  cblas_dcopy(N, X, 1, Y, 1);
+}
+
+template <>
+void caffeine_sqr<float>(const int n, const float* a, float* y){
+  vsSqr(n, a, y);
+}
+
+template <>
+void caffeine_sqr<double>(const int n, const double* a, double* y) {
+  vdSqr(n, a, y);
+}
+
+template <>
+void caffeine_mul<float>(const int n, const float* a, const float* b,
+    float* y) { vsMul(n, a, b, y); }
+
+template <>
+void caffeine_mul<double>(const int n, const double* a, const double* b,
+    double* y) { vdMul(n, a, b, y); }
+
+template <>
+void caffeine_powx<float>(const int n, const float* a, const float b,
+    float* y) { vsPowx(n, a, b, y); }
+
+template <>
+void caffeine_powx<double>(const int n, const double* a, const double b,
+    double* y) { vdPowx(n, a, b, y); }
+
 }  // namespace caffeine
index 00eeea4..509cf87 100644 (file)
@@ -33,6 +33,22 @@ void caffeine_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
     const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta,
     Dtype* y);
 
+template <typename Dtype>
+void caffeine_axpy(const int N, const Dtype alpha, const Dtype* X,
+    Dtype* Y);
+
+template <typename Dtype>
+void caffeine_copy(const int N, const Dtype *X, Dtype *Y);
+
+template <typename Dtype>
+void caffeine_sqr(const int N, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
+void caffeine_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
+void caffeine_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);
+
 }  // namespace caffeine
 
 
index 2fe6a1b..d931bc2 100644 (file)
@@ -124,6 +124,16 @@ class LRNLayer : public Layer<Dtype> {
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
   //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
   //    const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  // scale_ stores the intermediate summing results
+  Blob<Dtype> scale_;
+  int size_;
+  int pre_pad_;
+  Dtype alpha_;
+  Dtype beta_;
+  int num_;
+  int channels_;
+  int height_;
+  int width_;
 };
 
 template <typename Dtype>