152b4006454648674208c8e7118e9418d0ed94a9
[platform/upstream/opencv.git] / samples / cpp / matching_to_many_images.cpp
1 #include "opencv2/highgui/highgui.hpp"
2 #include "opencv2/features2d/features2d.hpp"
3 #include "opencv2/contrib/contrib.hpp"
4
5 #include <iostream>
6 #include <fstream>
7
8 using namespace cv;
9 using namespace std;
10
11 const string defaultDetectorType = "SURF";
12 const string defaultDescriptorType = "SURF";
13 const string defaultMatcherType = "FlannBased";
14 const string defaultQueryImageName = "../../opencv/samples/cpp/matching_to_many_images/query.png";
15 const string defaultFileWithTrainImages = "../../opencv/samples/cpp/matching_to_many_images/train/trainImages.txt";
16 const string defaultDirToSaveResImages = "../../opencv/samples/cpp/matching_to_many_images/results";
17
18 static void printPrompt( const string& applName )
19 {
20     cout << "/*\n"
21          << " * This is a sample on matching descriptors detected on one image to descriptors detected in image set.\n"
22          << " * So we have one query image and several train images. For each keypoint descriptor of query image\n"
23          << " * the one nearest train descriptor is found the entire collection of train images. To visualize the result\n"
24          << " * of matching we save images, each of which combines query and train image with matches between them (if they exist).\n"
25          << " * Match is drawn as line between corresponding points. Count of all matches is equel to count of\n"
26          << " * query keypoints, so we have the same count of lines in all set of result images (but not for each result\n"
27          << " * (train) image).\n"
28          << " */\n" << endl;
29
30     cout << endl << "Format:\n" << endl;
31     cout << "./" << applName << " [detectorType] [descriptorType] [matcherType] [queryImage] [fileWithTrainImages] [dirToSaveResImages]" << endl;
32     cout << endl;
33
34     cout << "\nExample:" << endl
35          << "./" << applName << " " << defaultDetectorType << " " << defaultDescriptorType << " " << defaultMatcherType << " "
36          << defaultQueryImageName << " " << defaultFileWithTrainImages << " " << defaultDirToSaveResImages << endl;
37 }
38
39 static void maskMatchesByTrainImgIdx( const vector<DMatch>& matches, int trainImgIdx, vector<char>& mask )
40 {
41     mask.resize( matches.size() );
42     fill( mask.begin(), mask.end(), 0 );
43     for( size_t i = 0; i < matches.size(); i++ )
44     {
45         if( matches[i].imgIdx == trainImgIdx )
46             mask[i] = 1;
47     }
48 }
49
50 static void readTrainFilenames( const string& filename, string& dirName, vector<string>& trainFilenames )
51 {
52     trainFilenames.clear();
53
54     ifstream file( filename.c_str() );
55     if ( !file.is_open() )
56         return;
57
58     size_t pos = filename.rfind('\\');
59     char dlmtr = '\\';
60     if (pos == string::npos)
61     {
62         pos = filename.rfind('/');
63         dlmtr = '/';
64     }
65     dirName = pos == string::npos ? "" : filename.substr(0, pos) + dlmtr;
66
67     while( !file.eof() )
68     {
69         string str; getline( file, str );
70         if( str.empty() ) break;
71         trainFilenames.push_back(str);
72     }
73     file.close();
74 }
75
76 static bool createDetectorDescriptorMatcher( const string& detectorType, const string& descriptorType, const string& matcherType,
77                                       Ptr<FeatureDetector>& featureDetector,
78                                       Ptr<DescriptorExtractor>& descriptorExtractor,
79                                       Ptr<DescriptorMatcher>& descriptorMatcher )
80 {
81     cout << "< Creating feature detector, descriptor extractor and descriptor matcher ..." << endl;
82     featureDetector = FeatureDetector::create( detectorType );
83     descriptorExtractor = DescriptorExtractor::create( descriptorType );
84     descriptorMatcher = DescriptorMatcher::create( matcherType );
85     cout << ">" << endl;
86
87     bool isCreated = featureDetector && descriptorExtractor && descriptorMatcher;
88     if( !isCreated )
89         cout << "Can not create feature detector or descriptor extractor or descriptor matcher of given types." << endl << ">" << endl;
90
91     return isCreated;
92 }
93
94 static bool readImages( const string& queryImageName, const string& trainFilename,
95                  Mat& queryImage, vector <Mat>& trainImages, vector<string>& trainImageNames )
96 {
97     cout << "< Reading the images..." << endl;
98     queryImage = imread( queryImageName, IMREAD_GRAYSCALE);
99     if( queryImage.empty() )
100     {
101         cout << "Query image can not be read." << endl << ">" << endl;
102         return false;
103     }
104     string trainDirName;
105     readTrainFilenames( trainFilename, trainDirName, trainImageNames );
106     if( trainImageNames.empty() )
107     {
108         cout << "Train image filenames can not be read." << endl << ">" << endl;
109         return false;
110     }
111     int readImageCount = 0;
112     for( size_t i = 0; i < trainImageNames.size(); i++ )
113     {
114         string filename = trainDirName + trainImageNames[i];
115         Mat img = imread( filename, IMREAD_GRAYSCALE );
116         if( img.empty() )
117             cout << "Train image " << filename << " can not be read." << endl;
118         else
119             readImageCount++;
120         trainImages.push_back( img );
121     }
122     if( !readImageCount )
123     {
124         cout << "All train images can not be read." << endl << ">" << endl;
125         return false;
126     }
127     else
128         cout << readImageCount << " train images were read." << endl;
129     cout << ">" << endl;
130
131     return true;
132 }
133
134 static void detectKeypoints( const Mat& queryImage, vector<KeyPoint>& queryKeypoints,
135                       const vector<Mat>& trainImages, vector<vector<KeyPoint> >& trainKeypoints,
136                       Ptr<FeatureDetector>& featureDetector )
137 {
138     cout << endl << "< Extracting keypoints from images..." << endl;
139     featureDetector->detect( queryImage, queryKeypoints );
140     featureDetector->detect( trainImages, trainKeypoints );
141     cout << ">" << endl;
142 }
143
144 static void computeDescriptors( const Mat& queryImage, vector<KeyPoint>& queryKeypoints, Mat& queryDescriptors,
145                          const vector<Mat>& trainImages, vector<vector<KeyPoint> >& trainKeypoints, vector<Mat>& trainDescriptors,
146                          Ptr<DescriptorExtractor>& descriptorExtractor )
147 {
148     cout << "< Computing descriptors for keypoints..." << endl;
149     descriptorExtractor->compute( queryImage, queryKeypoints, queryDescriptors );
150     descriptorExtractor->compute( trainImages, trainKeypoints, trainDescriptors );
151
152     int totalTrainDesc = 0;
153     for( vector<Mat>::const_iterator tdIter = trainDescriptors.begin(); tdIter != trainDescriptors.end(); tdIter++ )
154         totalTrainDesc += tdIter->rows;
155
156     cout << "Query descriptors count: " << queryDescriptors.rows << "; Total train descriptors count: " << totalTrainDesc << endl;
157     cout << ">" << endl;
158 }
159
160 static void matchDescriptors( const Mat& queryDescriptors, const vector<Mat>& trainDescriptors,
161                        vector<DMatch>& matches, Ptr<DescriptorMatcher>& descriptorMatcher )
162 {
163     cout << "< Set train descriptors collection in the matcher and match query descriptors to them..." << endl;
164     TickMeter tm;
165
166     tm.start();
167     descriptorMatcher->add( trainDescriptors );
168     descriptorMatcher->train();
169     tm.stop();
170     double buildTime = tm.getTimeMilli();
171
172     tm.start();
173     descriptorMatcher->match( queryDescriptors, matches );
174     tm.stop();
175     double matchTime = tm.getTimeMilli();
176
177     CV_Assert( queryDescriptors.rows == (int)matches.size() || matches.empty() );
178
179     cout << "Number of matches: " << matches.size() << endl;
180     cout << "Build time: " << buildTime << " ms; Match time: " << matchTime << " ms" << endl;
181     cout << ">" << endl;
182 }
183
184 static void saveResultImages( const Mat& queryImage, const vector<KeyPoint>& queryKeypoints,
185                        const vector<Mat>& trainImages, const vector<vector<KeyPoint> >& trainKeypoints,
186                        const vector<DMatch>& matches, const vector<string>& trainImagesNames, const string& resultDir )
187 {
188     cout << "< Save results..." << endl;
189     Mat drawImg;
190     vector<char> mask;
191     for( size_t i = 0; i < trainImages.size(); i++ )
192     {
193         if( !trainImages[i].empty() )
194         {
195             maskMatchesByTrainImgIdx( matches, (int)i, mask );
196             drawMatches( queryImage, queryKeypoints, trainImages[i], trainKeypoints[i],
197                          matches, drawImg, Scalar(255, 0, 0), Scalar(0, 255, 255), mask );
198             string filename = resultDir + "/res_" + trainImagesNames[i];
199             if( !imwrite( filename, drawImg ) )
200                 cout << "Image " << filename << " can not be saved (may be because directory " << resultDir << " does not exist)." << endl;
201         }
202     }
203     cout << ">" << endl;
204 }
205
206 int main(int argc, char** argv)
207 {
208     string detectorType = defaultDetectorType;
209     string descriptorType = defaultDescriptorType;
210     string matcherType = defaultMatcherType;
211     string queryImageName = defaultQueryImageName;
212     string fileWithTrainImages = defaultFileWithTrainImages;
213     string dirToSaveResImages = defaultDirToSaveResImages;
214
215     if( argc != 7 && argc != 1 )
216     {
217         printPrompt( argv[0] );
218         return -1;
219     }
220
221     if( argc != 1 )
222     {
223         detectorType = argv[1]; descriptorType = argv[2]; matcherType = argv[3];
224         queryImageName = argv[4]; fileWithTrainImages = argv[5];
225         dirToSaveResImages = argv[6];
226     }
227
228     Ptr<FeatureDetector> featureDetector;
229     Ptr<DescriptorExtractor> descriptorExtractor;
230     Ptr<DescriptorMatcher> descriptorMatcher;
231     if( !createDetectorDescriptorMatcher( detectorType, descriptorType, matcherType, featureDetector, descriptorExtractor, descriptorMatcher ) )
232     {
233         printPrompt( argv[0] );
234         return -1;
235     }
236
237     Mat queryImage;
238     vector<Mat> trainImages;
239     vector<string> trainImagesNames;
240     if( !readImages( queryImageName, fileWithTrainImages, queryImage, trainImages, trainImagesNames ) )
241     {
242         printPrompt( argv[0] );
243         return -1;
244     }
245
246     vector<KeyPoint> queryKeypoints;
247     vector<vector<KeyPoint> > trainKeypoints;
248     detectKeypoints( queryImage, queryKeypoints, trainImages, trainKeypoints, featureDetector );
249
250     Mat queryDescriptors;
251     vector<Mat> trainDescriptors;
252     computeDescriptors( queryImage, queryKeypoints, queryDescriptors,
253                         trainImages, trainKeypoints, trainDescriptors,
254                         descriptorExtractor );
255
256     vector<DMatch> matches;
257     matchDescriptors( queryDescriptors, trainDescriptors, matches, descriptorMatcher );
258
259     saveResultImages( queryImage, queryKeypoints, trainImages, trainKeypoints,
260                       matches, trainImagesNames, dirToSaveResImages );
261     return 0;
262 }