solver: check and set type to reconcile class and proto
[platform/upstream/caffeonacl.git] / include / caffe / solver.hpp
1 #ifndef CAFFE_SOLVER_HPP_
2 #define CAFFE_SOLVER_HPP_
3 #include <boost/function.hpp>
4 #include <string>
5 #include <vector>
6
7 #include "caffe/net.hpp"
8 #include "caffe/solver_factory.hpp"
9
10 namespace caffe {
11
12 /**
13   * @brief Enumeration of actions that a client of the Solver may request by
14   * implementing the Solver's action request function, which a
15   * client may optionally provide in order to request early termination
16   * or saving a snapshot without exiting. In the executable caffe, this
17   * mechanism is used to allow the snapshot to be saved when stopping
18   * execution with a SIGINT (Ctrl-C).
19   */
20   namespace SolverAction {
21     enum Enum {
22       NONE = 0,  // Take no special action.
23       STOP = 1,  // Stop training. snapshot_after_train controls whether a
24                  // snapshot is created.
25       SNAPSHOT = 2  // Take a snapshot, and keep training.
26     };
27   }
28
29 /**
30  * @brief Type of a function that returns a Solver Action enumeration.
31  */
32 typedef boost::function<SolverAction::Enum()> ActionCallback;
33
34 /**
35  * @brief An interface for classes that perform optimization on Net%s.
36  *
37  * Requires implementation of ApplyUpdate to compute a parameter update
38  * given the current state of the Net parameters.
39  */
40 template <typename Dtype>
41 class Solver {
42  public:
43   explicit Solver(const SolverParameter& param,
44       const Solver* root_solver = NULL);
45   explicit Solver(const string& param_file, const Solver* root_solver = NULL);
46   void Init(const SolverParameter& param);
47   void InitTrainNet();
48   void InitTestNets();
49
50   // Client of the Solver optionally may call this in order to set the function
51   // that the solver uses to see what action it should take (e.g. snapshot or
52   // exit training early).
53   void SetActionFunction(ActionCallback func);
54   SolverAction::Enum GetRequestedAction();
55   // The main entry of the solver function. In default, iter will be zero. Pass
56   // in a non-zero iter number to resume training for a pre-trained net.
57   virtual void Solve(const char* resume_file = NULL);
58   inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
59   void Step(int iters);
60   // The Restore method simply dispatches to one of the
61   // RestoreSolverStateFrom___ protected methods. You should implement these
62   // methods to restore the state from the appropriate snapshot type.
63   void Restore(const char* resume_file);
64   // The Solver::Snapshot function implements the basic snapshotting utility
65   // that stores the learned net. You should implement the SnapshotSolverState()
66   // function that produces a SolverState protocol buffer that needs to be
67   // written to disk together with the learned net.
68   void Snapshot();
69   virtual ~Solver() {}
70   inline const SolverParameter& param() const { return param_; }
71   inline shared_ptr<Net<Dtype> > net() { return net_; }
72   inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
73     return test_nets_;
74   }
75   int iter() { return iter_; }
76
77   // Invoked at specific points during an iteration
78   class Callback {
79    protected:
80     virtual void on_start() = 0;
81     virtual void on_gradients_ready() = 0;
82
83     template <typename T>
84     friend class Solver;
85   };
86   const vector<Callback*>& callbacks() const { return callbacks_; }
87   void add_callback(Callback* value) {
88     callbacks_.push_back(value);
89   }
90
91   void CheckSnapshotWritePermissions();
92   /**
93    * @brief Returns the solver type.
94    */
95   virtual inline const char* type() const { return ""; }
96
97  protected:
98   // Make and apply the update value for the current iteration.
99   virtual void ApplyUpdate() = 0;
100   string SnapshotFilename(const string extension);
101   string SnapshotToBinaryProto();
102   string SnapshotToHDF5();
103   // The test routine
104   void TestAll();
105   void Test(const int test_net_id = 0);
106   virtual void SnapshotSolverState(const string& model_filename) = 0;
107   virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
108   virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
109   void DisplayOutputBlobs(const int net_id);
110   void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
111   /// Harmonize solver class type with configured proto type.
112   void CheckType(SolverParameter* param);
113
114   SolverParameter param_;
115   int iter_;
116   int current_step_;
117   shared_ptr<Net<Dtype> > net_;
118   vector<shared_ptr<Net<Dtype> > > test_nets_;
119   vector<Callback*> callbacks_;
120   vector<Dtype> losses_;
121   Dtype smoothed_loss_;
122
123   // The root solver that holds root nets (actually containing shared layers)
124   // in data parallelism
125   const Solver* const root_solver_;
126
127   // A function that can be set by a client of the Solver to provide indication
128   // that it wants a snapshot saved and/or to exit early.
129   ActionCallback action_request_function_;
130
131   // True iff a request to stop early was received.
132   bool requested_early_exit_;
133
134   DISABLE_COPY_AND_ASSIGN(Solver);
135 };
136
137 /**
138  * @brief Solver that only computes gradients, used as worker
139  *        for multi-GPU training.
140  */
141 template <typename Dtype>
142 class WorkerSolver : public Solver<Dtype> {
143  public:
144   explicit WorkerSolver(const SolverParameter& param,
145       const Solver<Dtype>* root_solver = NULL)
146       : Solver<Dtype>(param, root_solver) {}
147
148  protected:
149   void ApplyUpdate() {}
150   void SnapshotSolverState(const string& model_filename) {
151     LOG(FATAL) << "Should not be called on worker solver.";
152   }
153   void RestoreSolverStateFromBinaryProto(const string& state_file) {
154     LOG(FATAL) << "Should not be called on worker solver.";
155   }
156   void RestoreSolverStateFromHDF5(const string& state_file) {
157     LOG(FATAL) << "Should not be called on worker solver.";
158   }
159 };
160
161 }  // namespace caffe
162
163 #endif  // CAFFE_SOLVER_HPP_