softmaxwithloss layer: softmax + loss
authorYangqing Jia <jiayq84@gmail.com>
Wed, 9 Oct 2013 05:25:03 +0000 (22:25 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Wed, 9 Oct 2013 05:25:03 +0000 (22:25 -0700)
src/caffe/layer_factory.hpp
src/caffe/layers/loss_layer.cu
src/caffe/layers/lrn_layer.cu
src/caffe/layers/softmax_loss_layer.cu [new file with mode: 0644]
src/caffe/optimization/solver.cpp
src/caffe/test/test_softmax_with_loss_layer.cpp [new file with mode: 0644]
src/caffe/vision_layers.hpp
src/programs/train_alexnet.cpp

index c2795b9..d231e17 100644 (file)
@@ -43,6 +43,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
     return new ReLULayer<Dtype>(param);
   } else if (type == "softmax") {
     return new SoftmaxLayer<Dtype>(param);
+  } else if (type == "softmax_loss") {
+    return new SoftmaxWithLossLayer<Dtype>(param);
   } else if (type == "multinomial_logistic_loss") {
     return new MultinomialLogisticLossLayer<Dtype>(param);
   } else {
index 737f1a2..18a8023 100644 (file)
@@ -35,7 +35,7 @@ Dtype MultinomialLogisticLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>
   int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
   memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count());
   Dtype loss = 0;
-  const Dtype kLOG_THRESHOLD = 1e-8;
+  const Dtype kLOG_THRESHOLD = 1e-20;
   for (int i = 0; i < num; ++i) {
     int label = static_cast<int>(bottom_label[i]);
     Dtype prob = max(bottom_data[i * dim + label], kLOG_THRESHOLD);
index c2a5201..2afbf38 100644 (file)
@@ -109,7 +109,7 @@ __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
     int pre_pad = size - (size + 1) / 2;
     int post_pad = size - pre_pad - 1;
     Dtype accum_ratio = 0;
-    // accumulate values 
+    // accumulate values
     while (head < post_pad) {
       accum_ratio += top_diff[head * step] * top_data[head * step] /
           scale[head * step];
diff --git a/src/caffe/layers/softmax_loss_layer.cu b/src/caffe/layers/softmax_loss_layer.cu
new file mode 100644 (file)
index 0000000..3a001c0
--- /dev/null
@@ -0,0 +1,73 @@
+// Copyright 2013 Yangqing Jia
+
+#include <algorithm>
+#include <cfloat>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+using std::max;
+
+namespace caffe {
+
+template <typename Dtype>
+void SoftmaxWithLossLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 2) << "SoftmaxLoss Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 0) << "SoftmaxLoss Layer takes no blob as output.";
+  softmax_bottom_vec_.clear();
+  softmax_bottom_vec_.push_back(bottom[0]);
+  softmax_top_vec_.push_back(&prob_);
+  softmax_layer_->SetUp(softmax_bottom_vec_, &softmax_top_vec_);
+};
+
+template <typename Dtype>
+void SoftmaxWithLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  // The forward pass computes the softmax prob values.
+  softmax_bottom_vec_[0] = bottom[0];
+  softmax_layer_->Forward(softmax_bottom_vec_, &softmax_top_vec_);
+}
+
+template <typename Dtype>
+void SoftmaxWithLossLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  // The forward pass computes the softmax prob values.
+  softmax_bottom_vec_[0] = bottom[0];
+  softmax_layer_->Forward(softmax_bottom_vec_, &softmax_top_vec_);
+}
+
+template <typename Dtype>
+Dtype SoftmaxWithLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  // First, compute the diff
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  const Dtype* prob_data = prob_.cpu_data();
+  memcpy(bottom_diff, prob_data, sizeof(Dtype) * prob_.count());
+  const Dtype* label = (*bottom)[1]->cpu_data();
+  int num = prob_.num();
+  int dim = prob_.count() / num;
+  Dtype loss = 0;
+  for (int i = 0; i < num; ++i) {
+    bottom_diff[i * dim + static_cast<int>(label[i])] -= 1;
+    loss += -log(max(prob_data[i * dim + static_cast<int>(label[i])], FLT_MIN));
+  }
+  // Scale down gradient
+  caffe_scal(prob_.count(), Dtype(1) / num, bottom_diff);
+  return loss / num;
+}
+
+template <typename Dtype>
+Dtype SoftmaxWithLossLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  // TODO(Yangqing): implement the GPU version of softmax.
+  return Backward_cpu(top, propagate_down, bottom);
+}
+
+INSTANTIATE_CLASS(SoftmaxWithLossLayer);
+
+
+}  // namespace caffe
index d841383..6d82df3 100644 (file)
@@ -45,7 +45,8 @@ void Solver<Dtype>::Solve(Net<Dtype>* net) {
 template <typename Dtype>
 void Solver<Dtype>::Snapshot(bool is_final) {
   NetParameter net_param;
-  net_->ToProto(&net_param);
+  // For intermediate results, we will also dump the gradient values.
+  net_->ToProto(&net_param, !is_final);
   stringstream ss;
   ss << param_.snapshot_prefix();
   if (is_final) {
@@ -83,13 +84,11 @@ void SGDSolver<Dtype>::PreSolve() {
   // First of all, see if we need to initialize the history
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   history_.clear();
-  if (this->param_.momentum() > 0) {
-    for (int i = 0; i < net_params.size(); ++i) {
-      const Blob<Dtype>* net_param = net_params[i].get();
-      history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
-          net_param->num(), net_param->channels(), net_param->height(),
-          net_param->width())));
-    }
+  for (int i = 0; i < net_params.size(); ++i) {
+    const Blob<Dtype>* net_param = net_params[i].get();
+    history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
+        net_param->num(), net_param->channels(), net_param->height(),
+        net_param->width())));
   }
 }
 
diff --git a/src/caffe/test/test_softmax_with_loss_layer.cpp b/src/caffe/test/test_softmax_with_loss_layer.cpp
new file mode 100644 (file)
index 0000000..a955192
--- /dev/null
@@ -0,0 +1,71 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cmath>
+#include <cstdlib>
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class SoftmaxWithLossLayerTest : public ::testing::Test {
+ protected:
+  SoftmaxWithLossLayerTest()
+      : blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
+        blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)) {
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_data_);
+    blob_bottom_vec_.push_back(blob_bottom_data_);
+    for (int i = 0; i < blob_bottom_label_->count(); ++i) {
+      blob_bottom_label_->mutable_cpu_data()[i] = rand() % 5;
+    }
+    blob_bottom_vec_.push_back(blob_bottom_label_);
+  }
+  virtual ~SoftmaxWithLossLayerTest() {
+    delete blob_bottom_data_;
+    delete blob_bottom_label_;
+  }
+  Blob<Dtype>* const blob_bottom_data_;
+  Blob<Dtype>* const blob_bottom_label_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(SoftmaxWithLossLayerTest, Dtypes);
+
+
+TYPED_TEST(SoftmaxWithLossLayerTest, TestGradientCPU) {
+  LayerParameter layer_param;
+  Caffe::set_mode(Caffe::CPU);
+  SoftmaxWithLossLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2, 1701);
+  checker.CheckGradientSingle(layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0, -1, -1);
+}
+
+TYPED_TEST(SoftmaxWithLossLayerTest, TestGradientGPU) {
+  LayerParameter layer_param;
+  Caffe::set_mode(Caffe::GPU);
+  SoftmaxWithLossLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2, 1701);
+  checker.CheckGradientSingle(layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0, -1, -1);
+}
+
+}
index 6f943a8..7f94d99 100644 (file)
@@ -233,6 +233,7 @@ class ConvolutionLayer : public Layer<Dtype> {
   int N_;
 };
 
+// This function is used to create a pthread that prefetches the data.
 template <typename Dtype>
 void* DataLayerPrefetch(void* layer_pointer);
 
@@ -293,6 +294,7 @@ class SoftmaxLayer : public Layer<Dtype> {
   Blob<Dtype> scale_;
 };
 
+
 template <typename Dtype>
 class MultinomialLogisticLossLayer : public Layer<Dtype> {
  public:
@@ -314,6 +316,39 @@ class MultinomialLogisticLossLayer : public Layer<Dtype> {
   //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
+
+// SoftmaxWithLossLayer is a layer that implements softmax and then computes
+// the loss - it is preferred over softmax + multinomiallogisticloss in the
+// sense that during training, this will produce more numerically stable
+// gradients. During testing this layer could be replaced by a softmax layer
+// to generate probability outputs.
+template <typename Dtype>
+class SoftmaxWithLossLayer : public Layer<Dtype> {
+ public:
+  explicit SoftmaxWithLossLayer(const LayerParameter& param)
+      : Layer<Dtype>(param), softmax_layer_(new SoftmaxLayer<Dtype>(param)) {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  shared_ptr<SoftmaxLayer<Dtype> > softmax_layer_;
+  // prob stores the output probability of the layer.
+  Blob<Dtype> prob_;
+  // Vector holders to call the underlying softmax layer forward and backward.
+  vector<Blob<Dtype>*> softmax_bottom_vec_;
+  vector<Blob<Dtype>*> softmax_top_vec_;
+};
+
+
 template <typename Dtype>
 class EuclideanLossLayer : public Layer<Dtype> {
  public:
@@ -345,8 +380,6 @@ class AccuracyLayer : public Layer<Dtype> {
       vector<Blob<Dtype>*>* top);
 
  protected:
-  // The loss layer will do nothing during forward - all computation are
-  // carried out in the backward pass.
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   // The accuracy layer should not be used to compute backward operations.
index d908859..6b2260a 100644 (file)
@@ -33,15 +33,13 @@ int main(int argc, char** argv) {
   LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
 
   SolverParameter solver_param;
-  solver_param.set_base_lr(0.001);
+  solver_param.set_base_lr(0.002);
   solver_param.set_display(1);
   solver_param.set_max_iter(600000);
   solver_param.set_lr_policy("fixed");
-  //solver_param.set_gamma(0.0001);
-  //solver_param.set_power(0.75);
   solver_param.set_momentum(0.9);
   solver_param.set_weight_decay(0.0005);
-  solver_param.set_snapshot(100);
+  solver_param.set_snapshot(1000);
   solver_param.set_snapshot_prefix("alexnet");
 
   LOG(ERROR) << "Starting Optimization";