some samples updated according to new CommandLineParser class
[profile/ivi/opencv.git] / samples / c / find_obj_calonder.cpp
1 #include "opencv2/highgui/highgui.hpp"
2 #include "opencv2/core/core.hpp"
3 #include "opencv2/imgproc/imgproc.hpp"
4 #include "opencv2/features2d/features2d.hpp"
5
6 #include <iostream>
7 #include <fstream>
8
9 using namespace std;
10 using namespace cv;
11
12 void help()
13 {
14     cout << "This program shows the use of the Calonder point descriptor classifier"
15                 "SURF is used to detect interest points, Calonder is used to describe/match these points\n"
16                 "Format:" << endl <<
17             "   classifier_file(to write) test_image file_with_train_images_filenames(txt)" <<
18             "   or" << endl <<
19             "   classifier_file(to read) test_image"
20                 "Using OpenCV version %s\n" << CV_VERSION << "\n"
21             << endl;
22 }
23 /*
24  * Generates random perspective transform of image
25  */
26 void warpPerspectiveRand( const Mat& src, Mat& dst, Mat& H, RNG& rng )
27 {
28     H.create(3, 3, CV_32FC1);
29     H.at<float>(0,0) = rng.uniform( 0.8f, 1.2f);
30     H.at<float>(0,1) = rng.uniform(-0.1f, 0.1f);
31     H.at<float>(0,2) = rng.uniform(-0.1f, 0.1f)*src.cols;
32     H.at<float>(1,0) = rng.uniform(-0.1f, 0.1f);
33     H.at<float>(1,1) = rng.uniform( 0.8f, 1.2f);
34     H.at<float>(1,2) = rng.uniform(-0.1f, 0.1f)*src.rows;
35     H.at<float>(2,0) = rng.uniform( -1e-4f, 1e-4f);
36     H.at<float>(2,1) = rng.uniform( -1e-4f, 1e-4f);
37     H.at<float>(2,2) = rng.uniform( 0.8f, 1.2f);
38
39     warpPerspective( src, dst, H, src.size() );
40 }
41
42 /*
43  * Trains Calonder classifier and writes trained classifier in file:
44  *      imgFilename - name of .txt file which contains list of full filenames of train images,
45  *      classifierFilename - name of binary file in which classifier will be written.
46  *
47  * To train Calonder classifier RTreeClassifier class need to be used.
48  */
49 void trainCalonderClassifier( const string& classifierFilename, const string& imgFilename )
50 {
51     // Reads train images
52     ifstream is( imgFilename.c_str(), ifstream::in );
53     vector<Mat> trainImgs;
54     while( !is.eof() )
55     {
56         string str;
57         getline( is, str );
58         if (str.empty()) break;
59         Mat img = imread( str, CV_LOAD_IMAGE_GRAYSCALE );
60         if( !img.empty() )
61             trainImgs.push_back( img );
62     }
63     if( trainImgs.empty() )
64     {
65         cout << "All train images can not be read." << endl;
66         exit(-1);
67     }
68     cout << trainImgs.size() << " train images were read." << endl;
69
70     // Extracts keypoints from train images
71     SurfFeatureDetector detector;
72     vector<BaseKeypoint> trainPoints;
73     vector<IplImage> iplTrainImgs(trainImgs.size());
74     for( size_t imgIdx = 0; imgIdx < trainImgs.size(); imgIdx++ )
75     {
76         iplTrainImgs[imgIdx] = trainImgs[imgIdx];
77         vector<KeyPoint> kps; detector.detect( trainImgs[imgIdx], kps );
78
79         for( size_t pointIdx = 0; pointIdx < kps.size(); pointIdx++ )
80         {
81             Point2f p = kps[pointIdx].pt;
82             trainPoints.push_back( BaseKeypoint(cvRound(p.x), cvRound(p.y), &iplTrainImgs[imgIdx]) );
83         }
84     }
85
86     // Trains Calonder classifier on extracted points
87     RTreeClassifier classifier;
88     classifier.train( trainPoints, theRNG(), 48, 9, 100 );
89     // Writes classifier
90     classifier.write( classifierFilename.c_str() );
91 }
92
93 /*
94  * Test Calonder classifier to match keypoints on given image:
95  *      classifierFilename - name of file from which classifier will be read,
96  *      imgFilename - test image filename.
97  *
98  * To calculate keypoint descriptors you may use RTreeClassifier class (as to train),
99  * but it is convenient to use CalonderDescriptorExtractor class which is wrapper of
100  * RTreeClassifier.
101  */
102 void testCalonderClassifier( const string& classifierFilename, const string& imgFilename )
103 {
104     Mat img1 = imread( imgFilename, CV_LOAD_IMAGE_GRAYSCALE ), img2, H12;
105     if( img1.empty() )
106     {
107         cout << "Test image can not be read." << endl;
108         exit(-1);
109     }
110     warpPerspectiveRand( img1, img2, H12, theRNG() );
111
112     // Exstract keypoints from test images
113     SurfFeatureDetector detector;
114     vector<KeyPoint> keypoints1; detector.detect( img1, keypoints1 );
115     vector<KeyPoint> keypoints2; detector.detect( img2, keypoints2 );
116
117     // Compute descriptors
118     CalonderDescriptorExtractor<float> de( classifierFilename );
119     Mat descriptors1;  de.compute( img1, keypoints1, descriptors1 );
120     Mat descriptors2;  de.compute( img2, keypoints2, descriptors2 );
121
122     // Match descriptors
123     BruteForceMatcher<L1<float> > matcher;
124     vector<DMatch> matches;
125     matcher.match( descriptors1, descriptors2, matches );
126
127     // Prepare inlier mask
128     vector<char> matchesMask( matches.size(), 0 );
129     vector<Point2f> points1; KeyPoint::convert( keypoints1, points1 );
130     vector<Point2f> points2; KeyPoint::convert( keypoints2, points2 );
131     Mat points1t; perspectiveTransform(Mat(points1), points1t, H12);
132     for( size_t mi = 0; mi < matches.size(); mi++ )
133     {
134         if( norm(points2[matches[mi].trainIdx] - points1t.at<Point2f>(mi,0)) < 4 ) // inlier
135             matchesMask[mi] = 1;
136     }
137
138     // Draw
139     Mat drawImg;
140     drawMatches( img1, keypoints1, img2, keypoints2, matches, drawImg, CV_RGB(0, 255, 0), CV_RGB(0, 0, 255), matchesMask );
141     string winName = "Matches";
142     namedWindow( winName, WINDOW_AUTOSIZE );
143     imshow( winName, drawImg );
144     waitKey();
145 }
146
147 int main( int argc, char **argv )
148 {
149     if( argc != 4 && argc != 3 )
150     {
151         help();
152         return -1;
153     }
154
155     if( argc == 4 )
156         trainCalonderClassifier( argv[1], argv[3] );
157
158     testCalonderClassifier( argv[1], argv[2] );
159
160     return 0;
161 }