Imported Upstream version ceres 1.13.0
[platform/upstream/ceres-solver.git] / internal / ceres / linear_least_squares_problems.cc
1 // Ceres Solver - A fast non-linear least squares minimizer
2 // Copyright 2015 Google Inc. All rights reserved.
3 // http://ceres-solver.org/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are met:
7 //
8 // * Redistributions of source code must retain the above copyright notice,
9 //   this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 //   this list of conditions and the following disclaimer in the documentation
12 //   and/or other materials provided with the distribution.
13 // * Neither the name of Google Inc. nor the names of its contributors may be
14 //   used to endorse or promote products derived from this software without
15 //   specific prior written permission.
16 //
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 // POSSIBILITY OF SUCH DAMAGE.
28 //
29 // Author: sameeragarwal@google.com (Sameer Agarwal)
30
31 #include "ceres/linear_least_squares_problems.h"
32
33 #include <cstdio>
34 #include <string>
35 #include <vector>
36 #include "ceres/block_sparse_matrix.h"
37 #include "ceres/block_structure.h"
38 #include "ceres/casts.h"
39 #include "ceres/file.h"
40 #include "ceres/internal/scoped_ptr.h"
41 #include "ceres/stringprintf.h"
42 #include "ceres/triplet_sparse_matrix.h"
43 #include "ceres/types.h"
44 #include "glog/logging.h"
45
46 namespace ceres {
47 namespace internal {
48
49 using std::string;
50
51 LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromId(int id) {
52   switch (id) {
53     case 0:
54       return LinearLeastSquaresProblem0();
55     case 1:
56       return LinearLeastSquaresProblem1();
57     case 2:
58       return LinearLeastSquaresProblem2();
59     case 3:
60       return LinearLeastSquaresProblem3();
61     case 4:
62       return LinearLeastSquaresProblem4();
63     default:
64       LOG(FATAL) << "Unknown problem id requested " << id;
65   }
66   return NULL;
67 }
68
69 /*
70 A = [1   2]
71     [3   4]
72     [6 -10]
73
74 b = [  8
75       18
76      -18]
77
78 x = [2
79      3]
80
81 D = [1
82      2]
83
84 x_D = [1.78448275;
85        2.82327586;]
86  */
87 LinearLeastSquaresProblem* LinearLeastSquaresProblem0() {
88   LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
89
90   TripletSparseMatrix* A = new TripletSparseMatrix(3, 2, 6);
91   problem->b.reset(new double[3]);
92   problem->D.reset(new double[2]);
93
94   problem->x.reset(new double[2]);
95   problem->x_D.reset(new double[2]);
96
97   int* Ai = A->mutable_rows();
98   int* Aj = A->mutable_cols();
99   double* Ax = A->mutable_values();
100
101   int counter = 0;
102   for (int i = 0; i < 3; ++i) {
103     for (int j = 0; j< 2; ++j) {
104       Ai[counter] = i;
105       Aj[counter] = j;
106       ++counter;
107     }
108   }
109
110   Ax[0] = 1.;
111   Ax[1] = 2.;
112   Ax[2] = 3.;
113   Ax[3] = 4.;
114   Ax[4] = 6;
115   Ax[5] = -10;
116   A->set_num_nonzeros(6);
117   problem->A.reset(A);
118
119   problem->b[0] = 8;
120   problem->b[1] = 18;
121   problem->b[2] = -18;
122
123   problem->x[0] = 2.0;
124   problem->x[1] = 3.0;
125
126   problem->D[0] = 1;
127   problem->D[1] = 2;
128
129   problem->x_D[0] = 1.78448275;
130   problem->x_D[1] = 2.82327586;
131   return problem;
132 }
133
134
135 /*
136       A = [1 0  | 2 0 0
137            3 0  | 0 4 0
138            0 5  | 0 0 6
139            0 7  | 8 0 0
140            0 9  | 1 0 0
141            0 0  | 1 1 1]
142
143       b = [0
144            1
145            2
146            3
147            4
148            5]
149
150       c = A'* b = [ 3
151                    67
152                    33
153                     9
154                    17]
155
156       A'A = [10    0    2   12   0
157               0  155   65    0  30
158               2   65   70    1   1
159              12    0    1   17   1
160               0   30    1    1  37]
161
162       S = [ 42.3419  -1.4000  -11.5806
163             -1.4000   2.6000    1.0000
164             11.5806   1.0000   31.1935]
165
166       r = [ 4.3032
167             5.4000
168             5.0323]
169
170       S\r = [ 0.2102
171               2.1367
172               0.1388]
173
174       A\b = [-2.3061
175               0.3172
176               0.2102
177               2.1367
178               0.1388]
179 */
180 // The following two functions create a TripletSparseMatrix and a
181 // BlockSparseMatrix version of this problem.
182
183 // TripletSparseMatrix version.
184 LinearLeastSquaresProblem* LinearLeastSquaresProblem1() {
185   int num_rows = 6;
186   int num_cols = 5;
187
188   LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
189   TripletSparseMatrix* A = new TripletSparseMatrix(num_rows,
190                                                    num_cols,
191                                                    num_rows * num_cols);
192   problem->b.reset(new double[num_rows]);
193   problem->D.reset(new double[num_cols]);
194   problem->num_eliminate_blocks = 2;
195
196   int* rows = A->mutable_rows();
197   int* cols = A->mutable_cols();
198   double* values = A->mutable_values();
199
200   int nnz = 0;
201
202   // Row 1
203   {
204     rows[nnz] = 0;
205     cols[nnz] = 0;
206     values[nnz++] = 1;
207
208     rows[nnz] = 0;
209     cols[nnz] = 2;
210     values[nnz++] = 2;
211   }
212
213   // Row 2
214   {
215     rows[nnz] = 1;
216     cols[nnz] = 0;
217     values[nnz++] = 3;
218
219     rows[nnz] = 1;
220     cols[nnz] = 3;
221     values[nnz++] = 4;
222   }
223
224   // Row 3
225   {
226     rows[nnz] = 2;
227     cols[nnz] = 1;
228     values[nnz++] = 5;
229
230     rows[nnz] = 2;
231     cols[nnz] = 4;
232     values[nnz++] = 6;
233   }
234
235   // Row 4
236   {
237     rows[nnz] = 3;
238     cols[nnz] = 1;
239     values[nnz++] = 7;
240
241     rows[nnz] = 3;
242     cols[nnz] = 2;
243     values[nnz++] = 8;
244   }
245
246   // Row 5
247   {
248     rows[nnz] = 4;
249     cols[nnz] = 1;
250     values[nnz++] = 9;
251
252     rows[nnz] = 4;
253     cols[nnz] = 2;
254     values[nnz++] = 1;
255   }
256
257   // Row 6
258   {
259     rows[nnz] = 5;
260     cols[nnz] = 2;
261     values[nnz++] = 1;
262
263     rows[nnz] = 5;
264     cols[nnz] = 3;
265     values[nnz++] = 1;
266
267     rows[nnz] = 5;
268     cols[nnz] = 4;
269     values[nnz++] = 1;
270   }
271
272   A->set_num_nonzeros(nnz);
273   CHECK(A->IsValid());
274
275   problem->A.reset(A);
276
277   for (int i = 0; i < num_cols; ++i) {
278     problem->D.get()[i] = 1;
279   }
280
281   for (int i = 0; i < num_rows; ++i) {
282     problem->b.get()[i] = i;
283   }
284
285   return problem;
286 }
287
288 // BlockSparseMatrix version
289 LinearLeastSquaresProblem* LinearLeastSquaresProblem2() {
290   int num_rows = 6;
291   int num_cols = 5;
292
293   LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
294
295   problem->b.reset(new double[num_rows]);
296   problem->D.reset(new double[num_cols]);
297   problem->num_eliminate_blocks = 2;
298
299   CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
300   scoped_array<double> values(new double[num_rows * num_cols]);
301
302   for (int c = 0; c < num_cols; ++c) {
303     bs->cols.push_back(Block());
304     bs->cols.back().size = 1;
305     bs->cols.back().position = c;
306   }
307
308   int nnz = 0;
309
310   // Row 1
311   {
312     values[nnz++] = 1;
313     values[nnz++] = 2;
314
315     bs->rows.push_back(CompressedRow());
316     CompressedRow& row = bs->rows.back();
317     row.block.size = 1;
318     row.block.position = 0;
319     row.cells.push_back(Cell(0, 0));
320     row.cells.push_back(Cell(2, 1));
321   }
322
323   // Row 2
324   {
325     values[nnz++] = 3;
326     values[nnz++] = 4;
327
328     bs->rows.push_back(CompressedRow());
329     CompressedRow& row = bs->rows.back();
330     row.block.size = 1;
331     row.block.position = 1;
332     row.cells.push_back(Cell(0, 2));
333     row.cells.push_back(Cell(3, 3));
334   }
335
336   // Row 3
337   {
338     values[nnz++] = 5;
339     values[nnz++] = 6;
340
341     bs->rows.push_back(CompressedRow());
342     CompressedRow& row = bs->rows.back();
343     row.block.size = 1;
344     row.block.position = 2;
345     row.cells.push_back(Cell(1, 4));
346     row.cells.push_back(Cell(4, 5));
347   }
348
349   // Row 4
350   {
351     values[nnz++] = 7;
352     values[nnz++] = 8;
353
354     bs->rows.push_back(CompressedRow());
355     CompressedRow& row = bs->rows.back();
356     row.block.size = 1;
357     row.block.position = 3;
358     row.cells.push_back(Cell(1, 6));
359     row.cells.push_back(Cell(2, 7));
360   }
361
362   // Row 5
363   {
364     values[nnz++] = 9;
365     values[nnz++] = 1;
366
367     bs->rows.push_back(CompressedRow());
368     CompressedRow& row = bs->rows.back();
369     row.block.size = 1;
370     row.block.position = 4;
371     row.cells.push_back(Cell(1, 8));
372     row.cells.push_back(Cell(2, 9));
373   }
374
375   // Row 6
376   {
377     values[nnz++] = 1;
378     values[nnz++] = 1;
379     values[nnz++] = 1;
380
381     bs->rows.push_back(CompressedRow());
382     CompressedRow& row = bs->rows.back();
383     row.block.size = 1;
384     row.block.position = 5;
385     row.cells.push_back(Cell(2, 10));
386     row.cells.push_back(Cell(3, 11));
387     row.cells.push_back(Cell(4, 12));
388   }
389
390   BlockSparseMatrix* A = new BlockSparseMatrix(bs);
391   memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
392
393   for (int i = 0; i < num_cols; ++i) {
394     problem->D.get()[i] = 1;
395   }
396
397   for (int i = 0; i < num_rows; ++i) {
398     problem->b.get()[i] = i;
399   }
400
401   problem->A.reset(A);
402
403   return problem;
404 }
405
406
407 /*
408       A = [1 0
409            3 0
410            0 5
411            0 7
412            0 9
413            0 0]
414
415       b = [0
416            1
417            2
418            3
419            4
420            5]
421 */
422 // BlockSparseMatrix version
423 LinearLeastSquaresProblem* LinearLeastSquaresProblem3() {
424   int num_rows = 5;
425   int num_cols = 2;
426
427   LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
428
429   problem->b.reset(new double[num_rows]);
430   problem->D.reset(new double[num_cols]);
431   problem->num_eliminate_blocks = 2;
432
433   CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
434   scoped_array<double> values(new double[num_rows * num_cols]);
435
436   for (int c = 0; c < num_cols; ++c) {
437     bs->cols.push_back(Block());
438     bs->cols.back().size = 1;
439     bs->cols.back().position = c;
440   }
441
442   int nnz = 0;
443
444   // Row 1
445   {
446     values[nnz++] = 1;
447     bs->rows.push_back(CompressedRow());
448     CompressedRow& row = bs->rows.back();
449     row.block.size = 1;
450     row.block.position = 0;
451     row.cells.push_back(Cell(0, 0));
452   }
453
454   // Row 2
455   {
456     values[nnz++] = 3;
457     bs->rows.push_back(CompressedRow());
458     CompressedRow& row = bs->rows.back();
459     row.block.size = 1;
460     row.block.position = 1;
461     row.cells.push_back(Cell(0, 1));
462   }
463
464   // Row 3
465   {
466     values[nnz++] = 5;
467     bs->rows.push_back(CompressedRow());
468     CompressedRow& row = bs->rows.back();
469     row.block.size = 1;
470     row.block.position = 2;
471     row.cells.push_back(Cell(1, 2));
472   }
473
474   // Row 4
475   {
476     values[nnz++] = 7;
477     bs->rows.push_back(CompressedRow());
478     CompressedRow& row = bs->rows.back();
479     row.block.size = 1;
480     row.block.position = 3;
481     row.cells.push_back(Cell(1, 3));
482   }
483
484   // Row 5
485   {
486     values[nnz++] = 9;
487     bs->rows.push_back(CompressedRow());
488     CompressedRow& row = bs->rows.back();
489     row.block.size = 1;
490     row.block.position = 4;
491     row.cells.push_back(Cell(1, 4));
492   }
493
494   BlockSparseMatrix* A = new BlockSparseMatrix(bs);
495   memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
496
497   for (int i = 0; i < num_cols; ++i) {
498     problem->D.get()[i] = 1;
499   }
500
501   for (int i = 0; i < num_rows; ++i) {
502     problem->b.get()[i] = i;
503   }
504
505   problem->A.reset(A);
506
507   return problem;
508 }
509
510 /*
511       A = [1 2 0 0 0 1 1
512            1 4 0 0 0 5 6
513            0 0 9 0 0 3 1]
514
515       b = [0
516            1
517            2]
518 */
519 // BlockSparseMatrix version
520 //
521 // This problem has the unique property that it has two different
522 // sized f-blocks, but only one of them occurs in the rows involving
523 // the one e-block. So performing Schur elimination on this problem
524 // tests the Schur Eliminator's ability to handle non-e-block rows
525 // correctly when their structure does not conform to the static
526 // structure determined by DetectStructure.
527 //
528 // NOTE: This problem is too small and rank deficient to be solved without
529 // the diagonal regularization.
530 LinearLeastSquaresProblem* LinearLeastSquaresProblem4() {
531   int num_rows = 3;
532   int num_cols = 7;
533
534   LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
535
536   problem->b.reset(new double[num_rows]);
537   problem->D.reset(new double[num_cols]);
538   problem->num_eliminate_blocks = 1;
539
540   CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
541   scoped_array<double> values(new double[num_rows * num_cols]);
542
543   // Column block structure
544   bs->cols.push_back(Block());
545   bs->cols.back().size = 2;
546   bs->cols.back().position = 0;
547
548   bs->cols.push_back(Block());
549   bs->cols.back().size = 3;
550   bs->cols.back().position = 2;
551
552   bs->cols.push_back(Block());
553   bs->cols.back().size = 2;
554   bs->cols.back().position = 5;
555
556   int nnz = 0;
557
558   // Row 1 & 2
559   {
560     bs->rows.push_back(CompressedRow());
561     CompressedRow& row = bs->rows.back();
562     row.block.size = 2;
563     row.block.position = 0;
564
565     row.cells.push_back(Cell(0, nnz));
566     values[nnz++] = 1;
567     values[nnz++] = 2;
568     values[nnz++] = 1;
569     values[nnz++] = 4;
570
571     row.cells.push_back(Cell(2, nnz));
572     values[nnz++] = 1;
573     values[nnz++] = 1;
574     values[nnz++] = 5;
575     values[nnz++] = 6;
576   }
577
578   // Row 3
579   {
580     bs->rows.push_back(CompressedRow());
581     CompressedRow& row = bs->rows.back();
582     row.block.size = 1;
583     row.block.position = 2;
584
585     row.cells.push_back(Cell(1, nnz));
586     values[nnz++] = 9;
587     values[nnz++] = 0;
588     values[nnz++] = 0;
589
590     row.cells.push_back(Cell(2, nnz));
591     values[nnz++] = 3;
592     values[nnz++] = 1;
593   }
594
595   BlockSparseMatrix* A = new BlockSparseMatrix(bs);
596   memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
597
598   for (int i = 0; i < num_cols; ++i) {
599     problem->D.get()[i] = (i + 1) * 100;
600   }
601
602   for (int i = 0; i < num_rows; ++i) {
603     problem->b.get()[i] = i;
604   }
605
606   problem->A.reset(A);
607   return problem;
608 }
609
610 namespace {
611 bool DumpLinearLeastSquaresProblemToConsole(const SparseMatrix* A,
612                                             const double* D,
613                                             const double* b,
614                                             const double* x,
615                                             int num_eliminate_blocks) {
616   CHECK_NOTNULL(A);
617   Matrix AA;
618   A->ToDenseMatrix(&AA);
619   LOG(INFO) << "A^T: \n" << AA.transpose();
620
621   if (D != NULL) {
622     LOG(INFO) << "A's appended diagonal:\n"
623               << ConstVectorRef(D, A->num_cols());
624   }
625
626   if (b != NULL) {
627     LOG(INFO) << "b: \n" << ConstVectorRef(b, A->num_rows());
628   }
629
630   if (x != NULL) {
631     LOG(INFO) << "x: \n" << ConstVectorRef(x, A->num_cols());
632   }
633   return true;
634 }
635
636 void WriteArrayToFileOrDie(const string& filename,
637                            const double* x,
638                            const int size) {
639   CHECK_NOTNULL(x);
640   VLOG(2) << "Writing array to: " << filename;
641   FILE* fptr = fopen(filename.c_str(), "w");
642   CHECK_NOTNULL(fptr);
643   for (int i = 0; i < size; ++i) {
644     fprintf(fptr, "%17f\n", x[i]);
645   }
646   fclose(fptr);
647 }
648
649 bool DumpLinearLeastSquaresProblemToTextFile(const string& filename_base,
650                                              const SparseMatrix* A,
651                                              const double* D,
652                                              const double* b,
653                                              const double* x,
654                                              int num_eliminate_blocks) {
655   CHECK_NOTNULL(A);
656   LOG(INFO) << "writing to: " << filename_base << "*";
657
658   string matlab_script;
659   StringAppendF(&matlab_script,
660                 "function lsqp = load_trust_region_problem()\n");
661   StringAppendF(&matlab_script,
662                 "lsqp.num_rows = %d;\n", A->num_rows());
663   StringAppendF(&matlab_script,
664                 "lsqp.num_cols = %d;\n", A->num_cols());
665
666   {
667     string filename = filename_base + "_A.txt";
668     FILE* fptr = fopen(filename.c_str(), "w");
669     CHECK_NOTNULL(fptr);
670     A->ToTextFile(fptr);
671     fclose(fptr);
672     StringAppendF(&matlab_script,
673                   "tmp = load('%s', '-ascii');\n", filename.c_str());
674     StringAppendF(
675         &matlab_script,
676         "lsqp.A = sparse(tmp(:, 1) + 1, tmp(:, 2) + 1, tmp(:, 3), %d, %d);\n",
677         A->num_rows(),
678         A->num_cols());
679   }
680
681
682   if (D != NULL) {
683     string filename = filename_base + "_D.txt";
684     WriteArrayToFileOrDie(filename, D, A->num_cols());
685     StringAppendF(&matlab_script,
686                   "lsqp.D = load('%s', '-ascii');\n", filename.c_str());
687   }
688
689   if (b != NULL) {
690     string filename = filename_base + "_b.txt";
691     WriteArrayToFileOrDie(filename, b, A->num_rows());
692     StringAppendF(&matlab_script,
693                   "lsqp.b = load('%s', '-ascii');\n", filename.c_str());
694   }
695
696   if (x != NULL) {
697     string filename = filename_base + "_x.txt";
698     WriteArrayToFileOrDie(filename, x, A->num_cols());
699     StringAppendF(&matlab_script,
700                   "lsqp.x = load('%s', '-ascii');\n", filename.c_str());
701   }
702
703   string matlab_filename = filename_base + ".m";
704   WriteStringToFileOrDie(matlab_script, matlab_filename);
705   return true;
706 }
707 }  // namespace
708
709 bool DumpLinearLeastSquaresProblem(const string& filename_base,
710                                    DumpFormatType dump_format_type,
711                                    const SparseMatrix* A,
712                                    const double* D,
713                                    const double* b,
714                                    const double* x,
715                                    int num_eliminate_blocks) {
716   switch (dump_format_type) {
717     case CONSOLE:
718       return DumpLinearLeastSquaresProblemToConsole(A, D, b, x,
719                                                     num_eliminate_blocks);
720     case TEXTFILE:
721       return DumpLinearLeastSquaresProblemToTextFile(filename_base,
722                                                      A, D, b, x,
723                                                      num_eliminate_blocks);
724     default:
725       LOG(FATAL) << "Unknown DumpFormatType " << dump_format_type;
726   }
727
728   return true;
729 }
730
731 }  // namespace internal
732 }  // namespace ceres