virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
+ /// Harmonize solver class type with configured proto type.
+ void CheckType(SolverParameter* param);
SolverParameter param_;
int iter_;
requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
+ CheckType(¶m);
Init(param);
}
template <typename Dtype>
+void Solver<Dtype>::CheckType(SolverParameter* param) {
+ // Harmonize solver class type with configured type to avoid confusion.
+ if (param->has_type()) {
+ CHECK_EQ(param->type(), this->type())
+ << "Solver type must agree with instantiated solver class.";
+ } else {
+ param->set_type(this->type());
+ }
+}
+
+template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
<< "root_solver_ needs to be set for all non-root solvers";
}
}
+TYPED_TEST(SGDSolverTest, TestSolverType) {
+ this->TestLeastSquaresUpdate();
+ EXPECT_NE(this->solver_->type(), string(""));
+ EXPECT_EQ(this->solver_->type(), this->solver_->param().type());
+}
template <typename TypeParam>
class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {