From 48a8a64c9b1596f60d6eabff4c0df887a3ea53bf Mon Sep 17 00:00:00 2001 From: Sergey Karayev Date: Mon, 28 Apr 2014 19:39:36 -0700 Subject: [PATCH] Split all loss layers into own .cpp files --- include/caffe/loss_layers.hpp | 4 +- src/caffe/layers/accuracy_layer.cpp | 65 ++++++ src/caffe/layers/euclidean_loss_layer.cpp | 51 +++++ src/caffe/layers/hinge_loss_layer.cpp | 57 +++++ src/caffe/layers/infogain_loss_layer.cpp | 76 +++++++ src/caffe/layers/loss_layer.cpp | 230 +-------------------- .../layers/multinomial_logistic_loss_layer.cpp | 60 ++++++ 7 files changed, 313 insertions(+), 230 deletions(-) create mode 100644 src/caffe/layers/accuracy_layer.cpp create mode 100644 src/caffe/layers/euclidean_loss_layer.cpp create mode 100644 src/caffe/layers/hinge_loss_layer.cpp create mode 100644 src/caffe/layers/infogain_loss_layer.cpp create mode 100644 src/caffe/layers/multinomial_logistic_loss_layer.cpp diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp index 6ddfcc4..a13e889 100644 --- a/include/caffe/loss_layers.hpp +++ b/include/caffe/loss_layers.hpp @@ -20,6 +20,8 @@ namespace caffe { +const float kLOG_THRESHOLD = 1e-20; + // LossLayer takes two inputs of same num, and has no output. template class LossLayer : public Layer { @@ -29,7 +31,7 @@ class LossLayer : public Layer { virtual void SetUp( const vector*>& bottom, vector*>* top); virtual void FurtherSetUp( - const vector*>& bottom, vector*>* top); + const vector*>& bottom, vector*>* top) {} }; // SigmoidCrossEntropyLossLayer diff --git a/src/caffe/layers/accuracy_layer.cpp b/src/caffe/layers/accuracy_layer.cpp new file mode 100644 index 0000000..3e67170 --- /dev/null +++ b/src/caffe/layers/accuracy_layer.cpp @@ -0,0 +1,65 @@ +// Copyright 2014 BVLC and contributors. + +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/io.hpp" + +using std::max; + +namespace caffe { + +template +void AccuracyLayer::SetUp( + const vector*>& bottom, vector*>* top) { + CHECK_EQ(bottom.size(), 2) << "Accuracy Layer takes two blobs as input."; + CHECK_EQ(top->size(), 1) << "Accuracy Layer takes 1 output."; + CHECK_EQ(bottom[0]->num(), bottom[1]->num()) + << "The data and label should have the same number."; + CHECK_EQ(bottom[1]->channels(), 1); + CHECK_EQ(bottom[1]->height(), 1); + CHECK_EQ(bottom[1]->width(), 1); + (*top)[0]->Reshape(1, 2, 1, 1); +} + +template +Dtype AccuracyLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + Dtype accuracy = 0; + Dtype logprob = 0; + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_label = bottom[1]->cpu_data(); + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + for (int i = 0; i < num; ++i) { + // Accuracy + Dtype maxval = -FLT_MAX; + int max_id = 0; + for (int j = 0; j < dim; ++j) { + if (bottom_data[i * dim + j] > maxval) { + maxval = bottom_data[i * dim + j]; + max_id = j; + } + } + if (max_id == static_cast(bottom_label[i])) { + ++accuracy; + } + Dtype prob = max(bottom_data[i * dim + static_cast(bottom_label[i])], + Dtype(kLOG_THRESHOLD)); + logprob -= log(prob); + } + // LOG(INFO) << "Accuracy: " << accuracy; + (*top)[0]->mutable_cpu_data()[0] = accuracy / num; + (*top)[0]->mutable_cpu_data()[1] = logprob / num; + // Accuracy layer should not be used as a loss function. + return Dtype(0); +} + +INSTANTIATE_CLASS(AccuracyLayer); + +} // namespace caffe diff --git a/src/caffe/layers/euclidean_loss_layer.cpp b/src/caffe/layers/euclidean_loss_layer.cpp new file mode 100644 index 0000000..9bf7f98 --- /dev/null +++ b/src/caffe/layers/euclidean_loss_layer.cpp @@ -0,0 +1,51 @@ +// Copyright 2014 BVLC and contributors. + +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/io.hpp" + +using std::max; + +namespace caffe { + +template +void EuclideanLossLayer::FurtherSetUp( + const vector*>& bottom, vector*>* top) { + CHECK_EQ(bottom[0]->channels(), bottom[1]->channels()); + CHECK_EQ(bottom[0]->height(), bottom[1]->height()); + CHECK_EQ(bottom[0]->width(), bottom[1]->width()); + difference_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); +} + +template +Dtype EuclideanLossLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + int count = bottom[0]->count(); + int num = bottom[0]->num(); + caffe_sub(count, bottom[0]->cpu_data(), bottom[1]->cpu_data(), + difference_.mutable_cpu_data()); + Dtype loss = caffe_cpu_dot( + count, difference_.cpu_data(), difference_.cpu_data()) / num / Dtype(2); + return loss; +} + +template +void EuclideanLossLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom) { + int count = (*bottom)[0]->count(); + int num = (*bottom)[0]->num(); + // Compute the gradient + caffe_cpu_axpby(count, Dtype(1) / num, difference_.cpu_data(), Dtype(0), + (*bottom)[0]->mutable_cpu_diff()); +} + +INSTANTIATE_CLASS(EuclideanLossLayer); + +} // namespace caffe diff --git a/src/caffe/layers/hinge_loss_layer.cpp b/src/caffe/layers/hinge_loss_layer.cpp new file mode 100644 index 0000000..24329fb --- /dev/null +++ b/src/caffe/layers/hinge_loss_layer.cpp @@ -0,0 +1,57 @@ +// Copyright 2014 BVLC and contributors. + +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/io.hpp" + +using std::max; + +namespace caffe { + +template +Dtype HingeLossLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const Dtype* label = bottom[1]->cpu_data(); + int num = bottom[0]->num(); + int count = bottom[0]->count(); + int dim = count / num; + + caffe_copy(count, bottom_data, bottom_diff); + for (int i = 0; i < num; ++i) { + bottom_diff[i * dim + static_cast(label[i])] *= -1; + } + for (int i = 0; i < num; ++i) { + for (int j = 0; j < dim; ++j) { + bottom_diff[i * dim + j] = max(Dtype(0), 1 + bottom_diff[i * dim + j]); + } + } + return caffe_cpu_asum(count, bottom_diff) / num; +} + +template +void HingeLossLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom) { + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + const Dtype* label = (*bottom)[1]->cpu_data(); + int num = (*bottom)[0]->num(); + int count = (*bottom)[0]->count(); + int dim = count / num; + + caffe_cpu_sign(count, bottom_diff, bottom_diff); + for (int i = 0; i < num; ++i) { + bottom_diff[i * dim + static_cast(label[i])] *= -1; + } + caffe_scal(count, Dtype(1. / num), bottom_diff); +} + +INSTANTIATE_CLASS(HingeLossLayer); + +} // namespace caffe diff --git a/src/caffe/layers/infogain_loss_layer.cpp b/src/caffe/layers/infogain_loss_layer.cpp new file mode 100644 index 0000000..ab6e67d --- /dev/null +++ b/src/caffe/layers/infogain_loss_layer.cpp @@ -0,0 +1,76 @@ +// Copyright 2014 BVLC and contributors. + +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/io.hpp" + +using std::max; + +namespace caffe { + +template +void InfogainLossLayer::FurtherSetUp( + const vector*>& bottom, vector*>* top) { + CHECK_EQ(bottom[1]->channels(), 1); + CHECK_EQ(bottom[1]->height(), 1); + CHECK_EQ(bottom[1]->width(), 1); + + BlobProto blob_proto; + ReadProtoFromBinaryFile( + this->layer_param_.infogain_loss_param().source(), &blob_proto); + infogain_.FromProto(blob_proto); + CHECK_EQ(infogain_.num(), 1); + CHECK_EQ(infogain_.channels(), 1); + CHECK_EQ(infogain_.height(), infogain_.width()); +} + + +template +Dtype InfogainLossLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_label = bottom[1]->cpu_data(); + const Dtype* infogain_mat = infogain_.cpu_data(); + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + CHECK_EQ(infogain_.height(), dim); + Dtype loss = 0; + for (int i = 0; i < num; ++i) { + int label = static_cast(bottom_label[i]); + for (int j = 0; j < dim; ++j) { + Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD)); + loss -= infogain_mat[label * dim + j] * log(prob); + } + } + return loss / num; +} + +template +void InfogainLossLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + const Dtype* bottom_data = (*bottom)[0]->cpu_data(); + const Dtype* bottom_label = (*bottom)[1]->cpu_data(); + const Dtype* infogain_mat = infogain_.cpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + int num = (*bottom)[0]->num(); + int dim = (*bottom)[0]->count() / (*bottom)[0]->num(); + CHECK_EQ(infogain_.height(), dim); + for (int i = 0; i < num; ++i) { + int label = static_cast(bottom_label[i]); + for (int j = 0; j < dim; ++j) { + Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD)); + bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num; + } + } +} + +INSTANTIATE_CLASS(InfogainLossLayer); + +} // namespace caffe diff --git a/src/caffe/layers/loss_layer.cpp b/src/caffe/layers/loss_layer.cpp index 3fc34a6..1efb623 100644 --- a/src/caffe/layers/loss_layer.cpp +++ b/src/caffe/layers/loss_layer.cpp @@ -14,8 +14,6 @@ using std::max; namespace caffe { -const float kLOG_THRESHOLD = 1e-20; - template void LossLayer::SetUp( const vector*>& bottom, vector*>* top) { @@ -26,232 +24,6 @@ void LossLayer::SetUp( FurtherSetUp(bottom, top); } -template -void LossLayer::FurtherSetUp( - const vector*>& bottom, vector*>* top) { - // Nothing to do -} - -template -void MultinomialLogisticLossLayer::FurtherSetUp( - const vector*>& bottom, vector*>* top) { - CHECK_EQ(bottom[1]->channels(), 1); - CHECK_EQ(bottom[1]->height(), 1); - CHECK_EQ(bottom[1]->width(), 1); -} - -template -Dtype MultinomialLogisticLossLayer::Forward_cpu( - const vector*>& bottom, vector*>* top) { - const Dtype* bottom_data = bottom[0]->cpu_data(); - const Dtype* bottom_label = bottom[1]->cpu_data(); - int num = bottom[0]->num(); - int dim = bottom[0]->count() / bottom[0]->num(); - Dtype loss = 0; - for (int i = 0; i < num; ++i) { - int label = static_cast(bottom_label[i]); - Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD)); - loss -= log(prob); - } - return loss / num; -} - -template -void MultinomialLogisticLossLayer::Backward_cpu( - const vector*>& top, const bool propagate_down, - vector*>* bottom) { - const Dtype* bottom_data = (*bottom)[0]->cpu_data(); - const Dtype* bottom_label = (*bottom)[1]->cpu_data(); - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - int num = (*bottom)[0]->num(); - int dim = (*bottom)[0]->count() / (*bottom)[0]->num(); - memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count()); - for (int i = 0; i < num; ++i) { - int label = static_cast(bottom_label[i]); - Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD)); - bottom_diff[i * dim + label] = -1. / prob / num; - } -} - - -template -void InfogainLossLayer::FurtherSetUp( - const vector*>& bottom, vector*>* top) { - CHECK_EQ(bottom[1]->channels(), 1); - CHECK_EQ(bottom[1]->height(), 1); - CHECK_EQ(bottom[1]->width(), 1); - - BlobProto blob_proto; - ReadProtoFromBinaryFile( - this->layer_param_.infogain_loss_param().source(), &blob_proto); - infogain_.FromProto(blob_proto); - CHECK_EQ(infogain_.num(), 1); - CHECK_EQ(infogain_.channels(), 1); - CHECK_EQ(infogain_.height(), infogain_.width()); -} - - -template -Dtype InfogainLossLayer::Forward_cpu(const vector*>& bottom, - vector*>* top) { - const Dtype* bottom_data = bottom[0]->cpu_data(); - const Dtype* bottom_label = bottom[1]->cpu_data(); - const Dtype* infogain_mat = infogain_.cpu_data(); - int num = bottom[0]->num(); - int dim = bottom[0]->count() / bottom[0]->num(); - CHECK_EQ(infogain_.height(), dim); - Dtype loss = 0; - for (int i = 0; i < num; ++i) { - int label = static_cast(bottom_label[i]); - for (int j = 0; j < dim; ++j) { - Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD)); - loss -= infogain_mat[label * dim + j] * log(prob); - } - } - return loss / num; -} - -template -void InfogainLossLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, - vector*>* bottom) { - const Dtype* bottom_data = (*bottom)[0]->cpu_data(); - const Dtype* bottom_label = (*bottom)[1]->cpu_data(); - const Dtype* infogain_mat = infogain_.cpu_data(); - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - int num = (*bottom)[0]->num(); - int dim = (*bottom)[0]->count() / (*bottom)[0]->num(); - CHECK_EQ(infogain_.height(), dim); - for (int i = 0; i < num; ++i) { - int label = static_cast(bottom_label[i]); - for (int j = 0; j < dim; ++j) { - Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD)); - bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num; - } - } -} - - -template -void EuclideanLossLayer::FurtherSetUp( - const vector*>& bottom, vector*>* top) { - CHECK_EQ(bottom[0]->channels(), bottom[1]->channels()); - CHECK_EQ(bottom[0]->height(), bottom[1]->height()); - CHECK_EQ(bottom[0]->width(), bottom[1]->width()); - difference_.Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); -} - -template -Dtype EuclideanLossLayer::Forward_cpu(const vector*>& bottom, - vector*>* top) { - int count = bottom[0]->count(); - int num = bottom[0]->num(); - caffe_sub(count, bottom[0]->cpu_data(), bottom[1]->cpu_data(), - difference_.mutable_cpu_data()); - Dtype loss = caffe_cpu_dot( - count, difference_.cpu_data(), difference_.cpu_data()) / num / Dtype(2); - return loss; -} - -template -void EuclideanLossLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - int count = (*bottom)[0]->count(); - int num = (*bottom)[0]->num(); - // Compute the gradient - caffe_cpu_axpby(count, Dtype(1) / num, difference_.cpu_data(), Dtype(0), - (*bottom)[0]->mutable_cpu_diff()); -} - -template -Dtype HingeLossLayer::Forward_cpu(const vector*>& bottom, - vector*>* top) { - const Dtype* bottom_data = bottom[0]->cpu_data(); - Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); - const Dtype* label = bottom[1]->cpu_data(); - int num = bottom[0]->num(); - int count = bottom[0]->count(); - int dim = count / num; - - caffe_copy(count, bottom_data, bottom_diff); - for (int i = 0; i < num; ++i) { - bottom_diff[i * dim + static_cast(label[i])] *= -1; - } - for (int i = 0; i < num; ++i) { - for (int j = 0; j < dim; ++j) { - bottom_diff[i * dim + j] = max(Dtype(0), 1 + bottom_diff[i * dim + j]); - } - } - return caffe_cpu_asum(count, bottom_diff) / num; -} - -template -void HingeLossLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - const Dtype* label = (*bottom)[1]->cpu_data(); - int num = (*bottom)[0]->num(); - int count = (*bottom)[0]->count(); - int dim = count / num; - - caffe_cpu_sign(count, bottom_diff, bottom_diff); - for (int i = 0; i < num; ++i) { - bottom_diff[i * dim + static_cast(label[i])] *= -1; - } - caffe_scal(count, Dtype(1. / num), bottom_diff); -} - -template -void AccuracyLayer::SetUp( - const vector*>& bottom, vector*>* top) { - CHECK_EQ(bottom.size(), 2) << "Accuracy Layer takes two blobs as input."; - CHECK_EQ(top->size(), 1) << "Accuracy Layer takes 1 output."; - CHECK_EQ(bottom[0]->num(), bottom[1]->num()) - << "The data and label should have the same number."; - CHECK_EQ(bottom[1]->channels(), 1); - CHECK_EQ(bottom[1]->height(), 1); - CHECK_EQ(bottom[1]->width(), 1); - (*top)[0]->Reshape(1, 2, 1, 1); -} - -template -Dtype AccuracyLayer::Forward_cpu(const vector*>& bottom, - vector*>* top) { - Dtype accuracy = 0; - Dtype logprob = 0; - const Dtype* bottom_data = bottom[0]->cpu_data(); - const Dtype* bottom_label = bottom[1]->cpu_data(); - int num = bottom[0]->num(); - int dim = bottom[0]->count() / bottom[0]->num(); - for (int i = 0; i < num; ++i) { - // Accuracy - Dtype maxval = -FLT_MAX; - int max_id = 0; - for (int j = 0; j < dim; ++j) { - if (bottom_data[i * dim + j] > maxval) { - maxval = bottom_data[i * dim + j]; - max_id = j; - } - } - if (max_id == static_cast(bottom_label[i])) { - ++accuracy; - } - Dtype prob = max(bottom_data[i * dim + static_cast(bottom_label[i])], - Dtype(kLOG_THRESHOLD)); - logprob -= log(prob); - } - // LOG(INFO) << "Accuracy: " << accuracy; - (*top)[0]->mutable_cpu_data()[0] = accuracy / num; - (*top)[0]->mutable_cpu_data()[1] = logprob / num; - // Accuracy layer should not be used as a loss function. - return Dtype(0); -} - -INSTANTIATE_CLASS(MultinomialLogisticLossLayer); -INSTANTIATE_CLASS(InfogainLossLayer); -INSTANTIATE_CLASS(EuclideanLossLayer); -INSTANTIATE_CLASS(HingeLossLayer); -INSTANTIATE_CLASS(AccuracyLayer); +INSTANTIATE_CLASS(LossLayer); } // namespace caffe diff --git a/src/caffe/layers/multinomial_logistic_loss_layer.cpp b/src/caffe/layers/multinomial_logistic_loss_layer.cpp new file mode 100644 index 0000000..6486621 --- /dev/null +++ b/src/caffe/layers/multinomial_logistic_loss_layer.cpp @@ -0,0 +1,60 @@ +// Copyright 2014 BVLC and contributors. + +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/io.hpp" + +using std::max; + +namespace caffe { + +template +void MultinomialLogisticLossLayer::FurtherSetUp( + const vector*>& bottom, vector*>* top) { + CHECK_EQ(bottom[1]->channels(), 1); + CHECK_EQ(bottom[1]->height(), 1); + CHECK_EQ(bottom[1]->width(), 1); +} + +template +Dtype MultinomialLogisticLossLayer::Forward_cpu( + const vector*>& bottom, vector*>* top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_label = bottom[1]->cpu_data(); + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + Dtype loss = 0; + for (int i = 0; i < num; ++i) { + int label = static_cast(bottom_label[i]); + Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD)); + loss -= log(prob); + } + return loss / num; +} + +template +void MultinomialLogisticLossLayer::Backward_cpu( + const vector*>& top, const bool propagate_down, + vector*>* bottom) { + const Dtype* bottom_data = (*bottom)[0]->cpu_data(); + const Dtype* bottom_label = (*bottom)[1]->cpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + int num = (*bottom)[0]->num(); + int dim = (*bottom)[0]->count() / (*bottom)[0]->num(); + memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count()); + for (int i = 0; i < num; ++i) { + int label = static_cast(bottom_label[i]); + Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD)); + bottom_diff[i * dim + label] = -1. / prob / num; + } +} + +INSTANTIATE_CLASS(MultinomialLogisticLossLayer); + +} // namespace caffe -- 2.7.4