6 #include "caffe/solver.hpp"
7 #include "caffe/util/format.hpp"
8 #include "caffe/util/hdf5.hpp"
9 #include "caffe/util/io.hpp"
10 #include "caffe/util/upgrade_proto.hpp"
14 template<typename Dtype>
15 void Solver<Dtype>::SetActionFunction(ActionCallback func) {
16 action_request_function_ = func;
19 template<typename Dtype>
20 SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
21 if (action_request_function_) {
22 // If the external request function has been set, call it.
23 return action_request_function_();
25 return SolverAction::NONE;
28 template <typename Dtype>
29 Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
30 : net_(), callbacks_(), root_solver_(root_solver),
31 requested_early_exit_(false) {
35 template <typename Dtype>
36 Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
37 : net_(), callbacks_(), root_solver_(root_solver),
38 requested_early_exit_(false) {
39 SolverParameter param;
40 ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
44 template <typename Dtype>
45 void Solver<Dtype>::Init(const SolverParameter& param) {
46 CHECK(Caffe::root_solver() || root_solver_)
47 << "root_solver_ needs to be set for all non-root solvers";
48 LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
49 << std::endl << param.DebugString();
51 CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
52 CheckSnapshotWritePermissions();
53 if (Caffe::root_solver() && param_.random_seed() >= 0) {
54 Caffe::set_random_seed(param_.random_seed());
58 if (Caffe::root_solver()) {
60 LOG(INFO) << "Solver scaffolding done.";
66 template <typename Dtype>
67 void Solver<Dtype>::InitTrainNet() {
68 const int num_train_nets = param_.has_net() + param_.has_net_param() +
69 param_.has_train_net() + param_.has_train_net_param();
70 const string& field_names = "net, net_param, train_net, train_net_param";
71 CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
72 << "using one of these fields: " << field_names;
73 CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
74 << "one of these fields specifying a train_net: " << field_names;
75 NetParameter net_param;
76 if (param_.has_train_net_param()) {
77 LOG_IF(INFO, Caffe::root_solver())
78 << "Creating training net specified in train_net_param.";
79 net_param.CopyFrom(param_.train_net_param());
80 } else if (param_.has_train_net()) {
81 LOG_IF(INFO, Caffe::root_solver())
82 << "Creating training net from train_net file: " << param_.train_net();
83 ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
85 if (param_.has_net_param()) {
86 LOG_IF(INFO, Caffe::root_solver())
87 << "Creating training net specified in net_param.";
88 net_param.CopyFrom(param_.net_param());
90 if (param_.has_net()) {
91 LOG_IF(INFO, Caffe::root_solver())
92 << "Creating training net from net file: " << param_.net();
93 ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
95 // Set the correct NetState. We start with the solver defaults (lowest
96 // precedence); then, merge in any NetState specified by the net_param itself;
97 // finally, merge in any NetState specified by the train_state (highest
100 net_state.set_phase(TRAIN);
101 net_state.MergeFrom(net_param.state());
102 net_state.MergeFrom(param_.train_state());
103 net_param.mutable_state()->CopyFrom(net_state);
104 if (Caffe::root_solver()) {
105 net_.reset(new Net<Dtype>(net_param));
107 net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
111 template <typename Dtype>
112 void Solver<Dtype>::InitTestNets() {
113 CHECK(Caffe::root_solver());
114 const bool has_net_param = param_.has_net_param();
115 const bool has_net_file = param_.has_net();
116 const int num_generic_nets = has_net_param + has_net_file;
117 CHECK_LE(num_generic_nets, 1)
118 << "Both net_param and net_file may not be specified.";
119 const int num_test_net_params = param_.test_net_param_size();
120 const int num_test_net_files = param_.test_net_size();
121 const int num_test_nets = num_test_net_params + num_test_net_files;
122 if (num_generic_nets) {
123 CHECK_GE(param_.test_iter_size(), num_test_nets)
124 << "test_iter must be specified for each test network.";
126 CHECK_EQ(param_.test_iter_size(), num_test_nets)
127 << "test_iter must be specified for each test network.";
129 // If we have a generic net (specified by net or net_param, rather than
130 // test_net or test_net_param), we may have an unlimited number of actual
131 // test networks -- the actual number is given by the number of remaining
132 // test_iters after any test nets specified by test_net_param and/or test_net
134 const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
135 const int num_test_net_instances = num_test_nets + num_generic_net_instances;
136 if (param_.test_state_size()) {
137 CHECK_EQ(param_.test_state_size(), num_test_net_instances)
138 << "test_state must be unspecified or specified once per test net.";
140 if (num_test_net_instances) {
141 CHECK_GT(param_.test_interval(), 0);
144 vector<string> sources(num_test_net_instances);
145 vector<NetParameter> net_params(num_test_net_instances);
146 for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
147 sources[test_net_id] = "test_net_param";
148 net_params[test_net_id].CopyFrom(param_.test_net_param(i));
150 for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
151 sources[test_net_id] = "test_net file: " + param_.test_net(i);
152 ReadNetParamsFromTextFileOrDie(param_.test_net(i),
153 &net_params[test_net_id]);
155 const int remaining_test_nets = param_.test_iter_size() - test_net_id;
157 for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
158 sources[test_net_id] = "net_param";
159 net_params[test_net_id].CopyFrom(param_.net_param());
163 for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
164 sources[test_net_id] = "net file: " + param_.net();
165 ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
168 test_nets_.resize(num_test_net_instances);
169 for (int i = 0; i < num_test_net_instances; ++i) {
170 // Set the correct NetState. We start with the solver defaults (lowest
171 // precedence); then, merge in any NetState specified by the net_param
172 // itself; finally, merge in any NetState specified by the test_state
173 // (highest precedence).
175 net_state.set_phase(TEST);
176 net_state.MergeFrom(net_params[i].state());
177 if (param_.test_state_size()) {
178 net_state.MergeFrom(param_.test_state(i));
180 net_params[i].mutable_state()->CopyFrom(net_state);
182 << "Creating test net (#" << i << ") specified by " << sources[i];
183 if (Caffe::root_solver()) {
184 test_nets_[i].reset(new Net<Dtype>(net_params[i]));
186 test_nets_[i].reset(new Net<Dtype>(net_params[i],
187 root_solver_->test_nets_[i].get()));
189 test_nets_[i]->set_debug_info(param_.debug_info());
193 template <typename Dtype>
194 void Solver<Dtype>::Step(int iters) {
195 const int start_iter = iter_;
196 const int stop_iter = iter_ + iters;
197 int average_loss = this->param_.average_loss();
201 while (iter_ < stop_iter) {
202 // zero-init the params
203 net_->ClearParamDiffs();
204 if (param_.test_interval() && iter_ % param_.test_interval() == 0
205 && (iter_ > 0 || param_.test_initialization())
206 && Caffe::root_solver()) {
208 if (requested_early_exit_) {
209 // Break out of the while loop because stop was requested while testing.
214 for (int i = 0; i < callbacks_.size(); ++i) {
215 callbacks_[i]->on_start();
217 const bool display = param_.display() && iter_ % param_.display() == 0;
218 net_->set_debug_info(display && param_.debug_info());
219 // accumulate the loss and gradient
221 for (int i = 0; i < param_.iter_size(); ++i) {
222 loss += net_->ForwardBackward();
224 loss /= param_.iter_size();
225 // average the loss across iterations for smoothed reporting
226 UpdateSmoothedLoss(loss, start_iter, average_loss);
228 LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
229 << ", loss = " << smoothed_loss_;
230 const vector<Blob<Dtype>*>& result = net_->output_blobs();
232 for (int j = 0; j < result.size(); ++j) {
233 const Dtype* result_vec = result[j]->cpu_data();
234 const string& output_name =
235 net_->blob_names()[net_->output_blob_indices()[j]];
236 const Dtype loss_weight =
237 net_->blob_loss_weights()[net_->output_blob_indices()[j]];
238 for (int k = 0; k < result[j]->count(); ++k) {
239 ostringstream loss_msg_stream;
241 loss_msg_stream << " (* " << loss_weight
242 << " = " << loss_weight * result_vec[k] << " loss)";
244 LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
245 << score_index++ << ": " << output_name << " = "
246 << result_vec[k] << loss_msg_stream.str();
250 for (int i = 0; i < callbacks_.size(); ++i) {
251 callbacks_[i]->on_gradients_ready();
255 // Increment the internal iter_ counter -- its value should always indicate
256 // the number of times the weights have been updated.
259 SolverAction::Enum request = GetRequestedAction();
261 // Save a snapshot if needed.
262 if ((param_.snapshot()
263 && iter_ % param_.snapshot() == 0
264 && Caffe::root_solver()) ||
265 (request == SolverAction::SNAPSHOT)) {
268 if (SolverAction::STOP == request) {
269 requested_early_exit_ = true;
270 // Break out of training loop.
276 template <typename Dtype>
277 void Solver<Dtype>::Solve(const char* resume_file) {
278 CHECK(Caffe::root_solver());
279 LOG(INFO) << "Solving " << net_->name();
280 LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
282 // Initialize to false every time we start solving.
283 requested_early_exit_ = false;
286 LOG(INFO) << "Restoring previous solver status from " << resume_file;
287 Restore(resume_file);
290 // For a network that is trained by the solver, no bottom or top vecs
291 // should be given, and we will just provide dummy vecs.
292 int start_iter = iter_;
293 Step(param_.max_iter() - iter_);
294 // If we haven't already, save a snapshot after optimization, unless
295 // overridden by setting snapshot_after_train := false
296 if (param_.snapshot_after_train()
297 && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
300 if (requested_early_exit_) {
301 LOG(INFO) << "Optimization stopped early.";
304 // After the optimization is done, run an additional train and test pass to
305 // display the train and test loss/outputs if appropriate (based on the
306 // display and test_interval settings, respectively). Unlike in the rest of
307 // training, for the train net we only run a forward pass as we've already
308 // updated the parameters "max_iter" times -- this final pass is only done to
309 // display the loss, which is computed in the forward pass.
310 if (param_.display() && iter_ % param_.display() == 0) {
311 int average_loss = this->param_.average_loss();
313 net_->Forward(&loss);
315 UpdateSmoothedLoss(loss, start_iter, average_loss);
317 LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
319 if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
322 LOG(INFO) << "Optimization Done.";
325 template <typename Dtype>
326 void Solver<Dtype>::TestAll() {
327 for (int test_net_id = 0;
328 test_net_id < test_nets_.size() && !requested_early_exit_;
334 template <typename Dtype>
335 void Solver<Dtype>::Test(const int test_net_id) {
336 CHECK(Caffe::root_solver());
337 LOG(INFO) << "Iteration " << iter_
338 << ", Testing net (#" << test_net_id << ")";
339 CHECK_NOTNULL(test_nets_[test_net_id].get())->
340 ShareTrainedLayersWith(net_.get());
341 vector<Dtype> test_score;
342 vector<int> test_score_output_id;
343 const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
345 for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
346 SolverAction::Enum request = GetRequestedAction();
347 // Check to see if stoppage of testing/training has been requested.
348 while (request != SolverAction::NONE) {
349 if (SolverAction::SNAPSHOT == request) {
351 } else if (SolverAction::STOP == request) {
352 requested_early_exit_ = true;
354 request = GetRequestedAction();
356 if (requested_early_exit_) {
357 // break out of test loop.
362 const vector<Blob<Dtype>*>& result =
363 test_net->Forward(&iter_loss);
364 if (param_.test_compute_loss()) {
368 for (int j = 0; j < result.size(); ++j) {
369 const Dtype* result_vec = result[j]->cpu_data();
370 for (int k = 0; k < result[j]->count(); ++k) {
371 test_score.push_back(result_vec[k]);
372 test_score_output_id.push_back(j);
377 for (int j = 0; j < result.size(); ++j) {
378 const Dtype* result_vec = result[j]->cpu_data();
379 for (int k = 0; k < result[j]->count(); ++k) {
380 test_score[idx++] += result_vec[k];
385 if (requested_early_exit_) {
386 LOG(INFO) << "Test interrupted.";
389 if (param_.test_compute_loss()) {
390 loss /= param_.test_iter(test_net_id);
391 LOG(INFO) << "Test loss: " << loss;
393 for (int i = 0; i < test_score.size(); ++i) {
394 const int output_blob_index =
395 test_net->output_blob_indices()[test_score_output_id[i]];
396 const string& output_name = test_net->blob_names()[output_blob_index];
397 const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
398 ostringstream loss_msg_stream;
399 const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
401 loss_msg_stream << " (* " << loss_weight
402 << " = " << loss_weight * mean_score << " loss)";
404 LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
405 << mean_score << loss_msg_stream.str();
409 template <typename Dtype>
410 void Solver<Dtype>::Snapshot() {
411 CHECK(Caffe::root_solver());
412 string model_filename;
413 switch (param_.snapshot_format()) {
414 case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
415 model_filename = SnapshotToBinaryProto();
417 case caffe::SolverParameter_SnapshotFormat_HDF5:
418 model_filename = SnapshotToHDF5();
421 LOG(FATAL) << "Unsupported snapshot format.";
424 SnapshotSolverState(model_filename);
427 template <typename Dtype>
428 void Solver<Dtype>::CheckSnapshotWritePermissions() {
429 if (Caffe::root_solver() && param_.snapshot()) {
430 CHECK(param_.has_snapshot_prefix())
431 << "In solver params, snapshot is specified but snapshot_prefix is not";
432 string probe_filename = SnapshotFilename(".tempfile");
433 std::ofstream probe_ofs(probe_filename.c_str());
434 if (probe_ofs.good()) {
436 std::remove(probe_filename.c_str());
438 LOG(FATAL) << "Cannot write to snapshot prefix '"
439 << param_.snapshot_prefix() << "'. Make sure "
440 << "that the directory exists and is writeable.";
445 template <typename Dtype>
446 string Solver<Dtype>::SnapshotFilename(const string extension) {
447 return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
451 template <typename Dtype>
452 string Solver<Dtype>::SnapshotToBinaryProto() {
453 string model_filename = SnapshotFilename(".caffemodel");
454 LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
455 NetParameter net_param;
456 net_->ToProto(&net_param, param_.snapshot_diff());
457 WriteProtoToBinaryFile(net_param, model_filename);
458 return model_filename;
461 template <typename Dtype>
462 string Solver<Dtype>::SnapshotToHDF5() {
463 string model_filename = SnapshotFilename(".caffemodel.h5");
464 LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
465 net_->ToHDF5(model_filename, param_.snapshot_diff());
466 return model_filename;
469 template <typename Dtype>
470 void Solver<Dtype>::Restore(const char* state_file) {
471 CHECK(Caffe::root_solver());
472 string state_filename(state_file);
473 if (state_filename.size() >= 3 &&
474 state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
475 RestoreSolverStateFromHDF5(state_filename);
477 RestoreSolverStateFromBinaryProto(state_filename);
481 template <typename Dtype>
482 void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
484 if (losses_.size() < average_loss) {
485 losses_.push_back(loss);
486 int size = losses_.size();
487 smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
489 int idx = (iter_ - start_iter) % average_loss;
490 smoothed_loss_ += (loss - losses_[idx]) / average_loss;
495 INSTANTIATE_CLASS(Solver);