#include "caffe/parallel.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"
+#include "caffe/solver_factory.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) { PreSolve(); }
+ virtual inline const char* type() const { return "SGD"; }
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
: SGDSolver<Dtype>(param) {}
explicit NesterovSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) {}
+ virtual inline const char* type() const { return "Nesterov"; }
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit AdaGradSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+ virtual inline const char* type() const { return "AdaGrad"; }
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit RMSPropSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+ virtual inline const char* type() const { return "RMSProp"; }
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
: SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
explicit AdaDeltaSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
+ virtual inline const char* type() const { return "AdaDelta"; }
protected:
void AdaDeltaPreSolve();
: SGDSolver<Dtype>(param) { AdamPreSolve();}
explicit AdamSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
+ virtual inline const char* type() const { return "Adam"; }
protected:
void AdamPreSolve();
#include <vector>
#include "caffe/net.hpp"
+#include "caffe/solver_factory.hpp"
namespace caffe {
}
void CheckSnapshotWritePermissions();
+ /**
+ * @brief Returns the solver type.
+ */
+ virtual inline const char* type() const { return ""; }
protected:
// Make and apply the update value for the current iteration.
}
};
-// The solver factory function
-template <typename Dtype>
-Solver<Dtype>* GetSolver(const SolverParameter& param);
-
} // namespace caffe
#endif // CAFFE_SOLVER_HPP_
--- /dev/null
+/**
+ * @brief A solver factory that allows one to register solvers, similar to
+ * layer factory. During runtime, registered solvers could be called by passing
+ * a SolverParameter protobuffer to the CreateSolver function:
+ *
+ * SolverRegistry<Dtype>::CreateSolver(param);
+ *
+ * There are two ways to register a solver. Assuming that we have a solver like:
+ *
+ * template <typename Dtype>
+ * class MyAwesomeSolver : public Solver<Dtype> {
+ * // your implementations
+ * };
+ *
+ * and its type is its C++ class name, but without the "Solver" at the end
+ * ("MyAwesomeSolver" -> "MyAwesome").
+ *
+ * If the solver is going to be created simply by its constructor, in your c++
+ * file, add the following line:
+ *
+ * REGISTER_SOLVER_CLASS(MyAwesome);
+ *
+ * Or, if the solver is going to be created by another creator function, in the
+ * format of:
+ *
+ * template <typename Dtype>
+ * Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
+ * // your implementation
+ * }
+ *
+ * then you can register the creator function instead, like
+ *
+ * REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
+ *
+ * Note that each solver type should only be registered once.
+ */
+
+#ifndef CAFFE_SOLVER_FACTORY_H_
+#define CAFFE_SOLVER_FACTORY_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "caffe/common.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+template <typename Dtype>
+class Solver;
+
+template <typename Dtype>
+class SolverRegistry {
+ public:
+ typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
+ typedef std::map<string, Creator> CreatorRegistry;
+
+ static CreatorRegistry& Registry() {
+ static CreatorRegistry* g_registry_ = new CreatorRegistry();
+ return *g_registry_;
+ }
+
+ // Adds a creator.
+ static void AddCreator(const string& type, Creator creator) {
+ CreatorRegistry& registry = Registry();
+ CHECK_EQ(registry.count(type), 0)
+ << "Solver type " << type << " already registered.";
+ registry[type] = creator;
+ }
+
+ // Get a solver using a SolverParameter.
+ static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
+ const string& type = param.type();
+ CreatorRegistry& registry = Registry();
+ CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
+ << " (known types: " << SolverTypeListString() << ")";
+ return registry[type](param);
+ }
+
+ static vector<string> SolverTypeList() {
+ CreatorRegistry& registry = Registry();
+ vector<string> solver_types;
+ for (typename CreatorRegistry::iterator iter = registry.begin();
+ iter != registry.end(); ++iter) {
+ solver_types.push_back(iter->first);
+ }
+ return solver_types;
+ }
+
+ private:
+ // Solver registry should never be instantiated - everything is done with its
+ // static variables.
+ SolverRegistry() {}
+
+ static string SolverTypeListString() {
+ vector<string> solver_types = SolverTypeList();
+ string solver_types_str;
+ for (vector<string>::iterator iter = solver_types.begin();
+ iter != solver_types.end(); ++iter) {
+ if (iter != solver_types.begin()) {
+ solver_types_str += ", ";
+ }
+ solver_types_str += *iter;
+ }
+ return solver_types_str;
+ }
+};
+
+
+template <typename Dtype>
+class SolverRegisterer {
+ public:
+ SolverRegisterer(const string& type,
+ Solver<Dtype>* (*creator)(const SolverParameter&)) {
+ // LOG(INFO) << "Registering solver type: " << type;
+ SolverRegistry<Dtype>::AddCreator(type, creator);
+ }
+};
+
+
+#define REGISTER_SOLVER_CREATOR(type, creator) \
+ static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
+ static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
+
+#define REGISTER_SOLVER_CLASS(type) \
+ template <typename Dtype> \
+ Solver<Dtype>* Creator_##type##Solver( \
+ const SolverParameter& param) \
+ { \
+ return new type##Solver<Dtype>(param); \
+ } \
+ REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
+
+} // namespace caffe
+
+#endif // CAFFE_SOLVER_FACTORY_H_
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
-// SolverParameter next available ID: 40 (last added: momentum2)
+// SolverParameter next available ID: 41 (last added: type)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
// (and by default) initialize using a seed derived from the system clock.
optional int64 random_seed = 20 [default = -1];
- // Solver type
- enum SolverType {
- SGD = 0;
- NESTEROV = 1;
- ADAGRAD = 2;
- RMSPROP = 3;
- ADADELTA = 4;
- ADAM = 5;
- }
- optional SolverType solver_type = 30 [default = SGD];
+ // type of the solver
+ optional string type = 40 [default = "SGD"];
+
// numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
optional float delta = 31 [default = 1e-8];
// parameters for the Adam solver
// If false, don't save a snapshot after training finishes.
optional bool snapshot_after_train = 28 [default = true];
+
+ // DEPRECATED: old solver enum types, use string instead
+ enum SolverType {
+ SGD = 0;
+ NESTEROV = 1;
+ ADAGRAD = 2;
+ RMSPROP = 3;
+ ADADELTA = 4;
+ ADAM = 5;
+ }
+ // DEPRECATED: use type instead of solver_type
+ optional SolverType solver_type = 30 [default = SGD];
}
// A message that stores the solver snapshots
+++ /dev/null
-#include "caffe/solver.hpp"
-#include "caffe/sgd_solvers.hpp"
-
-namespace caffe {
-
-template <typename Dtype>
-Solver<Dtype>* GetSolver(const SolverParameter& param) {
- SolverParameter_SolverType type = param.solver_type();
-
- switch (type) {
- case SolverParameter_SolverType_SGD:
- return new SGDSolver<Dtype>(param);
- case SolverParameter_SolverType_NESTEROV:
- return new NesterovSolver<Dtype>(param);
- case SolverParameter_SolverType_ADAGRAD:
- return new AdaGradSolver<Dtype>(param);
- case SolverParameter_SolverType_RMSPROP:
- return new RMSPropSolver<Dtype>(param);
- case SolverParameter_SolverType_ADADELTA:
- return new AdaDeltaSolver<Dtype>(param);
- case SolverParameter_SolverType_ADAM:
- return new AdamSolver<Dtype>(param);
- default:
- LOG(FATAL) << "Unknown SolverType: " << type;
- }
- return (Solver<Dtype>*) NULL;
-}
-
-template Solver<float>* GetSolver(const SolverParameter& param);
-template Solver<double>* GetSolver(const SolverParameter& param);
-
-} // namespace caffe
}
INSTANTIATE_CLASS(AdaDeltaSolver);
+REGISTER_SOLVER_CLASS(AdaDelta);
} // namespace caffe
}
INSTANTIATE_CLASS(AdaGradSolver);
+REGISTER_SOLVER_CLASS(AdaGrad);
} // namespace caffe
}
INSTANTIATE_CLASS(AdamSolver);
+REGISTER_SOLVER_CLASS(Adam);
} // namespace caffe
}
INSTANTIATE_CLASS(NesterovSolver);
+REGISTER_SOLVER_CLASS(Nesterov);
} // namespace caffe
}
INSTANTIATE_CLASS(RMSPropSolver);
+REGISTER_SOLVER_CLASS(RMSProp);
} // namespace caffe
}
INSTANTIATE_CLASS(SGDSolver);
+REGISTER_SOLVER_CLASS(SGD);
} // namespace caffe
// Test data: check out generate_sample_data.py in the same directory.
string* input_file_;
- virtual SolverParameter_SolverType solver_type() = 0;
virtual void InitSolver(const SolverParameter& param) = 0;
virtual void InitSolverFromProtoString(const string& proto) {
((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
// Finally, compute update.
const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
- if (solver_type() != SolverParameter_SolverType_ADADELTA
- && solver_type() != SolverParameter_SolverType_ADAM) {
+ if (solver_->type() != string("AdaDelta")
+ && solver_->type() != string("Adam")) {
ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias
} else {
ASSERT_EQ(4, history.size()); // additional blobs for update history
const Dtype history_value = (i == D) ?
history[1]->cpu_data()[0] : history[0]->cpu_data()[i];
const Dtype temp = momentum * history_value;
- switch (solver_type()) {
- case SolverParameter_SolverType_SGD:
+ if (solver_->type() == string("SGD")) {
update_value += temp;
- break;
- case SolverParameter_SolverType_NESTEROV:
+ } else if (solver_->type() == string("Nesterov")) {
update_value += temp;
// step back then over-step
update_value = (1 + momentum) * update_value - temp;
- break;
- case SolverParameter_SolverType_ADAGRAD:
+ } else if (solver_->type() == string("AdaGrad")) {
update_value /= std::sqrt(history_value + grad * grad) + delta_;
- break;
- case SolverParameter_SolverType_RMSPROP: {
+ } else if (solver_->type() == string("RMSProp")) {
const Dtype rms_decay = 0.95;
update_value /= std::sqrt(rms_decay*history_value
+ grad * grad * (1 - rms_decay)) + delta_;
- }
- break;
- case SolverParameter_SolverType_ADADELTA:
- {
+ } else if (solver_->type() == string("AdaDelta")) {
const Dtype update_history_value = (i == D) ?
history[1 + num_param_blobs]->cpu_data()[0] :
history[0 + num_param_blobs]->cpu_data()[i];
// not actually needed, just here for illustrative purposes
// const Dtype weighted_update_average =
// momentum * update_history_value + (1 - momentum) * (update_value);
- break;
- }
- case SolverParameter_SolverType_ADAM: {
+ } else if (solver_->type() == string("Adam")) {
const Dtype momentum2 = 0.999;
const Dtype m = history_value;
const Dtype v = (i == D) ?
std::sqrt(Dtype(1) - pow(momentum2, num_iters)) /
(Dtype(1.) - pow(momentum, num_iters));
update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
- break;
- }
- default:
- LOG(FATAL) << "Unknown solver type: " << solver_type();
+ } else {
+ LOG(FATAL) << "Unknown solver type: " << solver_->type();
}
if (i == D) {
updated_bias.mutable_cpu_diff()[0] = update_value;
EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);
// Check the solver's history -- should contain the previous update value.
- if (solver_type() == SolverParameter_SolverType_SGD) {
+ if (solver_->type() == string("SGD")) {
const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
ASSERT_EQ(2, history.size());
for (int i = 0; i < D; ++i) {
virtual void InitSolver(const SolverParameter& param) {
this->solver_.reset(new SGDSolver<Dtype>(param));
}
-
- virtual SolverParameter_SolverType solver_type() {
- return SolverParameter_SolverType_SGD;
- }
};
TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices);
virtual void InitSolver(const SolverParameter& param) {
this->solver_.reset(new AdaGradSolver<Dtype>(param));
}
- virtual SolverParameter_SolverType solver_type() {
- return SolverParameter_SolverType_ADAGRAD;
- }
};
TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices);
virtual void InitSolver(const SolverParameter& param) {
this->solver_.reset(new NesterovSolver<Dtype>(param));
}
- virtual SolverParameter_SolverType solver_type() {
- return SolverParameter_SolverType_NESTEROV;
- }
};
TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices);
virtual void InitSolver(const SolverParameter& param) {
this->solver_.reset(new AdaDeltaSolver<Dtype>(param));
}
-
- virtual SolverParameter_SolverType solver_type() {
- return SolverParameter_SolverType_ADADELTA;
- }
};
TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices);
new_param.set_momentum2(momentum2);
this->solver_.reset(new AdamSolver<Dtype>(new_param));
}
- virtual SolverParameter_SolverType solver_type() {
- return SolverParameter_SolverType_ADAM;
- }
};
TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices);
new_param.set_rms_decay(rms_decay);
this->solver_.reset(new RMSPropSolver<Dtype>(new_param));
}
- virtual SolverParameter_SolverType solver_type() {
- return SolverParameter_SolverType_RMSPROP;
- }
};
TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);
--- /dev/null
+#include <map>
+#include <string>
+
+#include "boost/scoped_ptr.hpp"
+#include "google/protobuf/text_format.h"
+#include "gtest/gtest.h"
+
+#include "caffe/common.hpp"
+#include "caffe/solver.hpp"
+#include "caffe/solver_factory.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+template <typename TypeParam>
+class SolverFactoryTest : public MultiDeviceTest<TypeParam> {
+ protected:
+ SolverParameter simple_solver_param() {
+ const string solver_proto =
+ "train_net_param { "
+ " layer { "
+ " name: 'data' type: 'DummyData' top: 'data' "
+ " dummy_data_param { shape { dim: 1 } } "
+ " } "
+ "} ";
+ SolverParameter solver_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ solver_proto, &solver_param));
+ return solver_param;
+ }
+};
+
+TYPED_TEST_CASE(SolverFactoryTest, TestDtypesAndDevices);
+
+TYPED_TEST(SolverFactoryTest, TestCreateSolver) {
+ typedef typename TypeParam::Dtype Dtype;
+ typename SolverRegistry<Dtype>::CreatorRegistry& registry =
+ SolverRegistry<Dtype>::Registry();
+ shared_ptr<Solver<Dtype> > solver;
+ SolverParameter solver_param = this->simple_solver_param();
+ for (typename SolverRegistry<Dtype>::CreatorRegistry::iterator iter =
+ registry.begin(); iter != registry.end(); ++iter) {
+ solver_param.set_type(iter->first);
+ solver.reset(SolverRegistry<Dtype>::CreateSolver(solver_param));
+ EXPECT_EQ(iter->first, solver->type());
+ }
+}
+
+} // namespace caffe
GetRequestedAction(FLAGS_sighup_effect));
shared_ptr<caffe::Solver<float> >
- solver(caffe::GetSolver<float>(solver_param));
+ solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
solver->SetActionFunction(signal_handler.GetActionFunction());