virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
- /// Harmonize solver class type with configured proto type.
- void CheckType(SolverParameter* param);
SolverParameter param_;
int iter_;
requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
- CheckType(¶m);
Init(param);
}
template <typename Dtype>
-void Solver<Dtype>::CheckType(SolverParameter* param) {
- // Harmonize solver class type with configured type to avoid confusion.
- if (param->has_type()) {
- CHECK_EQ(param->type(), this->type())
- << "Solver type must agree with instantiated solver class.";
- } else {
- param->set_type(this->type());
- }
-}
-
-template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
<< "root_solver_ needs to be set for all non-root solvers";
}
}
-TYPED_TEST(SGDSolverTest, TestSolverType) {
- this->TestLeastSquaresUpdate();
- EXPECT_NE(this->solver_->type(), string(""));
- EXPECT_EQ(this->solver_->type(), this->solver_->param().type());
-}
template <typename TypeParam>
class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {