Refactor solvers regularization and logging code
authorCyprien Noel <cyprien.noel@gmail.com>
Tue, 19 May 2015 01:30:00 +0000 (18:30 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 27 May 2015 03:07:00 +0000 (20:07 -0700)
include/caffe/solver.hpp
src/caffe/solver.cpp

index 4dcdc3d..c920679 100644 (file)
@@ -39,8 +39,8 @@ class Solver {
   int iter() { return iter_; }
 
  protected:
-  // Get the update value for the current iteration.
-  virtual void ComputeUpdateValue() = 0;
+  // Get and apply the update value for the current iteration.
+  virtual void MakeUpdate() = 0;
   // The Solver::Snapshot function implements the basic snapshotting utility
   // that stores the learned net. You should implement the SnapshotSolverState()
   // function that produces a SolverState protocol buffer that needs to be
@@ -80,7 +80,9 @@ class SGDSolver : public Solver<Dtype> {
  protected:
   void PreSolve();
   Dtype GetLearningRate();
-  virtual void ComputeUpdateValue();
+  virtual void MakeUpdate();
+  virtual void Regularize(int param_id);
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
   virtual void ClipGradients();
   virtual void SnapshotSolverState(SolverState * state);
   virtual void RestoreSolverState(const SolverState& state);
@@ -102,7 +104,7 @@ class NesterovSolver : public SGDSolver<Dtype> {
       : SGDSolver<Dtype>(param_file) {}
 
  protected:
-  virtual void ComputeUpdateValue();
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
 
   DISABLE_COPY_AND_ASSIGN(NesterovSolver);
 };
@@ -116,7 +118,7 @@ class AdaGradSolver : public SGDSolver<Dtype> {
       : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
 
  protected:
-  virtual void ComputeUpdateValue();
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
   void constructor_sanity_check() {
     CHECK_EQ(0, this->param_.momentum())
         << "Momentum cannot be used with AdaGrad.";
index 877b19b..88f6d31 100644 (file)
@@ -207,8 +207,7 @@ void Solver<Dtype>::Step(int iters) {
         }
       }
     }
-    ComputeUpdateValue();
-    net_->Update();
+    MakeUpdate();
 
     // Increment the internal iter_ counter -- its value should always indicate
     // the number of times the weights have been updated.
@@ -456,95 +455,118 @@ void SGDSolver<Dtype>::ClipGradients() {
 }
 
 template <typename Dtype>
-void SGDSolver<Dtype>::ComputeUpdateValue() {
-  const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
-  const vector<float>& net_params_lr = this->net_->params_lr();
-  const vector<float>& net_params_weight_decay =
-      this->net_->params_weight_decay();
-  // get the learning rate
+void SGDSolver<Dtype>::MakeUpdate() {
   Dtype rate = GetLearningRate();
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
     LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
   }
   ClipGradients();
-  Dtype momentum = this->param_.momentum();
+  for (int param_id = 0; param_id < this->net_->params().size(); ++param_id) {
+    Regularize(param_id);
+    ComputeUpdateValue(param_id, rate);
+  }
+  this->net_->Update();
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::Regularize(int param_id) {
+  const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+  const vector<float>& net_params_weight_decay =
+      this->net_->params_weight_decay();
   Dtype weight_decay = this->param_.weight_decay();
   string regularization_type = this->param_.regularization_type();
   switch (Caffe::mode()) {
-  case Caffe::CPU:
-    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
-      // Compute the value to history, and then copy them to the blob's diff.
-      Dtype local_rate = rate * net_params_lr[param_id];
-      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-
-      if (local_decay) {
-        if (regularization_type == "L2") {
-          // add weight decay
-          caffe_axpy(net_params[param_id]->count(),
-              local_decay,
-              net_params[param_id]->cpu_data(),
-              net_params[param_id]->mutable_cpu_diff());
-        } else if (regularization_type == "L1") {
-          caffe_cpu_sign(net_params[param_id]->count(),
-              net_params[param_id]->cpu_data(),
-              temp_[param_id]->mutable_cpu_data());
-          caffe_axpy(net_params[param_id]->count(),
-              local_decay,
-              temp_[param_id]->cpu_data(),
-              net_params[param_id]->mutable_cpu_diff());
-        } else {
-          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
-        }
+  case Caffe::CPU: {
+    Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
+    if (local_decay) {
+      if (regularization_type == "L2") {
+        // add weight decay
+        caffe_axpy(net_params[param_id]->count(),
+            local_decay,
+            net_params[param_id]->cpu_data(),
+            net_params[param_id]->mutable_cpu_diff());
+      } else if (regularization_type == "L1") {
+        caffe_cpu_sign(net_params[param_id]->count(),
+            net_params[param_id]->cpu_data(),
+            temp_[param_id]->mutable_cpu_data());
+        caffe_axpy(net_params[param_id]->count(),
+            local_decay,
+            temp_[param_id]->cpu_data(),
+            net_params[param_id]->mutable_cpu_diff());
+      } else {
+        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
       }
-
-      caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
-                net_params[param_id]->cpu_diff(), momentum,
-                history_[param_id]->mutable_cpu_data());
-      // copy
-      caffe_copy(net_params[param_id]->count(),
-          history_[param_id]->cpu_data(),
-          net_params[param_id]->mutable_cpu_diff());
     }
     break;
-  case Caffe::GPU:
+  }
+  case Caffe::GPU: {
 #ifndef CPU_ONLY
-    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
-      // Compute the value to history, and then copy them to the blob's diff.
-      Dtype local_rate = rate * net_params_lr[param_id];
-      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-
-      if (local_decay) {
-        if (regularization_type == "L2") {
-          // add weight decay
-          caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay,
-              net_params[param_id]->gpu_data(),
-              net_params[param_id]->mutable_gpu_diff());
-        } else if (regularization_type == "L1") {
-          caffe_gpu_sign(net_params[param_id]->count(),
-              net_params[param_id]->gpu_data(),
-              temp_[param_id]->mutable_gpu_data());
-          caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay,
-              temp_[param_id]->gpu_data(),
-              net_params[param_id]->mutable_gpu_diff());
-        } else {
-          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
-        }
+    Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
+    if (local_decay) {
+      if (regularization_type == "L2") {
+        // add weight decay
+        caffe_gpu_axpy(net_params[param_id]->count(),
+            local_decay,
+            net_params[param_id]->gpu_data(),
+            net_params[param_id]->mutable_gpu_diff());
+      } else if (regularization_type == "L1") {
+        caffe_gpu_sign(net_params[param_id]->count(),
+            net_params[param_id]->gpu_data(),
+            temp_[param_id]->mutable_gpu_data());
+        caffe_gpu_axpy(net_params[param_id]->count(),
+            local_decay,
+            temp_[param_id]->gpu_data(),
+            net_params[param_id]->mutable_gpu_diff());
+      } else {
+        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
       }
-
-      caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
-                net_params[param_id]->gpu_diff(), momentum,
-                history_[param_id]->mutable_gpu_data());
-      // copy
-      caffe_copy(net_params[param_id]->count(),
-          history_[param_id]->gpu_data(),
-          net_params[param_id]->mutable_gpu_diff());
     }
 #else
     NO_GPU;
 #endif
     break;
+  }
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+  const vector<float>& net_params_lr = this->net_->params_lr();
+  Dtype momentum = this->param_.momentum();
+  switch (Caffe::mode()) {
+  case Caffe::CPU: {
+    // Compute the value to history, and then copy them to the blob's diff.
+    Dtype local_rate = rate * net_params_lr[param_id];
+
+    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+              net_params[param_id]->cpu_diff(), momentum,
+              history_[param_id]->mutable_cpu_data());
+    // copy
+    caffe_copy(net_params[param_id]->count(),
+        history_[param_id]->cpu_data(),
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  }
+  case Caffe::GPU: {
+#ifndef CPU_ONLY
+    // Compute the value to history, and then copy them to the blob's diff.
+    Dtype local_rate = rate * net_params_lr[param_id];
+
+    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+              net_params[param_id]->gpu_diff(), momentum,
+              history_[param_id]->mutable_gpu_data());
+    // copy
+    caffe_copy(net_params[param_id]->count(),
+        history_[param_id]->gpu_data(),
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  }
   default:
     LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
   }
@@ -571,252 +593,144 @@ void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
 }
 
 template <typename Dtype>
-void NesterovSolver<Dtype>::ComputeUpdateValue() {
+void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
   const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   const vector<float>& net_params_lr = this->net_->params_lr();
-  const vector<float>& net_params_weight_decay =
-      this->net_->params_weight_decay();
-  // get the learning rate
-  Dtype rate = this->GetLearningRate();
-  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
-    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
-  }
-  SGDSolver<Dtype>::ClipGradients();
   Dtype momentum = this->param_.momentum();
-  Dtype weight_decay = this->param_.weight_decay();
-  string regularization_type = this->param_.regularization_type();
   switch (Caffe::mode()) {
-  case Caffe::CPU:
-    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
-      // save history momentum for stepping back
-      caffe_copy(net_params[param_id]->count(),
-          this->history_[param_id]->cpu_data(),
-          this->update_[param_id]->mutable_cpu_data());
-
-      Dtype local_rate = rate * net_params_lr[param_id];
-      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-
-      if (local_decay) {
-        if (regularization_type == "L2") {
-          // add weight decay
-          caffe_axpy(net_params[param_id]->count(),
-              local_decay,
-              net_params[param_id]->cpu_data(),
-              net_params[param_id]->mutable_cpu_diff());
-        } else if (regularization_type == "L1") {
-          caffe_cpu_sign(net_params[param_id]->count(),
-              net_params[param_id]->cpu_data(),
-              this->temp_[param_id]->mutable_cpu_data());
-          caffe_axpy(net_params[param_id]->count(),
-              local_decay,
-              this->temp_[param_id]->cpu_data(),
-              net_params[param_id]->mutable_cpu_diff());
-        } else {
-          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
-        }
-      }
-
-      // update history
-      caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
-                net_params[param_id]->cpu_diff(), momentum,
-                this->history_[param_id]->mutable_cpu_data());
-
-      // compute udpate: step back then over step
-      caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
-          this->history_[param_id]->cpu_data(), -momentum,
-          this->update_[param_id]->mutable_cpu_data());
-
-      // copy
-      caffe_copy(net_params[param_id]->count(),
-          this->update_[param_id]->cpu_data(),
-          net_params[param_id]->mutable_cpu_diff());
-    }
+  case Caffe::CPU: {
+    // save history momentum for stepping back
+    caffe_copy(net_params[param_id]->count(),
+        this->history_[param_id]->cpu_data(),
+        this->update_[param_id]->mutable_cpu_data());
+
+    Dtype local_rate = rate * net_params_lr[param_id];
+
+    // update history
+    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+              net_params[param_id]->cpu_diff(), momentum,
+              this->history_[param_id]->mutable_cpu_data());
+
+    // compute update: step back then over step
+    caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
+        this->history_[param_id]->cpu_data(), -momentum,
+        this->update_[param_id]->mutable_cpu_data());
+
+    // copy
+    caffe_copy(net_params[param_id]->count(),
+        this->update_[param_id]->cpu_data(),
+        net_params[param_id]->mutable_cpu_diff());
     break;
-  case Caffe::GPU:
+  }
+  case Caffe::GPU: {
 #ifndef CPU_ONLY
-    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
-      // save history momentum for stepping back
-      caffe_copy(net_params[param_id]->count(),
-          this->history_[param_id]->gpu_data(),
-          this->update_[param_id]->mutable_gpu_data());
-
-      Dtype local_rate = rate * net_params_lr[param_id];
-      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-
-      if (local_decay) {
-        if (regularization_type == "L2") {
-          // add weight decay
-          caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay,
-              net_params[param_id]->gpu_data(),
-              net_params[param_id]->mutable_gpu_diff());
-        } else if (regularization_type == "L1") {
-          caffe_gpu_sign(net_params[param_id]->count(),
-              net_params[param_id]->gpu_data(),
-              this->temp_[param_id]->mutable_gpu_data());
-          caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay,
-              this->temp_[param_id]->gpu_data(),
-              net_params[param_id]->mutable_gpu_diff());
-        } else {
-          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
-        }
-      }
-
-      // update history
-      caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
-                net_params[param_id]->gpu_diff(), momentum,
-                this->history_[param_id]->mutable_gpu_data());
-
-      // compute udpate: step back then over step
-      caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
-          this->history_[param_id]->gpu_data(), -momentum,
-          this->update_[param_id]->mutable_gpu_data());
-
-      // copy
-      caffe_copy(net_params[param_id]->count(),
-          this->update_[param_id]->gpu_data(),
-          net_params[param_id]->mutable_gpu_diff());
-    }
+    // save history momentum for stepping back
+    caffe_copy(net_params[param_id]->count(),
+        this->history_[param_id]->gpu_data(),
+        this->update_[param_id]->mutable_gpu_data());
+
+    Dtype local_rate = rate * net_params_lr[param_id];
+
+    // update history
+    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+              net_params[param_id]->gpu_diff(), momentum,
+              this->history_[param_id]->mutable_gpu_data());
+
+    // compute update: step back then over step
+    caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
+        this->history_[param_id]->gpu_data(), -momentum,
+        this->update_[param_id]->mutable_gpu_data());
+
+    // copy
+    caffe_copy(net_params[param_id]->count(),
+        this->update_[param_id]->gpu_data(),
+        net_params[param_id]->mutable_gpu_diff());
 #else
     NO_GPU;
 #endif
     break;
+  }
   default:
     LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
   }
 }
 
 template <typename Dtype>
-void AdaGradSolver<Dtype>::ComputeUpdateValue() {
+void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
   const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   const vector<float>& net_params_lr = this->net_->params_lr();
-  const vector<float>& net_params_weight_decay =
-      this->net_->params_weight_decay();
-  // get the learning rate
-  Dtype rate = this->GetLearningRate();
   Dtype delta = this->param_.delta();
-  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
-    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
-  }
-  SGDSolver<Dtype>::ClipGradients();
-  Dtype weight_decay = this->param_.weight_decay();
-  string regularization_type = this->param_.regularization_type();
   switch (Caffe::mode()) {
-  case Caffe::CPU:
-    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
-      Dtype local_rate = rate * net_params_lr[param_id];
-      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-
-      if (local_decay) {
-        if (regularization_type == "L2") {
-          // add weight decay
-          caffe_axpy(net_params[param_id]->count(),
-              local_decay,
-              net_params[param_id]->cpu_data(),
-              net_params[param_id]->mutable_cpu_diff());
-        } else if (regularization_type == "L1") {
-          caffe_cpu_sign(net_params[param_id]->count(),
-              net_params[param_id]->cpu_data(),
-              this->temp_[param_id]->mutable_cpu_data());
-          caffe_axpy(net_params[param_id]->count(),
-              local_decay,
-              this->temp_[param_id]->cpu_data(),
-              net_params[param_id]->mutable_cpu_diff());
-        } else {
-          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
-        }
-      }
-
-      // compute square of gradient in update
-      caffe_powx(net_params[param_id]->count(),
-          net_params[param_id]->cpu_diff(), Dtype(2),
-          this->update_[param_id]->mutable_cpu_data());
-
-      // update history
-      caffe_add(net_params[param_id]->count(),
-          this->update_[param_id]->cpu_data(),
-          this->history_[param_id]->cpu_data(),
-          this->history_[param_id]->mutable_cpu_data());
-
-      // prepare update
-      caffe_powx(net_params[param_id]->count(),
-                this->history_[param_id]->cpu_data(), Dtype(0.5),
-                this->update_[param_id]->mutable_cpu_data());
-
-      caffe_add_scalar(net_params[param_id]->count(),
-                delta, this->update_[param_id]->mutable_cpu_data());
-
-      caffe_div(net_params[param_id]->count(),
-                net_params[param_id]->cpu_diff(),
-                this->update_[param_id]->cpu_data(),
-                this->update_[param_id]->mutable_cpu_data());
-
-      // scale and copy
-      caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
-          this->update_[param_id]->cpu_data(), Dtype(0),
-          net_params[param_id]->mutable_cpu_diff());
-    }
+  case Caffe::CPU: {
+    Dtype local_rate = rate * net_params_lr[param_id];
+
+    // compute square of gradient in update
+    caffe_powx(net_params[param_id]->count(),
+        net_params[param_id]->cpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // update history
+    caffe_add(net_params[param_id]->count(),
+        this->update_[param_id]->cpu_data(),
+        this->history_[param_id]->cpu_data(),
+        this->history_[param_id]->mutable_cpu_data());
+
+    // prepare update
+    caffe_powx(net_params[param_id]->count(),
+              this->history_[param_id]->cpu_data(), Dtype(0.5),
+              this->update_[param_id]->mutable_cpu_data());
+
+    caffe_add_scalar(net_params[param_id]->count(),
+              delta, this->update_[param_id]->mutable_cpu_data());
+
+    caffe_div(net_params[param_id]->count(),
+              net_params[param_id]->cpu_diff(),
+              this->update_[param_id]->cpu_data(),
+              this->update_[param_id]->mutable_cpu_data());
+
+    // scale and copy
+    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+        this->update_[param_id]->cpu_data(), Dtype(0),
+        net_params[param_id]->mutable_cpu_diff());
     break;
-  case Caffe::GPU:
+  }
+  case Caffe::GPU: {
 #ifndef CPU_ONLY
-    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
-      Dtype local_rate = rate * net_params_lr[param_id];
-      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-
-      if (local_decay) {
-        if (regularization_type == "L2") {
-          // add weight decay
-          caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay,
-              net_params[param_id]->gpu_data(),
-              net_params[param_id]->mutable_gpu_diff());
-        } else if (regularization_type == "L1") {
-          caffe_gpu_sign(net_params[param_id]->count(),
-              net_params[param_id]->gpu_data(),
-              this->temp_[param_id]->mutable_gpu_data());
-          caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay,
-              this->temp_[param_id]->gpu_data(),
-              net_params[param_id]->mutable_gpu_diff());
-        } else {
-          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
-        }
-      }
-
-      // compute square of gradient in update
-      caffe_gpu_powx(net_params[param_id]->count(),
-          net_params[param_id]->gpu_diff(), Dtype(2),
-          this->update_[param_id]->mutable_gpu_data());
-
-      // update history
-      caffe_gpu_add(net_params[param_id]->count(),
-          this->update_[param_id]->gpu_data(),
-          this->history_[param_id]->gpu_data(),
-          this->history_[param_id]->mutable_gpu_data());
-
-      // prepare update
-      caffe_gpu_powx(net_params[param_id]->count(),
-                this->history_[param_id]->gpu_data(), Dtype(0.5),
-                this->update_[param_id]->mutable_gpu_data());
-
-      caffe_gpu_add_scalar(net_params[param_id]->count(),
-                delta, this->update_[param_id]->mutable_gpu_data());
-
-      caffe_gpu_div(net_params[param_id]->count(),
-                net_params[param_id]->gpu_diff(),
-                this->update_[param_id]->gpu_data(),
-                this->update_[param_id]->mutable_gpu_data());
-
-      // scale and copy
-      caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
-          this->update_[param_id]->gpu_data(), Dtype(0),
-          net_params[param_id]->mutable_gpu_diff());
-    }
+    Dtype local_rate = rate * net_params_lr[param_id];
+
+    // compute square of gradient in update
+    caffe_gpu_powx(net_params[param_id]->count(),
+        net_params[param_id]->gpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_gpu_data());
+
+    // update history
+    caffe_gpu_add(net_params[param_id]->count(),
+        this->update_[param_id]->gpu_data(),
+        this->history_[param_id]->gpu_data(),
+        this->history_[param_id]->mutable_gpu_data());
+
+    // prepare update
+    caffe_gpu_powx(net_params[param_id]->count(),
+              this->history_[param_id]->gpu_data(), Dtype(0.5),
+              this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_add_scalar(net_params[param_id]->count(),
+              delta, this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_div(net_params[param_id]->count(),
+              net_params[param_id]->gpu_diff(),
+              this->update_[param_id]->gpu_data(),
+              this->update_[param_id]->mutable_gpu_data());
+
+    // scale and copy
+    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+        this->update_[param_id]->gpu_data(), Dtype(0),
+        net_params[param_id]->mutable_gpu_diff());
 #else
     NO_GPU;
 #endif
     break;
+  }
   default:
     LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
   }