TestSnapshot expects .h5 snapshots, explicitly checks history.
authorEric Tzeng <etzeng@eecs.berkeley.edu>
Thu, 30 Jul 2015 01:40:38 +0000 (18:40 -0700)
committerEric Tzeng <etzeng@eecs.berkeley.edu>
Fri, 7 Aug 2015 21:56:38 +0000 (14:56 -0700)
src/caffe/test/test_gradient_based_solver.cpp

index 94b5001..78bf4b3 100644 (file)
@@ -139,7 +139,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     if (snapshot) {
       ostringstream resume_file;
       resume_file << snapshot_prefix_ << "/_iter_" << num_iters
-                  << ".solverstate";
+                  << ".solverstate.h5";
       string resume_filename = resume_file.str();
       return resume_filename;
     }
@@ -394,6 +394,18 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
       }
     }
 
+    // Save the solver history
+    vector<shared_ptr<Blob<Dtype> > > history_copies;
+    const vector<shared_ptr<Blob<Dtype> > >& orig_history = solver_->history();
+    history_copies.resize(orig_history.size());
+    for (int i = 0; i < orig_history.size(); ++i) {
+      history_copies[i].reset(new Blob<Dtype>());
+      const bool kReshape = true;
+      for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
+        history_copies[i]->CopyFrom(*orig_history[i], copy_diff, kReshape);
+      }
+    }
+
     // Run the solver for num_iters iterations and snapshot.
     snapshot = true;
     string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
@@ -414,6 +426,17 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
             << "param " << i << " diff differed at dim " << j;
       }
     }
+
+    // Check that history now matches.
+    const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
+    for (int i = 0; i < history.size(); ++i) {
+      for (int j = 0; j < history[i]->count(); ++j) {
+        EXPECT_EQ(history_copies[i]->cpu_data()[j], history[i]->cpu_data()[j])
+            << "history blob " << i << " data differed at dim " << j;
+        EXPECT_EQ(history_copies[i]->cpu_diff()[j], history[i]->cpu_diff()[j])
+            << "history blob " << i << " diff differed at dim " << j;
+      }
+    }
   }
 };