From 4dcbe9287d334a06049bc95af37143fc73f15536 Mon Sep 17 00:00:00 2001 From: Ronghang Hu Date: Sun, 9 Aug 2015 01:32:25 -0700 Subject: [PATCH] Encapsulate kRMSDecay in solver tests Instead of introducing another argument kRMSDecay and setting it for every test, this param could be set by the RMSProp test class for encapsulation. --- src/caffe/test/test_gradient_based_solver.cpp | 156 ++++++++++---------------- 1 file changed, 58 insertions(+), 98 deletions(-) diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index b091892..e3c6b6d 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -58,8 +58,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { } string RunLeastSquaresSolver(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, const Dtype rms_decay, - const int num_iters, const int iter_size = 1, const bool snapshot = false, + const Dtype weight_decay, const Dtype momentum, const int num_iters, + const int iter_size = 1, const bool snapshot = false, const char* from_snapshot = NULL) { ostringstream proto; proto << @@ -174,9 +174,6 @@ class GradientBasedSolverTest : public MultiDeviceTest { if (momentum != 0) { proto << "momentum: " << momentum << " "; } - if (rms_decay != 0) { - proto << "rms_decay: " << rms_decay << " "; - } MakeTempDir(&snapshot_prefix_); proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' "; if (snapshot) { @@ -208,7 +205,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { // updated_params will store the updated weight and bias results, // using the blobs' diffs to hold the update values themselves. void ComputeLeastSquaresUpdate(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, const Dtype rms_decay, + const Dtype weight_decay, const Dtype momentum, vector > >* updated_params) { const int N = num_; const int D = channels_ * height_ * width_; @@ -291,9 +288,11 @@ class GradientBasedSolverTest : public MultiDeviceTest { case SolverParameter_SolverType_ADAGRAD: update_value /= std::sqrt(history_value + grad * grad) + delta_; break; - case SolverParameter_SolverType_RMSPROP: + case SolverParameter_SolverType_RMSPROP: { + const Dtype rms_decay = 0.95; update_value /= std::sqrt(rms_decay*history_value + grad * grad * (1 - rms_decay)) + delta_; + } break; default: LOG(FATAL) << "Unknown solver type: " << solver_type(); @@ -360,14 +359,13 @@ class GradientBasedSolverTest : public MultiDeviceTest { } void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay, - const Dtype kMomentum, const Dtype rms_decay, const int kNumIters, - const int kIterSize) { + const Dtype kMomentum, const int kNumIters, const int kIterSize) { const double kPrecision = 1e-2; const double kMinPrecision = 1e-7; constant_data_ = true; // Solve without accumulation and save parameters. this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum, - rms_decay, kNumIters); + kNumIters); // Save parameters for comparison. Net& net = *this->solver_->net(); const vector > >& param_blobs = @@ -379,7 +377,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { } // Solve by equivalent accumulation of gradients over divided batches. this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum, - rms_decay, kNumIters, kIterSize); + kNumIters, kIterSize); Net& net_accum = *this->solver_->net(); const vector > >& accum_params = net_accum.layer_by_name("innerprod")->blobs(); @@ -417,18 +415,17 @@ class GradientBasedSolverTest : public MultiDeviceTest { // matches the solver's (K+1)th update. void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, - const Dtype rms_decay = 0.0, const int iter_to_check = 0) { + const int iter_to_check = 0) { // Initialize the solver and run K (= iter_to_check) solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay, - iter_to_check); + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); // Compute the (K+1)th update using the analytic least squares gradient. vector > > updated_params; ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, - rms_decay, &updated_params); + &updated_params); // Reinitialize the solver and run K+1 solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay, + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check + 1); // Check that the solver's solution matches ours. @@ -437,13 +434,13 @@ class GradientBasedSolverTest : public MultiDeviceTest { void TestSnapshot(const Dtype learning_rate = 1.0, const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, - const Dtype rms_decay = 0.0, const int num_iters = 1) { + const int num_iters = 1) { // Run the solver for num_iters * 2 iterations. const int total_num_iters = num_iters * 2; bool snapshot = false; const int kIterSize = 1; - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay, - total_num_iters, kIterSize, snapshot); + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + total_num_iters, kIterSize, snapshot); // Save the resulting param values. vector > > param_copies; @@ -473,11 +470,11 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Run the solver for num_iters iterations and snapshot. snapshot = true; string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay, - momentum, rms_decay, num_iters, kIterSize, snapshot); + momentum, num_iters, kIterSize, snapshot); // Reinitialize the solver and run for num_iters more iterations. snapshot = false; - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay, + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, total_num_iters, kIterSize, snapshot, snapshot_name.c_str()); // Check that params now match. @@ -558,11 +555,9 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0; const Dtype kMomentum = 0.5; - const Dtype kRMSDecay = 0; const int kNumIters = 1; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -571,11 +566,9 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0; const Dtype kMomentum = 0.5; - const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -584,11 +577,9 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.5; - const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -597,12 +588,10 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.5; - const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -611,11 +600,10 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; - const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, - kNumIters, kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); } TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { @@ -623,12 +611,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; - const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; this->share_ = true; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, - kNumIters, kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); } TYPED_TEST(SGDSolverTest, TestSnapshot) { @@ -636,10 +623,9 @@ TYPED_TEST(SGDSolverTest, TestSnapshot) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; - const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -648,11 +634,10 @@ TYPED_TEST(SGDSolverTest, TestSnapshotShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; - const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -694,11 +679,9 @@ TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; - const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -708,12 +691,10 @@ TYPED_TEST(AdaGradSolverTest, const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; - const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -722,11 +703,10 @@ TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; - const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, - kNumIters, kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); } TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { @@ -734,12 +714,11 @@ TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; - const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; this->share_ = true; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, - kNumIters, kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); } TYPED_TEST(AdaGradSolverTest, TestSnapshot) { @@ -747,10 +726,9 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshot) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; - const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -759,11 +737,10 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; - const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -817,11 +794,9 @@ TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0; const Dtype kMomentum = 0.5; - const Dtype kRMSDecay = 0; const int kNumIters = 1; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -830,11 +805,9 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0; const Dtype kMomentum = 0.5; - const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -855,12 +828,10 @@ TYPED_TEST(NesterovSolverTest, const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; - const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -869,11 +840,10 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; - const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, - kNumIters, kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); } TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { @@ -881,12 +851,11 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; - const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; this->share_ = true; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, - kNumIters, kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); } TYPED_TEST(NesterovSolverTest, TestSnapshot) { @@ -894,10 +863,9 @@ TYPED_TEST(NesterovSolverTest, TestSnapshot) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; - const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -906,11 +874,10 @@ TYPED_TEST(NesterovSolverTest, TestSnapshotShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; - const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -920,7 +887,10 @@ class RMSPropSolverTest : public GradientBasedSolverTest { protected: virtual void InitSolver(const SolverParameter& param) { - this->solver_.reset(new RMSPropSolver(param)); + const Dtype rms_decay = 0.95; + SolverParameter new_param = param; + new_param.set_rms_decay(rms_decay); + this->solver_.reset(new RMSPropSolver(new_param)); } virtual SolverParameter_SolverType solver_type() { return SolverParameter_SolverType_RMSPROP; @@ -941,11 +911,9 @@ TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.0; const Dtype kMomentum = 0.0; - const Dtype kRMSDecay = 0.95; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -954,11 +922,9 @@ TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.0; - const Dtype kRMSDecay = 0.95; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -968,12 +934,10 @@ TYPED_TEST(RMSPropSolverTest, const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.0; - const Dtype kRMSDecay = 0.95; const int kNumIters = 4; this->share_ = true; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, - kRMSDecay, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -982,11 +946,10 @@ TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.0; - const Dtype kRMSDecay = 0.95; const int kNumIters = 4; const int kIterSize = 2; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, - kNumIters, kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); } TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { @@ -994,12 +957,11 @@ TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.0; - const Dtype kRMSDecay = 0.95; const int kNumIters = 4; const int kIterSize = 2; this->share_ = true; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, - kNumIters, kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); } TYPED_TEST(RMSPropSolverTest, TestSnapshot) { @@ -1007,10 +969,9 @@ TYPED_TEST(RMSPropSolverTest, TestSnapshot) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; - const Dtype kRMSDecay = 0.95; const int kNumIters = 4; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); } } @@ -1019,11 +980,10 @@ TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; - const Dtype kRMSDecay = 0.95; const int kNumIters = 4; this->share_ = true; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); } } -- 2.7.4