6 #include "google/protobuf/text_format.h"
8 #include "gtest/gtest.h"
10 #include "caffe/common.hpp"
11 #include "caffe/parallel.hpp"
12 #include "caffe/proto/caffe.pb.h"
13 #include "caffe/sgd_solvers.hpp"
14 #include "caffe/util/io.hpp"
16 #include "caffe/test/test_caffe_main.hpp"
18 using std::ostringstream;
22 template <typename TypeParam>
23 class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
24 typedef typename TypeParam::Dtype Dtype;
27 GradientBasedSolverTest() :
28 seed_(1701), num_(4), channels_(3), height_(10), width_(10),
30 input_file_ = new string(
31 CMAKE_SOURCE_DIR "caffe/test/test_data/solver_data_list.txt" CMAKE_EXT);
33 ~GradientBasedSolverTest() {
37 string snapshot_prefix_;
38 shared_ptr<SGDSolver<Dtype> > solver_;
39 shared_ptr<P2PSync<Dtype> > sync_;
41 // Dimensions are determined by generate_sample_data.py
42 // TODO this is brittle and the hdf5 file should be checked instead.
43 int num_, channels_, height_, width_;
45 Dtype delta_; // Stability constant for RMSProp, AdaGrad, AdaDelta and Adam
47 // Test data: check out generate_sample_data.py in the same directory.
50 virtual void InitSolver(const SolverParameter& param) = 0;
52 virtual void InitSolverFromProtoString(const string& proto) {
53 SolverParameter param;
54 CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m));
55 // Set the solver_mode according to current Caffe::mode.
56 switch (Caffe::mode()) {
58 param.set_solver_mode(SolverParameter_SolverMode_CPU);
61 param.set_solver_mode(SolverParameter_SolverMode_GPU);
64 LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
67 delta_ = param.delta();
70 string RunLeastSquaresSolver(const Dtype learning_rate,
71 const Dtype weight_decay, const Dtype momentum, const int num_iters,
72 const int iter_size = 1, const int devices = 1,
73 const bool snapshot = false, const char* from_snapshot = NULL) {
77 if (Caffe::mode() == Caffe::GPU) {
78 CUDA_CHECK(cudaGetDevice(&device_id));
82 "snapshot_after_train: " << snapshot << " "
83 "max_iter: " << num_iters << " "
84 "base_lr: " << learning_rate << " "
86 "iter_size: " << iter_size << " "
87 "device_id: " << device_id << " "
89 " name: 'TestNetwork' "
94 " source: '" << *(this->input_file_) << "' "
95 " batch_size: " << num_ / iter_size << " "
115 " name: 'innerprod' "
116 " type: 'InnerProduct' "
117 " param { name: 'weights' } "
118 " param { name: 'bias' } "
119 " inner_product_param { "
130 " bottom: '" << string(share_ ? "data1": "data") << "' "
131 " top: '" << string(share_ ? "innerprod1": "innerprod") << "' "
136 " name: 'innerprod2' "
137 " type: 'InnerProduct' "
138 " param { name: 'weights' } "
139 " param { name: 'bias' } "
140 " inner_product_param { "
152 " top: 'innerprod2' "
157 " bottom: 'innerprod1' "
158 " bottom: 'innerprod2' "
168 " type: 'EuclideanLoss' "
169 " bottom: 'innerprod' "
170 " bottom: 'targets' "
173 if (weight_decay != 0) {
174 proto << "weight_decay: " << weight_decay << " ";
177 proto << "momentum: " << momentum << " ";
179 MakeTempDir(&snapshot_prefix_);
180 proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' ";
182 proto << "snapshot: " << num_iters << " ";
184 Caffe::set_random_seed(this->seed_);
185 this->InitSolverFromProtoString(proto.str());
186 if (from_snapshot != NULL) {
187 this->solver_->Restore(from_snapshot);
188 for (int i = 0; i < this->solver_->iter(); ++i) {
189 this->solver_->net()->Forward();
193 this->solver_->Solve();
195 LOG(INFO) << "Multi-GPU test on " << devices << " devices";
197 // put current device at the beginning
198 int device_id = solver_->param().device_id();
199 gpus.push_back(device_id);
200 for (int i = 0; gpus.size() < devices; ++i) {
204 Caffe::set_solver_count(gpus.size());
205 this->sync_.reset(new P2PSync<Dtype>(
206 this->solver_, NULL, this->solver_->param()));
207 this->sync_->Run(gpus);
208 Caffe::set_solver_count(1);
211 ostringstream resume_file;
212 resume_file << snapshot_prefix_ << "/_iter_" << num_iters
214 string resume_filename = resume_file.str();
215 return resume_filename;
220 // Compute an update value given the current state of the train net,
221 // using the analytical formula for the least squares gradient.
222 // updated_params will store the updated weight and bias results,
223 // using the blobs' diffs to hold the update values themselves.
224 void ComputeLeastSquaresUpdate(const Dtype learning_rate,
225 const Dtype weight_decay, const Dtype momentum, const int num_iters,
226 vector<shared_ptr<Blob<Dtype> > >* updated_params) {
228 const int D = channels_ * height_ * width_;
230 // Run a forward pass, and manually compute the update values from the
232 Net<Dtype>& net = *this->solver_->net();
234 ASSERT_TRUE(net.has_blob("data"));
235 const Blob<Dtype>& data = *net.blob_by_name("data");
236 ASSERT_TRUE(net.has_blob("targets"));
237 const Blob<Dtype>& targets = *net.blob_by_name("targets");
238 ASSERT_TRUE(net.has_layer("innerprod"));
239 const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
240 net.layer_by_name("innerprod")->blobs();
241 const int num_param_blobs = 2;
242 ASSERT_EQ(num_param_blobs, param_blobs.size());
243 const Blob<Dtype>& weights = *param_blobs[0];
244 const Blob<Dtype>& bias = *param_blobs[1];
245 ASSERT_EQ(D * N, data.count());
246 ASSERT_EQ(N, targets.count());
247 ASSERT_EQ(D, weights.count());
248 ASSERT_EQ(1, bias.count());
250 updated_params->clear();
251 updated_params->resize(num_param_blobs);
252 for (int i = 0; i < num_param_blobs; ++i) {
253 (*updated_params)[i].reset(new Blob<Dtype>());
255 Blob<Dtype>& updated_weights = *(*updated_params)[0];
256 updated_weights.ReshapeLike(weights);
257 Blob<Dtype>& updated_bias = *(*updated_params)[1];
258 updated_bias.ReshapeLike(bias);
260 for (int i = 0; i <= D; ++i) {
261 // Compute the derivative with respect to the ith weight (i.e., the ith
262 // element of the gradient).
264 for (int j = 0; j <= D; ++j) {
265 // Compute element (i, j) of X^T * X.
267 for (int k = 0; k < N; ++k) {
268 // (i, k) in X^T (== (k, i) in X) times (k, j) in X.
269 const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
270 const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j];
271 element += element_i * element_j;
274 grad += element * bias.cpu_data()[0];
276 grad += element * weights.cpu_data()[j];
279 for (int k = 0; k < N; ++k) {
280 const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
281 grad -= element_i * targets.cpu_data()[k];
283 // Scale the gradient over the N samples.
285 // Add the weight decay to the gradient.
286 grad += weight_decay *
287 ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
288 // Finally, compute update.
289 const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
290 if (solver_->type() != string("AdaDelta")
291 && solver_->type() != string("Adam")) {
292 ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias
294 ASSERT_EQ(4, history.size()); // additional blobs for update history
296 Dtype update_value = learning_rate * grad;
297 const Dtype history_value = (i == D) ?
298 history[1]->cpu_data()[0] : history[0]->cpu_data()[i];
299 const Dtype temp = momentum * history_value;
300 if (solver_->type() == string("SGD")) {
301 update_value += temp;
302 } else if (solver_->type() == string("Nesterov")) {
303 update_value += temp;
304 // step back then over-step
305 update_value = (1 + momentum) * update_value - temp;
306 } else if (solver_->type() == string("AdaGrad")) {
307 update_value /= std::sqrt(history_value + grad * grad) + delta_;
308 } else if (solver_->type() == string("RMSProp")) {
309 const Dtype rms_decay = 0.95;
310 update_value /= std::sqrt(rms_decay*history_value
311 + grad * grad * (1 - rms_decay)) + delta_;
312 } else if (solver_->type() == string("AdaDelta")) {
313 const Dtype update_history_value = (i == D) ?
314 history[1 + num_param_blobs]->cpu_data()[0] :
315 history[0 + num_param_blobs]->cpu_data()[i];
316 const Dtype weighted_gradient_average =
317 momentum * history_value + (1 - momentum) * (grad * grad);
318 update_value = grad * std::sqrt((update_history_value + delta_) /
319 (weighted_gradient_average + delta_)) * learning_rate;
320 // not actually needed, just here for illustrative purposes
321 // const Dtype weighted_update_average =
322 // momentum * update_history_value + (1 - momentum) * (update_value);
323 } else if (solver_->type() == string("Adam")) {
324 const Dtype momentum2 = 0.999;
325 const Dtype m = history_value;
326 const Dtype v = (i == D) ?
327 history[1 + num_param_blobs]->cpu_data()[0] :
328 history[0 + num_param_blobs]->cpu_data()[i];
329 const Dtype val_m = (1 - momentum) * grad + momentum * m;
330 const Dtype val_v = (1 - momentum2) * grad * grad + momentum2 * v;
331 Dtype alpha_t = learning_rate *
332 std::sqrt(Dtype(1) - pow(momentum2, num_iters)) /
333 (Dtype(1.) - pow(momentum, num_iters));
334 update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
336 LOG(FATAL) << "Unknown solver type: " << solver_->type();
339 updated_bias.mutable_cpu_diff()[0] = update_value;
340 updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value;
342 updated_weights.mutable_cpu_diff()[i] = update_value;
343 updated_weights.mutable_cpu_data()[i] =
344 weights.cpu_data()[i] - update_value;
349 void CheckLeastSquaresUpdate(
350 const vector<shared_ptr<Blob<Dtype> > >& updated_params) {
351 const int D = channels_ * height_ * width_;
353 const Blob<Dtype>& updated_weights = *updated_params[0];
354 const Blob<Dtype>& updated_bias = *updated_params[1];
356 Net<Dtype>& net = *this->solver_->net();
357 ASSERT_TRUE(net.has_layer("innerprod"));
358 const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
359 net.layer_by_name("innerprod")->blobs();
360 ASSERT_EQ(2, param_blobs.size());
361 const Blob<Dtype>& solver_updated_weights = *param_blobs[0];
362 ASSERT_EQ(D, solver_updated_weights.count());
363 const double kPrecision = 1e-2;
364 const double kMinPrecision = 1e-7;
365 for (int i = 0; i < D; ++i) {
366 const Dtype expected_updated_weight = updated_weights.cpu_data()[i];
367 const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i];
368 const Dtype error_margin = std::max(kMinPrecision, kPrecision *
369 std::min(fabs(expected_updated_weight), fabs(solver_updated_weight)));
370 EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin);
372 const Blob<Dtype>& solver_updated_bias_blob = *param_blobs[1];
373 ASSERT_EQ(1, solver_updated_bias_blob.count());
374 const Dtype expected_updated_bias = updated_bias.cpu_data()[0];
375 const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0];
376 const Dtype error_margin = std::max(kMinPrecision, kPrecision *
377 std::min(fabs(expected_updated_bias), fabs(solver_updated_bias)));
378 EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);
380 // Check the solver's history -- should contain the previous update value.
381 if (solver_->type() == string("SGD")) {
382 const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
383 ASSERT_EQ(2, history.size());
384 for (int i = 0; i < D; ++i) {
385 const Dtype expected_history = updated_weights.cpu_diff()[i];
386 const Dtype solver_history = history[0]->cpu_data()[i];
387 const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
388 std::min(fabs(expected_history), fabs(solver_history)));
389 EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
391 const Dtype expected_history = updated_bias.cpu_diff()[0];
392 const Dtype solver_history = history[1]->cpu_data()[0];
393 const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
394 std::min(fabs(expected_history), fabs(solver_history)));
395 EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
399 void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay,
400 const Dtype kMomentum, const int kNumIters, const int kIterSize) {
401 const double kPrecision = 1e-2;
402 const double kMinPrecision = 1e-7;
403 // Solve without accumulation and save parameters.
404 this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
406 // Save parameters for comparison.
407 Net<Dtype>& net = *this->solver_->net();
408 const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
409 net.layer_by_name("innerprod")->blobs();
410 vector<shared_ptr<Blob<Dtype> > > noaccum_params(param_blobs.size());
411 for (int i = 0; i < param_blobs.size(); ++i) {
412 noaccum_params[i].reset(new Blob<Dtype>());
413 noaccum_params[i]->CopyFrom(*param_blobs[i], false, true);
415 // Solve by equivalent accumulation of gradients over divided batches.
416 this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
417 kNumIters, kIterSize);
418 Net<Dtype>& net_accum = *this->solver_->net();
419 const vector<shared_ptr<Blob<Dtype> > >& accum_params =
420 net_accum.layer_by_name("innerprod")->blobs();
421 // Compare accumulated parameters against no accumulation standard.
422 const int D = this->channels_ * this->height_ * this->width_;
423 for (int i = 0; i < D; ++i) {
424 const Dtype expected_param = noaccum_params[0]->cpu_data()[i];
425 const Dtype accum_param = accum_params[0]->cpu_data()[i];
426 const Dtype error_margin = std::max(kMinPrecision, kPrecision *
427 std::min(fabs(expected_param), fabs(accum_param)));
428 EXPECT_NEAR(expected_param, accum_param, error_margin);
430 ASSERT_EQ(1, accum_params[1]->count());
431 const Dtype expected_bias = noaccum_params[1]->cpu_data()[0];
432 const Dtype accum_bias = accum_params[1]->cpu_data()[0];
433 const Dtype error_margin = std::max(kMinPrecision, kPrecision *
434 std::min(fabs(expected_bias), fabs(accum_bias)));
435 EXPECT_NEAR(expected_bias, accum_bias, error_margin);
438 // Test that the correct update is computed for a regularized least squares
441 // E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2
442 // \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w
444 // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1)
445 // w \in R^{(d+1) x 1} ((d+1)th element is the bias)
447 // lambda is weight_decay
449 // TestLeastSquaresUpdate works "inductively", assuming that the solver
450 // correctly updates the net K (= iter_to_check) times, then given the history
451 // from the Kth update, we compute the (K+1)th update and check that it
452 // matches the solver's (K+1)th update.
453 void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
454 const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
455 const int iter_to_check = 0) {
456 const int kNum = num_;
457 const int kIterSize = 1;
458 // Test over all numbers of devices.
459 int available_devices = 1;
461 if (Caffe::mode() == Caffe::GPU) {
462 CUDA_CHECK(cudaGetDeviceCount(&available_devices));
465 for (int devices = 1; devices <= available_devices; ++devices) {
466 // Configure batch size for single / multi device equivalence.
467 // Constant data is needed for multi device as for accumulation.
468 num_ = kNum * devices;
470 // Initialize the solver and run K (= iter_to_check) solver iterations
471 // (on single device).
472 RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
473 iter_to_check, kIterSize, 1);
475 // Compute the (K+1)th update using the analytic least squares gradient.
476 vector<shared_ptr<Blob<Dtype> > > updated_params;
477 ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
478 iter_to_check + 1, &updated_params);
480 // Reinitialize the solver and run K+1 solver iterations.
482 RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
483 iter_to_check + 1, kIterSize, devices);
485 // Check that the solver's solution matches ours.
486 CheckLeastSquaresUpdate(updated_params);
490 void TestSnapshot(const Dtype learning_rate = 1.0,
491 const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
492 const int num_iters = 1) {
493 // Run the solver for num_iters * 2 iterations.
494 const int total_num_iters = num_iters * 2;
495 bool snapshot = false;
496 const int kIterSize = 1;
497 const int kDevices = 1;
498 RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
499 total_num_iters, kIterSize, kDevices, snapshot);
501 // Save the resulting param values.
502 vector<shared_ptr<Blob<Dtype> > > param_copies;
503 const vector<Blob<Dtype>*>& orig_params =
504 solver_->net()->learnable_params();
505 param_copies.resize(orig_params.size());
506 for (int i = 0; i < orig_params.size(); ++i) {
507 param_copies[i].reset(new Blob<Dtype>());
508 const bool kReshape = true;
509 for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
510 param_copies[i]->CopyFrom(*orig_params[i], copy_diff, kReshape);
514 // Save the solver history
515 vector<shared_ptr<Blob<Dtype> > > history_copies;
516 const vector<shared_ptr<Blob<Dtype> > >& orig_history = solver_->history();
517 history_copies.resize(orig_history.size());
518 for (int i = 0; i < orig_history.size(); ++i) {
519 history_copies[i].reset(new Blob<Dtype>());
520 const bool kReshape = true;
521 for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
522 history_copies[i]->CopyFrom(*orig_history[i], copy_diff, kReshape);
526 // Run the solver for num_iters iterations and snapshot.
528 string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
529 momentum, num_iters, kIterSize, kDevices, snapshot);
531 // Reinitialize the solver and run for num_iters more iterations.
533 RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
534 total_num_iters, kIterSize, kDevices,
535 snapshot, snapshot_name.c_str());
537 // Check that params now match.
538 const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params();
539 for (int i = 0; i < params.size(); ++i) {
540 for (int j = 0; j < params[i]->count(); ++j) {
541 EXPECT_EQ(param_copies[i]->cpu_data()[j], params[i]->cpu_data()[j])
542 << "param " << i << " data differed at dim " << j;
543 EXPECT_EQ(param_copies[i]->cpu_diff()[j], params[i]->cpu_diff()[j])
544 << "param " << i << " diff differed at dim " << j;
548 // Check that history now matches.
549 const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
550 for (int i = 0; i < history.size(); ++i) {
551 for (int j = 0; j < history[i]->count(); ++j) {
552 EXPECT_EQ(history_copies[i]->cpu_data()[j], history[i]->cpu_data()[j])
553 << "history blob " << i << " data differed at dim " << j;
554 EXPECT_EQ(history_copies[i]->cpu_diff()[j], history[i]->cpu_diff()[j])
555 << "history blob " << i << " diff differed at dim " << j;
562 template <typename TypeParam>
563 class SGDSolverTest : public GradientBasedSolverTest<TypeParam> {
564 typedef typename TypeParam::Dtype Dtype;
567 virtual void InitSolver(const SolverParameter& param) {
568 this->solver_.reset(new SGDSolver<Dtype>(param));
572 TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices);
574 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdate) {
575 this->TestLeastSquaresUpdate();
578 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateLROneHundredth) {
579 typedef typename TypeParam::Dtype Dtype;
580 const Dtype kLearningRate = 0.01;
581 this->TestLeastSquaresUpdate(kLearningRate);
584 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecay) {
585 typedef typename TypeParam::Dtype Dtype;
586 const Dtype kLearningRate = 0.01;
587 const Dtype kWeightDecay = 0.5;
588 const Dtype kMomentum = 0;
589 const int kNumIters = 1;
590 for (int i = 0; i <= kNumIters; ++i) {
591 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
595 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecayMultiIter) {
596 typedef typename TypeParam::Dtype Dtype;
597 const Dtype kLearningRate = 0.01;
598 const Dtype kWeightDecay = 0.5;
599 const Dtype kMomentum = 0;
600 const int kNumIters = 4;
601 for (int i = 0; i <= kNumIters; ++i) {
602 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
606 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) {
607 typedef typename TypeParam::Dtype Dtype;
608 const Dtype kLearningRate = 0.01;
609 const Dtype kWeightDecay = 0;
610 const Dtype kMomentum = 0.5;
611 const int kNumIters = 1;
612 for (int i = 0; i <= kNumIters; ++i) {
613 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
617 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
618 typedef typename TypeParam::Dtype Dtype;
619 const Dtype kLearningRate = 0.01;
620 const Dtype kWeightDecay = 0;
621 const Dtype kMomentum = 0.5;
622 const int kNumIters = 4;
623 for (int i = 0; i <= kNumIters; ++i) {
624 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
628 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) {
629 typedef typename TypeParam::Dtype Dtype;
630 const Dtype kLearningRate = 0.01;
631 const Dtype kWeightDecay = 0.5;
632 const Dtype kMomentum = 0.5;
633 const int kNumIters = 4;
634 for (int i = 0; i <= kNumIters; ++i) {
635 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
639 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingShare) {
640 typedef typename TypeParam::Dtype Dtype;
641 const Dtype kLearningRate = 0.01;
642 const Dtype kWeightDecay = 0.5;
643 const Dtype kMomentum = 0.5;
644 const int kNumIters = 4;
646 for (int i = 0; i <= kNumIters; ++i) {
647 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
651 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
652 typedef typename TypeParam::Dtype Dtype;
653 const Dtype kLearningRate = 0.01;
654 const Dtype kWeightDecay = 0.5;
655 const Dtype kMomentum = 0.9;
656 const int kNumIters = 4;
657 const int kIterSize = 2;
658 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
662 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
663 typedef typename TypeParam::Dtype Dtype;
664 const Dtype kLearningRate = 0.01;
665 const Dtype kWeightDecay = 0.5;
666 const Dtype kMomentum = 0.9;
667 const int kNumIters = 4;
668 const int kIterSize = 2;
670 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
674 TYPED_TEST(SGDSolverTest, TestSnapshot) {
675 typedef typename TypeParam::Dtype Dtype;
676 const Dtype kLearningRate = 0.01;
677 const Dtype kWeightDecay = 0.5;
678 const Dtype kMomentum = 0.9;
679 const int kNumIters = 4;
680 for (int i = 1; i <= kNumIters; ++i) {
681 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
685 TYPED_TEST(SGDSolverTest, TestSnapshotShare) {
686 typedef typename TypeParam::Dtype Dtype;
687 const Dtype kLearningRate = 0.01;
688 const Dtype kWeightDecay = 0.5;
689 const Dtype kMomentum = 0.9;
690 const int kNumIters = 4;
692 for (int i = 1; i <= kNumIters; ++i) {
693 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
698 template <typename TypeParam>
699 class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
700 typedef typename TypeParam::Dtype Dtype;
703 virtual void InitSolver(const SolverParameter& param) {
704 this->solver_.reset(new AdaGradSolver<Dtype>(param));
708 TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices);
710 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdate) {
711 this->TestLeastSquaresUpdate();
714 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateLROneHundredth) {
715 typedef typename TypeParam::Dtype Dtype;
716 const Dtype kLearningRate = 0.01;
717 this->TestLeastSquaresUpdate(kLearningRate);
720 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithWeightDecay) {
721 typedef typename TypeParam::Dtype Dtype;
722 const Dtype kLearningRate = 0.01;
723 const Dtype kWeightDecay = 0.5;
724 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
727 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) {
728 typedef typename TypeParam::Dtype Dtype;
729 const Dtype kLearningRate = 0.01;
730 const Dtype kWeightDecay = 0.5;
731 const Dtype kMomentum = 0;
732 const int kNumIters = 4;
733 for (int i = 0; i <= kNumIters; ++i) {
734 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
738 TYPED_TEST(AdaGradSolverTest,
739 TestAdaGradLeastSquaresUpdateWithEverythingShare) {
740 typedef typename TypeParam::Dtype Dtype;
741 const Dtype kLearningRate = 0.01;
742 const Dtype kWeightDecay = 0.5;
743 const Dtype kMomentum = 0;
744 const int kNumIters = 4;
746 for (int i = 0; i <= kNumIters; ++i) {
747 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
751 TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
752 typedef typename TypeParam::Dtype Dtype;
753 const Dtype kLearningRate = 0.01;
754 const Dtype kWeightDecay = 0.5;
755 const Dtype kMomentum = 0;
756 const int kNumIters = 4;
757 const int kIterSize = 2;
758 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
762 TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
763 typedef typename TypeParam::Dtype Dtype;
764 const Dtype kLearningRate = 0.01;
765 const Dtype kWeightDecay = 0.5;
766 const Dtype kMomentum = 0;
767 const int kNumIters = 4;
768 const int kIterSize = 2;
770 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
774 TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
775 typedef typename TypeParam::Dtype Dtype;
776 const Dtype kLearningRate = 0.01;
777 const Dtype kWeightDecay = 0.5;
778 const Dtype kMomentum = 0;
779 const int kNumIters = 4;
780 for (int i = 1; i <= kNumIters; ++i) {
781 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
785 TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) {
786 typedef typename TypeParam::Dtype Dtype;
787 const Dtype kLearningRate = 0.01;
788 const Dtype kWeightDecay = 0.5;
789 const Dtype kMomentum = 0;
790 const int kNumIters = 4;
792 for (int i = 1; i <= kNumIters; ++i) {
793 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
798 template <typename TypeParam>
799 class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
800 typedef typename TypeParam::Dtype Dtype;
803 virtual void InitSolver(const SolverParameter& param) {
804 this->solver_.reset(new NesterovSolver<Dtype>(param));
808 TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices);
810 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdate) {
811 this->TestLeastSquaresUpdate();
814 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateLROneHundredth) {
815 typedef typename TypeParam::Dtype Dtype;
816 const Dtype kLearningRate = 0.01;
817 this->TestLeastSquaresUpdate(kLearningRate);
820 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithWeightDecay) {
821 typedef typename TypeParam::Dtype Dtype;
822 const Dtype kLearningRate = 0.01;
823 const Dtype kWeightDecay = 0.5;
824 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
827 TYPED_TEST(NesterovSolverTest,
828 TestNesterovLeastSquaresUpdateWithWeightDecayMultiIter) {
829 typedef typename TypeParam::Dtype Dtype;
830 const Dtype kLearningRate = 0.01;
831 const Dtype kWeightDecay = 0.5;
832 const Dtype kMomentum = 0;
833 const int kNumIters = 4;
834 for (int i = 0; i <= kNumIters; ++i) {
835 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
839 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) {
840 typedef typename TypeParam::Dtype Dtype;
841 const Dtype kLearningRate = 0.01;
842 const Dtype kWeightDecay = 0;
843 const Dtype kMomentum = 0.5;
844 const int kNumIters = 1;
845 for (int i = 0; i <= kNumIters; ++i) {
846 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
850 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
851 typedef typename TypeParam::Dtype Dtype;
852 const Dtype kLearningRate = 0.01;
853 const Dtype kWeightDecay = 0;
854 const Dtype kMomentum = 0.5;
855 const int kNumIters = 4;
856 for (int i = 0; i <= kNumIters; ++i) {
857 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
861 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) {
862 typedef typename TypeParam::Dtype Dtype;
863 const Dtype kLearningRate = 0.01;
864 const Dtype kWeightDecay = 0.5;
865 const Dtype kMomentum = 0.9;
866 const int kNumIters = 4;
867 for (int i = 0; i <= kNumIters; ++i) {
868 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
872 TYPED_TEST(NesterovSolverTest,
873 TestNesterovLeastSquaresUpdateWithEverythingShare) {
874 typedef typename TypeParam::Dtype Dtype;
875 const Dtype kLearningRate = 0.01;
876 const Dtype kWeightDecay = 0.5;
877 const Dtype kMomentum = 0.9;
878 const int kNumIters = 4;
880 for (int i = 0; i <= kNumIters; ++i) {
881 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
885 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
886 typedef typename TypeParam::Dtype Dtype;
887 const Dtype kLearningRate = 0.01;
888 const Dtype kWeightDecay = 0.5;
889 const Dtype kMomentum = 0.9;
890 const int kNumIters = 4;
891 const int kIterSize = 2;
892 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
896 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
897 typedef typename TypeParam::Dtype Dtype;
898 const Dtype kLearningRate = 0.01;
899 const Dtype kWeightDecay = 0.5;
900 const Dtype kMomentum = 0.9;
901 const int kNumIters = 4;
902 const int kIterSize = 2;
904 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
908 TYPED_TEST(NesterovSolverTest, TestSnapshot) {
909 typedef typename TypeParam::Dtype Dtype;
910 const Dtype kLearningRate = 0.01;
911 const Dtype kWeightDecay = 0.5;
912 const Dtype kMomentum = 0.9;
913 const int kNumIters = 4;
914 for (int i = 1; i <= kNumIters; ++i) {
915 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
919 TYPED_TEST(NesterovSolverTest, TestSnapshotShare) {
920 typedef typename TypeParam::Dtype Dtype;
921 const Dtype kLearningRate = 0.01;
922 const Dtype kWeightDecay = 0.5;
923 const Dtype kMomentum = 0.9;
924 const int kNumIters = 4;
926 for (int i = 1; i <= kNumIters; ++i) {
927 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
931 template <typename TypeParam>
932 class AdaDeltaSolverTest : public GradientBasedSolverTest<TypeParam> {
933 typedef typename TypeParam::Dtype Dtype;
936 virtual void InitSolver(const SolverParameter& param) {
937 this->solver_.reset(new AdaDeltaSolver<Dtype>(param));
941 TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices);
943 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) {
944 typedef typename TypeParam::Dtype Dtype;
945 const Dtype kLearningRate = 0.1;
946 this->TestLeastSquaresUpdate(kLearningRate);
949 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) {
950 typedef typename TypeParam::Dtype Dtype;
951 const Dtype kLearningRate = 0.1;
952 const Dtype kWeightDecay = 0.5;
953 const Dtype kMomentum = 0.95;
954 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
957 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) {
958 typedef typename TypeParam::Dtype Dtype;
959 const Dtype kLearningRate = 0.1;
960 const Dtype kWeightDecay = 0.0;
961 const Dtype kMomentum = 0.5;
962 const int kNumIters = 1;
963 for (int i = 0; i <= kNumIters; ++i) {
964 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
968 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) {
969 typedef typename TypeParam::Dtype Dtype;
970 const Dtype kLearningRate = 0.1;
971 const Dtype kWeightDecay = 0.0;
972 const Dtype kMomentum = 0.95;
973 const int kNumIters = 1;
974 for (int i = 0; i <= kNumIters; ++i) {
975 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
979 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
980 typedef typename TypeParam::Dtype Dtype;
981 const Dtype kLearningRate = 0.1;
982 const Dtype kWeightDecay = 0.0;
983 const Dtype kMomentum = 0.95;
984 const int kNumIters = 4;
985 for (int i = 0; i <= kNumIters; ++i) {
986 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
990 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
991 typedef typename TypeParam::Dtype Dtype;
992 const Dtype kLearningRate = 0.1;
993 const Dtype kWeightDecay = 0.1;
994 const Dtype kMomentum = 0.95;
995 const int kNumIters = 4;
996 for (int i = 0; i <= kNumIters; ++i) {
997 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1001 TYPED_TEST(AdaDeltaSolverTest,
1002 TestAdaDeltaLeastSquaresUpdateWithEverythingShare) {
1003 typedef typename TypeParam::Dtype Dtype;
1004 const Dtype kLearningRate = 0.1;
1005 const Dtype kWeightDecay = 0.1;
1006 const Dtype kMomentum = 0.95;
1007 const int kNumIters = 4;
1008 this->share_ = true;
1009 for (int i = 0; i <= kNumIters; ++i) {
1010 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1014 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1015 typedef typename TypeParam::Dtype Dtype;
1016 const Dtype kLearningRate = 0.1;
1017 const Dtype kWeightDecay = 0.1;
1018 const Dtype kMomentum = 0.95;
1019 const int kNumIters = 4;
1020 const int kIterSize = 2;
1021 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1025 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1026 typedef typename TypeParam::Dtype Dtype;
1027 const Dtype kLearningRate = 0.1;
1028 const Dtype kWeightDecay = 0.1;
1029 const Dtype kMomentum = 0.95;
1030 const int kNumIters = 4;
1031 const int kIterSize = 2;
1032 this->share_ = true;
1033 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1037 TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) {
1038 typedef typename TypeParam::Dtype Dtype;
1039 const Dtype kLearningRate = 0.1;
1040 const Dtype kWeightDecay = 0.1;
1041 const Dtype kMomentum = 0.95;
1042 const int kNumIters = 4;
1043 for (int i = 1; i <= kNumIters; ++i) {
1044 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1048 TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) {
1049 typedef typename TypeParam::Dtype Dtype;
1050 const Dtype kLearningRate = 0.1;
1051 const Dtype kWeightDecay = 0.1;
1052 const Dtype kMomentum = 0.95;
1053 const int kNumIters = 4;
1054 this->share_ = true;
1055 for (int i = 1; i <= kNumIters; ++i) {
1056 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1060 template <typename TypeParam>
1061 class AdamSolverTest : public GradientBasedSolverTest<TypeParam> {
1062 typedef typename TypeParam::Dtype Dtype;
1065 virtual void InitSolver(const SolverParameter& param) {
1066 SolverParameter new_param = param;
1067 const Dtype momentum = 0.9;
1068 new_param.set_momentum(momentum);
1069 const Dtype momentum2 = 0.999;
1070 new_param.set_momentum2(momentum2);
1071 this->solver_.reset(new AdamSolver<Dtype>(new_param));
1075 TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices);
1077 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) {
1078 typedef typename TypeParam::Dtype Dtype;
1079 const Dtype kLearningRate = 0.01;
1080 const Dtype kWeightDecay = 0;
1081 const Dtype kMomentum = 0.9;
1082 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
1085 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithWeightDecay) {
1086 typedef typename TypeParam::Dtype Dtype;
1087 const Dtype kLearningRate = 0.01;
1088 const Dtype kWeightDecay = 0.5;
1089 const Dtype kMomentum = 0.9;
1090 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
1093 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverything) {
1094 typedef typename TypeParam::Dtype Dtype;
1095 const Dtype kLearningRate = 0.01;
1096 const Dtype kWeightDecay = 0.5;
1097 const Dtype kMomentum = 0.9;
1098 const int kNumIters = 4;
1099 for (int i = 0; i <= kNumIters; ++i) {
1100 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1104 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverythingShare) {
1105 typedef typename TypeParam::Dtype Dtype;
1106 const Dtype kLearningRate = 0.01;
1107 const Dtype kWeightDecay = 0.5;
1108 const Dtype kMomentum = 0.9;
1109 const int kNumIters = 4;
1110 this->share_ = true;
1111 for (int i = 0; i <= kNumIters; ++i) {
1112 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1116 TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1117 typedef typename TypeParam::Dtype Dtype;
1118 const Dtype kLearningRate = 0.01;
1119 const Dtype kWeightDecay = 0.5;
1120 const Dtype kMomentum = 0.9;
1121 const int kNumIters = 4;
1122 const int kIterSize = 2;
1123 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1127 TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1128 typedef typename TypeParam::Dtype Dtype;
1129 const Dtype kLearningRate = 0.01;
1130 const Dtype kWeightDecay = 0.5;
1131 const Dtype kMomentum = 0.9;
1132 const int kNumIters = 4;
1133 const int kIterSize = 2;
1134 this->share_ = true;
1135 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1139 TYPED_TEST(AdamSolverTest, TestSnapshot) {
1140 typedef typename TypeParam::Dtype Dtype;
1141 const Dtype kLearningRate = 0.01;
1142 const Dtype kWeightDecay = 0.5;
1143 const Dtype kMomentum = 0.9;
1144 const int kNumIters = 4;
1145 for (int i = 1; i <= kNumIters; ++i) {
1146 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1150 TYPED_TEST(AdamSolverTest, TestSnapshotShare) {
1151 typedef typename TypeParam::Dtype Dtype;
1152 const Dtype kLearningRate = 0.01;
1153 const Dtype kWeightDecay = 0.5;
1154 const Dtype kMomentum = 0.9;
1155 const int kNumIters = 4;
1156 this->share_ = true;
1157 for (int i = 1; i <= kNumIters; ++i) {
1158 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1162 template <typename TypeParam>
1163 class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
1164 typedef typename TypeParam::Dtype Dtype;
1167 virtual void InitSolver(const SolverParameter& param) {
1168 const Dtype rms_decay = 0.95;
1169 SolverParameter new_param = param;
1170 new_param.set_rms_decay(rms_decay);
1171 this->solver_.reset(new RMSPropSolver<Dtype>(new_param));
1175 TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);
1177 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithWeightDecay) {
1178 typedef typename TypeParam::Dtype Dtype;
1179 const Dtype kLearningRate = 1.0;
1180 const Dtype kWeightDecay = 0.5;
1181 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
1184 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) {
1185 typedef typename TypeParam::Dtype Dtype;
1186 const Dtype kLearningRate = 0.01;
1187 const Dtype kWeightDecay = 0.0;
1188 const Dtype kMomentum = 0.0;
1189 const int kNumIters = 4;
1190 for (int i = 0; i <= kNumIters; ++i) {
1191 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1195 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) {
1196 typedef typename TypeParam::Dtype Dtype;
1197 const Dtype kLearningRate = 0.01;
1198 const Dtype kWeightDecay = 0.5;
1199 const Dtype kMomentum = 0.0;
1200 const int kNumIters = 4;
1201 for (int i = 0; i <= kNumIters; ++i) {
1202 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1206 TYPED_TEST(RMSPropSolverTest,
1207 TestRMSPropLeastSquaresUpdateWithEverythingShare) {
1208 typedef typename TypeParam::Dtype Dtype;
1209 const Dtype kLearningRate = 0.01;
1210 const Dtype kWeightDecay = 0.5;
1211 const Dtype kMomentum = 0.0;
1212 const int kNumIters = 4;
1213 this->share_ = true;
1214 for (int i = 0; i <= kNumIters; ++i) {
1215 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1219 TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1220 typedef typename TypeParam::Dtype Dtype;
1221 const Dtype kLearningRate = 0.01;
1222 const Dtype kWeightDecay = 0.5;
1223 const Dtype kMomentum = 0.0;
1224 const int kNumIters = 4;
1225 const int kIterSize = 2;
1226 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1230 TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1231 typedef typename TypeParam::Dtype Dtype;
1232 const Dtype kLearningRate = 0.01;
1233 const Dtype kWeightDecay = 0.5;
1234 const Dtype kMomentum = 0.0;
1235 const int kNumIters = 4;
1236 const int kIterSize = 2;
1237 this->share_ = true;
1238 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1242 TYPED_TEST(RMSPropSolverTest, TestSnapshot) {
1243 typedef typename TypeParam::Dtype Dtype;
1244 const Dtype kLearningRate = 0.01;
1245 const Dtype kWeightDecay = 0.5;
1246 const Dtype kMomentum = 0;
1247 const int kNumIters = 4;
1248 for (int i = 1; i <= kNumIters; ++i) {
1249 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1253 TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) {
1254 typedef typename TypeParam::Dtype Dtype;
1255 const Dtype kLearningRate = 0.01;
1256 const Dtype kWeightDecay = 0.5;
1257 const Dtype kMomentum = 0;
1258 const int kNumIters = 4;
1259 this->share_ = true;
1260 for (int i = 1; i <= kNumIters; ++i) {
1261 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1265 } // namespace caffe