1 // Ceres Solver - A fast non-linear least squares minimizer
2 // Copyright 2015 Google Inc. All rights reserved.
3 // http://ceres-solver.org/
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are met:
8 // * Redistributions of source code must retain the above copyright notice,
9 // this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 // this list of conditions and the following disclaimer in the documentation
12 // and/or other materials provided with the distribution.
13 // * Neither the name of Google Inc. nor the names of its contributors may be
14 // used to endorse or promote products derived from this software without
15 // specific prior written permission.
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 // POSSIBILITY OF SUCH DAMAGE.
29 // Author: sameeragarwal@google.com (Sameer Agarwal)
31 // CostFunctionToFunctor is an adapter class that allows users to use
32 // SizedCostFunction objects in templated functors which are to be used for
33 // automatic differentiation. This allows the user to seamlessly mix
34 // analytic, numeric and automatic differentiation.
36 // For example, let us assume that
38 // class IntrinsicProjection : public SizedCostFunction<2, 5, 3> {
40 // IntrinsicProjection(const double* observation);
41 // virtual bool Evaluate(double const* const* parameters,
43 // double** jacobians) const;
46 // is a cost function that implements the projection of a point in its
47 // local coordinate system onto its image plane and subtracts it from
48 // the observed point projection. It can compute its residual and
49 // either via analytic or numerical differentiation can compute its
52 // Now we would like to compose the action of this CostFunction with
53 // the action of camera extrinsics, i.e., rotation and
54 // translation. Say we have a templated function
56 // template<typename T>
57 // void RotateAndTranslatePoint(const T* rotation,
58 // const T* translation,
62 // Then we can now do the following,
64 // struct CameraProjection {
65 // CameraProjection(const double* observation)
66 // : intrinsic_projection_(new IntrinsicProjection(observation)) {
68 // template <typename T>
69 // bool operator()(const T* rotation,
70 // const T* translation,
71 // const T* intrinsics,
73 // T* residual) const {
74 // T transformed_point[3];
75 // RotateAndTranslatePoint(rotation, translation, point, transformed_point);
77 // // Note that we call intrinsic_projection_, just like it was
78 // // any other templated functor.
80 // return intrinsic_projection_(intrinsics, transformed_point, residual);
84 // CostFunctionToFunctor<2,5,3> intrinsic_projection_;
87 #ifndef CERES_PUBLIC_COST_FUNCTION_TO_FUNCTOR_H_
88 #define CERES_PUBLIC_COST_FUNCTION_TO_FUNCTOR_H_
93 #include "ceres/cost_function.h"
94 #include "ceres/dynamic_cost_function_to_functor.h"
95 #include "ceres/internal/fixed_array.h"
96 #include "ceres/internal/port.h"
97 #include "ceres/internal/scoped_ptr.h"
101 template <int kNumResiduals,
102 int N0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0,
103 int N5 = 0, int N6 = 0, int N7 = 0, int N8 = 0, int N9 = 0>
104 class CostFunctionToFunctor {
106 // Takes ownership of cost_function.
107 explicit CostFunctionToFunctor(CostFunction* cost_function)
108 : cost_functor_(cost_function) {
109 CHECK_NOTNULL(cost_function);
110 CHECK(kNumResiduals > 0 || kNumResiduals == DYNAMIC);
112 // This block breaks the 80 column rule to keep it somewhat readable.
113 CHECK((!N1 && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
114 ((N1 > 0) && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
115 ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || // NOLINT
116 ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || // NOLINT
117 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5 && !N6 && !N7 && !N8 && !N9) || // NOLINT
118 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && !N6 && !N7 && !N8 && !N9) || // NOLINT
119 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && !N7 && !N8 && !N9) || // NOLINT
120 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && !N8 && !N9) || // NOLINT
121 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && !N9) || // NOLINT
122 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && (N9 > 0))) // NOLINT
123 << "Zero block cannot precede a non-zero block. Block sizes are "
124 << "(ignore trailing 0s): " << N0 << ", " << N1 << ", " << N2 << ", "
125 << N3 << ", " << N4 << ", " << N5 << ", " << N6 << ", " << N7 << ", "
128 const std::vector<int32>& parameter_block_sizes =
129 cost_function->parameter_block_sizes();
130 const int num_parameter_blocks =
131 (N0 > 0) + (N1 > 0) + (N2 > 0) + (N3 > 0) + (N4 > 0) +
132 (N5 > 0) + (N6 > 0) + (N7 > 0) + (N8 > 0) + (N9 > 0);
133 CHECK_EQ(static_cast<int>(parameter_block_sizes.size()),
134 num_parameter_blocks);
136 CHECK_EQ(N0, parameter_block_sizes[0]);
137 if (parameter_block_sizes.size() > 1) CHECK_EQ(N1, parameter_block_sizes[1]); // NOLINT
138 if (parameter_block_sizes.size() > 2) CHECK_EQ(N2, parameter_block_sizes[2]); // NOLINT
139 if (parameter_block_sizes.size() > 3) CHECK_EQ(N3, parameter_block_sizes[3]); // NOLINT
140 if (parameter_block_sizes.size() > 4) CHECK_EQ(N4, parameter_block_sizes[4]); // NOLINT
141 if (parameter_block_sizes.size() > 5) CHECK_EQ(N5, parameter_block_sizes[5]); // NOLINT
142 if (parameter_block_sizes.size() > 6) CHECK_EQ(N6, parameter_block_sizes[6]); // NOLINT
143 if (parameter_block_sizes.size() > 7) CHECK_EQ(N7, parameter_block_sizes[7]); // NOLINT
144 if (parameter_block_sizes.size() > 8) CHECK_EQ(N8, parameter_block_sizes[8]); // NOLINT
145 if (parameter_block_sizes.size() > 9) CHECK_EQ(N9, parameter_block_sizes[9]); // NOLINT
147 CHECK_EQ(accumulate(parameter_block_sizes.begin(),
148 parameter_block_sizes.end(), 0),
149 N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8 + N9);
152 bool operator()(const double* x0, double* residuals) const {
164 return cost_functor_(&x0, residuals);
167 bool operator()(const double* x0,
169 double* residuals) const {
180 internal::FixedArray<const double*> parameter_blocks(2);
181 parameter_blocks[0] = x0;
182 parameter_blocks[1] = x1;
183 return cost_functor_(parameter_blocks.get(), residuals);
186 bool operator()(const double* x0,
189 double* residuals) const {
200 internal::FixedArray<const double*> parameter_blocks(3);
201 parameter_blocks[0] = x0;
202 parameter_blocks[1] = x1;
203 parameter_blocks[2] = x2;
204 return cost_functor_(parameter_blocks.get(), residuals);
207 bool operator()(const double* x0,
211 double* residuals) const {
222 internal::FixedArray<const double*> parameter_blocks(4);
223 parameter_blocks[0] = x0;
224 parameter_blocks[1] = x1;
225 parameter_blocks[2] = x2;
226 parameter_blocks[3] = x3;
227 return cost_functor_(parameter_blocks.get(), residuals);
230 bool operator()(const double* x0,
235 double* residuals) const {
246 internal::FixedArray<const double*> parameter_blocks(5);
247 parameter_blocks[0] = x0;
248 parameter_blocks[1] = x1;
249 parameter_blocks[2] = x2;
250 parameter_blocks[3] = x3;
251 parameter_blocks[4] = x4;
252 return cost_functor_(parameter_blocks.get(), residuals);
255 bool operator()(const double* x0,
261 double* residuals) const {
272 internal::FixedArray<const double*> parameter_blocks(6);
273 parameter_blocks[0] = x0;
274 parameter_blocks[1] = x1;
275 parameter_blocks[2] = x2;
276 parameter_blocks[3] = x3;
277 parameter_blocks[4] = x4;
278 parameter_blocks[5] = x5;
279 return cost_functor_(parameter_blocks.get(), residuals);
282 bool operator()(const double* x0,
289 double* residuals) const {
300 internal::FixedArray<const double*> parameter_blocks(7);
301 parameter_blocks[0] = x0;
302 parameter_blocks[1] = x1;
303 parameter_blocks[2] = x2;
304 parameter_blocks[3] = x3;
305 parameter_blocks[4] = x4;
306 parameter_blocks[5] = x5;
307 parameter_blocks[6] = x6;
308 return cost_functor_(parameter_blocks.get(), residuals);
311 bool operator()(const double* x0,
319 double* residuals) const {
330 internal::FixedArray<const double*> parameter_blocks(8);
331 parameter_blocks[0] = x0;
332 parameter_blocks[1] = x1;
333 parameter_blocks[2] = x2;
334 parameter_blocks[3] = x3;
335 parameter_blocks[4] = x4;
336 parameter_blocks[5] = x5;
337 parameter_blocks[6] = x6;
338 parameter_blocks[7] = x7;
339 return cost_functor_(parameter_blocks.get(), residuals);
342 bool operator()(const double* x0,
351 double* residuals) const {
362 internal::FixedArray<const double*> parameter_blocks(9);
363 parameter_blocks[0] = x0;
364 parameter_blocks[1] = x1;
365 parameter_blocks[2] = x2;
366 parameter_blocks[3] = x3;
367 parameter_blocks[4] = x4;
368 parameter_blocks[5] = x5;
369 parameter_blocks[6] = x6;
370 parameter_blocks[7] = x7;
371 parameter_blocks[8] = x8;
372 return cost_functor_(parameter_blocks.get(), residuals);
375 bool operator()(const double* x0,
385 double* residuals) const {
396 internal::FixedArray<const double*> parameter_blocks(10);
397 parameter_blocks[0] = x0;
398 parameter_blocks[1] = x1;
399 parameter_blocks[2] = x2;
400 parameter_blocks[3] = x3;
401 parameter_blocks[4] = x4;
402 parameter_blocks[5] = x5;
403 parameter_blocks[6] = x6;
404 parameter_blocks[7] = x7;
405 parameter_blocks[8] = x8;
406 parameter_blocks[9] = x9;
407 return cost_functor_(parameter_blocks.get(), residuals);
410 template <typename JetT>
411 bool operator()(const JetT* x0, JetT* residuals) const {
422 return cost_functor_(&x0, residuals);
425 template <typename JetT>
426 bool operator()(const JetT* x0,
428 JetT* residuals) const {
439 internal::FixedArray<const JetT*> jets(2);
442 return cost_functor_(jets.get(), residuals);
445 template <typename JetT>
446 bool operator()(const JetT* x0,
449 JetT* residuals) const {
460 internal::FixedArray<const JetT*> jets(3);
464 return cost_functor_(jets.get(), residuals);
467 template <typename JetT>
468 bool operator()(const JetT* x0,
472 JetT* residuals) const {
483 internal::FixedArray<const JetT*> jets(4);
488 return cost_functor_(jets.get(), residuals);
491 template <typename JetT>
492 bool operator()(const JetT* x0,
497 JetT* residuals) const {
508 internal::FixedArray<const JetT*> jets(5);
514 return cost_functor_(jets.get(), residuals);
517 template <typename JetT>
518 bool operator()(const JetT* x0,
524 JetT* residuals) const {
535 internal::FixedArray<const JetT*> jets(6);
542 return cost_functor_(jets.get(), residuals);
545 template <typename JetT>
546 bool operator()(const JetT* x0,
553 JetT* residuals) const {
564 internal::FixedArray<const JetT*> jets(7);
572 return cost_functor_(jets.get(), residuals);
575 template <typename JetT>
576 bool operator()(const JetT* x0,
584 JetT* residuals) const {
595 internal::FixedArray<const JetT*> jets(8);
604 return cost_functor_(jets.get(), residuals);
607 template <typename JetT>
608 bool operator()(const JetT* x0,
617 JetT* residuals) const {
628 internal::FixedArray<const JetT*> jets(9);
638 return cost_functor_(jets.get(), residuals);
641 template <typename JetT>
642 bool operator()(const JetT* x0,
652 JetT* residuals) const {
663 internal::FixedArray<const JetT*> jets(10);
674 return cost_functor_(jets.get(), residuals);
678 DynamicCostFunctionToFunctor cost_functor_;
683 #endif // CERES_PUBLIC_COST_FUNCTION_TO_FUNCTOR_H_