1 // Ceres Solver - A fast non-linear least squares minimizer
2 // Copyright 2015 Google Inc. All rights reserved.
3 // http://ceres-solver.org/
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are met:
8 // * Redistributions of source code must retain the above copyright notice,
9 // this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 // this list of conditions and the following disclaimer in the documentation
12 // and/or other materials provided with the distribution.
13 // * Neither the name of Google Inc. nor the names of its contributors may be
14 // used to endorse or promote products derived from this software without
15 // specific prior written permission.
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 // POSSIBILITY OF SUCH DAMAGE.
29 // Author: sameeragarwal@google.com (Sameer Agarwal)
31 #include "ceres/loss_function.h"
35 #include "glog/logging.h"
36 #include "gtest/gtest.h"
42 // Helper function for testing a LossFunction callback.
44 // Compares the values of rho'(s) and rho''(s) computed by the
45 // callback with estimates obtained by symmetric finite differencing
47 void AssertLossFunctionIsValid(const LossFunction& loss, double s) {
50 // Evaluate rho(s), rho'(s) and rho''(s).
52 loss.Evaluate(s, rho);
54 // Use symmetric finite differencing to estimate rho'(s) and
56 const double kH = 1e-4;
61 loss.Evaluate(s + kH, fwd);
62 loss.Evaluate(s - kH, bwd);
65 const double fd_1 = (fwd[0] - bwd[0]) / (2 * kH);
66 ASSERT_NEAR(fd_1, rho[1], 1e-6);
69 const double fd_2 = (fwd[0] - 2*rho[0] + bwd[0]) / (kH * kH);
70 ASSERT_NEAR(fd_2, rho[2], 1e-6);
74 // Try two values of the scaling a = 0.7 and 1.3
75 // (where scaling makes sense) and of the squared norm
76 // s = 0.357 and 1.792
78 // Note that for the Huber loss the test exercises both code paths
79 // (i.e. both small and large values of s).
81 TEST(LossFunction, TrivialLoss) {
82 AssertLossFunctionIsValid(TrivialLoss(), 0.357);
83 AssertLossFunctionIsValid(TrivialLoss(), 1.792);
86 TEST(LossFunction, HuberLoss) {
87 AssertLossFunctionIsValid(HuberLoss(0.7), 0.357);
88 AssertLossFunctionIsValid(HuberLoss(0.7), 1.792);
89 AssertLossFunctionIsValid(HuberLoss(1.3), 0.357);
90 AssertLossFunctionIsValid(HuberLoss(1.3), 1.792);
93 TEST(LossFunction, SoftLOneLoss) {
94 AssertLossFunctionIsValid(SoftLOneLoss(0.7), 0.357);
95 AssertLossFunctionIsValid(SoftLOneLoss(0.7), 1.792);
96 AssertLossFunctionIsValid(SoftLOneLoss(1.3), 0.357);
97 AssertLossFunctionIsValid(SoftLOneLoss(1.3), 1.792);
100 TEST(LossFunction, CauchyLoss) {
101 AssertLossFunctionIsValid(CauchyLoss(0.7), 0.357);
102 AssertLossFunctionIsValid(CauchyLoss(0.7), 1.792);
103 AssertLossFunctionIsValid(CauchyLoss(1.3), 0.357);
104 AssertLossFunctionIsValid(CauchyLoss(1.3), 1.792);
107 TEST(LossFunction, ArctanLoss) {
108 AssertLossFunctionIsValid(ArctanLoss(0.7), 0.357);
109 AssertLossFunctionIsValid(ArctanLoss(0.7), 1.792);
110 AssertLossFunctionIsValid(ArctanLoss(1.3), 0.357);
111 AssertLossFunctionIsValid(ArctanLoss(1.3), 1.792);
114 TEST(LossFunction, TolerantLoss) {
115 AssertLossFunctionIsValid(TolerantLoss(0.7, 0.4), 0.357);
116 AssertLossFunctionIsValid(TolerantLoss(0.7, 0.4), 1.792);
117 AssertLossFunctionIsValid(TolerantLoss(0.7, 0.4), 55.5);
118 AssertLossFunctionIsValid(TolerantLoss(1.3, 0.1), 0.357);
119 AssertLossFunctionIsValid(TolerantLoss(1.3, 0.1), 1.792);
120 AssertLossFunctionIsValid(TolerantLoss(1.3, 0.1), 55.5);
121 // Check the value at zero is actually zero.
123 TolerantLoss(0.7, 0.4).Evaluate(0.0, rho);
124 ASSERT_NEAR(rho[0], 0.0, 1e-6);
125 // Check that loss before and after the approximation threshold are good.
126 // A threshold of 36.7 is used by the implementation.
127 AssertLossFunctionIsValid(TolerantLoss(20.0, 1.0), 20.0 + 36.6);
128 AssertLossFunctionIsValid(TolerantLoss(20.0, 1.0), 20.0 + 36.7);
129 AssertLossFunctionIsValid(TolerantLoss(20.0, 1.0), 20.0 + 36.8);
130 AssertLossFunctionIsValid(TolerantLoss(20.0, 1.0), 20.0 + 1000.0);
133 TEST(LossFunction, TukeyLoss) {
134 AssertLossFunctionIsValid(TukeyLoss(0.7), 0.357);
135 AssertLossFunctionIsValid(TukeyLoss(0.7), 1.792);
136 AssertLossFunctionIsValid(TukeyLoss(1.3), 0.357);
137 AssertLossFunctionIsValid(TukeyLoss(1.3), 1.792);
140 TEST(LossFunction, ComposedLoss) {
144 ComposedLoss c(&f, DO_NOT_TAKE_OWNERSHIP, &g, DO_NOT_TAKE_OWNERSHIP);
145 AssertLossFunctionIsValid(c, 0.357);
146 AssertLossFunctionIsValid(c, 1.792);
151 ComposedLoss c(&f, DO_NOT_TAKE_OWNERSHIP, &g, DO_NOT_TAKE_OWNERSHIP);
152 AssertLossFunctionIsValid(c, 0.357);
153 AssertLossFunctionIsValid(c, 1.792);
157 TEST(LossFunction, ScaledLoss) {
158 // Wrap a few loss functions, and a few scale factors. This can't combine
159 // construction with the call to AssertLossFunctionIsValid() because Apple's
160 // GCC is unable to eliminate the copy of ScaledLoss, which is not copyable.
162 ScaledLoss scaled_loss(NULL, 6, TAKE_OWNERSHIP);
163 AssertLossFunctionIsValid(scaled_loss, 0.323);
166 ScaledLoss scaled_loss(new TrivialLoss(), 10, TAKE_OWNERSHIP);
167 AssertLossFunctionIsValid(scaled_loss, 0.357);
170 ScaledLoss scaled_loss(new HuberLoss(0.7), 0.1, TAKE_OWNERSHIP);
171 AssertLossFunctionIsValid(scaled_loss, 1.792);
174 ScaledLoss scaled_loss(new SoftLOneLoss(1.3), 0.1, TAKE_OWNERSHIP);
175 AssertLossFunctionIsValid(scaled_loss, 1.792);
178 ScaledLoss scaled_loss(new CauchyLoss(1.3), 10, TAKE_OWNERSHIP);
179 AssertLossFunctionIsValid(scaled_loss, 1.792);
182 ScaledLoss scaled_loss(new ArctanLoss(1.3), 10, TAKE_OWNERSHIP);
183 AssertLossFunctionIsValid(scaled_loss, 1.792);
186 ScaledLoss scaled_loss(
187 new TolerantLoss(1.3, 0.1), 10, TAKE_OWNERSHIP);
188 AssertLossFunctionIsValid(scaled_loss, 1.792);
191 ScaledLoss scaled_loss(
193 new HuberLoss(0.8), TAKE_OWNERSHIP,
194 new TolerantLoss(1.3, 0.5), TAKE_OWNERSHIP), 10, TAKE_OWNERSHIP);
195 AssertLossFunctionIsValid(scaled_loss, 1.792);
199 TEST(LossFunction, LossFunctionWrapper) {
201 HuberLoss loss_function1(1.0);
202 LossFunctionWrapper loss_function_wrapper(new HuberLoss(1.0),
208 loss_function1.Evaluate(s, rho_gold);
209 loss_function_wrapper.Evaluate(s, rho);
210 for (int i = 0; i < 3; ++i) {
211 EXPECT_NEAR(rho[i], rho_gold[i], 1e-12);
215 HuberLoss loss_function2(0.5);
216 loss_function_wrapper.Reset(new HuberLoss(0.5), TAKE_OWNERSHIP);
217 loss_function_wrapper.Evaluate(s, rho);
218 loss_function2.Evaluate(s, rho_gold);
219 for (int i = 0; i < 3; ++i) {
220 EXPECT_NEAR(rho[i], rho_gold[i], 1e-12);
223 // Not taking ownership.
224 HuberLoss loss_function3(0.3);
225 loss_function_wrapper.Reset(&loss_function3, DO_NOT_TAKE_OWNERSHIP);
226 loss_function_wrapper.Evaluate(s, rho);
227 loss_function3.Evaluate(s, rho_gold);
228 for (int i = 0; i < 3; ++i) {
229 EXPECT_NEAR(rho[i], rho_gold[i], 1e-12);
233 TrivialLoss loss_function4;
234 loss_function_wrapper.Reset(NULL, TAKE_OWNERSHIP);
235 loss_function_wrapper.Evaluate(s, rho);
236 loss_function4.Evaluate(s, rho_gold);
237 for (int i = 0; i < 3; ++i) {
238 EXPECT_NEAR(rho[i], rho_gold[i], 1e-12);
241 // Set to NULL, not taking ownership
242 loss_function_wrapper.Reset(NULL, DO_NOT_TAKE_OWNERSHIP);
243 loss_function_wrapper.Evaluate(s, rho);
244 loss_function4.Evaluate(s, rho_gold);
245 for (int i = 0; i < 3; ++i) {
246 EXPECT_NEAR(rho[i], rho_gold[i], 1e-12);
251 } // namespace internal