#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"
+#include "caffe/util/io.hpp"
#include "caffe/test/test_caffe_main.hpp"
GradientBasedSolverTest() :
seed_(1701), num_(4), channels_(3), height_(10), width_(10) {}
+ string snapshot_prefix_;
shared_ptr<SGDSolver<Dtype> > solver_;
int seed_;
int num_, channels_, height_, width_;
virtual void InitSolverFromProtoString(const string& proto) {
SolverParameter param;
CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m));
- // Disable saving a final snapshot so the tests don't pollute the user's
- // working directory with useless snapshots.
- param.set_snapshot_after_train(false);
// Set the solver_mode according to current Caffe::mode.
switch (Caffe::mode()) {
case Caffe::CPU:
param.delta() : 0;
}
- void RunLeastSquaresSolver(const Dtype learning_rate,
+ string RunLeastSquaresSolver(const Dtype learning_rate,
const Dtype weight_decay, const Dtype momentum, const int num_iters,
- const int iter_size = 1) {
+ const int iter_size = 1, const bool snapshot = false,
+ const char* from_snapshot = NULL) {
ostringstream proto;
proto <<
+ "snapshot_after_train: " << snapshot << " "
"max_iter: " << num_iters << " "
"base_lr: " << learning_rate << " "
"lr_policy: 'fixed' "
if (momentum != 0) {
proto << "momentum: " << momentum << " ";
}
+ MakeTempDir(&snapshot_prefix_);
+ proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' ";
+ if (snapshot) {
+ proto << "snapshot: " << num_iters << " ";
+ }
Caffe::set_random_seed(this->seed_);
this->InitSolverFromProtoString(proto.str());
+ Caffe::set_random_seed(this->seed_);
+ if (from_snapshot != NULL) {
+ this->solver_->Restore(from_snapshot);
+ vector<Blob<Dtype>*> empty_bottom_vec;
+ for (int i = 0; i < this->solver_->iter(); ++i) {
+ this->solver_->net()->Forward(empty_bottom_vec);
+ }
+ }
this->solver_->Solve();
+ if (snapshot) {
+ ostringstream resume_file;
+ resume_file << snapshot_prefix_ << "/_iter_" << num_iters
+ << ".solverstate";
+ string resume_filename = resume_file.str();
+ return resume_filename;
+ }
+ return string();
}
// Compute an update value given the current state of the train net,
// Check that the solver's solution matches ours.
CheckLeastSquaresUpdate(updated_params);
}
+
+ void TestSnapshot(const Dtype learning_rate = 1.0,
+ const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
+ const int num_iters = 1) {
+ // Run the solver for num_iters * 2 iterations.
+ const int total_num_iters = num_iters * 2;
+ bool snapshot = false;
+ const int kIterSize = 1;
+ RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+ total_num_iters, kIterSize, snapshot);
+
+ // Save the resulting param values.
+ vector<shared_ptr<Blob<Dtype> > > param_copies;
+ const vector<shared_ptr<Blob<Dtype> > >& orig_params =
+ solver_->net()->params();
+ param_copies.resize(orig_params.size());
+ for (int i = 0; i < orig_params.size(); ++i) {
+ param_copies[i].reset(new Blob<Dtype>());
+ const bool kReshape = true;
+ for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
+ param_copies[i]->CopyFrom(*orig_params[i], copy_diff, kReshape);
+ }
+ }
+
+ // Run the solver for num_iters iterations and snapshot.
+ snapshot = true;
+ string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
+ momentum, num_iters, kIterSize, snapshot);
+
+ // Reinitialize the solver and run for num_iters more iterations.
+ snapshot = false;
+ RunLeastSquaresSolver(learning_rate, weight_decay,
+ momentum, total_num_iters, kIterSize, snapshot, snapshot_name.c_str());
+
+ // Check that params now match.
+ const vector<shared_ptr<Blob<Dtype> > >& params = solver_->net()->params();
+ for (int i = 0; i < params.size(); ++i) {
+ for (int j = 0; j < params[i]->count(); ++j) {
+ EXPECT_EQ(param_copies[i]->cpu_data()[j], params[i]->cpu_data()[j])
+ << "param " << i << " data differed at dim " << j;
+ EXPECT_EQ(param_copies[i]->cpu_diff()[j], params[i]->cpu_diff()[j])
+ << "param " << i << " diff differed at dim " << j;
+ }
+ }
+ }
};
kIterSize);
}
+TYPED_TEST(SGDSolverTest, TestSnapshot) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.1;
+ const Dtype kMomentum = 0.9;
+ const int kNumIters = 4;
+ for (int i = 1; i <= kNumIters; ++i) {
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+ }
+}
+
+
template <typename TypeParam>
class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
kIterSize);
}
+TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.1;
+ const Dtype kMomentum = 0.0;
+ const int kNumIters = 4;
+ for (int i = 1; i <= kNumIters; ++i) {
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+ }
+}
+
+
template <typename TypeParam>
class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
kIterSize);
}
+TYPED_TEST(NesterovSolverTest, TestSnapshot) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.1;
+ const Dtype kMomentum = 0.0;
+ const int kNumIters = 4;
+ for (int i = 1; i <= kNumIters; ++i) {
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+ }
+}
+
} // namespace caffe