added infogain loss layer
authorYangqing Jia <jiayq84@gmail.com>
Mon, 11 Nov 2013 19:40:31 +0000 (11:40 -0800)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 11 Nov 2013 19:40:31 +0000 (11:40 -0800)
include/caffe/vision_layers.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/loss_layer.cu

index 2e24bef..2290f0f 100644 (file)
@@ -345,6 +345,29 @@ class MultinomialLogisticLossLayer : public Layer<Dtype> {
   //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
+template <typename Dtype>
+class InfogainLossLayer : public Layer<Dtype> {
+ public:
+  explicit InfogainLossLayer(const LayerParameter& param)
+      : Layer<Dtype>(param), infogain_() {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  // The loss layer will do nothing during forward - all computation are
+  // carried out in the backward pass.
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) { return; }
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) { return; }
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+  //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  Blob<Dtype> infogain_;
+};
+
 
 // SoftmaxWithLossLayer is a layer that implements softmax and then computes
 // the loss - it is preferred over softmax + multinomiallogisticloss in the
index 178607f..b663cb2 100644 (file)
@@ -33,6 +33,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
     return new EuclideanLossLayer<Dtype>(param);
   } else if (type == "im2col") {
     return new Im2colLayer<Dtype>(param);
+  } else if (type == "infogain_loss") {
+    return new InfogainLossLayer<Dtype>(param);
   } else if (type == "innerproduct") {
     return new InnerProductLayer<Dtype>(param);
   } else if (type == "lrn") {
index 0a6f5ee..ac05ba4 100644 (file)
@@ -6,6 +6,7 @@
 #include "caffe/layer.hpp"
 #include "caffe/vision_layers.hpp"
 #include "caffe/util/math_functions.hpp"
+#include "caffe/util/io.hpp"
 
 using std::max;
 
@@ -17,7 +18,7 @@ template <typename Dtype>
 void MultinomialLogisticLossLayer<Dtype>::SetUp(
     const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
-  CHECK_EQ(top->size(), 0) << "Loss Layer takes no as output.";
+  CHECK_EQ(top->size(), 0) << "Loss Layer takes no output.";
   CHECK_EQ(bottom[0]->num(), bottom[1]->num())
       << "The data and label should have the same number.";
   CHECK_EQ(bottom[1]->channels(), 1);
@@ -50,6 +51,49 @@ Dtype MultinomialLogisticLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>
 
 
 template <typename Dtype>
+void InfogainLossLayer<Dtype>::SetUp(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
+  CHECK_EQ(top->size(), 0) << "Loss Layer takes no output.";
+  CHECK_EQ(bottom[0]->num(), bottom[1]->num())
+      << "The data and label should have the same number.";
+  CHECK_EQ(bottom[1]->channels(), 1);
+  CHECK_EQ(bottom[1]->height(), 1);
+  CHECK_EQ(bottom[1]->width(), 1);
+  BlobProto blob_proto;
+  ReadProtoFromBinaryFile(this->layer_param_.source(), &blob_proto);
+  infogain_.FromProto(blob_proto);
+  CHECK_EQ(infogain_.num(), 1);
+  CHECK_EQ(infogain_.channels(), 1);
+  CHECK_EQ(infogain_.height(), infogain_.width());
+};
+
+
+template <typename Dtype>
+Dtype InfogainLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+  const Dtype* bottom_label = (*bottom)[1]->cpu_data();
+  const Dtype* infogain_mat = infogain_.cpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  int num = (*bottom)[0]->num();
+  int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
+  CHECK_EQ(infogain_.height(), dim);
+  Dtype loss = 0;
+  for (int i = 0; i < num; ++i) {
+    int label = static_cast<int>(bottom_label[i]);
+    for (int j = 0; j < dim; ++j) {
+      Dtype prob = max(bottom_data[i * dim + j], kLOG_THRESHOLD);
+      loss -= infogain_mat[label * dim + j] * log(prob);
+      bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num;
+    }
+  }
+  return loss / num;
+}
+
+
+template <typename Dtype>
 void EuclideanLossLayer<Dtype>::SetUp(
   const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
@@ -122,6 +166,7 @@ void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 }
 
 INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
+INSTANTIATE_CLASS(InfogainLossLayer);
 INSTANTIATE_CLASS(EuclideanLossLayer);
 INSTANTIATE_CLASS(AccuracyLayer);