Implement RMSProp Solver
authorEren Golge <erogol@hotmail.com>
Sun, 9 Aug 2015 06:45:08 +0000 (23:45 -0700)
committerRonghang Hu <huronghang@hotmail.com>
Sun, 9 Aug 2015 06:45:08 +0000 (23:45 -0700)
Implement RMSProp solver and cleaned up to adjust to new solver interface that uses
accumulated gradients and refactored regularization.

examples/mnist/lenet_solver_rmsprop.prototxt [new file with mode: 0644]
examples/mnist/train_lenet_rmsprop.sh [new file with mode: 0755]
include/caffe/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp
src/caffe/test/test_gradient_based_solver.cpp

diff --git a/examples/mnist/lenet_solver_rmsprop.prototxt b/examples/mnist/lenet_solver_rmsprop.prototxt
new file mode 100644 (file)
index 0000000..74dadc5
--- /dev/null
@@ -0,0 +1,27 @@
+# The train/test net protocol buffer definition
+net: "examples/mnist/lenet_train_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of MNIST, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 500 training iterations.
+test_interval: 500
+# The base learning rate, momentum and the weight decay of the network.
+base_lr: 0.01
+momentum: 0.0
+weight_decay: 0.0005
+# The learning rate policy
+lr_policy: "inv"
+gamma: 0.0001
+power: 0.75
+# Display every 100 iterations
+display: 100
+# The maximum number of iterations
+max_iter: 10000
+# snapshot intermediate results
+snapshot: 5000
+snapshot_prefix: "examples/mnist/lenet_rmsprop"
+# solver mode: CPU or GPU
+solver_mode: GPU
+solver_type: RMSPROP
+rms_decay: 0.98
diff --git a/examples/mnist/train_lenet_rmsprop.sh b/examples/mnist/train_lenet_rmsprop.sh
new file mode 100755 (executable)
index 0000000..621cab2
--- /dev/null
@@ -0,0 +1,3 @@
+#!/usr/bin/env sh
+
+./build/tools/caffe train --solver=examples/mnist/lenet_solver_rmsprop.prototxt
index 703434b..fbade93 100644 (file)
@@ -135,6 +135,29 @@ class AdaGradSolver : public SGDSolver<Dtype> {
   DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
 };
 
+
+template <typename Dtype>
+class RMSPropSolver : public SGDSolver<Dtype> {
+ public:
+  explicit RMSPropSolver(const SolverParameter& param)
+      : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
+  explicit RMSPropSolver(const string& param_file)
+      : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+
+ protected:
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
+  void constructor_sanity_check() {
+    CHECK_EQ(0, this->param_.momentum())
+        << "Momentum cannot be used with RMSProp.";
+    CHECK_GE(this->param_.rms_decay(), 0)
+        << "rms_decay should lie between 0 and 1.";
+    CHECK_LT(this->param_.rms_decay(), 1)
+        << "rms_decay should lie between 0 and 1.";
+  }
+
+  DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
+};
+
 template <typename Dtype>
 Solver<Dtype>* GetSolver(const SolverParameter& param) {
   SolverParameter_SolverType type = param.solver_type();
@@ -146,6 +169,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
       return new NesterovSolver<Dtype>(param);
   case SolverParameter_SolverType_ADAGRAD:
       return new AdaGradSolver<Dtype>(param);
+  case SolverParameter_SolverType_RMSPROP:
+      return new RMSPropSolver<Dtype>(param);
   default:
       LOG(FATAL) << "Unknown SolverType: " << type;
   }
index a13c0e7..89f1459 100644 (file)
@@ -98,7 +98,7 @@ message NetParameter {
 // NOTE
 // Update the next available ID when you add a new SolverParameter field.
 //
-// SolverParameter next available ID: 38 (last added: snapshot_format)
+// SolverParameter next available ID: 39 (last added: rms_decay)
 message SolverParameter {
   //////////////////////////////////////////////////////////////////////////////
   // Specifying the train and test networks
@@ -153,7 +153,23 @@ message SolverParameter {
   optional int32 max_iter = 7; // the maximum number of iterations
   // accumulate gradients over `iter_size` x `batch_size` instances
   optional int32 iter_size = 36 [default = 1];
-  optional string lr_policy = 8; // The learning rate decay policy.
+
+  // The learning rate decay policy. The currently implemented learning rate
+  // policies are as follows:
+  //    - fixed: always return base_lr.
+  //    - step: return base_lr * gamma ^ (floor(iter / step))
+  //    - exp: return base_lr * gamma ^ iter
+  //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
+  //    - multistep: similar to step but it allows non uniform steps defined by
+  //      stepvalue
+  //    - poly: the effective learning rate follows a polynomial decay, to be
+  //      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
+  //    - sigmoid: the effective learning rate follows a sigmod decay
+  //      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
+  //
+  // where base_lr, max_iter, gamma, step, stepvalue and power are defined
+  // in the solver parameter protocol buffer, and iter is the current iteration.
+  optional string lr_policy = 8;
   optional float gamma = 9; // The parameter to compute the learning rate.
   optional float power = 10; // The parameter to compute the learning rate.
   optional float momentum = 11; // The momentum value.
@@ -198,11 +214,16 @@ message SolverParameter {
     SGD = 0;
     NESTEROV = 1;
     ADAGRAD = 2;
+    RMSPROP = 3;
   }
   optional SolverType solver_type = 30 [default = SGD];
   // numerical stability for AdaGrad
   optional float delta = 31 [default = 1e-8];
 
+  // RMSProp decay value
+  // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
+  optional float rms_decay = 38;
+
   // If true, print information about the state of the net that may help with
   // debugging learning problems.
   optional bool debug_info = 23 [default = false];
index 32276ac..43834c0 100644 (file)
@@ -859,9 +859,85 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
   }
 }
 
+template <typename Dtype>
+void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+  const vector<float>& net_params_lr = this->net_->params_lr();
+
+  // get the learning rate
+  Dtype delta = this->param_.delta();
+  Dtype rms_decay = this->param_.rms_decay();
+  Dtype local_rate = rate * net_params_lr[param_id];
+
+  switch (Caffe::mode()) {
+  case Caffe::CPU:
+    // compute square of gradient in update
+    caffe_powx(net_params[param_id]->count(),
+        net_params[param_id]->cpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // update history
+    caffe_cpu_axpby(net_params[param_id] -> count(),
+        Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
+        rms_decay, this->history_[param_id]-> mutable_cpu_data());
+
+    // prepare update
+    caffe_powx(net_params[param_id]->count(),
+        this->history_[param_id]->cpu_data(), Dtype(0.5),
+        this->update_[param_id]->mutable_cpu_data());
+
+    caffe_add_scalar(net_params[param_id]->count(),
+        delta, this->update_[param_id]->mutable_cpu_data());
+
+    caffe_div(net_params[param_id]->count(),
+        net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
+        this->update_[param_id]->mutable_cpu_data());
+
+    // scale and copy
+    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+        this->update_[param_id]->cpu_data(), Dtype(0),
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  case Caffe::GPU:
+#ifndef CPU_ONLY
+    // compute square of gradient in update
+    caffe_gpu_powx(net_params[param_id]->count(),
+        net_params[param_id]->gpu_diff(), Dtype(2),
+        this->update_[param_id]->mutable_gpu_data());
+
+    // update history
+    caffe_gpu_axpby(net_params[param_id] -> count(),
+        Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
+        rms_decay, this->history_[param_id]-> mutable_gpu_data());
+
+    // prepare update
+    caffe_gpu_powx(net_params[param_id]->count(),
+        this->history_[param_id]->gpu_data(), Dtype(0.5),
+        this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_add_scalar(net_params[param_id]->count(),
+        delta, this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_div(net_params[param_id]->count(),
+        net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
+        this->update_[param_id]->mutable_gpu_data());
+
+    caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+        this->update_[param_id]->gpu_data(), Dtype(0),
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
 INSTANTIATE_CLASS(Solver);
 INSTANTIATE_CLASS(SGDSolver);
 INSTANTIATE_CLASS(NesterovSolver);
 INSTANTIATE_CLASS(AdaGradSolver);
+INSTANTIATE_CLASS(RMSPropSolver);
 
 }  // namespace caffe
index 7bb0ec1..b091892 100644 (file)
@@ -52,13 +52,14 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
         LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
     }
     InitSolver(param);
-    delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD) ?
-         param.delta() : 0;
+    delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD ||
+        solver_type() == SolverParameter_SolverType_RMSPROP) ?
+        param.delta() : 0;
   }
 
   string RunLeastSquaresSolver(const Dtype learning_rate,
-      const Dtype weight_decay, const Dtype momentum, const int num_iters,
-      const int iter_size = 1, const bool snapshot = false,
+      const Dtype weight_decay, const Dtype momentum, const Dtype rms_decay,
+      const int num_iters, const int iter_size = 1, const bool snapshot = false,
       const char* from_snapshot = NULL) {
     ostringstream proto;
     proto <<
@@ -173,6 +174,9 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     if (momentum != 0) {
       proto << "momentum: " << momentum << " ";
     }
+    if (rms_decay != 0) {
+      proto << "rms_decay: " << rms_decay << " ";
+    }
     MakeTempDir(&snapshot_prefix_);
     proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' ";
     if (snapshot) {
@@ -204,7 +208,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
   // updated_params will store the updated weight and bias results,
   // using the blobs' diffs to hold the update values themselves.
   void ComputeLeastSquaresUpdate(const Dtype learning_rate,
-      const Dtype weight_decay, const Dtype momentum,
+      const Dtype weight_decay, const Dtype momentum, const Dtype rms_decay,
       vector<shared_ptr<Blob<Dtype> > >* updated_params) {
     const int N = num_;
     const int D = channels_ * height_ * width_;
@@ -287,6 +291,10 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
       case SolverParameter_SolverType_ADAGRAD:
         update_value /= std::sqrt(history_value + grad * grad) + delta_;
         break;
+      case SolverParameter_SolverType_RMSPROP:
+        update_value /= std::sqrt(rms_decay*history_value
+            + grad * grad * (1 - rms_decay)) + delta_;
+        break;
       default:
         LOG(FATAL) << "Unknown solver type: " << solver_type();
       }
@@ -352,13 +360,14 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
   }
 
   void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay,
-      const Dtype kMomentum, const int kNumIters, const int kIterSize) {
+      const Dtype kMomentum, const Dtype rms_decay, const int kNumIters,
+      const int kIterSize) {
     const double kPrecision = 1e-2;
     const double kMinPrecision = 1e-7;
     constant_data_ = true;
     // Solve without accumulation and save parameters.
     this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
-        kNumIters);
+        rms_decay, kNumIters);
     // Save parameters for comparison.
     Net<Dtype>& net = *this->solver_->net();
     const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
@@ -370,7 +379,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     }
     // Solve by equivalent accumulation of gradients over divided batches.
     this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
-        kNumIters, kIterSize);
+        rms_decay, kNumIters, kIterSize);
     Net<Dtype>& net_accum = *this->solver_->net();
     const vector<shared_ptr<Blob<Dtype> > >& accum_params =
         net_accum.layer_by_name("innerprod")->blobs();
@@ -408,18 +417,19 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
   // matches the solver's (K+1)th update.
   void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
       const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
-      const int iter_to_check = 0) {
+      const Dtype rms_decay = 0.0, const int iter_to_check = 0) {
     // Initialize the solver and run K (= iter_to_check) solver iterations.
-    RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check);
+    RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay,
+        iter_to_check);
 
     // Compute the (K+1)th update using the analytic least squares gradient.
     vector<shared_ptr<Blob<Dtype> > > updated_params;
     ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
-                              &updated_params);
+        rms_decay, &updated_params);
 
     // Reinitialize the solver and run K+1 solver iterations.
-    RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
-                          iter_to_check + 1);
+    RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay,
+        iter_to_check + 1);
 
     // Check that the solver's solution matches ours.
     CheckLeastSquaresUpdate(updated_params);
@@ -427,12 +437,12 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
 
   void TestSnapshot(const Dtype learning_rate = 1.0,
       const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
-      const int num_iters = 1) {
+      const Dtype rms_decay = 0.0, const int num_iters = 1) {
     // Run the solver for num_iters * 2 iterations.
     const int total_num_iters = num_iters * 2;
     bool snapshot = false;
     const int kIterSize = 1;
-    RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+    RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay,
                           total_num_iters, kIterSize, snapshot);
 
     // Save the resulting param values.
@@ -463,12 +473,12 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     // Run the solver for num_iters iterations and snapshot.
     snapshot = true;
     string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
-        momentum, num_iters, kIterSize, snapshot);
+        momentum, rms_decay, num_iters, kIterSize, snapshot);
 
     // Reinitialize the solver and run for num_iters more iterations.
     snapshot = false;
-    RunLeastSquaresSolver(learning_rate, weight_decay,
-        momentum, total_num_iters, kIterSize, snapshot, snapshot_name.c_str());
+    RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay,
+        total_num_iters, kIterSize, snapshot, snapshot_name.c_str());
 
     // Check that params now match.
     const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params();
@@ -548,9 +558,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0;
   const Dtype kMomentum = 0.5;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 1;
   for (int i = 0; i <= kNumIters; ++i) {
-    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
   }
 }
 
@@ -559,9 +571,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0;
   const Dtype kMomentum = 0.5;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   for (int i = 0; i <= kNumIters; ++i) {
-    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
   }
 }
 
@@ -570,9 +584,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.5;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   for (int i = 0; i <= kNumIters; ++i) {
-    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
   }
 }
 
@@ -581,10 +597,12 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingShare) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.5;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   this->share_ = true;
   for (int i = 0; i <= kNumIters; ++i) {
-    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
   }
 }
 
@@ -593,10 +611,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.9;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   const int kIterSize = 2;
-  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
-      kIterSize);
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+      kNumIters, kIterSize);
 }
 
 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
@@ -604,11 +623,12 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.9;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   const int kIterSize = 2;
   this->share_ = true;
-  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
-      kIterSize);
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+      kNumIters, kIterSize);
 }
 
 TYPED_TEST(SGDSolverTest, TestSnapshot) {
@@ -616,9 +636,10 @@ TYPED_TEST(SGDSolverTest, TestSnapshot) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.9;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   for (int i = 1; i <= kNumIters; ++i) {
-    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
   }
 }
 
@@ -627,10 +648,11 @@ TYPED_TEST(SGDSolverTest, TestSnapshotShare) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.9;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   this->share_ = true;
   for (int i = 1; i <= kNumIters; ++i) {
-    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
   }
 }
 
@@ -672,22 +694,26 @@ TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   for (int i = 0; i <= kNumIters; ++i) {
-    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
   }
 }
 
 TYPED_TEST(AdaGradSolverTest,
-           TestAdaGradLeastSquaresUpdateWithEverythingShare) {
+      TestAdaGradLeastSquaresUpdateWithEverythingShare) {
   typedef typename TypeParam::Dtype Dtype;
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   this->share_ = true;
   for (int i = 0; i <= kNumIters; ++i) {
-    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
   }
 }
 
@@ -696,10 +722,11 @@ TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   const int kIterSize = 2;
-  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
-      kIterSize);
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+      kNumIters, kIterSize);
 }
 
 TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
@@ -707,11 +734,12 @@ TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   const int kIterSize = 2;
   this->share_ = true;
-  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
-      kIterSize);
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+      kNumIters, kIterSize);
 }
 
 TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
@@ -719,9 +747,10 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   for (int i = 1; i <= kNumIters; ++i) {
-    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
   }
 }
 
@@ -730,10 +759,11 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   this->share_ = true;
   for (int i = 1; i <= kNumIters; ++i) {
-    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
   }
 }
 
@@ -787,9 +817,11 @@ TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0;
   const Dtype kMomentum = 0.5;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 1;
   for (int i = 0; i <= kNumIters; ++i) {
-    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
   }
 }
 
@@ -798,9 +830,11 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0;
   const Dtype kMomentum = 0.5;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   for (int i = 0; i <= kNumIters; ++i) {
-    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
   }
 }
 
@@ -821,10 +855,12 @@ TYPED_TEST(NesterovSolverTest,
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.9;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   this->share_ = true;
   for (int i = 0; i <= kNumIters; ++i) {
-    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
   }
 }
 
@@ -833,10 +869,11 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.9;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   const int kIterSize = 2;
-  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
-      kIterSize);
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+      kNumIters, kIterSize);
 }
 
 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
@@ -844,11 +881,12 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.9;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   const int kIterSize = 2;
   this->share_ = true;
-  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
-      kIterSize);
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+      kNumIters, kIterSize);
 }
 
 TYPED_TEST(NesterovSolverTest, TestSnapshot) {
@@ -856,9 +894,10 @@ TYPED_TEST(NesterovSolverTest, TestSnapshot) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.9;
+  const Dtype kRMSDecay = 0;
   const int kNumIters = 4;
   for (int i = 1; i <= kNumIters; ++i) {
-    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
   }
 }
 
@@ -867,10 +906,124 @@ TYPED_TEST(NesterovSolverTest, TestSnapshotShare) {
   const Dtype kLearningRate = 0.01;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.9;
+  const Dtype kRMSDecay = 0;
+  const int kNumIters = 4;
+  this->share_ = true;
+  for (int i = 1; i <= kNumIters; ++i) {
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
+  }
+}
+
+template <typename TypeParam>
+class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  virtual void InitSolver(const SolverParameter& param) {
+    this->solver_.reset(new RMSPropSolver<Dtype>(param));
+  }
+  virtual SolverParameter_SolverType solver_type() {
+    return SolverParameter_SolverType_RMSPROP;
+  }
+};
+
+TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);
+
+TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithWeightDecay) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 1.0;
+  const Dtype kWeightDecay = 0.5;
+  this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
+}
+
+TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.0;
+  const Dtype kMomentum = 0.0;
+  const Dtype kRMSDecay = 0.95;
+  const int kNumIters = 4;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
+  }
+}
+
+TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.0;
+  const Dtype kRMSDecay = 0.95;
+  const int kNumIters = 4;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
+  }
+}
+
+TYPED_TEST(RMSPropSolverTest,
+      TestRMSPropLeastSquaresUpdateWithEverythingShare) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.0;
+  const Dtype kRMSDecay = 0.95;
+  const int kNumIters = 4;
+  this->share_ = true;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+        kRMSDecay, i);
+  }
+}
+
+TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.0;
+  const Dtype kRMSDecay = 0.95;
+  const int kNumIters = 4;
+  const int kIterSize = 2;
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+      kNumIters, kIterSize);
+}
+
+TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.0;
+  const Dtype kRMSDecay = 0.95;
+  const int kNumIters = 4;
+  const int kIterSize = 2;
+  this->share_ = true;
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+      kNumIters, kIterSize);
+}
+
+TYPED_TEST(RMSPropSolverTest, TestSnapshot) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0;
+  const Dtype kRMSDecay = 0.95;
+  const int kNumIters = 4;
+  for (int i = 1; i <= kNumIters; ++i) {
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
+  }
+}
+
+TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0;
+  const Dtype kRMSDecay = 0.95;
   const int kNumIters = 4;
   this->share_ = true;
   for (int i = 1; i <= kNumIters; ++i) {
-    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
   }
 }