Solver switching support & implementation of Nesterov's accelerated gradient and...
authorqipeng <pengrobertqi@163.com>
Sat, 19 Jul 2014 21:14:21 +0000 (14:14 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Mon, 1 Sep 2014 18:33:41 +0000 (11:33 -0700)
include/caffe/solver.hpp
src/caffe/solver.cpp
tools/train_net.cpp

index 9012c5d..87c6563 100644 (file)
@@ -66,17 +66,62 @@ class SGDSolver : public Solver<Dtype> {
       : Solver<Dtype>(param_file) {}
 
  protected:
-  virtual void PreSolve();
+  void PreSolve();
   Dtype GetLearningRate();
   virtual void ComputeUpdateValue();
-  virtual void SnapshotSolverState(SolverState * state);
-  virtual void RestoreSolverState(const SolverState& state);
+  void SnapshotSolverState(SolverState * state);
+  void RestoreSolverState(const SolverState& state);
   // history maintains the historical momentum data.
-  vector<shared_ptr<Blob<Dtype> > > history_;
+  // update maintains update related data and is not needed in snapshots.
+  vector<shared_ptr<Blob<Dtype> > > history_, update_;
 
   DISABLE_COPY_AND_ASSIGN(SGDSolver);
 };
 
+template <typename Dtype>
+class NesterovSolver : public SGDSolver<Dtype> {
+ public:
+  explicit NesterovSolver(const SolverParameter& param)
+      : SGDSolver<Dtype>(param) {}
+  explicit NesterovSolver(const string& param_file)
+      : SGDSolver<Dtype>(param_file) {}
+
+ protected:
+  virtual void ComputeUpdateValue();
+
+  DISABLE_COPY_AND_ASSIGN(NesterovSolver);
+};
+
+template <typename Dtype>
+class AdaGradSolver : public SGDSolver<Dtype> {
+ public:
+  explicit AdaGradSolver(const SolverParameter& param)
+      : SGDSolver<Dtype>(param) {}
+  explicit AdaGradSolver(const string& param_file)
+      : SGDSolver<Dtype>(param_file) {}
+
+ protected:
+  virtual void ComputeUpdateValue();
+
+  DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
+};
+
+template <typename Dtype>
+Solver<Dtype>* GetSolver(const SolverParameter& param) {
+  SolverParameter_SolverType type = param.solver_type();
+
+  switch (type) {
+  case SolverParameter_SolverType_SGD:
+      return new SGDSolver<Dtype>(param);
+  case SolverParameter_SolverType_NESTEROV:
+      return new NesterovSolver<Dtype>(param);
+  case SolverParameter_SolverType_ADAGRAD:
+      return new AdaGradSolver<Dtype>(param);
+  default:
+      LOG(FATAL) << "Unknown SolverType: " << type;
+  }
+}
+
 
 }  // namespace caffe
 
index 80582b3..5632f24 100644 (file)
@@ -377,11 +377,15 @@ void SGDSolver<Dtype>::PreSolve() {
   // Initialize the history
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   history_.clear();
+  update_.clear();
   for (int i = 0; i < net_params.size(); ++i) {
     const Blob<Dtype>* net_param = net_params[i].get();
     history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
         net_param->num(), net_param->channels(), net_param->height(),
         net_param->width())));
+    update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
+        net_param->num(), net_param->channels(), net_param->height(),
+        net_param->width())));
   }
 }
 
@@ -470,7 +474,192 @@ void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
   }
 }
 
+template <typename Dtype>
+void NesterovSolver<Dtype>::ComputeUpdateValue() {
+  vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+  vector<float>& net_params_lr = this->net_->params_lr();
+  vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
+  // get the learning rate
+  Dtype rate = this->GetLearningRate();
+  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
+    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
+  }
+  Dtype momentum = this->param_.momentum();
+  Dtype weight_decay = this->param_.weight_decay();
+  switch (Caffe::mode()) {
+  case Caffe::CPU:
+    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
+      // save history momentum for stepping back
+      caffe_copy(net_params[param_id]->count(),
+          this->history_[param_id]->cpu_data(),
+          this->update_[param_id]->mutable_cpu_data());
+
+      Dtype local_rate = rate * net_params_lr[param_id];
+      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
+      caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+          net_params[param_id]->cpu_diff(), momentum,
+          this->history_[param_id]->mutable_cpu_data());
+      if (local_decay) {
+        // add weight decay
+        caffe_axpy(net_params[param_id]->count(),
+            local_decay * local_rate,
+            net_params[param_id]->cpu_data(),
+            this->history_[param_id]->mutable_cpu_data());
+      }
+      // compute udpate: step back then over step
+      caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
+          this->history_[param_id]->cpu_data(), -momentum,
+          this->update_[param_id]->mutable_cpu_data());
+
+      // copy
+      caffe_copy(net_params[param_id]->count(),
+          this->update_[param_id]->cpu_data(),
+          net_params[param_id]->mutable_cpu_diff());
+    }
+    break;
+  case Caffe::GPU:
+#ifndef CPU_ONLY
+    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
+      // save history momentum for stepping back
+      caffe_copy(net_params[param_id]->count(),
+          this->history_[param_id]->gpu_data(),
+          this->update_[param_id]->mutable_gpu_data());
+
+      Dtype local_rate = rate * net_params_lr[param_id];
+      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
+      caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+          net_params[param_id]->gpu_diff(), momentum,
+          this->history_[param_id]->mutable_gpu_data());
+      if (local_decay) {
+        // add weight decay
+        caffe_gpu_axpy(net_params[param_id]->count(),
+            local_decay * local_rate,
+            net_params[param_id]->gpu_data(),
+            this->history_[param_id]->mutable_gpu_data());
+      }
+      // compute udpate: step back then over step
+      caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
+          this->history_[param_id]->gpu_data(), -momentum,
+          this->update_[param_id]->mutable_gpu_data());
+
+      // copy
+      caffe_copy(net_params[param_id]->count(),
+          this->update_[param_id]->gpu_data(),
+          net_params[param_id]->mutable_gpu_diff());
+    }
+#else
+    NO_GPU;
+#endif
+    break;
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+template <typename Dtype>
+void AdaGradSolver<Dtype>::ComputeUpdateValue() {
+  vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+  vector<float>& net_params_lr = this->net_->params_lr();
+  vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
+  // get the learning rate
+  Dtype rate = this->GetLearningRate();
+  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
+    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
+  }
+  Dtype weight_decay = this->param_.weight_decay();
+  switch (Caffe::mode()) {
+  case Caffe::CPU:
+    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
+      Dtype local_rate = rate * net_params_lr[param_id];
+      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
+
+      if (local_decay) {
+        // add weight decay
+        caffe_axpy(net_params[param_id]->count(),
+            local_decay * local_rate,
+            net_params[param_id]->cpu_data(),
+            this->history_[param_id]->mutable_cpu_data());
+      }
+
+      // compute square of gradient in update
+      caffe_powx(net_params[param_id]->count(),
+          net_params[param_id]->cpu_data(), Dtype(2),
+          this->update_[param_id]->mutable_cpu_data());
+
+      // update history
+      caffe_add(net_params[param_id]->count(),
+          this->update_[param_id]->cpu_data(),
+          this->history_[param_id]->cpu_data(),
+          this->history_[param_id]->mutable_cpu_data());
+
+      // prepare update
+      caffe_powx(net_params[param_id]->count(),
+                this->history_[param_id]->cpu_data(), Dtype(-0.5),
+                this->update_[param_id]->mutable_cpu_data());
+
+      caffe_mul(net_params[param_id]->count(),
+                net_params[param_id]->cpu_diff(),
+                this->update_[param_id]->cpu_data(),
+                this->update_[param_id]->mutable_cpu_data());
+
+      // scale and copy
+      caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+          this->update_[param_id]->cpu_data(), Dtype(0),
+          net_params[param_id]->mutable_cpu_diff());
+    }
+    break;
+  case Caffe::GPU:
+#ifndef CPU_ONLY
+    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
+      Dtype local_rate = rate * net_params_lr[param_id];
+      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
+
+      if (local_decay) {
+        // add weight decay
+        caffe_gpu_axpy(net_params[param_id]->count(),
+            local_decay * local_rate,
+            net_params[param_id]->gpu_data(),
+            this->history_[param_id]->mutable_gpu_data());
+      }
+
+      // compute square of gradient in update
+      caffe_gpu_powx(net_params[param_id]->count(),
+          net_params[param_id]->gpu_data(), Dtype(2),
+          this->update_[param_id]->mutable_gpu_data());
+
+      // update history
+      caffe_gpu_add(net_params[param_id]->count(),
+          this->update_[param_id]->gpu_data(),
+          this->history_[param_id]->gpu_data(),
+          this->history_[param_id]->mutable_gpu_data());
+
+      // prepare update
+      caffe_gpu_powx(net_params[param_id]->count(),
+                this->history_[param_id]->gpu_data(), Dtype(-0.5),
+                this->update_[param_id]->mutable_gpu_data());
+
+      caffe_gpu_mul(net_params[param_id]->count(),
+                net_params[param_id]->gpu_diff(),
+                this->update_[param_id]->gpu_data(),
+                this->update_[param_id]->mutable_gpu_data());
+
+      // scale and copy
+      caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+          this->update_[param_id]->gpu_data(), Dtype(0),
+          net_params[param_id]->mutable_gpu_diff());
+    }
+#else
+    NO_GPU;
+#endif
+    break;
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
 INSTANTIATE_CLASS(Solver);
 INSTANTIATE_CLASS(SGDSolver);
+INSTANTIATE_CLASS(NesterovSolver);
+INSTANTIATE_CLASS(AdaGradSolver);
 
 }  // namespace caffe
index 622bca3..1176759 100644 (file)
@@ -1,6 +1,9 @@
 #include "caffe/caffe.hpp"
 
+using namespace caffe;  // NOLINT(build/namespaces)
+
 int main(int argc, char** argv) {
+
   LOG(FATAL) << "Deprecated. Use caffe train --solver=... "
                 "[--snapshot=...] instead.";
   return 0;