solver: check and set type to reconcile class and proto
[platform/upstream/caffeonacl.git] / src / caffe / test / test_gradient_based_solver.cpp
1 #include <algorithm>
2 #include <string>
3 #include <utility>
4 #include <vector>
5
6 #include "google/protobuf/text_format.h"
7
8 #include "gtest/gtest.h"
9
10 #include "caffe/common.hpp"
11 #include "caffe/parallel.hpp"
12 #include "caffe/proto/caffe.pb.h"
13 #include "caffe/sgd_solvers.hpp"
14 #include "caffe/util/io.hpp"
15
16 #include "caffe/test/test_caffe_main.hpp"
17
18 using std::ostringstream;
19
20 namespace caffe {
21
22 template <typename TypeParam>
23 class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
24   typedef typename TypeParam::Dtype Dtype;
25
26  protected:
27   GradientBasedSolverTest() :
28       seed_(1701), num_(4), channels_(3), height_(10), width_(10),
29       share_(false) {
30         input_file_ = new string(
31         CMAKE_SOURCE_DIR "caffe/test/test_data/solver_data_list.txt" CMAKE_EXT);
32       }
33   ~GradientBasedSolverTest() {
34     delete input_file_;
35   }
36
37   string snapshot_prefix_;
38   shared_ptr<SGDSolver<Dtype> > solver_;
39   shared_ptr<P2PSync<Dtype> > sync_;
40   int seed_;
41   // Dimensions are determined by generate_sample_data.py
42   // TODO this is brittle and the hdf5 file should be checked instead.
43   int num_, channels_, height_, width_;
44   bool share_;
45   Dtype delta_;  // Stability constant for RMSProp, AdaGrad, AdaDelta and Adam
46
47   // Test data: check out generate_sample_data.py in the same directory.
48   string* input_file_;
49
50   virtual void InitSolver(const SolverParameter& param) = 0;
51
52   virtual void InitSolverFromProtoString(const string& proto) {
53     SolverParameter param;
54     CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
55     // Set the solver_mode according to current Caffe::mode.
56     switch (Caffe::mode()) {
57       case Caffe::CPU:
58         param.set_solver_mode(SolverParameter_SolverMode_CPU);
59         break;
60       case Caffe::GPU:
61         param.set_solver_mode(SolverParameter_SolverMode_GPU);
62         break;
63       default:
64         LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
65     }
66     InitSolver(param);
67     delta_ = param.delta();
68   }
69
70   string RunLeastSquaresSolver(const Dtype learning_rate,
71       const Dtype weight_decay, const Dtype momentum, const int num_iters,
72       const int iter_size = 1, const int devices = 1,
73       const bool snapshot = false, const char* from_snapshot = NULL) {
74     ostringstream proto;
75     int device_id = 0;
76 #ifndef CPU_ONLY
77     if (Caffe::mode() == Caffe::GPU) {
78       CUDA_CHECK(cudaGetDevice(&device_id));
79     }
80 #endif
81     proto <<
82        "snapshot_after_train: " << snapshot << " "
83        "max_iter: " << num_iters << " "
84        "base_lr: " << learning_rate << " "
85        "lr_policy: 'fixed' "
86        "iter_size: " << iter_size << " "
87        "device_id: " << device_id << " "
88        "net_param { "
89        "  name: 'TestNetwork' "
90        "  layer { "
91        "    name: 'data' "
92        "    type: 'HDF5Data' "
93        "    hdf5_data_param { "
94        "      source: '" << *(this->input_file_) << "' "
95        "      batch_size: " << num_ / iter_size << " "
96        "    } "
97        "    top: 'data' "
98        "    top: 'targets' "
99        "  } ";
100     if (share_) {
101       proto <<
102          "  layer { "
103          "    name: 'slice' "
104          "    type: 'Slice' "
105          "    bottom: 'data' "
106          "    top: 'data1' "
107          "    top: 'data2' "
108          "    slice_param { "
109          "      axis: 0 "
110          "    } "
111          "  } ";
112     }
113     proto <<
114        "  layer { "
115        "    name: 'innerprod' "
116        "    type: 'InnerProduct' "
117        "    param { name: 'weights' } "
118        "    param { name: 'bias' } "
119        "    inner_product_param { "
120        "      num_output: 1 "
121        "      weight_filler { "
122        "        type: 'gaussian' "
123        "        std: 1.0 "
124        "      } "
125        "      bias_filler { "
126        "        type: 'gaussian' "
127        "        std: 1.0 "
128        "      } "
129        "    } "
130        "    bottom: '" << string(share_ ? "data1": "data") << "' "
131        "    top: '" << string(share_ ? "innerprod1": "innerprod") << "' "
132        "  } ";
133     if (share_) {
134       proto <<
135          "  layer { "
136          "    name: 'innerprod2' "
137          "    type: 'InnerProduct' "
138          "    param { name: 'weights' } "
139          "    param { name: 'bias' } "
140          "    inner_product_param { "
141          "      num_output: 1 "
142          "      weight_filler { "
143          "        type: 'gaussian' "
144          "        std: 1.0 "
145          "      } "
146          "      bias_filler { "
147          "        type: 'gaussian' "
148          "        std: 1.0 "
149          "      } "
150          "    } "
151          "    bottom: 'data2' "
152          "    top: 'innerprod2' "
153          "  } "
154          "  layer { "
155          "    name: 'concat' "
156          "    type: 'Concat' "
157          "    bottom: 'innerprod1' "
158          "    bottom: 'innerprod2' "
159          "    top: 'innerprod' "
160          "    concat_param { "
161          "      axis: 0 "
162          "    } "
163          "  } ";
164     }
165     proto <<
166        "  layer { "
167        "    name: 'loss' "
168        "    type: 'EuclideanLoss' "
169        "    bottom: 'innerprod' "
170        "    bottom: 'targets' "
171        "  } "
172        "} ";
173     if (weight_decay != 0) {
174       proto << "weight_decay: " << weight_decay << " ";
175     }
176     if (momentum != 0) {
177       proto << "momentum: " << momentum << " ";
178     }
179     MakeTempDir(&snapshot_prefix_);
180     proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' ";
181     if (snapshot) {
182       proto << "snapshot: " << num_iters << " ";
183     }
184     Caffe::set_random_seed(this->seed_);
185     this->InitSolverFromProtoString(proto.str());
186     if (from_snapshot != NULL) {
187       this->solver_->Restore(from_snapshot);
188       for (int i = 0; i < this->solver_->iter(); ++i) {
189         this->solver_->net()->Forward();
190       }
191     }
192     if (devices == 1) {
193       this->solver_->Solve();
194     } else {
195       LOG(INFO) << "Multi-GPU test on " << devices << " devices";
196       vector<int> gpus;
197       // put current device at the beginning
198       int device_id = solver_->param().device_id();
199       gpus.push_back(device_id);
200       for (int i = 0; gpus.size() < devices; ++i) {
201         if (i != device_id)
202           gpus.push_back(i);
203       }
204       Caffe::set_solver_count(gpus.size());
205       this->sync_.reset(new P2PSync<Dtype>(
206           this->solver_, NULL, this->solver_->param()));
207       this->sync_->Run(gpus);
208       Caffe::set_solver_count(1);
209     }
210     if (snapshot) {
211       ostringstream resume_file;
212       resume_file << snapshot_prefix_ << "/_iter_" << num_iters
213                   << ".solverstate";
214       string resume_filename = resume_file.str();
215       return resume_filename;
216     }
217     return string();
218   }
219
220   // Compute an update value given the current state of the train net,
221   // using the analytical formula for the least squares gradient.
222   // updated_params will store the updated weight and bias results,
223   // using the blobs' diffs to hold the update values themselves.
224   void ComputeLeastSquaresUpdate(const Dtype learning_rate,
225       const Dtype weight_decay, const Dtype momentum, const int num_iters,
226       vector<shared_ptr<Blob<Dtype> > >* updated_params) {
227     const int N = num_;
228     const int D = channels_ * height_ * width_;
229
230     // Run a forward pass, and manually compute the update values from the
231     // result.
232     Net<Dtype>& net = *this->solver_->net();
233     net.Forward();
234     ASSERT_TRUE(net.has_blob("data"));
235     const Blob<Dtype>& data = *net.blob_by_name("data");
236     ASSERT_TRUE(net.has_blob("targets"));
237     const Blob<Dtype>& targets = *net.blob_by_name("targets");
238     ASSERT_TRUE(net.has_layer("innerprod"));
239     const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
240         net.layer_by_name("innerprod")->blobs();
241     const int num_param_blobs = 2;
242     ASSERT_EQ(num_param_blobs, param_blobs.size());
243     const Blob<Dtype>& weights = *param_blobs[0];
244     const Blob<Dtype>& bias = *param_blobs[1];
245     ASSERT_EQ(D * N, data.count());
246     ASSERT_EQ(N, targets.count());
247     ASSERT_EQ(D, weights.count());
248     ASSERT_EQ(1, bias.count());
249
250     updated_params->clear();
251     updated_params->resize(num_param_blobs);
252     for (int i = 0; i < num_param_blobs; ++i) {
253       (*updated_params)[i].reset(new Blob<Dtype>());
254     }
255     Blob<Dtype>& updated_weights = *(*updated_params)[0];
256     updated_weights.ReshapeLike(weights);
257     Blob<Dtype>& updated_bias = *(*updated_params)[1];
258     updated_bias.ReshapeLike(bias);
259
260     for (int i = 0; i <= D; ++i) {
261       // Compute the derivative with respect to the ith weight (i.e., the ith
262       // element of the gradient).
263       Dtype grad = 0;
264       for (int j = 0; j <= D; ++j) {
265         // Compute element (i, j) of X^T * X.
266         Dtype element = 0;
267         for (int k = 0; k < N; ++k) {
268           // (i, k) in X^T (== (k, i) in X) times (k, j) in X.
269           const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
270           const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j];
271           element += element_i * element_j;
272         }
273         if (j == D) {
274           grad += element * bias.cpu_data()[0];
275         } else {
276           grad += element * weights.cpu_data()[j];
277         }
278       }
279       for (int k = 0; k < N; ++k) {
280         const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
281         grad -= element_i * targets.cpu_data()[k];
282       }
283       // Scale the gradient over the N samples.
284       grad /= N;
285       // Add the weight decay to the gradient.
286       grad += weight_decay *
287           ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
288       // Finally, compute update.
289       const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
290       if (solver_->type() != string("AdaDelta")
291           && solver_->type() != string("Adam")) {
292         ASSERT_EQ(2, history.size());  // 1 blob for weights, 1 for bias
293       } else {
294         ASSERT_EQ(4, history.size());  // additional blobs for update history
295       }
296       Dtype update_value = learning_rate * grad;
297       const Dtype history_value = (i == D) ?
298             history[1]->cpu_data()[0] : history[0]->cpu_data()[i];
299       const Dtype temp = momentum * history_value;
300       if (solver_->type() == string("SGD")) {
301         update_value += temp;
302       } else if (solver_->type() == string("Nesterov")) {
303         update_value += temp;
304         // step back then over-step
305         update_value = (1 + momentum) * update_value - temp;
306       } else if (solver_->type() == string("AdaGrad")) {
307         update_value /= std::sqrt(history_value + grad * grad) + delta_;
308       } else if (solver_->type() == string("RMSProp")) {
309         const Dtype rms_decay = 0.95;
310         update_value /= std::sqrt(rms_decay*history_value
311             + grad * grad * (1 - rms_decay)) + delta_;
312       } else if (solver_->type() == string("AdaDelta")) {
313         const Dtype update_history_value = (i == D) ?
314             history[1 + num_param_blobs]->cpu_data()[0] :
315             history[0 + num_param_blobs]->cpu_data()[i];
316         const Dtype weighted_gradient_average =
317             momentum * history_value + (1 - momentum) * (grad * grad);
318         update_value = grad * std::sqrt((update_history_value + delta_) /
319             (weighted_gradient_average + delta_)) * learning_rate;
320         // not actually needed, just here for illustrative purposes
321         // const Dtype weighted_update_average =
322         //   momentum * update_history_value + (1 - momentum) * (update_value);
323       } else if (solver_->type() == string("Adam")) {
324         const Dtype momentum2 = 0.999;
325         const Dtype m = history_value;
326         const Dtype v = (i == D) ?
327             history[1 + num_param_blobs]->cpu_data()[0] :
328             history[0 + num_param_blobs]->cpu_data()[i];
329         const Dtype val_m = (1 - momentum) * grad + momentum * m;
330         const Dtype val_v = (1 - momentum2) * grad * grad + momentum2 * v;
331         Dtype alpha_t = learning_rate *
332             std::sqrt(Dtype(1) - pow(momentum2, num_iters)) /
333             (Dtype(1.) - pow(momentum, num_iters));
334         update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
335       } else {
336         LOG(FATAL) << "Unknown solver type: " << solver_->type();
337       }
338       if (i == D) {
339         updated_bias.mutable_cpu_diff()[0] = update_value;
340         updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value;
341       } else {
342         updated_weights.mutable_cpu_diff()[i] = update_value;
343         updated_weights.mutable_cpu_data()[i] =
344             weights.cpu_data()[i] - update_value;
345       }
346     }
347   }
348
349   void CheckLeastSquaresUpdate(
350       const vector<shared_ptr<Blob<Dtype> > >& updated_params) {
351     const int D = channels_ * height_ * width_;
352
353     const Blob<Dtype>& updated_weights = *updated_params[0];
354     const Blob<Dtype>& updated_bias = *updated_params[1];
355
356     Net<Dtype>& net = *this->solver_->net();
357     ASSERT_TRUE(net.has_layer("innerprod"));
358     const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
359         net.layer_by_name("innerprod")->blobs();
360     ASSERT_EQ(2, param_blobs.size());
361     const Blob<Dtype>& solver_updated_weights = *param_blobs[0];
362     ASSERT_EQ(D, solver_updated_weights.count());
363     const double kPrecision = 1e-2;
364     const double kMinPrecision = 1e-7;
365     for (int i = 0; i < D; ++i) {
366       const Dtype expected_updated_weight = updated_weights.cpu_data()[i];
367       const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i];
368       const Dtype error_margin = std::max(kMinPrecision, kPrecision *
369           std::min(fabs(expected_updated_weight), fabs(solver_updated_weight)));
370       EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin);
371     }
372     const Blob<Dtype>& solver_updated_bias_blob = *param_blobs[1];
373     ASSERT_EQ(1, solver_updated_bias_blob.count());
374     const Dtype expected_updated_bias = updated_bias.cpu_data()[0];
375     const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0];
376     const Dtype error_margin = std::max(kMinPrecision, kPrecision *
377           std::min(fabs(expected_updated_bias), fabs(solver_updated_bias)));
378     EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);
379
380     // Check the solver's history -- should contain the previous update value.
381     if (solver_->type() == string("SGD")) {
382       const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
383       ASSERT_EQ(2, history.size());
384       for (int i = 0; i < D; ++i) {
385         const Dtype expected_history = updated_weights.cpu_diff()[i];
386         const Dtype solver_history = history[0]->cpu_data()[i];
387         const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
388             std::min(fabs(expected_history), fabs(solver_history)));
389         EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
390       }
391       const Dtype expected_history = updated_bias.cpu_diff()[0];
392       const Dtype solver_history = history[1]->cpu_data()[0];
393       const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
394           std::min(fabs(expected_history), fabs(solver_history)));
395       EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
396     }
397   }
398
399   void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay,
400       const Dtype kMomentum, const int kNumIters, const int kIterSize) {
401     const double kPrecision = 1e-2;
402     const double kMinPrecision = 1e-7;
403     // Solve without accumulation and save parameters.
404     this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
405         kNumIters);
406     // Save parameters for comparison.
407     Net<Dtype>& net = *this->solver_->net();
408     const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
409         net.layer_by_name("innerprod")->blobs();
410     vector<shared_ptr<Blob<Dtype> > > noaccum_params(param_blobs.size());
411     for (int i = 0; i < param_blobs.size(); ++i) {
412       noaccum_params[i].reset(new Blob<Dtype>());
413       noaccum_params[i]->CopyFrom(*param_blobs[i], false, true);
414     }
415     // Solve by equivalent accumulation of gradients over divided batches.
416     this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
417         kNumIters, kIterSize);
418     Net<Dtype>& net_accum = *this->solver_->net();
419     const vector<shared_ptr<Blob<Dtype> > >& accum_params =
420         net_accum.layer_by_name("innerprod")->blobs();
421     // Compare accumulated parameters against no accumulation standard.
422     const int D = this->channels_ * this->height_ * this->width_;
423     for (int i = 0; i < D; ++i) {
424       const Dtype expected_param = noaccum_params[0]->cpu_data()[i];
425       const Dtype accum_param = accum_params[0]->cpu_data()[i];
426       const Dtype error_margin = std::max(kMinPrecision, kPrecision *
427           std::min(fabs(expected_param), fabs(accum_param)));
428       EXPECT_NEAR(expected_param, accum_param, error_margin);
429     }
430     ASSERT_EQ(1, accum_params[1]->count());
431     const Dtype expected_bias = noaccum_params[1]->cpu_data()[0];
432     const Dtype accum_bias = accum_params[1]->cpu_data()[0];
433     const Dtype error_margin = std::max(kMinPrecision, kPrecision *
434         std::min(fabs(expected_bias), fabs(accum_bias)));
435     EXPECT_NEAR(expected_bias, accum_bias, error_margin);
436   }
437
438   // Test that the correct update is computed for a regularized least squares
439   // problem:
440   //
441   //            E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2
442   //   \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w
443   //
444   // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1)
445   // w \in R^{(d+1) x 1} ((d+1)th element is the bias)
446   // y \in R^{n x 1}
447   // lambda is weight_decay
448   //
449   // TestLeastSquaresUpdate works "inductively", assuming that the solver
450   // correctly updates the net K (= iter_to_check) times, then given the history
451   // from the Kth update, we compute the (K+1)th update and check that it
452   // matches the solver's (K+1)th update.
453   void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
454       const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
455       const int iter_to_check = 0) {
456     const int kNum = num_;
457     const int kIterSize = 1;
458     // Test over all numbers of devices.
459     int available_devices = 1;
460 #ifndef CPU_ONLY
461     if (Caffe::mode() == Caffe::GPU) {
462       CUDA_CHECK(cudaGetDeviceCount(&available_devices));
463     }
464 #endif
465     for (int devices = 1; devices <= available_devices; ++devices) {
466       // Configure batch size for single / multi device equivalence.
467       // Constant data is needed for multi device as for accumulation.
468       num_ = kNum * devices;
469
470       // Initialize the solver and run K (= iter_to_check) solver iterations
471       // (on single device).
472       RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
473                             iter_to_check, kIterSize, 1);
474
475       // Compute the (K+1)th update using the analytic least squares gradient.
476       vector<shared_ptr<Blob<Dtype> > > updated_params;
477       ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
478           iter_to_check + 1, &updated_params);
479
480       // Reinitialize the solver and run K+1 solver iterations.
481       num_ = kNum;
482       RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
483           iter_to_check + 1, kIterSize, devices);
484
485       // Check that the solver's solution matches ours.
486       CheckLeastSquaresUpdate(updated_params);
487     }
488   }
489
490   void TestSnapshot(const Dtype learning_rate = 1.0,
491       const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
492       const int num_iters = 1) {
493     // Run the solver for num_iters * 2 iterations.
494     const int total_num_iters = num_iters * 2;
495     bool snapshot = false;
496     const int kIterSize = 1;
497     const int kDevices = 1;
498     RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
499         total_num_iters, kIterSize, kDevices, snapshot);
500
501     // Save the resulting param values.
502     vector<shared_ptr<Blob<Dtype> > > param_copies;
503     const vector<Blob<Dtype>*>& orig_params =
504         solver_->net()->learnable_params();
505     param_copies.resize(orig_params.size());
506     for (int i = 0; i < orig_params.size(); ++i) {
507       param_copies[i].reset(new Blob<Dtype>());
508       const bool kReshape = true;
509       for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
510         param_copies[i]->CopyFrom(*orig_params[i], copy_diff, kReshape);
511       }
512     }
513
514     // Save the solver history
515     vector<shared_ptr<Blob<Dtype> > > history_copies;
516     const vector<shared_ptr<Blob<Dtype> > >& orig_history = solver_->history();
517     history_copies.resize(orig_history.size());
518     for (int i = 0; i < orig_history.size(); ++i) {
519       history_copies[i].reset(new Blob<Dtype>());
520       const bool kReshape = true;
521       for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
522         history_copies[i]->CopyFrom(*orig_history[i], copy_diff, kReshape);
523       }
524     }
525
526     // Run the solver for num_iters iterations and snapshot.
527     snapshot = true;
528     string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
529         momentum, num_iters, kIterSize, kDevices, snapshot);
530
531     // Reinitialize the solver and run for num_iters more iterations.
532     snapshot = false;
533     RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
534         total_num_iters, kIterSize, kDevices,
535         snapshot, snapshot_name.c_str());
536
537     // Check that params now match.
538     const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params();
539     for (int i = 0; i < params.size(); ++i) {
540       for (int j = 0; j < params[i]->count(); ++j) {
541         EXPECT_EQ(param_copies[i]->cpu_data()[j], params[i]->cpu_data()[j])
542             << "param " << i << " data differed at dim " << j;
543         EXPECT_EQ(param_copies[i]->cpu_diff()[j], params[i]->cpu_diff()[j])
544             << "param " << i << " diff differed at dim " << j;
545       }
546     }
547
548     // Check that history now matches.
549     const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
550     for (int i = 0; i < history.size(); ++i) {
551       for (int j = 0; j < history[i]->count(); ++j) {
552         EXPECT_EQ(history_copies[i]->cpu_data()[j], history[i]->cpu_data()[j])
553             << "history blob " << i << " data differed at dim " << j;
554         EXPECT_EQ(history_copies[i]->cpu_diff()[j], history[i]->cpu_diff()[j])
555             << "history blob " << i << " diff differed at dim " << j;
556       }
557     }
558   }
559 };
560
561
562 template <typename TypeParam>
563 class SGDSolverTest : public GradientBasedSolverTest<TypeParam> {
564   typedef typename TypeParam::Dtype Dtype;
565
566  protected:
567   virtual void InitSolver(const SolverParameter& param) {
568     this->solver_.reset(new SGDSolver<Dtype>(param));
569   }
570 };
571
572 TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices);
573
574 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdate) {
575   this->TestLeastSquaresUpdate();
576 }
577
578 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateLROneHundredth) {
579   typedef typename TypeParam::Dtype Dtype;
580   const Dtype kLearningRate = 0.01;
581   this->TestLeastSquaresUpdate(kLearningRate);
582 }
583
584 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecay) {
585   typedef typename TypeParam::Dtype Dtype;
586   const Dtype kLearningRate = 0.01;
587   const Dtype kWeightDecay = 0.5;
588   const Dtype kMomentum = 0;
589   const int kNumIters = 1;
590   for (int i = 0; i <= kNumIters; ++i) {
591     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
592   }
593 }
594
595 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecayMultiIter) {
596   typedef typename TypeParam::Dtype Dtype;
597   const Dtype kLearningRate = 0.01;
598   const Dtype kWeightDecay = 0.5;
599   const Dtype kMomentum = 0;
600   const int kNumIters = 4;
601   for (int i = 0; i <= kNumIters; ++i) {
602     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
603   }
604 }
605
606 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) {
607   typedef typename TypeParam::Dtype Dtype;
608   const Dtype kLearningRate = 0.01;
609   const Dtype kWeightDecay = 0;
610   const Dtype kMomentum = 0.5;
611   const int kNumIters = 1;
612   for (int i = 0; i <= kNumIters; ++i) {
613     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
614   }
615 }
616
617 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
618   typedef typename TypeParam::Dtype Dtype;
619   const Dtype kLearningRate = 0.01;
620   const Dtype kWeightDecay = 0;
621   const Dtype kMomentum = 0.5;
622   const int kNumIters = 4;
623   for (int i = 0; i <= kNumIters; ++i) {
624     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
625   }
626 }
627
628 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) {
629   typedef typename TypeParam::Dtype Dtype;
630   const Dtype kLearningRate = 0.01;
631   const Dtype kWeightDecay = 0.5;
632   const Dtype kMomentum = 0.5;
633   const int kNumIters = 4;
634   for (int i = 0; i <= kNumIters; ++i) {
635     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
636   }
637 }
638
639 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingShare) {
640   typedef typename TypeParam::Dtype Dtype;
641   const Dtype kLearningRate = 0.01;
642   const Dtype kWeightDecay = 0.5;
643   const Dtype kMomentum = 0.5;
644   const int kNumIters = 4;
645   this->share_ = true;
646   for (int i = 0; i <= kNumIters; ++i) {
647     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
648   }
649 }
650
651 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
652   typedef typename TypeParam::Dtype Dtype;
653   const Dtype kLearningRate = 0.01;
654   const Dtype kWeightDecay = 0.5;
655   const Dtype kMomentum = 0.9;
656   const int kNumIters = 4;
657   const int kIterSize = 2;
658   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
659       kIterSize);
660 }
661
662 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
663   typedef typename TypeParam::Dtype Dtype;
664   const Dtype kLearningRate = 0.01;
665   const Dtype kWeightDecay = 0.5;
666   const Dtype kMomentum = 0.9;
667   const int kNumIters = 4;
668   const int kIterSize = 2;
669   this->share_ = true;
670   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
671       kIterSize);
672 }
673
674 TYPED_TEST(SGDSolverTest, TestSnapshot) {
675   typedef typename TypeParam::Dtype Dtype;
676   const Dtype kLearningRate = 0.01;
677   const Dtype kWeightDecay = 0.5;
678   const Dtype kMomentum = 0.9;
679   const int kNumIters = 4;
680   for (int i = 1; i <= kNumIters; ++i) {
681     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
682   }
683 }
684
685 TYPED_TEST(SGDSolverTest, TestSnapshotShare) {
686   typedef typename TypeParam::Dtype Dtype;
687   const Dtype kLearningRate = 0.01;
688   const Dtype kWeightDecay = 0.5;
689   const Dtype kMomentum = 0.9;
690   const int kNumIters = 4;
691   this->share_ = true;
692   for (int i = 1; i <= kNumIters; ++i) {
693     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
694   }
695 }
696
697 TYPED_TEST(SGDSolverTest, TestSolverType) {
698   this->TestLeastSquaresUpdate();
699   EXPECT_NE(this->solver_->type(), string(""));
700   EXPECT_EQ(this->solver_->type(), this->solver_->param().type());
701 }
702
703 template <typename TypeParam>
704 class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
705   typedef typename TypeParam::Dtype Dtype;
706
707  protected:
708   virtual void InitSolver(const SolverParameter& param) {
709     this->solver_.reset(new AdaGradSolver<Dtype>(param));
710   }
711 };
712
713 TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices);
714
715 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdate) {
716   this->TestLeastSquaresUpdate();
717 }
718
719 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateLROneHundredth) {
720   typedef typename TypeParam::Dtype Dtype;
721   const Dtype kLearningRate = 0.01;
722   this->TestLeastSquaresUpdate(kLearningRate);
723 }
724
725 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithWeightDecay) {
726   typedef typename TypeParam::Dtype Dtype;
727   const Dtype kLearningRate = 0.01;
728   const Dtype kWeightDecay = 0.5;
729   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
730 }
731
732 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) {
733   typedef typename TypeParam::Dtype Dtype;
734   const Dtype kLearningRate = 0.01;
735   const Dtype kWeightDecay = 0.5;
736   const Dtype kMomentum = 0;
737   const int kNumIters = 4;
738   for (int i = 0; i <= kNumIters; ++i) {
739     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
740   }
741 }
742
743 TYPED_TEST(AdaGradSolverTest,
744       TestAdaGradLeastSquaresUpdateWithEverythingShare) {
745   typedef typename TypeParam::Dtype Dtype;
746   const Dtype kLearningRate = 0.01;
747   const Dtype kWeightDecay = 0.5;
748   const Dtype kMomentum = 0;
749   const int kNumIters = 4;
750   this->share_ = true;
751   for (int i = 0; i <= kNumIters; ++i) {
752     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
753   }
754 }
755
756 TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
757   typedef typename TypeParam::Dtype Dtype;
758   const Dtype kLearningRate = 0.01;
759   const Dtype kWeightDecay = 0.5;
760   const Dtype kMomentum = 0;
761   const int kNumIters = 4;
762   const int kIterSize = 2;
763   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
764       kIterSize);
765 }
766
767 TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
768   typedef typename TypeParam::Dtype Dtype;
769   const Dtype kLearningRate = 0.01;
770   const Dtype kWeightDecay = 0.5;
771   const Dtype kMomentum = 0;
772   const int kNumIters = 4;
773   const int kIterSize = 2;
774   this->share_ = true;
775   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
776       kIterSize);
777 }
778
779 TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
780   typedef typename TypeParam::Dtype Dtype;
781   const Dtype kLearningRate = 0.01;
782   const Dtype kWeightDecay = 0.5;
783   const Dtype kMomentum = 0;
784   const int kNumIters = 4;
785   for (int i = 1; i <= kNumIters; ++i) {
786     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
787   }
788 }
789
790 TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) {
791   typedef typename TypeParam::Dtype Dtype;
792   const Dtype kLearningRate = 0.01;
793   const Dtype kWeightDecay = 0.5;
794   const Dtype kMomentum = 0;
795   const int kNumIters = 4;
796   this->share_ = true;
797   for (int i = 1; i <= kNumIters; ++i) {
798     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
799   }
800 }
801
802
803 template <typename TypeParam>
804 class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
805   typedef typename TypeParam::Dtype Dtype;
806
807  protected:
808   virtual void InitSolver(const SolverParameter& param) {
809     this->solver_.reset(new NesterovSolver<Dtype>(param));
810   }
811 };
812
813 TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices);
814
815 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdate) {
816   this->TestLeastSquaresUpdate();
817 }
818
819 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateLROneHundredth) {
820   typedef typename TypeParam::Dtype Dtype;
821   const Dtype kLearningRate = 0.01;
822   this->TestLeastSquaresUpdate(kLearningRate);
823 }
824
825 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithWeightDecay) {
826   typedef typename TypeParam::Dtype Dtype;
827   const Dtype kLearningRate = 0.01;
828   const Dtype kWeightDecay = 0.5;
829   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
830 }
831
832 TYPED_TEST(NesterovSolverTest,
833            TestNesterovLeastSquaresUpdateWithWeightDecayMultiIter) {
834   typedef typename TypeParam::Dtype Dtype;
835   const Dtype kLearningRate = 0.01;
836   const Dtype kWeightDecay = 0.5;
837   const Dtype kMomentum = 0;
838   const int kNumIters = 4;
839   for (int i = 0; i <= kNumIters; ++i) {
840     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
841   }
842 }
843
844 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) {
845   typedef typename TypeParam::Dtype Dtype;
846   const Dtype kLearningRate = 0.01;
847   const Dtype kWeightDecay = 0;
848   const Dtype kMomentum = 0.5;
849   const int kNumIters = 1;
850   for (int i = 0; i <= kNumIters; ++i) {
851     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
852   }
853 }
854
855 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
856   typedef typename TypeParam::Dtype Dtype;
857   const Dtype kLearningRate = 0.01;
858   const Dtype kWeightDecay = 0;
859   const Dtype kMomentum = 0.5;
860   const int kNumIters = 4;
861   for (int i = 0; i <= kNumIters; ++i) {
862     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
863   }
864 }
865
866 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) {
867   typedef typename TypeParam::Dtype Dtype;
868   const Dtype kLearningRate = 0.01;
869   const Dtype kWeightDecay = 0.5;
870   const Dtype kMomentum = 0.9;
871   const int kNumIters = 4;
872   for (int i = 0; i <= kNumIters; ++i) {
873     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
874   }
875 }
876
877 TYPED_TEST(NesterovSolverTest,
878            TestNesterovLeastSquaresUpdateWithEverythingShare) {
879   typedef typename TypeParam::Dtype Dtype;
880   const Dtype kLearningRate = 0.01;
881   const Dtype kWeightDecay = 0.5;
882   const Dtype kMomentum = 0.9;
883   const int kNumIters = 4;
884   this->share_ = true;
885   for (int i = 0; i <= kNumIters; ++i) {
886     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
887   }
888 }
889
890 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
891   typedef typename TypeParam::Dtype Dtype;
892   const Dtype kLearningRate = 0.01;
893   const Dtype kWeightDecay = 0.5;
894   const Dtype kMomentum = 0.9;
895   const int kNumIters = 4;
896   const int kIterSize = 2;
897   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
898       kIterSize);
899 }
900
901 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
902   typedef typename TypeParam::Dtype Dtype;
903   const Dtype kLearningRate = 0.01;
904   const Dtype kWeightDecay = 0.5;
905   const Dtype kMomentum = 0.9;
906   const int kNumIters = 4;
907   const int kIterSize = 2;
908   this->share_ = true;
909   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
910       kIterSize);
911 }
912
913 TYPED_TEST(NesterovSolverTest, TestSnapshot) {
914   typedef typename TypeParam::Dtype Dtype;
915   const Dtype kLearningRate = 0.01;
916   const Dtype kWeightDecay = 0.5;
917   const Dtype kMomentum = 0.9;
918   const int kNumIters = 4;
919   for (int i = 1; i <= kNumIters; ++i) {
920     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
921   }
922 }
923
924 TYPED_TEST(NesterovSolverTest, TestSnapshotShare) {
925   typedef typename TypeParam::Dtype Dtype;
926   const Dtype kLearningRate = 0.01;
927   const Dtype kWeightDecay = 0.5;
928   const Dtype kMomentum = 0.9;
929   const int kNumIters = 4;
930   this->share_ = true;
931   for (int i = 1; i <= kNumIters; ++i) {
932     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
933   }
934 }
935
936 template <typename TypeParam>
937 class AdaDeltaSolverTest : public GradientBasedSolverTest<TypeParam> {
938   typedef typename TypeParam::Dtype Dtype;
939
940  protected:
941   virtual void InitSolver(const SolverParameter& param) {
942     this->solver_.reset(new AdaDeltaSolver<Dtype>(param));
943   }
944 };
945
946 TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices);
947
948 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) {
949   typedef typename TypeParam::Dtype Dtype;
950   const Dtype kLearningRate = 0.1;
951   this->TestLeastSquaresUpdate(kLearningRate);
952 }
953
954 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) {
955   typedef typename TypeParam::Dtype Dtype;
956   const Dtype kLearningRate = 0.1;
957   const Dtype kWeightDecay = 0.5;
958   const Dtype kMomentum = 0.95;
959   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
960 }
961
962 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) {
963   typedef typename TypeParam::Dtype Dtype;
964   const Dtype kLearningRate = 0.1;
965   const Dtype kWeightDecay = 0.0;
966   const Dtype kMomentum = 0.5;
967   const int kNumIters = 1;
968   for (int i = 0; i <= kNumIters; ++i) {
969     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
970   }
971 }
972
973 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) {
974   typedef typename TypeParam::Dtype Dtype;
975   const Dtype kLearningRate = 0.1;
976   const Dtype kWeightDecay = 0.0;
977   const Dtype kMomentum = 0.95;
978   const int kNumIters = 1;
979   for (int i = 0; i <= kNumIters; ++i) {
980     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
981   }
982 }
983
984 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
985   typedef typename TypeParam::Dtype Dtype;
986   const Dtype kLearningRate = 0.1;
987   const Dtype kWeightDecay = 0.0;
988   const Dtype kMomentum = 0.95;
989   const int kNumIters = 4;
990   for (int i = 0; i <= kNumIters; ++i) {
991     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
992   }
993 }
994
995 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
996   typedef typename TypeParam::Dtype Dtype;
997   const Dtype kLearningRate = 0.1;
998   const Dtype kWeightDecay = 0.1;
999   const Dtype kMomentum = 0.95;
1000   const int kNumIters = 4;
1001   for (int i = 0; i <= kNumIters; ++i) {
1002     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1003   }
1004 }
1005
1006 TYPED_TEST(AdaDeltaSolverTest,
1007            TestAdaDeltaLeastSquaresUpdateWithEverythingShare) {
1008   typedef typename TypeParam::Dtype Dtype;
1009   const Dtype kLearningRate = 0.1;
1010   const Dtype kWeightDecay = 0.1;
1011   const Dtype kMomentum = 0.95;
1012   const int kNumIters = 4;
1013   this->share_ = true;
1014   for (int i = 0; i <= kNumIters; ++i) {
1015     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1016   }
1017 }
1018
1019 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1020   typedef typename TypeParam::Dtype Dtype;
1021   const Dtype kLearningRate = 0.1;
1022   const Dtype kWeightDecay = 0.1;
1023   const Dtype kMomentum = 0.95;
1024   const int kNumIters = 4;
1025   const int kIterSize = 2;
1026   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1027       kIterSize);
1028 }
1029
1030 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1031   typedef typename TypeParam::Dtype Dtype;
1032   const Dtype kLearningRate = 0.1;
1033   const Dtype kWeightDecay = 0.1;
1034   const Dtype kMomentum = 0.95;
1035   const int kNumIters = 4;
1036   const int kIterSize = 2;
1037   this->share_ = true;
1038   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1039       kIterSize);
1040 }
1041
1042 TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) {
1043   typedef typename TypeParam::Dtype Dtype;
1044   const Dtype kLearningRate = 0.1;
1045   const Dtype kWeightDecay = 0.1;
1046   const Dtype kMomentum = 0.95;
1047   const int kNumIters = 4;
1048   for (int i = 1; i <= kNumIters; ++i) {
1049     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1050   }
1051 }
1052
1053 TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) {
1054   typedef typename TypeParam::Dtype Dtype;
1055   const Dtype kLearningRate = 0.1;
1056   const Dtype kWeightDecay = 0.1;
1057   const Dtype kMomentum = 0.95;
1058   const int kNumIters = 4;
1059   this->share_ = true;
1060   for (int i = 1; i <= kNumIters; ++i) {
1061     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1062   }
1063 }
1064
1065 template <typename TypeParam>
1066 class AdamSolverTest : public GradientBasedSolverTest<TypeParam> {
1067   typedef typename TypeParam::Dtype Dtype;
1068
1069  protected:
1070   virtual void InitSolver(const SolverParameter& param) {
1071     SolverParameter new_param = param;
1072     const Dtype momentum = 0.9;
1073     new_param.set_momentum(momentum);
1074     const Dtype momentum2 = 0.999;
1075     new_param.set_momentum2(momentum2);
1076     this->solver_.reset(new AdamSolver<Dtype>(new_param));
1077   }
1078 };
1079
1080 TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices);
1081
1082 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) {
1083   typedef typename TypeParam::Dtype Dtype;
1084   const Dtype kLearningRate = 0.01;
1085   const Dtype kWeightDecay = 0;
1086   const Dtype kMomentum = 0.9;
1087   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
1088 }
1089
1090 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithWeightDecay) {
1091   typedef typename TypeParam::Dtype Dtype;
1092   const Dtype kLearningRate = 0.01;
1093   const Dtype kWeightDecay = 0.5;
1094   const Dtype kMomentum = 0.9;
1095   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
1096 }
1097
1098 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverything) {
1099   typedef typename TypeParam::Dtype Dtype;
1100   const Dtype kLearningRate = 0.01;
1101   const Dtype kWeightDecay = 0.5;
1102   const Dtype kMomentum = 0.9;
1103   const int kNumIters = 4;
1104   for (int i = 0; i <= kNumIters; ++i) {
1105     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1106   }
1107 }
1108
1109 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverythingShare) {
1110   typedef typename TypeParam::Dtype Dtype;
1111   const Dtype kLearningRate = 0.01;
1112   const Dtype kWeightDecay = 0.5;
1113   const Dtype kMomentum = 0.9;
1114   const int kNumIters = 4;
1115   this->share_ = true;
1116   for (int i = 0; i <= kNumIters; ++i) {
1117     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1118   }
1119 }
1120
1121 TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1122   typedef typename TypeParam::Dtype Dtype;
1123   const Dtype kLearningRate = 0.01;
1124   const Dtype kWeightDecay = 0.5;
1125   const Dtype kMomentum = 0.9;
1126   const int kNumIters = 4;
1127   const int kIterSize = 2;
1128   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1129       kIterSize);
1130 }
1131
1132 TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1133   typedef typename TypeParam::Dtype Dtype;
1134   const Dtype kLearningRate = 0.01;
1135   const Dtype kWeightDecay = 0.5;
1136   const Dtype kMomentum = 0.9;
1137   const int kNumIters = 4;
1138   const int kIterSize = 2;
1139   this->share_ = true;
1140   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1141       kIterSize);
1142 }
1143
1144 TYPED_TEST(AdamSolverTest, TestSnapshot) {
1145   typedef typename TypeParam::Dtype Dtype;
1146   const Dtype kLearningRate = 0.01;
1147   const Dtype kWeightDecay = 0.5;
1148   const Dtype kMomentum = 0.9;
1149   const int kNumIters = 4;
1150   for (int i = 1; i <= kNumIters; ++i) {
1151     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1152   }
1153 }
1154
1155 TYPED_TEST(AdamSolverTest, TestSnapshotShare) {
1156   typedef typename TypeParam::Dtype Dtype;
1157   const Dtype kLearningRate = 0.01;
1158   const Dtype kWeightDecay = 0.5;
1159   const Dtype kMomentum = 0.9;
1160   const int kNumIters = 4;
1161   this->share_ = true;
1162   for (int i = 1; i <= kNumIters; ++i) {
1163     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1164   }
1165 }
1166
1167 template <typename TypeParam>
1168 class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
1169   typedef typename TypeParam::Dtype Dtype;
1170
1171  protected:
1172   virtual void InitSolver(const SolverParameter& param) {
1173     const Dtype rms_decay = 0.95;
1174     SolverParameter new_param = param;
1175     new_param.set_rms_decay(rms_decay);
1176     this->solver_.reset(new RMSPropSolver<Dtype>(new_param));
1177   }
1178 };
1179
1180 TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);
1181
1182 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithWeightDecay) {
1183   typedef typename TypeParam::Dtype Dtype;
1184   const Dtype kLearningRate = 1.0;
1185   const Dtype kWeightDecay = 0.5;
1186   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
1187 }
1188
1189 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) {
1190   typedef typename TypeParam::Dtype Dtype;
1191   const Dtype kLearningRate = 0.01;
1192   const Dtype kWeightDecay = 0.0;
1193   const Dtype kMomentum = 0.0;
1194   const int kNumIters = 4;
1195   for (int i = 0; i <= kNumIters; ++i) {
1196     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1197   }
1198 }
1199
1200 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) {
1201   typedef typename TypeParam::Dtype Dtype;
1202   const Dtype kLearningRate = 0.01;
1203   const Dtype kWeightDecay = 0.5;
1204   const Dtype kMomentum = 0.0;
1205   const int kNumIters = 4;
1206   for (int i = 0; i <= kNumIters; ++i) {
1207     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1208   }
1209 }
1210
1211 TYPED_TEST(RMSPropSolverTest,
1212       TestRMSPropLeastSquaresUpdateWithEverythingShare) {
1213   typedef typename TypeParam::Dtype Dtype;
1214   const Dtype kLearningRate = 0.01;
1215   const Dtype kWeightDecay = 0.5;
1216   const Dtype kMomentum = 0.0;
1217   const int kNumIters = 4;
1218   this->share_ = true;
1219   for (int i = 0; i <= kNumIters; ++i) {
1220     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1221   }
1222 }
1223
1224 TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1225   typedef typename TypeParam::Dtype Dtype;
1226   const Dtype kLearningRate = 0.01;
1227   const Dtype kWeightDecay = 0.5;
1228   const Dtype kMomentum = 0.0;
1229   const int kNumIters = 4;
1230   const int kIterSize = 2;
1231   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1232       kIterSize);
1233 }
1234
1235 TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1236   typedef typename TypeParam::Dtype Dtype;
1237   const Dtype kLearningRate = 0.01;
1238   const Dtype kWeightDecay = 0.5;
1239   const Dtype kMomentum = 0.0;
1240   const int kNumIters = 4;
1241   const int kIterSize = 2;
1242   this->share_ = true;
1243   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1244       kIterSize);
1245 }
1246
1247 TYPED_TEST(RMSPropSolverTest, TestSnapshot) {
1248   typedef typename TypeParam::Dtype Dtype;
1249   const Dtype kLearningRate = 0.01;
1250   const Dtype kWeightDecay = 0.5;
1251   const Dtype kMomentum = 0;
1252   const int kNumIters = 4;
1253   for (int i = 1; i <= kNumIters; ++i) {
1254     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1255   }
1256 }
1257
1258 TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) {
1259   typedef typename TypeParam::Dtype Dtype;
1260   const Dtype kLearningRate = 0.01;
1261   const Dtype kWeightDecay = 0.5;
1262   const Dtype kMomentum = 0;
1263   const int kNumIters = 4;
1264   this->share_ = true;
1265   for (int i = 1; i <= kNumIters; ++i) {
1266     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1267   }
1268 }
1269
1270 }  // namespace caffe