6 #include "caffe/solver.hpp"
7 #include "caffe/util/format.hpp"
8 #include "caffe/util/hdf5.hpp"
9 #include "caffe/util/io.hpp"
10 #include "caffe/util/upgrade_proto.hpp"
14 template<typename Dtype>
15 void Solver<Dtype>::SetActionFunction(ActionCallback func) {
16 action_request_function_ = func;
19 template<typename Dtype>
20 SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
21 if (action_request_function_) {
22 // If the external request function has been set, call it.
23 return action_request_function_();
25 return SolverAction::NONE;
28 template <typename Dtype>
29 Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
30 : net_(), callbacks_(), root_solver_(root_solver),
31 requested_early_exit_(false) {
35 template <typename Dtype>
36 Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
37 : net_(), callbacks_(), root_solver_(root_solver),
38 requested_early_exit_(false) {
39 SolverParameter param;
40 ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
45 template <typename Dtype>
46 void Solver<Dtype>::CheckType(SolverParameter* param) {
47 // Harmonize solver class type with configured type to avoid confusion.
48 if (param->has_type()) {
49 CHECK_EQ(param->type(), this->type())
50 << "Solver type must agree with instantiated solver class.";
52 param->set_type(this->type());
56 template <typename Dtype>
57 void Solver<Dtype>::Init(const SolverParameter& param) {
58 CHECK(Caffe::root_solver() || root_solver_)
59 << "root_solver_ needs to be set for all non-root solvers";
60 LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
61 << std::endl << param.DebugString();
63 CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
64 CheckSnapshotWritePermissions();
65 if (Caffe::root_solver() && param_.random_seed() >= 0) {
66 Caffe::set_random_seed(param_.random_seed());
70 if (Caffe::root_solver()) {
72 LOG(INFO) << "Solver scaffolding done.";
78 template <typename Dtype>
79 void Solver<Dtype>::InitTrainNet() {
80 const int num_train_nets = param_.has_net() + param_.has_net_param() +
81 param_.has_train_net() + param_.has_train_net_param();
82 const string& field_names = "net, net_param, train_net, train_net_param";
83 CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
84 << "using one of these fields: " << field_names;
85 CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
86 << "one of these fields specifying a train_net: " << field_names;
87 NetParameter net_param;
88 if (param_.has_train_net_param()) {
89 LOG_IF(INFO, Caffe::root_solver())
90 << "Creating training net specified in train_net_param.";
91 net_param.CopyFrom(param_.train_net_param());
92 } else if (param_.has_train_net()) {
93 LOG_IF(INFO, Caffe::root_solver())
94 << "Creating training net from train_net file: " << param_.train_net();
95 ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
97 if (param_.has_net_param()) {
98 LOG_IF(INFO, Caffe::root_solver())
99 << "Creating training net specified in net_param.";
100 net_param.CopyFrom(param_.net_param());
102 if (param_.has_net()) {
103 LOG_IF(INFO, Caffe::root_solver())
104 << "Creating training net from net file: " << param_.net();
105 ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
107 // Set the correct NetState. We start with the solver defaults (lowest
108 // precedence); then, merge in any NetState specified by the net_param itself;
109 // finally, merge in any NetState specified by the train_state (highest
112 net_state.set_phase(TRAIN);
113 net_state.MergeFrom(net_param.state());
114 net_state.MergeFrom(param_.train_state());
115 net_param.mutable_state()->CopyFrom(net_state);
116 if (Caffe::root_solver()) {
117 net_.reset(new Net<Dtype>(net_param));
119 net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
123 template <typename Dtype>
124 void Solver<Dtype>::InitTestNets() {
125 CHECK(Caffe::root_solver());
126 const bool has_net_param = param_.has_net_param();
127 const bool has_net_file = param_.has_net();
128 const int num_generic_nets = has_net_param + has_net_file;
129 CHECK_LE(num_generic_nets, 1)
130 << "Both net_param and net_file may not be specified.";
131 const int num_test_net_params = param_.test_net_param_size();
132 const int num_test_net_files = param_.test_net_size();
133 const int num_test_nets = num_test_net_params + num_test_net_files;
134 if (num_generic_nets) {
135 CHECK_GE(param_.test_iter_size(), num_test_nets)
136 << "test_iter must be specified for each test network.";
138 CHECK_EQ(param_.test_iter_size(), num_test_nets)
139 << "test_iter must be specified for each test network.";
141 // If we have a generic net (specified by net or net_param, rather than
142 // test_net or test_net_param), we may have an unlimited number of actual
143 // test networks -- the actual number is given by the number of remaining
144 // test_iters after any test nets specified by test_net_param and/or test_net
146 const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
147 const int num_test_net_instances = num_test_nets + num_generic_net_instances;
148 if (param_.test_state_size()) {
149 CHECK_EQ(param_.test_state_size(), num_test_net_instances)
150 << "test_state must be unspecified or specified once per test net.";
152 if (num_test_net_instances) {
153 CHECK_GT(param_.test_interval(), 0);
156 vector<string> sources(num_test_net_instances);
157 vector<NetParameter> net_params(num_test_net_instances);
158 for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
159 sources[test_net_id] = "test_net_param";
160 net_params[test_net_id].CopyFrom(param_.test_net_param(i));
162 for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
163 sources[test_net_id] = "test_net file: " + param_.test_net(i);
164 ReadNetParamsFromTextFileOrDie(param_.test_net(i),
165 &net_params[test_net_id]);
167 const int remaining_test_nets = param_.test_iter_size() - test_net_id;
169 for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
170 sources[test_net_id] = "net_param";
171 net_params[test_net_id].CopyFrom(param_.net_param());
175 for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
176 sources[test_net_id] = "net file: " + param_.net();
177 ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
180 test_nets_.resize(num_test_net_instances);
181 for (int i = 0; i < num_test_net_instances; ++i) {
182 // Set the correct NetState. We start with the solver defaults (lowest
183 // precedence); then, merge in any NetState specified by the net_param
184 // itself; finally, merge in any NetState specified by the test_state
185 // (highest precedence).
187 net_state.set_phase(TEST);
188 net_state.MergeFrom(net_params[i].state());
189 if (param_.test_state_size()) {
190 net_state.MergeFrom(param_.test_state(i));
192 net_params[i].mutable_state()->CopyFrom(net_state);
194 << "Creating test net (#" << i << ") specified by " << sources[i];
195 if (Caffe::root_solver()) {
196 test_nets_[i].reset(new Net<Dtype>(net_params[i]));
198 test_nets_[i].reset(new Net<Dtype>(net_params[i],
199 root_solver_->test_nets_[i].get()));
201 test_nets_[i]->set_debug_info(param_.debug_info());
205 template <typename Dtype>
206 void Solver<Dtype>::Step(int iters) {
207 const int start_iter = iter_;
208 const int stop_iter = iter_ + iters;
209 int average_loss = this->param_.average_loss();
213 while (iter_ < stop_iter) {
214 // zero-init the params
215 net_->ClearParamDiffs();
216 if (param_.test_interval() && iter_ % param_.test_interval() == 0
217 && (iter_ > 0 || param_.test_initialization())
218 && Caffe::root_solver()) {
220 if (requested_early_exit_) {
221 // Break out of the while loop because stop was requested while testing.
226 for (int i = 0; i < callbacks_.size(); ++i) {
227 callbacks_[i]->on_start();
229 const bool display = param_.display() && iter_ % param_.display() == 0;
230 net_->set_debug_info(display && param_.debug_info());
231 // accumulate the loss and gradient
233 for (int i = 0; i < param_.iter_size(); ++i) {
234 loss += net_->ForwardBackward();
236 loss /= param_.iter_size();
237 // average the loss across iterations for smoothed reporting
238 UpdateSmoothedLoss(loss, start_iter, average_loss);
240 LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
241 << ", loss = " << smoothed_loss_;
242 const vector<Blob<Dtype>*>& result = net_->output_blobs();
244 for (int j = 0; j < result.size(); ++j) {
245 const Dtype* result_vec = result[j]->cpu_data();
246 const string& output_name =
247 net_->blob_names()[net_->output_blob_indices()[j]];
248 const Dtype loss_weight =
249 net_->blob_loss_weights()[net_->output_blob_indices()[j]];
250 for (int k = 0; k < result[j]->count(); ++k) {
251 ostringstream loss_msg_stream;
253 loss_msg_stream << " (* " << loss_weight
254 << " = " << loss_weight * result_vec[k] << " loss)";
256 LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
257 << score_index++ << ": " << output_name << " = "
258 << result_vec[k] << loss_msg_stream.str();
262 for (int i = 0; i < callbacks_.size(); ++i) {
263 callbacks_[i]->on_gradients_ready();
267 // Increment the internal iter_ counter -- its value should always indicate
268 // the number of times the weights have been updated.
271 SolverAction::Enum request = GetRequestedAction();
273 // Save a snapshot if needed.
274 if ((param_.snapshot()
275 && iter_ % param_.snapshot() == 0
276 && Caffe::root_solver()) ||
277 (request == SolverAction::SNAPSHOT)) {
280 if (SolverAction::STOP == request) {
281 requested_early_exit_ = true;
282 // Break out of training loop.
288 template <typename Dtype>
289 void Solver<Dtype>::Solve(const char* resume_file) {
290 CHECK(Caffe::root_solver());
291 LOG(INFO) << "Solving " << net_->name();
292 LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
294 // Initialize to false every time we start solving.
295 requested_early_exit_ = false;
298 LOG(INFO) << "Restoring previous solver status from " << resume_file;
299 Restore(resume_file);
302 // For a network that is trained by the solver, no bottom or top vecs
303 // should be given, and we will just provide dummy vecs.
304 int start_iter = iter_;
305 Step(param_.max_iter() - iter_);
306 // If we haven't already, save a snapshot after optimization, unless
307 // overridden by setting snapshot_after_train := false
308 if (param_.snapshot_after_train()
309 && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
312 if (requested_early_exit_) {
313 LOG(INFO) << "Optimization stopped early.";
316 // After the optimization is done, run an additional train and test pass to
317 // display the train and test loss/outputs if appropriate (based on the
318 // display and test_interval settings, respectively). Unlike in the rest of
319 // training, for the train net we only run a forward pass as we've already
320 // updated the parameters "max_iter" times -- this final pass is only done to
321 // display the loss, which is computed in the forward pass.
322 if (param_.display() && iter_ % param_.display() == 0) {
323 int average_loss = this->param_.average_loss();
325 net_->Forward(&loss);
327 UpdateSmoothedLoss(loss, start_iter, average_loss);
329 LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
331 if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
334 LOG(INFO) << "Optimization Done.";
337 template <typename Dtype>
338 void Solver<Dtype>::TestAll() {
339 for (int test_net_id = 0;
340 test_net_id < test_nets_.size() && !requested_early_exit_;
346 template <typename Dtype>
347 void Solver<Dtype>::Test(const int test_net_id) {
348 CHECK(Caffe::root_solver());
349 LOG(INFO) << "Iteration " << iter_
350 << ", Testing net (#" << test_net_id << ")";
351 CHECK_NOTNULL(test_nets_[test_net_id].get())->
352 ShareTrainedLayersWith(net_.get());
353 vector<Dtype> test_score;
354 vector<int> test_score_output_id;
355 const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
357 for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
358 SolverAction::Enum request = GetRequestedAction();
359 // Check to see if stoppage of testing/training has been requested.
360 while (request != SolverAction::NONE) {
361 if (SolverAction::SNAPSHOT == request) {
363 } else if (SolverAction::STOP == request) {
364 requested_early_exit_ = true;
366 request = GetRequestedAction();
368 if (requested_early_exit_) {
369 // break out of test loop.
374 const vector<Blob<Dtype>*>& result =
375 test_net->Forward(&iter_loss);
376 if (param_.test_compute_loss()) {
380 for (int j = 0; j < result.size(); ++j) {
381 const Dtype* result_vec = result[j]->cpu_data();
382 for (int k = 0; k < result[j]->count(); ++k) {
383 test_score.push_back(result_vec[k]);
384 test_score_output_id.push_back(j);
389 for (int j = 0; j < result.size(); ++j) {
390 const Dtype* result_vec = result[j]->cpu_data();
391 for (int k = 0; k < result[j]->count(); ++k) {
392 test_score[idx++] += result_vec[k];
397 if (requested_early_exit_) {
398 LOG(INFO) << "Test interrupted.";
401 if (param_.test_compute_loss()) {
402 loss /= param_.test_iter(test_net_id);
403 LOG(INFO) << "Test loss: " << loss;
405 for (int i = 0; i < test_score.size(); ++i) {
406 const int output_blob_index =
407 test_net->output_blob_indices()[test_score_output_id[i]];
408 const string& output_name = test_net->blob_names()[output_blob_index];
409 const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
410 ostringstream loss_msg_stream;
411 const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
413 loss_msg_stream << " (* " << loss_weight
414 << " = " << loss_weight * mean_score << " loss)";
416 LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
417 << mean_score << loss_msg_stream.str();
421 template <typename Dtype>
422 void Solver<Dtype>::Snapshot() {
423 CHECK(Caffe::root_solver());
424 string model_filename;
425 switch (param_.snapshot_format()) {
426 case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
427 model_filename = SnapshotToBinaryProto();
429 case caffe::SolverParameter_SnapshotFormat_HDF5:
430 model_filename = SnapshotToHDF5();
433 LOG(FATAL) << "Unsupported snapshot format.";
436 SnapshotSolverState(model_filename);
439 template <typename Dtype>
440 void Solver<Dtype>::CheckSnapshotWritePermissions() {
441 if (Caffe::root_solver() && param_.snapshot()) {
442 CHECK(param_.has_snapshot_prefix())
443 << "In solver params, snapshot is specified but snapshot_prefix is not";
444 string probe_filename = SnapshotFilename(".tempfile");
445 std::ofstream probe_ofs(probe_filename.c_str());
446 if (probe_ofs.good()) {
448 std::remove(probe_filename.c_str());
450 LOG(FATAL) << "Cannot write to snapshot prefix '"
451 << param_.snapshot_prefix() << "'. Make sure "
452 << "that the directory exists and is writeable.";
457 template <typename Dtype>
458 string Solver<Dtype>::SnapshotFilename(const string extension) {
459 return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
463 template <typename Dtype>
464 string Solver<Dtype>::SnapshotToBinaryProto() {
465 string model_filename = SnapshotFilename(".caffemodel");
466 LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
467 NetParameter net_param;
468 net_->ToProto(&net_param, param_.snapshot_diff());
469 WriteProtoToBinaryFile(net_param, model_filename);
470 return model_filename;
473 template <typename Dtype>
474 string Solver<Dtype>::SnapshotToHDF5() {
475 string model_filename = SnapshotFilename(".caffemodel.h5");
476 LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
477 net_->ToHDF5(model_filename, param_.snapshot_diff());
478 return model_filename;
481 template <typename Dtype>
482 void Solver<Dtype>::Restore(const char* state_file) {
483 CHECK(Caffe::root_solver());
484 string state_filename(state_file);
485 if (state_filename.size() >= 3 &&
486 state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
487 RestoreSolverStateFromHDF5(state_filename);
489 RestoreSolverStateFromBinaryProto(state_filename);
493 template <typename Dtype>
494 void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
496 if (losses_.size() < average_loss) {
497 losses_.push_back(loss);
498 int size = losses_.size();
499 smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
501 int idx = (iter_ - start_iter) % average_loss;
502 smoothed_loss_ += (loss - losses_[idx]) / average_loss;
507 INSTANTIATE_CLASS(Solver);