Imported Upstream version ceres 1.13.0
[platform/upstream/ceres-solver.git] / internal / ceres / levenberg_marquardt_strategy_test.cc
1 // Ceres Solver - A fast non-linear least squares minimizer
2 // Copyright 2015 Google Inc. All rights reserved.
3 // http://ceres-solver.org/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are met:
7 //
8 // * Redistributions of source code must retain the above copyright notice,
9 //   this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 //   this list of conditions and the following disclaimer in the documentation
12 //   and/or other materials provided with the distribution.
13 // * Neither the name of Google Inc. nor the names of its contributors may be
14 //   used to endorse or promote products derived from this software without
15 //   specific prior written permission.
16 //
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 // POSSIBILITY OF SUCH DAMAGE.
28 //
29 // Author: sameeragarwal@google.com (Sameer Agarwal)
30
31 #include "ceres/internal/eigen.h"
32 #include "ceres/internal/scoped_ptr.h"
33 #include "ceres/levenberg_marquardt_strategy.h"
34 #include "ceres/linear_solver.h"
35 #include "ceres/trust_region_strategy.h"
36 #include "glog/logging.h"
37 #include "gmock/gmock.h"
38 #include "gmock/mock-log.h"
39 #include "gtest/gtest.h"
40
41 using testing::AllOf;
42 using testing::AnyNumber;
43 using testing::HasSubstr;
44 using testing::ScopedMockLog;
45 using testing::_;
46
47 namespace ceres {
48 namespace internal {
49
50 const double kTolerance = 1e-16;
51
52 // Linear solver that takes as input a vector and checks that the
53 // caller passes the same vector as LinearSolver::PerSolveOptions.D.
54 class RegularizationCheckingLinearSolver : public DenseSparseMatrixSolver {
55  public:
56   RegularizationCheckingLinearSolver(const int num_cols, const double* diagonal)
57       : num_cols_(num_cols),
58         diagonal_(diagonal) {
59   }
60
61   virtual ~RegularizationCheckingLinearSolver() {}
62
63  private:
64   virtual LinearSolver::Summary SolveImpl(
65       DenseSparseMatrix* A,
66       const double* b,
67       const LinearSolver::PerSolveOptions& per_solve_options,
68       double* x) {
69     CHECK_NOTNULL(per_solve_options.D);
70     for (int i = 0; i < num_cols_; ++i) {
71       EXPECT_NEAR(per_solve_options.D[i], diagonal_[i], kTolerance)
72           << i << " " << per_solve_options.D[i] << " " << diagonal_[i];
73     }
74     return LinearSolver::Summary();
75   }
76
77   const int num_cols_;
78   const double* diagonal_;
79 };
80
81 TEST(LevenbergMarquardtStrategy, AcceptRejectStepRadiusScaling) {
82   TrustRegionStrategy::Options options;
83   options.initial_radius = 2.0;
84   options.max_radius = 20.0;
85   options.min_lm_diagonal = 1e-8;
86   options.max_lm_diagonal = 1e8;
87
88   // We need a non-null pointer here, so anything should do.
89   scoped_ptr<LinearSolver> linear_solver(
90       new RegularizationCheckingLinearSolver(0, NULL));
91   options.linear_solver = linear_solver.get();
92
93   LevenbergMarquardtStrategy lms(options);
94   EXPECT_EQ(lms.Radius(), options.initial_radius);
95   lms.StepRejected(0.0);
96   EXPECT_EQ(lms.Radius(), 1.0);
97   lms.StepRejected(-1.0);
98   EXPECT_EQ(lms.Radius(), 0.25);
99   lms.StepAccepted(1.0);
100   EXPECT_EQ(lms.Radius(), 0.25 * 3.0);
101   lms.StepAccepted(1.0);
102   EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0);
103   lms.StepAccepted(0.25);
104   EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0 / 1.125);
105   lms.StepAccepted(1.0);
106   EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0 / 1.125 * 3.0);
107   lms.StepAccepted(1.0);
108   EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0 / 1.125 * 3.0 * 3.0);
109   lms.StepAccepted(1.0);
110   EXPECT_EQ(lms.Radius(), options.max_radius);
111 }
112
113 TEST(LevenbergMarquardtStrategy, CorrectDiagonalToLinearSolver) {
114   Matrix jacobian(2, 3);
115   jacobian.setZero();
116   jacobian(0, 0) = 0.0;
117   jacobian(0, 1) = 1.0;
118   jacobian(1, 1) = 1.0;
119   jacobian(0, 2) = 100.0;
120
121   double residual = 1.0;
122   double x[3];
123   DenseSparseMatrix dsm(jacobian);
124
125   TrustRegionStrategy::Options options;
126   options.initial_radius = 2.0;
127   options.max_radius = 20.0;
128   options.min_lm_diagonal = 1e-2;
129   options.max_lm_diagonal = 1e2;
130
131   double diagonal[3];
132   diagonal[0] = options.min_lm_diagonal;
133   diagonal[1] = 2.0;
134   diagonal[2] = options.max_lm_diagonal;
135   for (int i = 0; i < 3; ++i) {
136     diagonal[i] = sqrt(diagonal[i] / options.initial_radius);
137   }
138
139   RegularizationCheckingLinearSolver linear_solver(3, diagonal);
140   options.linear_solver = &linear_solver;
141
142   LevenbergMarquardtStrategy lms(options);
143   TrustRegionStrategy::PerSolveOptions pso;
144
145   {
146     ScopedMockLog log;
147     EXPECT_CALL(log, Log(_, _, _)).Times(AnyNumber());
148     // This using directive is needed get around the fact that there
149     // are versions of glog which are not in the google namespace.
150     using namespace google;
151
152 #if defined(_MSC_VER)
153     // Use GLOG_WARNING to support MSVC if GLOG_NO_ABBREVIATED_SEVERITIES
154     // is defined.
155     EXPECT_CALL(log, Log(GLOG_WARNING, _,
156                          HasSubstr("Failed to compute a step")));
157 #else
158     EXPECT_CALL(log, Log(google::WARNING, _,
159                          HasSubstr("Failed to compute a step")));
160 #endif
161
162     TrustRegionStrategy::Summary summary =
163         lms.ComputeStep(pso, &dsm, &residual, x);
164     EXPECT_EQ(summary.termination_type, LINEAR_SOLVER_FAILURE);
165   }
166 }
167
168 }  // namespace internal
169 }  // namespace ceres