Make training iterations 0-indexed.
authorJeff Donahue <jeff.donahue@gmail.com>
Sun, 27 Jul 2014 00:03:27 +0000 (17:03 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Sun, 27 Jul 2014 19:04:22 +0000 (12:04 -0700)
src/caffe/solver.cpp

index 0dfc7f0..b44948e 100644 (file)
@@ -89,25 +89,24 @@ void Solver<Dtype>::Solve(const char* resume_file) {
     LOG(INFO) << "Restoring previous solver status from " << resume_file;
     Restore(resume_file);
   }
-
-  // Run a test pass before doing any training to avoid waiting a potentially
-  // very long time (param_.test_interval() training iterations) to report that
-  // there's not enough memory to run the test net and crash, etc.; and to gauge
-  // the effect of the first training iterations.
-  if (param_.test_interval()) {
-    TestAll();
-  }
+  // Remember the initial iter_ value; will be non-zero if we loaded from a
+  // resume_file above.
+  const int start_iter = iter_;
 
   // For a network that is trained by the solver, no bottom or top vecs
   // should be given, and we will just provide dummy vecs.
   vector<Blob<Dtype>*> bottom_vec;
-  while (iter_++ < param_.max_iter()) {
-    const bool display = param_.display() && iter_ % param_.display() == 0;
-    if (display) {
-      net_->set_debug_info(param_.debug_info());
-    } else {
-      net_->set_debug_info(false);
+  for (; iter_ < param_.max_iter(); ++iter_) {
+    // Save a snapshot if needed.
+    if (param_.snapshot() && iter_ > start_iter &&
+        iter_ % param_.snapshot() == 0) {
+      Snapshot();
+    }
+    if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
+      TestAll();
     }
+    const bool display = param_.display() && iter_ % param_.display() == 0;
+    net_->set_debug_info(display && param_.debug_info());
     Dtype loss = net_->ForwardBackward(bottom_vec);
     ComputeUpdateValue();
     net_->Update();
@@ -115,17 +114,23 @@ void Solver<Dtype>::Solve(const char* resume_file) {
     if (display) {
       LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
     }
-    if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
-      TestAll();
-    }
-    // Check if we need to do snapshot
-    if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
-      Snapshot();
-    }
   }
-  // After the optimization is done, always do a snapshot.
-  iter_--;
+  // Always save a snapshot after optimization.
   Snapshot();
+  // After the optimization is done, run an additional train and test pass to
+  // display the train and test loss/outputs if appropriate (based on the
+  // display and test_interval settings, respectively).  Unlike in the rest of
+  // training, for the train net we only run a forward pass as we've already
+  // updated the parameters "max_iter" times -- this final pass is only done to
+  // display the loss, which is computed in the forward pass.
+  if (param_.display() && iter_ % param_.display() == 0) {
+    Dtype loss;
+    net_->Forward(bottom_vec, &loss);
+    LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
+  }
+  if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
+    TestAll();
+  }
   LOG(INFO) << "Optimization Done.";
 }