allow multiple test nets
authorJeff Donahue <jeff.donahue@gmail.com>
Sat, 10 May 2014 03:28:55 +0000 (20:28 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Sat, 10 May 2014 04:02:46 +0000 (21:02 -0700)
include/caffe/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp

index aef9b22..3112c59 100644 (file)
@@ -33,7 +33,8 @@ class Solver {
   // written to disk together with the learned net.
   void Snapshot();
   // The test routine
-  void Test();
+  void TestAll();
+  void Test(const int test_net_id = 0);
   virtual void SnapshotSolverState(SolverState* state) = 0;
   // The Restore function implements how one should restore the solver to a
   // previously snapshotted state. You should implement the RestoreSolverState()
@@ -44,7 +45,7 @@ class Solver {
   SolverParameter param_;
   int iter_;
   shared_ptr<Net<Dtype> > net_;
-  shared_ptr<Net<Dtype> > test_net_;
+  vector<shared_ptr<Net<Dtype> > > test_nets_;
 
   DISABLE_COPY_AND_ASSIGN(Solver);
 };
index eb086a1..cf3a9b7 100644 (file)
@@ -64,9 +64,9 @@ message SolverParameter {
   // If {train,test}_net is specified, {train,test}_net_param should not be,
   // and vice versa.
   optional string train_net = 1; // The proto filename for the train net.
-  optional string test_net = 2; // The proto filename for the test net.
+  repeated string test_net = 2; // The proto filenames for the test nets.
   optional NetParameter train_net_param = 21; // Full params for the train net.
-  optional NetParameter test_net_param = 22; // Full params for the test net.
+  repeated NetParameter test_net_param = 22; // Full params for the test nets.
   // The number of iterations for each testing phase.
   optional int32 test_iter = 3 [default = 0];
   // The number of iterations between two testing phases.
index 9420ca3..e3dc705 100644 (file)
@@ -19,13 +19,13 @@ namespace caffe {
 
 template <typename Dtype>
 Solver<Dtype>::Solver(const SolverParameter& param)
-    : net_(), test_net_() {
+    : net_() {
   Init(param);
 }
 
 template <typename Dtype>
 Solver<Dtype>::Solver(const string& param_file)
-    : net_(), test_net_() {
+    : net_() {
   SolverParameter param;
   ReadProtoFromTextFile(param_file, &param);
   Init(param);
@@ -49,23 +49,25 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
     LOG(INFO) << "Creating training net from file: " << param_.train_net();
     net_.reset(new Net<Dtype>(param_.train_net()));
   }
-  bool has_test_net = false;
-  NetParameter test_net_param;
-  if (param_.has_test_net_param()) {
-    CHECK(!param_.has_test_net()) << "Either test_net_param or test_net may be "
-                                  << "specified, but not both.";
-    LOG(INFO) << "Creating testing net specified in SolverParameter.";
-    test_net_.reset(new Net<Dtype>(param_.test_net_param()));
-    has_test_net = true;
-  } else if (param_.has_test_net()) {
-    LOG(INFO) << "Creating testing net from file: " << param_.test_net();
-    test_net_.reset(new Net<Dtype>(param_.test_net()));
-    has_test_net = true;
-  }
-  if (has_test_net) {
+  const int num_test_net_params = param_.test_net_param_size();
+  const int num_test_net_files = param_.test_net_size();
+  const int num_test_nets = num_test_net_params + num_test_net_files;
+  if (num_test_nets) {
     CHECK_GT(param_.test_iter(), 0);
     CHECK_GT(param_.test_interval(), 0);
   }
+  test_nets_.resize(num_test_nets);
+  for (int i = 0; i < num_test_net_params; ++i) {
+      LOG(INFO) << "Creating testing net (#" << i
+                << ") specified in SolverParameter.";
+      test_nets_[i].reset(new Net<Dtype>(param_.test_net_param(i)));
+  }
+  for (int i = 0, test_net_id = num_test_net_params;
+       i < num_test_net_files; ++i, ++test_net_id) {
+      LOG(INFO) << "Creating testing net (#" << test_net_id
+                << ") from file: " << param.test_net(i);
+      test_nets_[test_net_id].reset(new Net<Dtype>(param_.test_net(i)));
+  }
   LOG(INFO) << "Solver scaffolding done.";
 }
 
@@ -92,7 +94,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
   // there's not enough memory to run the test net and crash, etc.; and to gauge
   // the effect of the first training iterations.
   if (param_.test_interval()) {
-    Test();
+    TestAll();
   }
 
   // For a network that is trained by the solver, no bottom or top vecs
@@ -107,7 +109,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
       LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
     }
     if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
-      Test();
+      TestAll();
     }
     // Check if we need to do snapshot
     if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
@@ -122,18 +124,27 @@ void Solver<Dtype>::Solve(const char* resume_file) {
 
 
 template <typename Dtype>
-void Solver<Dtype>::Test() {
-  LOG(INFO) << "Iteration " << iter_ << ", Testing net";
+void Solver<Dtype>::TestAll() {
+  for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) {
+    Test(test_net_id);
+  }
+}
+
+
+template <typename Dtype>
+void Solver<Dtype>::Test(const int test_net_id) {
+  LOG(INFO) << "Iteration " << iter_ << ", Testing net (#" << test_net_id << ")";
   // We need to set phase to test before running.
   Caffe::set_phase(Caffe::TEST);
-  CHECK_NOTNULL(test_net_.get())->ShareTrainedLayersWith(net_.get());
+  CHECK_NOTNULL(test_nets_[test_net_id].get())->
+      ShareTrainedLayersWith(net_.get());
   vector<Dtype> test_score;
   vector<Blob<Dtype>*> bottom_vec;
   Dtype loss = 0;
   for (int i = 0; i < param_.test_iter(); ++i) {
     Dtype iter_loss;
     const vector<Blob<Dtype>*>& result =
-        test_net_->Forward(bottom_vec, &iter_loss);
+        test_nets_[test_net_id]->Forward(bottom_vec, &iter_loss);
     if (param_.test_compute_loss()) {
       loss += iter_loss;
     }