euclidean layer update
authorYangqing Jia <jiayq84@gmail.com>
Mon, 30 Sep 2013 18:04:40 +0000 (11:04 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 30 Sep 2013 18:04:40 +0000 (11:04 -0700)
src/caffe/layer_factory.hpp
src/caffe/layers/loss_layer.cu [moved from src/caffe/layers/multinomial_logistic_loss_layer.cu with 54% similarity]
src/caffe/test/test_euclidean_loss_layer.cpp [new file with mode: 0644]
src/caffe/test/test_multinomial_logistic_loss_layer.cpp
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.hpp
src/caffe/vision_layers.hpp

index 8453fd5..e9f8dbb 100644 (file)
@@ -25,6 +25,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
     return new DataLayer<Dtype>(param);
   } else if (type == "dropout") {
     return new DropoutLayer<Dtype>(param);
+  } else if (type == "euclidean_loss") {
+    return new EuclideanLossLayer<Dtype>(param);
   } else if (type == "im2col") {
     return new Im2colLayer<Dtype>(param);
   } else if (type == "innerproduct") {
similarity index 54%
rename from src/caffe/layers/multinomial_logistic_loss_layer.cu
rename to src/caffe/layers/loss_layer.cu
index 5ffa4ac..0c09a5d 100644 (file)
@@ -44,9 +44,39 @@ Dtype MultinomialLogisticLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>
   return loss / num;
 }
 
-// TODO: implement the GPU version
+// TODO: implement the GPU version for multinomial loss
+
+template <typename Dtype>
+void EuclideanLossLayer<Dtype>::SetUp(
+  const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
+  CHECK_EQ(top->size(), 0) << "Loss Layer takes no as output.";
+  CHECK_EQ(bottom[0]->num(), bottom[1]->num())
+      << "The data and label should have the same number.";
+  CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
+  CHECK_EQ(bottom[0]->height(), bottom[1]->height());
+  CHECK_EQ(bottom[0]->width(), bottom[1]->width());
+  difference_.Reshape(bottom[0]->num(), bottom[0]->channels(),
+      bottom[0]->height(), bottom[0]->width());
+}
+
+template <typename Dtype>
+Dtype EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  int count = (*bottom)[0]->count();
+  int num = (*bottom)[0]->num();
+  caffe_sub(count, (*bottom)[0]->cpu_data(), (*bottom)[1]->cpu_data(),
+      difference_.mutable_cpu_data());
+  Dtype loss = caffe_cpu_dot(
+      count, difference_.cpu_data(), difference_.cpu_data()) / num / Dtype(2);
+  // Compute the gradient
+  caffe_axpby(count, Dtype(1) / num, difference_.cpu_data(), Dtype(0),
+      (*bottom)[0]->mutable_cpu_diff());
+  return loss;
+}
 
 INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
+INSTANTIATE_CLASS(EuclideanLossLayer);
 
 
 }  // namespace caffe
diff --git a/src/caffe/test/test_euclidean_loss_layer.cpp b/src/caffe/test/test_euclidean_loss_layer.cpp
new file mode 100644 (file)
index 0000000..82ea682
--- /dev/null
@@ -0,0 +1,58 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cmath>
+#include <cstdlib>
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class EuclideanLossLayerTest : public ::testing::Test {
+ protected:
+  EuclideanLossLayerTest()
+      : blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
+        blob_bottom_label_(new Blob<Dtype>(10, 5, 1, 1)) {
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_data_);
+    blob_bottom_vec_.push_back(blob_bottom_data_);
+    filler.Fill(this->blob_bottom_label_);
+    blob_bottom_vec_.push_back(blob_bottom_label_);
+  }
+  virtual ~EuclideanLossLayerTest() {
+    delete blob_bottom_data_;
+    delete blob_bottom_label_;
+  }
+  Blob<Dtype>* const blob_bottom_data_;
+  Blob<Dtype>* const blob_bottom_label_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(EuclideanLossLayerTest, Dtypes);
+
+TYPED_TEST(EuclideanLossLayerTest, TestGradientCPU) {
+  LayerParameter layer_param;
+  Caffe::set_mode(Caffe::CPU);
+  EuclideanLossLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2, 1701);
+  checker.CheckGradientSingle(layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0, -1, -1);
+}
+
+}
index de50245..5595c84 100644 (file)
@@ -52,6 +52,7 @@ TYPED_TEST(MultinomialLogisticLossLayerTest, TestGradientCPU) {
   LayerParameter layer_param;
   Caffe::set_mode(Caffe::CPU);
   MultinomialLogisticLossLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
   GradientChecker<TypeParam> checker(1e-2, 1e-2, 1701, 0, 0.05);
   checker.CheckGradientSingle(layer, this->blob_bottom_vec_,
       this->blob_top_vec_, 0, -1, -1);
index 7cd3b26..a3a94cd 100644 (file)
@@ -104,6 +104,18 @@ void caffe_axpy<double>(const int N, const double alpha, const double* X,
     double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
 
 template <>
+void caffe_axpby<float>(const int N, const float alpha, const float* X,
+    const float beta, float* Y) {
+  cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
+}
+
+template <>
+void caffe_axpby<double>(const int N, const double alpha, const double* X,
+    const double beta, double* Y) {
+  cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
+}
+
+template <>
 void caffe_copy<float>(const int N, const float* X, float* Y) {
   cblas_scopy(N, X, 1, Y, 1);
 }
@@ -144,6 +156,22 @@ void caffe_sqr<double>(const int n, const double* a, double* y) {
 }
 
 template <>
+void caffe_add<float>(const int n, const float* a, const float* b,
+    float* y) { vsAdd(n, a, b, y); }
+
+template <>
+void caffe_add<double>(const int n, const double* a, const double* b,
+    double* y) { vdAdd(n, a, b, y); }
+
+template <>
+void caffe_sub<float>(const int n, const float* a, const float* b,
+    float* y) { vsSub(n, a, b, y); }
+
+template <>
+void caffe_sub<double>(const int n, const double* a, const double* b,
+    double* y) { vdSub(n, a, b, y); }
+
+template <>
 void caffe_mul<float>(const int n, const float* a, const float* b,
     float* y) { vsMul(n, a, b, y); }
 
index f09afe3..e3ace98 100644 (file)
@@ -40,6 +40,10 @@ void caffe_axpy(const int N, const Dtype alpha, const Dtype* X,
     Dtype* Y);
 
 template <typename Dtype>
+void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
+    const Dtype beta, Dtype* Y);
+
+template <typename Dtype>
 void caffe_copy(const int N, const Dtype *X, Dtype *Y);
 
 template <typename Dtype>
@@ -52,6 +56,12 @@ template <typename Dtype>
 void caffe_sqr(const int N, const Dtype* a, Dtype* y);
 
 template <typename Dtype>
+void caffe_add(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
+void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
 void caffe_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
 
 template <typename Dtype>
index 66eeda0..29d2eb3 100644 (file)
@@ -301,6 +301,29 @@ class MultinomialLogisticLossLayer : public Layer<Dtype> {
   //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
+template <typename Dtype>
+class EuclideanLossLayer : public Layer<Dtype> {
+ public:
+  explicit EuclideanLossLayer(const LayerParameter& param)
+      : Layer<Dtype>(param), difference_() {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  // The loss layer will do nothing during forward - all computation are
+  // carried out in the backward pass.
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) { return; }
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) { return; }
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+  //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  Blob<Dtype> difference_;
+};
+
+
 }  // namespace caffe
 
 #endif  // CAFFE_VISION_LAYERS_HPP_