1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
42 #include "precomp.hpp"
45 #include <Eigen/Array>
52 const int draw_shift_bits = 4;
53 const int draw_multiplier = 1 << draw_shift_bits;
58 Mat windowedMatchingMask( const vector<KeyPoint>& keypoints1, const vector<KeyPoint>& keypoints2,
59 float maxDeltaX, float maxDeltaY )
61 if( keypoints1.empty() || keypoints2.empty() )
64 Mat mask( keypoints1.size(), keypoints2.size(), CV_8UC1 );
65 for( size_t i = 0; i < keypoints1.size(); i++ )
67 for( size_t j = 0; j < keypoints2.size(); j++ )
69 Point2f diff = keypoints2[j].pt - keypoints1[i].pt;
70 mask.at<uchar>(i, j) = std::abs(diff.x) < maxDeltaX && std::abs(diff.y) < maxDeltaY;
80 static inline void _drawKeypoint( Mat& img, const KeyPoint& p, const Scalar& color, int flags )
82 Point center( p.pt.x * draw_multiplier, p.pt.y * draw_multiplier );
84 if( flags & DrawMatchesFlags::DRAW_RICH_KEYPOINTS )
86 int radius = p.size/2 * draw_multiplier; // KeyPoint::size is a diameter
88 // draw the circles around keypoints with the keypoints size
89 circle( img, center, radius, color, 1, CV_AA, draw_shift_bits );
91 // draw orientation of the keypoint, if it is applicable
94 float srcAngleRad = p.angle*CV_PI/180;
95 Point orient(cos(srcAngleRad)*radius, sin(srcAngleRad)*radius);
96 line( img, center, center+orient, color, 1, CV_AA, draw_shift_bits );
101 // draw center with R=1
102 int radius = 1 * draw_multiplier;
103 circle( img, center, radius, color, 1, CV_AA, draw_shift_bits );
109 // draw center with R=3
110 int radius = 3 * draw_multiplier;
111 circle( img, center, radius, color, 1, CV_AA, draw_shift_bits );
115 void drawKeypoints( const Mat& image, const vector<KeyPoint>& keypoints, Mat& outImg,
116 const Scalar& _color, int flags )
118 if( !(flags & DrawMatchesFlags::DRAW_OVER_OUTIMG) )
119 cvtColor( image, outImg, CV_GRAY2BGR );
122 bool isRandColor = _color == Scalar::all(-1);
124 for( vector<KeyPoint>::const_iterator i = keypoints.begin(), ie = keypoints.end(); i != ie; ++i )
126 Scalar color = isRandColor ? Scalar(rng(256), rng(256), rng(256)) : _color;
127 _drawKeypoint( outImg, *i, color, flags );
131 static void _prepareImgAndDrawKeypoints( const Mat& img1, const vector<KeyPoint>& keypoints1,
132 const Mat& img2, const vector<KeyPoint>& keypoints2,
133 Mat& outImg, Mat& outImg1, Mat& outImg2,
134 const Scalar& singlePointColor, int flags )
136 Size size( img1.cols + img2.cols, MAX(img1.rows, img2.rows) );
137 if( flags & DrawMatchesFlags::DRAW_OVER_OUTIMG )
139 if( size.width > outImg.cols || size.height > outImg.rows )
140 CV_Error( CV_StsBadSize, "outImg has size less than need to draw img1 and img2 together" );
141 outImg1 = outImg( Rect(0, 0, img1.cols, img1.rows) );
142 outImg2 = outImg( Rect(img1.cols, 0, img2.cols, img2.rows) );
146 outImg.create( size, CV_MAKETYPE(img1.depth(), 3) );
147 outImg1 = outImg( Rect(0, 0, img1.cols, img1.rows) );
148 outImg2 = outImg( Rect(img1.cols, 0, img2.cols, img2.rows) );
149 cvtColor( img1, outImg1, CV_GRAY2RGB );
150 cvtColor( img2, outImg2, CV_GRAY2RGB );
154 if( !(flags & DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS) )
156 Mat outImg1 = outImg( Rect(0, 0, img1.cols, img1.rows) );
157 drawKeypoints( outImg1, keypoints1, outImg1, singlePointColor, flags + DrawMatchesFlags::DRAW_OVER_OUTIMG );
159 Mat outImg2 = outImg( Rect(img1.cols, 0, img2.cols, img2.rows) );
160 drawKeypoints( outImg2, keypoints2, outImg2, singlePointColor, flags + DrawMatchesFlags::DRAW_OVER_OUTIMG );
164 static inline void _drawMatch( Mat& outImg, Mat& outImg1, Mat& outImg2 ,
165 const KeyPoint& kp1, const KeyPoint& kp2, const Scalar& matchColor, int flags )
168 bool isRandMatchColor = matchColor == Scalar::all(-1);
169 Scalar color = isRandMatchColor ? Scalar( rng(256), rng(256), rng(256) ) : matchColor;
171 _drawKeypoint( outImg1, kp1, color, flags );
172 _drawKeypoint( outImg2, kp2, color, flags );
174 Point2f pt1 = kp1.pt,
176 dpt2 = Point2f( std::min(pt2.x+outImg1.cols, float(outImg.cols-1)), pt2.y );
178 line( outImg, Point(pt1.x*draw_multiplier, pt1.y*draw_multiplier), Point(dpt2.x*draw_multiplier, dpt2.y*draw_multiplier),
179 color, 1, CV_AA, draw_shift_bits );
182 void drawMatches( const Mat& img1, const vector<KeyPoint>& keypoints1,
183 const Mat& img2,const vector<KeyPoint>& keypoints2,
184 const vector<int>& matches1to2, Mat& outImg,
185 const Scalar& matchColor, const Scalar& singlePointColor,
186 const vector<char>& matchesMask, int flags )
188 if( matches1to2.size() != keypoints1.size() )
189 CV_Error( CV_StsBadSize, "matches1to2 must have the same size as keypoints1" );
190 if( !matchesMask.empty() && matchesMask.size() != matches1to2.size() )
191 CV_Error( CV_StsBadSize, "matchesMask must have the same size as matches1to2" );
193 Mat outImg1, outImg2;
194 _prepareImgAndDrawKeypoints( img1, keypoints1, img2, keypoints2,
195 outImg, outImg1, outImg2, singlePointColor, flags );
198 for( size_t i1 = 0; i1 < keypoints1.size(); i1++ )
200 int i2 = matches1to2[i1];
201 if( (matchesMask.empty() || matchesMask[i1] ) && i2 >= 0 )
203 const KeyPoint &kp1 = keypoints1[i1], &kp2 = keypoints2[i2];
204 _drawMatch( outImg, outImg1, outImg2, kp1, kp2, matchColor, flags );
209 void drawMatches( const Mat& img1, const vector<KeyPoint>& keypoints1,
210 const Mat& img2, const vector<KeyPoint>& keypoints2,
211 const vector<DMatch>& matches1to2, Mat& outImg,
212 const Scalar& matchColor, const Scalar& singlePointColor,
213 const vector<char>& matchesMask, int flags )
215 if( !matchesMask.empty() && matchesMask.size() != matches1to2.size() )
216 CV_Error( CV_StsBadSize, "matchesMask must have the same size as matches1to2" );
218 Mat outImg1, outImg2;
219 _prepareImgAndDrawKeypoints( img1, keypoints1, img2, keypoints2,
220 outImg, outImg1, outImg2, singlePointColor, flags );
223 for( size_t m = 0; m < matches1to2.size(); m++ )
225 int i1 = matches1to2[m].indexQuery;
226 int i2 = matches1to2[m].indexTrain;
227 if( matchesMask.empty() || matchesMask[m] )
229 const KeyPoint &kp1 = keypoints1[i1], &kp2 = keypoints2[i2];
230 _drawMatch( outImg, outImg1, outImg2, kp1, kp2, matchColor, flags );
235 void drawMatches( const Mat& img1, const vector<KeyPoint>& keypoints1,
236 const Mat& img2, const vector<KeyPoint>& keypoints2,
237 const vector<vector<DMatch> >& matches1to2, Mat& outImg,
238 const Scalar& matchColor, const Scalar& singlePointColor,
239 const vector<vector<char> >& matchesMask, int flags )
241 if( !matchesMask.empty() && matchesMask.size() != matches1to2.size() )
242 CV_Error( CV_StsBadSize, "matchesMask must have the same size as matches1to2" );
244 Mat outImg1, outImg2;
245 _prepareImgAndDrawKeypoints( img1, keypoints1, img2, keypoints2,
246 outImg, outImg1, outImg2, singlePointColor, flags );
249 for( size_t i = 0; i < matches1to2.size(); i++ )
251 for( size_t j = 0; j < matches1to2[i].size(); j++ )
253 int i1 = matches1to2[i][j].indexQuery;
254 int i2 = matches1to2[i][j].indexTrain;
255 if( matchesMask.empty() || matchesMask[i][j] )
257 const KeyPoint &kp1 = keypoints1[i1], &kp2 = keypoints2[i2];
258 _drawMatch( outImg, outImg1, outImg2, kp1, kp2, matchColor, flags );
264 /****************************************************************************************\
265 * DescriptorExtractor *
266 \****************************************************************************************/
268 * DescriptorExtractor
272 RoiPredicate(float _minX, float _minY, float _maxX, float _maxY)
273 : minX(_minX), minY(_minY), maxX(_maxX), maxY(_maxY)
276 bool operator()( const KeyPoint& keyPt) const
278 Point2f pt = keyPt.pt;
279 return (pt.x < minX) || (pt.x >= maxX) || (pt.y < minY) || (pt.y >= maxY);
282 float minX, minY, maxX, maxY;
285 void DescriptorExtractor::removeBorderKeypoints( vector<KeyPoint>& keypoints,
286 Size imageSize, int borderPixels )
288 keypoints.erase( remove_if(keypoints.begin(), keypoints.end(),
289 RoiPredicate((float)borderPixels, (float)borderPixels,
290 (float)(imageSize.width - borderPixels),
291 (float)(imageSize.height - borderPixels))),
295 /****************************************************************************************\
296 * SiftDescriptorExtractor *
297 \****************************************************************************************/
298 SiftDescriptorExtractor::SiftDescriptorExtractor( double magnification, bool isNormalize, bool recalculateAngles,
299 int nOctaves, int nOctaveLayers, int firstOctave, int angleMode )
300 : sift( magnification, isNormalize, recalculateAngles, nOctaves, nOctaveLayers, firstOctave, angleMode )
303 void SiftDescriptorExtractor::compute( const Mat& image,
304 vector<KeyPoint>& keypoints,
305 Mat& descriptors) const
307 bool useProvidedKeypoints = true;
308 sift(image, Mat(), keypoints, descriptors, useProvidedKeypoints);
311 void SiftDescriptorExtractor::read (const FileNode &fn)
313 double magnification = fn["magnification"];
314 bool isNormalize = (int)fn["isNormalize"] != 0;
315 bool recalculateAngles = (int)fn["recalculateAngles"] != 0;
316 int nOctaves = fn["nOctaves"];
317 int nOctaveLayers = fn["nOctaveLayers"];
318 int firstOctave = fn["firstOctave"];
319 int angleMode = fn["angleMode"];
321 sift = SIFT( magnification, isNormalize, recalculateAngles, nOctaves, nOctaveLayers, firstOctave, angleMode );
324 void SiftDescriptorExtractor::write (FileStorage &fs) const
326 // fs << "algorithm" << getAlgorithmName ();
328 SIFT::CommonParams commParams = sift.getCommonParams ();
329 SIFT::DescriptorParams descriptorParams = sift.getDescriptorParams ();
330 fs << "magnification" << descriptorParams.magnification;
331 fs << "isNormalize" << descriptorParams.isNormalize;
332 fs << "recalculateAngles" << descriptorParams.recalculateAngles;
333 fs << "nOctaves" << commParams.nOctaves;
334 fs << "nOctaveLayers" << commParams.nOctaveLayers;
335 fs << "firstOctave" << commParams.firstOctave;
336 fs << "angleMode" << commParams.angleMode;
339 /****************************************************************************************\
340 * SurfDescriptorExtractor *
341 \****************************************************************************************/
342 SurfDescriptorExtractor::SurfDescriptorExtractor( int nOctaves,
343 int nOctaveLayers, bool extended )
344 : surf( 0.0, nOctaves, nOctaveLayers, extended )
347 void SurfDescriptorExtractor::compute( const Mat& image,
348 vector<KeyPoint>& keypoints,
349 Mat& descriptors) const
351 // Compute descriptors for given keypoints
352 vector<float> _descriptors;
354 bool useProvidedKeypoints = true;
355 surf(image, mask, keypoints, _descriptors, useProvidedKeypoints);
357 descriptors.create((int)keypoints.size(), (int)surf.descriptorSize(), CV_32FC1);
358 assert( (int)_descriptors.size() == descriptors.rows * descriptors.cols );
359 std::copy(_descriptors.begin(), _descriptors.end(), descriptors.begin<float>());
362 void SurfDescriptorExtractor::read( const FileNode &fn )
364 int nOctaves = fn["nOctaves"];
365 int nOctaveLayers = fn["nOctaveLayers"];
366 bool extended = (int)fn["extended"] != 0;
368 surf = SURF( 0.0, nOctaves, nOctaveLayers, extended );
371 void SurfDescriptorExtractor::write( FileStorage &fs ) const
373 // fs << "algorithm" << getAlgorithmName ();
375 fs << "nOctaves" << surf.nOctaves;
376 fs << "nOctaveLayers" << surf.nOctaveLayers;
377 fs << "extended" << surf.extended;
380 /****************************************************************************************\
381 * Factory functions for descriptor extractor and matcher creating *
382 \****************************************************************************************/
384 Ptr<DescriptorExtractor> createDescriptorExtractor( const string& descriptorExtractorType )
386 DescriptorExtractor* de = 0;
387 if( !descriptorExtractorType.compare( "SIFT" ) )
389 de = new SiftDescriptorExtractor/*( double magnification=SIFT::DescriptorParams::GET_DEFAULT_MAGNIFICATION(),
390 bool isNormalize=true, bool recalculateAngles=true,
391 int nOctaves=SIFT::CommonParams::DEFAULT_NOCTAVES,
392 int nOctaveLayers=SIFT::CommonParams::DEFAULT_NOCTAVE_LAYERS,
393 int firstOctave=SIFT::CommonParams::DEFAULT_FIRST_OCTAVE,
394 int angleMode=SIFT::CommonParams::FIRST_ANGLE )*/;
396 else if( !descriptorExtractorType.compare( "SURF" ) )
398 de = new SurfDescriptorExtractor/*( int nOctaves=4, int nOctaveLayers=2, bool extended=false )*/;
402 //CV_Error( CV_StsBadArg, "unsupported descriptor extractor type");
407 Ptr<DescriptorMatcher> createDescriptorMatcher( const string& descriptorMatcherType )
409 DescriptorMatcher* dm = 0;
410 if( !descriptorMatcherType.compare( "BruteForce" ) )
412 dm = new BruteForceMatcher<L2<float> >();
414 else if ( !descriptorMatcherType.compare( "BruteForce-L1" ) )
416 dm = new BruteForceMatcher<L1<float> >();
420 //CV_Error( CV_StsBadArg, "unsupported descriptor matcher type");
426 /****************************************************************************************\
427 * BruteForceMatcher L2 specialization *
428 \****************************************************************************************/
430 void BruteForceMatcher<L2<float> >::matchImpl( const Mat& query, const Mat& mask, vector<DMatch>& matches ) const
432 assert( mask.empty() || (mask.rows == query.rows && mask.cols == train.rows) );
433 assert( query.cols == train.cols || query.empty() || train.empty() );
436 matches.reserve( query.rows );
437 #if (!defined HAVE_EIGEN2)
439 cv::reduce( train.mul( train ), norms, 1, 0);
441 Mat desc_2t = train.t();
442 for( int i=0;i<query.rows;i++ )
444 Mat distances = (-2)*query.row(i)*desc_2t;
447 match.indexTrain = -1;
452 minMaxLoc ( distances, &minVal, 0, &minLoc );
456 minMaxLoc ( distances, &minVal, 0, &minLoc, 0, mask.row( i ) );
458 match.indexTrain = minLoc.x;
460 if( match.indexTrain != -1 )
462 match.indexQuery = i;
463 double queryNorm = norm( query.row(i) );
464 match.distance = sqrt( minVal + queryNorm*queryNorm );
465 matches.push_back( match );
470 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> desc1t;
471 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> desc2;
472 cv2eigen( query.t(), desc1t);
473 cv2eigen( train, desc2 );
475 Eigen::Matrix<float, Eigen::Dynamic, 1> norms = desc2.rowwise().squaredNorm() / 2;
479 for( int i=0;i<query.rows;i++ )
481 Eigen::Matrix<float, Eigen::Dynamic, 1> distances = desc2*desc1t.col(i);
484 match.indexQuery = i;
485 match.distance = sqrt( (-2)*distances.maxCoeff( &match.indexTrain ) + desc1t.col(i).squaredNorm() );
486 matches.push_back( match );
491 for( int i=0;i<query.rows;i++ )
493 Eigen::Matrix<float, Eigen::Dynamic, 1> distances = desc2*desc1t.col(i);
496 float maxCoeff = -std::numeric_limits<float>::max();
498 match.indexTrain = -1;
499 for( int j=0;j<train.rows;j++ )
501 if( possibleMatch( mask, i, j ) && distances( j, 0 ) > maxCoeff )
503 maxCoeff = distances( j, 0 );
504 match.indexTrain = j;
508 if( match.indexTrain != -1 )
510 match.indexQuery = i;
511 match.distance = sqrt( (-2)*maxCoeff + desc1t.col(i).squaredNorm() );
512 matches.push_back( match );
519 /****************************************************************************************\
520 * GenericDescriptorMatch *
521 \****************************************************************************************/
525 void KeyPointCollection::add( const Mat& _image, const vector<KeyPoint>& _points )
527 // update m_start_indices
528 if( startIndices.empty() )
529 startIndices.push_back(0);
531 startIndices.push_back((int)(*startIndices.rbegin() + points.rbegin()->size()));
533 // add image and keypoints
534 images.push_back(_image);
535 points.push_back(_points);
538 KeyPoint KeyPointCollection::getKeyPoint( int index ) const
541 for(; i < startIndices.size() && startIndices[i] <= index; i++);
543 assert(i < startIndices.size() && (size_t)index - startIndices[i] < points[i].size());
545 return points[i][index - startIndices[i]];
548 size_t KeyPointCollection::calcKeypointCount() const
550 if( startIndices.empty() )
552 return *startIndices.rbegin() + points.rbegin()->size();
555 void KeyPointCollection::clear()
559 startIndices.clear();
563 * GenericDescriptorMatch
566 void GenericDescriptorMatch::match( const Mat&, vector<KeyPoint>&, vector<DMatch>& )
570 void GenericDescriptorMatch::match( const Mat&, vector<KeyPoint>&, vector<vector<DMatch> >&, float )
574 void GenericDescriptorMatch::add( KeyPointCollection& collection )
576 for( size_t i = 0; i < collection.images.size(); i++ )
577 add( collection.images[i], collection.points[i] );
580 void GenericDescriptorMatch::classify( const Mat& image, vector<cv::KeyPoint>& points )
582 vector<int> keypointIndices;
583 match( image, points, keypointIndices );
585 // remap keypoint indices to descriptors
586 for( size_t i = 0; i < keypointIndices.size(); i++ )
587 points[i].class_id = collection.getKeyPoint(keypointIndices[i]).class_id;
590 void GenericDescriptorMatch::clear()
596 * Factory function for GenericDescriptorMatch creating
598 Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericDescritptorMatchType, const string ¶msFilename )
600 GenericDescriptorMatch *descriptorMatch = 0;
601 if( ! genericDescritptorMatchType.compare ("ONEWAY") )
603 descriptorMatch = new OneWayDescriptorMatch ();
605 else if( ! genericDescritptorMatchType.compare ("FERN") )
607 FernDescriptorMatch::Params params;
608 params.signatureSize = numeric_limits<int>::max();
609 descriptorMatch = new FernDescriptorMatch (params);
611 else if( ! genericDescritptorMatchType.compare ("CALONDER") )
613 //descriptorMatch = new CalonderDescriptorMatch ();
616 if( !paramsFilename.empty() && descriptorMatch != 0 )
618 FileStorage fs = FileStorage( paramsFilename, FileStorage::READ );
621 descriptorMatch->read( fs.root() );
626 return descriptorMatch;
629 /****************************************************************************************\
630 * OneWayDescriptorMatch *
631 \****************************************************************************************/
632 OneWayDescriptorMatch::OneWayDescriptorMatch()
635 OneWayDescriptorMatch::OneWayDescriptorMatch( const Params& _params)
640 OneWayDescriptorMatch::~OneWayDescriptorMatch()
643 void OneWayDescriptorMatch::initialize( const Params& _params, OneWayDescriptorBase *_base)
653 void OneWayDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
656 base = new OneWayDescriptorObject( params.patchSize, params.poseCount, params.pcaFilename,
657 params.trainPath, params.trainImagesList, params.minScale, params.maxScale, params.stepScale);
659 size_t trainFeatureCount = keypoints.size();
661 base->Allocate( (int)trainFeatureCount );
663 IplImage _image = image;
664 for( size_t i = 0; i < keypoints.size(); i++ )
665 base->InitializeDescriptor( (int)i, &_image, keypoints[i], "" );
667 collection.add( Mat(), keypoints );
670 base->ConvertDescriptorsArrayToTree();
674 void OneWayDescriptorMatch::add( KeyPointCollection& keypoints )
677 base = new OneWayDescriptorObject( params.patchSize, params.poseCount, params.pcaFilename,
678 params.trainPath, params.trainImagesList, params.minScale, params.maxScale, params.stepScale);
680 size_t trainFeatureCount = keypoints.calcKeypointCount();
682 base->Allocate( (int)trainFeatureCount );
685 for( size_t i = 0; i < keypoints.points.size(); i++ )
687 for( size_t j = 0; j < keypoints.points[i].size(); j++ )
689 IplImage img = keypoints.images[i];
690 base->InitializeDescriptor( count++, &img, keypoints.points[i][j], "" );
693 collection.add( Mat(), keypoints.points[i] );
697 base->ConvertDescriptorsArrayToTree();
701 void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<int>& indices)
703 vector<DMatch> matchings( points.size() );
704 indices.resize(points.size());
706 match( image, points, matchings );
708 for( size_t i = 0; i < points.size(); i++ )
709 indices[i] = matchings[i].indexTrain;
712 void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<DMatch>& matches )
714 matches.resize( points.size() );
715 IplImage _image = image;
716 for( size_t i = 0; i < points.size(); i++ )
721 match.indexQuery = (int)i;
722 match.indexTrain = -1;
723 base->FindDescriptor( &_image, points[i].pt, match.indexTrain, poseIdx, match.distance );
728 void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<vector<DMatch> >& matches, float /*threshold*/ )
731 matches.resize( points.size() );
733 vector<DMatch> dmatches;
734 match( image, points, dmatches );
735 for( size_t i=0;i<matches.size();i++ )
737 matches[i].push_back( dmatches[i] );
741 printf("Start matching %d points\n", points.size());
742 //std::cout << "Start matching " << points.size() << "points\n";
743 assert(collection.images.size() == 1);
744 int n = collection.points[0].size();
746 printf("n = %d\n", n);
747 for( size_t i = 0; i < points.size(); i++ )
749 //printf("Matching %d\n", i);
753 match.indexQuery = i;
754 match.indexTrain = -1;
757 CvPoint pt = points[i].pt;
758 CvRect roi = cvRect(cvRound(pt.x - 24/4),
759 cvRound(pt.y - 24/4),
761 cvSetImageROI(&_image, roi);
763 std::vector<int> desc_idxs;
764 std::vector<int> pose_idxs;
765 std::vector<float> distances;
766 std::vector<float> _scales;
769 base->FindDescriptor(&_image, n, desc_idxs, pose_idxs, distances, _scales);
770 cvResetImageROI(&_image);
772 for( int j=0;j<n;j++ )
774 match.indexTrain = desc_idxs[j];
775 match.distance = distances[j];
776 matches[i].push_back( match );
779 //sort( matches[i].begin(), matches[i].end(), compareIndexTrain );
780 //for( int j=0;j<n;j++ )
782 //printf( "%d %f; ",matches[i][j].indexTrain, matches[i][j].distance);
788 //base->FindDescriptor( &_image, 100, points[i].pt, match.indexTrain, poseIdx, match.distance );
789 //matches[i].push_back( match );
795 void OneWayDescriptorMatch::read( const FileNode &fn )
797 base = new OneWayDescriptorObject( params.patchSize, params.poseCount, string (), string (), string (),
798 params.minScale, params.maxScale, params.stepScale );
803 void OneWayDescriptorMatch::write( FileStorage& fs ) const
808 void OneWayDescriptorMatch::classify( const Mat& image, vector<KeyPoint>& points )
810 IplImage _image = image;
811 for( size_t i = 0; i < points.size(); i++ )
816 base->FindDescriptor(&_image, points[i].pt, descIdx, poseIdx, distance);
817 points[i].class_id = collection.getKeyPoint(descIdx).class_id;
821 void OneWayDescriptorMatch::clear ()
823 GenericDescriptorMatch::clear();
827 /****************************************************************************************\
828 * FernDescriptorMatch *
829 \****************************************************************************************/
830 FernDescriptorMatch::Params::Params( int _nclasses, int _patchSize, int _signatureSize,
831 int _nstructs, int _structSize, int _nviews, int _compressionMethod,
832 const PatchGenerator& _patchGenerator ) :
833 nclasses(_nclasses), patchSize(_patchSize), signatureSize(_signatureSize),
834 nstructs(_nstructs), structSize(_structSize), nviews(_nviews),
835 compressionMethod(_compressionMethod), patchGenerator(_patchGenerator)
838 FernDescriptorMatch::Params::Params( const string& _filename )
840 filename = _filename;
843 FernDescriptorMatch::FernDescriptorMatch()
846 FernDescriptorMatch::FernDescriptorMatch( const Params& _params )
851 FernDescriptorMatch::~FernDescriptorMatch()
854 void FernDescriptorMatch::initialize( const Params& _params )
856 classifier.release();
858 if( !params.filename.empty() )
860 classifier = new FernClassifier;
861 FileStorage fs(params.filename, FileStorage::READ);
863 classifier->read( fs.getFirstTopLevelNode() );
867 void FernDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
869 if( params.filename.empty() )
870 collection.add( image, keypoints );
873 void FernDescriptorMatch::trainFernClassifier()
875 if( classifier.empty() )
877 assert( params.filename.empty() );
879 vector<vector<Point2f> > points;
880 for( size_t imgIdx = 0; imgIdx < collection.images.size(); imgIdx++ )
881 KeyPoint::convert( collection.points[imgIdx], points[imgIdx] );
883 classifier = new FernClassifier( points, collection.images, vector<vector<int> >(), 0, // each points is a class
884 params.patchSize, params.signatureSize, params.nstructs, params.structSize,
885 params.nviews, params.compressionMethod, params.patchGenerator );
889 void FernDescriptorMatch::calcBestProbAndMatchIdx( const Mat& image, const Point2f& pt,
890 float& bestProb, int& bestMatchIdx, vector<float>& signature )
892 (*classifier)( image, pt, signature);
896 for( int ci = 0; ci < classifier->getClassCount(); ci++ )
898 if( signature[ci] > bestProb )
900 bestProb = signature[ci];
906 void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints, vector<int>& indices )
908 trainFernClassifier();
910 indices.resize( keypoints.size() );
911 vector<float> signature( (size_t)classifier->getClassCount() );
913 for( size_t pi = 0; pi < keypoints.size(); pi++ )
915 //calcBestProbAndMatchIdx( image, keypoints[pi].pt, bestProb, indices[pi], signature );
916 //TODO: use octave and image pyramid
917 indices[pi] = (*classifier)(image, keypoints[pi].pt, signature);
921 void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints, vector<DMatch>& matches )
923 trainFernClassifier();
925 matches.resize( keypoints.size() );
926 vector<float> signature( (size_t)classifier->getClassCount() );
928 for( int pi = 0; pi < (int)keypoints.size(); pi++ )
930 matches[pi].indexQuery = pi;
931 calcBestProbAndMatchIdx( image, keypoints[pi].pt, matches[pi].distance, matches[pi].indexTrain, signature );
932 //matching[pi].distance is log of probability so we need to transform it
933 matches[pi].distance = -matches[pi].distance;
937 void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints, vector<vector<DMatch> >& matches, float threshold )
939 trainFernClassifier();
941 matches.resize( keypoints.size() );
942 vector<float> signature( (size_t)classifier->getClassCount() );
944 for( int pi = 0; pi < (int)keypoints.size(); pi++ )
946 (*classifier)( image, keypoints[pi].pt, signature);
949 match.indexQuery = pi;
951 for( int ci = 0; ci < classifier->getClassCount(); ci++ )
953 if( -signature[ci] < threshold )
955 match.distance = -signature[ci];
956 match.indexTrain = ci;
957 matches[pi].push_back( match );
963 void FernDescriptorMatch::classify( const Mat& image, vector<KeyPoint>& keypoints )
965 trainFernClassifier();
967 vector<float> signature( (size_t)classifier->getClassCount() );
968 for( size_t pi = 0; pi < keypoints.size(); pi++ )
971 int bestMatchIdx = -1;
972 calcBestProbAndMatchIdx( image, keypoints[pi].pt, bestProb, bestMatchIdx, signature );
973 keypoints[pi].class_id = collection.getKeyPoint(bestMatchIdx).class_id;
977 void FernDescriptorMatch::read( const FileNode &fn )
979 params.nclasses = fn["nclasses"];
980 params.patchSize = fn["patchSize"];
981 params.signatureSize = fn["signatureSize"];
982 params.nstructs = fn["nstructs"];
983 params.structSize = fn["structSize"];
984 params.nviews = fn["nviews"];
985 params.compressionMethod = fn["compressionMethod"];
987 //classifier->read(fn);
990 void FernDescriptorMatch::write( FileStorage& fs ) const
992 fs << "nclasses" << params.nclasses;
993 fs << "patchSize" << params.patchSize;
994 fs << "signatureSize" << params.signatureSize;
995 fs << "nstructs" << params.nstructs;
996 fs << "structSize" << params.structSize;
997 fs << "nviews" << params.nviews;
998 fs << "compressionMethod" << params.compressionMethod;
1000 // classifier->write(fs);
1003 void FernDescriptorMatch::clear ()
1005 GenericDescriptorMatch::clear();
1006 classifier.release();
1009 /****************************************************************************************\
1010 * VectorDescriptorMatch *
1011 \****************************************************************************************/
1012 void VectorDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
1015 extractor->compute( image, keypoints, descriptors );
1016 matcher->add( descriptors );
1018 collection.add( Mat(), keypoints );
1021 void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<int>& keypointIndices )
1024 extractor->compute( image, points, descriptors );
1026 matcher->match( descriptors, keypointIndices );
1029 void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<DMatch>& matches )
1032 extractor->compute( image, points, descriptors );
1034 matcher->match( descriptors, matches );
1037 void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points,
1038 vector<vector<DMatch> >& matches, float threshold )
1041 extractor->compute( image, points, descriptors );
1043 matcher->match( descriptors, matches, threshold );
1046 void VectorDescriptorMatch::clear()
1048 GenericDescriptorMatch::clear();
1052 void VectorDescriptorMatch::read( const FileNode& fn )
1054 GenericDescriptorMatch::read(fn);
1055 extractor->read (fn);
1058 void VectorDescriptorMatch::write (FileStorage& fs) const
1060 GenericDescriptorMatch::write(fs);
1061 extractor->write (fs);