Merge pull request #5184 from shaibagon/fix_batch_norm_param_upgrade
[platform/upstream/caffeonacl.git] / include / caffe / solver_factory.hpp
1 /**
2  * @brief A solver factory that allows one to register solvers, similar to
3  * layer factory. During runtime, registered solvers could be called by passing
4  * a SolverParameter protobuffer to the CreateSolver function:
5  *
6  *     SolverRegistry<Dtype>::CreateSolver(param);
7  *
8  * There are two ways to register a solver. Assuming that we have a solver like:
9  *
10  *   template <typename Dtype>
11  *   class MyAwesomeSolver : public Solver<Dtype> {
12  *     // your implementations
13  *   };
14  *
15  * and its type is its C++ class name, but without the "Solver" at the end
16  * ("MyAwesomeSolver" -> "MyAwesome").
17  *
18  * If the solver is going to be created simply by its constructor, in your C++
19  * file, add the following line:
20  *
21  *    REGISTER_SOLVER_CLASS(MyAwesome);
22  *
23  * Or, if the solver is going to be created by another creator function, in the
24  * format of:
25  *
26  *    template <typename Dtype>
27  *    Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
28  *      // your implementation
29  *    }
30  *
31  * then you can register the creator function instead, like
32  *
33  * REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
34  *
35  * Note that each solver type should only be registered once.
36  */
37
38 #ifndef CAFFE_SOLVER_FACTORY_H_
39 #define CAFFE_SOLVER_FACTORY_H_
40
41 #include <map>
42 #include <string>
43 #include <vector>
44
45 #include "caffe/common.hpp"
46 #include "caffe/proto/caffe.pb.h"
47
48 namespace caffe {
49
50 template <typename Dtype>
51 class Solver;
52
53 template <typename Dtype>
54 class SolverRegistry {
55  public:
56   typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
57   typedef std::map<string, Creator> CreatorRegistry;
58
59   static CreatorRegistry& Registry() {
60     static CreatorRegistry* g_registry_ = new CreatorRegistry();
61     return *g_registry_;
62   }
63
64   // Adds a creator.
65   static void AddCreator(const string& type, Creator creator) {
66     CreatorRegistry& registry = Registry();
67     CHECK_EQ(registry.count(type), 0)
68         << "Solver type " << type << " already registered.";
69     registry[type] = creator;
70   }
71
72   // Get a solver using a SolverParameter.
73   static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
74     const string& type = param.type();
75     CreatorRegistry& registry = Registry();
76     CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
77         << " (known types: " << SolverTypeListString() << ")";
78     return registry[type](param);
79   }
80
81   static vector<string> SolverTypeList() {
82     CreatorRegistry& registry = Registry();
83     vector<string> solver_types;
84     for (typename CreatorRegistry::iterator iter = registry.begin();
85          iter != registry.end(); ++iter) {
86       solver_types.push_back(iter->first);
87     }
88     return solver_types;
89   }
90
91  private:
92   // Solver registry should never be instantiated - everything is done with its
93   // static variables.
94   SolverRegistry() {}
95
96   static string SolverTypeListString() {
97     vector<string> solver_types = SolverTypeList();
98     string solver_types_str;
99     for (vector<string>::iterator iter = solver_types.begin();
100          iter != solver_types.end(); ++iter) {
101       if (iter != solver_types.begin()) {
102         solver_types_str += ", ";
103       }
104       solver_types_str += *iter;
105     }
106     return solver_types_str;
107   }
108 };
109
110
111 template <typename Dtype>
112 class SolverRegisterer {
113  public:
114   SolverRegisterer(const string& type,
115       Solver<Dtype>* (*creator)(const SolverParameter&)) {
116     // LOG(INFO) << "Registering solver type: " << type;
117     SolverRegistry<Dtype>::AddCreator(type, creator);
118   }
119 };
120
121
122 #define REGISTER_SOLVER_CREATOR(type, creator)                                 \
123   static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \
124   static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \
125
126 #define REGISTER_SOLVER_CLASS(type)                                            \
127   template <typename Dtype>                                                    \
128   Solver<Dtype>* Creator_##type##Solver(                                       \
129       const SolverParameter& param)                                            \
130   {                                                                            \
131     return new type##Solver<Dtype>(param);                                     \
132   }                                                                            \
133   REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
134
135 }  // namespace caffe
136
137 #endif  // CAFFE_SOLVER_FACTORY_H_