solver: check and set type to reconcile class and proto
[platform/upstream/caffeonacl.git] / src / caffe / solver.cpp
1 #include <cstdio>
2
3 #include <string>
4 #include <vector>
5
6 #include "caffe/solver.hpp"
7 #include "caffe/util/format.hpp"
8 #include "caffe/util/hdf5.hpp"
9 #include "caffe/util/io.hpp"
10 #include "caffe/util/upgrade_proto.hpp"
11
12 namespace caffe {
13
14 template<typename Dtype>
15 void Solver<Dtype>::SetActionFunction(ActionCallback func) {
16   action_request_function_ = func;
17 }
18
19 template<typename Dtype>
20 SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
21   if (action_request_function_) {
22     // If the external request function has been set, call it.
23     return action_request_function_();
24   }
25   return SolverAction::NONE;
26 }
27
28 template <typename Dtype>
29 Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
30     : net_(), callbacks_(), root_solver_(root_solver),
31       requested_early_exit_(false) {
32   Init(param);
33 }
34
35 template <typename Dtype>
36 Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
37     : net_(), callbacks_(), root_solver_(root_solver),
38       requested_early_exit_(false) {
39   SolverParameter param;
40   ReadSolverParamsFromTextFileOrDie(param_file, &param);
41   CheckType(&param);
42   Init(param);
43 }
44
45 template <typename Dtype>
46 void Solver<Dtype>::CheckType(SolverParameter* param) {
47   // Harmonize solver class type with configured type to avoid confusion.
48   if (param->has_type()) {
49     CHECK_EQ(param->type(), this->type())
50         << "Solver type must agree with instantiated solver class.";
51   } else {
52     param->set_type(this->type());
53   }
54 }
55
56 template <typename Dtype>
57 void Solver<Dtype>::Init(const SolverParameter& param) {
58   CHECK(Caffe::root_solver() || root_solver_)
59       << "root_solver_ needs to be set for all non-root solvers";
60   LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
61     << std::endl << param.DebugString();
62   param_ = param;
63   CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
64   CheckSnapshotWritePermissions();
65   if (Caffe::root_solver() && param_.random_seed() >= 0) {
66     Caffe::set_random_seed(param_.random_seed());
67   }
68   // Scaffolding code
69   InitTrainNet();
70   if (Caffe::root_solver()) {
71     InitTestNets();
72     LOG(INFO) << "Solver scaffolding done.";
73   }
74   iter_ = 0;
75   current_step_ = 0;
76 }
77
78 template <typename Dtype>
79 void Solver<Dtype>::InitTrainNet() {
80   const int num_train_nets = param_.has_net() + param_.has_net_param() +
81       param_.has_train_net() + param_.has_train_net_param();
82   const string& field_names = "net, net_param, train_net, train_net_param";
83   CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
84       << "using one of these fields: " << field_names;
85   CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
86       << "one of these fields specifying a train_net: " << field_names;
87   NetParameter net_param;
88   if (param_.has_train_net_param()) {
89     LOG_IF(INFO, Caffe::root_solver())
90         << "Creating training net specified in train_net_param.";
91     net_param.CopyFrom(param_.train_net_param());
92   } else if (param_.has_train_net()) {
93     LOG_IF(INFO, Caffe::root_solver())
94         << "Creating training net from train_net file: " << param_.train_net();
95     ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
96   }
97   if (param_.has_net_param()) {
98     LOG_IF(INFO, Caffe::root_solver())
99         << "Creating training net specified in net_param.";
100     net_param.CopyFrom(param_.net_param());
101   }
102   if (param_.has_net()) {
103     LOG_IF(INFO, Caffe::root_solver())
104         << "Creating training net from net file: " << param_.net();
105     ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
106   }
107   // Set the correct NetState.  We start with the solver defaults (lowest
108   // precedence); then, merge in any NetState specified by the net_param itself;
109   // finally, merge in any NetState specified by the train_state (highest
110   // precedence).
111   NetState net_state;
112   net_state.set_phase(TRAIN);
113   net_state.MergeFrom(net_param.state());
114   net_state.MergeFrom(param_.train_state());
115   net_param.mutable_state()->CopyFrom(net_state);
116   if (Caffe::root_solver()) {
117     net_.reset(new Net<Dtype>(net_param));
118   } else {
119     net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
120   }
121 }
122
123 template <typename Dtype>
124 void Solver<Dtype>::InitTestNets() {
125   CHECK(Caffe::root_solver());
126   const bool has_net_param = param_.has_net_param();
127   const bool has_net_file = param_.has_net();
128   const int num_generic_nets = has_net_param + has_net_file;
129   CHECK_LE(num_generic_nets, 1)
130       << "Both net_param and net_file may not be specified.";
131   const int num_test_net_params = param_.test_net_param_size();
132   const int num_test_net_files = param_.test_net_size();
133   const int num_test_nets = num_test_net_params + num_test_net_files;
134   if (num_generic_nets) {
135       CHECK_GE(param_.test_iter_size(), num_test_nets)
136           << "test_iter must be specified for each test network.";
137   } else {
138       CHECK_EQ(param_.test_iter_size(), num_test_nets)
139           << "test_iter must be specified for each test network.";
140   }
141   // If we have a generic net (specified by net or net_param, rather than
142   // test_net or test_net_param), we may have an unlimited number of actual
143   // test networks -- the actual number is given by the number of remaining
144   // test_iters after any test nets specified by test_net_param and/or test_net
145   // are evaluated.
146   const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
147   const int num_test_net_instances = num_test_nets + num_generic_net_instances;
148   if (param_.test_state_size()) {
149     CHECK_EQ(param_.test_state_size(), num_test_net_instances)
150         << "test_state must be unspecified or specified once per test net.";
151   }
152   if (num_test_net_instances) {
153     CHECK_GT(param_.test_interval(), 0);
154   }
155   int test_net_id = 0;
156   vector<string> sources(num_test_net_instances);
157   vector<NetParameter> net_params(num_test_net_instances);
158   for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
159       sources[test_net_id] = "test_net_param";
160       net_params[test_net_id].CopyFrom(param_.test_net_param(i));
161   }
162   for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
163       sources[test_net_id] = "test_net file: " + param_.test_net(i);
164       ReadNetParamsFromTextFileOrDie(param_.test_net(i),
165           &net_params[test_net_id]);
166   }
167   const int remaining_test_nets = param_.test_iter_size() - test_net_id;
168   if (has_net_param) {
169     for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
170       sources[test_net_id] = "net_param";
171       net_params[test_net_id].CopyFrom(param_.net_param());
172     }
173   }
174   if (has_net_file) {
175     for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
176       sources[test_net_id] = "net file: " + param_.net();
177       ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
178     }
179   }
180   test_nets_.resize(num_test_net_instances);
181   for (int i = 0; i < num_test_net_instances; ++i) {
182     // Set the correct NetState.  We start with the solver defaults (lowest
183     // precedence); then, merge in any NetState specified by the net_param
184     // itself; finally, merge in any NetState specified by the test_state
185     // (highest precedence).
186     NetState net_state;
187     net_state.set_phase(TEST);
188     net_state.MergeFrom(net_params[i].state());
189     if (param_.test_state_size()) {
190       net_state.MergeFrom(param_.test_state(i));
191     }
192     net_params[i].mutable_state()->CopyFrom(net_state);
193     LOG(INFO)
194         << "Creating test net (#" << i << ") specified by " << sources[i];
195     if (Caffe::root_solver()) {
196       test_nets_[i].reset(new Net<Dtype>(net_params[i]));
197     } else {
198       test_nets_[i].reset(new Net<Dtype>(net_params[i],
199           root_solver_->test_nets_[i].get()));
200     }
201     test_nets_[i]->set_debug_info(param_.debug_info());
202   }
203 }
204
205 template <typename Dtype>
206 void Solver<Dtype>::Step(int iters) {
207   const int start_iter = iter_;
208   const int stop_iter = iter_ + iters;
209   int average_loss = this->param_.average_loss();
210   losses_.clear();
211   smoothed_loss_ = 0;
212
213   while (iter_ < stop_iter) {
214     // zero-init the params
215     net_->ClearParamDiffs();
216     if (param_.test_interval() && iter_ % param_.test_interval() == 0
217         && (iter_ > 0 || param_.test_initialization())
218         && Caffe::root_solver()) {
219       TestAll();
220       if (requested_early_exit_) {
221         // Break out of the while loop because stop was requested while testing.
222         break;
223       }
224     }
225
226     for (int i = 0; i < callbacks_.size(); ++i) {
227       callbacks_[i]->on_start();
228     }
229     const bool display = param_.display() && iter_ % param_.display() == 0;
230     net_->set_debug_info(display && param_.debug_info());
231     // accumulate the loss and gradient
232     Dtype loss = 0;
233     for (int i = 0; i < param_.iter_size(); ++i) {
234       loss += net_->ForwardBackward();
235     }
236     loss /= param_.iter_size();
237     // average the loss across iterations for smoothed reporting
238     UpdateSmoothedLoss(loss, start_iter, average_loss);
239     if (display) {
240       LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
241           << ", loss = " << smoothed_loss_;
242       const vector<Blob<Dtype>*>& result = net_->output_blobs();
243       int score_index = 0;
244       for (int j = 0; j < result.size(); ++j) {
245         const Dtype* result_vec = result[j]->cpu_data();
246         const string& output_name =
247             net_->blob_names()[net_->output_blob_indices()[j]];
248         const Dtype loss_weight =
249             net_->blob_loss_weights()[net_->output_blob_indices()[j]];
250         for (int k = 0; k < result[j]->count(); ++k) {
251           ostringstream loss_msg_stream;
252           if (loss_weight) {
253             loss_msg_stream << " (* " << loss_weight
254                             << " = " << loss_weight * result_vec[k] << " loss)";
255           }
256           LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
257               << score_index++ << ": " << output_name << " = "
258               << result_vec[k] << loss_msg_stream.str();
259         }
260       }
261     }
262     for (int i = 0; i < callbacks_.size(); ++i) {
263       callbacks_[i]->on_gradients_ready();
264     }
265     ApplyUpdate();
266
267     // Increment the internal iter_ counter -- its value should always indicate
268     // the number of times the weights have been updated.
269     ++iter_;
270
271     SolverAction::Enum request = GetRequestedAction();
272
273     // Save a snapshot if needed.
274     if ((param_.snapshot()
275          && iter_ % param_.snapshot() == 0
276          && Caffe::root_solver()) ||
277          (request == SolverAction::SNAPSHOT)) {
278       Snapshot();
279     }
280     if (SolverAction::STOP == request) {
281       requested_early_exit_ = true;
282       // Break out of training loop.
283       break;
284     }
285   }
286 }
287
288 template <typename Dtype>
289 void Solver<Dtype>::Solve(const char* resume_file) {
290   CHECK(Caffe::root_solver());
291   LOG(INFO) << "Solving " << net_->name();
292   LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
293
294   // Initialize to false every time we start solving.
295   requested_early_exit_ = false;
296
297   if (resume_file) {
298     LOG(INFO) << "Restoring previous solver status from " << resume_file;
299     Restore(resume_file);
300   }
301
302   // For a network that is trained by the solver, no bottom or top vecs
303   // should be given, and we will just provide dummy vecs.
304   int start_iter = iter_;
305   Step(param_.max_iter() - iter_);
306   // If we haven't already, save a snapshot after optimization, unless
307   // overridden by setting snapshot_after_train := false
308   if (param_.snapshot_after_train()
309       && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
310     Snapshot();
311   }
312   if (requested_early_exit_) {
313     LOG(INFO) << "Optimization stopped early.";
314     return;
315   }
316   // After the optimization is done, run an additional train and test pass to
317   // display the train and test loss/outputs if appropriate (based on the
318   // display and test_interval settings, respectively).  Unlike in the rest of
319   // training, for the train net we only run a forward pass as we've already
320   // updated the parameters "max_iter" times -- this final pass is only done to
321   // display the loss, which is computed in the forward pass.
322   if (param_.display() && iter_ % param_.display() == 0) {
323     int average_loss = this->param_.average_loss();
324     Dtype loss;
325     net_->Forward(&loss);
326
327     UpdateSmoothedLoss(loss, start_iter, average_loss);
328
329     LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
330   }
331   if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
332     TestAll();
333   }
334   LOG(INFO) << "Optimization Done.";
335 }
336
337 template <typename Dtype>
338 void Solver<Dtype>::TestAll() {
339   for (int test_net_id = 0;
340        test_net_id < test_nets_.size() && !requested_early_exit_;
341        ++test_net_id) {
342     Test(test_net_id);
343   }
344 }
345
346 template <typename Dtype>
347 void Solver<Dtype>::Test(const int test_net_id) {
348   CHECK(Caffe::root_solver());
349   LOG(INFO) << "Iteration " << iter_
350             << ", Testing net (#" << test_net_id << ")";
351   CHECK_NOTNULL(test_nets_[test_net_id].get())->
352       ShareTrainedLayersWith(net_.get());
353   vector<Dtype> test_score;
354   vector<int> test_score_output_id;
355   const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
356   Dtype loss = 0;
357   for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
358     SolverAction::Enum request = GetRequestedAction();
359     // Check to see if stoppage of testing/training has been requested.
360     while (request != SolverAction::NONE) {
361         if (SolverAction::SNAPSHOT == request) {
362           Snapshot();
363         } else if (SolverAction::STOP == request) {
364           requested_early_exit_ = true;
365         }
366         request = GetRequestedAction();
367     }
368     if (requested_early_exit_) {
369       // break out of test loop.
370       break;
371     }
372
373     Dtype iter_loss;
374     const vector<Blob<Dtype>*>& result =
375         test_net->Forward(&iter_loss);
376     if (param_.test_compute_loss()) {
377       loss += iter_loss;
378     }
379     if (i == 0) {
380       for (int j = 0; j < result.size(); ++j) {
381         const Dtype* result_vec = result[j]->cpu_data();
382         for (int k = 0; k < result[j]->count(); ++k) {
383           test_score.push_back(result_vec[k]);
384           test_score_output_id.push_back(j);
385         }
386       }
387     } else {
388       int idx = 0;
389       for (int j = 0; j < result.size(); ++j) {
390         const Dtype* result_vec = result[j]->cpu_data();
391         for (int k = 0; k < result[j]->count(); ++k) {
392           test_score[idx++] += result_vec[k];
393         }
394       }
395     }
396   }
397   if (requested_early_exit_) {
398     LOG(INFO)     << "Test interrupted.";
399     return;
400   }
401   if (param_.test_compute_loss()) {
402     loss /= param_.test_iter(test_net_id);
403     LOG(INFO) << "Test loss: " << loss;
404   }
405   for (int i = 0; i < test_score.size(); ++i) {
406     const int output_blob_index =
407         test_net->output_blob_indices()[test_score_output_id[i]];
408     const string& output_name = test_net->blob_names()[output_blob_index];
409     const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
410     ostringstream loss_msg_stream;
411     const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
412     if (loss_weight) {
413       loss_msg_stream << " (* " << loss_weight
414                       << " = " << loss_weight * mean_score << " loss)";
415     }
416     LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
417               << mean_score << loss_msg_stream.str();
418   }
419 }
420
421 template <typename Dtype>
422 void Solver<Dtype>::Snapshot() {
423   CHECK(Caffe::root_solver());
424   string model_filename;
425   switch (param_.snapshot_format()) {
426   case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
427     model_filename = SnapshotToBinaryProto();
428     break;
429   case caffe::SolverParameter_SnapshotFormat_HDF5:
430     model_filename = SnapshotToHDF5();
431     break;
432   default:
433     LOG(FATAL) << "Unsupported snapshot format.";
434   }
435
436   SnapshotSolverState(model_filename);
437 }
438
439 template <typename Dtype>
440 void Solver<Dtype>::CheckSnapshotWritePermissions() {
441   if (Caffe::root_solver() && param_.snapshot()) {
442     CHECK(param_.has_snapshot_prefix())
443         << "In solver params, snapshot is specified but snapshot_prefix is not";
444     string probe_filename = SnapshotFilename(".tempfile");
445     std::ofstream probe_ofs(probe_filename.c_str());
446     if (probe_ofs.good()) {
447       probe_ofs.close();
448       std::remove(probe_filename.c_str());
449     } else {
450       LOG(FATAL) << "Cannot write to snapshot prefix '"
451           << param_.snapshot_prefix() << "'.  Make sure "
452           << "that the directory exists and is writeable.";
453     }
454   }
455 }
456
457 template <typename Dtype>
458 string Solver<Dtype>::SnapshotFilename(const string extension) {
459   return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
460     + extension;
461 }
462
463 template <typename Dtype>
464 string Solver<Dtype>::SnapshotToBinaryProto() {
465   string model_filename = SnapshotFilename(".caffemodel");
466   LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
467   NetParameter net_param;
468   net_->ToProto(&net_param, param_.snapshot_diff());
469   WriteProtoToBinaryFile(net_param, model_filename);
470   return model_filename;
471 }
472
473 template <typename Dtype>
474 string Solver<Dtype>::SnapshotToHDF5() {
475   string model_filename = SnapshotFilename(".caffemodel.h5");
476   LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
477   net_->ToHDF5(model_filename, param_.snapshot_diff());
478   return model_filename;
479 }
480
481 template <typename Dtype>
482 void Solver<Dtype>::Restore(const char* state_file) {
483   CHECK(Caffe::root_solver());
484   string state_filename(state_file);
485   if (state_filename.size() >= 3 &&
486       state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
487     RestoreSolverStateFromHDF5(state_filename);
488   } else {
489     RestoreSolverStateFromBinaryProto(state_filename);
490   }
491 }
492
493 template <typename Dtype>
494 void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
495     int average_loss) {
496   if (losses_.size() < average_loss) {
497     losses_.push_back(loss);
498     int size = losses_.size();
499     smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
500   } else {
501     int idx = (iter_ - start_iter) % average_loss;
502     smoothed_loss_ += (loss - losses_[idx]) / average_loss;
503     losses_[idx] = loss;
504   }
505 }
506
507 INSTANTIATE_CLASS(Solver);
508
509 }  // namespace caffe