#include "caffe/blob.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
-#include "caffe/layer_factory.hpp"
#include "caffe/net.hpp"
#include "caffe/vision_layers.hpp"
}
}
+// The layer factory function
+template <typename Dtype>
+Layer<Dtype>* GetLayer(const LayerParameter& param);
+
} // namespace caffe
#endif // CAFFE_LAYER_H_
return (Layer<Dtype>*)(NULL);
}
+template Layer<float>* GetLayer(const LayerParameter& param);
+template Layer<double>* GetLayer(const LayerParameter& param);
} // namespace caffe
#include <vector>
#include "caffe/proto/caffe.pb.h"
-#include "caffe/layer_factory.hpp"
+#include "caffe/layer.hpp"
#include "caffe/net.hpp"
using std::pair;
namespace caffe {
template <typename Dtype>
-void Solver<Dtype>::Solve(Net<Dtype>* net) {
+void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
net_ = net;
LOG(INFO) << "Solving " << net_->name();
PreSolve();
+
iter_ = 0;
+ if (resume_file) {
+ LOG(INFO) << "Restoring previous solver status from " << resume_file;
+ Restore(resume_file);
+ }
+
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
vector<Blob<Dtype>*> bottom_vec;
sprintf(iter_str_buffer, "_iter_%d", iter_);
filename += iter_str_buffer;
}
- LOG(ERROR) << "Snapshotting to " << filename;
+ LOG(INFO) << "Snapshotting to " << filename;
WriteProtoToBinaryFile(net_param, filename.c_str());
+ SolverState state;
+ SnapshotSolverState(&state);
+ state.set_iter(iter_);
+ state.set_learned_net(filename);
+ filename += ".solverstate";
+ LOG(INFO) << "Snapshotting solver state to " << filename;
+ WriteProtoToBinaryFile(state, filename.c_str());
+}
+
+template <typename Dtype>
+void Solver<Dtype>::Restore(char* state_file) {
+ SolverState state;
+ NetParameter net_param;
+ ReadProtoFromBinaryFile(state_file, &state);
+ ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
+ net_->CopyTrainedLayersFrom(net_param);
+ iter_ = state.iter();
+ RestoreSolverState(state);
}
}
}
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) {
+ state->clear_history();
+ for (int i = 0; i < history_.size(); ++i) {
+ // Add history
+ BlobProto* history_blob = state->add_history();
+ history_[i]->ToProto(history_blob);
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
+ CHECK_EQ(state.history_size(), history_.size())
+ << "Incorrect length of history blobs.";
+ for (int i = 0; i < history_.size(); ++i) {
+ history_[i]->FromProto(state.history(i));
+ }
+}
INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
public:
explicit Solver(const SolverParameter& param)
: param_(param) {}
- // The main entry of the solver function.
- void Solve(Net<Dtype>* net);
+ // The main entry of the solver function. In default, iter will be zero. Pass
+ // in a non-zero iter number to resume training for a pre-trained net.
+ void Solve(Net<Dtype>* net, char* state_file = NULL);
virtual ~Solver() {}
protected:
virtual void PreSolve() {}
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
+ // The Solver::Snapshot function implements the basic snapshotting utility
+ // that stores the learned net. You should implement the SnapshotSolverState()
+ // function that produces a SolverState protocol buffer that needs to be
+ // written to disk together with the learned net.
void Snapshot(bool is_final = false);
+ virtual void SnapshotSolverState(SolverState* state) = 0;
+ // The Restore function implements how one should restore the solver to a
+ // previously snapshotted state. You should implement the RestoreSolverState()
+ // function that restores the state from a SolverState protocol buffer.
+ void Restore(char* state_file);
+ virtual void RestoreSolverState(const SolverState& state) = 0;
SolverParameter param_;
int iter_;
Net<Dtype>* net_;
protected:
virtual void PreSolve();
- Dtype GetLearningRate();
+ virtual Dtype GetLearningRate();
virtual void ComputeUpdateValue();
+ virtual void SnapshotSolverState(SolverState * state);
+ virtual void RestoreSolverState(const SolverState& state);
// history maintains the historical momentum data.
vector<shared_ptr<Blob<Dtype> > > history_;
};
optional float stepsize = 12; // the stepsize for learning rate policy "step"
optional string snapshot_prefix = 13; // The prefix for the snapshot.
+
+ // Adagrad solver parameters
+ // For Adagrad, we will first run normal sgd using the sgd parameters above
+ // for adagrad_skip iterations, and then kick in the adagrad algorithm, with
+ // the learning rate being adagrad_gamma * adagrad_skip. Note that the adagrad
+ // algorithm will NOT use the learning rate multiplier that is specified in
+ // the layer parameter specifications, as it will adjust the learning rate
+ // of individual parameters in a data-dependent way.
+ // WORK IN PROGRESS: not actually implemented yet.
+ optional float adagrad_gamma = 14; // adagrad learning rate multiplier
+ optional float adagrad_skip = 15; // the steps to skip before adagrad kicks in
}
+
+// A message that stores the solver snapshots
+message SolverState {
+ optional int32 iter = 1; // The current iteration
+ optional string learned_net = 2; // The file that stores the learned net.
+ repeated BlobProto history = 3; // The history for sgd solvers
+}
\ No newline at end of file
pthread_t thread_;
shared_ptr<Blob<Dtype> > prefetch_data_;
shared_ptr<Blob<Dtype> > prefetch_label_;
+ Blob<Dtype> data_mean_;
};