TestGradientBasedSolver: add TestSnapshot to verify behavior when
authorJeff Donahue <jeff.donahue@gmail.com>
Thu, 30 Jul 2015 00:27:58 +0000 (17:27 -0700)
committerEric Tzeng <etzeng@eecs.berkeley.edu>
Fri, 7 Aug 2015 20:48:42 +0000 (13:48 -0700)
restoring net/solver from snapshot

src/caffe/test/test_gradient_based_solver.cpp

index c9135d6..94b5001 100644 (file)
@@ -10,6 +10,7 @@
 #include "caffe/common.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/solver.hpp"
+#include "caffe/util/io.hpp"
 
 #include "caffe/test/test_caffe_main.hpp"
 
@@ -25,6 +26,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
   GradientBasedSolverTest() :
       seed_(1701), num_(4), channels_(3), height_(10), width_(10) {}
 
+  string snapshot_prefix_;
   shared_ptr<SGDSolver<Dtype> > solver_;
   int seed_;
   int num_, channels_, height_, width_;
@@ -36,9 +38,6 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
   virtual void InitSolverFromProtoString(const string& proto) {
     SolverParameter param;
     CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
-    // Disable saving a final snapshot so the tests don't pollute the user's
-    // working directory with useless snapshots.
-    param.set_snapshot_after_train(false);
     // Set the solver_mode according to current Caffe::mode.
     switch (Caffe::mode()) {
       case Caffe::CPU:
@@ -55,11 +54,13 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
          param.delta() : 0;
   }
 
-  void RunLeastSquaresSolver(const Dtype learning_rate,
+  string RunLeastSquaresSolver(const Dtype learning_rate,
       const Dtype weight_decay, const Dtype momentum, const int num_iters,
-      const int iter_size = 1) {
+      const int iter_size = 1, const bool snapshot = false,
+      const char* from_snapshot = NULL) {
     ostringstream proto;
     proto <<
+       "snapshot_after_train: " << snapshot << " "
        "max_iter: " << num_iters << " "
        "base_lr: " << learning_rate << " "
        "lr_policy: 'fixed' "
@@ -119,9 +120,30 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     if (momentum != 0) {
       proto << "momentum: " << momentum << " ";
     }
+    MakeTempDir(&snapshot_prefix_);
+    proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' ";
+    if (snapshot) {
+      proto << "snapshot: " << num_iters << " ";
+    }
     Caffe::set_random_seed(this->seed_);
     this->InitSolverFromProtoString(proto.str());
+    Caffe::set_random_seed(this->seed_);
+    if (from_snapshot != NULL) {
+      this->solver_->Restore(from_snapshot);
+      vector<Blob<Dtype>*> empty_bottom_vec;
+      for (int i = 0; i < this->solver_->iter(); ++i) {
+        this->solver_->net()->Forward(empty_bottom_vec);
+      }
+    }
     this->solver_->Solve();
+    if (snapshot) {
+      ostringstream resume_file;
+      resume_file << snapshot_prefix_ << "/_iter_" << num_iters
+                  << ".solverstate";
+      string resume_filename = resume_file.str();
+      return resume_filename;
+    }
+    return string();
   }
 
   // Compute an update value given the current state of the train net,
@@ -348,6 +370,51 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     // Check that the solver's solution matches ours.
     CheckLeastSquaresUpdate(updated_params);
   }
+
+  void TestSnapshot(const Dtype learning_rate = 1.0,
+      const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
+      const int num_iters = 1) {
+    // Run the solver for num_iters * 2 iterations.
+    const int total_num_iters = num_iters * 2;
+    bool snapshot = false;
+    const int kIterSize = 1;
+    RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+                          total_num_iters, kIterSize, snapshot);
+
+    // Save the resulting param values.
+    vector<shared_ptr<Blob<Dtype> > > param_copies;
+    const vector<shared_ptr<Blob<Dtype> > >& orig_params =
+        solver_->net()->params();
+    param_copies.resize(orig_params.size());
+    for (int i = 0; i < orig_params.size(); ++i) {
+      param_copies[i].reset(new Blob<Dtype>());
+      const bool kReshape = true;
+      for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
+        param_copies[i]->CopyFrom(*orig_params[i], copy_diff, kReshape);
+      }
+    }
+
+    // Run the solver for num_iters iterations and snapshot.
+    snapshot = true;
+    string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
+        momentum, num_iters, kIterSize, snapshot);
+
+    // Reinitialize the solver and run for num_iters more iterations.
+    snapshot = false;
+    RunLeastSquaresSolver(learning_rate, weight_decay,
+        momentum, total_num_iters, kIterSize, snapshot, snapshot_name.c_str());
+
+    // Check that params now match.
+    const vector<shared_ptr<Blob<Dtype> > >& params = solver_->net()->params();
+    for (int i = 0; i < params.size(); ++i) {
+      for (int j = 0; j < params[i]->count(); ++j) {
+        EXPECT_EQ(param_copies[i]->cpu_data()[j], params[i]->cpu_data()[j])
+            << "param " << i << " data differed at dim " << j;
+        EXPECT_EQ(param_copies[i]->cpu_diff()[j], params[i]->cpu_diff()[j])
+            << "param " << i << " diff differed at dim " << j;
+      }
+    }
+  }
 };
 
 
@@ -428,6 +495,18 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
       kIterSize);
 }
 
+TYPED_TEST(SGDSolverTest, TestSnapshot) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.1;
+  const Dtype kMomentum = 0.9;
+  const int kNumIters = 4;
+  for (int i = 1; i <= kNumIters; ++i) {
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+
 template <typename TypeParam>
 class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
   typedef typename TypeParam::Dtype Dtype;
@@ -482,6 +561,18 @@ TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
       kIterSize);
 }
 
+TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.1;
+  const Dtype kMomentum = 0.0;
+  const int kNumIters = 4;
+  for (int i = 1; i <= kNumIters; ++i) {
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+
 template <typename TypeParam>
 class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
   typedef typename TypeParam::Dtype Dtype;
@@ -558,4 +649,15 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
       kIterSize);
 }
 
+TYPED_TEST(NesterovSolverTest, TestSnapshot) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.1;
+  const Dtype kMomentum = 0.0;
+  const int kNumIters = 4;
+  for (int i = 1; i <= kNumIters; ++i) {
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
 }  // namespace caffe