the solver checks its proto type (SolverParameter.type) on
instantiation:
- if the proto type is unspecified it's set according to the class type
`Solver::type()`
- if the proto type and class type conflict, the solver dies loudly
this helps avoid accidental instantiation of a different solver type
than intended when the solver def and class differ. guaranteed type
information in the SolverParameter will simplify multi-solver
coordination too.
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
+ /// Harmonize solver class type with configured proto type.
+ void CheckType(SolverParameter* param);
SolverParameter param_;
int iter_;
SolverParameter param_;
int iter_;
requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
Init(param);
}
template <typename Dtype>
Init(param);
}
template <typename Dtype>
+void Solver<Dtype>::CheckType(SolverParameter* param) {
+ // Harmonize solver class type with configured type to avoid confusion.
+ if (param->has_type()) {
+ CHECK_EQ(param->type(), this->type())
+ << "Solver type must agree with instantiated solver class.";
+ } else {
+ param->set_type(this->type());
+ }
+}
+
+template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
<< "root_solver_ needs to be set for all non-root solvers";
void Solver<Dtype>::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
<< "root_solver_ needs to be set for all non-root solvers";
+TYPED_TEST(SGDSolverTest, TestSolverType) {
+ this->TestLeastSquaresUpdate();
+ EXPECT_NE(this->solver_->type(), string(""));
+ EXPECT_EQ(this->solver_->type(), this->solver_->param().type());
+}
template <typename TypeParam>
class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
template <typename TypeParam>
class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {