6 #include "google/protobuf/text_format.h"
8 #include "gtest/gtest.h"
10 #include "caffe/common.hpp"
11 #include "caffe/parallel.hpp"
12 #include "caffe/proto/caffe.pb.h"
13 #include "caffe/sgd_solvers.hpp"
14 #include "caffe/util/io.hpp"
16 #include "caffe/test/test_caffe_main.hpp"
18 using std::ostringstream;
22 template <typename TypeParam>
23 class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
24 typedef typename TypeParam::Dtype Dtype;
27 GradientBasedSolverTest() :
28 seed_(1701), num_(4), channels_(3), height_(10), width_(10),
30 input_file_ = new string(
31 CMAKE_SOURCE_DIR "caffe/test/test_data/solver_data_list.txt" CMAKE_EXT);
33 ~GradientBasedSolverTest() {
37 string snapshot_prefix_;
38 shared_ptr<SGDSolver<Dtype> > solver_;
39 shared_ptr<P2PSync<Dtype> > sync_;
41 // Dimensions are determined by generate_sample_data.py
42 // TODO this is brittle and the hdf5 file should be checked instead.
43 int num_, channels_, height_, width_;
45 Dtype delta_; // Stability constant for RMSProp, AdaGrad, AdaDelta and Adam
47 // Test data: check out generate_sample_data.py in the same directory.
50 virtual void InitSolver(const SolverParameter& param) = 0;
52 virtual void InitSolverFromProtoString(const string& proto) {
53 SolverParameter param;
54 CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m));
55 // Set the solver_mode according to current Caffe::mode.
56 switch (Caffe::mode()) {
58 param.set_solver_mode(SolverParameter_SolverMode_CPU);
61 param.set_solver_mode(SolverParameter_SolverMode_GPU);
64 LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
67 delta_ = param.delta();
70 string RunLeastSquaresSolver(const Dtype learning_rate,
71 const Dtype weight_decay, const Dtype momentum, const int num_iters,
72 const int iter_size = 1, const int devices = 1,
73 const bool snapshot = false, const char* from_snapshot = NULL) {
77 if (Caffe::mode() == Caffe::GPU) {
78 CUDA_CHECK(cudaGetDevice(&device_id));
82 "snapshot_after_train: " << snapshot << " "
83 "max_iter: " << num_iters << " "
84 "base_lr: " << learning_rate << " "
86 "iter_size: " << iter_size << " "
87 "device_id: " << device_id << " "
89 " name: 'TestNetwork' "
94 " source: '" << *(this->input_file_) << "' "
95 " batch_size: " << num_ / iter_size << " "
115 " name: 'innerprod' "
116 " type: 'InnerProduct' "
117 " param { name: 'weights' } "
118 " param { name: 'bias' } "
119 " inner_product_param { "
130 " bottom: '" << string(share_ ? "data1": "data") << "' "
131 " top: '" << string(share_ ? "innerprod1": "innerprod") << "' "
136 " name: 'innerprod2' "
137 " type: 'InnerProduct' "
138 " param { name: 'weights' } "
139 " param { name: 'bias' } "
140 " inner_product_param { "
152 " top: 'innerprod2' "
157 " bottom: 'innerprod1' "
158 " bottom: 'innerprod2' "
168 " type: 'EuclideanLoss' "
169 " bottom: 'innerprod' "
170 " bottom: 'targets' "
173 if (weight_decay != 0) {
174 proto << "weight_decay: " << weight_decay << " ";
177 proto << "momentum: " << momentum << " ";
179 MakeTempDir(&snapshot_prefix_);
180 proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' ";
182 proto << "snapshot: " << num_iters << " ";
184 Caffe::set_random_seed(this->seed_);
185 this->InitSolverFromProtoString(proto.str());
186 if (from_snapshot != NULL) {
187 this->solver_->Restore(from_snapshot);
188 for (int i = 0; i < this->solver_->iter(); ++i) {
189 this->solver_->net()->Forward();
193 this->solver_->Solve();
195 LOG(INFO) << "Multi-GPU test on " << devices << " devices";
197 // put current device at the beginning
198 int device_id = solver_->param().device_id();
199 gpus.push_back(device_id);
200 for (int i = 0; gpus.size() < devices; ++i) {
204 Caffe::set_solver_count(gpus.size());
205 this->sync_.reset(new P2PSync<Dtype>(
206 this->solver_, NULL, this->solver_->param()));
207 this->sync_->Run(gpus);
208 Caffe::set_solver_count(1);
211 ostringstream resume_file;
212 resume_file << snapshot_prefix_ << "/_iter_" << num_iters
214 string resume_filename = resume_file.str();
215 return resume_filename;
220 // Compute an update value given the current state of the train net,
221 // using the analytical formula for the least squares gradient.
222 // updated_params will store the updated weight and bias results,
223 // using the blobs' diffs to hold the update values themselves.
224 void ComputeLeastSquaresUpdate(const Dtype learning_rate,
225 const Dtype weight_decay, const Dtype momentum, const int num_iters,
226 vector<shared_ptr<Blob<Dtype> > >* updated_params) {
228 const int D = channels_ * height_ * width_;
230 // Run a forward pass, and manually compute the update values from the
232 Net<Dtype>& net = *this->solver_->net();
234 ASSERT_TRUE(net.has_blob("data"));
235 const Blob<Dtype>& data = *net.blob_by_name("data");
236 ASSERT_TRUE(net.has_blob("targets"));
237 const Blob<Dtype>& targets = *net.blob_by_name("targets");
238 ASSERT_TRUE(net.has_layer("innerprod"));
239 const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
240 net.layer_by_name("innerprod")->blobs();
241 const int num_param_blobs = 2;
242 ASSERT_EQ(num_param_blobs, param_blobs.size());
243 const Blob<Dtype>& weights = *param_blobs[0];
244 const Blob<Dtype>& bias = *param_blobs[1];
245 ASSERT_EQ(D * N, data.count());
246 ASSERT_EQ(N, targets.count());
247 ASSERT_EQ(D, weights.count());
248 ASSERT_EQ(1, bias.count());
250 updated_params->clear();
251 updated_params->resize(num_param_blobs);
252 for (int i = 0; i < num_param_blobs; ++i) {
253 (*updated_params)[i].reset(new Blob<Dtype>());
255 Blob<Dtype>& updated_weights = *(*updated_params)[0];
256 updated_weights.ReshapeLike(weights);
257 Blob<Dtype>& updated_bias = *(*updated_params)[1];
258 updated_bias.ReshapeLike(bias);
260 for (int i = 0; i <= D; ++i) {
261 // Compute the derivative with respect to the ith weight (i.e., the ith
262 // element of the gradient).
264 for (int j = 0; j <= D; ++j) {
265 // Compute element (i, j) of X^T * X.
267 for (int k = 0; k < N; ++k) {
268 // (i, k) in X^T (== (k, i) in X) times (k, j) in X.
269 const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
270 const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j];
271 element += element_i * element_j;
274 grad += element * bias.cpu_data()[0];
276 grad += element * weights.cpu_data()[j];
279 for (int k = 0; k < N; ++k) {
280 const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
281 grad -= element_i * targets.cpu_data()[k];
283 // Scale the gradient over the N samples.
285 // Add the weight decay to the gradient.
286 grad += weight_decay *
287 ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
288 // Finally, compute update.
289 const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
290 if (solver_->type() != string("AdaDelta")
291 && solver_->type() != string("Adam")) {
292 ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias
294 ASSERT_EQ(4, history.size()); // additional blobs for update history
296 Dtype update_value = learning_rate * grad;
297 const Dtype history_value = (i == D) ?
298 history[1]->cpu_data()[0] : history[0]->cpu_data()[i];
299 const Dtype temp = momentum * history_value;
300 if (solver_->type() == string("SGD")) {
301 update_value += temp;
302 } else if (solver_->type() == string("Nesterov")) {
303 update_value += temp;
304 // step back then over-step
305 update_value = (1 + momentum) * update_value - temp;
306 } else if (solver_->type() == string("AdaGrad")) {
307 update_value /= std::sqrt(history_value + grad * grad) + delta_;
308 } else if (solver_->type() == string("RMSProp")) {
309 const Dtype rms_decay = 0.95;
310 update_value /= std::sqrt(rms_decay*history_value
311 + grad * grad * (1 - rms_decay)) + delta_;
312 } else if (solver_->type() == string("AdaDelta")) {
313 const Dtype update_history_value = (i == D) ?
314 history[1 + num_param_blobs]->cpu_data()[0] :
315 history[0 + num_param_blobs]->cpu_data()[i];
316 const Dtype weighted_gradient_average =
317 momentum * history_value + (1 - momentum) * (grad * grad);
318 update_value = grad * std::sqrt((update_history_value + delta_) /
319 (weighted_gradient_average + delta_)) * learning_rate;
320 // not actually needed, just here for illustrative purposes
321 // const Dtype weighted_update_average =
322 // momentum * update_history_value + (1 - momentum) * (update_value);
323 } else if (solver_->type() == string("Adam")) {
324 const Dtype momentum2 = 0.999;
325 const Dtype m = history_value;
326 const Dtype v = (i == D) ?
327 history[1 + num_param_blobs]->cpu_data()[0] :
328 history[0 + num_param_blobs]->cpu_data()[i];
329 const Dtype val_m = (1 - momentum) * grad + momentum * m;
330 const Dtype val_v = (1 - momentum2) * grad * grad + momentum2 * v;
331 Dtype alpha_t = learning_rate *
332 std::sqrt(Dtype(1) - pow(momentum2, num_iters)) /
333 (Dtype(1.) - pow(momentum, num_iters));
334 update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
336 LOG(FATAL) << "Unknown solver type: " << solver_->type();
339 updated_bias.mutable_cpu_diff()[0] = update_value;
340 updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value;
342 updated_weights.mutable_cpu_diff()[i] = update_value;
343 updated_weights.mutable_cpu_data()[i] =
344 weights.cpu_data()[i] - update_value;
349 void CheckLeastSquaresUpdate(
350 const vector<shared_ptr<Blob<Dtype> > >& updated_params) {
351 const int D = channels_ * height_ * width_;
353 const Blob<Dtype>& updated_weights = *updated_params[0];
354 const Blob<Dtype>& updated_bias = *updated_params[1];
356 Net<Dtype>& net = *this->solver_->net();
357 ASSERT_TRUE(net.has_layer("innerprod"));
358 const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
359 net.layer_by_name("innerprod")->blobs();
360 ASSERT_EQ(2, param_blobs.size());
361 const Blob<Dtype>& solver_updated_weights = *param_blobs[0];
362 ASSERT_EQ(D, solver_updated_weights.count());
363 const double kPrecision = 1e-2;
364 const double kMinPrecision = 1e-7;
365 for (int i = 0; i < D; ++i) {
366 const Dtype expected_updated_weight = updated_weights.cpu_data()[i];
367 const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i];
368 const Dtype error_margin = std::max(kMinPrecision, kPrecision *
369 std::min(fabs(expected_updated_weight), fabs(solver_updated_weight)));
370 EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin);
372 const Blob<Dtype>& solver_updated_bias_blob = *param_blobs[1];
373 ASSERT_EQ(1, solver_updated_bias_blob.count());
374 const Dtype expected_updated_bias = updated_bias.cpu_data()[0];
375 const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0];
376 const Dtype error_margin = std::max(kMinPrecision, kPrecision *
377 std::min(fabs(expected_updated_bias), fabs(solver_updated_bias)));
378 EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);
380 // Check the solver's history -- should contain the previous update value.
381 if (solver_->type() == string("SGD")) {
382 const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
383 ASSERT_EQ(2, history.size());
384 for (int i = 0; i < D; ++i) {
385 const Dtype expected_history = updated_weights.cpu_diff()[i];
386 const Dtype solver_history = history[0]->cpu_data()[i];
387 const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
388 std::min(fabs(expected_history), fabs(solver_history)));
389 EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
391 const Dtype expected_history = updated_bias.cpu_diff()[0];
392 const Dtype solver_history = history[1]->cpu_data()[0];
393 const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
394 std::min(fabs(expected_history), fabs(solver_history)));
395 EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
399 void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay,
400 const Dtype kMomentum, const int kNumIters, const int kIterSize) {
401 const double kPrecision = 1e-2;
402 const double kMinPrecision = 1e-7;
403 // Solve without accumulation and save parameters.
404 this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
406 // Save parameters for comparison.
407 Net<Dtype>& net = *this->solver_->net();
408 const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
409 net.layer_by_name("innerprod")->blobs();
410 vector<shared_ptr<Blob<Dtype> > > noaccum_params(param_blobs.size());
411 for (int i = 0; i < param_blobs.size(); ++i) {
412 noaccum_params[i].reset(new Blob<Dtype>());
413 noaccum_params[i]->CopyFrom(*param_blobs[i], false, true);
415 // Solve by equivalent accumulation of gradients over divided batches.
416 this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
417 kNumIters, kIterSize);
418 Net<Dtype>& net_accum = *this->solver_->net();
419 const vector<shared_ptr<Blob<Dtype> > >& accum_params =
420 net_accum.layer_by_name("innerprod")->blobs();
421 // Compare accumulated parameters against no accumulation standard.
422 const int D = this->channels_ * this->height_ * this->width_;
423 for (int i = 0; i < D; ++i) {
424 const Dtype expected_param = noaccum_params[0]->cpu_data()[i];
425 const Dtype accum_param = accum_params[0]->cpu_data()[i];
426 const Dtype error_margin = std::max(kMinPrecision, kPrecision *
427 std::min(fabs(expected_param), fabs(accum_param)));
428 EXPECT_NEAR(expected_param, accum_param, error_margin);
430 ASSERT_EQ(1, accum_params[1]->count());
431 const Dtype expected_bias = noaccum_params[1]->cpu_data()[0];
432 const Dtype accum_bias = accum_params[1]->cpu_data()[0];
433 const Dtype error_margin = std::max(kMinPrecision, kPrecision *
434 std::min(fabs(expected_bias), fabs(accum_bias)));
435 EXPECT_NEAR(expected_bias, accum_bias, error_margin);
438 // Test that the correct update is computed for a regularized least squares
441 // E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2
442 // \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w
444 // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1)
445 // w \in R^{(d+1) x 1} ((d+1)th element is the bias)
447 // lambda is weight_decay
449 // TestLeastSquaresUpdate works "inductively", assuming that the solver
450 // correctly updates the net K (= iter_to_check) times, then given the history
451 // from the Kth update, we compute the (K+1)th update and check that it
452 // matches the solver's (K+1)th update.
453 void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
454 const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
455 const int iter_to_check = 0) {
456 const int kNum = num_;
457 const int kIterSize = 1;
458 // Test over all numbers of devices.
459 int available_devices = 1;
461 if (Caffe::mode() == Caffe::GPU) {
462 CUDA_CHECK(cudaGetDeviceCount(&available_devices));
465 for (int devices = 1; devices <= available_devices; ++devices) {
466 // Configure batch size for single / multi device equivalence.
467 // Constant data is needed for multi device as for accumulation.
468 num_ = kNum * devices;
470 // Initialize the solver and run K (= iter_to_check) solver iterations
471 // (on single device).
472 RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
473 iter_to_check, kIterSize, 1);
475 // Compute the (K+1)th update using the analytic least squares gradient.
476 vector<shared_ptr<Blob<Dtype> > > updated_params;
477 ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
478 iter_to_check + 1, &updated_params);
480 // Reinitialize the solver and run K+1 solver iterations.
482 RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
483 iter_to_check + 1, kIterSize, devices);
485 // Check that the solver's solution matches ours.
486 CheckLeastSquaresUpdate(updated_params);
490 void TestSnapshot(const Dtype learning_rate = 1.0,
491 const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
492 const int num_iters = 1) {
493 // Run the solver for num_iters * 2 iterations.
494 const int total_num_iters = num_iters * 2;
495 bool snapshot = false;
496 const int kIterSize = 1;
497 const int kDevices = 1;
498 RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
499 total_num_iters, kIterSize, kDevices, snapshot);
501 // Save the resulting param values.
502 vector<shared_ptr<Blob<Dtype> > > param_copies;
503 const vector<Blob<Dtype>*>& orig_params =
504 solver_->net()->learnable_params();
505 param_copies.resize(orig_params.size());
506 for (int i = 0; i < orig_params.size(); ++i) {
507 param_copies[i].reset(new Blob<Dtype>());
508 const bool kReshape = true;
509 for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
510 param_copies[i]->CopyFrom(*orig_params[i], copy_diff, kReshape);
514 // Save the solver history
515 vector<shared_ptr<Blob<Dtype> > > history_copies;
516 const vector<shared_ptr<Blob<Dtype> > >& orig_history = solver_->history();
517 history_copies.resize(orig_history.size());
518 for (int i = 0; i < orig_history.size(); ++i) {
519 history_copies[i].reset(new Blob<Dtype>());
520 const bool kReshape = true;
521 for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
522 history_copies[i]->CopyFrom(*orig_history[i], copy_diff, kReshape);
526 // Run the solver for num_iters iterations and snapshot.
528 string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
529 momentum, num_iters, kIterSize, kDevices, snapshot);
531 // Reinitialize the solver and run for num_iters more iterations.
533 RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
534 total_num_iters, kIterSize, kDevices,
535 snapshot, snapshot_name.c_str());
537 // Check that params now match.
538 const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params();
539 for (int i = 0; i < params.size(); ++i) {
540 for (int j = 0; j < params[i]->count(); ++j) {
541 EXPECT_EQ(param_copies[i]->cpu_data()[j], params[i]->cpu_data()[j])
542 << "param " << i << " data differed at dim " << j;
543 EXPECT_EQ(param_copies[i]->cpu_diff()[j], params[i]->cpu_diff()[j])
544 << "param " << i << " diff differed at dim " << j;
548 // Check that history now matches.
549 const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
550 for (int i = 0; i < history.size(); ++i) {
551 for (int j = 0; j < history[i]->count(); ++j) {
552 EXPECT_EQ(history_copies[i]->cpu_data()[j], history[i]->cpu_data()[j])
553 << "history blob " << i << " data differed at dim " << j;
554 EXPECT_EQ(history_copies[i]->cpu_diff()[j], history[i]->cpu_diff()[j])
555 << "history blob " << i << " diff differed at dim " << j;
562 template <typename TypeParam>
563 class SGDSolverTest : public GradientBasedSolverTest<TypeParam> {
564 typedef typename TypeParam::Dtype Dtype;
567 virtual void InitSolver(const SolverParameter& param) {
568 this->solver_.reset(new SGDSolver<Dtype>(param));
572 TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices);
574 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdate) {
575 this->TestLeastSquaresUpdate();
578 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateLROneHundredth) {
579 typedef typename TypeParam::Dtype Dtype;
580 const Dtype kLearningRate = 0.01;
581 this->TestLeastSquaresUpdate(kLearningRate);
584 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecay) {
585 typedef typename TypeParam::Dtype Dtype;
586 const Dtype kLearningRate = 0.01;
587 const Dtype kWeightDecay = 0.5;
588 const Dtype kMomentum = 0;
589 const int kNumIters = 1;
590 for (int i = 0; i <= kNumIters; ++i) {
591 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
595 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecayMultiIter) {
596 typedef typename TypeParam::Dtype Dtype;
597 const Dtype kLearningRate = 0.01;
598 const Dtype kWeightDecay = 0.5;
599 const Dtype kMomentum = 0;
600 const int kNumIters = 4;
601 for (int i = 0; i <= kNumIters; ++i) {
602 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
606 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) {
607 typedef typename TypeParam::Dtype Dtype;
608 const Dtype kLearningRate = 0.01;
609 const Dtype kWeightDecay = 0;
610 const Dtype kMomentum = 0.5;
611 const int kNumIters = 1;
612 for (int i = 0; i <= kNumIters; ++i) {
613 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
617 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
618 typedef typename TypeParam::Dtype Dtype;
619 const Dtype kLearningRate = 0.01;
620 const Dtype kWeightDecay = 0;
621 const Dtype kMomentum = 0.5;
622 const int kNumIters = 4;
623 for (int i = 0; i <= kNumIters; ++i) {
624 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
628 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) {
629 typedef typename TypeParam::Dtype Dtype;
630 const Dtype kLearningRate = 0.01;
631 const Dtype kWeightDecay = 0.5;
632 const Dtype kMomentum = 0.5;
633 const int kNumIters = 4;
634 for (int i = 0; i <= kNumIters; ++i) {
635 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
639 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingShare) {
640 typedef typename TypeParam::Dtype Dtype;
641 const Dtype kLearningRate = 0.01;
642 const Dtype kWeightDecay = 0.5;
643 const Dtype kMomentum = 0.5;
644 const int kNumIters = 4;
646 for (int i = 0; i <= kNumIters; ++i) {
647 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
651 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
652 typedef typename TypeParam::Dtype Dtype;
653 const Dtype kLearningRate = 0.01;
654 const Dtype kWeightDecay = 0.5;
655 const Dtype kMomentum = 0.9;
656 const int kNumIters = 4;
657 const int kIterSize = 2;
658 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
662 TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
663 typedef typename TypeParam::Dtype Dtype;
664 const Dtype kLearningRate = 0.01;
665 const Dtype kWeightDecay = 0.5;
666 const Dtype kMomentum = 0.9;
667 const int kNumIters = 4;
668 const int kIterSize = 2;
670 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
674 TYPED_TEST(SGDSolverTest, TestSnapshot) {
675 typedef typename TypeParam::Dtype Dtype;
676 const Dtype kLearningRate = 0.01;
677 const Dtype kWeightDecay = 0.5;
678 const Dtype kMomentum = 0.9;
679 const int kNumIters = 4;
680 for (int i = 1; i <= kNumIters; ++i) {
681 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
685 TYPED_TEST(SGDSolverTest, TestSnapshotShare) {
686 typedef typename TypeParam::Dtype Dtype;
687 const Dtype kLearningRate = 0.01;
688 const Dtype kWeightDecay = 0.5;
689 const Dtype kMomentum = 0.9;
690 const int kNumIters = 4;
692 for (int i = 1; i <= kNumIters; ++i) {
693 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
697 TYPED_TEST(SGDSolverTest, TestSolverType) {
698 this->TestLeastSquaresUpdate();
699 EXPECT_NE(this->solver_->type(), string(""));
700 EXPECT_EQ(this->solver_->type(), this->solver_->param().type());
703 template <typename TypeParam>
704 class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
705 typedef typename TypeParam::Dtype Dtype;
708 virtual void InitSolver(const SolverParameter& param) {
709 this->solver_.reset(new AdaGradSolver<Dtype>(param));
713 TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices);
715 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdate) {
716 this->TestLeastSquaresUpdate();
719 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateLROneHundredth) {
720 typedef typename TypeParam::Dtype Dtype;
721 const Dtype kLearningRate = 0.01;
722 this->TestLeastSquaresUpdate(kLearningRate);
725 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithWeightDecay) {
726 typedef typename TypeParam::Dtype Dtype;
727 const Dtype kLearningRate = 0.01;
728 const Dtype kWeightDecay = 0.5;
729 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
732 TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) {
733 typedef typename TypeParam::Dtype Dtype;
734 const Dtype kLearningRate = 0.01;
735 const Dtype kWeightDecay = 0.5;
736 const Dtype kMomentum = 0;
737 const int kNumIters = 4;
738 for (int i = 0; i <= kNumIters; ++i) {
739 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
743 TYPED_TEST(AdaGradSolverTest,
744 TestAdaGradLeastSquaresUpdateWithEverythingShare) {
745 typedef typename TypeParam::Dtype Dtype;
746 const Dtype kLearningRate = 0.01;
747 const Dtype kWeightDecay = 0.5;
748 const Dtype kMomentum = 0;
749 const int kNumIters = 4;
751 for (int i = 0; i <= kNumIters; ++i) {
752 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
756 TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
757 typedef typename TypeParam::Dtype Dtype;
758 const Dtype kLearningRate = 0.01;
759 const Dtype kWeightDecay = 0.5;
760 const Dtype kMomentum = 0;
761 const int kNumIters = 4;
762 const int kIterSize = 2;
763 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
767 TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
768 typedef typename TypeParam::Dtype Dtype;
769 const Dtype kLearningRate = 0.01;
770 const Dtype kWeightDecay = 0.5;
771 const Dtype kMomentum = 0;
772 const int kNumIters = 4;
773 const int kIterSize = 2;
775 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
779 TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
780 typedef typename TypeParam::Dtype Dtype;
781 const Dtype kLearningRate = 0.01;
782 const Dtype kWeightDecay = 0.5;
783 const Dtype kMomentum = 0;
784 const int kNumIters = 4;
785 for (int i = 1; i <= kNumIters; ++i) {
786 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
790 TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) {
791 typedef typename TypeParam::Dtype Dtype;
792 const Dtype kLearningRate = 0.01;
793 const Dtype kWeightDecay = 0.5;
794 const Dtype kMomentum = 0;
795 const int kNumIters = 4;
797 for (int i = 1; i <= kNumIters; ++i) {
798 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
803 template <typename TypeParam>
804 class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
805 typedef typename TypeParam::Dtype Dtype;
808 virtual void InitSolver(const SolverParameter& param) {
809 this->solver_.reset(new NesterovSolver<Dtype>(param));
813 TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices);
815 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdate) {
816 this->TestLeastSquaresUpdate();
819 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateLROneHundredth) {
820 typedef typename TypeParam::Dtype Dtype;
821 const Dtype kLearningRate = 0.01;
822 this->TestLeastSquaresUpdate(kLearningRate);
825 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithWeightDecay) {
826 typedef typename TypeParam::Dtype Dtype;
827 const Dtype kLearningRate = 0.01;
828 const Dtype kWeightDecay = 0.5;
829 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
832 TYPED_TEST(NesterovSolverTest,
833 TestNesterovLeastSquaresUpdateWithWeightDecayMultiIter) {
834 typedef typename TypeParam::Dtype Dtype;
835 const Dtype kLearningRate = 0.01;
836 const Dtype kWeightDecay = 0.5;
837 const Dtype kMomentum = 0;
838 const int kNumIters = 4;
839 for (int i = 0; i <= kNumIters; ++i) {
840 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
844 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) {
845 typedef typename TypeParam::Dtype Dtype;
846 const Dtype kLearningRate = 0.01;
847 const Dtype kWeightDecay = 0;
848 const Dtype kMomentum = 0.5;
849 const int kNumIters = 1;
850 for (int i = 0; i <= kNumIters; ++i) {
851 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
855 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
856 typedef typename TypeParam::Dtype Dtype;
857 const Dtype kLearningRate = 0.01;
858 const Dtype kWeightDecay = 0;
859 const Dtype kMomentum = 0.5;
860 const int kNumIters = 4;
861 for (int i = 0; i <= kNumIters; ++i) {
862 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
866 TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) {
867 typedef typename TypeParam::Dtype Dtype;
868 const Dtype kLearningRate = 0.01;
869 const Dtype kWeightDecay = 0.5;
870 const Dtype kMomentum = 0.9;
871 const int kNumIters = 4;
872 for (int i = 0; i <= kNumIters; ++i) {
873 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
877 TYPED_TEST(NesterovSolverTest,
878 TestNesterovLeastSquaresUpdateWithEverythingShare) {
879 typedef typename TypeParam::Dtype Dtype;
880 const Dtype kLearningRate = 0.01;
881 const Dtype kWeightDecay = 0.5;
882 const Dtype kMomentum = 0.9;
883 const int kNumIters = 4;
885 for (int i = 0; i <= kNumIters; ++i) {
886 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
890 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
891 typedef typename TypeParam::Dtype Dtype;
892 const Dtype kLearningRate = 0.01;
893 const Dtype kWeightDecay = 0.5;
894 const Dtype kMomentum = 0.9;
895 const int kNumIters = 4;
896 const int kIterSize = 2;
897 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
901 TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
902 typedef typename TypeParam::Dtype Dtype;
903 const Dtype kLearningRate = 0.01;
904 const Dtype kWeightDecay = 0.5;
905 const Dtype kMomentum = 0.9;
906 const int kNumIters = 4;
907 const int kIterSize = 2;
909 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
913 TYPED_TEST(NesterovSolverTest, TestSnapshot) {
914 typedef typename TypeParam::Dtype Dtype;
915 const Dtype kLearningRate = 0.01;
916 const Dtype kWeightDecay = 0.5;
917 const Dtype kMomentum = 0.9;
918 const int kNumIters = 4;
919 for (int i = 1; i <= kNumIters; ++i) {
920 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
924 TYPED_TEST(NesterovSolverTest, TestSnapshotShare) {
925 typedef typename TypeParam::Dtype Dtype;
926 const Dtype kLearningRate = 0.01;
927 const Dtype kWeightDecay = 0.5;
928 const Dtype kMomentum = 0.9;
929 const int kNumIters = 4;
931 for (int i = 1; i <= kNumIters; ++i) {
932 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
936 template <typename TypeParam>
937 class AdaDeltaSolverTest : public GradientBasedSolverTest<TypeParam> {
938 typedef typename TypeParam::Dtype Dtype;
941 virtual void InitSolver(const SolverParameter& param) {
942 this->solver_.reset(new AdaDeltaSolver<Dtype>(param));
946 TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices);
948 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) {
949 typedef typename TypeParam::Dtype Dtype;
950 const Dtype kLearningRate = 0.1;
951 this->TestLeastSquaresUpdate(kLearningRate);
954 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) {
955 typedef typename TypeParam::Dtype Dtype;
956 const Dtype kLearningRate = 0.1;
957 const Dtype kWeightDecay = 0.5;
958 const Dtype kMomentum = 0.95;
959 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
962 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) {
963 typedef typename TypeParam::Dtype Dtype;
964 const Dtype kLearningRate = 0.1;
965 const Dtype kWeightDecay = 0.0;
966 const Dtype kMomentum = 0.5;
967 const int kNumIters = 1;
968 for (int i = 0; i <= kNumIters; ++i) {
969 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
973 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) {
974 typedef typename TypeParam::Dtype Dtype;
975 const Dtype kLearningRate = 0.1;
976 const Dtype kWeightDecay = 0.0;
977 const Dtype kMomentum = 0.95;
978 const int kNumIters = 1;
979 for (int i = 0; i <= kNumIters; ++i) {
980 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
984 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
985 typedef typename TypeParam::Dtype Dtype;
986 const Dtype kLearningRate = 0.1;
987 const Dtype kWeightDecay = 0.0;
988 const Dtype kMomentum = 0.95;
989 const int kNumIters = 4;
990 for (int i = 0; i <= kNumIters; ++i) {
991 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
995 TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
996 typedef typename TypeParam::Dtype Dtype;
997 const Dtype kLearningRate = 0.1;
998 const Dtype kWeightDecay = 0.1;
999 const Dtype kMomentum = 0.95;
1000 const int kNumIters = 4;
1001 for (int i = 0; i <= kNumIters; ++i) {
1002 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1006 TYPED_TEST(AdaDeltaSolverTest,
1007 TestAdaDeltaLeastSquaresUpdateWithEverythingShare) {
1008 typedef typename TypeParam::Dtype Dtype;
1009 const Dtype kLearningRate = 0.1;
1010 const Dtype kWeightDecay = 0.1;
1011 const Dtype kMomentum = 0.95;
1012 const int kNumIters = 4;
1013 this->share_ = true;
1014 for (int i = 0; i <= kNumIters; ++i) {
1015 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1019 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1020 typedef typename TypeParam::Dtype Dtype;
1021 const Dtype kLearningRate = 0.1;
1022 const Dtype kWeightDecay = 0.1;
1023 const Dtype kMomentum = 0.95;
1024 const int kNumIters = 4;
1025 const int kIterSize = 2;
1026 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1030 TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1031 typedef typename TypeParam::Dtype Dtype;
1032 const Dtype kLearningRate = 0.1;
1033 const Dtype kWeightDecay = 0.1;
1034 const Dtype kMomentum = 0.95;
1035 const int kNumIters = 4;
1036 const int kIterSize = 2;
1037 this->share_ = true;
1038 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1042 TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) {
1043 typedef typename TypeParam::Dtype Dtype;
1044 const Dtype kLearningRate = 0.1;
1045 const Dtype kWeightDecay = 0.1;
1046 const Dtype kMomentum = 0.95;
1047 const int kNumIters = 4;
1048 for (int i = 1; i <= kNumIters; ++i) {
1049 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1053 TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) {
1054 typedef typename TypeParam::Dtype Dtype;
1055 const Dtype kLearningRate = 0.1;
1056 const Dtype kWeightDecay = 0.1;
1057 const Dtype kMomentum = 0.95;
1058 const int kNumIters = 4;
1059 this->share_ = true;
1060 for (int i = 1; i <= kNumIters; ++i) {
1061 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1065 template <typename TypeParam>
1066 class AdamSolverTest : public GradientBasedSolverTest<TypeParam> {
1067 typedef typename TypeParam::Dtype Dtype;
1070 virtual void InitSolver(const SolverParameter& param) {
1071 SolverParameter new_param = param;
1072 const Dtype momentum = 0.9;
1073 new_param.set_momentum(momentum);
1074 const Dtype momentum2 = 0.999;
1075 new_param.set_momentum2(momentum2);
1076 this->solver_.reset(new AdamSolver<Dtype>(new_param));
1080 TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices);
1082 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) {
1083 typedef typename TypeParam::Dtype Dtype;
1084 const Dtype kLearningRate = 0.01;
1085 const Dtype kWeightDecay = 0;
1086 const Dtype kMomentum = 0.9;
1087 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
1090 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithWeightDecay) {
1091 typedef typename TypeParam::Dtype Dtype;
1092 const Dtype kLearningRate = 0.01;
1093 const Dtype kWeightDecay = 0.5;
1094 const Dtype kMomentum = 0.9;
1095 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
1098 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverything) {
1099 typedef typename TypeParam::Dtype Dtype;
1100 const Dtype kLearningRate = 0.01;
1101 const Dtype kWeightDecay = 0.5;
1102 const Dtype kMomentum = 0.9;
1103 const int kNumIters = 4;
1104 for (int i = 0; i <= kNumIters; ++i) {
1105 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1109 TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverythingShare) {
1110 typedef typename TypeParam::Dtype Dtype;
1111 const Dtype kLearningRate = 0.01;
1112 const Dtype kWeightDecay = 0.5;
1113 const Dtype kMomentum = 0.9;
1114 const int kNumIters = 4;
1115 this->share_ = true;
1116 for (int i = 0; i <= kNumIters; ++i) {
1117 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1121 TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1122 typedef typename TypeParam::Dtype Dtype;
1123 const Dtype kLearningRate = 0.01;
1124 const Dtype kWeightDecay = 0.5;
1125 const Dtype kMomentum = 0.9;
1126 const int kNumIters = 4;
1127 const int kIterSize = 2;
1128 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1132 TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1133 typedef typename TypeParam::Dtype Dtype;
1134 const Dtype kLearningRate = 0.01;
1135 const Dtype kWeightDecay = 0.5;
1136 const Dtype kMomentum = 0.9;
1137 const int kNumIters = 4;
1138 const int kIterSize = 2;
1139 this->share_ = true;
1140 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1144 TYPED_TEST(AdamSolverTest, TestSnapshot) {
1145 typedef typename TypeParam::Dtype Dtype;
1146 const Dtype kLearningRate = 0.01;
1147 const Dtype kWeightDecay = 0.5;
1148 const Dtype kMomentum = 0.9;
1149 const int kNumIters = 4;
1150 for (int i = 1; i <= kNumIters; ++i) {
1151 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1155 TYPED_TEST(AdamSolverTest, TestSnapshotShare) {
1156 typedef typename TypeParam::Dtype Dtype;
1157 const Dtype kLearningRate = 0.01;
1158 const Dtype kWeightDecay = 0.5;
1159 const Dtype kMomentum = 0.9;
1160 const int kNumIters = 4;
1161 this->share_ = true;
1162 for (int i = 1; i <= kNumIters; ++i) {
1163 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1167 template <typename TypeParam>
1168 class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
1169 typedef typename TypeParam::Dtype Dtype;
1172 virtual void InitSolver(const SolverParameter& param) {
1173 const Dtype rms_decay = 0.95;
1174 SolverParameter new_param = param;
1175 new_param.set_rms_decay(rms_decay);
1176 this->solver_.reset(new RMSPropSolver<Dtype>(new_param));
1180 TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);
1182 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithWeightDecay) {
1183 typedef typename TypeParam::Dtype Dtype;
1184 const Dtype kLearningRate = 1.0;
1185 const Dtype kWeightDecay = 0.5;
1186 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
1189 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) {
1190 typedef typename TypeParam::Dtype Dtype;
1191 const Dtype kLearningRate = 0.01;
1192 const Dtype kWeightDecay = 0.0;
1193 const Dtype kMomentum = 0.0;
1194 const int kNumIters = 4;
1195 for (int i = 0; i <= kNumIters; ++i) {
1196 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1200 TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) {
1201 typedef typename TypeParam::Dtype Dtype;
1202 const Dtype kLearningRate = 0.01;
1203 const Dtype kWeightDecay = 0.5;
1204 const Dtype kMomentum = 0.0;
1205 const int kNumIters = 4;
1206 for (int i = 0; i <= kNumIters; ++i) {
1207 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1211 TYPED_TEST(RMSPropSolverTest,
1212 TestRMSPropLeastSquaresUpdateWithEverythingShare) {
1213 typedef typename TypeParam::Dtype Dtype;
1214 const Dtype kLearningRate = 0.01;
1215 const Dtype kWeightDecay = 0.5;
1216 const Dtype kMomentum = 0.0;
1217 const int kNumIters = 4;
1218 this->share_ = true;
1219 for (int i = 0; i <= kNumIters; ++i) {
1220 this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
1224 TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
1225 typedef typename TypeParam::Dtype Dtype;
1226 const Dtype kLearningRate = 0.01;
1227 const Dtype kWeightDecay = 0.5;
1228 const Dtype kMomentum = 0.0;
1229 const int kNumIters = 4;
1230 const int kIterSize = 2;
1231 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1235 TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
1236 typedef typename TypeParam::Dtype Dtype;
1237 const Dtype kLearningRate = 0.01;
1238 const Dtype kWeightDecay = 0.5;
1239 const Dtype kMomentum = 0.0;
1240 const int kNumIters = 4;
1241 const int kIterSize = 2;
1242 this->share_ = true;
1243 this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
1247 TYPED_TEST(RMSPropSolverTest, TestSnapshot) {
1248 typedef typename TypeParam::Dtype Dtype;
1249 const Dtype kLearningRate = 0.01;
1250 const Dtype kWeightDecay = 0.5;
1251 const Dtype kMomentum = 0;
1252 const int kNumIters = 4;
1253 for (int i = 1; i <= kNumIters; ++i) {
1254 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1258 TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) {
1259 typedef typename TypeParam::Dtype Dtype;
1260 const Dtype kLearningRate = 0.01;
1261 const Dtype kWeightDecay = 0.5;
1262 const Dtype kMomentum = 0;
1263 const int kNumIters = 4;
1264 this->share_ = true;
1265 for (int i = 1; i <= kNumIters; ++i) {
1266 this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
1270 } // namespace caffe