Merge pull request #5207 from CDLuminate/cmake-bump-soversion-to-rc4
[platform/upstream/caffeonacl.git] / src / caffe / solver.cpp
1 #include <cstdio>
2
3 #include <string>
4 #include <vector>
5
6 #include "caffe/solver.hpp"
7 #include "caffe/util/format.hpp"
8 #include "caffe/util/hdf5.hpp"
9 #include "caffe/util/io.hpp"
10 #include "caffe/util/upgrade_proto.hpp"
11
12 namespace caffe {
13
14 template<typename Dtype>
15 void Solver<Dtype>::SetActionFunction(ActionCallback func) {
16   action_request_function_ = func;
17 }
18
19 template<typename Dtype>
20 SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
21   if (action_request_function_) {
22     // If the external request function has been set, call it.
23     return action_request_function_();
24   }
25   return SolverAction::NONE;
26 }
27
28 template <typename Dtype>
29 Solver<Dtype>::Solver(const SolverParameter& param)
30     : net_(), callbacks_(), requested_early_exit_(false) {
31   Init(param);
32 }
33
34 template <typename Dtype>
35 Solver<Dtype>::Solver(const string& param_file)
36     : net_(), callbacks_(), requested_early_exit_(false) {
37   SolverParameter param;
38   ReadSolverParamsFromTextFileOrDie(param_file, &param);
39   Init(param);
40 }
41
42 template <typename Dtype>
43 void Solver<Dtype>::Init(const SolverParameter& param) {
44   LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
45     << std::endl << param.DebugString();
46   param_ = param;
47   CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
48   CheckSnapshotWritePermissions();
49   if (param_.random_seed() >= 0) {
50     Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank());
51   }
52   // Scaffolding code
53   InitTrainNet();
54   if (Caffe::root_solver()) {
55     InitTestNets();
56     LOG(INFO) << "Solver scaffolding done.";
57   }
58   iter_ = 0;
59   current_step_ = 0;
60 }
61
62 template <typename Dtype>
63 void Solver<Dtype>::InitTrainNet() {
64   const int num_train_nets = param_.has_net() + param_.has_net_param() +
65       param_.has_train_net() + param_.has_train_net_param();
66   const string& field_names = "net, net_param, train_net, train_net_param";
67   CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
68       << "using one of these fields: " << field_names;
69   CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
70       << "one of these fields specifying a train_net: " << field_names;
71   NetParameter net_param;
72   if (param_.has_train_net_param()) {
73     LOG_IF(INFO, Caffe::root_solver())
74         << "Creating training net specified in train_net_param.";
75     net_param.CopyFrom(param_.train_net_param());
76   } else if (param_.has_train_net()) {
77     LOG_IF(INFO, Caffe::root_solver())
78         << "Creating training net from train_net file: " << param_.train_net();
79     ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
80   }
81   if (param_.has_net_param()) {
82     LOG_IF(INFO, Caffe::root_solver())
83         << "Creating training net specified in net_param.";
84     net_param.CopyFrom(param_.net_param());
85   }
86   if (param_.has_net()) {
87     LOG_IF(INFO, Caffe::root_solver())
88         << "Creating training net from net file: " << param_.net();
89     ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
90   }
91   // Set the correct NetState.  We start with the solver defaults (lowest
92   // precedence); then, merge in any NetState specified by the net_param itself;
93   // finally, merge in any NetState specified by the train_state (highest
94   // precedence).
95   NetState net_state;
96   net_state.set_phase(TRAIN);
97   net_state.MergeFrom(net_param.state());
98   net_state.MergeFrom(param_.train_state());
99   net_param.mutable_state()->CopyFrom(net_state);
100   net_.reset(new Net<Dtype>(net_param));
101 }
102
103 template <typename Dtype>
104 void Solver<Dtype>::InitTestNets() {
105   CHECK(Caffe::root_solver());
106   const bool has_net_param = param_.has_net_param();
107   const bool has_net_file = param_.has_net();
108   const int num_generic_nets = has_net_param + has_net_file;
109   CHECK_LE(num_generic_nets, 1)
110       << "Both net_param and net_file may not be specified.";
111   const int num_test_net_params = param_.test_net_param_size();
112   const int num_test_net_files = param_.test_net_size();
113   const int num_test_nets = num_test_net_params + num_test_net_files;
114   if (num_generic_nets) {
115       CHECK_GE(param_.test_iter_size(), num_test_nets)
116           << "test_iter must be specified for each test network.";
117   } else {
118       CHECK_EQ(param_.test_iter_size(), num_test_nets)
119           << "test_iter must be specified for each test network.";
120   }
121   // If we have a generic net (specified by net or net_param, rather than
122   // test_net or test_net_param), we may have an unlimited number of actual
123   // test networks -- the actual number is given by the number of remaining
124   // test_iters after any test nets specified by test_net_param and/or test_net
125   // are evaluated.
126   const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
127   const int num_test_net_instances = num_test_nets + num_generic_net_instances;
128   if (param_.test_state_size()) {
129     CHECK_EQ(param_.test_state_size(), num_test_net_instances)
130         << "test_state must be unspecified or specified once per test net.";
131   }
132   if (num_test_net_instances) {
133     CHECK_GT(param_.test_interval(), 0);
134   }
135   int test_net_id = 0;
136   vector<string> sources(num_test_net_instances);
137   vector<NetParameter> net_params(num_test_net_instances);
138   for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
139       sources[test_net_id] = "test_net_param";
140       net_params[test_net_id].CopyFrom(param_.test_net_param(i));
141   }
142   for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
143       sources[test_net_id] = "test_net file: " + param_.test_net(i);
144       ReadNetParamsFromTextFileOrDie(param_.test_net(i),
145           &net_params[test_net_id]);
146   }
147   const int remaining_test_nets = param_.test_iter_size() - test_net_id;
148   if (has_net_param) {
149     for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
150       sources[test_net_id] = "net_param";
151       net_params[test_net_id].CopyFrom(param_.net_param());
152     }
153   }
154   if (has_net_file) {
155     for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
156       sources[test_net_id] = "net file: " + param_.net();
157       ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
158     }
159   }
160   test_nets_.resize(num_test_net_instances);
161   for (int i = 0; i < num_test_net_instances; ++i) {
162     // Set the correct NetState.  We start with the solver defaults (lowest
163     // precedence); then, merge in any NetState specified by the net_param
164     // itself; finally, merge in any NetState specified by the test_state
165     // (highest precedence).
166     NetState net_state;
167     net_state.set_phase(TEST);
168     net_state.MergeFrom(net_params[i].state());
169     if (param_.test_state_size()) {
170       net_state.MergeFrom(param_.test_state(i));
171     }
172     net_params[i].mutable_state()->CopyFrom(net_state);
173     LOG(INFO)
174         << "Creating test net (#" << i << ") specified by " << sources[i];
175     test_nets_[i].reset(new Net<Dtype>(net_params[i]));
176     test_nets_[i]->set_debug_info(param_.debug_info());
177   }
178 }
179
180 template <typename Dtype>
181 void Solver<Dtype>::Step(int iters) {
182   const int start_iter = iter_;
183   const int stop_iter = iter_ + iters;
184   int average_loss = this->param_.average_loss();
185   losses_.clear();
186   smoothed_loss_ = 0;
187   iteration_timer_.Start();
188
189   while (iter_ < stop_iter) {
190     // zero-init the params
191     net_->ClearParamDiffs();
192     if (param_.test_interval() && iter_ % param_.test_interval() == 0
193         && (iter_ > 0 || param_.test_initialization())) {
194       if (Caffe::root_solver()) {
195         TestAll();
196       }
197       if (requested_early_exit_) {
198         // Break out of the while loop because stop was requested while testing.
199         break;
200       }
201     }
202
203     for (int i = 0; i < callbacks_.size(); ++i) {
204       callbacks_[i]->on_start();
205     }
206     const bool display = param_.display() && iter_ % param_.display() == 0;
207     net_->set_debug_info(display && param_.debug_info());
208     // accumulate the loss and gradient
209     Dtype loss = 0;
210     for (int i = 0; i < param_.iter_size(); ++i) {
211       loss += net_->ForwardBackward();
212     }
213     loss /= param_.iter_size();
214     // average the loss across iterations for smoothed reporting
215     UpdateSmoothedLoss(loss, start_iter, average_loss);
216     if (display) {
217       float lapse = iteration_timer_.Seconds();
218       float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1);
219       LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
220           << " (" << per_s << " iter/s, " << lapse << "s/"
221           << param_.display() << " iters), loss = " << smoothed_loss_;
222       iteration_timer_.Start();
223       iterations_last_ = iter_;
224       const vector<Blob<Dtype>*>& result = net_->output_blobs();
225       int score_index = 0;
226       for (int j = 0; j < result.size(); ++j) {
227         const Dtype* result_vec = result[j]->cpu_data();
228         const string& output_name =
229             net_->blob_names()[net_->output_blob_indices()[j]];
230         const Dtype loss_weight =
231             net_->blob_loss_weights()[net_->output_blob_indices()[j]];
232         for (int k = 0; k < result[j]->count(); ++k) {
233           ostringstream loss_msg_stream;
234           if (loss_weight) {
235             loss_msg_stream << " (* " << loss_weight
236                             << " = " << loss_weight * result_vec[k] << " loss)";
237           }
238           LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
239               << score_index++ << ": " << output_name << " = "
240               << result_vec[k] << loss_msg_stream.str();
241         }
242       }
243     }
244     for (int i = 0; i < callbacks_.size(); ++i) {
245       callbacks_[i]->on_gradients_ready();
246     }
247     ApplyUpdate();
248
249     // Increment the internal iter_ counter -- its value should always indicate
250     // the number of times the weights have been updated.
251     ++iter_;
252
253     SolverAction::Enum request = GetRequestedAction();
254
255     // Save a snapshot if needed.
256     if ((param_.snapshot()
257          && iter_ % param_.snapshot() == 0
258          && Caffe::root_solver()) ||
259          (request == SolverAction::SNAPSHOT)) {
260       Snapshot();
261     }
262     if (SolverAction::STOP == request) {
263       requested_early_exit_ = true;
264       // Break out of training loop.
265       break;
266     }
267   }
268 }
269
270 template <typename Dtype>
271 void Solver<Dtype>::Solve(const char* resume_file) {
272   CHECK(Caffe::root_solver());
273   LOG(INFO) << "Solving " << net_->name();
274   LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
275
276   // Initialize to false every time we start solving.
277   requested_early_exit_ = false;
278
279   if (resume_file) {
280     LOG(INFO) << "Restoring previous solver status from " << resume_file;
281     Restore(resume_file);
282   }
283
284   // For a network that is trained by the solver, no bottom or top vecs
285   // should be given, and we will just provide dummy vecs.
286   int start_iter = iter_;
287   Step(param_.max_iter() - iter_);
288   // If we haven't already, save a snapshot after optimization, unless
289   // overridden by setting snapshot_after_train := false
290   if (param_.snapshot_after_train()
291       && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
292     Snapshot();
293   }
294   if (requested_early_exit_) {
295     LOG(INFO) << "Optimization stopped early.";
296     return;
297   }
298   // After the optimization is done, run an additional train and test pass to
299   // display the train and test loss/outputs if appropriate (based on the
300   // display and test_interval settings, respectively).  Unlike in the rest of
301   // training, for the train net we only run a forward pass as we've already
302   // updated the parameters "max_iter" times -- this final pass is only done to
303   // display the loss, which is computed in the forward pass.
304   if (param_.display() && iter_ % param_.display() == 0) {
305     int average_loss = this->param_.average_loss();
306     Dtype loss;
307     net_->Forward(&loss);
308
309     UpdateSmoothedLoss(loss, start_iter, average_loss);
310
311     LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
312   }
313   if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
314     TestAll();
315   }
316   LOG(INFO) << "Optimization Done.";
317 }
318
319 template <typename Dtype>
320 void Solver<Dtype>::TestAll() {
321   for (int test_net_id = 0;
322        test_net_id < test_nets_.size() && !requested_early_exit_;
323        ++test_net_id) {
324     Test(test_net_id);
325   }
326 }
327
328 template <typename Dtype>
329 void Solver<Dtype>::Test(const int test_net_id) {
330   CHECK(Caffe::root_solver());
331   LOG(INFO) << "Iteration " << iter_
332             << ", Testing net (#" << test_net_id << ")";
333   CHECK_NOTNULL(test_nets_[test_net_id].get())->
334       ShareTrainedLayersWith(net_.get());
335   vector<Dtype> test_score;
336   vector<int> test_score_output_id;
337   const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
338   Dtype loss = 0;
339   for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
340     SolverAction::Enum request = GetRequestedAction();
341     // Check to see if stoppage of testing/training has been requested.
342     while (request != SolverAction::NONE) {
343         if (SolverAction::SNAPSHOT == request) {
344           Snapshot();
345         } else if (SolverAction::STOP == request) {
346           requested_early_exit_ = true;
347         }
348         request = GetRequestedAction();
349     }
350     if (requested_early_exit_) {
351       // break out of test loop.
352       break;
353     }
354
355     Dtype iter_loss;
356     const vector<Blob<Dtype>*>& result =
357         test_net->Forward(&iter_loss);
358     if (param_.test_compute_loss()) {
359       loss += iter_loss;
360     }
361     if (i == 0) {
362       for (int j = 0; j < result.size(); ++j) {
363         const Dtype* result_vec = result[j]->cpu_data();
364         for (int k = 0; k < result[j]->count(); ++k) {
365           test_score.push_back(result_vec[k]);
366           test_score_output_id.push_back(j);
367         }
368       }
369     } else {
370       int idx = 0;
371       for (int j = 0; j < result.size(); ++j) {
372         const Dtype* result_vec = result[j]->cpu_data();
373         for (int k = 0; k < result[j]->count(); ++k) {
374           test_score[idx++] += result_vec[k];
375         }
376       }
377     }
378   }
379   if (requested_early_exit_) {
380     LOG(INFO)     << "Test interrupted.";
381     return;
382   }
383   if (param_.test_compute_loss()) {
384     loss /= param_.test_iter(test_net_id);
385     LOG(INFO) << "Test loss: " << loss;
386   }
387   for (int i = 0; i < test_score.size(); ++i) {
388     const int output_blob_index =
389         test_net->output_blob_indices()[test_score_output_id[i]];
390     const string& output_name = test_net->blob_names()[output_blob_index];
391     const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
392     ostringstream loss_msg_stream;
393     const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
394     if (loss_weight) {
395       loss_msg_stream << " (* " << loss_weight
396                       << " = " << loss_weight * mean_score << " loss)";
397     }
398     LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
399               << mean_score << loss_msg_stream.str();
400   }
401 }
402
403 template <typename Dtype>
404 void Solver<Dtype>::Snapshot() {
405   CHECK(Caffe::root_solver());
406   string model_filename;
407   switch (param_.snapshot_format()) {
408   case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
409     model_filename = SnapshotToBinaryProto();
410     break;
411   case caffe::SolverParameter_SnapshotFormat_HDF5:
412     model_filename = SnapshotToHDF5();
413     break;
414   default:
415     LOG(FATAL) << "Unsupported snapshot format.";
416   }
417
418   SnapshotSolverState(model_filename);
419 }
420
421 template <typename Dtype>
422 void Solver<Dtype>::CheckSnapshotWritePermissions() {
423   if (Caffe::root_solver() && param_.snapshot()) {
424     CHECK(param_.has_snapshot_prefix())
425         << "In solver params, snapshot is specified but snapshot_prefix is not";
426     string probe_filename = SnapshotFilename(".tempfile");
427     std::ofstream probe_ofs(probe_filename.c_str());
428     if (probe_ofs.good()) {
429       probe_ofs.close();
430       std::remove(probe_filename.c_str());
431     } else {
432       LOG(FATAL) << "Cannot write to snapshot prefix '"
433           << param_.snapshot_prefix() << "'.  Make sure "
434           << "that the directory exists and is writeable.";
435     }
436   }
437 }
438
439 template <typename Dtype>
440 string Solver<Dtype>::SnapshotFilename(const string extension) {
441   return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
442     + extension;
443 }
444
445 template <typename Dtype>
446 string Solver<Dtype>::SnapshotToBinaryProto() {
447   string model_filename = SnapshotFilename(".caffemodel");
448   LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
449   NetParameter net_param;
450   net_->ToProto(&net_param, param_.snapshot_diff());
451   WriteProtoToBinaryFile(net_param, model_filename);
452   return model_filename;
453 }
454
455 template <typename Dtype>
456 string Solver<Dtype>::SnapshotToHDF5() {
457   string model_filename = SnapshotFilename(".caffemodel.h5");
458   LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
459   net_->ToHDF5(model_filename, param_.snapshot_diff());
460   return model_filename;
461 }
462
463 template <typename Dtype>
464 void Solver<Dtype>::Restore(const char* state_file) {
465   CHECK(Caffe::root_solver());
466   string state_filename(state_file);
467   if (state_filename.size() >= 3 &&
468       state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
469     RestoreSolverStateFromHDF5(state_filename);
470   } else {
471     RestoreSolverStateFromBinaryProto(state_filename);
472   }
473 }
474
475 template <typename Dtype>
476 void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
477     int average_loss) {
478   if (losses_.size() < average_loss) {
479     losses_.push_back(loss);
480     int size = losses_.size();
481     smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
482   } else {
483     int idx = (iter_ - start_iter) % average_loss;
484     smoothed_loss_ += (loss - losses_[idx]) / average_loss;
485     losses_[idx] = loss;
486   }
487 }
488
489 INSTANTIATE_CLASS(Solver);
490
491 }  // namespace caffe