98c872dc21d573f1f3b70c9f8ab4212ac88e80f3
[platform/upstream/caffeonacl.git] / src / caffe / optimization / solver.hpp
1 // Copyright Yangqing Jia 2013
2
3 #ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
4 #define CAFFE_OPTIMIZATION_SOLVER_HPP_
5
6 #include <vector>
7
8 namespace caffe {
9
10 template <typename Dtype>
11 class Solver {
12  public:
13   explicit Solver(const SolverParameter& param)
14       : param_(param) {}
15   // The main entry of the solver function. In default, iter will be zero. Pass
16   // in a non-zero iter number to resume training for a pre-trained net.
17   void Solve(Net<Dtype>* net, char* state_file = NULL);
18   virtual ~Solver() {}
19
20  protected:
21   // PreSolve is run before any solving iteration starts, allowing one to
22   // put up some scaffold.
23   virtual void PreSolve() {}
24   // Get the update value for the current iteration.
25   virtual void ComputeUpdateValue() = 0;
26   // The Solver::Snapshot function implements the basic snapshotting utility
27   // that stores the learned net. You should implement the SnapshotSolverState()
28   // function that produces a SolverState protocol buffer that needs to be
29   // written to disk together with the learned net.
30   void Snapshot();
31   virtual void SnapshotSolverState(SolverState* state) = 0;
32   // The Restore function implements how one should restore the solver to a
33   // previously snapshotted state. You should implement the RestoreSolverState()
34   // function that restores the state from a SolverState protocol buffer.
35   void Restore(char* state_file);
36   virtual void RestoreSolverState(const SolverState& state) = 0;
37   SolverParameter param_;
38   int iter_;
39   Net<Dtype>* net_;
40
41   DISABLE_COPY_AND_ASSIGN(Solver);
42 };
43
44
45 template <typename Dtype>
46 class SGDSolver : public Solver<Dtype> {
47  public:
48   explicit SGDSolver(const SolverParameter& param)
49       : Solver<Dtype>(param) {}
50
51  protected:
52   virtual void PreSolve();
53   virtual Dtype GetLearningRate();
54   virtual void ComputeUpdateValue();
55   virtual void SnapshotSolverState(SolverState * state);
56   virtual void RestoreSolverState(const SolverState& state);
57   // history maintains the historical momentum data.
58   vector<shared_ptr<Blob<Dtype> > > history_;
59 };
60
61
62 }  // namspace caffe
63
64 #endif  // CAFFE_OPTIMIZATION_SOLVER_HPP_