improved numerical stability for AdaGrad
authorqipeng <pengrobertqi@163.com>
Wed, 23 Jul 2014 04:17:19 +0000 (21:17 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Mon, 1 Sep 2014 18:33:41 +0000 (11:33 -0700)
src/caffe/solver.cpp

index 5632f24..abcbe5e 100644 (file)
@@ -481,6 +481,7 @@ void NesterovSolver<Dtype>::ComputeUpdateValue() {
   vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
   // get the learning rate
   Dtype rate = this->GetLearningRate();
+  Dtype delta = this->param_.delta();
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
     LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
   }
@@ -594,10 +595,13 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
 
       // prepare update
       caffe_powx(net_params[param_id]->count(),
-                this->history_[param_id]->cpu_data(), Dtype(-0.5),
+                this->history_[param_id]->cpu_data(), Dtype(0.5),
                 this->update_[param_id]->mutable_cpu_data());
 
-      caffe_mul(net_params[param_id]->count(),
+      caffe_add_scalar(net_params[param_id]->count(),
+                delta, this->update_[param_id]->mutable_cpu_data());
+
+      caffe_div(net_params[param_id]->count(),
                 net_params[param_id]->cpu_diff(),
                 this->update_[param_id]->cpu_data(),
                 this->update_[param_id]->mutable_cpu_data());
@@ -635,10 +639,13 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
 
       // prepare update
       caffe_gpu_powx(net_params[param_id]->count(),
-                this->history_[param_id]->gpu_data(), Dtype(-0.5),
+                this->history_[param_id]->gpu_data(), Dtype(0.5),
                 this->update_[param_id]->mutable_gpu_data());
 
-      caffe_gpu_mul(net_params[param_id]->count(),
+      caffe_gpu_add_scalar(net_params[param_id]->count(),
+                delta, this->update_[param_id]->mutable_gpu_data());
+
+      caffe_gpu_div(net_params[param_id]->count(),
                 net_params[param_id]->gpu_diff(),
                 this->update_[param_id]->gpu_data(),
                 this->update_[param_id]->mutable_gpu_data());