1 #ifndef CAFFE_SOLVER_HPP_
2 #define CAFFE_SOLVER_HPP_
3 #include <boost/function.hpp>
7 #include "caffe/net.hpp"
8 #include "caffe/solver_factory.hpp"
13 * @brief Enumeration of actions that a client of the Solver may request by
14 * implementing the Solver's action request function, which a
15 * client may optionally provide in order to request early termination
16 * or saving a snapshot without exiting. In the executable caffe, this
17 * mechanism is used to allow the snapshot to be saved when stopping
18 * execution with a SIGINT (Ctrl-C).
20 namespace SolverAction {
22 NONE = 0, // Take no special action.
23 STOP = 1, // Stop training. snapshot_after_train controls whether a
24 // snapshot is created.
25 SNAPSHOT = 2 // Take a snapshot, and keep training.
30 * @brief Type of a function that returns a Solver Action enumeration.
32 typedef boost::function<SolverAction::Enum()> ActionCallback;
35 * @brief An interface for classes that perform optimization on Net%s.
37 * Requires implementation of ApplyUpdate to compute a parameter update
38 * given the current state of the Net parameters.
40 template <typename Dtype>
43 explicit Solver(const SolverParameter& param,
44 const Solver* root_solver = NULL);
45 explicit Solver(const string& param_file, const Solver* root_solver = NULL);
46 void Init(const SolverParameter& param);
50 // Client of the Solver optionally may call this in order to set the function
51 // that the solver uses to see what action it should take (e.g. snapshot or
52 // exit training early).
53 void SetActionFunction(ActionCallback func);
54 SolverAction::Enum GetRequestedAction();
55 // The main entry of the solver function. In default, iter will be zero. Pass
56 // in a non-zero iter number to resume training for a pre-trained net.
57 virtual void Solve(const char* resume_file = NULL);
58 inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
60 // The Restore method simply dispatches to one of the
61 // RestoreSolverStateFrom___ protected methods. You should implement these
62 // methods to restore the state from the appropriate snapshot type.
63 void Restore(const char* resume_file);
64 // The Solver::Snapshot function implements the basic snapshotting utility
65 // that stores the learned net. You should implement the SnapshotSolverState()
66 // function that produces a SolverState protocol buffer that needs to be
67 // written to disk together with the learned net.
70 inline const SolverParameter& param() const { return param_; }
71 inline shared_ptr<Net<Dtype> > net() { return net_; }
72 inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
75 int iter() { return iter_; }
77 // Invoked at specific points during an iteration
80 virtual void on_start() = 0;
81 virtual void on_gradients_ready() = 0;
86 const vector<Callback*>& callbacks() const { return callbacks_; }
87 void add_callback(Callback* value) {
88 callbacks_.push_back(value);
91 void CheckSnapshotWritePermissions();
93 * @brief Returns the solver type.
95 virtual inline const char* type() const { return ""; }
98 // Make and apply the update value for the current iteration.
99 virtual void ApplyUpdate() = 0;
100 string SnapshotFilename(const string extension);
101 string SnapshotToBinaryProto();
102 string SnapshotToHDF5();
105 void Test(const int test_net_id = 0);
106 virtual void SnapshotSolverState(const string& model_filename) = 0;
107 virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
108 virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
109 void DisplayOutputBlobs(const int net_id);
110 void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
111 /// Harmonize solver class type with configured proto type.
112 void CheckType(SolverParameter* param);
114 SolverParameter param_;
117 shared_ptr<Net<Dtype> > net_;
118 vector<shared_ptr<Net<Dtype> > > test_nets_;
119 vector<Callback*> callbacks_;
120 vector<Dtype> losses_;
121 Dtype smoothed_loss_;
123 // The root solver that holds root nets (actually containing shared layers)
124 // in data parallelism
125 const Solver* const root_solver_;
127 // A function that can be set by a client of the Solver to provide indication
128 // that it wants a snapshot saved and/or to exit early.
129 ActionCallback action_request_function_;
131 // True iff a request to stop early was received.
132 bool requested_early_exit_;
134 DISABLE_COPY_AND_ASSIGN(Solver);
138 * @brief Solver that only computes gradients, used as worker
139 * for multi-GPU training.
141 template <typename Dtype>
142 class WorkerSolver : public Solver<Dtype> {
144 explicit WorkerSolver(const SolverParameter& param,
145 const Solver<Dtype>* root_solver = NULL)
146 : Solver<Dtype>(param, root_solver) {}
149 void ApplyUpdate() {}
150 void SnapshotSolverState(const string& model_filename) {
151 LOG(FATAL) << "Should not be called on worker solver.";
153 void RestoreSolverStateFromBinaryProto(const string& state_file) {
154 LOG(FATAL) << "Should not be called on worker solver.";
156 void RestoreSolverStateFromHDF5(const string& state_file) {
157 LOG(FATAL) << "Should not be called on worker solver.";
163 #endif // CAFFE_SOLVER_HPP_