Moved the layer factory implementation to cpp; added snapshot and restore functions...
authorYangqing Jia <jiayq84@gmail.com>
Tue, 15 Oct 2013 18:28:26 +0000 (11:28 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Tue, 15 Oct 2013 18:28:26 +0000 (11:28 -0700)
src/caffe/caffe.hpp
src/caffe/layer.hpp
src/caffe/layer_factory.cpp [moved from src/caffe/layer_factory.hpp with 93% similarity]
src/caffe/net.cpp
src/caffe/optimization/solver.cpp
src/caffe/optimization/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/vision_layers.hpp

index 800138f..5806bc0 100644 (file)
@@ -7,7 +7,6 @@
 #include "caffe/blob.hpp"
 #include "caffe/filler.hpp"
 #include "caffe/layer.hpp"
-#include "caffe/layer_factory.hpp"
 #include "caffe/net.hpp"
 #include "caffe/vision_layers.hpp"
 
index cbfde0c..adc6365 100644 (file)
@@ -127,6 +127,10 @@ void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
   }
 }
 
+// The layer factory function
+template <typename Dtype>
+Layer<Dtype>* GetLayer(const LayerParameter& param);
+
 }  // namespace caffe
 
 #endif  // CAFFE_LAYER_H_
similarity index 93%
rename from src/caffe/layer_factory.hpp
rename to src/caffe/layer_factory.cpp
index d231e17..6961bb3 100644 (file)
@@ -54,6 +54,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
   return (Layer<Dtype>*)(NULL);
 }
 
+template Layer<float>* GetLayer(const LayerParameter& param);
+template Layer<double>* GetLayer(const LayerParameter& param);
 
 }  // namespace caffe
 
index 22250da..e1442ec 100644 (file)
@@ -6,7 +6,7 @@
 #include <vector>
 
 #include "caffe/proto/caffe.pb.h"
-#include "caffe/layer_factory.hpp"
+#include "caffe/layer.hpp"
 #include "caffe/net.hpp"
 
 using std::pair;
index b2a5760..73c69c0 100644 (file)
@@ -18,11 +18,17 @@ using std::min;
 namespace caffe {
 
 template <typename Dtype>
-void Solver<Dtype>::Solve(Net<Dtype>* net) {
+void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
   net_ = net;
   LOG(INFO) << "Solving " << net_->name();
   PreSolve();
+
   iter_ = 0;
+  if (resume_file) {
+    LOG(INFO) << "Restoring previous solver status from " << resume_file;
+    Restore(resume_file);
+  }
+
   // For a network that is trained by the solver, no bottom or top vecs
   // should be given, and we will just provide dummy vecs.
   vector<Blob<Dtype>*> bottom_vec;
@@ -56,8 +62,26 @@ void Solver<Dtype>::Snapshot(bool is_final) {
     sprintf(iter_str_buffer, "_iter_%d", iter_);
     filename += iter_str_buffer;
   }
-  LOG(ERROR) << "Snapshotting to " << filename;
+  LOG(INFO) << "Snapshotting to " << filename;
   WriteProtoToBinaryFile(net_param, filename.c_str());
+  SolverState state;
+  SnapshotSolverState(&state);
+  state.set_iter(iter_);
+  state.set_learned_net(filename);
+  filename += ".solverstate";
+  LOG(INFO) << "Snapshotting solver state to " << filename;
+  WriteProtoToBinaryFile(state, filename.c_str());
+}
+
+template <typename Dtype>
+void Solver<Dtype>::Restore(char* state_file) {
+  SolverState state;
+  NetParameter net_param;
+  ReadProtoFromBinaryFile(state_file, &state);
+  ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
+  net_->CopyTrainedLayersFrom(net_param);
+  iter_ = state.iter();
+  RestoreSolverState(state);
 }
 
 
@@ -167,6 +191,24 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
   }
 }
 
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) {
+  state->clear_history();
+  for (int i = 0; i < history_.size(); ++i) {
+    // Add history
+    BlobProto* history_blob = state->add_history();
+    history_[i]->ToProto(history_blob);
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
+  CHECK_EQ(state.history_size(), history_.size())
+      << "Incorrect length of history blobs.";
+  for (int i = 0; i < history_.size(); ++i) {
+    history_[i]->FromProto(state.history(i));
+  }
+}
 
 INSTANTIATE_CLASS(Solver);
 INSTANTIATE_CLASS(SGDSolver);
index 8dc41af..a5ea612 100644 (file)
@@ -12,8 +12,9 @@ class Solver {
  public:
   explicit Solver(const SolverParameter& param)
       : param_(param) {}
-  // The main entry of the solver function.
-  void Solve(Net<Dtype>* net);
+  // The main entry of the solver function. In default, iter will be zero. Pass
+  // in a non-zero iter number to resume training for a pre-trained net.
+  void Solve(Net<Dtype>* net, char* state_file = NULL);
   virtual ~Solver() {}
 
  protected:
@@ -22,7 +23,17 @@ class Solver {
   virtual void PreSolve() {}
   // Get the update value for the current iteration.
   virtual void ComputeUpdateValue() = 0;
+  // The Solver::Snapshot function implements the basic snapshotting utility
+  // that stores the learned net. You should implement the SnapshotSolverState()
+  // function that produces a SolverState protocol buffer that needs to be
+  // written to disk together with the learned net.
   void Snapshot(bool is_final = false);
+  virtual void SnapshotSolverState(SolverState* state) = 0;
+  // The Restore function implements how one should restore the solver to a
+  // previously snapshotted state. You should implement the RestoreSolverState()
+  // function that restores the state from a SolverState protocol buffer.
+  void Restore(char* state_file);
+  virtual void RestoreSolverState(const SolverState& state) = 0;
   SolverParameter param_;
   int iter_;
   Net<Dtype>* net_;
@@ -39,8 +50,10 @@ class SGDSolver : public Solver<Dtype> {
 
  protected:
   virtual void PreSolve();
-  Dtype GetLearningRate();
+  virtual Dtype GetLearningRate();
   virtual void ComputeUpdateValue();
+  virtual void SnapshotSolverState(SolverState * state);
+  virtual void RestoreSolverState(const SolverState& state);
   // history maintains the historical momentum data.
   vector<shared_ptr<Blob<Dtype> > > history_;
 };
index 87f2c2c..4be9696 100644 (file)
@@ -105,4 +105,22 @@ message SolverParameter {
   optional float stepsize = 12; // the stepsize for learning rate policy "step"
 
   optional string snapshot_prefix = 13; // The prefix for the snapshot.
+
+  // Adagrad solver parameters
+  // For Adagrad, we will first run normal sgd using the sgd parameters above
+  // for adagrad_skip iterations, and then kick in the adagrad algorithm, with
+  // the learning rate being adagrad_gamma * adagrad_skip. Note that the adagrad
+  // algorithm will NOT use the learning rate multiplier that is specified in
+  // the layer parameter specifications, as it will adjust the learning rate
+  // of individual parameters in a data-dependent way.
+  //    WORK IN PROGRESS: not actually implemented yet.
+  optional float adagrad_gamma = 14; // adagrad learning rate multiplier
+  optional float adagrad_skip = 15; // the steps to skip before adagrad kicks in
 }
+
+// A message that stores the solver snapshots
+message SolverState {
+  optional int32 iter = 1; // The current iteration
+  optional string learned_net = 2; // The file that stores the learned net.
+  repeated BlobProto history = 3; // The history for sgd solvers
+}
\ No newline at end of file
index b07307b..0dc3476 100644 (file)
@@ -274,6 +274,7 @@ class DataLayer : public Layer<Dtype> {
   pthread_t thread_;
   shared_ptr<Blob<Dtype> > prefetch_data_;
   shared_ptr<Blob<Dtype> > prefetch_label_;
+  Blob<Dtype> data_mean_;
 };