Change solver type to string and provide solver registry
authorRonghang Hu <huronghang@hotmail.com>
Fri, 25 Sep 2015 02:40:45 +0000 (19:40 -0700)
committerRonghang Hu <huronghang@hotmail.com>
Sat, 17 Oct 2015 05:32:32 +0000 (22:32 -0700)
15 files changed:
include/caffe/caffe.hpp
include/caffe/sgd_solvers.hpp
include/caffe/solver.hpp
include/caffe/solver_factory.hpp [new file with mode: 0644]
src/caffe/proto/caffe.proto
src/caffe/solver_factory.cpp [deleted file]
src/caffe/solvers/adadelta_solver.cpp
src/caffe/solvers/adagrad_solver.cpp
src/caffe/solvers/adam_solver.cpp
src/caffe/solvers/nesterov_solver.cpp
src/caffe/solvers/rmsprop_solver.cpp
src/caffe/solvers/sgd_solver.cpp
src/caffe/test/test_gradient_based_solver.cpp
src/caffe/test/test_solver_factory.cpp [new file with mode: 0644]
tools/caffe.cpp

index 68a5e1d..bd77283 100644 (file)
@@ -13,6 +13,7 @@
 #include "caffe/parallel.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/solver.hpp"
+#include "caffe/solver_factory.hpp"
 #include "caffe/util/benchmark.hpp"
 #include "caffe/util/io.hpp"
 #include "caffe/vision_layers.hpp"
index 6bf1d70..1fc52d8 100644 (file)
@@ -19,6 +19,7 @@ class SGDSolver : public Solver<Dtype> {
       : Solver<Dtype>(param) { PreSolve(); }
   explicit SGDSolver(const string& param_file)
       : Solver<Dtype>(param_file) { PreSolve(); }
+  virtual inline const char* type() const { return "SGD"; }
 
   const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
 
@@ -51,6 +52,7 @@ class NesterovSolver : public SGDSolver<Dtype> {
       : SGDSolver<Dtype>(param) {}
   explicit NesterovSolver(const string& param_file)
       : SGDSolver<Dtype>(param_file) {}
+  virtual inline const char* type() const { return "Nesterov"; }
 
  protected:
   virtual void ComputeUpdateValue(int param_id, Dtype rate);
@@ -65,6 +67,7 @@ class AdaGradSolver : public SGDSolver<Dtype> {
       : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
   explicit AdaGradSolver(const string& param_file)
       : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+  virtual inline const char* type() const { return "AdaGrad"; }
 
  protected:
   virtual void ComputeUpdateValue(int param_id, Dtype rate);
@@ -84,6 +87,7 @@ class RMSPropSolver : public SGDSolver<Dtype> {
       : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
   explicit RMSPropSolver(const string& param_file)
       : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+  virtual inline const char* type() const { return "RMSProp"; }
 
  protected:
   virtual void ComputeUpdateValue(int param_id, Dtype rate);
@@ -106,6 +110,7 @@ class AdaDeltaSolver : public SGDSolver<Dtype> {
       : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
   explicit AdaDeltaSolver(const string& param_file)
       : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
+  virtual inline const char* type() const { return "AdaDelta"; }
 
  protected:
   void AdaDeltaPreSolve();
@@ -129,6 +134,7 @@ class AdamSolver : public SGDSolver<Dtype> {
       : SGDSolver<Dtype>(param) { AdamPreSolve();}
   explicit AdamSolver(const string& param_file)
       : SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
+  virtual inline const char* type() const { return "Adam"; }
 
  protected:
   void AdamPreSolve();
index a045ccf..298a68f 100644 (file)
@@ -5,6 +5,7 @@
 #include <vector>
 
 #include "caffe/net.hpp"
+#include "caffe/solver_factory.hpp"
 
 namespace caffe {
 
@@ -83,6 +84,10 @@ class Solver {
   }
 
   void CheckSnapshotWritePermissions();
+  /**
+   * @brief Returns the solver type.
+   */
+  virtual inline const char* type() const { return ""; }
 
  protected:
   // Make and apply the update value for the current iteration.
@@ -148,10 +153,6 @@ class WorkerSolver : public Solver<Dtype> {
   }
 };
 
-// The solver factory function
-template <typename Dtype>
-Solver<Dtype>* GetSolver(const SolverParameter& param);
-
 }  // namespace caffe
 
 #endif  // CAFFE_SOLVER_HPP_
diff --git a/include/caffe/solver_factory.hpp b/include/caffe/solver_factory.hpp
new file mode 100644 (file)
index 0000000..cfff721
--- /dev/null
@@ -0,0 +1,137 @@
+/**
+ * @brief A solver factory that allows one to register solvers, similar to
+ * layer factory. During runtime, registered solvers could be called by passing
+ * a SolverParameter protobuffer to the CreateSolver function:
+ *
+ *     SolverRegistry<Dtype>::CreateSolver(param);
+ *
+ * There are two ways to register a solver. Assuming that we have a solver like:
+ *
+ *   template <typename Dtype>
+ *   class MyAwesomeSolver : public Solver<Dtype> {
+ *     // your implementations
+ *   };
+ *
+ * and its type is its C++ class name, but without the "Solver" at the end
+ * ("MyAwesomeSolver" -> "MyAwesome").
+ *
+ * If the solver is going to be created simply by its constructor, in your c++
+ * file, add the following line:
+ *
+ *    REGISTER_SOLVER_CLASS(MyAwesome);
+ *
+ * Or, if the solver is going to be created by another creator function, in the
+ * format of:
+ *
+ *    template <typename Dtype>
+ *    Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
+ *      // your implementation
+ *    }
+ *
+ * then you can register the creator function instead, like
+ *
+ * REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
+ *
+ * Note that each solver type should only be registered once.
+ */
+
+#ifndef CAFFE_SOLVER_FACTORY_H_
+#define CAFFE_SOLVER_FACTORY_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "caffe/common.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+template <typename Dtype>
+class Solver;
+
+template <typename Dtype>
+class SolverRegistry {
+ public:
+  typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
+  typedef std::map<string, Creator> CreatorRegistry;
+
+  static CreatorRegistry& Registry() {
+    static CreatorRegistry* g_registry_ = new CreatorRegistry();
+    return *g_registry_;
+  }
+
+  // Adds a creator.
+  static void AddCreator(const string& type, Creator creator) {
+    CreatorRegistry& registry = Registry();
+    CHECK_EQ(registry.count(type), 0)
+        << "Solver type " << type << " already registered.";
+    registry[type] = creator;
+  }
+
+  // Get a solver using a SolverParameter.
+  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
+    const string& type = param.type();
+    CreatorRegistry& registry = Registry();
+    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
+        << " (known types: " << SolverTypeListString() << ")";
+    return registry[type](param);
+  }
+
+  static vector<string> SolverTypeList() {
+    CreatorRegistry& registry = Registry();
+    vector<string> solver_types;
+    for (typename CreatorRegistry::iterator iter = registry.begin();
+         iter != registry.end(); ++iter) {
+      solver_types.push_back(iter->first);
+    }
+    return solver_types;
+  }
+
+ private:
+  // Solver registry should never be instantiated - everything is done with its
+  // static variables.
+  SolverRegistry() {}
+
+  static string SolverTypeListString() {
+    vector<string> solver_types = SolverTypeList();
+    string solver_types_str;
+    for (vector<string>::iterator iter = solver_types.begin();
+         iter != solver_types.end(); ++iter) {
+      if (iter != solver_types.begin()) {
+        solver_types_str += ", ";
+      }
+      solver_types_str += *iter;
+    }
+    return solver_types_str;
+  }
+};
+
+
+template <typename Dtype>
+class SolverRegisterer {
+ public:
+  SolverRegisterer(const string& type,
+      Solver<Dtype>* (*creator)(const SolverParameter&)) {
+    // LOG(INFO) << "Registering solver type: " << type;
+    SolverRegistry<Dtype>::AddCreator(type, creator);
+  }
+};
+
+
+#define REGISTER_SOLVER_CREATOR(type, creator)                                 \
+  static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \
+  static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \
+
+#define REGISTER_SOLVER_CLASS(type)                                            \
+  template <typename Dtype>                                                    \
+  Solver<Dtype>* Creator_##type##Solver(                                       \
+      const SolverParameter& param)                                            \
+  {                                                                            \
+    return new type##Solver<Dtype>(param);                                     \
+  }                                                                            \
+  REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
+
+}  // namespace caffe
+
+#endif  // CAFFE_SOLVER_FACTORY_H_
index 4794991..76c869c 100644 (file)
@@ -98,7 +98,7 @@ message NetParameter {
 // NOTE
 // Update the next available ID when you add a new SolverParameter field.
 //
-// SolverParameter next available ID: 40 (last added: momentum2)
+// SolverParameter next available ID: 41 (last added: type)
 message SolverParameter {
   //////////////////////////////////////////////////////////////////////////////
   // Specifying the train and test networks
@@ -209,16 +209,9 @@ message SolverParameter {
   // (and by default) initialize using a seed derived from the system clock.
   optional int64 random_seed = 20 [default = -1];
 
-  // Solver type
-  enum SolverType {
-    SGD = 0;
-    NESTEROV = 1;
-    ADAGRAD = 2;
-    RMSPROP = 3;
-    ADADELTA = 4;
-    ADAM = 5;
-  }
-  optional SolverType solver_type = 30 [default = SGD];
+  // type of the solver
+  optional string type = 40 [default = "SGD"];
+
   // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
   optional float delta = 31 [default = 1e-8];
   // parameters for the Adam solver
@@ -234,6 +227,18 @@ message SolverParameter {
 
   // If false, don't save a snapshot after training finishes.
   optional bool snapshot_after_train = 28 [default = true];
+
+  // DEPRECATED: old solver enum types, use string instead
+  enum SolverType {
+    SGD = 0;
+    NESTEROV = 1;
+    ADAGRAD = 2;
+    RMSPROP = 3;
+    ADADELTA = 4;
+    ADAM = 5;
+  }
+  // DEPRECATED: use type instead of solver_type
+  optional SolverType solver_type = 30 [default = SGD];
 }
 
 // A message that stores the solver snapshots
diff --git a/src/caffe/solver_factory.cpp b/src/caffe/solver_factory.cpp
deleted file mode 100644 (file)
index f78fab2..0000000
+++ /dev/null
@@ -1,32 +0,0 @@
-#include "caffe/solver.hpp"
-#include "caffe/sgd_solvers.hpp"
-
-namespace caffe {
-
-template <typename Dtype>
-Solver<Dtype>* GetSolver(const SolverParameter& param) {
-  SolverParameter_SolverType type = param.solver_type();
-
-  switch (type) {
-  case SolverParameter_SolverType_SGD:
-    return new SGDSolver<Dtype>(param);
-  case SolverParameter_SolverType_NESTEROV:
-    return new NesterovSolver<Dtype>(param);
-  case SolverParameter_SolverType_ADAGRAD:
-    return new AdaGradSolver<Dtype>(param);
-  case SolverParameter_SolverType_RMSPROP:
-    return new RMSPropSolver<Dtype>(param);
-  case SolverParameter_SolverType_ADADELTA:
-    return new AdaDeltaSolver<Dtype>(param);
-  case SolverParameter_SolverType_ADAM:
-    return new AdamSolver<Dtype>(param);
-  default:
-    LOG(FATAL) << "Unknown SolverType: " << type;
-  }
-  return (Solver<Dtype>*) NULL;
-}
-
-template Solver<float>* GetSolver(const SolverParameter& param);
-template Solver<double>* GetSolver(const SolverParameter& param);
-
-}  // namespace caffe
index 45cd4eb..a37899e 100644 (file)
@@ -151,5 +151,6 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
 }
 
 INSTANTIATE_CLASS(AdaDeltaSolver);
+REGISTER_SOLVER_CLASS(AdaDelta);
 
 }  // namespace caffe
index 627d816..5e40632 100644 (file)
@@ -84,5 +84,6 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
 }
 
 INSTANTIATE_CLASS(AdaGradSolver);
+REGISTER_SOLVER_CLASS(AdaGrad);
 
 }  // namespace caffe
index 8c334f6..cb0fbfe 100644 (file)
@@ -108,5 +108,6 @@ void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
 }
 
 INSTANTIATE_CLASS(AdamSolver);
+REGISTER_SOLVER_CLASS(Adam);
 
 }  // namespace caffe
index 8135ee2..34bf01e 100644 (file)
@@ -66,5 +66,6 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
 }
 
 INSTANTIATE_CLASS(NesterovSolver);
+REGISTER_SOLVER_CLASS(Nesterov);
 
 }  // namespace caffe
index 96d1b3d..c624767 100644 (file)
@@ -80,5 +80,6 @@ void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
 }
 
 INSTANTIATE_CLASS(RMSPropSolver);
+REGISTER_SOLVER_CLASS(RMSProp);
 
 }  // namespace caffe
index 89ef5ec..32bf19b 100644 (file)
@@ -343,5 +343,6 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
 }
 
 INSTANTIATE_CLASS(SGDSolver);
+REGISTER_SOLVER_CLASS(SGD);
 
 }  // namespace caffe
index 1767ad3..84c6747 100644 (file)
@@ -47,7 +47,6 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
   // Test data: check out generate_sample_data.py in the same directory.
   string* input_file_;
 
-  virtual SolverParameter_SolverType solver_type() = 0;
   virtual void InitSolver(const SolverParameter& param) = 0;
 
   virtual void InitSolverFromProtoString(const string& proto) {
@@ -290,8 +289,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
           ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
       // Finally, compute update.
       const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
-      if (solver_type() != SolverParameter_SolverType_ADADELTA
-          && solver_type() != SolverParameter_SolverType_ADAM) {
+      if (solver_->type() != string("AdaDelta")
+          && solver_->type() != string("Adam")) {
         ASSERT_EQ(2, history.size());  // 1 blob for weights, 1 for bias
       } else {
         ASSERT_EQ(4, history.size());  // additional blobs for update history
@@ -300,26 +299,19 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
       const Dtype history_value = (i == D) ?
             history[1]->cpu_data()[0] : history[0]->cpu_data()[i];
       const Dtype temp = momentum * history_value;
-      switch (solver_type()) {
-      case SolverParameter_SolverType_SGD:
+      if (solver_->type() == string("SGD")) {
         update_value += temp;
-        break;
-      case SolverParameter_SolverType_NESTEROV:
+      } else if (solver_->type() == string("Nesterov")) {
         update_value += temp;
         // step back then over-step
         update_value = (1 + momentum) * update_value - temp;
-        break;
-      case SolverParameter_SolverType_ADAGRAD:
+      } else if (solver_->type() == string("AdaGrad")) {
         update_value /= std::sqrt(history_value + grad * grad) + delta_;
-        break;
-      case SolverParameter_SolverType_RMSPROP: {
+      } else if (solver_->type() == string("RMSProp")) {
         const Dtype rms_decay = 0.95;
         update_value /= std::sqrt(rms_decay*history_value
             + grad * grad * (1 - rms_decay)) + delta_;
-        }
-        break;
-      case SolverParameter_SolverType_ADADELTA:
-      {
+      } else if (solver_->type() == string("AdaDelta")) {
         const Dtype update_history_value = (i == D) ?
             history[1 + num_param_blobs]->cpu_data()[0] :
             history[0 + num_param_blobs]->cpu_data()[i];
@@ -330,9 +322,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
         // not actually needed, just here for illustrative purposes
         // const Dtype weighted_update_average =
         //   momentum * update_history_value + (1 - momentum) * (update_value);
-        break;
-      }
-      case SolverParameter_SolverType_ADAM: {
+      } else if (solver_->type() == string("Adam")) {
         const Dtype momentum2 = 0.999;
         const Dtype m = history_value;
         const Dtype v = (i == D) ?
@@ -344,10 +334,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
             std::sqrt(Dtype(1) - pow(momentum2, num_iters)) /
             (Dtype(1.) - pow(momentum, num_iters));
         update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
-        break;
-      }
-      default:
-        LOG(FATAL) << "Unknown solver type: " << solver_type();
+      } else {
+        LOG(FATAL) << "Unknown solver type: " << solver_->type();
       }
       if (i == D) {
         updated_bias.mutable_cpu_diff()[0] = update_value;
@@ -392,7 +380,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);
 
     // Check the solver's history -- should contain the previous update value.
-    if (solver_type() == SolverParameter_SolverType_SGD) {
+    if (solver_->type() == string("SGD")) {
       const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
       ASSERT_EQ(2, history.size());
       for (int i = 0; i < D; ++i) {
@@ -581,10 +569,6 @@ class SGDSolverTest : public GradientBasedSolverTest<TypeParam> {
   virtual void InitSolver(const SolverParameter& param) {
     this->solver_.reset(new SGDSolver<Dtype>(param));
   }
-
-  virtual SolverParameter_SolverType solver_type() {
-    return SolverParameter_SolverType_SGD;
-  }
 };
 
 TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices);
@@ -721,9 +705,6 @@ class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
   virtual void InitSolver(const SolverParameter& param) {
     this->solver_.reset(new AdaGradSolver<Dtype>(param));
   }
-  virtual SolverParameter_SolverType solver_type() {
-    return SolverParameter_SolverType_ADAGRAD;
-  }
 };
 
 TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices);
@@ -824,9 +805,6 @@ class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
   virtual void InitSolver(const SolverParameter& param) {
     this->solver_.reset(new NesterovSolver<Dtype>(param));
   }
-  virtual SolverParameter_SolverType solver_type() {
-    return SolverParameter_SolverType_NESTEROV;
-  }
 };
 
 TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices);
@@ -960,10 +938,6 @@ class AdaDeltaSolverTest : public GradientBasedSolverTest<TypeParam> {
   virtual void InitSolver(const SolverParameter& param) {
     this->solver_.reset(new AdaDeltaSolver<Dtype>(param));
   }
-
-  virtual SolverParameter_SolverType solver_type() {
-    return SolverParameter_SolverType_ADADELTA;
-  }
 };
 
 TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices);
@@ -1098,9 +1072,6 @@ class AdamSolverTest : public GradientBasedSolverTest<TypeParam> {
     new_param.set_momentum2(momentum2);
     this->solver_.reset(new AdamSolver<Dtype>(new_param));
   }
-  virtual SolverParameter_SolverType solver_type() {
-    return SolverParameter_SolverType_ADAM;
-  }
 };
 
 TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices);
@@ -1201,9 +1172,6 @@ class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
     new_param.set_rms_decay(rms_decay);
     this->solver_.reset(new RMSPropSolver<Dtype>(new_param));
   }
-  virtual SolverParameter_SolverType solver_type() {
-    return SolverParameter_SolverType_RMSPROP;
-  }
 };
 
 TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);
diff --git a/src/caffe/test/test_solver_factory.cpp b/src/caffe/test/test_solver_factory.cpp
new file mode 100644 (file)
index 0000000..eef5290
--- /dev/null
@@ -0,0 +1,50 @@
+#include <map>
+#include <string>
+
+#include "boost/scoped_ptr.hpp"
+#include "google/protobuf/text_format.h"
+#include "gtest/gtest.h"
+
+#include "caffe/common.hpp"
+#include "caffe/solver.hpp"
+#include "caffe/solver_factory.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+template <typename TypeParam>
+class SolverFactoryTest : public MultiDeviceTest<TypeParam> {
+ protected:
+  SolverParameter simple_solver_param() {
+    const string solver_proto =
+        "train_net_param { "
+        "  layer { "
+        "    name: 'data' type: 'DummyData' top: 'data' "
+        "    dummy_data_param { shape { dim: 1 } } "
+        "  } "
+        "} ";
+    SolverParameter solver_param;
+    CHECK(google::protobuf::TextFormat::ParseFromString(
+        solver_proto, &solver_param));
+    return solver_param;
+  }
+};
+
+TYPED_TEST_CASE(SolverFactoryTest, TestDtypesAndDevices);
+
+TYPED_TEST(SolverFactoryTest, TestCreateSolver) {
+  typedef typename TypeParam::Dtype Dtype;
+  typename SolverRegistry<Dtype>::CreatorRegistry& registry =
+      SolverRegistry<Dtype>::Registry();
+  shared_ptr<Solver<Dtype> > solver;
+  SolverParameter solver_param = this->simple_solver_param();
+  for (typename SolverRegistry<Dtype>::CreatorRegistry::iterator iter =
+       registry.begin(); iter != registry.end(); ++iter) {
+    solver_param.set_type(iter->first);
+    solver.reset(SolverRegistry<Dtype>::CreateSolver(solver_param));
+    EXPECT_EQ(iter->first, solver->type());
+  }
+}
+
+}  // namespace caffe
index e3f684b..1cb6ad8 100644 (file)
@@ -194,7 +194,7 @@ int train() {
         GetRequestedAction(FLAGS_sighup_effect));
 
   shared_ptr<caffe::Solver<float> >
-    solver(caffe::GetSolver<float>(solver_param));
+      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
 
   solver->SetActionFunction(signal_handler.GetActionFunction());