From 2cd46db70d7625968cbf232d70723825d74d5040 Mon Sep 17 00:00:00 2001 From: Tobias Domhan Date: Sat, 10 May 2014 18:41:05 +0200 Subject: [PATCH] multiple test_iter --- src/caffe/proto/caffe.proto | 2 +- src/caffe/solver.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index cf3a9b7..12c4dc6 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -68,7 +68,7 @@ message SolverParameter { optional NetParameter train_net_param = 21; // Full params for the train net. repeated NetParameter test_net_param = 22; // Full params for the test nets. // The number of iterations for each testing phase. - optional int32 test_iter = 3 [default = 0]; + repeated int32 test_iter = 3; // The number of iterations between two testing phases. optional int32 test_interval = 4 [default = 0]; optional bool test_compute_loss = 19 [default = false]; diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index e3dc705..e68f719 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -53,7 +53,7 @@ void Solver::Init(const SolverParameter& param) { const int num_test_net_files = param_.test_net_size(); const int num_test_nets = num_test_net_params + num_test_net_files; if (num_test_nets) { - CHECK_GT(param_.test_iter(), 0); + CHECK_EQ(param_.test_iter().size(), num_test_nets) << "you need to specify test_iter for each test network."; CHECK_GT(param_.test_interval(), 0); } test_nets_.resize(num_test_nets); @@ -141,7 +141,7 @@ void Solver::Test(const int test_net_id) { vector test_score; vector*> bottom_vec; Dtype loss = 0; - for (int i = 0; i < param_.test_iter(); ++i) { + for (int i = 0; i < param_.test_iter().Get(test_net_id); ++i) { Dtype iter_loss; const vector*>& result = test_nets_[test_net_id]->Forward(bottom_vec, &iter_loss); @@ -166,12 +166,12 @@ void Solver::Test(const int test_net_id) { } } if (param_.test_compute_loss()) { - loss /= param_.test_iter(); + loss /= param_.test_iter().Get(test_net_id); LOG(INFO) << "Test loss: " << loss; } for (int i = 0; i < test_score.size(); ++i) { LOG(INFO) << "Test score #" << i << ": " - << test_score[i] / param_.test_iter(); + << test_score[i] / param_.test_iter().Get(test_net_id); } Caffe::set_phase(Caffe::TRAIN); } -- 2.7.4