multiple test_iter
authorTobias Domhan <tdomhan@gmail.com>
Sat, 10 May 2014 16:41:05 +0000 (18:41 +0200)
committerTobias Domhan <tdomhan@gmail.com>
Sat, 10 May 2014 16:41:05 +0000 (18:41 +0200)
src/caffe/proto/caffe.proto
src/caffe/solver.cpp

index cf3a9b7..12c4dc6 100644 (file)
@@ -68,7 +68,7 @@ message SolverParameter {
   optional NetParameter train_net_param = 21; // Full params for the train net.
   repeated NetParameter test_net_param = 22; // Full params for the test nets.
   // The number of iterations for each testing phase.
-  optional int32 test_iter = 3 [default = 0];
+  repeated int32 test_iter = 3;
   // The number of iterations between two testing phases.
   optional int32 test_interval = 4 [default = 0];
   optional bool test_compute_loss = 19 [default = false];
index e3dc705..e68f719 100644 (file)
@@ -53,7 +53,7 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
   const int num_test_net_files = param_.test_net_size();
   const int num_test_nets = num_test_net_params + num_test_net_files;
   if (num_test_nets) {
-    CHECK_GT(param_.test_iter(), 0);
+    CHECK_EQ(param_.test_iter().size(), num_test_nets) << "you need to specify test_iter for each test network.";
     CHECK_GT(param_.test_interval(), 0);
   }
   test_nets_.resize(num_test_nets);
@@ -141,7 +141,7 @@ void Solver<Dtype>::Test(const int test_net_id) {
   vector<Dtype> test_score;
   vector<Blob<Dtype>*> bottom_vec;
   Dtype loss = 0;
-  for (int i = 0; i < param_.test_iter(); ++i) {
+  for (int i = 0; i < param_.test_iter().Get(test_net_id); ++i) {
     Dtype iter_loss;
     const vector<Blob<Dtype>*>& result =
         test_nets_[test_net_id]->Forward(bottom_vec, &iter_loss);
@@ -166,12 +166,12 @@ void Solver<Dtype>::Test(const int test_net_id) {
     }
   }
   if (param_.test_compute_loss()) {
-    loss /= param_.test_iter();
+    loss /= param_.test_iter().Get(test_net_id);
     LOG(INFO) << "Test loss: " << loss;
   }
   for (int i = 0; i < test_score.size(); ++i) {
     LOG(INFO) << "Test score #" << i << ": "
-        << test_score[i] / param_.test_iter();
+        << test_score[i] / param_.test_iter().Get(test_net_id);
   }
   Caffe::set_phase(Caffe::TRAIN);
 }