975a8f0f88a9b811c2c997926add7d565315ed86
[platform/upstream/caffeonacl.git] / src / caffe / test / test_gradient_based_solver.cpp
1 #include <algorithm>
2 #include <string>
3 #include <utility>
4 #include <vector>
5
6 #include "google/protobuf/text_format.h"
7
8 #include "gtest/gtest.h"
9
10 #include "caffe/common.hpp"
11 #include "caffe/parallel.hpp"
12 #include "caffe/proto/caffe.pb.h"
13 #include "caffe/sgd_solvers.hpp"
14 #include "caffe/util/io.hpp"
15
16 #include "caffe/test/test_caffe_main.hpp"
17
18 using std::ostringstream;
19
20 namespace caffe {
21
22 template <typename TypeParam>
23 class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
24   typedef typename TypeParam::Dtype Dtype;
25
26  protected:
27   GradientBasedSolverTest() :
28       seed_(1701), num_(4), channels_(3), height_(10), width_(10),
29       share_(false) {
30         input_file_ = new string(
31         CMAKE_SOURCE_DIR "caffe/test/test_data/solver_data_list.txt" CMAKE_EXT);
32       }
33   ~GradientBasedSolverTest() {
34     delete input_file_;
35   }
36
37   string snapshot_prefix_;
38   shared_ptr<SGDSolver<Dtype> > solver_;
39   shared_ptr<P2PSync<Dtype> > sync_;
40   int seed_;
41   // Dimensions are determined by generate_sample_data.py
42   // TODO this is brittle and the hdf5 file should be checked instead.
43   int num_, channels_, height_, width_;
44   bool share_;
45   Dtype delta_;  // Stability constant for RMSProp, AdaGrad, AdaDelta and Adam
46
47   // Test data: check out generate_sample_data.py in the same directory.
48   string* input_file_;
49
50   virtual void InitSolver(const SolverParameter& param) = 0;
51
52   virtual void InitSolverFromProtoString(const string& proto) {
53     SolverParameter param;
54     CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
55     // Set the solver_mode according to current Caffe::mode.
56     switch (Caffe::mode()) {
57       case Caffe::CPU:
58         param.set_solver_mode(SolverParameter_SolverMode_CPU);
59         break;
60       case Caffe::GPU:
61         param.set_solver_mode(SolverParameter_SolverMode_GPU);
62         break;
63       default:
64         LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
65     }
66     InitSolver(param);
67     delta_ = param.delta();
68   }
69
70   string RunLeastSquaresSolver(const Dtype learning_rate,
71       const Dtype weight_decay, const Dtype momentum, const int num_iters,
72       const int iter_size = 1, const int devices = 1,
73       const bool snapshot = false, const char* from_snapshot = NULL) {
74     ostringstream proto;
75     int device_id = 0;
76 #ifndef CPU_ONLY
77     if (Caffe::mode() == Caffe::GPU) {
78       CUDA_CHECK(cudaGetDevice(&device_id));
79     }
80 #endif
81     proto <<
82        "snapshot_after_train: " << snapshot << " "
83        "max_iter: " << num_iters << " "
84        "base_lr: " << learning_rate << " "
85        "lr_policy: 'fixed' "
86        "iter_size: " << iter_size << " "
87        "device_id: " << device_id << " "
88        "net_param { "
89        "  name: 'TestNetwork' "
90        "  layer { "
91        "    name: 'data' "
92        "    type: 'HDF5Data' "
93        "    hdf5_data_param { "
94        "      source: '" << *(this->input_file_) << "' "
95        "      batch_size: " << num_ / iter_size << " "
96        "    } "
97        "    top: 'data' "
98        "    top: 'targets' "
99        "  } ";
100     if (share_) {
101       proto <<
102          "  layer { "
103          "    name: 'slice' "
104          "    type: 'Slice' "
105          "    bottom: 'data' "
106          "    top: 'data1' "
107          "    top: 'data2' "
108          "    slice_param { "
109          "      axis: 0 "
110          "    } "
111          "  } ";
112     }
113     proto <<
114        "  layer { "
115        "    name: 'innerprod' "
116        "    type: 'InnerProduct' "
117        "    param { name: 'weights' } "
118        "    param { name: 'bias' } "
119        "    inner_product_param { "
120        "      num_output: 1 "
121        "      weight_filler { "
122        "        type: 'gaussian' "
123        "        std: 1.0 "
124        "      } "
125        "      bias_filler { "
126        "        type: 'gaussian' "
127        "        std: 1.0 "
128        "      } "
129        "    } "
130        "    bottom: '" << string(share_ ? "data1": "data") << "' "
131        "    top: '" << string(share_ ? "innerprod1": "innerprod") << "' "
132        "  } ";
133     if (share_) {
134       proto <<
135          "  layer { "
136          "    name: 'innerprod2' "
137          "    type: 'InnerProduct' "
138          "    param { name: 'weights' } "
139          "    param { name: 'bias' } "
140          "    inner_product_param { "
141          "      num_output: 1 "
142          "      weight_filler { "
143          "        type: 'gaussian' "
144          "        std: 1.0 "
145          "      } "
146          "      bias_filler { "
147          "        type: 'gaussian' "
148          "        std: 1.0 "
149          "      } "
150          "    } "
151          "    bottom: 'data2' "
152          "    top: 'innerprod2' "
153          "  } "
154          "  layer { "
155          "    name: 'concat' "
156          "    type: 'Concat' "
157          "    bottom: 'innerprod1' "
158          "    bottom: 'innerprod2' "
159          "    top: 'innerprod' "
160          "    concat_param { "
161          "      axis: 0 "
162          "    } "
163          "  } ";
164     }
165     proto <<
166        "  layer { "
167        "    name: 'loss' "
168        "    type: 'EuclideanLoss' "
169        "    bottom: 'innerprod' "
170        "    bottom: 'targets' "
171        "  } "
172        "} ";
173     if (weight_decay != 0) {
174       proto << "weight_decay: " << weight_decay << " ";
175     }
176     if (momentum != 0) {
177       proto << "momentum: " << momentum << " ";
178     }
179     MakeTempDir(&snapshot_prefix_);
180     proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' ";
181     if (snapshot) {
182       proto << "snapshot: " << num_iters << " ";
183     }
184     Caffe::set_random_seed(this->seed_);
185     this->InitSolverFromProtoString(proto.str());
186     if (from_snapshot != NULL) {
187       this->solver_->Restore(from_snapshot);
188       for (int i = 0; i < this->solver_->iter(); ++i) {
189         this->solver_->net()->Forward();
190       }
191     }
192     if (devices == 1) {
193       this->solver_->Solve();
194     } else {
195       LOG(INFO) << "Multi-GPU test on " << devices << " devices";
196       vector<int> gpus;
197       // put current device at the beginning
198       int device_id = solver_->param().device_id();
199       gpus.push_back(device_id);
200       for (int i = 0; gpus.size() < devices; ++i) {
201         if (i != device_id)
202           gpus.push_back(i);
203       }
204       Caffe::set_solver_count(gpus.size());
205       this->sync_.reset(new P2PSync<Dtype>(
206           this->solver_, NULL, this->solver_->param()));
207       this->sync_->Run(gpus);
208       Caffe::set_solver_count(1);
209     }
210     if (snapshot) {
211       ostringstream resume_file;
212       resume_file << snapshot_prefix_ << "/_iter_" << num_iters
213                   << ".solverstate";
214       string resume_filename = resume_file.str();
215       return resume_filename;
216     }
217     return string();
218   }
219
220   // Compute an update value given the current state of the train net,
221   // using the analytical formula for the least squares gradient.
222   // updated_params will store the updated weight and bias results,
223   // using the blobs' diffs to hold the update values themselves.
224   void ComputeLeastSquaresUpdate(const Dtype learning_rate,
225       const Dtype weight_decay, const Dtype momentum, const int num_iters,
226       vector<shared_ptr<Blob<Dtype> > >* updated_params) {
227     const int N = num_;
228     const int D = channels_ * height_ * width_;
229
230     // Run a forward pass, and manually compute the update values from the
231     // result.
232     Net<Dtype>& net = *this->solver_->net();
233     net.Forward();
234     ASSERT_TRUE(net.has_blob("data"));
235     const Blob<Dtype>& data = *net.blob_by_name("data");
236     ASSERT_TRUE(net.has_blob("targets"));
237     const Blob<Dtype>& targets = *net.blob_by_name("targets");
238     ASSERT_TRUE(net.has_layer("innerprod"));
239     const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
240         net.layer_by_name("innerprod")->blobs();
241     const int num_param_blobs = 2;
242     ASSERT_EQ(num_param_blobs, param_blobs.size());
243     const Blob<Dtype>& weights = *param_blobs[0];
244     const Blob<Dtype>& bias = *param_blobs[1];
245     ASSERT_EQ(D * N, data.count());
246     ASSERT_EQ(N, targets.count());
247     ASSERT_EQ(D, weights.count());
248     ASSERT_EQ(1, bias.count());
249
250     updated_params->clear();
251     updated_params->resize(num_param_blobs);
252     for (int i = 0; i < num_param_blobs; ++i) {
253       (*updated_params)[i].reset(new Blob<Dtype>());
254     }
255     Blob<Dtype>& updated_weights = *(*updated_params)[0];
256     updated_weights.ReshapeLike(weights);
257     Blob<Dtype>& updated_bias = *(*updated_params)[1];
258     updated_bias.ReshapeLike(bias);
259
260     for (int i = 0; i <= D; ++i) {
261       // Compute the derivative with respect to the ith weight (i.e., the ith
262       // element of the gradient).
263       Dtype grad = 0;
264       for (int j = 0; j <= D; ++j) {
265         // Compute element (i, j) of X^T * X.
266         Dtype element = 0;
267         for (int k = 0; k < N; ++k) {
268           // (i, k) in X^T (== (k, i) in X) times (k, j) in X.
269           const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
270           const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j];
271           element += element_i * element_j;
272         }
273         if (j == D) {
274           grad += element * bias.cpu_data()[0];
275         } else {
276           grad += element * weights.cpu_data()[j];
277         }
278       }
279       for (int k = 0; k < N; ++k) {
280         const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
281         grad -= element_i * targets.cpu_data()[k];
282       }
283       // Scale the gradient over the N samples.
284       grad /= N;
285       // Add the weight decay to the gradient.
286       grad += weight_decay *
287           ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
288       // Finally, compute update.
289       const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
290       if (solver_->type() != string("AdaDelta")
291           && solver_->type() != string("Adam")) {
292         ASSERT_EQ(2, history.size());  // 1 blob for weights, 1 for bias
293       } else {
294         ASSERT_EQ(4, history.size());  // additional blobs for update history
295       }
296       Dtype update_value = learning_rate * grad;
297       const Dtype history_value = (i == D) ?
298             history[1]->cpu_data()[0] : history[0]->cpu_data()[i];
299       const Dtype temp = momentum * history_value;
300       if (solver_->type() == string("SGD")) {
301         update_value += temp;
302       } else if (solver_->type() == string("Nesterov")) {
303         update_value += temp;
304         // step back then over-step
305         update_value = (1 + momentum) * update_value - temp;
306       } else if (solver_->type() == string("AdaGrad")) {
307         update_value /= std::sqrt(history_value + grad * grad) + delta_;
308       } else if (solver_->type() == string("RMSProp")) {
309         const Dtype rms_decay = 0.95;
310         update_value /= std::sqrt(rms_decay*history_value
311             + grad * grad * (1 - rms_decay)) + delta_;
312       } else if (solver_->type() == string("AdaDelta")) {
313         const Dtype update_history_value = (i == D) ?
314             history[1 + num_param_blobs]->cpu_data()[0] :
315             history[0 + num_param_blobs]->cpu_data()[i];
316         const Dtype weighted_gradient_average =
317             momentum * history_value + (1 - momentum) * (grad * grad);
318         update_value = grad * std::sqrt((update_history_value + delta_) /
319             (weighted_gradient_average + delta_)) * learning_rate;
320         // not actually needed, just here for illustrative purposes
321         // const Dtype weighted_update_average =
322         //   momentum * update_history_value + (1 - momentum) * (update_value);
323       } else if (solver_->type() == string("Adam")) {
324         const Dtype momentum2 = 0.999;
325         const Dtype m = history_value;
326         const Dtype v = (i == D) ?
327             history[1 + num_param_blobs]->cpu_data()[0] :
328             history[0 + num_param_blobs]->cpu_data()[i];
329         const Dtype val_m = (1 - momentum) * grad + momentum * m;
330         const Dtype val_v = (1 - momentum2) * grad * grad + momentum2 * v;
331         Dtype alpha_t = learning_rate *
332             std::sqrt(Dtype(1) - pow(momentum2, num_iters)) /
333             (Dtype(1.) - pow(momentum, num_iters));
334         update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
335       } else {
336         LOG(FATAL) << "Unknown solver type: " << solver_->type();
337       }
338       if (i == D) {
339         updated_bias.mutable_cpu_diff()[0] = update_value;
340         updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value;
341       } else {
342         updated_weights.mutable_cpu_diff()[i] = update_value;
343         updated_weights.mutable_cpu_data()[i] =
344             weights.cpu_data()[i] - update_value;
345       }
346     }
347   }
348
349   void CheckLeastSquaresUpdate(
350       const vector<shared_ptr<Blob<Dtype> > >& updated_params) {
351     const int D = channels_ * height_ * width_;
352
353     const Blob<Dtype>& updated_weights = *updated_params[0];
354     const Blob<Dtype>& updated_bias = *updated_params[1];
355
356     Net<Dtype>& net = *this->solver_->net();
357     ASSERT_TRUE(net.has_layer("innerprod"));
358     const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
359         net.layer_by_name("innerprod")->blobs();
360     ASSERT_EQ(2, param_blobs.size());
361     const Blob<Dtype>& solver_updated_weights = *param_blobs[0];
362     ASSERT_EQ(D, solver_updated_weights.count());
363     const double kPrecision = 1e-2;
364     const double kMinPrecision = 1e-7;
365     for (int i = 0; i < D; ++i) {
366       const Dtype expected_updated_weight = updated_weights.cpu_data()[i];
367       const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i];
368       const Dtype error_margin = std::max(kMinPrecision, kPrecision *
369           std::min(fabs(expected_updated_weight), fabs(solver_updated_weight)));
370       EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin);
371     }
372     const Blob<Dtype>& solver_updated_bias_blob = *param_blobs[1];
373     ASSERT_EQ(1, solver_updated_bias_blob.count());
374     const Dtype expected_updated_bias = updated_bias.cpu_data()[0];
375     const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0];
376     const Dtype error_margin = std::max(kMinPrecision, kPrecision *
377           std::min(fabs(expected_updated_bias), fabs(solver_updated_bias)));
378     EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);
379
380     // Check the solver's history -- should contain the previous update value.
381     if (solver_->type() == string("SGD")) {
382       const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
383       ASSERT_EQ(2, history.size());
384       for (int i = 0; i < D; ++i) {
385         const Dtype expected_history = updated_weights.cpu_diff()[i];
386         const Dtype solver_history = history[0]->cpu_data()[i];
387         const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
388             std::min(fabs(expected_history), fabs(solver_history)));
389         EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
390       }
391       const Dtype expected_history = updated_bias.cpu_diff()[0];
392       const Dtype solver_history = history[1]->cpu_data()[0];
393       const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
394           std::min(fabs(expected_history), fabs(solver_history)));
395       EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
396     }
397   }
398
399   void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay,
400       const Dtype kMomentum, const int kNumIters, const int kIterSize) {
401     const double kPrecision = 1e-2;
402     const double kMinPrecision = 1e-7;
403     // Solve without accumulation and save parameters.
404     this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
405         kNumIters);
406     // Save parameters for comparison.
407     Net<Dtype>& net = *this->solver_->net();
408     const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
409         net.layer_by_name("innerprod")->blobs();
410     vector<shared_ptr<Blob<Dtype> > > noaccum_params(param_blobs.size());
411     for (int i = 0; i < param_blobs.size(); ++i) {
412       noaccum_params[i].reset(new Blob<Dtype>());
413       noaccum_params[i]->CopyFrom(*param_blobs[i], false, true);
414     }
415     // Solve by equivalent accumulation of gradients over divided batches.
416     this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
417         kNumIters, kIterSize);
418     Net<Dtype>& net_accum = *this->solver_->net();
419     const vector<shared_ptr<Blob<Dtype> > >& accum_params =
420         net_accum.layer_by_name("innerprod")->blobs();
421     // Compare accumulated parameters against no accumulation standard.
422     const int D = this->channels_ * this->height_ * this->width_;
423     for (int i = 0; i < D; ++i) {
424       const Dtype expected_param = noaccum_params[0]->cpu_data()[i];
425       const Dtype accum_param = accum_params[0]->cpu_data()[i];
426       const Dtype error_margin = std::max(kMinPrecision, kPrecision *
427           std::min(fabs(expected_param), fabs(accum_param)));
428       EXPECT_NEAR(expected_param, accum_param, error_margin);
429     }
430     ASSERT_EQ(1, accum_params[1]->count());
431     const Dtype expected_bias = noaccum_params[1]->cpu_data()[0];
432     const Dtype accum_bias = accum_params[1]->cpu_data()[0];
433     const Dtype error_margin = std::max(kMinPrecision, kPrecision *
434         std::min(fabs(expected_bias), fabs(accum_bias)));
435     EXPECT_NEAR(expected_bias, accum_bias, error_margin);
436   }
437
438   // Test that the correct update is computed for a regularized least squares
439   // problem:
440   //
441   //            E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2
442   //   \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w
443   //
444   // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1)
445   // w \in R^{(d+1) x 1} ((d+1)th element is the bias)
446   // y \in R^{n x 1}
447   // lambda is weight_decay
448   //
449   // TestLeastSquaresUpdate works "inductively", assuming that the solver
450   // correctly updates the net K (= iter_to_check) times, then given the history
451   // from the Kth update, we compute the (K+1)th update and check that it
452   // matches the solver's (K+1)th update.
453   void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
454       const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
455       const int iter_to_check = 0) {
456     const int kNum = num_;
457     const int kIterSize = 1;
458     // Test over all numbers of devices.
459     int available_devices = 1;
460 #ifndef CPU_ONLY
461     if (Caffe::mode() == Caffe::GPU) {
462       CUDA_CHECK(cudaGetDeviceCount(&available_devices));
463     }
464 #endif
465     for (int devices = 1; devices <= available_devices; ++devices) {
466       // Configure batch size for single / multi device equivalence.
467       // Constant data is needed for multi device as for accumulation.
468       num_ = kNum * devices;
469
470       // Initialize the solver and run K (= iter_to_check) solver iterations
471       // (on single device).
472       RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
473                             iter_to_check, kIterSize, 1);
474
475       // Compute the (K+1)th update using the analytic least squares gradient.
476       vector<shared_ptr<Blob<Dtype> > > updated_params;
477       ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
478           iter_to_check + 1, &updated_params);
479
480       // Reinitialize the solver and run K+1 solver iterations.
481       num_ = kNum;
482       RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
483           iter_to_check + 1, kIterSize, devices);
484
485       // Check that the solver's solution matches ours.
486       CheckLeastSquaresUpdate(updated_params);
487     }
488   }
489
490   void TestSnapshot(const Dtype learning_rate = 1.0,
491       const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
492       const int num_iters = 1) {
493     // Run the solver for num_iters * 2 iterations.
494     const int total_num_iters = num_iters * 2;
495     bool snapshot = false;
496     const int kIterSize = 1;
497     const int kDevices = 1;
498     RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
499         total_num_iters, kIterSize, kDevices, snapshot);
500
501     // Save the resulting param values.
502     vector<shared_ptr<Blob<Dtype> > > param_copies;
503     const vector<Blob<Dtype>*>& orig_params =
504         solver_->net()->learnable_params();
505     param_copies.resize(orig_params.size());
506     for (int i = 0; i < orig_params.size(); ++i) {
507       param_copies[i].reset(new Blob<Dtype>());
508       const bool kReshape = true;
509       for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
510         param_copies[i]->CopyFrom(*orig_params[i], copy_diff, kReshape);
511       }
512     }
513
514     // Save the solver history
515     vector<shared_ptr<Blob<Dtype> > > history_copies;
516     const vector<shared_ptr<Blob<Dtype> > >& orig_history = solver_->history();
517     history_copies.resize(orig_history.size());
518     for (int i = 0; i < orig_history.size(); ++i) {
519       history_copies[i].reset(new Blob<Dtype>());
520       const bool kReshape = true;
521       for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
522         history_copies[i]->CopyFrom(*orig_history[i], copy_diff, kReshape);
523       }
524     }
525
526     // Run the solver for num_iters iterations and snapshot.
527     snapshot = true;
528     string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
529         momentum, num_iters, kIterSize, kDevices, snapshot);
530
531     // Reinitialize the solver and run for num_iters more iterations.
532     snapshot = false;
533     RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
534         total_num_iters, kIterSize, kDevices,
535         snapshot, snapshot_name.c_str());
536
537     // Check that params now match.
538     const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params();
539     for (int i = 0; i < params.size(); ++i) {
540       for (int j = 0; j < params[i]->count(); ++j) {
541         EXPECT_EQ(param_copies[i]->cpu_data()[j], params[i]->cpu_data()[j])
542             << "param " << i << " data differed at dim " << j;
543         EXPECT_EQ(param_copies[i]->cpu_diff()[j], params[i]->cpu_diff()[j])
544             << "param " << i << " diff differed at dim " << j;
545       }
546     }
547
548     // Check that history now matches.
549     const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
550     for (int i = 0; i < history.size(); ++i) {
551       for (int j = 0; j < history[i]->count(); ++j) {
552         EXPECT_EQ(history_copies[i]->cpu_data()[j], history[i]->cpu_data()[j])
553             << "history blob " << i << " data differed at dim " << j;
554         EXPECT_EQ(history_copies[i]->cpu_diff()[j], history[i]->cpu_diff()[j])
555             << "history blob " << i << " diff differed at dim " << j;
556       }
557     }
558   }
559 };
560
561
562 template <typename TypeParam>
563 class SGDSolverTest : public GradientBasedSolverTest<TypeParam> {
564   typedef typename TypeParam::Dtype Dtype;
565
566  protected:
567   virtual void InitSolver(const SolverParameter& param) {
568     this->solver_.reset(new SGDSolver<Dtype>(param));
569   }
570 };
571
572 TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices);
573
574 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdate) {
575   this->TestLeastSquaresUpdate();
576 }
577
578 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateLROneHundredth) {
579   typedef typename TypeParam::Dtype Dtype;
580   const Dtype kLearningRate = 0.01;
581   this->TestLeastSquaresUpdate(kLearningRate);
582 }
583
584 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecay) {
585   typedef typename TypeParam::Dtype Dtype;
586   const Dtype kLearningRate = 0.01;
587   const Dtype kWeightDecay = 0.5;
588   const Dtype kMomentum = 0;
589   const int kNumIters = 1;
590   for (int i = 0; i <= kNumIters; ++i) {
591     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
592   }
593 }
594
595 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecayMultiIter) {
596   typedef typename TypeParam::Dtype Dtype;
597   const Dtype kLearningRate = 0.01;
598   const Dtype kWeightDecay = 0.5;
599   const Dtype kMomentum = 0;
600   const int kNumIters = 4;
601   for (int i = 0; i <= kNumIters; ++i) {
602     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
603   }
604 }
605
606 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) {
607   typedef typename TypeParam::Dtype Dtype;
608   const Dtype kLearningRate = 0.01;
609   const Dtype kWeightDecay = 0;
610   const Dtype kMomentum = 0.5;
611   const int kNumIters = 1;
612   for (int i = 0; i <= kNumIters; ++i) {
613     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
614   }
615 }
616
617 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
618   typedef typename TypeParam::Dtype Dtype;
619   const Dtype kLearningRate = 0.01;
620   const Dtype kWeightDecay = 0;
621   const Dtype kMomentum = 0.5;
622   const int kNumIters = 4;
623   for (int i = 0; i <= kNumIters; ++i) {
624     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
625   }
626 }
627
628 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) {
629   typedef typename TypeParam::Dtype Dtype;
630   const Dtype kLearningRate = 0.01;
631   const Dtype kWeightDecay = 0.5;
632   const Dtype kMomentum = 0.5;
633   const int kNumIters = 4;
634   for (int i = 0; i <= kNumIters; ++i) {
635     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
636   }
637 }
638
639 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingShare) {
640   typedef typename TypeParam::Dtype Dtype;
641   const Dtype kLearningRate = 0.01;
642   const Dtype kWeightDecay = 0.5;
643   const Dtype kMomentum = 0.5;
644   const int kNumIters = 4;
645   this->share_ = true;
646   for (int i = 0; i <= kNumIters; ++i) {
647     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
648   }
649 }
650
651 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
652   typedef typename TypeParam::Dtype Dtype;
653   const Dtype kLearningRate = 0.01;
654   const Dtype kWeightDecay = 0.5;
655   const Dtype kMomentum = 0.9;
656   const int kNumIters = 4;
657   const int kIterSize = 2;
658   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
659       kIterSize);
660 }
661
662 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
663   typedef typename TypeParam::Dtype Dtype;
664   const Dtype kLearningRate = 0.01;
665   const Dtype kWeightDecay = 0.5;
666   const Dtype kMomentum = 0.9;
667   const int kNumIters = 4;
668   const int kIterSize = 2;
669   this->share_ = true;
670   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
671       kIterSize);
672 }
673
674 TYPED_TEST(SGDSolverTest, TestSnapshot) {
675   typedef typename TypeParam::Dtype Dtype;
676   const Dtype kLearningRate = 0.01;
677   const Dtype kWeightDecay = 0.5;
678   const Dtype kMomentum = 0.9;
679   const int kNumIters = 4;
680   for (int i = 1; i <= kNumIters; ++i) {
681     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
682   }
683 }
684
685 TYPED_TEST(SGDSolverTest, TestSnapshotShare) {
686   typedef typename TypeParam::Dtype Dtype;
687   const Dtype kLearningRate = 0.01;
688   const Dtype kWeightDecay = 0.5;
689   const Dtype kMomentum = 0.9;
690   const int kNumIters = 4;
691   this->share_ = true;
692   for (int i = 1; i <= kNumIters; ++i) {
693     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
694   }
695 }
696
697
698 template <typename TypeParam>
699 class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
700   typedef typename TypeParam::Dtype Dtype;
701
702  protected:
703   virtual void InitSolver(const SolverParameter& param) {
704     this->solver_.reset(new AdaGradSolver<Dtype>(param));
705   }
706 };
707
708 TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices);
709
710 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdate) {
711   this->TestLeastSquaresUpdate();
712 }
713
714 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateLROneHundredth) {
715   typedef typename TypeParam::Dtype Dtype;
716   const Dtype kLearningRate = 0.01;
717   this->TestLeastSquaresUpdate(kLearningRate);
718 }
719
720 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithWeightDecay) {
721   typedef typename TypeParam::Dtype Dtype;
722   const Dtype kLearningRate = 0.01;
723   const Dtype kWeightDecay = 0.5;
724   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
725 }
726
727 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) {
728   typedef typename TypeParam::Dtype Dtype;
729   const Dtype kLearningRate = 0.01;
730   const Dtype kWeightDecay = 0.5;
731   const Dtype kMomentum = 0;
732   const int kNumIters = 4;
733   for (int i = 0; i <= kNumIters; ++i) {
734     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
735   }
736 }
737
738 TYPED_TEST(AdaGradSolverTest,
739       TestAdaGradLeastSquaresUpdateWithEverythingShare) {
740   typedef typename TypeParam::Dtype Dtype;
741   const Dtype kLearningRate = 0.01;
742   const Dtype kWeightDecay = 0.5;
743   const Dtype kMomentum = 0;
744   const int kNumIters = 4;
745   this->share_ = true;
746   for (int i = 0; i <= kNumIters; ++i) {
747     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
748   }
749 }
750
751 TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
752   typedef typename TypeParam::Dtype Dtype;
753   const Dtype kLearningRate = 0.01;
754   const Dtype kWeightDecay = 0.5;
755   const Dtype kMomentum = 0;
756   const int kNumIters = 4;
757   const int kIterSize = 2;
758   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
759       kIterSize);
760 }
761
762 TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
763   typedef typename TypeParam::Dtype Dtype;
764   const Dtype kLearningRate = 0.01;
765   const Dtype kWeightDecay = 0.5;
766   const Dtype kMomentum = 0;
767   const int kNumIters = 4;
768   const int kIterSize = 2;
769   this->share_ = true;
770   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
771       kIterSize);
772 }
773
774 TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
775   typedef typename TypeParam::Dtype Dtype;
776   const Dtype kLearningRate = 0.01;
777   const Dtype kWeightDecay = 0.5;
778   const Dtype kMomentum = 0;
779   const int kNumIters = 4;
780   for (int i = 1; i <= kNumIters; ++i) {
781     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
782   }
783 }
784
785 TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) {
786   typedef typename TypeParam::Dtype Dtype;
787   const Dtype kLearningRate = 0.01;
788   const Dtype kWeightDecay = 0.5;
789   const Dtype kMomentum = 0;
790   const int kNumIters = 4;
791   this->share_ = true;
792   for (int i = 1; i <= kNumIters; ++i) {
793     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
794   }
795 }
796
797
798 template <typename TypeParam>
799 class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
800   typedef typename TypeParam::Dtype Dtype;
801
802  protected:
803   virtual void InitSolver(const SolverParameter& param) {
804     this->solver_.reset(new NesterovSolver<Dtype>(param));
805   }
806 };
807
808 TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices);
809
810 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdate) {
811   this->TestLeastSquaresUpdate();
812 }
813
814 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateLROneHundredth) {
815   typedef typename TypeParam::Dtype Dtype;
816   const Dtype kLearningRate = 0.01;
817   this->TestLeastSquaresUpdate(kLearningRate);
818 }
819
820 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithWeightDecay) {
821   typedef typename TypeParam::Dtype Dtype;
822   const Dtype kLearningRate = 0.01;
823   const Dtype kWeightDecay = 0.5;
824   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
825 }
826
827 TYPED_TEST(NesterovSolverTest,
828            TestNesterovLeastSquaresUpdateWithWeightDecayMultiIter) {
829   typedef typename TypeParam::Dtype Dtype;
830   const Dtype kLearningRate = 0.01;
831   const Dtype kWeightDecay = 0.5;
832   const Dtype kMomentum = 0;
833   const int kNumIters = 4;
834   for (int i = 0; i <= kNumIters; ++i) {
835     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
836   }
837 }
838
839 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) {
840   typedef typename TypeParam::Dtype Dtype;
841   const Dtype kLearningRate = 0.01;
842   const Dtype kWeightDecay = 0;
843   const Dtype kMomentum = 0.5;
844   const int kNumIters = 1;
845   for (int i = 0; i <= kNumIters; ++i) {
846     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
847   }
848 }
849
850 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
851   typedef typename TypeParam::Dtype Dtype;
852   const Dtype kLearningRate = 0.01;
853   const Dtype kWeightDecay = 0;
854   const Dtype kMomentum = 0.5;
855   const int kNumIters = 4;
856   for (int i = 0; i <= kNumIters; ++i) {
857     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
858   }
859 }
860
861 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) {
862   typedef typename TypeParam::Dtype Dtype;
863   const Dtype kLearningRate = 0.01;
864   const Dtype kWeightDecay = 0.5;
865   const Dtype kMomentum = 0.9;
866   const int kNumIters = 4;
867   for (int i = 0; i <= kNumIters; ++i) {
868     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
869   }
870 }
871
872 TYPED_TEST(NesterovSolverTest,
873            TestNesterovLeastSquaresUpdateWithEverythingShare) {
874   typedef typename TypeParam::Dtype Dtype;
875   const Dtype kLearningRate = 0.01;
876   const Dtype kWeightDecay = 0.5;
877   const Dtype kMomentum = 0.9;
878   const int kNumIters = 4;
879   this->share_ = true;
880   for (int i = 0; i <= kNumIters; ++i) {
881     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
882   }
883 }
884
885 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
886   typedef typename TypeParam::Dtype Dtype;
887   const Dtype kLearningRate = 0.01;
888   const Dtype kWeightDecay = 0.5;
889   const Dtype kMomentum = 0.9;
890   const int kNumIters = 4;
891   const int kIterSize = 2;
892   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
893       kIterSize);
894 }
895
896 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
897   typedef typename TypeParam::Dtype Dtype;
898   const Dtype kLearningRate = 0.01;
899   const Dtype kWeightDecay = 0.5;
900   const Dtype kMomentum = 0.9;
901   const int kNumIters = 4;
902   const int kIterSize = 2;
903   this->share_ = true;
904   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
905       kIterSize);
906 }
907
908 TYPED_TEST(NesterovSolverTest, TestSnapshot) {
909   typedef typename TypeParam::Dtype Dtype;
910   const Dtype kLearningRate = 0.01;
911   const Dtype kWeightDecay = 0.5;
912   const Dtype kMomentum = 0.9;
913   const int kNumIters = 4;
914   for (int i = 1; i <= kNumIters; ++i) {
915     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
916   }
917 }
918
919 TYPED_TEST(NesterovSolverTest, TestSnapshotShare) {
920   typedef typename TypeParam::Dtype Dtype;
921   const Dtype kLearningRate = 0.01;
922   const Dtype kWeightDecay = 0.5;
923   const Dtype kMomentum = 0.9;
924   const int kNumIters = 4;
925   this->share_ = true;
926   for (int i = 1; i <= kNumIters; ++i) {
927     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
928   }
929 }
930
931 template <typename TypeParam>
932 class AdaDeltaSolverTest : public GradientBasedSolverTest<TypeParam> {
933   typedef typename TypeParam::Dtype Dtype;
934
935  protected:
936   virtual void InitSolver(const SolverParameter& param) {
937     this->solver_.reset(new AdaDeltaSolver<Dtype>(param));
938   }
939 };
940
941 TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices);
942
943 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) {
944   typedef typename TypeParam::Dtype Dtype;
945   const Dtype kLearningRate = 0.1;
946   this->TestLeastSquaresUpdate(kLearningRate);
947 }
948
949 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) {
950   typedef typename TypeParam::Dtype Dtype;
951   const Dtype kLearningRate = 0.1;
952   const Dtype kWeightDecay = 0.5;
953   const Dtype kMomentum = 0.95;
954   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
955 }
956
957 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) {
958   typedef typename TypeParam::Dtype Dtype;
959   const Dtype kLearningRate = 0.1;
960   const Dtype kWeightDecay = 0.0;
961   const Dtype kMomentum = 0.5;
962   const int kNumIters = 1;
963   for (int i = 0; i <= kNumIters; ++i) {
964     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
965   }
966 }
967
968 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) {
969   typedef typename TypeParam::Dtype Dtype;
970   const Dtype kLearningRate = 0.1;
971   const Dtype kWeightDecay = 0.0;
972   const Dtype kMomentum = 0.95;
973   const int kNumIters = 1;
974   for (int i = 0; i <= kNumIters; ++i) {
975     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
976   }
977 }
978
979 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
980   typedef typename TypeParam::Dtype Dtype;
981   const Dtype kLearningRate = 0.1;
982   const Dtype kWeightDecay = 0.0;
983   const Dtype kMomentum = 0.95;
984   const int kNumIters = 4;
985   for (int i = 0; i <= kNumIters; ++i) {
986     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
987   }
988 }
989
990 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
991   typedef typename TypeParam::Dtype Dtype;
992   const Dtype kLearningRate = 0.1;
993   const Dtype kWeightDecay = 0.1;
994   const Dtype kMomentum = 0.95;
995   const int kNumIters = 4;
996   for (int i = 0; i <= kNumIters; ++i) {
997     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
998   }
999 }
1000
1001 TYPED_TEST(AdaDeltaSolverTest,
1002            TestAdaDeltaLeastSquaresUpdateWithEverythingShare) {
1003   typedef typename TypeParam::Dtype Dtype;
1004   const Dtype kLearningRate = 0.1;
1005   const Dtype kWeightDecay = 0.1;
1006   const Dtype kMomentum = 0.95;
1007   const int kNumIters = 4;
1008   this->share_ = true;
1009   for (int i = 0; i <= kNumIters; ++i) {
1010     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1011   }
1012 }
1013
1014 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1015   typedef typename TypeParam::Dtype Dtype;
1016   const Dtype kLearningRate = 0.1;
1017   const Dtype kWeightDecay = 0.1;
1018   const Dtype kMomentum = 0.95;
1019   const int kNumIters = 4;
1020   const int kIterSize = 2;
1021   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1022       kIterSize);
1023 }
1024
1025 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1026   typedef typename TypeParam::Dtype Dtype;
1027   const Dtype kLearningRate = 0.1;
1028   const Dtype kWeightDecay = 0.1;
1029   const Dtype kMomentum = 0.95;
1030   const int kNumIters = 4;
1031   const int kIterSize = 2;
1032   this->share_ = true;
1033   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1034       kIterSize);
1035 }
1036
1037 TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) {
1038   typedef typename TypeParam::Dtype Dtype;
1039   const Dtype kLearningRate = 0.1;
1040   const Dtype kWeightDecay = 0.1;
1041   const Dtype kMomentum = 0.95;
1042   const int kNumIters = 4;
1043   for (int i = 1; i <= kNumIters; ++i) {
1044     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1045   }
1046 }
1047
1048 TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) {
1049   typedef typename TypeParam::Dtype Dtype;
1050   const Dtype kLearningRate = 0.1;
1051   const Dtype kWeightDecay = 0.1;
1052   const Dtype kMomentum = 0.95;
1053   const int kNumIters = 4;
1054   this->share_ = true;
1055   for (int i = 1; i <= kNumIters; ++i) {
1056     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1057   }
1058 }
1059
1060 template <typename TypeParam>
1061 class AdamSolverTest : public GradientBasedSolverTest<TypeParam> {
1062   typedef typename TypeParam::Dtype Dtype;
1063
1064  protected:
1065   virtual void InitSolver(const SolverParameter& param) {
1066     SolverParameter new_param = param;
1067     const Dtype momentum = 0.9;
1068     new_param.set_momentum(momentum);
1069     const Dtype momentum2 = 0.999;
1070     new_param.set_momentum2(momentum2);
1071     this->solver_.reset(new AdamSolver<Dtype>(new_param));
1072   }
1073 };
1074
1075 TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices);
1076
1077 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) {
1078   typedef typename TypeParam::Dtype Dtype;
1079   const Dtype kLearningRate = 0.01;
1080   const Dtype kWeightDecay = 0;
1081   const Dtype kMomentum = 0.9;
1082   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
1083 }
1084
1085 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithWeightDecay) {
1086   typedef typename TypeParam::Dtype Dtype;
1087   const Dtype kLearningRate = 0.01;
1088   const Dtype kWeightDecay = 0.5;
1089   const Dtype kMomentum = 0.9;
1090   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
1091 }
1092
1093 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverything) {
1094   typedef typename TypeParam::Dtype Dtype;
1095   const Dtype kLearningRate = 0.01;
1096   const Dtype kWeightDecay = 0.5;
1097   const Dtype kMomentum = 0.9;
1098   const int kNumIters = 4;
1099   for (int i = 0; i <= kNumIters; ++i) {
1100     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1101   }
1102 }
1103
1104 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverythingShare) {
1105   typedef typename TypeParam::Dtype Dtype;
1106   const Dtype kLearningRate = 0.01;
1107   const Dtype kWeightDecay = 0.5;
1108   const Dtype kMomentum = 0.9;
1109   const int kNumIters = 4;
1110   this->share_ = true;
1111   for (int i = 0; i <= kNumIters; ++i) {
1112     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1113   }
1114 }
1115
1116 TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1117   typedef typename TypeParam::Dtype Dtype;
1118   const Dtype kLearningRate = 0.01;
1119   const Dtype kWeightDecay = 0.5;
1120   const Dtype kMomentum = 0.9;
1121   const int kNumIters = 4;
1122   const int kIterSize = 2;
1123   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1124       kIterSize);
1125 }
1126
1127 TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1128   typedef typename TypeParam::Dtype Dtype;
1129   const Dtype kLearningRate = 0.01;
1130   const Dtype kWeightDecay = 0.5;
1131   const Dtype kMomentum = 0.9;
1132   const int kNumIters = 4;
1133   const int kIterSize = 2;
1134   this->share_ = true;
1135   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1136       kIterSize);
1137 }
1138
1139 TYPED_TEST(AdamSolverTest, TestSnapshot) {
1140   typedef typename TypeParam::Dtype Dtype;
1141   const Dtype kLearningRate = 0.01;
1142   const Dtype kWeightDecay = 0.5;
1143   const Dtype kMomentum = 0.9;
1144   const int kNumIters = 4;
1145   for (int i = 1; i <= kNumIters; ++i) {
1146     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1147   }
1148 }
1149
1150 TYPED_TEST(AdamSolverTest, TestSnapshotShare) {
1151   typedef typename TypeParam::Dtype Dtype;
1152   const Dtype kLearningRate = 0.01;
1153   const Dtype kWeightDecay = 0.5;
1154   const Dtype kMomentum = 0.9;
1155   const int kNumIters = 4;
1156   this->share_ = true;
1157   for (int i = 1; i <= kNumIters; ++i) {
1158     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1159   }
1160 }
1161
1162 template <typename TypeParam>
1163 class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
1164   typedef typename TypeParam::Dtype Dtype;
1165
1166  protected:
1167   virtual void InitSolver(const SolverParameter& param) {
1168     const Dtype rms_decay = 0.95;
1169     SolverParameter new_param = param;
1170     new_param.set_rms_decay(rms_decay);
1171     this->solver_.reset(new RMSPropSolver<Dtype>(new_param));
1172   }
1173 };
1174
1175 TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);
1176
1177 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithWeightDecay) {
1178   typedef typename TypeParam::Dtype Dtype;
1179   const Dtype kLearningRate = 1.0;
1180   const Dtype kWeightDecay = 0.5;
1181   this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
1182 }
1183
1184 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) {
1185   typedef typename TypeParam::Dtype Dtype;
1186   const Dtype kLearningRate = 0.01;
1187   const Dtype kWeightDecay = 0.0;
1188   const Dtype kMomentum = 0.0;
1189   const int kNumIters = 4;
1190   for (int i = 0; i <= kNumIters; ++i) {
1191     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1192   }
1193 }
1194
1195 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) {
1196   typedef typename TypeParam::Dtype Dtype;
1197   const Dtype kLearningRate = 0.01;
1198   const Dtype kWeightDecay = 0.5;
1199   const Dtype kMomentum = 0.0;
1200   const int kNumIters = 4;
1201   for (int i = 0; i <= kNumIters; ++i) {
1202     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1203   }
1204 }
1205
1206 TYPED_TEST(RMSPropSolverTest,
1207       TestRMSPropLeastSquaresUpdateWithEverythingShare) {
1208   typedef typename TypeParam::Dtype Dtype;
1209   const Dtype kLearningRate = 0.01;
1210   const Dtype kWeightDecay = 0.5;
1211   const Dtype kMomentum = 0.0;
1212   const int kNumIters = 4;
1213   this->share_ = true;
1214   for (int i = 0; i <= kNumIters; ++i) {
1215     this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1216   }
1217 }
1218
1219 TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1220   typedef typename TypeParam::Dtype Dtype;
1221   const Dtype kLearningRate = 0.01;
1222   const Dtype kWeightDecay = 0.5;
1223   const Dtype kMomentum = 0.0;
1224   const int kNumIters = 4;
1225   const int kIterSize = 2;
1226   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1227       kIterSize);
1228 }
1229
1230 TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1231   typedef typename TypeParam::Dtype Dtype;
1232   const Dtype kLearningRate = 0.01;
1233   const Dtype kWeightDecay = 0.5;
1234   const Dtype kMomentum = 0.0;
1235   const int kNumIters = 4;
1236   const int kIterSize = 2;
1237   this->share_ = true;
1238   this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1239       kIterSize);
1240 }
1241
1242 TYPED_TEST(RMSPropSolverTest, TestSnapshot) {
1243   typedef typename TypeParam::Dtype Dtype;
1244   const Dtype kLearningRate = 0.01;
1245   const Dtype kWeightDecay = 0.5;
1246   const Dtype kMomentum = 0;
1247   const int kNumIters = 4;
1248   for (int i = 1; i <= kNumIters; ++i) {
1249     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1250   }
1251 }
1252
1253 TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) {
1254   typedef typename TypeParam::Dtype Dtype;
1255   const Dtype kLearningRate = 0.01;
1256   const Dtype kWeightDecay = 0.5;
1257   const Dtype kMomentum = 0;
1258   const int kNumIters = 4;
1259   this->share_ = true;
1260   for (int i = 1; i <= kNumIters; ++i) {
1261     this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1262   }
1263 }
1264
1265 }  // namespace caffe