added possibility of matcher filtering to sample
authorMaria Dimashova <no@email>
Fri, 29 Oct 2010 13:13:46 +0000 (13:13 +0000)
committerMaria Dimashova <no@email>
Fri, 29 Oct 2010 13:13:46 +0000 (13:13 +0000)
samples/cpp/descriptor_extractor_matcher.cpp

index 102a81b..afbfbca 100644 (file)
@@ -11,6 +11,58 @@ using namespace std;
 #define DRAW_RICH_KEYPOINTS_MODE     0
 #define DRAW_OUTLIERS_MODE           0
 
+const string winName = "correspondences";
+
+enum { NONE_FILTER = 0, CROSS_CHECK_FILTER = 1 };
+
+int getMatcherFilterType( const string& str )
+{
+    if( str == "NoneFilter" )
+        return NONE_FILTER;
+    if( str == "CrossCheckFilter" )
+        return CROSS_CHECK_FILTER;
+    CV_Assert(0);
+    return -1;
+}
+
+void simpleMatching( Ptr<DescriptorMatcher>& descriptorMatcher,
+                     const Mat& descriptors1, const Mat& descriptors2,
+                     vector<DMatch>& matches12 )
+{
+    vector<DMatch> matches;
+    descriptorMatcher->match( descriptors1, descriptors2, matches12 );
+}
+
+void crossCheckMatching( Ptr<DescriptorMatcher>& descriptorMatcher,
+                         const Mat& descriptors1, const Mat& descriptors2,
+                         vector<DMatch>& filteredMatches12, int knn=1 )
+{
+    filteredMatches12.clear();
+    vector<vector<DMatch> > matches12, matches21;
+    descriptorMatcher->knnMatch( descriptors1, descriptors2, matches12, knn );
+    descriptorMatcher->knnMatch( descriptors2, descriptors1, matches21, knn );
+    for( size_t m = 0; m < matches12.size(); m++ )
+    {
+        bool findCrossCheck = false;
+        for( size_t fk = 0; fk < matches12[m].size(); fk++ )
+        {
+            DMatch forward = matches12[m][fk];
+
+            for( size_t bk = 0; bk < matches21[forward.trainIdx].size(); bk++ )
+            {
+                DMatch backward = matches21[forward.trainIdx][bk];
+                if( backward.trainIdx == forward.queryIdx )
+                {
+                    filteredMatches12.push_back(forward);
+                    findCrossCheck = true;
+                    break;
+                }
+            }
+            if( findCrossCheck ) break;
+        }
+    }
+}
+
 void warpPerspectiveRand( const Mat& src, Mat& dst, Mat& H, RNG& rng )
 {
     H.create(3, 3, CV_32FC1);
@@ -27,12 +79,10 @@ void warpPerspectiveRand( const Mat& src, Mat& dst, Mat& H, RNG& rng )
     warpPerspective( src, dst, H, src.size() );
 }
 
-const string winName = "correspondences";
-
 void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective,
                   vector<KeyPoint>& keypoints1, const Mat& descriptors1,
                   Ptr<FeatureDetector>& detector, Ptr<DescriptorExtractor>& descriptorExtractor,
-                  Ptr<DescriptorMatcher>& descriptorMatcher,
+                  Ptr<DescriptorMatcher>& descriptorMatcher, int matcherFilter, bool eval,
                   double ransacReprojThreshold, RNG& rng )
 {
     assert( !img1.empty() );
@@ -47,7 +97,7 @@ void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective,
     detector->detect( img2, keypoints2 );
     cout << keypoints2.size() << " points" << endl << ">" << endl;
 
-    if( !H12.empty() )
+    if( !H12.empty() && eval )
     {
         cout << "< Evaluate feature detector..." << endl;
         float repeatability;
@@ -64,11 +114,18 @@ void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective,
     cout << ">" << endl;
 
     cout << "< Matching descriptors..." << endl;
-    vector<DMatch> matches;
-    descriptorMatcher->match( descriptors1, descriptors2, matches, Mat() );
+    vector<DMatch> filteredMatches;
+    switch( matcherFilter )
+    {
+    case CROSS_CHECK_FILTER :
+        crossCheckMatching( descriptorMatcher, descriptors1, descriptors2, filteredMatches, 1 );
+        break;
+    default :
+        simpleMatching( descriptorMatcher, descriptors1, descriptors2, filteredMatches );
+    }
     cout << ">" << endl;
 
-    if( !H12.empty() )
+    if( !H12.empty() && eval )
     {
         cout << "< Evaluate descriptor match..." << endl;
         vector<Point2f> curve;
@@ -79,14 +136,17 @@ void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective,
         cout << ">" << endl;
     }
 
-    vector<int> trainIdxs( matches.size() );
-    for( size_t i = 0; i < matches.size(); i++ )
-        trainIdxs[i] = matches[i].trainIdx;
+    vector<int> queryIdxs( filteredMatches.size() ), trainIdxs( filteredMatches.size() );
+    for( size_t i = 0; i < filteredMatches.size(); i++ )
+    {
+        queryIdxs[i] = filteredMatches[i].queryIdx;
+        trainIdxs[i] = filteredMatches[i].trainIdx;
+    }
 
     if( !isWarpPerspective && ransacReprojThreshold >= 0 )
     {
         cout << "< Computing homography (RANSAC)..." << endl;
-        vector<Point2f> points1; KeyPoint::convert(keypoints1, points1);
+        vector<Point2f> points1; KeyPoint::convert(keypoints1, points1, queryIdxs);
         vector<Point2f> points2; KeyPoint::convert(keypoints2, points2, trainIdxs);
         H12 = findHomography( Mat(points1), Mat(points2), CV_RANSAC, ransacReprojThreshold );
         cout << ">" << endl;
@@ -95,8 +155,8 @@ void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective,
     Mat drawImg;
     if( !H12.empty() ) // filter outliers
     {
-        vector<char> matchesMask( matches.size(), 0 );
-        vector<Point2f> points1; KeyPoint::convert(keypoints1, points1);
+        vector<char> matchesMask( filteredMatches.size(), 0 );
+        vector<Point2f> points1; KeyPoint::convert(keypoints1, points1, queryIdxs);
         vector<Point2f> points2; KeyPoint::convert(keypoints2, points2, trainIdxs);
         Mat points1t; perspectiveTransform(Mat(points1), points1t, H12);
         for( size_t i1 = 0; i1 < points1.size(); i1++ )
@@ -105,7 +165,7 @@ void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective,
                 matchesMask[i1] = 1;
         }
         // draw inliers
-        drawMatches( img1, keypoints1, img2, keypoints2, matches, drawImg, CV_RGB(0, 255, 0), CV_RGB(0, 0, 255), matchesMask
+        drawMatches( img1, keypoints1, img2, keypoints2, filteredMatches, drawImg, CV_RGB(0, 255, 0), CV_RGB(0, 0, 255), matchesMask
 #if DRAW_RICH_KEYPOINTS_MODE
                      , DrawMatchesFlags::DRAW_RICH_KEYPOINTS
 #endif
@@ -115,48 +175,56 @@ void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective,
         // draw outliers
         for( size_t i1 = 0; i1 < matchesMask.size(); i1++ )
             matchesMask[i1] = !matchesMask[i1];
-        drawMatches( img1, keypoints1, img2, keypoints2, matches, drawImg, CV_RGB(0, 0, 255), CV_RGB(255, 0, 0), matchesMask,
+        drawMatches( img1, keypoints1, img2, keypoints2, filteredMatches, drawImg, CV_RGB(0, 0, 255), CV_RGB(255, 0, 0), matchesMask,
                      DrawMatchesFlags::DRAW_OVER_OUTIMG | DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS );
 #endif
     }
     else
-        drawMatches( img1, keypoints1, img2, keypoints2, matches, drawImg );
+        drawMatches( img1, keypoints1, img2, keypoints2, filteredMatches, drawImg );
 
     imshow( winName, drawImg );
 }
 
 int main(int argc, char** argv)
 {
-    if( argc != 4 && argc != 6 )
+    if( argc != 7 && argc != 8 )
     {
         cout << "Format:" << endl;
         cout << "case1: second image is obtained from the first (given) image using random generated homography matrix" << endl;
-        cout << argv[0] << " [detectorType] [descriptorType] [image1]" << endl;
+        cout << argv[0] << " [detectorType] [descriptorType] [matcherType] [matcherFilterType] [image] [evaluate(0 or 1)]" << endl;
+        cout << "Example:" << endl;
+        cout << "./descriptor_extractor_matcher SURF SURF FlannBased NoneFilter cola.jpg 0" << endl;
+        cout << endl;
         cout << "case2: both images are given. If ransacReprojThreshold>=0 then homography matrix are calculated" << endl;
-        cout << argv[0] << " [detectorType] [descriptorType] [image1] [image2] [ransacReprojThreshold]" << endl;
+        cout << argv[0] << " [detectorType] [descriptorType] [matcherType] [matcherFilterType] [image1] [image2] [ransacReprojThreshold]" << endl;
         cout << endl << "Mathes are filtered using homography matrix in case1 and case2 (if ransacReprojThreshold>=0)" << endl;
+        cout << "Example:" << endl;
+        cout << "./descriptor_extractor_matcher SURF SURF BruteForce CrossCheckFilter cola1.jpg cola2.jpg 3" << endl;
+
         return -1;
     }
-    bool isWarpPerspective = argc == 4;
+    bool isWarpPerspective = argc == 7;
     double ransacReprojThreshold = -1;
     if( !isWarpPerspective )
-        ransacReprojThreshold = atof(argv[5]);
+        ransacReprojThreshold = atof(argv[7]);
 
     cout << "< Creating detector, descriptor extractor and descriptor matcher ..." << endl;
     Ptr<FeatureDetector> detector = createFeatureDetector( argv[1] );
     Ptr<DescriptorExtractor> descriptorExtractor = createDescriptorExtractor( argv[2] );
-    Ptr<DescriptorMatcher> descriptorMatcher = createDescriptorMatcher( "BruteForce" );
+    Ptr<DescriptorMatcher> descriptorMatcher = createDescriptorMatcher( argv[3] );
+    int mactherFilterType = getMatcherFilterType( argv[4] );
+    bool eval = !isWarpPerspective ? false : (atoi(argv[6]) == 0 ? false : true);
     cout << ">" << endl;
     if( detector.empty() || descriptorExtractor.empty() || descriptorMatcher.empty()  )
     {
         cout << "Can not create detector or descriptor exstractor or descriptor matcher of given types" << endl;
         return -1;
-       }
+    }
                
     cout << "< Reading the images..." << endl;
-    Mat img1 = imread( argv[3] ), img2;
+    Mat img1 = imread( argv[5] ), img2;
     if( !isWarpPerspective )
-        img2 = imread( argv[4] );
+        img2 = imread( argv[6] );
     cout << ">" << endl;
     if( img1.empty() || (!isWarpPerspective && img2.empty()) )
     {
@@ -177,11 +245,11 @@ int main(int argc, char** argv)
     namedWindow(winName, 1);
     RNG rng = theRNG();
     doIteration( img1, img2, isWarpPerspective, keypoints1, descriptors1,
-                 detector, descriptorExtractor, descriptorMatcher,
+                 detector, descriptorExtractor, descriptorMatcher, mactherFilterType, eval,
                  ransacReprojThreshold, rng );
     for(;;)
     {
-        char c = (char)cvWaitKey(0);
+        char c = (char)waitKey(0);
         if( c == '\x1b' ) // esc
         {
             cout << "Exiting ..." << endl;
@@ -190,7 +258,7 @@ int main(int argc, char** argv)
         else if( isWarpPerspective )
         {
             doIteration( img1, img2, isWarpPerspective, keypoints1, descriptors1,
-                         detector, descriptorExtractor, descriptorMatcher,
+                         detector, descriptorExtractor, descriptorMatcher, mactherFilterType, eval,
                          ransacReprojThreshold, rng );
         }
     }