ece3913e88ad01939488851f19aacaf24bce58fe
[platform/upstream/caffeonacl.git] / src / caffe / solver.cpp
1 #include <cstdio>
2
3 #include <string>
4 #include <vector>
5
6 #include "caffe/solver.hpp"
7 #include "caffe/util/format.hpp"
8 #include "caffe/util/hdf5.hpp"
9 #include "caffe/util/io.hpp"
10 #include "caffe/util/upgrade_proto.hpp"
11
12 namespace caffe {
13
14 template<typename Dtype>
15 void Solver<Dtype>::SetActionFunction(ActionCallback func) {
16   action_request_function_ = func;
17 }
18
19 template<typename Dtype>
20 SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
21   if (action_request_function_) {
22     // If the external request function has been set, call it.
23     return action_request_function_();
24   }
25   return SolverAction::NONE;
26 }
27
28 template <typename Dtype>
29 Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
30     : net_(), callbacks_(), root_solver_(root_solver),
31       requested_early_exit_(false) {
32   Init(param);
33 }
34
35 template <typename Dtype>
36 Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
37     : net_(), callbacks_(), root_solver_(root_solver),
38       requested_early_exit_(false) {
39   SolverParameter param;
40   ReadSolverParamsFromTextFileOrDie(param_file, &param);
41   Init(param);
42 }
43
44 template <typename Dtype>
45 void Solver<Dtype>::Init(const SolverParameter& param) {
46   CHECK(Caffe::root_solver() || root_solver_)
47       << "root_solver_ needs to be set for all non-root solvers";
48   LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
49     << std::endl << param.DebugString();
50   param_ = param;
51   CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
52   CheckSnapshotWritePermissions();
53   if (Caffe::root_solver() && param_.random_seed() >= 0) {
54     Caffe::set_random_seed(param_.random_seed());
55   }
56   // Scaffolding code
57   InitTrainNet();
58   if (Caffe::root_solver()) {
59     InitTestNets();
60     LOG(INFO) << "Solver scaffolding done.";
61   }
62   iter_ = 0;
63   current_step_ = 0;
64 }
65
66 template <typename Dtype>
67 void Solver<Dtype>::InitTrainNet() {
68   const int num_train_nets = param_.has_net() + param_.has_net_param() +
69       param_.has_train_net() + param_.has_train_net_param();
70   const string& field_names = "net, net_param, train_net, train_net_param";
71   CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
72       << "using one of these fields: " << field_names;
73   CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
74       << "one of these fields specifying a train_net: " << field_names;
75   NetParameter net_param;
76   if (param_.has_train_net_param()) {
77     LOG_IF(INFO, Caffe::root_solver())
78         << "Creating training net specified in train_net_param.";
79     net_param.CopyFrom(param_.train_net_param());
80   } else if (param_.has_train_net()) {
81     LOG_IF(INFO, Caffe::root_solver())
82         << "Creating training net from train_net file: " << param_.train_net();
83     ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
84   }
85   if (param_.has_net_param()) {
86     LOG_IF(INFO, Caffe::root_solver())
87         << "Creating training net specified in net_param.";
88     net_param.CopyFrom(param_.net_param());
89   }
90   if (param_.has_net()) {
91     LOG_IF(INFO, Caffe::root_solver())
92         << "Creating training net from net file: " << param_.net();
93     ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
94   }
95   // Set the correct NetState.  We start with the solver defaults (lowest
96   // precedence); then, merge in any NetState specified by the net_param itself;
97   // finally, merge in any NetState specified by the train_state (highest
98   // precedence).
99   NetState net_state;
100   net_state.set_phase(TRAIN);
101   net_state.MergeFrom(net_param.state());
102   net_state.MergeFrom(param_.train_state());
103   net_param.mutable_state()->CopyFrom(net_state);
104   if (Caffe::root_solver()) {
105     net_.reset(new Net<Dtype>(net_param));
106   } else {
107     net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
108   }
109 }
110
111 template <typename Dtype>
112 void Solver<Dtype>::InitTestNets() {
113   CHECK(Caffe::root_solver());
114   const bool has_net_param = param_.has_net_param();
115   const bool has_net_file = param_.has_net();
116   const int num_generic_nets = has_net_param + has_net_file;
117   CHECK_LE(num_generic_nets, 1)
118       << "Both net_param and net_file may not be specified.";
119   const int num_test_net_params = param_.test_net_param_size();
120   const int num_test_net_files = param_.test_net_size();
121   const int num_test_nets = num_test_net_params + num_test_net_files;
122   if (num_generic_nets) {
123       CHECK_GE(param_.test_iter_size(), num_test_nets)
124           << "test_iter must be specified for each test network.";
125   } else {
126       CHECK_EQ(param_.test_iter_size(), num_test_nets)
127           << "test_iter must be specified for each test network.";
128   }
129   // If we have a generic net (specified by net or net_param, rather than
130   // test_net or test_net_param), we may have an unlimited number of actual
131   // test networks -- the actual number is given by the number of remaining
132   // test_iters after any test nets specified by test_net_param and/or test_net
133   // are evaluated.
134   const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
135   const int num_test_net_instances = num_test_nets + num_generic_net_instances;
136   if (param_.test_state_size()) {
137     CHECK_EQ(param_.test_state_size(), num_test_net_instances)
138         << "test_state must be unspecified or specified once per test net.";
139   }
140   if (num_test_net_instances) {
141     CHECK_GT(param_.test_interval(), 0);
142   }
143   int test_net_id = 0;
144   vector<string> sources(num_test_net_instances);
145   vector<NetParameter> net_params(num_test_net_instances);
146   for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
147       sources[test_net_id] = "test_net_param";
148       net_params[test_net_id].CopyFrom(param_.test_net_param(i));
149   }
150   for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
151       sources[test_net_id] = "test_net file: " + param_.test_net(i);
152       ReadNetParamsFromTextFileOrDie(param_.test_net(i),
153           &net_params[test_net_id]);
154   }
155   const int remaining_test_nets = param_.test_iter_size() - test_net_id;
156   if (has_net_param) {
157     for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
158       sources[test_net_id] = "net_param";
159       net_params[test_net_id].CopyFrom(param_.net_param());
160     }
161   }
162   if (has_net_file) {
163     for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
164       sources[test_net_id] = "net file: " + param_.net();
165       ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
166     }
167   }
168   test_nets_.resize(num_test_net_instances);
169   for (int i = 0; i < num_test_net_instances; ++i) {
170     // Set the correct NetState.  We start with the solver defaults (lowest
171     // precedence); then, merge in any NetState specified by the net_param
172     // itself; finally, merge in any NetState specified by the test_state
173     // (highest precedence).
174     NetState net_state;
175     net_state.set_phase(TEST);
176     net_state.MergeFrom(net_params[i].state());
177     if (param_.test_state_size()) {
178       net_state.MergeFrom(param_.test_state(i));
179     }
180     net_params[i].mutable_state()->CopyFrom(net_state);
181     LOG(INFO)
182         << "Creating test net (#" << i << ") specified by " << sources[i];
183     if (Caffe::root_solver()) {
184       test_nets_[i].reset(new Net<Dtype>(net_params[i]));
185     } else {
186       test_nets_[i].reset(new Net<Dtype>(net_params[i],
187           root_solver_->test_nets_[i].get()));
188     }
189     test_nets_[i]->set_debug_info(param_.debug_info());
190   }
191 }
192
193 template <typename Dtype>
194 void Solver<Dtype>::Step(int iters) {
195   const int start_iter = iter_;
196   const int stop_iter = iter_ + iters;
197   int average_loss = this->param_.average_loss();
198   losses_.clear();
199   smoothed_loss_ = 0;
200
201   while (iter_ < stop_iter) {
202     // zero-init the params
203     net_->ClearParamDiffs();
204     if (param_.test_interval() && iter_ % param_.test_interval() == 0
205         && (iter_ > 0 || param_.test_initialization())
206         && Caffe::root_solver()) {
207       TestAll();
208       if (requested_early_exit_) {
209         // Break out of the while loop because stop was requested while testing.
210         break;
211       }
212     }
213
214     for (int i = 0; i < callbacks_.size(); ++i) {
215       callbacks_[i]->on_start();
216     }
217     const bool display = param_.display() && iter_ % param_.display() == 0;
218     net_->set_debug_info(display && param_.debug_info());
219     // accumulate the loss and gradient
220     Dtype loss = 0;
221     for (int i = 0; i < param_.iter_size(); ++i) {
222       loss += net_->ForwardBackward();
223     }
224     loss /= param_.iter_size();
225     // average the loss across iterations for smoothed reporting
226     UpdateSmoothedLoss(loss, start_iter, average_loss);
227     if (display) {
228       LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
229           << ", loss = " << smoothed_loss_;
230       const vector<Blob<Dtype>*>& result = net_->output_blobs();
231       int score_index = 0;
232       for (int j = 0; j < result.size(); ++j) {
233         const Dtype* result_vec = result[j]->cpu_data();
234         const string& output_name =
235             net_->blob_names()[net_->output_blob_indices()[j]];
236         const Dtype loss_weight =
237             net_->blob_loss_weights()[net_->output_blob_indices()[j]];
238         for (int k = 0; k < result[j]->count(); ++k) {
239           ostringstream loss_msg_stream;
240           if (loss_weight) {
241             loss_msg_stream << " (* " << loss_weight
242                             << " = " << loss_weight * result_vec[k] << " loss)";
243           }
244           LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
245               << score_index++ << ": " << output_name << " = "
246               << result_vec[k] << loss_msg_stream.str();
247         }
248       }
249     }
250     for (int i = 0; i < callbacks_.size(); ++i) {
251       callbacks_[i]->on_gradients_ready();
252     }
253     ApplyUpdate();
254
255     // Increment the internal iter_ counter -- its value should always indicate
256     // the number of times the weights have been updated.
257     ++iter_;
258
259     SolverAction::Enum request = GetRequestedAction();
260
261     // Save a snapshot if needed.
262     if ((param_.snapshot()
263          && iter_ % param_.snapshot() == 0
264          && Caffe::root_solver()) ||
265          (request == SolverAction::SNAPSHOT)) {
266       Snapshot();
267     }
268     if (SolverAction::STOP == request) {
269       requested_early_exit_ = true;
270       // Break out of training loop.
271       break;
272     }
273   }
274 }
275
276 template <typename Dtype>
277 void Solver<Dtype>::Solve(const char* resume_file) {
278   CHECK(Caffe::root_solver());
279   LOG(INFO) << "Solving " << net_->name();
280   LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
281
282   // Initialize to false every time we start solving.
283   requested_early_exit_ = false;
284
285   if (resume_file) {
286     LOG(INFO) << "Restoring previous solver status from " << resume_file;
287     Restore(resume_file);
288   }
289
290   // For a network that is trained by the solver, no bottom or top vecs
291   // should be given, and we will just provide dummy vecs.
292   int start_iter = iter_;
293   Step(param_.max_iter() - iter_);
294   // If we haven't already, save a snapshot after optimization, unless
295   // overridden by setting snapshot_after_train := false
296   if (param_.snapshot_after_train()
297       && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
298     Snapshot();
299   }
300   if (requested_early_exit_) {
301     LOG(INFO) << "Optimization stopped early.";
302     return;
303   }
304   // After the optimization is done, run an additional train and test pass to
305   // display the train and test loss/outputs if appropriate (based on the
306   // display and test_interval settings, respectively).  Unlike in the rest of
307   // training, for the train net we only run a forward pass as we've already
308   // updated the parameters "max_iter" times -- this final pass is only done to
309   // display the loss, which is computed in the forward pass.
310   if (param_.display() && iter_ % param_.display() == 0) {
311     int average_loss = this->param_.average_loss();
312     Dtype loss;
313     net_->Forward(&loss);
314
315     UpdateSmoothedLoss(loss, start_iter, average_loss);
316
317     LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
318   }
319   if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
320     TestAll();
321   }
322   LOG(INFO) << "Optimization Done.";
323 }
324
325 template <typename Dtype>
326 void Solver<Dtype>::TestAll() {
327   for (int test_net_id = 0;
328        test_net_id < test_nets_.size() && !requested_early_exit_;
329        ++test_net_id) {
330     Test(test_net_id);
331   }
332 }
333
334 template <typename Dtype>
335 void Solver<Dtype>::Test(const int test_net_id) {
336   CHECK(Caffe::root_solver());
337   LOG(INFO) << "Iteration " << iter_
338             << ", Testing net (#" << test_net_id << ")";
339   CHECK_NOTNULL(test_nets_[test_net_id].get())->
340       ShareTrainedLayersWith(net_.get());
341   vector<Dtype> test_score;
342   vector<int> test_score_output_id;
343   const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
344   Dtype loss = 0;
345   for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
346     SolverAction::Enum request = GetRequestedAction();
347     // Check to see if stoppage of testing/training has been requested.
348     while (request != SolverAction::NONE) {
349         if (SolverAction::SNAPSHOT == request) {
350           Snapshot();
351         } else if (SolverAction::STOP == request) {
352           requested_early_exit_ = true;
353         }
354         request = GetRequestedAction();
355     }
356     if (requested_early_exit_) {
357       // break out of test loop.
358       break;
359     }
360
361     Dtype iter_loss;
362     const vector<Blob<Dtype>*>& result =
363         test_net->Forward(&iter_loss);
364     if (param_.test_compute_loss()) {
365       loss += iter_loss;
366     }
367     if (i == 0) {
368       for (int j = 0; j < result.size(); ++j) {
369         const Dtype* result_vec = result[j]->cpu_data();
370         for (int k = 0; k < result[j]->count(); ++k) {
371           test_score.push_back(result_vec[k]);
372           test_score_output_id.push_back(j);
373         }
374       }
375     } else {
376       int idx = 0;
377       for (int j = 0; j < result.size(); ++j) {
378         const Dtype* result_vec = result[j]->cpu_data();
379         for (int k = 0; k < result[j]->count(); ++k) {
380           test_score[idx++] += result_vec[k];
381         }
382       }
383     }
384   }
385   if (requested_early_exit_) {
386     LOG(INFO)     << "Test interrupted.";
387     return;
388   }
389   if (param_.test_compute_loss()) {
390     loss /= param_.test_iter(test_net_id);
391     LOG(INFO) << "Test loss: " << loss;
392   }
393   for (int i = 0; i < test_score.size(); ++i) {
394     const int output_blob_index =
395         test_net->output_blob_indices()[test_score_output_id[i]];
396     const string& output_name = test_net->blob_names()[output_blob_index];
397     const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
398     ostringstream loss_msg_stream;
399     const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
400     if (loss_weight) {
401       loss_msg_stream << " (* " << loss_weight
402                       << " = " << loss_weight * mean_score << " loss)";
403     }
404     LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
405               << mean_score << loss_msg_stream.str();
406   }
407 }
408
409 template <typename Dtype>
410 void Solver<Dtype>::Snapshot() {
411   CHECK(Caffe::root_solver());
412   string model_filename;
413   switch (param_.snapshot_format()) {
414   case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
415     model_filename = SnapshotToBinaryProto();
416     break;
417   case caffe::SolverParameter_SnapshotFormat_HDF5:
418     model_filename = SnapshotToHDF5();
419     break;
420   default:
421     LOG(FATAL) << "Unsupported snapshot format.";
422   }
423
424   SnapshotSolverState(model_filename);
425 }
426
427 template <typename Dtype>
428 void Solver<Dtype>::CheckSnapshotWritePermissions() {
429   if (Caffe::root_solver() && param_.snapshot()) {
430     CHECK(param_.has_snapshot_prefix())
431         << "In solver params, snapshot is specified but snapshot_prefix is not";
432     string probe_filename = SnapshotFilename(".tempfile");
433     std::ofstream probe_ofs(probe_filename.c_str());
434     if (probe_ofs.good()) {
435       probe_ofs.close();
436       std::remove(probe_filename.c_str());
437     } else {
438       LOG(FATAL) << "Cannot write to snapshot prefix '"
439           << param_.snapshot_prefix() << "'.  Make sure "
440           << "that the directory exists and is writeable.";
441     }
442   }
443 }
444
445 template <typename Dtype>
446 string Solver<Dtype>::SnapshotFilename(const string extension) {
447   return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
448     + extension;
449 }
450
451 template <typename Dtype>
452 string Solver<Dtype>::SnapshotToBinaryProto() {
453   string model_filename = SnapshotFilename(".caffemodel");
454   LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
455   NetParameter net_param;
456   net_->ToProto(&net_param, param_.snapshot_diff());
457   WriteProtoToBinaryFile(net_param, model_filename);
458   return model_filename;
459 }
460
461 template <typename Dtype>
462 string Solver<Dtype>::SnapshotToHDF5() {
463   string model_filename = SnapshotFilename(".caffemodel.h5");
464   LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
465   net_->ToHDF5(model_filename, param_.snapshot_diff());
466   return model_filename;
467 }
468
469 template <typename Dtype>
470 void Solver<Dtype>::Restore(const char* state_file) {
471   CHECK(Caffe::root_solver());
472   string state_filename(state_file);
473   if (state_filename.size() >= 3 &&
474       state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
475     RestoreSolverStateFromHDF5(state_filename);
476   } else {
477     RestoreSolverStateFromBinaryProto(state_filename);
478   }
479 }
480
481 template <typename Dtype>
482 void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
483     int average_loss) {
484   if (losses_.size() < average_loss) {
485     losses_.push_back(loss);
486     int size = losses_.size();
487     smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
488   } else {
489     int idx = (iter_ - start_iter) % average_loss;
490     smoothed_loss_ += (loss - losses_[idx]) / average_loss;
491     losses_[idx] = loss;
492   }
493 }
494
495 INSTANTIATE_CLASS(Solver);
496
497 }  // namespace caffe