2 * @brief A solver factory that allows one to register solvers, similar to
3 * layer factory. During runtime, registered solvers could be called by passing
4 * a SolverParameter protobuffer to the CreateSolver function:
6 * SolverRegistry<Dtype>::CreateSolver(param);
8 * There are two ways to register a solver. Assuming that we have a solver like:
10 * template <typename Dtype>
11 * class MyAwesomeSolver : public Solver<Dtype> {
12 * // your implementations
15 * and its type is its C++ class name, but without the "Solver" at the end
16 * ("MyAwesomeSolver" -> "MyAwesome").
18 * If the solver is going to be created simply by its constructor, in your C++
19 * file, add the following line:
21 * REGISTER_SOLVER_CLASS(MyAwesome);
23 * Or, if the solver is going to be created by another creator function, in the
26 * template <typename Dtype>
27 * Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
28 * // your implementation
31 * then you can register the creator function instead, like
33 * REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
35 * Note that each solver type should only be registered once.
38 #ifndef CAFFE_SOLVER_FACTORY_H_
39 #define CAFFE_SOLVER_FACTORY_H_
45 #include "caffe/common.hpp"
46 #include "caffe/proto/caffe.pb.h"
50 template <typename Dtype>
53 template <typename Dtype>
54 class SolverRegistry {
56 typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
57 typedef std::map<string, Creator> CreatorRegistry;
59 static CreatorRegistry& Registry() {
60 static CreatorRegistry* g_registry_ = new CreatorRegistry();
65 static void AddCreator(const string& type, Creator creator) {
66 CreatorRegistry& registry = Registry();
67 CHECK_EQ(registry.count(type), 0)
68 << "Solver type " << type << " already registered.";
69 registry[type] = creator;
72 // Get a solver using a SolverParameter.
73 static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
74 const string& type = param.type();
75 CreatorRegistry& registry = Registry();
76 CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
77 << " (known types: " << SolverTypeListString() << ")";
78 return registry[type](param);
81 static vector<string> SolverTypeList() {
82 CreatorRegistry& registry = Registry();
83 vector<string> solver_types;
84 for (typename CreatorRegistry::iterator iter = registry.begin();
85 iter != registry.end(); ++iter) {
86 solver_types.push_back(iter->first);
92 // Solver registry should never be instantiated - everything is done with its
96 static string SolverTypeListString() {
97 vector<string> solver_types = SolverTypeList();
98 string solver_types_str;
99 for (vector<string>::iterator iter = solver_types.begin();
100 iter != solver_types.end(); ++iter) {
101 if (iter != solver_types.begin()) {
102 solver_types_str += ", ";
104 solver_types_str += *iter;
106 return solver_types_str;
111 template <typename Dtype>
112 class SolverRegisterer {
114 SolverRegisterer(const string& type,
115 Solver<Dtype>* (*creator)(const SolverParameter&)) {
116 // LOG(INFO) << "Registering solver type: " << type;
117 SolverRegistry<Dtype>::AddCreator(type, creator);
122 #define REGISTER_SOLVER_CREATOR(type, creator) \
123 static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
124 static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
126 #define REGISTER_SOLVER_CLASS(type) \
127 template <typename Dtype> \
128 Solver<Dtype>* Creator_##type##Solver( \
129 const SolverParameter& param) \
131 return new type##Solver<Dtype>(param); \
133 REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
137 #endif // CAFFE_SOLVER_FACTORY_H_