return new DataLayer<Dtype>(param);
} else if (type == "dropout") {
return new DropoutLayer<Dtype>(param);
+ } else if (type == "euclidean_loss") {
+ return new EuclideanLossLayer<Dtype>(param);
} else if (type == "im2col") {
return new Im2colLayer<Dtype>(param);
} else if (type == "innerproduct") {
return loss / num;
}
-// TODO: implement the GPU version
+// TODO: implement the GPU version for multinomial loss
+
+template <typename Dtype>
+void EuclideanLossLayer<Dtype>::SetUp(
+ const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+ CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
+ CHECK_EQ(top->size(), 0) << "Loss Layer takes no as output.";
+ CHECK_EQ(bottom[0]->num(), bottom[1]->num())
+ << "The data and label should have the same number.";
+ CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
+ CHECK_EQ(bottom[0]->height(), bottom[1]->height());
+ CHECK_EQ(bottom[0]->width(), bottom[1]->width());
+ difference_.Reshape(bottom[0]->num(), bottom[0]->channels(),
+ bottom[0]->height(), bottom[0]->width());
+}
+
+template <typename Dtype>
+Dtype EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+ int count = (*bottom)[0]->count();
+ int num = (*bottom)[0]->num();
+ caffe_sub(count, (*bottom)[0]->cpu_data(), (*bottom)[1]->cpu_data(),
+ difference_.mutable_cpu_data());
+ Dtype loss = caffe_cpu_dot(
+ count, difference_.cpu_data(), difference_.cpu_data()) / num / Dtype(2);
+ // Compute the gradient
+ caffe_axpby(count, Dtype(1) / num, difference_.cpu_data(), Dtype(0),
+ (*bottom)[0]->mutable_cpu_diff());
+ return loss;
+}
INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
+INSTANTIATE_CLASS(EuclideanLossLayer);
} // namespace caffe
--- /dev/null
+// Copyright 2013 Yangqing Jia
+
+#include <cmath>
+#include <cstdlib>
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class EuclideanLossLayerTest : public ::testing::Test {
+ protected:
+ EuclideanLossLayerTest()
+ : blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
+ blob_bottom_label_(new Blob<Dtype>(10, 5, 1, 1)) {
+ // fill the values
+ FillerParameter filler_param;
+ GaussianFiller<Dtype> filler(filler_param);
+ filler.Fill(this->blob_bottom_data_);
+ blob_bottom_vec_.push_back(blob_bottom_data_);
+ filler.Fill(this->blob_bottom_label_);
+ blob_bottom_vec_.push_back(blob_bottom_label_);
+ }
+ virtual ~EuclideanLossLayerTest() {
+ delete blob_bottom_data_;
+ delete blob_bottom_label_;
+ }
+ Blob<Dtype>* const blob_bottom_data_;
+ Blob<Dtype>* const blob_bottom_label_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(EuclideanLossLayerTest, Dtypes);
+
+TYPED_TEST(EuclideanLossLayerTest, TestGradientCPU) {
+ LayerParameter layer_param;
+ Caffe::set_mode(Caffe::CPU);
+ EuclideanLossLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+ GradientChecker<TypeParam> checker(1e-2, 1e-2, 1701);
+ checker.CheckGradientSingle(layer, this->blob_bottom_vec_,
+ this->blob_top_vec_, 0, -1, -1);
+}
+
+}
LayerParameter layer_param;
Caffe::set_mode(Caffe::CPU);
MultinomialLogisticLossLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
GradientChecker<TypeParam> checker(1e-2, 1e-2, 1701, 0, 0.05);
checker.CheckGradientSingle(layer, this->blob_bottom_vec_,
this->blob_top_vec_, 0, -1, -1);
double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
template <>
+void caffe_axpby<float>(const int N, const float alpha, const float* X,
+ const float beta, float* Y) {
+ cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
+}
+
+template <>
+void caffe_axpby<double>(const int N, const double alpha, const double* X,
+ const double beta, double* Y) {
+ cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
+}
+
+template <>
void caffe_copy<float>(const int N, const float* X, float* Y) {
cblas_scopy(N, X, 1, Y, 1);
}
}
template <>
+void caffe_add<float>(const int n, const float* a, const float* b,
+ float* y) { vsAdd(n, a, b, y); }
+
+template <>
+void caffe_add<double>(const int n, const double* a, const double* b,
+ double* y) { vdAdd(n, a, b, y); }
+
+template <>
+void caffe_sub<float>(const int n, const float* a, const float* b,
+ float* y) { vsSub(n, a, b, y); }
+
+template <>
+void caffe_sub<double>(const int n, const double* a, const double* b,
+ double* y) { vdSub(n, a, b, y); }
+
+template <>
void caffe_mul<float>(const int n, const float* a, const float* b,
float* y) { vsMul(n, a, b, y); }
Dtype* Y);
template <typename Dtype>
+void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
+ const Dtype beta, Dtype* Y);
+
+template <typename Dtype>
void caffe_copy(const int N, const Dtype *X, Dtype *Y);
template <typename Dtype>
void caffe_sqr(const int N, const Dtype* a, Dtype* y);
template <typename Dtype>
+void caffe_add(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
+void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
void caffe_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
// const bool propagate_down, vector<Blob<Dtype>*>* bottom);
};
+template <typename Dtype>
+class EuclideanLossLayer : public Layer<Dtype> {
+ public:
+ explicit EuclideanLossLayer(const LayerParameter& param)
+ : Layer<Dtype>(param), difference_() {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ // The loss layer will do nothing during forward - all computation are
+ // carried out in the backward pass.
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ Blob<Dtype> difference_;
+};
+
+
} // namespace caffe
#endif // CAFFE_VISION_LAYERS_HPP_