a7b01c8706521511648d39c5827aa2311c668e36
[platform/upstream/opencv.git] / samples / c / find_obj_calonder.cpp
1 #include "opencv2/highgui/highgui.hpp"
2 #include "opencv2/core/core.hpp"
3 #include "opencv2/imgproc/imgproc.hpp"
4 #include "opencv2/features2d/features2d.hpp"
5 #include "opencv2/nonfree/nonfree.hpp"
6 #include "opencv2/legacy/legacy.hpp"
7
8 #include <iostream>
9 #include <fstream>
10
11 using namespace std;
12 using namespace cv;
13
14 static void help()
15 {
16     cout << "This program shows the use of the Calonder point descriptor classifier"
17             "SURF is used to detect interest points, Calonder is used to describe/match these points\n"
18             "Format:" << endl <<
19             "   classifier_file(to write) test_image file_with_train_images_filenames(txt)" <<
20             "   or" << endl <<
21             "   classifier_file(to read) test_image" << "\n" << endl <<
22             "Using OpenCV version " << CV_VERSION << "\n" << endl;
23
24     return;
25 }
26
27
28 /*
29  * Generates random perspective transform of image
30  */
31 static void warpPerspectiveRand( const Mat& src, Mat& dst, Mat& H, RNG& rng )
32 {
33     H.create(3, 3, CV_32FC1);
34     H.at<float>(0,0) = rng.uniform( 0.8f, 1.2f);
35     H.at<float>(0,1) = rng.uniform(-0.1f, 0.1f);
36     H.at<float>(0,2) = rng.uniform(-0.1f, 0.1f)*src.cols;
37     H.at<float>(1,0) = rng.uniform(-0.1f, 0.1f);
38     H.at<float>(1,1) = rng.uniform( 0.8f, 1.2f);
39     H.at<float>(1,2) = rng.uniform(-0.1f, 0.1f)*src.rows;
40     H.at<float>(2,0) = rng.uniform( -1e-4f, 1e-4f);
41     H.at<float>(2,1) = rng.uniform( -1e-4f, 1e-4f);
42     H.at<float>(2,2) = rng.uniform( 0.8f, 1.2f);
43
44     warpPerspective( src, dst, H, src.size() );
45 }
46
47 /*
48  * Trains Calonder classifier and writes trained classifier in file:
49  *      imgFilename - name of .txt file which contains list of full filenames of train images,
50  *      classifierFilename - name of binary file in which classifier will be written.
51  *
52  * To train Calonder classifier RTreeClassifier class need to be used.
53  */
54 static void trainCalonderClassifier( const string& classifierFilename, const string& imgFilename )
55 {
56     // Reads train images
57     ifstream is( imgFilename.c_str(), ifstream::in );
58     vector<Mat> trainImgs;
59     while( !is.eof() )
60     {
61         string str;
62         getline( is, str );
63         if (str.empty()) break;
64         Mat img = imread( str, IMREAD_GRAYSCALE );
65         if( !img.empty() )
66             trainImgs.push_back( img );
67     }
68     if( trainImgs.empty() )
69     {
70         cout << "All train images can not be read." << endl;
71         exit(-1);
72     }
73     cout << trainImgs.size() << " train images were read." << endl;
74
75     // Extracts keypoints from train images
76     SurfFeatureDetector detector;
77     vector<BaseKeypoint> trainPoints;
78     vector<IplImage> iplTrainImgs(trainImgs.size());
79     for( size_t imgIdx = 0; imgIdx < trainImgs.size(); imgIdx++ )
80     {
81         iplTrainImgs[imgIdx] = trainImgs[imgIdx];
82         vector<KeyPoint> kps; detector.detect( trainImgs[imgIdx], kps );
83
84         for( size_t pointIdx = 0; pointIdx < kps.size(); pointIdx++ )
85         {
86             Point2f p = kps[pointIdx].pt;
87             trainPoints.push_back( BaseKeypoint(cvRound(p.x), cvRound(p.y), &iplTrainImgs[imgIdx]) );
88         }
89     }
90
91     // Trains Calonder classifier on extracted points
92     RTreeClassifier classifier;
93     classifier.train( trainPoints, theRNG(), 48, 9, 100 );
94     // Writes classifier
95     classifier.write( classifierFilename.c_str() );
96 }
97
98 /*
99  * Test Calonder classifier to match keypoints on given image:
100  *      classifierFilename - name of file from which classifier will be read,
101  *      imgFilename - test image filename.
102  *
103  * To calculate keypoint descriptors you may use RTreeClassifier class (as to train),
104  * but it is convenient to use CalonderDescriptorExtractor class which is wrapper of
105  * RTreeClassifier.
106  */
107 static void testCalonderClassifier( const string& classifierFilename, const string& imgFilename )
108 {
109     Mat img1 = imread( imgFilename, IMREAD_GRAYSCALE ), img2, H12;
110     if( img1.empty() )
111     {
112         cout << "Test image can not be read." << endl;
113         exit(-1);
114     }
115     warpPerspectiveRand( img1, img2, H12, theRNG() );
116
117     // Exstract keypoints from test images
118     SurfFeatureDetector detector;
119     vector<KeyPoint> keypoints1; detector.detect( img1, keypoints1 );
120     vector<KeyPoint> keypoints2; detector.detect( img2, keypoints2 );
121
122     // Compute descriptors
123     CalonderDescriptorExtractor<float> de( classifierFilename );
124     Mat descriptors1;  de.compute( img1, keypoints1, descriptors1 );
125     Mat descriptors2;  de.compute( img2, keypoints2, descriptors2 );
126
127     // Match descriptors
128     BFMatcher matcher(de.defaultNorm());
129     vector<DMatch> matches;
130     matcher.match( descriptors1, descriptors2, matches );
131
132     // Prepare inlier mask
133     vector<char> matchesMask( matches.size(), 0 );
134     vector<Point2f> points1; KeyPoint::convert( keypoints1, points1 );
135     vector<Point2f> points2; KeyPoint::convert( keypoints2, points2 );
136     Mat points1t; perspectiveTransform(Mat(points1), points1t, H12);
137     for( size_t mi = 0; mi < matches.size(); mi++ )
138     {
139         if( norm(points2[matches[mi].trainIdx] - points1t.at<Point2f>((int)mi,0)) < 4 ) // inlier
140             matchesMask[mi] = 1;
141     }
142
143     // Draw
144     Mat drawImg;
145     drawMatches( img1, keypoints1, img2, keypoints2, matches, drawImg, CV_RGB(0, 255, 0), CV_RGB(0, 0, 255), matchesMask );
146     string winName = "Matches";
147     namedWindow( winName, WINDOW_AUTOSIZE );
148     imshow( winName, drawImg );
149     waitKey();
150 }
151
152 int main( int argc, char **argv )
153 {
154     if( argc != 4 && argc != 3 )
155     {
156         help();
157         return -1;
158     }
159
160     if( argc == 4 )
161         trainCalonderClassifier( argv[1], argv[3] );
162
163     testCalonderClassifier( argv[1], argv[2] );
164
165     return 0;
166 }