From: Evan Shelhamer Date: Mon, 21 Nov 2016 17:35:57 +0000 (-0800) Subject: solver: check and set type to reconcile class and proto X-Git-Tag: submit/tizen/20180823.020014~119^2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e52451de914312b80a83459cb160c2f72a5b4fea;p=platform%2Fupstream%2Fcaffeonacl.git solver: check and set type to reconcile class and proto the solver checks its proto type (SolverParameter.type) on instantiation: - if the proto type is unspecified it's set according to the class type `Solver::type()` - if the proto type and class type conflict, the solver dies loudly this helps avoid accidental instantiation of a different solver type than intended when the solver def and class differ. guaranteed type information in the SolverParameter will simplify multi-solver coordination too. --- diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index eafcee3..ef38d6e 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -108,6 +108,8 @@ class Solver { virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0; void DisplayOutputBlobs(const int net_id); void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss); + /// Harmonize solver class type with configured proto type. + void CheckType(SolverParameter* param); SolverParameter param_; int iter_; diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index ece3913..ae6a5a3 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -38,10 +38,22 @@ Solver::Solver(const string& param_file, const Solver* root_solver) requested_early_exit_(false) { SolverParameter param; ReadSolverParamsFromTextFileOrDie(param_file, ¶m); + CheckType(¶m); Init(param); } template +void Solver::CheckType(SolverParameter* param) { + // Harmonize solver class type with configured type to avoid confusion. + if (param->has_type()) { + CHECK_EQ(param->type(), this->type()) + << "Solver type must agree with instantiated solver class."; + } else { + param->set_type(this->type()); + } +} + +template void Solver::Init(const SolverParameter& param) { CHECK(Caffe::root_solver() || root_solver_) << "root_solver_ needs to be set for all non-root solvers"; diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 975a8f0..e81caea 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -694,6 +694,11 @@ TYPED_TEST(SGDSolverTest, TestSnapshotShare) { } } +TYPED_TEST(SGDSolverTest, TestSolverType) { + this->TestLeastSquaresUpdate(); + EXPECT_NE(this->solver_->type(), string("")); + EXPECT_EQ(this->solver_->type(), this->solver_->param().type()); +} template class AdaGradSolverTest : public GradientBasedSolverTest {