From: Jeff Donahue Date: Sat, 10 May 2014 03:28:55 +0000 (-0700) Subject: allow multiple test nets X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c97fff670299c8600ddd8b0ba54bbd7c325ee2f5;p=platform%2Fupstream%2Fcaffe.git allow multiple test nets --- diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index aef9b22..3112c59 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -33,7 +33,8 @@ class Solver { // written to disk together with the learned net. void Snapshot(); // The test routine - void Test(); + void TestAll(); + void Test(const int test_net_id = 0); virtual void SnapshotSolverState(SolverState* state) = 0; // The Restore function implements how one should restore the solver to a // previously snapshotted state. You should implement the RestoreSolverState() @@ -44,7 +45,7 @@ class Solver { SolverParameter param_; int iter_; shared_ptr > net_; - shared_ptr > test_net_; + vector > > test_nets_; DISABLE_COPY_AND_ASSIGN(Solver); }; diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index eb086a1..cf3a9b7 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -64,9 +64,9 @@ message SolverParameter { // If {train,test}_net is specified, {train,test}_net_param should not be, // and vice versa. optional string train_net = 1; // The proto filename for the train net. - optional string test_net = 2; // The proto filename for the test net. + repeated string test_net = 2; // The proto filenames for the test nets. optional NetParameter train_net_param = 21; // Full params for the train net. - optional NetParameter test_net_param = 22; // Full params for the test net. + repeated NetParameter test_net_param = 22; // Full params for the test nets. // The number of iterations for each testing phase. optional int32 test_iter = 3 [default = 0]; // The number of iterations between two testing phases. diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 9420ca3..e3dc705 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -19,13 +19,13 @@ namespace caffe { template Solver::Solver(const SolverParameter& param) - : net_(), test_net_() { + : net_() { Init(param); } template Solver::Solver(const string& param_file) - : net_(), test_net_() { + : net_() { SolverParameter param; ReadProtoFromTextFile(param_file, ¶m); Init(param); @@ -49,23 +49,25 @@ void Solver::Init(const SolverParameter& param) { LOG(INFO) << "Creating training net from file: " << param_.train_net(); net_.reset(new Net(param_.train_net())); } - bool has_test_net = false; - NetParameter test_net_param; - if (param_.has_test_net_param()) { - CHECK(!param_.has_test_net()) << "Either test_net_param or test_net may be " - << "specified, but not both."; - LOG(INFO) << "Creating testing net specified in SolverParameter."; - test_net_.reset(new Net(param_.test_net_param())); - has_test_net = true; - } else if (param_.has_test_net()) { - LOG(INFO) << "Creating testing net from file: " << param_.test_net(); - test_net_.reset(new Net(param_.test_net())); - has_test_net = true; - } - if (has_test_net) { + const int num_test_net_params = param_.test_net_param_size(); + const int num_test_net_files = param_.test_net_size(); + const int num_test_nets = num_test_net_params + num_test_net_files; + if (num_test_nets) { CHECK_GT(param_.test_iter(), 0); CHECK_GT(param_.test_interval(), 0); } + test_nets_.resize(num_test_nets); + for (int i = 0; i < num_test_net_params; ++i) { + LOG(INFO) << "Creating testing net (#" << i + << ") specified in SolverParameter."; + test_nets_[i].reset(new Net(param_.test_net_param(i))); + } + for (int i = 0, test_net_id = num_test_net_params; + i < num_test_net_files; ++i, ++test_net_id) { + LOG(INFO) << "Creating testing net (#" << test_net_id + << ") from file: " << param.test_net(i); + test_nets_[test_net_id].reset(new Net(param_.test_net(i))); + } LOG(INFO) << "Solver scaffolding done."; } @@ -92,7 +94,7 @@ void Solver::Solve(const char* resume_file) { // there's not enough memory to run the test net and crash, etc.; and to gauge // the effect of the first training iterations. if (param_.test_interval()) { - Test(); + TestAll(); } // For a network that is trained by the solver, no bottom or top vecs @@ -107,7 +109,7 @@ void Solver::Solve(const char* resume_file) { LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; } if (param_.test_interval() && iter_ % param_.test_interval() == 0) { - Test(); + TestAll(); } // Check if we need to do snapshot if (param_.snapshot() && iter_ % param_.snapshot() == 0) { @@ -122,18 +124,27 @@ void Solver::Solve(const char* resume_file) { template -void Solver::Test() { - LOG(INFO) << "Iteration " << iter_ << ", Testing net"; +void Solver::TestAll() { + for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) { + Test(test_net_id); + } +} + + +template +void Solver::Test(const int test_net_id) { + LOG(INFO) << "Iteration " << iter_ << ", Testing net (#" << test_net_id << ")"; // We need to set phase to test before running. Caffe::set_phase(Caffe::TEST); - CHECK_NOTNULL(test_net_.get())->ShareTrainedLayersWith(net_.get()); + CHECK_NOTNULL(test_nets_[test_net_id].get())-> + ShareTrainedLayersWith(net_.get()); vector test_score; vector*> bottom_vec; Dtype loss = 0; for (int i = 0; i < param_.test_iter(); ++i) { Dtype iter_loss; const vector*>& result = - test_net_->Forward(bottom_vec, &iter_loss); + test_nets_[test_net_id]->Forward(bottom_vec, &iter_loss); if (param_.test_compute_loss()) { loss += iter_loss; }