solver: check and set type to reconcile class and proto
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Mon, 21 Nov 2016 17:35:57 +0000 (09:35 -0800)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Mon, 21 Nov 2016 17:35:57 +0000 (09:35 -0800)
the solver checks its proto type (SolverParameter.type) on
instantiation:

- if the proto type is unspecified it's set according to the class type
  `Solver::type()`
- if the proto type and class type conflict, the solver dies loudly

this helps avoid accidental instantiation of a different solver type
than intended when the solver def and class differ. guaranteed type
information in the SolverParameter will simplify multi-solver
coordination too.

include/caffe/solver.hpp
src/caffe/solver.cpp
src/caffe/test/test_gradient_based_solver.cpp

index eafcee3..ef38d6e 100644 (file)
@@ -108,6 +108,8 @@ class Solver {
   virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
   void DisplayOutputBlobs(const int net_id);
   void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
+  /// Harmonize solver class type with configured proto type.
+  void CheckType(SolverParameter* param);
 
   SolverParameter param_;
   int iter_;
index ece3913..ae6a5a3 100644 (file)
@@ -38,10 +38,22 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
       requested_early_exit_(false) {
   SolverParameter param;
   ReadSolverParamsFromTextFileOrDie(param_file, &param);
+  CheckType(&param);
   Init(param);
 }
 
 template <typename Dtype>
+void Solver<Dtype>::CheckType(SolverParameter* param) {
+  // Harmonize solver class type with configured type to avoid confusion.
+  if (param->has_type()) {
+    CHECK_EQ(param->type(), this->type())
+        << "Solver type must agree with instantiated solver class.";
+  } else {
+    param->set_type(this->type());
+  }
+}
+
+template <typename Dtype>
 void Solver<Dtype>::Init(const SolverParameter& param) {
   CHECK(Caffe::root_solver() || root_solver_)
       << "root_solver_ needs to be set for all non-root solvers";
index 975a8f0..e81caea 100644 (file)
@@ -694,6 +694,11 @@ TYPED_TEST(SGDSolverTest, TestSnapshotShare) {
   }
 }
 
+TYPED_TEST(SGDSolverTest, TestSolverType) {
+  this->TestLeastSquaresUpdate();
+  EXPECT_NE(this->solver_->type(), string(""));
+  EXPECT_EQ(this->solver_->type(), this->solver_->param().type());
+}
 
 template <typename TypeParam>
 class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {