Split solver code into one file per solver class
authorRonghang Hu <huronghang@hotmail.com>
Fri, 25 Sep 2015 00:11:07 +0000 (17:11 -0700)
committerRonghang Hu <huronghang@hotmail.com>
Sat, 17 Oct 2015 05:32:12 +0000 (22:32 -0700)
13 files changed:
include/caffe/sgd_solvers.hpp [new file with mode: 0644]
include/caffe/solver.hpp
python/caffe/_caffe.cpp
src/caffe/solver.cpp
src/caffe/solver_factory.cpp [new file with mode: 0644]
src/caffe/solvers/adadelta_solver.cpp [new file with mode: 0644]
src/caffe/solvers/adagrad_solver.cpp [new file with mode: 0644]
src/caffe/solvers/adam_solver.cpp [new file with mode: 0644]
src/caffe/solvers/nesterov_solver.cpp [new file with mode: 0644]
src/caffe/solvers/rmsprop_solver.cpp [new file with mode: 0644]
src/caffe/solvers/sgd_solver.cpp [new file with mode: 0644]
src/caffe/test/test_gradient_based_solver.cpp
src/caffe/test/test_solver.cpp

diff --git a/include/caffe/sgd_solvers.hpp b/include/caffe/sgd_solvers.hpp
new file mode 100644 (file)
index 0000000..6bf1d70
--- /dev/null
@@ -0,0 +1,142 @@
+#ifndef CAFFE_SGD_SOLVERS_HPP_
+#define CAFFE_SGD_SOLVERS_HPP_
+
+#include <string>
+#include <vector>
+
+#include "caffe/solver.hpp"
+
+namespace caffe {
+
+/**
+ * @brief Optimizes the parameters of a Net using
+ *        stochastic gradient descent (SGD) with momentum.
+ */
+template <typename Dtype>
+class SGDSolver : public Solver<Dtype> {
+ public:
+  explicit SGDSolver(const SolverParameter& param)
+      : Solver<Dtype>(param) { PreSolve(); }
+  explicit SGDSolver(const string& param_file)
+      : Solver<Dtype>(param_file) { PreSolve(); }
+
+  const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
+
+ protected:
+  void PreSolve();
+  Dtype GetLearningRate();
+  virtual void ApplyUpdate();
+  virtual void Normalize(int param_id);
+  virtual void Regularize(int param_id);
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
+  virtual void ClipGradients();
+  virtual void SnapshotSolverState(const string& model_filename);
+  virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
+  virtual void SnapshotSolverStateToHDF5(const string& model_filename);
+  virtual void RestoreSolverStateFromHDF5(const string& state_file);
+  virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
+  // history maintains the historical momentum data.
+  // update maintains update related data and is not needed in snapshots.
+  // temp maintains other information that might be needed in computation
+  //   of gradients/updates and is not needed in snapshots
+  vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
+
+  DISABLE_COPY_AND_ASSIGN(SGDSolver);
+};
+
+template <typename Dtype>
+class NesterovSolver : public SGDSolver<Dtype> {
+ public:
+  explicit NesterovSolver(const SolverParameter& param)
+      : SGDSolver<Dtype>(param) {}
+  explicit NesterovSolver(const string& param_file)
+      : SGDSolver<Dtype>(param_file) {}
+
+ protected:
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
+
+  DISABLE_COPY_AND_ASSIGN(NesterovSolver);
+};
+
+template <typename Dtype>
+class AdaGradSolver : public SGDSolver<Dtype> {
+ public:
+  explicit AdaGradSolver(const SolverParameter& param)
+      : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
+  explicit AdaGradSolver(const string& param_file)
+      : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+
+ protected:
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
+  void constructor_sanity_check() {
+    CHECK_EQ(0, this->param_.momentum())
+        << "Momentum cannot be used with AdaGrad.";
+  }
+
+  DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
+};
+
+
+template <typename Dtype>
+class RMSPropSolver : public SGDSolver<Dtype> {
+ public:
+  explicit RMSPropSolver(const SolverParameter& param)
+      : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
+  explicit RMSPropSolver(const string& param_file)
+      : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+
+ protected:
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
+  void constructor_sanity_check() {
+    CHECK_EQ(0, this->param_.momentum())
+        << "Momentum cannot be used with RMSProp.";
+    CHECK_GE(this->param_.rms_decay(), 0)
+        << "rms_decay should lie between 0 and 1.";
+    CHECK_LT(this->param_.rms_decay(), 1)
+        << "rms_decay should lie between 0 and 1.";
+  }
+
+  DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
+};
+
+template <typename Dtype>
+class AdaDeltaSolver : public SGDSolver<Dtype> {
+ public:
+  explicit AdaDeltaSolver(const SolverParameter& param)
+      : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
+  explicit AdaDeltaSolver(const string& param_file)
+      : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
+
+ protected:
+  void AdaDeltaPreSolve();
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
+
+  DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
+};
+
+/**
+ * @brief AdamSolver, an algorithm for first-order gradient-based optimization
+ *        of stochastic objective functions, based on adaptive estimates of
+ *        lower-order moments. Described in [1].
+ *
+ * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."
+ *     arXiv preprint arXiv:1412.6980v8 (2014).
+ */
+template <typename Dtype>
+class AdamSolver : public SGDSolver<Dtype> {
+ public:
+  explicit AdamSolver(const SolverParameter& param)
+      : SGDSolver<Dtype>(param) { AdamPreSolve();}
+  explicit AdamSolver(const string& param_file)
+      : SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
+
+ protected:
+  void AdamPreSolve();
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
+
+  DISABLE_COPY_AND_ASSIGN(AdamSolver);
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_SGD_SOLVERS_HPP_
index 2ecf539..a045ccf 100644 (file)
@@ -1,5 +1,5 @@
-#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
-#define CAFFE_OPTIMIZATION_SOLVER_HPP_
+#ifndef CAFFE_SOLVER_HPP_
+#define CAFFE_SOLVER_HPP_
 #include <boost/function.hpp>
 #include <string>
 #include <vector>
@@ -148,158 +148,10 @@ class WorkerSolver : public Solver<Dtype> {
   }
 };
 
-/**
- * @brief Optimizes the parameters of a Net using
- *        stochastic gradient descent (SGD) with momentum.
- */
-template <typename Dtype>
-class SGDSolver : public Solver<Dtype> {
- public:
-  explicit SGDSolver(const SolverParameter& param)
-      : Solver<Dtype>(param) { PreSolve(); }
-  explicit SGDSolver(const string& param_file)
-      : Solver<Dtype>(param_file) { PreSolve(); }
-
-  const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
-
- protected:
-  void PreSolve();
-  Dtype GetLearningRate();
-  virtual void ApplyUpdate();
-  virtual void Normalize(int param_id);
-  virtual void Regularize(int param_id);
-  virtual void ComputeUpdateValue(int param_id, Dtype rate);
-  virtual void ClipGradients();
-  virtual void SnapshotSolverState(const string& model_filename);
-  virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
-  virtual void SnapshotSolverStateToHDF5(const string& model_filename);
-  virtual void RestoreSolverStateFromHDF5(const string& state_file);
-  virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
-  // history maintains the historical momentum data.
-  // update maintains update related data and is not needed in snapshots.
-  // temp maintains other information that might be needed in computation
-  //   of gradients/updates and is not needed in snapshots
-  vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
-
-  DISABLE_COPY_AND_ASSIGN(SGDSolver);
-};
-
+// The solver factory function
 template <typename Dtype>
-class NesterovSolver : public SGDSolver<Dtype> {
- public:
-  explicit NesterovSolver(const SolverParameter& param)
-      : SGDSolver<Dtype>(param) {}
-  explicit NesterovSolver(const string& param_file)
-      : SGDSolver<Dtype>(param_file) {}
-
- protected:
-  virtual void ComputeUpdateValue(int param_id, Dtype rate);
-
-  DISABLE_COPY_AND_ASSIGN(NesterovSolver);
-};
-
-template <typename Dtype>
-class AdaGradSolver : public SGDSolver<Dtype> {
- public:
-  explicit AdaGradSolver(const SolverParameter& param)
-      : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
-  explicit AdaGradSolver(const string& param_file)
-      : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
-
- protected:
-  virtual void ComputeUpdateValue(int param_id, Dtype rate);
-  void constructor_sanity_check() {
-    CHECK_EQ(0, this->param_.momentum())
-        << "Momentum cannot be used with AdaGrad.";
-  }
-
-  DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
-};
-
-
-template <typename Dtype>
-class RMSPropSolver : public SGDSolver<Dtype> {
- public:
-  explicit RMSPropSolver(const SolverParameter& param)
-      : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
-  explicit RMSPropSolver(const string& param_file)
-      : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
-
- protected:
-  virtual void ComputeUpdateValue(int param_id, Dtype rate);
-  void constructor_sanity_check() {
-    CHECK_EQ(0, this->param_.momentum())
-        << "Momentum cannot be used with RMSProp.";
-    CHECK_GE(this->param_.rms_decay(), 0)
-        << "rms_decay should lie between 0 and 1.";
-    CHECK_LT(this->param_.rms_decay(), 1)
-        << "rms_decay should lie between 0 and 1.";
-  }
-
-  DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
-};
-
-template <typename Dtype>
-class AdaDeltaSolver : public SGDSolver<Dtype> {
- public:
-  explicit AdaDeltaSolver(const SolverParameter& param)
-      : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
-  explicit AdaDeltaSolver(const string& param_file)
-      : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
-
- protected:
-  void AdaDeltaPreSolve();
-  virtual void ComputeUpdateValue(int param_id, Dtype rate);
-
-  DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
-};
-
-/**
- * @brief AdamSolver, an algorithm for first-order gradient-based optimization
- *        of stochastic objective functions, based on adaptive estimates of
- *        lower-order moments. Described in [1].
- *
- * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."
- *     arXiv preprint arXiv:1412.6980v8 (2014).
- */
-template <typename Dtype>
-class AdamSolver : public SGDSolver<Dtype> {
- public:
-  explicit AdamSolver(const SolverParameter& param)
-      : SGDSolver<Dtype>(param) { AdamPreSolve();}
-  explicit AdamSolver(const string& param_file)
-      : SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
-
- protected:
-  void AdamPreSolve();
-  virtual void ComputeUpdateValue(int param_id, Dtype rate);
-
-  DISABLE_COPY_AND_ASSIGN(AdamSolver);
-};
-
-template <typename Dtype>
-Solver<Dtype>* GetSolver(const SolverParameter& param) {
-  SolverParameter_SolverType type = param.solver_type();
-
-  switch (type) {
-  case SolverParameter_SolverType_SGD:
-    return new SGDSolver<Dtype>(param);
-  case SolverParameter_SolverType_NESTEROV:
-    return new NesterovSolver<Dtype>(param);
-  case SolverParameter_SolverType_ADAGRAD:
-    return new AdaGradSolver<Dtype>(param);
-  case SolverParameter_SolverType_RMSPROP:
-    return new RMSPropSolver<Dtype>(param);
-  case SolverParameter_SolverType_ADADELTA:
-    return new AdaDeltaSolver<Dtype>(param);
-  case SolverParameter_SolverType_ADAM:
-    return new AdamSolver<Dtype>(param);
-  default:
-    LOG(FATAL) << "Unknown SolverType: " << type;
-  }
-  return (Solver<Dtype>*) NULL;
-}
+Solver<Dtype>* GetSolver(const SolverParameter& param);
 
 }  // namespace caffe
 
-#endif  // CAFFE_OPTIMIZATION_SOLVER_HPP_
+#endif  // CAFFE_SOLVER_HPP_
index ccd5776..0e38dee 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "caffe/caffe.hpp"
 #include "caffe/python_layer.hpp"
+#include "caffe/sgd_solvers.hpp"
 
 // Temporary solution for numpy < 1.7 versions: old macro, no promises.
 // You're strongly advised to upgrade to >= 1.7.
index 12c13dd..016a028 100644 (file)
@@ -1,18 +1,11 @@
 #include <cstdio>
 
-#include <algorithm>
 #include <string>
 #include <vector>
 
-#include "hdf5.h"
-#include "hdf5_hl.h"
-
-#include "caffe/net.hpp"
-#include "caffe/proto/caffe.pb.h"
 #include "caffe/solver.hpp"
 #include "caffe/util/hdf5.hpp"
 #include "caffe/util/io.hpp"
-#include "caffe/util/math_functions.hpp"
 #include "caffe/util/upgrade_proto.hpp"
 
 namespace caffe {
@@ -492,810 +485,6 @@ void Solver<Dtype>::Restore(const char* state_file) {
   }
 }
 
-// Return the current learning rate. The currently implemented learning rate
-// policies are as follows:
-//    - fixed: always return base_lr.
-//    - step: return base_lr * gamma ^ (floor(iter / step))
-//    - exp: return base_lr * gamma ^ iter
-//    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
-//    - multistep: similar to step but it allows non uniform steps defined by
-//      stepvalue
-//    - poly: the effective learning rate follows a polynomial decay, to be
-//      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
-//    - sigmoid: the effective learning rate follows a sigmod decay
-//      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
-//
-// where base_lr, max_iter, gamma, step, stepvalue and power are defined
-// in the solver parameter protocol buffer, and iter is the current iteration.
-template <typename Dtype>
-Dtype SGDSolver<Dtype>::GetLearningRate() {
-  Dtype rate;
-  const string& lr_policy = this->param_.lr_policy();
-  if (lr_policy == "fixed") {
-    rate = this->param_.base_lr();
-  } else if (lr_policy == "step") {
-    this->current_step_ = this->iter_ / this->param_.stepsize();
-    rate = this->param_.base_lr() *
-        pow(this->param_.gamma(), this->current_step_);
-  } else if (lr_policy == "exp") {
-    rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
-  } else if (lr_policy == "inv") {
-    rate = this->param_.base_lr() *
-        pow(Dtype(1) + this->param_.gamma() * this->iter_,
-            - this->param_.power());
-  } else if (lr_policy == "multistep") {
-    if (this->current_step_ < this->param_.stepvalue_size() &&
-          this->iter_ >= this->param_.stepvalue(this->current_step_)) {
-      this->current_step_++;
-      LOG(INFO) << "MultiStep Status: Iteration " <<
-      this->iter_ << ", step = " << this->current_step_;
-    }
-    rate = this->param_.base_lr() *
-        pow(this->param_.gamma(), this->current_step_);
-  } else if (lr_policy == "poly") {
-    rate = this->param_.base_lr() * pow(Dtype(1.) -
-        (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
-        this->param_.power());
-  } else if (lr_policy == "sigmoid") {
-    rate = this->param_.base_lr() * (Dtype(1.) /
-        (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
-          Dtype(this->param_.stepsize())))));
-  } else {
-    LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
-  }
-  return rate;
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::PreSolve() {
-  // Initialize the history
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  history_.clear();
-  update_.clear();
-  temp_.clear();
-  for (int i = 0; i < net_params.size(); ++i) {
-    const vector<int>& shape = net_params[i]->shape();
-    history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
-    update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
-    temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
-  }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::ClipGradients() {
-  const Dtype clip_gradients = this->param_.clip_gradients();
-  if (clip_gradients < 0) { return; }
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  Dtype sumsq_diff = 0;
-  for (int i = 0; i < net_params.size(); ++i) {
-    sumsq_diff += net_params[i]->sumsq_diff();
-  }
-  const Dtype l2norm_diff = std::sqrt(sumsq_diff);
-  if (l2norm_diff > clip_gradients) {
-    Dtype scale_factor = clip_gradients / l2norm_diff;
-    LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
-        << l2norm_diff << " > " << clip_gradients << ") "
-        << "by scale factor " << scale_factor;
-    for (int i = 0; i < net_params.size(); ++i) {
-      net_params[i]->scale_diff(scale_factor);
-    }
-  }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::ApplyUpdate() {
-  CHECK(Caffe::root_solver());
-  Dtype rate = GetLearningRate();
-  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
-    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
-  }
-  ClipGradients();
-  for (int param_id = 0; param_id < this->net_->learnable_params().size();
-       ++param_id) {
-    Normalize(param_id);
-    Regularize(param_id);
-    ComputeUpdateValue(param_id, rate);
-  }
-  this->net_->Update();
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::Normalize(int param_id) {
-  if (this->param_.iter_size() == 1) { return; }
-  // Scale gradient to counterbalance accumulation.
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
-  switch (Caffe::mode()) {
-  case Caffe::CPU: {
-    caffe_scal(net_params[param_id]->count(), accum_normalization,
-        net_params[param_id]->mutable_cpu_diff());
-    break;
-  }
-  case Caffe::GPU: {
-#ifndef CPU_ONLY
-    caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
-        net_params[param_id]->mutable_gpu_diff());
-#else
-    NO_GPU;
-#endif
-    break;
-  }
-  default:
-    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
-  }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::Regularize(int param_id) {
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  const vector<float>& net_params_weight_decay =
-      this->net_->params_weight_decay();
-  Dtype weight_decay = this->param_.weight_decay();
-  string regularization_type = this->param_.regularization_type();
-  Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-  switch (Caffe::mode()) {
-  case Caffe::CPU: {
-    if (local_decay) {
-      if (regularization_type == "L2") {
-        // add weight decay
-        caffe_axpy(net_params[param_id]->count(),
-            local_decay,
-            net_params[param_id]->cpu_data(),
-            net_params[param_id]->mutable_cpu_diff());
-      } else if (regularization_type == "L1") {
-        caffe_cpu_sign(net_params[param_id]->count(),
-            net_params[param_id]->cpu_data(),
-            temp_[param_id]->mutable_cpu_data());
-        caffe_axpy(net_params[param_id]->count(),
-            local_decay,
-            temp_[param_id]->cpu_data(),
-            net_params[param_id]->mutable_cpu_diff());
-      } else {
-        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
-      }
-    }
-    break;
-  }
-  case Caffe::GPU: {
-#ifndef CPU_ONLY
-    if (local_decay) {
-      if (regularization_type == "L2") {
-        // add weight decay
-        caffe_gpu_axpy(net_params[param_id]->count(),
-            local_decay,
-            net_params[param_id]->gpu_data(),
-            net_params[param_id]->mutable_gpu_diff());
-      } else if (regularization_type == "L1") {
-        caffe_gpu_sign(net_params[param_id]->count(),
-            net_params[param_id]->gpu_data(),
-            temp_[param_id]->mutable_gpu_data());
-        caffe_gpu_axpy(net_params[param_id]->count(),
-            local_decay,
-            temp_[param_id]->gpu_data(),
-            net_params[param_id]->mutable_gpu_diff());
-      } else {
-        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
-      }
-    }
-#else
-    NO_GPU;
-#endif
-    break;
-  }
-  default:
-    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
-  }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  const vector<float>& net_params_lr = this->net_->params_lr();
-  Dtype momentum = this->param_.momentum();
-  Dtype local_rate = rate * net_params_lr[param_id];
-  // Compute the update to history, then copy it to the parameter diff.
-  switch (Caffe::mode()) {
-  case Caffe::CPU: {
-    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
-              net_params[param_id]->cpu_diff(), momentum,
-              history_[param_id]->mutable_cpu_data());
-    caffe_copy(net_params[param_id]->count(),
-        history_[param_id]->cpu_data(),
-        net_params[param_id]->mutable_cpu_diff());
-    break;
-  }
-  case Caffe::GPU: {
-#ifndef CPU_ONLY
-    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
-              net_params[param_id]->gpu_diff(), momentum,
-              history_[param_id]->mutable_gpu_data());
-    caffe_copy(net_params[param_id]->count(),
-        history_[param_id]->gpu_data(),
-        net_params[param_id]->mutable_gpu_diff());
-#else
-    NO_GPU;
-#endif
-    break;
-  }
-  default:
-    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
-  }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
-  switch (this->param_.snapshot_format()) {
-    case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
-      SnapshotSolverStateToBinaryProto(model_filename);
-      break;
-    case caffe::SolverParameter_SnapshotFormat_HDF5:
-      SnapshotSolverStateToHDF5(model_filename);
-      break;
-    default:
-      LOG(FATAL) << "Unsupported snapshot format.";
-  }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
-    const string& model_filename) {
-  SolverState state;
-  state.set_iter(this->iter_);
-  state.set_learned_net(model_filename);
-  state.set_current_step(this->current_step_);
-  state.clear_history();
-  for (int i = 0; i < history_.size(); ++i) {
-    // Add history
-    BlobProto* history_blob = state.add_history();
-    history_[i]->ToProto(history_blob);
-  }
-  string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
-  LOG(INFO)
-    << "Snapshotting solver state to binary proto file " << snapshot_filename;
-  WriteProtoToBinaryFile(state, snapshot_filename.c_str());
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
-    const string& model_filename) {
-  string snapshot_filename =
-      Solver<Dtype>::SnapshotFilename(".solverstate.h5");
-  LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
-  hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC,
-      H5P_DEFAULT, H5P_DEFAULT);
-  CHECK_GE(file_hid, 0)
-      << "Couldn't open " << snapshot_filename << " to save solver state.";
-  hdf5_save_int(file_hid, "iter", this->iter_);
-  hdf5_save_string(file_hid, "learned_net", model_filename);
-  hdf5_save_int(file_hid, "current_step", this->current_step_);
-  hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT,
-      H5P_DEFAULT);
-  CHECK_GE(history_hid, 0)
-      << "Error saving solver state to " << snapshot_filename << ".";
-  for (int i = 0; i < history_.size(); ++i) {
-    ostringstream oss;
-    oss << i;
-    hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]);
-  }
-  H5Gclose(history_hid);
-  H5Fclose(file_hid);
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
-    const string& state_file) {
-  SolverState state;
-  ReadProtoFromBinaryFile(state_file, &state);
-  this->iter_ = state.iter();
-  if (state.has_learned_net()) {
-    NetParameter net_param;
-    ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
-    this->net_->CopyTrainedLayersFrom(net_param);
-  }
-  this->current_step_ = state.current_step();
-  CHECK_EQ(state.history_size(), history_.size())
-      << "Incorrect length of history blobs.";
-  LOG(INFO) << "SGDSolver: restoring history";
-  for (int i = 0; i < history_.size(); ++i) {
-    history_[i]->FromProto(state.history(i));
-  }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
-  hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
-  CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
-  this->iter_ = hdf5_load_int(file_hid, "iter");
-  if (H5LTfind_dataset(file_hid, "learned_net")) {
-    string learned_net = hdf5_load_string(file_hid, "learned_net");
-    this->net_->CopyTrainedLayersFrom(learned_net);
-  }
-  this->current_step_ = hdf5_load_int(file_hid, "current_step");
-  hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
-  CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
-  int state_history_size = hdf5_get_num_links(history_hid);
-  CHECK_EQ(state_history_size, history_.size())
-      << "Incorrect length of history blobs.";
-  for (int i = 0; i < history_.size(); ++i) {
-    ostringstream oss;
-    oss << i;
-    hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
-                                kMaxBlobAxes, history_[i].get());
-  }
-  H5Gclose(history_hid);
-  H5Fclose(file_hid);
-}
-
-template <typename Dtype>
-void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
-  CHECK(Caffe::root_solver());
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  const vector<float>& net_params_lr = this->net_->params_lr();
-  Dtype momentum = this->param_.momentum();
-  Dtype local_rate = rate * net_params_lr[param_id];
-  switch (Caffe::mode()) {
-  case Caffe::CPU: {
-    // save history momentum for stepping back
-    caffe_copy(net_params[param_id]->count(),
-        this->history_[param_id]->cpu_data(),
-        this->update_[param_id]->mutable_cpu_data());
-
-    // update history
-    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
-              net_params[param_id]->cpu_diff(), momentum,
-              this->history_[param_id]->mutable_cpu_data());
-
-    // compute update: step back then over step
-    caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
-        this->history_[param_id]->cpu_data(), -momentum,
-        this->update_[param_id]->mutable_cpu_data());
-
-    // copy
-    caffe_copy(net_params[param_id]->count(),
-        this->update_[param_id]->cpu_data(),
-        net_params[param_id]->mutable_cpu_diff());
-    break;
-  }
-  case Caffe::GPU: {
-#ifndef CPU_ONLY
-    // save history momentum for stepping back
-    caffe_copy(net_params[param_id]->count(),
-        this->history_[param_id]->gpu_data(),
-        this->update_[param_id]->mutable_gpu_data());
-
-    // update history
-    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
-              net_params[param_id]->gpu_diff(), momentum,
-              this->history_[param_id]->mutable_gpu_data());
-
-    // compute update: step back then over step
-    caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
-        this->history_[param_id]->gpu_data(), -momentum,
-        this->update_[param_id]->mutable_gpu_data());
-
-    // copy
-    caffe_copy(net_params[param_id]->count(),
-        this->update_[param_id]->gpu_data(),
-        net_params[param_id]->mutable_gpu_diff());
-#else
-    NO_GPU;
-#endif
-    break;
-  }
-  default:
-    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
-  }
-}
-
-template <typename Dtype>
-void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
-  CHECK(Caffe::root_solver());
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  const vector<float>& net_params_lr = this->net_->params_lr();
-  Dtype delta = this->param_.delta();
-  Dtype local_rate = rate * net_params_lr[param_id];
-  switch (Caffe::mode()) {
-  case Caffe::CPU: {
-    // compute square of gradient in update
-    caffe_powx(net_params[param_id]->count(),
-        net_params[param_id]->cpu_diff(), Dtype(2),
-        this->update_[param_id]->mutable_cpu_data());
-
-    // update history
-    caffe_add(net_params[param_id]->count(),
-        this->update_[param_id]->cpu_data(),
-        this->history_[param_id]->cpu_data(),
-        this->history_[param_id]->mutable_cpu_data());
-
-    // prepare update
-    caffe_powx(net_params[param_id]->count(),
-              this->history_[param_id]->cpu_data(), Dtype(0.5),
-              this->update_[param_id]->mutable_cpu_data());
-
-    caffe_add_scalar(net_params[param_id]->count(),
-              delta, this->update_[param_id]->mutable_cpu_data());
-
-    caffe_div(net_params[param_id]->count(),
-              net_params[param_id]->cpu_diff(),
-              this->update_[param_id]->cpu_data(),
-              this->update_[param_id]->mutable_cpu_data());
-
-    // scale and copy
-    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
-        this->update_[param_id]->cpu_data(), Dtype(0),
-        net_params[param_id]->mutable_cpu_diff());
-    break;
-  }
-  case Caffe::GPU: {
-#ifndef CPU_ONLY
-    // compute square of gradient in update
-    caffe_gpu_powx(net_params[param_id]->count(),
-        net_params[param_id]->gpu_diff(), Dtype(2),
-        this->update_[param_id]->mutable_gpu_data());
-
-    // update history
-    caffe_gpu_add(net_params[param_id]->count(),
-        this->update_[param_id]->gpu_data(),
-        this->history_[param_id]->gpu_data(),
-        this->history_[param_id]->mutable_gpu_data());
-
-    // prepare update
-    caffe_gpu_powx(net_params[param_id]->count(),
-              this->history_[param_id]->gpu_data(), Dtype(0.5),
-              this->update_[param_id]->mutable_gpu_data());
-
-    caffe_gpu_add_scalar(net_params[param_id]->count(),
-              delta, this->update_[param_id]->mutable_gpu_data());
-
-    caffe_gpu_div(net_params[param_id]->count(),
-              net_params[param_id]->gpu_diff(),
-              this->update_[param_id]->gpu_data(),
-              this->update_[param_id]->mutable_gpu_data());
-
-    // scale and copy
-    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
-        this->update_[param_id]->gpu_data(), Dtype(0),
-        net_params[param_id]->mutable_gpu_diff());
-#else
-    NO_GPU;
-#endif
-    break;
-  }
-  default:
-    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
-  }
-}
-
-template <typename Dtype>
-void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  const vector<float>& net_params_lr = this->net_->params_lr();
-
-  // get the learning rate
-  Dtype delta = this->param_.delta();
-  Dtype rms_decay = this->param_.rms_decay();
-  Dtype local_rate = rate * net_params_lr[param_id];
-
-  switch (Caffe::mode()) {
-  case Caffe::CPU:
-    // compute square of gradient in update
-    caffe_powx(net_params[param_id]->count(),
-        net_params[param_id]->cpu_diff(), Dtype(2),
-        this->update_[param_id]->mutable_cpu_data());
-
-    // update history
-    caffe_cpu_axpby(net_params[param_id] -> count(),
-        Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
-        rms_decay, this->history_[param_id]-> mutable_cpu_data());
-
-    // prepare update
-    caffe_powx(net_params[param_id]->count(),
-        this->history_[param_id]->cpu_data(), Dtype(0.5),
-        this->update_[param_id]->mutable_cpu_data());
-
-    caffe_add_scalar(net_params[param_id]->count(),
-        delta, this->update_[param_id]->mutable_cpu_data());
-
-    caffe_div(net_params[param_id]->count(),
-        net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
-        this->update_[param_id]->mutable_cpu_data());
-
-    // scale and copy
-    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
-        this->update_[param_id]->cpu_data(), Dtype(0),
-        net_params[param_id]->mutable_cpu_diff());
-    break;
-  case Caffe::GPU:
-#ifndef CPU_ONLY
-    // compute square of gradient in update
-    caffe_gpu_powx(net_params[param_id]->count(),
-        net_params[param_id]->gpu_diff(), Dtype(2),
-        this->update_[param_id]->mutable_gpu_data());
-
-    // update history
-    caffe_gpu_axpby(net_params[param_id] -> count(),
-        Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
-        rms_decay, this->history_[param_id]-> mutable_gpu_data());
-
-    // prepare update
-    caffe_gpu_powx(net_params[param_id]->count(),
-        this->history_[param_id]->gpu_data(), Dtype(0.5),
-        this->update_[param_id]->mutable_gpu_data());
-
-    caffe_gpu_add_scalar(net_params[param_id]->count(),
-        delta, this->update_[param_id]->mutable_gpu_data());
-
-    caffe_gpu_div(net_params[param_id]->count(),
-        net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
-        this->update_[param_id]->mutable_gpu_data());
-
-    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
-        this->update_[param_id]->gpu_data(), Dtype(0),
-        net_params[param_id]->mutable_gpu_diff());
-#else
-    NO_GPU;
-#endif
-    break;
-  default:
-    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
-  }
-}
-
-template <typename Dtype>
-void AdaDeltaSolver<Dtype>::AdaDeltaPreSolve() {
-  // Add the extra history entries for AdaDelta after those from
-  // SGDSolver::PreSolve
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  for (int i = 0; i < net_params.size(); ++i) {
-        const vector<int>& shape = net_params[i]->shape();
-        this->history_.push_back(
-                shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
-  }
-}
-
-template <typename Dtype>
-void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  const vector<float>& net_params_lr = this->net_->params_lr();
-  Dtype delta = this->param_.delta();
-  Dtype momentum = this->param_.momentum();
-  Dtype local_rate = rate * net_params_lr[param_id];
-  size_t update_history_offset = net_params.size();
-  switch (Caffe::mode()) {
-  case Caffe::CPU: {
-    // compute square of gradient in update
-    caffe_powx(net_params[param_id]->count(),
-        net_params[param_id]->cpu_diff(), Dtype(2),
-        this->update_[param_id]->mutable_cpu_data());
-
-    // update history of gradients
-    caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
-        this->update_[param_id]->cpu_data(), momentum,
-        this->history_[param_id]->mutable_cpu_data());
-
-    // add delta to history to guard against dividing by zero later
-    caffe_set(net_params[param_id]->count(), delta,
-        this->temp_[param_id]->mutable_cpu_data());
-
-    caffe_add(net_params[param_id]->count(),
-        this->temp_[param_id]->cpu_data(),
-        this->history_[update_history_offset + param_id]->cpu_data(),
-        this->update_[param_id]->mutable_cpu_data());
-
-    caffe_add(net_params[param_id]->count(),
-        this->temp_[param_id]->cpu_data(),
-        this->history_[param_id]->cpu_data(),
-        this->temp_[param_id]->mutable_cpu_data());
-
-    // divide history of updates by history of gradients
-    caffe_div(net_params[param_id]->count(),
-        this->update_[param_id]->cpu_data(),
-        this->temp_[param_id]->cpu_data(),
-        this->update_[param_id]->mutable_cpu_data());
-
-    // jointly compute the RMS of both for update and gradient history
-    caffe_powx(net_params[param_id]->count(),
-        this->update_[param_id]->cpu_data(), Dtype(0.5),
-        this->update_[param_id]->mutable_cpu_data());
-
-    // compute the update
-    caffe_mul(net_params[param_id]->count(),
-        net_params[param_id]->cpu_diff(),
-        this->update_[param_id]->cpu_data(),
-        net_params[param_id]->mutable_cpu_diff());
-
-    // compute square of update
-    caffe_powx(net_params[param_id]->count(),
-        net_params[param_id]->cpu_diff(), Dtype(2),
-        this->update_[param_id]->mutable_cpu_data());
-
-    // update history of updates
-    caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
-        this->update_[param_id]->cpu_data(), momentum,
-        this->history_[update_history_offset + param_id]->mutable_cpu_data());
-
-    // apply learning rate
-    caffe_cpu_scale(net_params[param_id]->count(), local_rate,
-        net_params[param_id]->cpu_diff(),
-        net_params[param_id]->mutable_cpu_diff());
-    break;
-  }
-  case Caffe::GPU: {
-#ifndef CPU_ONLY
-    // compute square of gradient in update
-    caffe_gpu_powx(net_params[param_id]->count(),
-        net_params[param_id]->gpu_diff(), Dtype(2),
-        this->update_[param_id]->mutable_gpu_data());
-
-    // update history of gradients
-    caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
-        this->update_[param_id]->gpu_data(), momentum,
-        this->history_[param_id]->mutable_gpu_data());
-
-    // add delta to history to guard against dividing by zero later
-    caffe_gpu_set(net_params[param_id]->count(), delta,
-        this->temp_[param_id]->mutable_gpu_data());
-
-    caffe_gpu_add(net_params[param_id]->count(),
-        this->temp_[param_id]->gpu_data(),
-        this->history_[update_history_offset + param_id]->gpu_data(),
-        this->update_[param_id]->mutable_gpu_data());
-
-    caffe_gpu_add(net_params[param_id]->count(),
-        this->temp_[param_id]->gpu_data(),
-        this->history_[param_id]->gpu_data(),
-        this->temp_[param_id]->mutable_gpu_data());
-
-    // divide history of updates by history of gradients
-    caffe_gpu_div(net_params[param_id]->count(),
-        this->update_[param_id]->gpu_data(),
-        this->temp_[param_id]->gpu_data(),
-        this->update_[param_id]->mutable_gpu_data());
-
-    // jointly compute the RMS of both for update and gradient history
-    caffe_gpu_powx(net_params[param_id]->count(),
-        this->update_[param_id]->gpu_data(), Dtype(0.5),
-        this->update_[param_id]->mutable_gpu_data());
-
-    // compute the update and copy to net_diff
-    caffe_gpu_mul(net_params[param_id]->count(),
-        net_params[param_id]->gpu_diff(),
-        this->update_[param_id]->gpu_data(),
-        net_params[param_id]->mutable_gpu_diff());
-
-    // compute square of update
-    caffe_gpu_powx(net_params[param_id]->count(),
-        net_params[param_id]->gpu_diff(), Dtype(2),
-        this->update_[param_id]->mutable_gpu_data());
-
-    // update history of updates
-    caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
-        this->update_[param_id]->gpu_data(), momentum,
-        this->history_[update_history_offset + param_id]->mutable_gpu_data());
-
-    // apply learning rate
-    caffe_gpu_scale(net_params[param_id]->count(), local_rate,
-        net_params[param_id]->gpu_diff(),
-        net_params[param_id]->mutable_gpu_diff());
-#else
-    NO_GPU;
-#endif
-    break;
-  }
-  default:
-    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
-  }
-}
-
-template <typename Dtype>
-void AdamSolver<Dtype>::AdamPreSolve() {
-  // Add the extra history entries for Adam after those from
-  // SGDSolver::PreSolve
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  for (int i = 0; i < net_params.size(); ++i) {
-    const vector<int>& shape = net_params[i]->shape();
-    this->history_.push_back(
-            shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
-  }
-}
-
-template <typename Dtype>
-void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
-  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
-  const vector<float>& net_params_lr = this->net_->params_lr();
-  Dtype local_rate = rate * net_params_lr[param_id];
-  const Dtype beta1 = this->param_.momentum();
-  const Dtype beta2 = this->param_.momentum2();
-
-  // we create aliases for convenience
-  size_t update_history_offset = net_params.size();
-  Blob<Dtype>* val_m = this->history_[param_id].get();
-  Blob<Dtype>* val_v = this->history_[param_id + update_history_offset].get();
-  Blob<Dtype>* val_t = this->temp_[param_id].get();
-
-  const int t = this->iter_  + 1;
-  const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) /
-      (Dtype(1.) - pow(beta1, t));
-  const int N = net_params[param_id]->count();
-  const Dtype eps_hat = this->param_.delta();
-
-  switch (Caffe::mode()) {
-    case Caffe::CPU: {
-    // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
-    caffe_cpu_axpby(N, Dtype(1)-beta1,
-        net_params[param_id]->cpu_diff(), beta1,
-        val_m->mutable_cpu_data());
-
-    // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
-    caffe_mul(N,
-        net_params[param_id]->cpu_diff(),
-        net_params[param_id]->cpu_diff(),
-    val_t->mutable_cpu_data());
-    caffe_cpu_axpby(N, Dtype(1)-beta2,
-        val_t->cpu_data(), beta2,
-        val_v->mutable_cpu_data());
-
-    // set update
-    caffe_powx(N,
-        val_v->cpu_data(), Dtype(0.5),
-        val_t->mutable_cpu_data());
-    caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
-    caffe_div(N,
-        val_m->cpu_data(),
-        val_t->cpu_data(),
-        val_t->mutable_cpu_data());
-
-    caffe_cpu_scale(N, local_rate*correction,
-        val_t->cpu_data(),
-        net_params[param_id]->mutable_cpu_diff());
-    break;
-  }
-  case Caffe::GPU: {
-#ifndef CPU_ONLY
-    // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
-    caffe_gpu_axpby(N, Dtype(1)-beta1,
-        net_params[param_id]->gpu_diff(), beta1,
-        val_m->mutable_gpu_data());
-
-    // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
-    caffe_gpu_mul(N,
-        net_params[param_id]->gpu_diff(),
-        net_params[param_id]->gpu_diff(),
-        val_t->mutable_gpu_data());
-    caffe_gpu_axpby(N, Dtype(1)-beta2,
-        val_t->gpu_data(), beta2,
-        val_v->mutable_gpu_data());
-
-    // set update
-    caffe_gpu_powx(N,
-        val_v->gpu_data(), Dtype(0.5),
-        val_t->mutable_gpu_data());
-    caffe_gpu_add_scalar(N, eps_hat,
-        val_t->mutable_gpu_data());
-    caffe_gpu_div(N,
-        val_m->gpu_data(),
-        val_t->gpu_data(),
-        val_t->mutable_gpu_data());
-
-    caffe_gpu_scale(N, local_rate*correction,
-        val_t->gpu_data(),
-        net_params[param_id]->mutable_gpu_diff());
-#else
-    NO_GPU;
-#endif
-    break;
-  }
-  default:
-    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
-  }
-}
-
 INSTANTIATE_CLASS(Solver);
-INSTANTIATE_CLASS(SGDSolver);
-INSTANTIATE_CLASS(NesterovSolver);
-INSTANTIATE_CLASS(AdaGradSolver);
-INSTANTIATE_CLASS(RMSPropSolver);
-INSTANTIATE_CLASS(AdaDeltaSolver);
-INSTANTIATE_CLASS(AdamSolver);
 
 }  // namespace caffe
diff --git a/src/caffe/solver_factory.cpp b/src/caffe/solver_factory.cpp
new file mode 100644 (file)
index 0000000..f78fab2
--- /dev/null
@@ -0,0 +1,32 @@
+#include "caffe/solver.hpp"
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+Solver<Dtype>* GetSolver(const SolverParameter& param) {
+  SolverParameter_SolverType type = param.solver_type();
+
+  switch (type) {
+  case SolverParameter_SolverType_SGD:
+    return new SGDSolver<Dtype>(param);
+  case SolverParameter_SolverType_NESTEROV:
+    return new NesterovSolver<Dtype>(param);
+  case SolverParameter_SolverType_ADAGRAD:
+    return new AdaGradSolver<Dtype>(param);
+  case SolverParameter_SolverType_RMSPROP:
+    return new RMSPropSolver<Dtype>(param);
+  case SolverParameter_SolverType_ADADELTA:
+    return new AdaDeltaSolver<Dtype>(param);
+  case SolverParameter_SolverType_ADAM:
+    return new AdamSolver<Dtype>(param);
+  default:
+    LOG(FATAL) << "Unknown SolverType: " << type;
+  }
+  return (Solver<Dtype>*) NULL;
+}
+
+template Solver<float>* GetSolver(const SolverParameter& param);
+template Solver<double>* GetSolver(const SolverParameter& param);
+
+}  // namespace caffe
diff --git a/src/caffe/solvers/adadelta_solver.cpp b/src/caffe/solvers/adadelta_solver.cpp
new file mode 100644 (file)
index 0000000..45cd4eb
--- /dev/null
@@ -0,0 +1,155 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void AdaDeltaSolver<Dtype>::AdaDeltaPreSolve() {
+  // Add the extra history entries for AdaDelta after those from
+  // SGDSolver::PreSolve
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  for (int i = 0; i < net_params.size(); ++i) {
+        const vector<int>& shape = net_params[i]->shape();
+        this->history_.push_back(
+                shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+  }
+}
+
+template <typename Dtype>
+void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  const vector<float>& net_params_lr = this->net_->params_lr();
+  Dtype delta = this->param_.delta();
+  Dtype momentum = this->param_.momentum();
+  Dtype local_rate = rate * net_params_lr[param_id];
+  size_t update_history_offset = net_params.size();
+  switch (Caffe::mode()) {
+  case Caffe::CPU: {
+    // compute square of gradient in update
+    caffe_powx(net_params[param_id]->count(),
+        net_params[param_id]->cpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // update history of gradients
+    caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
+        this->update_[param_id]->cpu_data(), momentum,
+        this->history_[param_id]->mutable_cpu_data());
+
+    // add delta to history to guard against dividing by zero later
+    caffe_set(net_params[param_id]->count(), delta,
+        this->temp_[param_id]->mutable_cpu_data());
+
+    caffe_add(net_params[param_id]->count(),
+        this->temp_[param_id]->cpu_data(),
+        this->history_[update_history_offset + param_id]->cpu_data(),
+        this->update_[param_id]->mutable_cpu_data());
+
+    caffe_add(net_params[param_id]->count(),
+        this->temp_[param_id]->cpu_data(),
+        this->history_[param_id]->cpu_data(),
+        this->temp_[param_id]->mutable_cpu_data());
+
+    // divide history of updates by history of gradients
+    caffe_div(net_params[param_id]->count(),
+        this->update_[param_id]->cpu_data(),
+        this->temp_[param_id]->cpu_data(),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // jointly compute the RMS of both for update and gradient history
+    caffe_powx(net_params[param_id]->count(),
+        this->update_[param_id]->cpu_data(), Dtype(0.5),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // compute the update
+    caffe_mul(net_params[param_id]->count(),
+        net_params[param_id]->cpu_diff(),
+        this->update_[param_id]->cpu_data(),
+        net_params[param_id]->mutable_cpu_diff());
+
+    // compute square of update
+    caffe_powx(net_params[param_id]->count(),
+        net_params[param_id]->cpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // update history of updates
+    caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
+        this->update_[param_id]->cpu_data(), momentum,
+        this->history_[update_history_offset + param_id]->mutable_cpu_data());
+
+    // apply learning rate
+    caffe_cpu_scale(net_params[param_id]->count(), local_rate,
+        net_params[param_id]->cpu_diff(),
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  }
+  case Caffe::GPU: {
+#ifndef CPU_ONLY
+    // compute square of gradient in update
+    caffe_gpu_powx(net_params[param_id]->count(),
+        net_params[param_id]->gpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_gpu_data());
+
+    // update history of gradients
+    caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
+        this->update_[param_id]->gpu_data(), momentum,
+        this->history_[param_id]->mutable_gpu_data());
+
+    // add delta to history to guard against dividing by zero later
+    caffe_gpu_set(net_params[param_id]->count(), delta,
+        this->temp_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_add(net_params[param_id]->count(),
+        this->temp_[param_id]->gpu_data(),
+        this->history_[update_history_offset + param_id]->gpu_data(),
+        this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_add(net_params[param_id]->count(),
+        this->temp_[param_id]->gpu_data(),
+        this->history_[param_id]->gpu_data(),
+        this->temp_[param_id]->mutable_gpu_data());
+
+    // divide history of updates by history of gradients
+    caffe_gpu_div(net_params[param_id]->count(),
+        this->update_[param_id]->gpu_data(),
+        this->temp_[param_id]->gpu_data(),
+        this->update_[param_id]->mutable_gpu_data());
+
+    // jointly compute the RMS of both for update and gradient history
+    caffe_gpu_powx(net_params[param_id]->count(),
+        this->update_[param_id]->gpu_data(), Dtype(0.5),
+        this->update_[param_id]->mutable_gpu_data());
+
+    // compute the update and copy to net_diff
+    caffe_gpu_mul(net_params[param_id]->count(),
+        net_params[param_id]->gpu_diff(),
+        this->update_[param_id]->gpu_data(),
+        net_params[param_id]->mutable_gpu_diff());
+
+    // compute square of update
+    caffe_gpu_powx(net_params[param_id]->count(),
+        net_params[param_id]->gpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_gpu_data());
+
+    // update history of updates
+    caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
+        this->update_[param_id]->gpu_data(), momentum,
+        this->history_[update_history_offset + param_id]->mutable_gpu_data());
+
+    // apply learning rate
+    caffe_gpu_scale(net_params[param_id]->count(), local_rate,
+        net_params[param_id]->gpu_diff(),
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  }
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+INSTANTIATE_CLASS(AdaDeltaSolver);
+
+}  // namespace caffe
diff --git a/src/caffe/solvers/adagrad_solver.cpp b/src/caffe/solvers/adagrad_solver.cpp
new file mode 100644 (file)
index 0000000..627d816
--- /dev/null
@@ -0,0 +1,88 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  CHECK(Caffe::root_solver());
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  const vector<float>& net_params_lr = this->net_->params_lr();
+  Dtype delta = this->param_.delta();
+  Dtype local_rate = rate * net_params_lr[param_id];
+  switch (Caffe::mode()) {
+  case Caffe::CPU: {
+    // compute square of gradient in update
+    caffe_powx(net_params[param_id]->count(),
+        net_params[param_id]->cpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // update history
+    caffe_add(net_params[param_id]->count(),
+        this->update_[param_id]->cpu_data(),
+        this->history_[param_id]->cpu_data(),
+        this->history_[param_id]->mutable_cpu_data());
+
+    // prepare update
+    caffe_powx(net_params[param_id]->count(),
+              this->history_[param_id]->cpu_data(), Dtype(0.5),
+              this->update_[param_id]->mutable_cpu_data());
+
+    caffe_add_scalar(net_params[param_id]->count(),
+              delta, this->update_[param_id]->mutable_cpu_data());
+
+    caffe_div(net_params[param_id]->count(),
+              net_params[param_id]->cpu_diff(),
+              this->update_[param_id]->cpu_data(),
+              this->update_[param_id]->mutable_cpu_data());
+
+    // scale and copy
+    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+        this->update_[param_id]->cpu_data(), Dtype(0),
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  }
+  case Caffe::GPU: {
+#ifndef CPU_ONLY
+    // compute square of gradient in update
+    caffe_gpu_powx(net_params[param_id]->count(),
+        net_params[param_id]->gpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_gpu_data());
+
+    // update history
+    caffe_gpu_add(net_params[param_id]->count(),
+        this->update_[param_id]->gpu_data(),
+        this->history_[param_id]->gpu_data(),
+        this->history_[param_id]->mutable_gpu_data());
+
+    // prepare update
+    caffe_gpu_powx(net_params[param_id]->count(),
+              this->history_[param_id]->gpu_data(), Dtype(0.5),
+              this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_add_scalar(net_params[param_id]->count(),
+              delta, this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_div(net_params[param_id]->count(),
+              net_params[param_id]->gpu_diff(),
+              this->update_[param_id]->gpu_data(),
+              this->update_[param_id]->mutable_gpu_data());
+
+    // scale and copy
+    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+        this->update_[param_id]->gpu_data(), Dtype(0),
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  }
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+INSTANTIATE_CLASS(AdaGradSolver);
+
+}  // namespace caffe
diff --git a/src/caffe/solvers/adam_solver.cpp b/src/caffe/solvers/adam_solver.cpp
new file mode 100644 (file)
index 0000000..8c334f6
--- /dev/null
@@ -0,0 +1,112 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void AdamSolver<Dtype>::AdamPreSolve() {
+  // Add the extra history entries for Adam after those from
+  // SGDSolver::PreSolve
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  for (int i = 0; i < net_params.size(); ++i) {
+    const vector<int>& shape = net_params[i]->shape();
+    this->history_.push_back(
+            shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+  }
+}
+
+template <typename Dtype>
+void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  const vector<float>& net_params_lr = this->net_->params_lr();
+  Dtype local_rate = rate * net_params_lr[param_id];
+  const Dtype beta1 = this->param_.momentum();
+  const Dtype beta2 = this->param_.momentum2();
+
+  // we create aliases for convenience
+  size_t update_history_offset = net_params.size();
+  Blob<Dtype>* val_m = this->history_[param_id].get();
+  Blob<Dtype>* val_v = this->history_[param_id + update_history_offset].get();
+  Blob<Dtype>* val_t = this->temp_[param_id].get();
+
+  const int t = this->iter_  + 1;
+  const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) /
+      (Dtype(1.) - pow(beta1, t));
+  const int N = net_params[param_id]->count();
+  const Dtype eps_hat = this->param_.delta();
+
+  switch (Caffe::mode()) {
+    case Caffe::CPU: {
+    // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
+    caffe_cpu_axpby(N, Dtype(1)-beta1,
+        net_params[param_id]->cpu_diff(), beta1,
+        val_m->mutable_cpu_data());
+
+    // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
+    caffe_mul(N,
+        net_params[param_id]->cpu_diff(),
+        net_params[param_id]->cpu_diff(),
+    val_t->mutable_cpu_data());
+    caffe_cpu_axpby(N, Dtype(1)-beta2,
+        val_t->cpu_data(), beta2,
+        val_v->mutable_cpu_data());
+
+    // set update
+    caffe_powx(N,
+        val_v->cpu_data(), Dtype(0.5),
+        val_t->mutable_cpu_data());
+    caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
+    caffe_div(N,
+        val_m->cpu_data(),
+        val_t->cpu_data(),
+        val_t->mutable_cpu_data());
+
+    caffe_cpu_scale(N, local_rate*correction,
+        val_t->cpu_data(),
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  }
+  case Caffe::GPU: {
+#ifndef CPU_ONLY
+    // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
+    caffe_gpu_axpby(N, Dtype(1)-beta1,
+        net_params[param_id]->gpu_diff(), beta1,
+        val_m->mutable_gpu_data());
+
+    // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
+    caffe_gpu_mul(N,
+        net_params[param_id]->gpu_diff(),
+        net_params[param_id]->gpu_diff(),
+        val_t->mutable_gpu_data());
+    caffe_gpu_axpby(N, Dtype(1)-beta2,
+        val_t->gpu_data(), beta2,
+        val_v->mutable_gpu_data());
+
+    // set update
+    caffe_gpu_powx(N,
+        val_v->gpu_data(), Dtype(0.5),
+        val_t->mutable_gpu_data());
+    caffe_gpu_add_scalar(N, eps_hat,
+        val_t->mutable_gpu_data());
+    caffe_gpu_div(N,
+        val_m->gpu_data(),
+        val_t->gpu_data(),
+        val_t->mutable_gpu_data());
+
+    caffe_gpu_scale(N, local_rate*correction,
+        val_t->gpu_data(),
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  }
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+INSTANTIATE_CLASS(AdamSolver);
+
+}  // namespace caffe
diff --git a/src/caffe/solvers/nesterov_solver.cpp b/src/caffe/solvers/nesterov_solver.cpp
new file mode 100644 (file)
index 0000000..8135ee2
--- /dev/null
@@ -0,0 +1,70 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  CHECK(Caffe::root_solver());
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  const vector<float>& net_params_lr = this->net_->params_lr();
+  Dtype momentum = this->param_.momentum();
+  Dtype local_rate = rate * net_params_lr[param_id];
+  switch (Caffe::mode()) {
+  case Caffe::CPU: {
+    // save history momentum for stepping back
+    caffe_copy(net_params[param_id]->count(),
+        this->history_[param_id]->cpu_data(),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // update history
+    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+              net_params[param_id]->cpu_diff(), momentum,
+              this->history_[param_id]->mutable_cpu_data());
+
+    // compute update: step back then over step
+    caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
+        this->history_[param_id]->cpu_data(), -momentum,
+        this->update_[param_id]->mutable_cpu_data());
+
+    // copy
+    caffe_copy(net_params[param_id]->count(),
+        this->update_[param_id]->cpu_data(),
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  }
+  case Caffe::GPU: {
+#ifndef CPU_ONLY
+    // save history momentum for stepping back
+    caffe_copy(net_params[param_id]->count(),
+        this->history_[param_id]->gpu_data(),
+        this->update_[param_id]->mutable_gpu_data());
+
+    // update history
+    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+              net_params[param_id]->gpu_diff(), momentum,
+              this->history_[param_id]->mutable_gpu_data());
+
+    // compute update: step back then over step
+    caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
+        this->history_[param_id]->gpu_data(), -momentum,
+        this->update_[param_id]->mutable_gpu_data());
+
+    // copy
+    caffe_copy(net_params[param_id]->count(),
+        this->update_[param_id]->gpu_data(),
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  }
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+INSTANTIATE_CLASS(NesterovSolver);
+
+}  // namespace caffe
diff --git a/src/caffe/solvers/rmsprop_solver.cpp b/src/caffe/solvers/rmsprop_solver.cpp
new file mode 100644 (file)
index 0000000..96d1b3d
--- /dev/null
@@ -0,0 +1,84 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  const vector<float>& net_params_lr = this->net_->params_lr();
+
+  // get the learning rate
+  Dtype delta = this->param_.delta();
+  Dtype rms_decay = this->param_.rms_decay();
+  Dtype local_rate = rate * net_params_lr[param_id];
+
+  switch (Caffe::mode()) {
+  case Caffe::CPU:
+    // compute square of gradient in update
+    caffe_powx(net_params[param_id]->count(),
+        net_params[param_id]->cpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // update history
+    caffe_cpu_axpby(net_params[param_id] -> count(),
+        Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
+        rms_decay, this->history_[param_id]-> mutable_cpu_data());
+
+    // prepare update
+    caffe_powx(net_params[param_id]->count(),
+        this->history_[param_id]->cpu_data(), Dtype(0.5),
+        this->update_[param_id]->mutable_cpu_data());
+
+    caffe_add_scalar(net_params[param_id]->count(),
+        delta, this->update_[param_id]->mutable_cpu_data());
+
+    caffe_div(net_params[param_id]->count(),
+        net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // scale and copy
+    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+        this->update_[param_id]->cpu_data(), Dtype(0),
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  case Caffe::GPU:
+#ifndef CPU_ONLY
+    // compute square of gradient in update
+    caffe_gpu_powx(net_params[param_id]->count(),
+        net_params[param_id]->gpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_gpu_data());
+
+    // update history
+    caffe_gpu_axpby(net_params[param_id] -> count(),
+        Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
+        rms_decay, this->history_[param_id]-> mutable_gpu_data());
+
+    // prepare update
+    caffe_gpu_powx(net_params[param_id]->count(),
+        this->history_[param_id]->gpu_data(), Dtype(0.5),
+        this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_add_scalar(net_params[param_id]->count(),
+        delta, this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_div(net_params[param_id]->count(),
+        net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
+        this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+        this->update_[param_id]->gpu_data(), Dtype(0),
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+INSTANTIATE_CLASS(RMSPropSolver);
+
+}  // namespace caffe
diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp
new file mode 100644 (file)
index 0000000..89ef5ec
--- /dev/null
@@ -0,0 +1,347 @@
+#include <string>
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+#include "caffe/util/hdf5.hpp"
+#include "caffe/util/io.hpp"
+#include "caffe/util/upgrade_proto.hpp"
+
+namespace caffe {
+
+// Return the current learning rate. The currently implemented learning rate
+// policies are as follows:
+//    - fixed: always return base_lr.
+//    - step: return base_lr * gamma ^ (floor(iter / step))
+//    - exp: return base_lr * gamma ^ iter
+//    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
+//    - multistep: similar to step but it allows non uniform steps defined by
+//      stepvalue
+//    - poly: the effective learning rate follows a polynomial decay, to be
+//      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
+//    - sigmoid: the effective learning rate follows a sigmod decay
+//      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
+//
+// where base_lr, max_iter, gamma, step, stepvalue and power are defined
+// in the solver parameter protocol buffer, and iter is the current iteration.
+template <typename Dtype>
+Dtype SGDSolver<Dtype>::GetLearningRate() {
+  Dtype rate;
+  const string& lr_policy = this->param_.lr_policy();
+  if (lr_policy == "fixed") {
+    rate = this->param_.base_lr();
+  } else if (lr_policy == "step") {
+    this->current_step_ = this->iter_ / this->param_.stepsize();
+    rate = this->param_.base_lr() *
+        pow(this->param_.gamma(), this->current_step_);
+  } else if (lr_policy == "exp") {
+    rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
+  } else if (lr_policy == "inv") {
+    rate = this->param_.base_lr() *
+        pow(Dtype(1) + this->param_.gamma() * this->iter_,
+            - this->param_.power());
+  } else if (lr_policy == "multistep") {
+    if (this->current_step_ < this->param_.stepvalue_size() &&
+          this->iter_ >= this->param_.stepvalue(this->current_step_)) {
+      this->current_step_++;
+      LOG(INFO) << "MultiStep Status: Iteration " <<
+      this->iter_ << ", step = " << this->current_step_;
+    }
+    rate = this->param_.base_lr() *
+        pow(this->param_.gamma(), this->current_step_);
+  } else if (lr_policy == "poly") {
+    rate = this->param_.base_lr() * pow(Dtype(1.) -
+        (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
+        this->param_.power());
+  } else if (lr_policy == "sigmoid") {
+    rate = this->param_.base_lr() * (Dtype(1.) /
+        (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
+          Dtype(this->param_.stepsize())))));
+  } else {
+    LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
+  }
+  return rate;
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::PreSolve() {
+  // Initialize the history
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  history_.clear();
+  update_.clear();
+  temp_.clear();
+  for (int i = 0; i < net_params.size(); ++i) {
+    const vector<int>& shape = net_params[i]->shape();
+    history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+    update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+    temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ClipGradients() {
+  const Dtype clip_gradients = this->param_.clip_gradients();
+  if (clip_gradients < 0) { return; }
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  Dtype sumsq_diff = 0;
+  for (int i = 0; i < net_params.size(); ++i) {
+    sumsq_diff += net_params[i]->sumsq_diff();
+  }
+  const Dtype l2norm_diff = std::sqrt(sumsq_diff);
+  if (l2norm_diff > clip_gradients) {
+    Dtype scale_factor = clip_gradients / l2norm_diff;
+    LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
+        << l2norm_diff << " > " << clip_gradients << ") "
+        << "by scale factor " << scale_factor;
+    for (int i = 0; i < net_params.size(); ++i) {
+      net_params[i]->scale_diff(scale_factor);
+    }
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ApplyUpdate() {
+  CHECK(Caffe::root_solver());
+  Dtype rate = GetLearningRate();
+  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
+    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
+  }
+  ClipGradients();
+  for (int param_id = 0; param_id < this->net_->learnable_params().size();
+       ++param_id) {
+    Normalize(param_id);
+    Regularize(param_id);
+    ComputeUpdateValue(param_id, rate);
+  }
+  this->net_->Update();
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::Normalize(int param_id) {
+  if (this->param_.iter_size() == 1) { return; }
+  // Scale gradient to counterbalance accumulation.
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
+  switch (Caffe::mode()) {
+  case Caffe::CPU: {
+    caffe_scal(net_params[param_id]->count(), accum_normalization,
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  }
+  case Caffe::GPU: {
+#ifndef CPU_ONLY
+    caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  }
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::Regularize(int param_id) {
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  const vector<float>& net_params_weight_decay =
+      this->net_->params_weight_decay();
+  Dtype weight_decay = this->param_.weight_decay();
+  string regularization_type = this->param_.regularization_type();
+  Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
+  switch (Caffe::mode()) {
+  case Caffe::CPU: {
+    if (local_decay) {
+      if (regularization_type == "L2") {
+        // add weight decay
+        caffe_axpy(net_params[param_id]->count(),
+            local_decay,
+            net_params[param_id]->cpu_data(),
+            net_params[param_id]->mutable_cpu_diff());
+      } else if (regularization_type == "L1") {
+        caffe_cpu_sign(net_params[param_id]->count(),
+            net_params[param_id]->cpu_data(),
+            temp_[param_id]->mutable_cpu_data());
+        caffe_axpy(net_params[param_id]->count(),
+            local_decay,
+            temp_[param_id]->cpu_data(),
+            net_params[param_id]->mutable_cpu_diff());
+      } else {
+        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+      }
+    }
+    break;
+  }
+  case Caffe::GPU: {
+#ifndef CPU_ONLY
+    if (local_decay) {
+      if (regularization_type == "L2") {
+        // add weight decay
+        caffe_gpu_axpy(net_params[param_id]->count(),
+            local_decay,
+            net_params[param_id]->gpu_data(),
+            net_params[param_id]->mutable_gpu_diff());
+      } else if (regularization_type == "L1") {
+        caffe_gpu_sign(net_params[param_id]->count(),
+            net_params[param_id]->gpu_data(),
+            temp_[param_id]->mutable_gpu_data());
+        caffe_gpu_axpy(net_params[param_id]->count(),
+            local_decay,
+            temp_[param_id]->gpu_data(),
+            net_params[param_id]->mutable_gpu_diff());
+      } else {
+        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+      }
+    }
+#else
+    NO_GPU;
+#endif
+    break;
+  }
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  const vector<float>& net_params_lr = this->net_->params_lr();
+  Dtype momentum = this->param_.momentum();
+  Dtype local_rate = rate * net_params_lr[param_id];
+  // Compute the update to history, then copy it to the parameter diff.
+  switch (Caffe::mode()) {
+  case Caffe::CPU: {
+    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+              net_params[param_id]->cpu_diff(), momentum,
+              history_[param_id]->mutable_cpu_data());
+    caffe_copy(net_params[param_id]->count(),
+        history_[param_id]->cpu_data(),
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  }
+  case Caffe::GPU: {
+#ifndef CPU_ONLY
+    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+              net_params[param_id]->gpu_diff(), momentum,
+              history_[param_id]->mutable_gpu_data());
+    caffe_copy(net_params[param_id]->count(),
+        history_[param_id]->gpu_data(),
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  }
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
+  switch (this->param_.snapshot_format()) {
+    case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
+      SnapshotSolverStateToBinaryProto(model_filename);
+      break;
+    case caffe::SolverParameter_SnapshotFormat_HDF5:
+      SnapshotSolverStateToHDF5(model_filename);
+      break;
+    default:
+      LOG(FATAL) << "Unsupported snapshot format.";
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
+    const string& model_filename) {
+  SolverState state;
+  state.set_iter(this->iter_);
+  state.set_learned_net(model_filename);
+  state.set_current_step(this->current_step_);
+  state.clear_history();
+  for (int i = 0; i < history_.size(); ++i) {
+    // Add history
+    BlobProto* history_blob = state.add_history();
+    history_[i]->ToProto(history_blob);
+  }
+  string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
+  LOG(INFO)
+    << "Snapshotting solver state to binary proto file " << snapshot_filename;
+  WriteProtoToBinaryFile(state, snapshot_filename.c_str());
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
+    const string& model_filename) {
+  string snapshot_filename =
+      Solver<Dtype>::SnapshotFilename(".solverstate.h5");
+  LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
+  hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC,
+      H5P_DEFAULT, H5P_DEFAULT);
+  CHECK_GE(file_hid, 0)
+      << "Couldn't open " << snapshot_filename << " to save solver state.";
+  hdf5_save_int(file_hid, "iter", this->iter_);
+  hdf5_save_string(file_hid, "learned_net", model_filename);
+  hdf5_save_int(file_hid, "current_step", this->current_step_);
+  hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT,
+      H5P_DEFAULT);
+  CHECK_GE(history_hid, 0)
+      << "Error saving solver state to " << snapshot_filename << ".";
+  for (int i = 0; i < history_.size(); ++i) {
+    ostringstream oss;
+    oss << i;
+    hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]);
+  }
+  H5Gclose(history_hid);
+  H5Fclose(file_hid);
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
+    const string& state_file) {
+  SolverState state;
+  ReadProtoFromBinaryFile(state_file, &state);
+  this->iter_ = state.iter();
+  if (state.has_learned_net()) {
+    NetParameter net_param;
+    ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
+    this->net_->CopyTrainedLayersFrom(net_param);
+  }
+  this->current_step_ = state.current_step();
+  CHECK_EQ(state.history_size(), history_.size())
+      << "Incorrect length of history blobs.";
+  LOG(INFO) << "SGDSolver: restoring history";
+  for (int i = 0; i < history_.size(); ++i) {
+    history_[i]->FromProto(state.history(i));
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
+  hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
+  CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
+  this->iter_ = hdf5_load_int(file_hid, "iter");
+  if (H5LTfind_dataset(file_hid, "learned_net")) {
+    string learned_net = hdf5_load_string(file_hid, "learned_net");
+    this->net_->CopyTrainedLayersFrom(learned_net);
+  }
+  this->current_step_ = hdf5_load_int(file_hid, "current_step");
+  hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
+  CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
+  int state_history_size = hdf5_get_num_links(history_hid);
+  CHECK_EQ(state_history_size, history_.size())
+      << "Incorrect length of history blobs.";
+  for (int i = 0; i < history_.size(); ++i) {
+    ostringstream oss;
+    oss << i;
+    hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
+                                kMaxBlobAxes, history_[i].get());
+  }
+  H5Gclose(history_hid);
+  H5Fclose(file_hid);
+}
+
+INSTANTIATE_CLASS(SGDSolver);
+
+}  // namespace caffe
index 7ad7467..1767ad3 100644 (file)
@@ -10,7 +10,7 @@
 #include "caffe/common.hpp"
 #include "caffe/parallel.hpp"
 #include "caffe/proto/caffe.pb.h"
-#include "caffe/solver.hpp"
+#include "caffe/sgd_solvers.hpp"
 #include "caffe/util/io.hpp"
 
 #include "caffe/test/test_caffe_main.hpp"
index ceabc9c..b181642 100644 (file)
@@ -7,6 +7,7 @@
 
 #include "caffe/common.hpp"
 #include "caffe/proto/caffe.pb.h"
+#include "caffe/sgd_solvers.hpp"
 #include "caffe/solver.hpp"
 
 #include "caffe/test/test_caffe_main.hpp"