Adam solver
authorPatWie <patrick@wieschollek.info>
Mon, 3 Aug 2015 15:31:14 +0000 (17:31 +0200)
committerRonghang Hu <huronghang@hotmail.com>
Fri, 14 Aug 2015 16:01:53 +0000 (09:01 -0700)
This commit implements the Adam solver by Kingma et. al for CPU and
GPU. All solver parameters are defined in the caffe.proto. This also
adds an example for the MNIST dataset.

examples/mnist/lenet_solver_adam.prototxt [new file with mode: 0644]
examples/mnist/train_lenet_adam.sh [new file with mode: 0755]
include/caffe/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp
src/caffe/test/test_gradient_based_solver.cpp

diff --git a/examples/mnist/lenet_solver_adam.prototxt b/examples/mnist/lenet_solver_adam.prototxt
new file mode 100644 (file)
index 0000000..d22c571
--- /dev/null
@@ -0,0 +1,26 @@
+# The train/test net protocol buffer definition
+# this follows "ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION"
+net: "examples/mnist/lenet_train_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of MNIST, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 500 training iterations.
+test_interval: 500
+# All parameters are from the cited paper above
+base_lr: 0.001
+momentum: 0.9
+momentum2: 0.999
+# since Adam dynamically changes the learning rate, we set the base learning
+# rate to a fixed value
+lr_policy: "fixed"
+# Display every 100 iterations
+display: 100
+# The maximum number of iterations
+max_iter: 10000
+# snapshot intermediate results
+snapshot: 5000
+snapshot_prefix: "examples/mnist/lenet"
+# solver mode: CPU or GPU
+solver_type: ADAM
+solver_mode: GPU
diff --git a/examples/mnist/train_lenet_adam.sh b/examples/mnist/train_lenet_adam.sh
new file mode 100755 (executable)
index 0000000..a32ecf2
--- /dev/null
@@ -0,0 +1,3 @@
+#!/usr/bin/env sh
+
+./build/tools/caffe train --solver=examples/mnist/lenet_solver_adam.prototxt
index d2b9992..582aa14 100644 (file)
@@ -218,6 +218,21 @@ class AdaDeltaSolver : public SGDSolver<Dtype> {
 };
 
 template <typename Dtype>
+class AdamSolver : public SGDSolver<Dtype> {
+ public:
+  explicit AdamSolver(const SolverParameter& param)
+      : SGDSolver<Dtype>(param) { AdamPreSolve();}
+  explicit AdamSolver(const string& param_file)
+      : SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
+
+ protected:
+  void AdamPreSolve();
+  virtual void ComputeUpdateValue(int param_id, Dtype rate);
+
+  DISABLE_COPY_AND_ASSIGN(AdamSolver);
+};
+
+template <typename Dtype>
 Solver<Dtype>* GetSolver(const SolverParameter& param) {
   SolverParameter_SolverType type = param.solver_type();
 
@@ -232,6 +247,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
       return new RMSPropSolver<Dtype>(param);
   case SolverParameter_SolverType_ADADELTA:
       return new AdaDeltaSolver<Dtype>(param);
+  case SolverParameter_SolverType_ADAM:
+      return new AdamSolver<Dtype>(param);
   default:
       LOG(FATAL) << "Unknown SolverType: " << type;
   }
index fc0d961..d4c97d2 100644 (file)
@@ -98,7 +98,7 @@ message NetParameter {
 // NOTE
 // Update the next available ID when you add a new SolverParameter field.
 //
-// SolverParameter next available ID: 39 (last added: rms_decay)
+// SolverParameter next available ID: 40 (last added: momentum2)
 message SolverParameter {
   //////////////////////////////////////////////////////////////////////////////
   // Specifying the train and test networks
@@ -216,10 +216,13 @@ message SolverParameter {
     ADAGRAD = 2;
     RMSPROP = 3;
     ADADELTA = 4;
+    ADAM = 5;
   }
   optional SolverType solver_type = 30 [default = SGD];
-  // numerical stability for AdaGrad
+  // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
   optional float delta = 31 [default = 1e-8];
+  // parameters for the Adam solver
+  optional float momentum2 = 39 [default = 0.999];
 
   // RMSProp decay value
   // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
index 248f238..9348e11 100644 (file)
@@ -1114,11 +1114,115 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
   }
 }
 
+template <typename Dtype>
+void AdamSolver<Dtype>::AdamPreSolve() {
+  // Add the extra history entries for Adam after those from
+  // SGDSolver::PreSolve
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  for (int i = 0; i < net_params.size(); ++i) {
+    const vector<int>& shape = net_params[i]->shape();
+    this->history_.push_back(
+            shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+  }
+}
+
+template <typename Dtype>
+void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+  const vector<float>& net_params_lr = this->net_->params_lr();
+  Dtype local_rate = rate * net_params_lr[param_id];
+  const Dtype beta1 = this->param_.momentum();
+  const Dtype beta2 = this->param_.momentum2();
+
+  // we create aliases for convenience
+  size_t update_history_offset = net_params.size();
+  Blob<Dtype>* val_m = this->history_[param_id].get();
+  Blob<Dtype>* val_v = this->history_[param_id + update_history_offset].get();
+  Blob<Dtype>* val_t = this->temp_[param_id].get();
+
+  const int t = this->iter_  + 1;
+  const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) /
+      (Dtype(1.) - pow(beta1, t));
+  const int N = net_params[param_id]->count();
+  const Dtype eps_hat = this->param_.delta();
+
+  switch (Caffe::mode()) {
+    case Caffe::CPU: {
+    // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
+    caffe_cpu_axpby(N, Dtype(1)-beta1,
+        net_params[param_id]->cpu_diff(), beta1,
+        val_m->mutable_cpu_data());
+
+    // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
+    caffe_mul(N,
+        net_params[param_id]->cpu_diff(),
+        net_params[param_id]->cpu_diff(),
+    val_t->mutable_cpu_data());
+    caffe_cpu_axpby(N, Dtype(1)-beta2,
+        val_t->cpu_data(), beta2,
+        val_v->mutable_cpu_data());
+
+    // set update
+    caffe_powx(N,
+        val_v->cpu_data(), Dtype(0.5),
+        val_t->mutable_cpu_data());
+    caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
+    caffe_div(N,
+        val_m->cpu_data(),
+        val_t->cpu_data(),
+        val_t->mutable_cpu_data());
+
+    caffe_cpu_scale(N, local_rate*correction,
+        val_t->cpu_data(),
+        net_params[param_id]->mutable_cpu_diff());
+    break;
+  }
+  case Caffe::GPU: {
+#ifndef CPU_ONLY
+    // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
+    caffe_gpu_axpby(N, Dtype(1)-beta1,
+        net_params[param_id]->gpu_diff(), beta1,
+        val_m->mutable_gpu_data());
+
+    // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
+    caffe_gpu_mul(N,
+        net_params[param_id]->gpu_diff(),
+        net_params[param_id]->gpu_diff(),
+        val_t->mutable_gpu_data());
+    caffe_gpu_axpby(N, Dtype(1)-beta2,
+        val_t->gpu_data(), beta2,
+        val_v->mutable_gpu_data());
+
+    // set update
+    caffe_gpu_powx(N,
+        val_v->gpu_data(), Dtype(0.5),
+        val_t->mutable_gpu_data());
+    caffe_gpu_add_scalar(N, eps_hat,
+        val_t->mutable_gpu_data());
+    caffe_gpu_div(N,
+        val_m->gpu_data(),
+        val_t->gpu_data(),
+        val_t->mutable_gpu_data());
+
+    caffe_gpu_scale(N, local_rate*correction,
+        val_t->gpu_data(),
+        net_params[param_id]->mutable_gpu_diff());
+#else
+    NO_GPU;
+#endif
+    break;
+  }
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  }
+}
+
 INSTANTIATE_CLASS(Solver);
 INSTANTIATE_CLASS(SGDSolver);
 INSTANTIATE_CLASS(NesterovSolver);
 INSTANTIATE_CLASS(AdaGradSolver);
 INSTANTIATE_CLASS(RMSPropSolver);
 INSTANTIATE_CLASS(AdaDeltaSolver);
+INSTANTIATE_CLASS(AdamSolver);
 
 }  // namespace caffe
index 1d255a8..dcbfff1 100644 (file)
@@ -42,7 +42,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
   // TODO this is brittle and the hdf5 file should be checked instead.
   int num_, channels_, height_, width_;
   bool share_;
-  Dtype delta_;  // Stability constant for AdaGrad.
+  Dtype delta_;  // Stability constant for RMSProp, AdaGrad, AdaDelta and Adam
 
   // Test data: check out generate_sample_data.py in the same directory.
   string* input_file_;
@@ -65,10 +65,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
         LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
     }
     InitSolver(param);
-    delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD ||
-        solver_type() == SolverParameter_SolverType_RMSPROP ||
-        solver_type() == SolverParameter_SolverType_ADADELTA) ?
-        param.delta() : 0;
+    delta_ = param.delta();
   }
 
   string RunLeastSquaresSolver(const Dtype learning_rate,
@@ -216,7 +213,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
   // updated_params will store the updated weight and bias results,
   // using the blobs' diffs to hold the update values themselves.
   void ComputeLeastSquaresUpdate(const Dtype learning_rate,
-      const Dtype weight_decay, const Dtype momentum,
+      const Dtype weight_decay, const Dtype momentum, const int num_iters,
       vector<shared_ptr<Blob<Dtype> > >* updated_params) {
     const int N = num_;
     const int D = channels_ * height_ * width_;
@@ -282,7 +279,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
           ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
       // Finally, compute update.
       const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
-      if (solver_type() != SolverParameter_SolverType_ADADELTA) {
+      if (solver_type() != SolverParameter_SolverType_ADADELTA
+          && solver_type() != SolverParameter_SolverType_ADAM) {
         ASSERT_EQ(2, history.size());  // 1 blob for weights, 1 for bias
       } else {
         ASSERT_EQ(4, history.size());  // additional blobs for update history
@@ -312,16 +310,31 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
       case SolverParameter_SolverType_ADADELTA:
       {
         const Dtype update_history_value = (i == D) ?
-            history[3]->cpu_data()[0] : history[2]->cpu_data()[i];
+            history[1 + num_param_blobs]->cpu_data()[0] :
+            history[0 + num_param_blobs]->cpu_data()[i];
         const Dtype weighted_gradient_average =
             momentum * history_value + (1 - momentum) * (grad * grad);
         update_value = grad * std::sqrt((update_history_value + delta_) /
-            (weighted_gradient_average + delta_));
+            (weighted_gradient_average + delta_)) * learning_rate;
         // not actually needed, just here for illustrative purposes
         // const Dtype weighted_update_average =
         //   momentum * update_history_value + (1 - momentum) * (update_value);
         break;
       }
+      case SolverParameter_SolverType_ADAM: {
+        const Dtype momentum2 = 0.999;
+        const Dtype m = history_value;
+        const Dtype v = (i == D) ?
+            history[1 + num_param_blobs]->cpu_data()[0] :
+            history[0 + num_param_blobs]->cpu_data()[i];
+        const Dtype val_m = (1 - momentum) * grad + momentum * m;
+        const Dtype val_v = (1 - momentum2) * grad * grad + momentum2 * v;
+        Dtype alpha_t = learning_rate *
+            std::sqrt(Dtype(1) - pow(momentum2, num_iters)) /
+            (Dtype(1.) - pow(momentum, num_iters));
+        update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
+        break;
+      }
       default:
         LOG(FATAL) << "Unknown solver type: " << solver_type();
       }
@@ -465,7 +478,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
       // Compute the (K+1)th update using the analytic least squares gradient.
       vector<shared_ptr<Blob<Dtype> > > updated_params;
       ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
-          &updated_params);
+          iter_to_check + 1, &updated_params);
 
       // Reinitialize the solver and run K+1 solver iterations.
       num_ = kNum;
@@ -946,13 +959,13 @@ TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices);
 
 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   this->TestLeastSquaresUpdate(kLearningRate);
 }
 
 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.5;
   const Dtype kMomentum = 0.95;
   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
@@ -960,64 +973,64 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) {
 
 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.0;
   const Dtype kMomentum = 0.5;
   const int kNumIters = 1;
   for (int i = 0; i <= kNumIters; ++i) {
-      this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
   }
 }
 
 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.0;
   const Dtype kMomentum = 0.95;
   const int kNumIters = 1;
   for (int i = 0; i <= kNumIters; ++i) {
-      this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
   }
 }
 
 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.0;
   const Dtype kMomentum = 0.95;
   const int kNumIters = 4;
   for (int i = 0; i <= kNumIters; ++i) {
-      this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
   }
 }
 
 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.1;
   const Dtype kMomentum = 0.95;
   const int kNumIters = 4;
   for (int i = 0; i <= kNumIters; ++i) {
-      this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
   }
 }
 
 TYPED_TEST(AdaDeltaSolverTest,
            TestAdaDeltaLeastSquaresUpdateWithEverythingShare) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.1;
   const Dtype kMomentum = 0.95;
   const int kNumIters = 4;
   this->share_ = true;
   for (int i = 0; i <= kNumIters; ++i) {
-      this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
   }
 }
 
 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.1;
   const Dtype kMomentum = 0.95;
   const int kNumIters = 4;
@@ -1028,7 +1041,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
 
 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.1;
   const Dtype kMomentum = 0.95;
   const int kNumIters = 4;
@@ -1040,7 +1053,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
 
 TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.1;
   const Dtype kMomentum = 0.95;
   const int kNumIters = 4;
@@ -1051,7 +1064,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) {
 
 TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) {
   typedef typename TypeParam::Dtype Dtype;
-  const Dtype kLearningRate = 1.0;
+  const Dtype kLearningRate = 0.1;
   const Dtype kWeightDecay = 0.1;
   const Dtype kMomentum = 0.95;
   const int kNumIters = 4;
@@ -1062,6 +1075,111 @@ TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) {
 }
 
 template <typename TypeParam>
+class AdamSolverTest : public GradientBasedSolverTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  virtual void InitSolver(const SolverParameter& param) {
+    SolverParameter new_param = param;
+    const Dtype momentum = 0.9;
+    new_param.set_momentum(momentum);
+    const Dtype momentum2 = 0.999;
+    new_param.set_momentum2(momentum2);
+    this->solver_.reset(new AdamSolver<Dtype>(new_param));
+  }
+  virtual SolverParameter_SolverType solver_type() {
+    return SolverParameter_SolverType_ADAM;
+  }
+};
+
+TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices);
+
+TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0;
+  const Dtype kMomentum = 0.9;
+  this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
+}
+
+TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithWeightDecay) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.9;
+  this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
+}
+
+TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverything) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.9;
+  const int kNumIters = 4;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverythingShare) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.9;
+  const int kNumIters = 4;
+  this->share_ = true;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.9;
+  const int kNumIters = 4;
+  const int kIterSize = 2;
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
+      kIterSize);
+}
+
+TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.9;
+  const int kNumIters = 4;
+  const int kIterSize = 2;
+  this->share_ = true;
+  this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
+      kIterSize);
+}
+
+TYPED_TEST(AdamSolverTest, TestSnapshot) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.9;
+  const int kNumIters = 4;
+  for (int i = 1; i <= kNumIters; ++i) {
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+TYPED_TEST(AdamSolverTest, TestSnapshotShare) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.5;
+  const Dtype kMomentum = 0.9;
+  const int kNumIters = 4;
+  this->share_ = true;
+  for (int i = 1; i <= kNumIters; ++i) {
+    this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+template <typename TypeParam>
 class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
   typedef typename TypeParam::Dtype Dtype;