Fixed bugs in BruteForceMatcher and its specialization
[platform/upstream/opencv.git] / modules / features2d / src / descriptors.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "precomp.hpp"
43
44 #ifdef HAVE_EIGEN2
45 #include <Eigen/Array>
46 #endif
47
48 //#define _KDTREE
49
50 using namespace std;
51
52 const int draw_shift_bits = 4;
53 const int draw_multiplier = 1 << draw_shift_bits;
54
55 namespace cv
56 {
57
58 Mat windowedMatchingMask( const vector<KeyPoint>& keypoints1, const vector<KeyPoint>& keypoints2,
59                           float maxDeltaX, float maxDeltaY )
60 {
61     if( keypoints1.empty() || keypoints2.empty() )
62         return Mat();
63
64     Mat mask( keypoints1.size(), keypoints2.size(), CV_8UC1 );
65     for( size_t i = 0; i < keypoints1.size(); i++ )
66     {
67         for( size_t j = 0; j < keypoints2.size(); j++ )
68         {
69             Point2f diff = keypoints2[j].pt - keypoints1[i].pt;
70             mask.at<uchar>(i, j) = std::abs(diff.x) < maxDeltaX && std::abs(diff.y) < maxDeltaY;
71         }
72     }
73     return mask;
74 }
75
76 /*
77  * Drawing functions
78  */
79
80 static inline void _drawKeypoint( Mat& img, const KeyPoint& p, const Scalar& color, int flags )
81 {
82     Point center( p.pt.x * draw_multiplier, p.pt.y * draw_multiplier );
83
84     if( flags & DrawMatchesFlags::DRAW_RICH_KEYPOINTS )
85     {
86         int radius = p.size/2 * draw_multiplier; // KeyPoint::size is a diameter
87
88         // draw the circles around keypoints with the keypoints size
89         circle( img, center, radius, color, 1, CV_AA, draw_shift_bits );
90
91         // draw orientation of the keypoint, if it is applicable
92         if( p.angle != -1 )
93         {
94             float srcAngleRad = p.angle*CV_PI/180;
95             Point orient(cos(srcAngleRad)*radius, sin(srcAngleRad)*radius);
96             line( img, center, center+orient, color, 1, CV_AA, draw_shift_bits );
97         }
98 #if 0
99         else
100         {
101             // draw center with R=1
102             int radius = 1 * draw_multiplier;
103             circle( img, center, radius, color, 1, CV_AA, draw_shift_bits );
104         }
105 #endif
106     }
107     else
108     {
109         // draw center with R=3
110         int radius = 3 * draw_multiplier;
111         circle( img, center, radius, color, 1, CV_AA, draw_shift_bits );
112     }
113 }
114
115 void drawKeypoints( const Mat& image, const vector<KeyPoint>& keypoints, Mat& outImg,
116                     const Scalar& _color, int flags )
117 {
118     if( !(flags & DrawMatchesFlags::DRAW_OVER_OUTIMG) )
119         cvtColor( image, outImg, CV_GRAY2BGR );
120
121     RNG& rng=theRNG();
122     bool isRandColor = _color == Scalar::all(-1);
123
124     for( vector<KeyPoint>::const_iterator i = keypoints.begin(), ie = keypoints.end(); i != ie; ++i )
125     {
126         Scalar color = isRandColor ? Scalar(rng(256), rng(256), rng(256)) : _color;
127         _drawKeypoint( outImg, *i, color, flags );
128     }
129 }
130
131 static void _prepareImgAndDrawKeypoints( const Mat& img1, const vector<KeyPoint>& keypoints1,
132                                          const Mat& img2, const vector<KeyPoint>& keypoints2,
133                                          Mat& outImg, Mat& outImg1, Mat& outImg2,
134                                          const Scalar& singlePointColor, int flags )
135 {
136     Size size( img1.cols + img2.cols, MAX(img1.rows, img2.rows) );
137     if( flags & DrawMatchesFlags::DRAW_OVER_OUTIMG )
138     {
139         if( size.width > outImg.cols || size.height > outImg.rows )
140             CV_Error( CV_StsBadSize, "outImg has size less than need to draw img1 and img2 together" );
141         outImg1 = outImg( Rect(0, 0, img1.cols, img1.rows) );
142         outImg2 = outImg( Rect(img1.cols, 0, img2.cols, img2.rows) );
143     }
144     else
145     {
146         outImg.create( size, CV_MAKETYPE(img1.depth(), 3) );
147         outImg1 = outImg( Rect(0, 0, img1.cols, img1.rows) );
148         outImg2 = outImg( Rect(img1.cols, 0, img2.cols, img2.rows) );
149         cvtColor( img1, outImg1, CV_GRAY2RGB );
150         cvtColor( img2, outImg2, CV_GRAY2RGB );
151     }
152
153     // draw keypoints
154     if( !(flags & DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS) )
155     {
156         Mat outImg1 = outImg( Rect(0, 0, img1.cols, img1.rows) );
157         drawKeypoints( outImg1, keypoints1, outImg1, singlePointColor, flags + DrawMatchesFlags::DRAW_OVER_OUTIMG );
158
159         Mat outImg2 = outImg( Rect(img1.cols, 0, img2.cols, img2.rows) );
160         drawKeypoints( outImg2, keypoints2, outImg2, singlePointColor, flags + DrawMatchesFlags::DRAW_OVER_OUTIMG );
161     }
162 }
163
164 static inline void _drawMatch( Mat& outImg, Mat& outImg1, Mat& outImg2 ,
165                           const KeyPoint& kp1, const KeyPoint& kp2, const Scalar& matchColor, int flags )
166 {
167     RNG& rng = theRNG();
168     bool isRandMatchColor = matchColor == Scalar::all(-1);
169     Scalar color = isRandMatchColor ? Scalar( rng(256), rng(256), rng(256) ) : matchColor;
170
171     _drawKeypoint( outImg1, kp1, color, flags );
172     _drawKeypoint( outImg2, kp2, color, flags );
173
174     Point2f pt1 = kp1.pt,
175             pt2 = kp2.pt,
176             dpt2 = Point2f( std::min(pt2.x+outImg1.cols, float(outImg.cols-1)), pt2.y );
177
178     line( outImg, Point(pt1.x*draw_multiplier, pt1.y*draw_multiplier), Point(dpt2.x*draw_multiplier, dpt2.y*draw_multiplier),
179           color, 1, CV_AA, draw_shift_bits );
180 }
181
182 void drawMatches( const Mat& img1, const vector<KeyPoint>& keypoints1,
183                   const Mat& img2,const vector<KeyPoint>& keypoints2,
184                   const vector<int>& matches1to2, Mat& outImg,
185                   const Scalar& matchColor, const Scalar& singlePointColor,
186                   const vector<char>& matchesMask, int flags )
187 {
188     if( matches1to2.size() != keypoints1.size() )
189         CV_Error( CV_StsBadSize, "matches1to2 must have the same size as keypoints1" );
190     if( !matchesMask.empty() && matchesMask.size() != matches1to2.size() )
191         CV_Error( CV_StsBadSize, "matchesMask must have the same size as matches1to2" );
192
193     Mat outImg1, outImg2;
194     _prepareImgAndDrawKeypoints( img1, keypoints1, img2, keypoints2,
195                                  outImg, outImg1, outImg2, singlePointColor, flags );
196
197     // draw matches
198     for( size_t i1 = 0; i1 < keypoints1.size(); i1++ )
199     {
200         int i2 = matches1to2[i1];
201         if( (matchesMask.empty() || matchesMask[i1] ) && i2 >= 0 )
202         {
203             const KeyPoint &kp1 = keypoints1[i1], &kp2 = keypoints2[i2];
204             _drawMatch( outImg, outImg1, outImg2, kp1, kp2, matchColor, flags );
205         }
206     }
207 }
208
209 void drawMatches( const Mat& img1, const vector<KeyPoint>& keypoints1,
210                   const Mat& img2, const vector<KeyPoint>& keypoints2,
211                   const vector<DMatch>& matches1to2, Mat& outImg,
212                   const Scalar& matchColor, const Scalar& singlePointColor,
213                   const vector<char>& matchesMask, int flags )
214 {
215     if( !matchesMask.empty() && matchesMask.size() != matches1to2.size() )
216         CV_Error( CV_StsBadSize, "matchesMask must have the same size as matches1to2" );
217
218     Mat outImg1, outImg2;
219     _prepareImgAndDrawKeypoints( img1, keypoints1, img2, keypoints2,
220                                  outImg, outImg1, outImg2, singlePointColor, flags );
221
222     // draw matches
223     for( size_t m = 0; m < matches1to2.size(); m++ )
224     {
225         int i1 = matches1to2[m].indexQuery;
226         int i2 = matches1to2[m].indexTrain;
227         if( matchesMask.empty() || matchesMask[m] )
228         {
229             const KeyPoint &kp1 = keypoints1[i1], &kp2 = keypoints2[i2];
230             _drawMatch( outImg, outImg1, outImg2, kp1, kp2, matchColor, flags );
231         }
232     }
233 }
234
235 void drawMatches( const Mat& img1, const vector<KeyPoint>& keypoints1,
236                   const Mat& img2, const vector<KeyPoint>& keypoints2,
237                   const vector<vector<DMatch> >& matches1to2, Mat& outImg,
238                   const Scalar& matchColor, const Scalar& singlePointColor,
239                   const vector<vector<char> >& matchesMask, int flags )
240 {
241     if( !matchesMask.empty() && matchesMask.size() != matches1to2.size() )
242         CV_Error( CV_StsBadSize, "matchesMask must have the same size as matches1to2" );
243
244     Mat outImg1, outImg2;
245     _prepareImgAndDrawKeypoints( img1, keypoints1, img2, keypoints2,
246                                  outImg, outImg1, outImg2, singlePointColor, flags );
247
248     // draw matches
249     for( size_t i = 0; i < matches1to2.size(); i++ )
250     {
251         for( size_t j = 0; j < matches1to2[i].size(); j++ )
252         {
253             int i1 = matches1to2[i][j].indexQuery;
254             int i2 = matches1to2[i][j].indexTrain;
255             if( matchesMask.empty() || matchesMask[i][j] )
256             {
257                 const KeyPoint &kp1 = keypoints1[i1], &kp2 = keypoints2[i2];
258                 _drawMatch( outImg, outImg1, outImg2, kp1, kp2, matchColor, flags );
259             }
260         }
261     }
262 }
263
264 /****************************************************************************************\
265 *                                 DescriptorExtractor                                    *
266 \****************************************************************************************/
267 /*
268  *   DescriptorExtractor
269  */
270 struct RoiPredicate
271 {
272     RoiPredicate(float _minX, float _minY, float _maxX, float _maxY)
273         : minX(_minX), minY(_minY), maxX(_maxX), maxY(_maxY)
274     {}
275
276     bool operator()( const KeyPoint& keyPt) const
277     {
278         Point2f pt = keyPt.pt;
279         return (pt.x < minX) || (pt.x >= maxX) || (pt.y < minY) || (pt.y >= maxY);
280     }
281
282     float minX, minY, maxX, maxY;
283 };
284
285 void DescriptorExtractor::removeBorderKeypoints( vector<KeyPoint>& keypoints,
286                                                  Size imageSize, int borderPixels )
287 {
288     keypoints.erase( remove_if(keypoints.begin(), keypoints.end(),
289                                RoiPredicate((float)borderPixels, (float)borderPixels,
290                                             (float)(imageSize.width - borderPixels),
291                                             (float)(imageSize.height - borderPixels))),
292                      keypoints.end());
293 }
294
295 /****************************************************************************************\
296 *                                SiftDescriptorExtractor                                 *
297 \****************************************************************************************/
298 SiftDescriptorExtractor::SiftDescriptorExtractor( double magnification, bool isNormalize, bool recalculateAngles,
299                                                   int nOctaves, int nOctaveLayers, int firstOctave, int angleMode )
300     : sift( magnification, isNormalize, recalculateAngles, nOctaves, nOctaveLayers, firstOctave, angleMode )
301 {}
302
303 void SiftDescriptorExtractor::compute( const Mat& image,
304                                        vector<KeyPoint>& keypoints,
305                                        Mat& descriptors) const
306 {
307     bool useProvidedKeypoints = true;
308     sift(image, Mat(), keypoints, descriptors, useProvidedKeypoints);
309 }
310
311 void SiftDescriptorExtractor::read (const FileNode &fn)
312 {
313     double magnification = fn["magnification"];
314     bool isNormalize = (int)fn["isNormalize"] != 0;
315     bool recalculateAngles = (int)fn["recalculateAngles"] != 0;
316     int nOctaves = fn["nOctaves"];
317     int nOctaveLayers = fn["nOctaveLayers"];
318     int firstOctave = fn["firstOctave"];
319     int angleMode = fn["angleMode"];
320
321     sift = SIFT( magnification, isNormalize, recalculateAngles, nOctaves, nOctaveLayers, firstOctave, angleMode );
322 }
323
324 void SiftDescriptorExtractor::write (FileStorage &fs) const
325 {
326 //    fs << "algorithm" << getAlgorithmName ();
327
328     SIFT::CommonParams commParams = sift.getCommonParams ();
329     SIFT::DescriptorParams descriptorParams = sift.getDescriptorParams ();
330     fs << "magnification" << descriptorParams.magnification;
331     fs << "isNormalize" << descriptorParams.isNormalize;
332     fs << "recalculateAngles" << descriptorParams.recalculateAngles;
333     fs << "nOctaves" << commParams.nOctaves;
334     fs << "nOctaveLayers" << commParams.nOctaveLayers;
335     fs << "firstOctave" << commParams.firstOctave;
336     fs << "angleMode" << commParams.angleMode;
337 }
338
339 /****************************************************************************************\
340 *                                SurfDescriptorExtractor                                 *
341 \****************************************************************************************/
342 SurfDescriptorExtractor::SurfDescriptorExtractor( int nOctaves,
343                                                   int nOctaveLayers, bool extended )
344     : surf( 0.0, nOctaves, nOctaveLayers, extended )
345 {}
346
347 void SurfDescriptorExtractor::compute( const Mat& image,
348                                        vector<KeyPoint>& keypoints,
349                                        Mat& descriptors) const
350 {
351     // Compute descriptors for given keypoints
352     vector<float> _descriptors;
353     Mat mask;
354     bool useProvidedKeypoints = true;
355     surf(image, mask, keypoints, _descriptors, useProvidedKeypoints);
356
357     descriptors.create((int)keypoints.size(), (int)surf.descriptorSize(), CV_32FC1);
358     assert( (int)_descriptors.size() == descriptors.rows * descriptors.cols );
359     std::copy(_descriptors.begin(), _descriptors.end(), descriptors.begin<float>());
360 }
361
362 void SurfDescriptorExtractor::read( const FileNode &fn )
363 {
364     int nOctaves = fn["nOctaves"];
365     int nOctaveLayers = fn["nOctaveLayers"];
366     bool extended = (int)fn["extended"] != 0;
367
368     surf = SURF( 0.0, nOctaves, nOctaveLayers, extended );
369 }
370
371 void SurfDescriptorExtractor::write( FileStorage &fs ) const
372 {
373 //    fs << "algorithm" << getAlgorithmName ();
374
375     fs << "nOctaves" << surf.nOctaves;
376     fs << "nOctaveLayers" << surf.nOctaveLayers;
377     fs << "extended" << surf.extended;
378 }
379
380 /****************************************************************************************\
381 *           Factory functions for descriptor extractor and matcher creating              *
382 \****************************************************************************************/
383
384 Ptr<DescriptorExtractor> createDescriptorExtractor( const string& descriptorExtractorType )
385 {
386     DescriptorExtractor* de = 0;
387     if( !descriptorExtractorType.compare( "SIFT" ) )
388     {
389         de = new SiftDescriptorExtractor/*( double magnification=SIFT::DescriptorParams::GET_DEFAULT_MAGNIFICATION(),
390                              bool isNormalize=true, bool recalculateAngles=true,
391                              int nOctaves=SIFT::CommonParams::DEFAULT_NOCTAVES,
392                              int nOctaveLayers=SIFT::CommonParams::DEFAULT_NOCTAVE_LAYERS,
393                              int firstOctave=SIFT::CommonParams::DEFAULT_FIRST_OCTAVE,
394                              int angleMode=SIFT::CommonParams::FIRST_ANGLE )*/;
395     }
396     else if( !descriptorExtractorType.compare( "SURF" ) )
397     {
398         de = new SurfDescriptorExtractor/*( int nOctaves=4, int nOctaveLayers=2, bool extended=false )*/;
399     }
400     else
401     {
402         //CV_Error( CV_StsBadArg, "unsupported descriptor extractor type");
403     }
404     return de;
405 }
406
407 Ptr<DescriptorMatcher> createDescriptorMatcher( const string& descriptorMatcherType )
408 {
409     DescriptorMatcher* dm = 0;
410     if( !descriptorMatcherType.compare( "BruteForce" ) )
411     {
412         dm = new BruteForceMatcher<L2<float> >();
413     }
414     else if ( !descriptorMatcherType.compare( "BruteForce-L1" ) )
415     {
416         dm = new BruteForceMatcher<L1<float> >();
417     }
418     else
419     {
420         //CV_Error( CV_StsBadArg, "unsupported descriptor matcher type");
421     }
422
423     return dm;
424 }
425
426 /****************************************************************************************\
427 *                             BruteForceMatcher L2 specialization                        *
428 \****************************************************************************************/
429 template<>
430 void BruteForceMatcher<L2<float> >::matchImpl( const Mat& query, const Mat& mask, vector<DMatch>& matches ) const
431 {
432     assert( mask.empty() || (mask.rows == query.rows && mask.cols == train.rows) );
433     assert( query.cols == train.cols ||  query.empty() ||  train.empty() );
434
435     matches.clear();
436     matches.reserve( query.rows );
437 #if (!defined HAVE_EIGEN2)
438     Mat norms;
439     cv::reduce( train.mul( train ), norms, 1, 0);
440     norms = norms.t();
441     Mat desc_2t = train.t();
442     for( int i=0;i<query.rows;i++ )
443     {
444         Mat distances = (-2)*query.row(i)*desc_2t;
445         distances += norms;
446         DMatch match;
447         match.indexTrain = -1;
448         double minVal;
449         Point minLoc;
450         if( mask.empty() )
451         {
452             minMaxLoc ( distances, &minVal, 0, &minLoc );
453         }
454         else
455         {
456             minMaxLoc ( distances, &minVal, 0, &minLoc, 0, mask.row( i ) );
457         }
458         match.indexTrain = minLoc.x;
459
460         if( match.indexTrain != -1 )
461         {
462             match.indexQuery = i;
463             double queryNorm = norm( query.row(i) );
464             match.distance = sqrt( minVal + queryNorm*queryNorm );
465             matches.push_back( match );
466         }
467     }
468
469 #else
470     Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> desc1t;
471     Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> desc2;
472     cv2eigen( query.t(), desc1t);
473     cv2eigen( train, desc2 );
474
475     Eigen::Matrix<float, Eigen::Dynamic, 1> norms = desc2.rowwise().squaredNorm() / 2;
476
477     if( mask.empty() )
478     {
479         for( int i=0;i<query.rows;i++ )
480         {
481             Eigen::Matrix<float, Eigen::Dynamic, 1> distances = desc2*desc1t.col(i);
482             distances -= norms;
483             DMatch match;
484             match.indexQuery = i;
485             match.distance = sqrt( (-2)*distances.maxCoeff( &match.indexTrain ) + desc1t.col(i).squaredNorm() );
486             matches.push_back( match );
487         }
488     }
489     else
490     {
491         for( int i=0;i<query.rows;i++ )
492         {
493             Eigen::Matrix<float, Eigen::Dynamic, 1> distances = desc2*desc1t.col(i);
494             distances -= norms;
495
496             float maxCoeff = -std::numeric_limits<float>::max();
497             DMatch match;
498             match.indexTrain = -1;
499             for( int j=0;j<train.rows;j++ )
500             {
501                 if( possibleMatch( mask, i, j ) && distances( j, 0 ) > maxCoeff )
502                 {
503                     maxCoeff = distances( j, 0 );
504                     match.indexTrain = j;
505                 }
506             }
507
508             if( match.indexTrain != -1 )
509             {
510                 match.indexQuery = i;
511                 match.distance = sqrt( (-2)*maxCoeff + desc1t.col(i).squaredNorm() );
512                 matches.push_back( match );
513             }
514         }
515     }
516 #endif
517 }
518
519 /****************************************************************************************\
520 *                                GenericDescriptorMatch                                  *
521 \****************************************************************************************/
522 /*
523  * KeyPointCollection
524  */
525 void KeyPointCollection::add( const Mat& _image, const vector<KeyPoint>& _points )
526 {
527     // update m_start_indices
528     if( startIndices.empty() )
529         startIndices.push_back(0);
530     else
531         startIndices.push_back((int)(*startIndices.rbegin() + points.rbegin()->size()));
532
533     // add image and keypoints
534     images.push_back(_image);
535     points.push_back(_points);
536 }
537
538 KeyPoint KeyPointCollection::getKeyPoint( int index ) const
539 {
540     size_t i = 0;
541     for(; i < startIndices.size() && startIndices[i] <= index; i++);
542     i--;
543     assert(i < startIndices.size() && (size_t)index - startIndices[i] < points[i].size());
544
545     return points[i][index - startIndices[i]];
546 }
547
548 size_t KeyPointCollection::calcKeypointCount() const
549 {
550     if( startIndices.empty() )
551         return 0;
552     return *startIndices.rbegin() + points.rbegin()->size();
553 }
554
555 void KeyPointCollection::clear()
556 {
557     images.clear();
558     points.clear();
559     startIndices.clear();
560 }
561
562 /*
563  * GenericDescriptorMatch
564  */
565
566 void GenericDescriptorMatch::match( const Mat&, vector<KeyPoint>&, vector<DMatch>& )
567 {
568 }
569
570 void GenericDescriptorMatch::match( const Mat&, vector<KeyPoint>&, vector<vector<DMatch> >&, float )
571 {
572 }
573
574 void GenericDescriptorMatch::add( KeyPointCollection& collection )
575 {
576     for( size_t i = 0; i < collection.images.size(); i++ )
577         add( collection.images[i], collection.points[i] );
578 }
579
580 void GenericDescriptorMatch::classify( const Mat& image, vector<cv::KeyPoint>& points )
581 {
582     vector<int> keypointIndices;
583     match( image, points, keypointIndices );
584
585     // remap keypoint indices to descriptors
586     for( size_t i = 0; i < keypointIndices.size(); i++ )
587         points[i].class_id = collection.getKeyPoint(keypointIndices[i]).class_id;
588 };
589
590 void GenericDescriptorMatch::clear()
591 {
592     collection.clear();
593 }
594
595 /*
596  * Factory function for GenericDescriptorMatch creating
597  */
598 Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericDescritptorMatchType, const string &paramsFilename )
599 {
600     GenericDescriptorMatch *descriptorMatch = 0;
601     if( ! genericDescritptorMatchType.compare ("ONEWAY") )
602     {
603         descriptorMatch = new OneWayDescriptorMatch ();
604     }
605     else if( ! genericDescritptorMatchType.compare ("FERN") )
606     {
607         FernDescriptorMatch::Params params;
608         params.signatureSize = numeric_limits<int>::max();
609         descriptorMatch = new FernDescriptorMatch (params);
610     }
611     else if( ! genericDescritptorMatchType.compare ("CALONDER") )
612     {
613         //descriptorMatch = new CalonderDescriptorMatch ();
614     }
615
616     if( !paramsFilename.empty() && descriptorMatch != 0 )
617     {
618         FileStorage fs = FileStorage( paramsFilename, FileStorage::READ );
619         if( fs.isOpened() )
620         {
621             descriptorMatch->read( fs.root() );
622             fs.release();
623         }
624     }
625
626     return descriptorMatch;
627 }
628
629 /****************************************************************************************\
630 *                                OneWayDescriptorMatch                                  *
631 \****************************************************************************************/
632 OneWayDescriptorMatch::OneWayDescriptorMatch()
633 {}
634
635 OneWayDescriptorMatch::OneWayDescriptorMatch( const Params& _params)
636 {
637     initialize(_params);
638 }
639
640 OneWayDescriptorMatch::~OneWayDescriptorMatch()
641 {}
642
643 void OneWayDescriptorMatch::initialize( const Params& _params, OneWayDescriptorBase *_base)
644 {
645     base.release();
646     if (_base != 0)
647     {
648         base = _base;
649     }
650     params = _params;
651 }
652
653 void OneWayDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
654 {
655     if( base.empty() )
656         base = new OneWayDescriptorObject( params.patchSize, params.poseCount, params.pcaFilename,
657                                            params.trainPath, params.trainImagesList, params.minScale, params.maxScale, params.stepScale);
658
659     size_t trainFeatureCount = keypoints.size();
660
661     base->Allocate( (int)trainFeatureCount );
662
663     IplImage _image = image;
664     for( size_t i = 0; i < keypoints.size(); i++ )
665         base->InitializeDescriptor( (int)i, &_image, keypoints[i], "" );
666
667     collection.add( Mat(), keypoints );
668
669 #if defined(_KDTREE)
670     base->ConvertDescriptorsArrayToTree();
671 #endif
672 }
673
674 void OneWayDescriptorMatch::add( KeyPointCollection& keypoints )
675 {
676     if( base.empty() )
677         base = new OneWayDescriptorObject( params.patchSize, params.poseCount, params.pcaFilename,
678                                            params.trainPath, params.trainImagesList, params.minScale, params.maxScale, params.stepScale);
679
680     size_t trainFeatureCount = keypoints.calcKeypointCount();
681
682     base->Allocate( (int)trainFeatureCount );
683
684     int count = 0;
685     for( size_t i = 0; i < keypoints.points.size(); i++ )
686     {
687         for( size_t j = 0; j < keypoints.points[i].size(); j++ )
688         {
689             IplImage img = keypoints.images[i];
690             base->InitializeDescriptor( count++, &img, keypoints.points[i][j], "" );
691         }
692
693         collection.add( Mat(), keypoints.points[i] );
694     }
695
696 #if defined(_KDTREE)
697     base->ConvertDescriptorsArrayToTree();
698 #endif
699 }
700
701 void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<int>& indices)
702 {
703     vector<DMatch> matchings( points.size() );
704     indices.resize(points.size());
705
706     match( image, points, matchings );
707
708     for( size_t i = 0; i < points.size(); i++ )
709         indices[i] = matchings[i].indexTrain;
710 }
711
712 void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<DMatch>& matches )
713 {
714     matches.resize( points.size() );
715     IplImage _image = image;
716     for( size_t i = 0; i < points.size(); i++ )
717     {
718         int poseIdx = -1;
719
720         DMatch match;
721         match.indexQuery = (int)i;
722         match.indexTrain = -1;
723         base->FindDescriptor( &_image, points[i].pt, match.indexTrain, poseIdx, match.distance );
724         matches[i] = match;
725     }
726 }
727
728 void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<vector<DMatch> >& matches, float /*threshold*/ )
729 {
730     matches.clear();
731     matches.resize( points.size() );
732
733     vector<DMatch> dmatches;
734     match( image, points, dmatches );
735     for( size_t i=0;i<matches.size();i++ )
736     {
737         matches[i].push_back( dmatches[i] );
738     }
739
740     /*
741     printf("Start matching %d points\n", points.size());
742     //std::cout << "Start matching " << points.size() << "points\n";
743     assert(collection.images.size() == 1);
744     int n = collection.points[0].size();
745
746     printf("n = %d\n", n);
747     for( size_t i = 0; i < points.size(); i++ )
748     {
749         //printf("Matching %d\n", i);
750         //int poseIdx = -1;
751
752         DMatch match;
753         match.indexQuery = i;
754         match.indexTrain = -1;
755
756
757         CvPoint pt = points[i].pt;
758         CvRect roi = cvRect(cvRound(pt.x - 24/4),
759                             cvRound(pt.y - 24/4),
760                             24/2, 24/2);
761         cvSetImageROI(&_image, roi);
762
763         std::vector<int> desc_idxs;
764         std::vector<int> pose_idxs;
765         std::vector<float> distances;
766         std::vector<float> _scales;
767
768
769         base->FindDescriptor(&_image, n, desc_idxs, pose_idxs, distances, _scales);
770         cvResetImageROI(&_image);
771
772         for( int j=0;j<n;j++ )
773         {
774             match.indexTrain = desc_idxs[j];
775             match.distance = distances[j];
776             matches[i].push_back( match );
777         }
778
779         //sort( matches[i].begin(), matches[i].end(), compareIndexTrain );
780         //for( int j=0;j<n;j++ )
781         //{
782             //printf( "%d %f;  ",matches[i][j].indexTrain, matches[i][j].distance);
783         //}
784         //printf("\n\n\n");
785
786
787
788         //base->FindDescriptor( &_image, 100, points[i].pt, match.indexTrain, poseIdx, match.distance );
789         //matches[i].push_back( match );
790     }
791     */
792 }
793
794
795 void OneWayDescriptorMatch::read( const FileNode &fn )
796 {
797     base = new OneWayDescriptorObject( params.patchSize, params.poseCount, string (), string (), string (),
798                                        params.minScale, params.maxScale, params.stepScale );
799     base->Read (fn);
800 }
801
802
803 void OneWayDescriptorMatch::write( FileStorage& fs ) const
804 {
805     base->Write (fs);
806 }
807
808 void OneWayDescriptorMatch::classify( const Mat& image, vector<KeyPoint>& points )
809 {
810     IplImage _image = image;
811     for( size_t i = 0; i < points.size(); i++ )
812     {
813         int descIdx = -1;
814         int poseIdx = -1;
815         float distance;
816         base->FindDescriptor(&_image, points[i].pt, descIdx, poseIdx, distance);
817         points[i].class_id = collection.getKeyPoint(descIdx).class_id;
818     }
819 }
820
821 void OneWayDescriptorMatch::clear ()
822 {
823     GenericDescriptorMatch::clear();
824     base->clear ();
825 }
826
827 /****************************************************************************************\
828 *                                  FernDescriptorMatch                                   *
829 \****************************************************************************************/
830 FernDescriptorMatch::Params::Params( int _nclasses, int _patchSize, int _signatureSize,
831                                      int _nstructs, int _structSize, int _nviews, int _compressionMethod,
832                                      const PatchGenerator& _patchGenerator ) :
833     nclasses(_nclasses), patchSize(_patchSize), signatureSize(_signatureSize),
834     nstructs(_nstructs), structSize(_structSize), nviews(_nviews),
835     compressionMethod(_compressionMethod), patchGenerator(_patchGenerator)
836 {}
837
838 FernDescriptorMatch::Params::Params( const string& _filename )
839 {
840     filename = _filename;
841 }
842
843 FernDescriptorMatch::FernDescriptorMatch()
844 {}
845
846 FernDescriptorMatch::FernDescriptorMatch( const Params& _params )
847 {
848     params = _params;
849 }
850
851 FernDescriptorMatch::~FernDescriptorMatch()
852 {}
853
854 void FernDescriptorMatch::initialize( const Params& _params )
855 {
856     classifier.release();
857     params = _params;
858     if( !params.filename.empty() )
859     {
860         classifier = new FernClassifier;
861         FileStorage fs(params.filename, FileStorage::READ);
862         if( fs.isOpened() )
863             classifier->read( fs.getFirstTopLevelNode() );
864     }
865 }
866
867 void FernDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
868 {
869     if( params.filename.empty() )
870         collection.add( image, keypoints );
871 }
872
873 void FernDescriptorMatch::trainFernClassifier()
874 {
875     if( classifier.empty() )
876     {
877         assert( params.filename.empty() );
878
879         vector<vector<Point2f> > points;
880         for( size_t imgIdx = 0; imgIdx < collection.images.size(); imgIdx++ )
881             KeyPoint::convert( collection.points[imgIdx], points[imgIdx] );
882
883         classifier = new FernClassifier( points, collection.images, vector<vector<int> >(), 0, // each points is a class
884                                          params.patchSize, params.signatureSize, params.nstructs, params.structSize,
885                                          params.nviews, params.compressionMethod, params.patchGenerator );
886     }
887 }
888
889 void FernDescriptorMatch::calcBestProbAndMatchIdx( const Mat& image, const Point2f& pt,
890                                                    float& bestProb, int& bestMatchIdx, vector<float>& signature )
891 {
892     (*classifier)( image, pt, signature);
893
894     bestProb = -FLT_MAX;
895     bestMatchIdx = -1;
896     for( int ci = 0; ci < classifier->getClassCount(); ci++ )
897     {
898         if( signature[ci] > bestProb )
899         {
900             bestProb = signature[ci];
901             bestMatchIdx = ci;
902         }
903     }
904 }
905
906 void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints, vector<int>& indices )
907 {
908     trainFernClassifier();
909
910     indices.resize( keypoints.size() );
911     vector<float> signature( (size_t)classifier->getClassCount() );
912
913     for( size_t pi = 0; pi < keypoints.size(); pi++ )
914     {
915         //calcBestProbAndMatchIdx( image, keypoints[pi].pt, bestProb, indices[pi], signature );
916         //TODO: use octave and image pyramid
917         indices[pi] = (*classifier)(image, keypoints[pi].pt, signature);
918     }
919 }
920
921 void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints, vector<DMatch>& matches )
922 {
923     trainFernClassifier();
924
925     matches.resize( keypoints.size() );
926     vector<float> signature( (size_t)classifier->getClassCount() );
927
928     for( int pi = 0; pi < (int)keypoints.size(); pi++ )
929     {
930         matches[pi].indexQuery = pi;
931         calcBestProbAndMatchIdx( image, keypoints[pi].pt, matches[pi].distance, matches[pi].indexTrain, signature );
932         //matching[pi].distance is log of probability so we need to transform it
933         matches[pi].distance = -matches[pi].distance;
934     }
935 }
936
937 void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints, vector<vector<DMatch> >& matches, float threshold )
938 {
939     trainFernClassifier();
940
941     matches.resize( keypoints.size() );
942     vector<float> signature( (size_t)classifier->getClassCount() );
943
944     for( int pi = 0; pi < (int)keypoints.size(); pi++ )
945     {
946         (*classifier)( image, keypoints[pi].pt, signature);
947
948         DMatch match;
949         match.indexQuery = pi;
950
951         for( int ci = 0; ci < classifier->getClassCount(); ci++ )
952         {
953             if( -signature[ci] < threshold )
954             {
955                 match.distance = -signature[ci];
956                 match.indexTrain = ci;
957                 matches[pi].push_back( match );
958             }
959         }
960     }
961 }
962
963 void FernDescriptorMatch::classify( const Mat& image, vector<KeyPoint>& keypoints )
964 {
965     trainFernClassifier();
966
967     vector<float> signature( (size_t)classifier->getClassCount() );
968     for( size_t pi = 0; pi < keypoints.size(); pi++ )
969     {
970         float bestProb = 0;
971         int bestMatchIdx = -1;
972         calcBestProbAndMatchIdx( image, keypoints[pi].pt, bestProb, bestMatchIdx, signature );
973         keypoints[pi].class_id = collection.getKeyPoint(bestMatchIdx).class_id;
974     }
975 }
976
977 void FernDescriptorMatch::read( const FileNode &fn )
978 {
979     params.nclasses = fn["nclasses"];
980     params.patchSize = fn["patchSize"];
981     params.signatureSize = fn["signatureSize"];
982     params.nstructs = fn["nstructs"];
983     params.structSize = fn["structSize"];
984     params.nviews = fn["nviews"];
985     params.compressionMethod = fn["compressionMethod"];
986
987     //classifier->read(fn);
988 }
989
990 void FernDescriptorMatch::write( FileStorage& fs ) const
991 {
992     fs << "nclasses" << params.nclasses;
993     fs << "patchSize" << params.patchSize;
994     fs << "signatureSize" << params.signatureSize;
995     fs << "nstructs" << params.nstructs;
996     fs << "structSize" << params.structSize;
997     fs << "nviews" << params.nviews;
998     fs << "compressionMethod" << params.compressionMethod;
999
1000 //    classifier->write(fs);
1001 }
1002
1003 void FernDescriptorMatch::clear ()
1004 {
1005     GenericDescriptorMatch::clear();
1006     classifier.release();
1007 }
1008
1009 /****************************************************************************************\
1010 *                                  VectorDescriptorMatch                                 *
1011 \****************************************************************************************/
1012 void VectorDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
1013 {
1014     Mat descriptors;
1015     extractor->compute( image, keypoints, descriptors );
1016     matcher->add( descriptors );
1017
1018     collection.add( Mat(), keypoints );
1019 };
1020
1021 void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<int>& keypointIndices )
1022 {
1023     Mat descriptors;
1024     extractor->compute( image, points, descriptors );
1025
1026     matcher->match( descriptors, keypointIndices );
1027 };
1028
1029 void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<DMatch>& matches )
1030 {
1031     Mat descriptors;
1032     extractor->compute( image, points, descriptors );
1033
1034     matcher->match( descriptors, matches );
1035 }
1036
1037 void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points,
1038                                    vector<vector<DMatch> >& matches, float threshold )
1039 {
1040     Mat descriptors;
1041     extractor->compute( image, points, descriptors );
1042
1043     matcher->match( descriptors, matches, threshold );
1044 }
1045
1046 void VectorDescriptorMatch::clear()
1047 {
1048     GenericDescriptorMatch::clear();
1049     matcher->clear();
1050 }
1051
1052 void VectorDescriptorMatch::read( const FileNode& fn )
1053 {
1054     GenericDescriptorMatch::read(fn);
1055     extractor->read (fn);
1056 }
1057
1058 void VectorDescriptorMatch::write (FileStorage& fs) const
1059 {
1060     GenericDescriptorMatch::write(fs);
1061     extractor->write (fs);
1062 }
1063
1064 }