Added Multistep, Poly and Sigmoid learning rate decay policies
authorSergio <sguada@gmail.com>
Sat, 4 Oct 2014 00:14:20 +0000 (17:14 -0700)
committerSergio <sguada@gmail.com>
Mon, 22 Dec 2014 01:00:05 +0000 (17:00 -0800)
Conflicts:
include/caffe/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp

examples/lenet/lenet_multistep_solver.prototxt [new file with mode: 0644]
examples/lenet/lenet_stepearly_solver.prototxt [new file with mode: 0644]
include/caffe/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp

diff --git a/examples/lenet/lenet_multistep_solver.prototxt b/examples/lenet/lenet_multistep_solver.prototxt
new file mode 100644 (file)
index 0000000..fadd7c9
--- /dev/null
@@ -0,0 +1,33 @@
+# The training protocol buffer definition
+train_net: "lenet_train.prototxt"
+# The testing protocol buffer definition
+test_net: "lenet_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of MNIST, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 500 training iterations.
+test_interval: 500
+# The base learning rate, momentum and the weight decay of the network.
+base_lr: 0.01
+momentum: 0.9
+weight_decay: 0.0005
+# The learning rate policy
+lr_policy: "multistep"
+gamma: 0.9
+stepvalue: 1000
+stepvalue: 2000
+stepvalue: 2500
+stepvalue: 3000
+stepvalue: 3500
+stepvalue: 4000
+# Display every 100 iterations
+display: 100
+# The maximum number of iterations
+max_iter: 10000
+# snapshot intermediate results
+snapshot: 5000
+snapshot_prefix: "lenet"
+# solver mode: 0 for CPU and 1 for GPU
+solver_mode: 1
+device_id: 1
diff --git a/examples/lenet/lenet_stepearly_solver.prototxt b/examples/lenet/lenet_stepearly_solver.prototxt
new file mode 100644 (file)
index 0000000..efc6a33
--- /dev/null
@@ -0,0 +1,28 @@
+# The training protocol buffer definition
+train_net: "lenet_train.prototxt"
+# The testing protocol buffer definition
+test_net: "lenet_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of MNIST, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 500 training iterations.
+test_interval: 500
+# The base learning rate, momentum and the weight decay of the network.
+base_lr: 0.01
+momentum: 0.9
+weight_decay: 0.0005
+# The learning rate policy
+lr_policy: "stepearly"
+gamma: 0.9
+stepearly: 1
+# Display every 100 iterations
+display: 100
+# The maximum number of iterations
+max_iter: 10000
+# snapshot intermediate results
+snapshot: 5000
+snapshot_prefix: "lenet"
+# solver mode: 0 for CPU and 1 for GPU
+solver_mode: 1
+device_id: 1
index 6fd159d..51aebb3 100644 (file)
@@ -56,6 +56,7 @@ class Solver {
 
   SolverParameter param_;
   int iter_;
+  int current_step_;
   shared_ptr<Net<Dtype> > net_;
   vector<shared_ptr<Net<Dtype> > > test_nets_;
 
index a789aee..88f670f 100644 (file)
@@ -63,7 +63,7 @@ message NetParameter {
 // NOTE
 // Update the next available ID when you add a new SolverParameter field.
 //
-// SolverParameter next available ID: 34 (last added: average_loss)
+// SolverParameter next available ID: 35 (last added: stepvalue)
 message SolverParameter {
   //////////////////////////////////////////////////////////////////////////////
   // Specifying the train and test networks
@@ -124,7 +124,10 @@ message SolverParameter {
   // regularization types supported: L1 and L2
   // controlled by weight_decay
   optional string regularization_type = 29 [default = "L2"];
-  optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
+  // the stepsize for learning rate policy "step"
+  optional int32 stepsize = 13;
+  // the stepsize for learning rate policy "multistep"
+  repeated int32 stepvalue = 34;
   optional int32 snapshot = 14 [default = 0]; // The snapshot interval
   optional string snapshot_prefix = 15; // The prefix for the snapshot.
   // whether to snapshot diff in the results or not. Snapshotting diff will help
@@ -166,6 +169,7 @@ message SolverState {
   optional int32 iter = 1; // The current iteration
   optional string learned_net = 2; // The file that stores the learned net.
   repeated BlobProto history = 3; // The history for sgd solvers
+  optional int32 current_step = 4 [default = 0]; // The current step for learning rate
 }
 
 enum Phase {
index 5c34e48..886c3cc 100644 (file)
@@ -158,6 +158,7 @@ template <typename Dtype>
 void Solver<Dtype>::Solve(const char* resume_file) {
   Caffe::set_phase(Caffe::TRAIN);
   LOG(INFO) << "Solving " << net_->name();
+  LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
   PreSolve();
 
   iter_ = 0;
@@ -257,7 +258,6 @@ void Solver<Dtype>::TestAll() {
   }
 }
 
-
 template <typename Dtype>
 void Solver<Dtype>::Test(const int test_net_id) {
   LOG(INFO) << "Iteration " << iter_
@@ -336,6 +336,7 @@ void Solver<Dtype>::Snapshot() {
   SnapshotSolverState(&state);
   state.set_iter(iter_);
   state.set_learned_net(model_filename);
+  state.set_current_step(current_step_);
   snapshot_filename = filename + ".solverstate";
   LOG(INFO) << "Snapshotting solver state to " << snapshot_filename;
   WriteProtoToBinaryFile(state, snapshot_filename.c_str());
@@ -351,6 +352,7 @@ void Solver<Dtype>::Restore(const char* state_file) {
     net_->CopyTrainedLayersFrom(net_param);
   }
   iter_ = state.iter();
+  current_step_ = state.current_step();
   RestoreSolverState(state);
 }
 
@@ -361,8 +363,15 @@ void Solver<Dtype>::Restore(const char* state_file) {
 //    - step: return base_lr * gamma ^ (floor(iter / step))
 //    - exp: return base_lr * gamma ^ iter
 //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
-// where base_lr, gamma, step and power are defined in the solver parameter
-// protocol buffer, and iter is the current iteration.
+//    - multistep: similar to step but it allows non uniform steps defined by
+//      stepvalue
+//    - poly: the effective learning rate follows a polynomial decay, to be
+//      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
+//    - sigmoid: the effective learning rate follows a sigmod decay
+//      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
+//
+// where base_lr, max_iter, gamma, step, stepvalue and power are defined
+// in the solver parameter protocol buffer, and iter is the current iteration.
 template <typename Dtype>
 Dtype SGDSolver<Dtype>::GetLearningRate() {
   Dtype rate;
@@ -370,22 +379,38 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
   if (lr_policy == "fixed") {
     rate = this->param_.base_lr();
   } else if (lr_policy == "step") {
-    int current_step = this->iter_ / this->param_.stepsize();
+    this->current_step_ = this->iter_ / this->param_.stepsize();
     rate = this->param_.base_lr() *
-        pow(this->param_.gamma(), current_step);
+        pow(this->param_.gamma(), this->current_step_);
   } else if (lr_policy == "exp") {
     rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
   } else if (lr_policy == "inv") {
     rate = this->param_.base_lr() *
         pow(Dtype(1) + this->param_.gamma() * this->iter_,
             - this->param_.power());
+  } else if (lr_policy == "multistep") {
+    if (this->current_step_ < this->param_.stepvalue_size() &&
+          this->iter_ >= this->param_.stepvalue(this->current_step_)) {
+      this->current_step_++;
+      LOG(INFO) << "MultiStep Status: Iteration " <<
+      this->iter_ << ", step = " << this->current_step_;
+    }
+    rate = this->param_.base_lr() *
+        pow(this->param_.gamma(), this->current_step_);
+  } else if (lr_policy == "poly") {
+    rate = this->param_.base_lr() * pow(Dtype(1.) -
+        (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
+        this->param_.power());
+  } else if (lr_policy == "sigmoid") {
+    rate = this->param_.base_lr() * (Dtype(1.) /
+        (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
+          Dtype(this->param_.stepsize())))));
   } else {
     LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
   }
   return rate;
 }
 
-
 template <typename Dtype>
 void SGDSolver<Dtype>::PreSolve() {
   // Initialize the history