6 #include "caffe/solver.hpp"
7 #include "caffe/util/format.hpp"
8 #include "caffe/util/hdf5.hpp"
9 #include "caffe/util/io.hpp"
10 #include "caffe/util/upgrade_proto.hpp"
14 template<typename Dtype>
15 void Solver<Dtype>::SetActionFunction(ActionCallback func) {
16 action_request_function_ = func;
19 template<typename Dtype>
20 SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
21 if (action_request_function_) {
22 // If the external request function has been set, call it.
23 return action_request_function_();
25 return SolverAction::NONE;
28 template <typename Dtype>
29 Solver<Dtype>::Solver(const SolverParameter& param)
30 : net_(), callbacks_(), requested_early_exit_(false) {
34 template <typename Dtype>
35 Solver<Dtype>::Solver(const string& param_file)
36 : net_(), callbacks_(), requested_early_exit_(false) {
37 SolverParameter param;
38 ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
42 template <typename Dtype>
43 void Solver<Dtype>::Init(const SolverParameter& param) {
44 LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
45 << std::endl << param.DebugString();
47 CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
48 CheckSnapshotWritePermissions();
49 if (param_.random_seed() >= 0) {
50 Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank());
54 if (Caffe::root_solver()) {
56 LOG(INFO) << "Solver scaffolding done.";
62 template <typename Dtype>
63 void Solver<Dtype>::InitTrainNet() {
64 const int num_train_nets = param_.has_net() + param_.has_net_param() +
65 param_.has_train_net() + param_.has_train_net_param();
66 const string& field_names = "net, net_param, train_net, train_net_param";
67 CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
68 << "using one of these fields: " << field_names;
69 CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
70 << "one of these fields specifying a train_net: " << field_names;
71 NetParameter net_param;
72 if (param_.has_train_net_param()) {
73 LOG_IF(INFO, Caffe::root_solver())
74 << "Creating training net specified in train_net_param.";
75 net_param.CopyFrom(param_.train_net_param());
76 } else if (param_.has_train_net()) {
77 LOG_IF(INFO, Caffe::root_solver())
78 << "Creating training net from train_net file: " << param_.train_net();
79 ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
81 if (param_.has_net_param()) {
82 LOG_IF(INFO, Caffe::root_solver())
83 << "Creating training net specified in net_param.";
84 net_param.CopyFrom(param_.net_param());
86 if (param_.has_net()) {
87 LOG_IF(INFO, Caffe::root_solver())
88 << "Creating training net from net file: " << param_.net();
89 ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
91 // Set the correct NetState. We start with the solver defaults (lowest
92 // precedence); then, merge in any NetState specified by the net_param itself;
93 // finally, merge in any NetState specified by the train_state (highest
96 net_state.set_phase(TRAIN);
97 net_state.MergeFrom(net_param.state());
98 net_state.MergeFrom(param_.train_state());
99 net_param.mutable_state()->CopyFrom(net_state);
100 net_.reset(new Net<Dtype>(net_param));
103 template <typename Dtype>
104 void Solver<Dtype>::InitTestNets() {
105 CHECK(Caffe::root_solver());
106 const bool has_net_param = param_.has_net_param();
107 const bool has_net_file = param_.has_net();
108 const int num_generic_nets = has_net_param + has_net_file;
109 CHECK_LE(num_generic_nets, 1)
110 << "Both net_param and net_file may not be specified.";
111 const int num_test_net_params = param_.test_net_param_size();
112 const int num_test_net_files = param_.test_net_size();
113 const int num_test_nets = num_test_net_params + num_test_net_files;
114 if (num_generic_nets) {
115 CHECK_GE(param_.test_iter_size(), num_test_nets)
116 << "test_iter must be specified for each test network.";
118 CHECK_EQ(param_.test_iter_size(), num_test_nets)
119 << "test_iter must be specified for each test network.";
121 // If we have a generic net (specified by net or net_param, rather than
122 // test_net or test_net_param), we may have an unlimited number of actual
123 // test networks -- the actual number is given by the number of remaining
124 // test_iters after any test nets specified by test_net_param and/or test_net
126 const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
127 const int num_test_net_instances = num_test_nets + num_generic_net_instances;
128 if (param_.test_state_size()) {
129 CHECK_EQ(param_.test_state_size(), num_test_net_instances)
130 << "test_state must be unspecified or specified once per test net.";
132 if (num_test_net_instances) {
133 CHECK_GT(param_.test_interval(), 0);
136 vector<string> sources(num_test_net_instances);
137 vector<NetParameter> net_params(num_test_net_instances);
138 for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
139 sources[test_net_id] = "test_net_param";
140 net_params[test_net_id].CopyFrom(param_.test_net_param(i));
142 for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
143 sources[test_net_id] = "test_net file: " + param_.test_net(i);
144 ReadNetParamsFromTextFileOrDie(param_.test_net(i),
145 &net_params[test_net_id]);
147 const int remaining_test_nets = param_.test_iter_size() - test_net_id;
149 for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
150 sources[test_net_id] = "net_param";
151 net_params[test_net_id].CopyFrom(param_.net_param());
155 for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
156 sources[test_net_id] = "net file: " + param_.net();
157 ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
160 test_nets_.resize(num_test_net_instances);
161 for (int i = 0; i < num_test_net_instances; ++i) {
162 // Set the correct NetState. We start with the solver defaults (lowest
163 // precedence); then, merge in any NetState specified by the net_param
164 // itself; finally, merge in any NetState specified by the test_state
165 // (highest precedence).
167 net_state.set_phase(TEST);
168 net_state.MergeFrom(net_params[i].state());
169 if (param_.test_state_size()) {
170 net_state.MergeFrom(param_.test_state(i));
172 net_params[i].mutable_state()->CopyFrom(net_state);
174 << "Creating test net (#" << i << ") specified by " << sources[i];
175 test_nets_[i].reset(new Net<Dtype>(net_params[i]));
176 test_nets_[i]->set_debug_info(param_.debug_info());
180 template <typename Dtype>
181 void Solver<Dtype>::Step(int iters) {
182 const int start_iter = iter_;
183 const int stop_iter = iter_ + iters;
184 int average_loss = this->param_.average_loss();
187 iteration_timer_.Start();
189 while (iter_ < stop_iter) {
190 // zero-init the params
191 net_->ClearParamDiffs();
192 if (param_.test_interval() && iter_ % param_.test_interval() == 0
193 && (iter_ > 0 || param_.test_initialization())) {
194 if (Caffe::root_solver()) {
197 if (requested_early_exit_) {
198 // Break out of the while loop because stop was requested while testing.
203 for (int i = 0; i < callbacks_.size(); ++i) {
204 callbacks_[i]->on_start();
206 const bool display = param_.display() && iter_ % param_.display() == 0;
207 net_->set_debug_info(display && param_.debug_info());
208 // accumulate the loss and gradient
210 for (int i = 0; i < param_.iter_size(); ++i) {
211 loss += net_->ForwardBackward();
213 loss /= param_.iter_size();
214 // average the loss across iterations for smoothed reporting
215 UpdateSmoothedLoss(loss, start_iter, average_loss);
217 float lapse = iteration_timer_.Seconds();
218 float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1);
219 LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
220 << " (" << per_s << " iter/s, " << lapse << "s/"
221 << param_.display() << " iters), loss = " << smoothed_loss_;
222 iteration_timer_.Start();
223 iterations_last_ = iter_;
224 const vector<Blob<Dtype>*>& result = net_->output_blobs();
226 for (int j = 0; j < result.size(); ++j) {
227 const Dtype* result_vec = result[j]->cpu_data();
228 const string& output_name =
229 net_->blob_names()[net_->output_blob_indices()[j]];
230 const Dtype loss_weight =
231 net_->blob_loss_weights()[net_->output_blob_indices()[j]];
232 for (int k = 0; k < result[j]->count(); ++k) {
233 ostringstream loss_msg_stream;
235 loss_msg_stream << " (* " << loss_weight
236 << " = " << loss_weight * result_vec[k] << " loss)";
238 LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
239 << score_index++ << ": " << output_name << " = "
240 << result_vec[k] << loss_msg_stream.str();
244 for (int i = 0; i < callbacks_.size(); ++i) {
245 callbacks_[i]->on_gradients_ready();
249 // Increment the internal iter_ counter -- its value should always indicate
250 // the number of times the weights have been updated.
253 SolverAction::Enum request = GetRequestedAction();
255 // Save a snapshot if needed.
256 if ((param_.snapshot()
257 && iter_ % param_.snapshot() == 0
258 && Caffe::root_solver()) ||
259 (request == SolverAction::SNAPSHOT)) {
262 if (SolverAction::STOP == request) {
263 requested_early_exit_ = true;
264 // Break out of training loop.
270 template <typename Dtype>
271 void Solver<Dtype>::Solve(const char* resume_file) {
272 CHECK(Caffe::root_solver());
273 LOG(INFO) << "Solving " << net_->name();
274 LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
276 // Initialize to false every time we start solving.
277 requested_early_exit_ = false;
280 LOG(INFO) << "Restoring previous solver status from " << resume_file;
281 Restore(resume_file);
284 // For a network that is trained by the solver, no bottom or top vecs
285 // should be given, and we will just provide dummy vecs.
286 int start_iter = iter_;
287 Step(param_.max_iter() - iter_);
288 // If we haven't already, save a snapshot after optimization, unless
289 // overridden by setting snapshot_after_train := false
290 if (param_.snapshot_after_train()
291 && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
294 if (requested_early_exit_) {
295 LOG(INFO) << "Optimization stopped early.";
298 // After the optimization is done, run an additional train and test pass to
299 // display the train and test loss/outputs if appropriate (based on the
300 // display and test_interval settings, respectively). Unlike in the rest of
301 // training, for the train net we only run a forward pass as we've already
302 // updated the parameters "max_iter" times -- this final pass is only done to
303 // display the loss, which is computed in the forward pass.
304 if (param_.display() && iter_ % param_.display() == 0) {
305 int average_loss = this->param_.average_loss();
307 net_->Forward(&loss);
309 UpdateSmoothedLoss(loss, start_iter, average_loss);
311 LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
313 if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
316 LOG(INFO) << "Optimization Done.";
319 template <typename Dtype>
320 void Solver<Dtype>::TestAll() {
321 for (int test_net_id = 0;
322 test_net_id < test_nets_.size() && !requested_early_exit_;
328 template <typename Dtype>
329 void Solver<Dtype>::Test(const int test_net_id) {
330 CHECK(Caffe::root_solver());
331 LOG(INFO) << "Iteration " << iter_
332 << ", Testing net (#" << test_net_id << ")";
333 CHECK_NOTNULL(test_nets_[test_net_id].get())->
334 ShareTrainedLayersWith(net_.get());
335 vector<Dtype> test_score;
336 vector<int> test_score_output_id;
337 const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
339 for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
340 SolverAction::Enum request = GetRequestedAction();
341 // Check to see if stoppage of testing/training has been requested.
342 while (request != SolverAction::NONE) {
343 if (SolverAction::SNAPSHOT == request) {
345 } else if (SolverAction::STOP == request) {
346 requested_early_exit_ = true;
348 request = GetRequestedAction();
350 if (requested_early_exit_) {
351 // break out of test loop.
356 const vector<Blob<Dtype>*>& result =
357 test_net->Forward(&iter_loss);
358 if (param_.test_compute_loss()) {
362 for (int j = 0; j < result.size(); ++j) {
363 const Dtype* result_vec = result[j]->cpu_data();
364 for (int k = 0; k < result[j]->count(); ++k) {
365 test_score.push_back(result_vec[k]);
366 test_score_output_id.push_back(j);
371 for (int j = 0; j < result.size(); ++j) {
372 const Dtype* result_vec = result[j]->cpu_data();
373 for (int k = 0; k < result[j]->count(); ++k) {
374 test_score[idx++] += result_vec[k];
379 if (requested_early_exit_) {
380 LOG(INFO) << "Test interrupted.";
383 if (param_.test_compute_loss()) {
384 loss /= param_.test_iter(test_net_id);
385 LOG(INFO) << "Test loss: " << loss;
387 for (int i = 0; i < test_score.size(); ++i) {
388 const int output_blob_index =
389 test_net->output_blob_indices()[test_score_output_id[i]];
390 const string& output_name = test_net->blob_names()[output_blob_index];
391 const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
392 ostringstream loss_msg_stream;
393 const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
395 loss_msg_stream << " (* " << loss_weight
396 << " = " << loss_weight * mean_score << " loss)";
398 LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
399 << mean_score << loss_msg_stream.str();
403 template <typename Dtype>
404 void Solver<Dtype>::Snapshot() {
405 CHECK(Caffe::root_solver());
406 string model_filename;
407 switch (param_.snapshot_format()) {
408 case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
409 model_filename = SnapshotToBinaryProto();
411 case caffe::SolverParameter_SnapshotFormat_HDF5:
412 model_filename = SnapshotToHDF5();
415 LOG(FATAL) << "Unsupported snapshot format.";
418 SnapshotSolverState(model_filename);
421 template <typename Dtype>
422 void Solver<Dtype>::CheckSnapshotWritePermissions() {
423 if (Caffe::root_solver() && param_.snapshot()) {
424 CHECK(param_.has_snapshot_prefix())
425 << "In solver params, snapshot is specified but snapshot_prefix is not";
426 string probe_filename = SnapshotFilename(".tempfile");
427 std::ofstream probe_ofs(probe_filename.c_str());
428 if (probe_ofs.good()) {
430 std::remove(probe_filename.c_str());
432 LOG(FATAL) << "Cannot write to snapshot prefix '"
433 << param_.snapshot_prefix() << "'. Make sure "
434 << "that the directory exists and is writeable.";
439 template <typename Dtype>
440 string Solver<Dtype>::SnapshotFilename(const string extension) {
441 return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
445 template <typename Dtype>
446 string Solver<Dtype>::SnapshotToBinaryProto() {
447 string model_filename = SnapshotFilename(".caffemodel");
448 LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
449 NetParameter net_param;
450 net_->ToProto(&net_param, param_.snapshot_diff());
451 WriteProtoToBinaryFile(net_param, model_filename);
452 return model_filename;
455 template <typename Dtype>
456 string Solver<Dtype>::SnapshotToHDF5() {
457 string model_filename = SnapshotFilename(".caffemodel.h5");
458 LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
459 net_->ToHDF5(model_filename, param_.snapshot_diff());
460 return model_filename;
463 template <typename Dtype>
464 void Solver<Dtype>::Restore(const char* state_file) {
465 CHECK(Caffe::root_solver());
466 string state_filename(state_file);
467 if (state_filename.size() >= 3 &&
468 state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
469 RestoreSolverStateFromHDF5(state_filename);
471 RestoreSolverStateFromBinaryProto(state_filename);
475 template <typename Dtype>
476 void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
478 if (losses_.size() < average_loss) {
479 losses_.push_back(loss);
480 int size = losses_.size();
481 smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
483 int idx = (iter_ - start_iter) % average_loss;
484 smoothed_loss_ += (loss - losses_[idx]) / average_loss;
489 INSTANTIATE_CLASS(Solver);