651f439245d8521f67aece59cc94b0e4917ac2de
[platform/upstream/opencv.git] / samples / c / find_obj.cpp
1 /*
2  * A Demo to OpenCV Implementation of SURF
3  * Further Information Refer to "SURF: Speed-Up Robust Feature"
4  * Author: Liu Liu
5  * liuliu.1987+opencv@gmail.com
6  */
7 #include "opencv2/objdetect/objdetect.hpp"
8 #include "opencv2/features2d/features2d.hpp"
9 #include "opencv2/calib3d/calib3d.hpp"
10 #include "opencv2/nonfree/nonfree.hpp"
11 #include "opencv2/imgproc/imgproc_c.h"
12 #include "opencv2/highgui/highgui_c.h"
13 #include "opencv2/legacy/legacy.hpp"
14 #include "opencv2/legacy/compat.hpp"
15
16 #include <iostream>
17 #include <vector>
18 #include <stdio.h>
19
20 using namespace std;
21 static void help()
22 {
23     printf(
24         "This program demonstrated the use of the SURF Detector and Descriptor using\n"
25         "either FLANN (fast approx nearst neighbor classification) or brute force matching\n"
26         "on planar objects.\n"
27         "Usage:\n"
28         "./find_obj <object_filename> <scene_filename>, default is box.png  and box_in_scene.png\n\n");
29     return;
30 }
31
32 // define whether to use approximate nearest-neighbor search
33 #define USE_FLANN
34
35 #ifdef USE_FLANN
36 static void
37 flannFindPairs( const CvSeq*, const CvSeq* objectDescriptors,
38            const CvSeq*, const CvSeq* imageDescriptors, vector<int>& ptpairs )
39 {
40     int length = (int)(objectDescriptors->elem_size/sizeof(float));
41
42     cv::Mat m_object(objectDescriptors->total, length, CV_32F);
43     cv::Mat m_image(imageDescriptors->total, length, CV_32F);
44
45
46     // copy descriptors
47     CvSeqReader obj_reader;
48     float* obj_ptr = m_object.ptr<float>(0);
49     cvStartReadSeq( objectDescriptors, &obj_reader );
50     for(int i = 0; i < objectDescriptors->total; i++ )
51     {
52         const float* descriptor = (const float*)obj_reader.ptr;
53         CV_NEXT_SEQ_ELEM( obj_reader.seq->elem_size, obj_reader );
54         memcpy(obj_ptr, descriptor, length*sizeof(float));
55         obj_ptr += length;
56     }
57     CvSeqReader img_reader;
58     float* img_ptr = m_image.ptr<float>(0);
59     cvStartReadSeq( imageDescriptors, &img_reader );
60     for(int i = 0; i < imageDescriptors->total; i++ )
61     {
62         const float* descriptor = (const float*)img_reader.ptr;
63         CV_NEXT_SEQ_ELEM( img_reader.seq->elem_size, img_reader );
64         memcpy(img_ptr, descriptor, length*sizeof(float));
65         img_ptr += length;
66     }
67
68     // find nearest neighbors using FLANN
69     cv::Mat m_indices(objectDescriptors->total, 2, CV_32S);
70     cv::Mat m_dists(objectDescriptors->total, 2, CV_32F);
71     cv::flann::Index flann_index(m_image, cv::flann::KDTreeIndexParams(4));  // using 4 randomized kdtrees
72     flann_index.knnSearch(m_object, m_indices, m_dists, 2, cv::flann::SearchParams(64) ); // maximum number of leafs checked
73
74     int* indices_ptr = m_indices.ptr<int>(0);
75     float* dists_ptr = m_dists.ptr<float>(0);
76     for (int i=0;i<m_indices.rows;++i) {
77         if (dists_ptr[2*i]<0.6*dists_ptr[2*i+1]) {
78             ptpairs.push_back(i);
79             ptpairs.push_back(indices_ptr[2*i]);
80         }
81     }
82 }
83 #else
84
85 static double
86 compareSURFDescriptors( const float* d1, const float* d2, double best, int length )
87 {
88     double total_cost = 0;
89     assert( length % 4 == 0 );
90     for( int i = 0; i < length; i += 4 )
91     {
92         double t0 = d1[i  ] - d2[i  ];
93         double t1 = d1[i+1] - d2[i+1];
94         double t2 = d1[i+2] - d2[i+2];
95         double t3 = d1[i+3] - d2[i+3];
96         total_cost += t0*t0 + t1*t1 + t2*t2 + t3*t3;
97         if( total_cost > best )
98             break;
99     }
100     return total_cost;
101 }
102
103 static int
104 naiveNearestNeighbor( const float* vec, int laplacian,
105                       const CvSeq* model_keypoints,
106                       const CvSeq* model_descriptors )
107 {
108     int length = (int)(model_descriptors->elem_size/sizeof(float));
109     int i, neighbor = -1;
110     double d, dist1 = 1e6, dist2 = 1e6;
111     CvSeqReader reader, kreader;
112     cvStartReadSeq( model_keypoints, &kreader, 0 );
113     cvStartReadSeq( model_descriptors, &reader, 0 );
114
115     for( i = 0; i < model_descriptors->total; i++ )
116     {
117         const CvSURFPoint* kp = (const CvSURFPoint*)kreader.ptr;
118         const float* mvec = (const float*)reader.ptr;
119         CV_NEXT_SEQ_ELEM( kreader.seq->elem_size, kreader );
120         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
121         if( laplacian != kp->laplacian )
122             continue;
123         d = compareSURFDescriptors( vec, mvec, dist2, length );
124         if( d < dist1 )
125         {
126             dist2 = dist1;
127             dist1 = d;
128             neighbor = i;
129         }
130         else if ( d < dist2 )
131             dist2 = d;
132     }
133     if ( dist1 < 0.6*dist2 )
134         return neighbor;
135     return -1;
136 }
137
138 static void
139 findPairs( const CvSeq* objectKeypoints, const CvSeq* objectDescriptors,
140            const CvSeq* imageKeypoints, const CvSeq* imageDescriptors, vector<int>& ptpairs )
141 {
142     int i;
143     CvSeqReader reader, kreader;
144     cvStartReadSeq( objectKeypoints, &kreader );
145     cvStartReadSeq( objectDescriptors, &reader );
146     ptpairs.clear();
147
148     for( i = 0; i < objectDescriptors->total; i++ )
149     {
150         const CvSURFPoint* kp = (const CvSURFPoint*)kreader.ptr;
151         const float* descriptor = (const float*)reader.ptr;
152         CV_NEXT_SEQ_ELEM( kreader.seq->elem_size, kreader );
153         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
154         int nearest_neighbor = naiveNearestNeighbor( descriptor, kp->laplacian, imageKeypoints, imageDescriptors );
155         if( nearest_neighbor >= 0 )
156         {
157             ptpairs.push_back(i);
158             ptpairs.push_back(nearest_neighbor);
159         }
160     }
161 }
162 #endif
163
164 /* a rough implementation for object location */
165 static int
166 locatePlanarObject( const CvSeq* objectKeypoints, const CvSeq* objectDescriptors,
167                     const CvSeq* imageKeypoints, const CvSeq* imageDescriptors,
168                     const CvPoint src_corners[4], CvPoint dst_corners[4] )
169 {
170     double h[9];
171     CvMat _h = cvMat(3, 3, CV_64F, h);
172     vector<int> ptpairs;
173     vector<CvPoint2D32f> pt1, pt2;
174     CvMat _pt1, _pt2;
175     int i, n;
176
177 #ifdef USE_FLANN
178     flannFindPairs( objectKeypoints, objectDescriptors, imageKeypoints, imageDescriptors, ptpairs );
179 #else
180     findPairs( objectKeypoints, objectDescriptors, imageKeypoints, imageDescriptors, ptpairs );
181 #endif
182
183     n = (int)(ptpairs.size()/2);
184     if( n < 4 )
185         return 0;
186
187     pt1.resize(n);
188     pt2.resize(n);
189     for( i = 0; i < n; i++ )
190     {
191         pt1[i] = ((CvSURFPoint*)cvGetSeqElem(objectKeypoints,ptpairs[i*2]))->pt;
192         pt2[i] = ((CvSURFPoint*)cvGetSeqElem(imageKeypoints,ptpairs[i*2+1]))->pt;
193     }
194
195     _pt1 = cvMat(1, n, CV_32FC2, &pt1[0] );
196     _pt2 = cvMat(1, n, CV_32FC2, &pt2[0] );
197     if( !cvFindHomography( &_pt1, &_pt2, &_h, CV_RANSAC, 5 ))
198         return 0;
199
200     for( i = 0; i < 4; i++ )
201     {
202         double x = src_corners[i].x, y = src_corners[i].y;
203         double Z = 1./(h[6]*x + h[7]*y + h[8]);
204         double X = (h[0]*x + h[1]*y + h[2])*Z;
205         double Y = (h[3]*x + h[4]*y + h[5])*Z;
206         dst_corners[i] = cvPoint(cvRound(X), cvRound(Y));
207     }
208
209     return 1;
210 }
211
212 int main(int argc, char** argv)
213 {
214     const char* object_filename = argc == 3 ? argv[1] : "box.png";
215     const char* scene_filename = argc == 3 ? argv[2] : "box_in_scene.png";
216
217     cv::initModule_nonfree();
218     help();
219
220     IplImage* object = cvLoadImage( object_filename, CV_LOAD_IMAGE_GRAYSCALE );
221     IplImage* image = cvLoadImage( scene_filename, CV_LOAD_IMAGE_GRAYSCALE );
222     if( !object || !image )
223     {
224         fprintf( stderr, "Can not load %s and/or %s\n",
225             object_filename, scene_filename );
226         exit(-1);
227     }
228
229     CvMemStorage* storage = cvCreateMemStorage(0);
230
231     cvNamedWindow("Object", 1);
232     cvNamedWindow("Object Correspond", 1);
233
234     static cv::Scalar colors[] =
235     {
236         cv::Scalar(0,0,255),
237         cv::Scalar(0,128,255),
238         cv::Scalar(0,255,255),
239         cv::Scalar(0,255,0),
240         cv::Scalar(255,128,0),
241         cv::Scalar(255,255,0),
242         cv::Scalar(255,0,0),
243         cv::Scalar(255,0,255),
244         cv::Scalar(255,255,255)
245     };
246
247     IplImage* object_color = cvCreateImage(cvGetSize(object), 8, 3);
248     cvCvtColor( object, object_color, CV_GRAY2BGR );
249
250     CvSeq* objectKeypoints = 0, *objectDescriptors = 0;
251     CvSeq* imageKeypoints = 0, *imageDescriptors = 0;
252     int i;
253     CvSURFParams params = cvSURFParams(500, 1);
254
255     double tt = (double)cvGetTickCount();
256     cvExtractSURF( object, 0, &objectKeypoints, &objectDescriptors, storage, params );
257     printf("Object Descriptors: %d\n", objectDescriptors->total);
258
259     cvExtractSURF( image, 0, &imageKeypoints, &imageDescriptors, storage, params );
260     printf("Image Descriptors: %d\n", imageDescriptors->total);
261     tt = (double)cvGetTickCount() - tt;
262
263     printf( "Extraction time = %gms\n", tt/(cvGetTickFrequency()*1000.));
264
265     CvPoint src_corners[4] = {CvPoint(0,0), CvPoint(object->width,0), CvPoint(object->width, object->height), CvPoint(0, object->height)};
266     CvPoint dst_corners[4];
267     IplImage* correspond = cvCreateImage( cvSize(image->width, object->height+image->height), 8, 1 );
268     cvSetImageROI( correspond, cvRect( 0, 0, object->width, object->height ) );
269     cvCopy( object, correspond );
270     cvSetImageROI( correspond, cvRect( 0, object->height, correspond->width, correspond->height ) );
271     cvCopy( image, correspond );
272     cvResetImageROI( correspond );
273
274 #ifdef USE_FLANN
275     printf("Using approximate nearest neighbor search\n");
276 #endif
277
278     if( locatePlanarObject( objectKeypoints, objectDescriptors, imageKeypoints,
279         imageDescriptors, src_corners, dst_corners ))
280     {
281         for( i = 0; i < 4; i++ )
282         {
283             CvPoint r1 = dst_corners[i%4];
284             CvPoint r2 = dst_corners[(i+1)%4];
285             cvLine( correspond, cvPoint(r1.x, r1.y+object->height ),
286                 cvPoint(r2.x, r2.y+object->height ), colors[8] );
287         }
288     }
289     vector<int> ptpairs;
290 #ifdef USE_FLANN
291     flannFindPairs( objectKeypoints, objectDescriptors, imageKeypoints, imageDescriptors, ptpairs );
292 #else
293     findPairs( objectKeypoints, objectDescriptors, imageKeypoints, imageDescriptors, ptpairs );
294 #endif
295     for( i = 0; i < (int)ptpairs.size(); i += 2 )
296     {
297         CvSURFPoint* r1 = (CvSURFPoint*)cvGetSeqElem( objectKeypoints, ptpairs[i] );
298         CvSURFPoint* r2 = (CvSURFPoint*)cvGetSeqElem( imageKeypoints, ptpairs[i+1] );
299         cvLine( correspond, cvPointFrom32f(r1->pt),
300             cvPoint(cvRound(r2->pt.x), cvRound(r2->pt.y+object->height)), colors[8] );
301     }
302
303     cvShowImage( "Object Correspond", correspond );
304     for( i = 0; i < objectKeypoints->total; i++ )
305     {
306         CvSURFPoint* r = (CvSURFPoint*)cvGetSeqElem( objectKeypoints, i );
307         CvPoint center;
308         int radius;
309         center.x = cvRound(r->pt.x);
310         center.y = cvRound(r->pt.y);
311         radius = cvRound(r->size*1.2/9.*2);
312         cvCircle( object_color, center, radius, colors[0], 1, 8, 0 );
313     }
314     cvShowImage( "Object", object_color );
315
316     cvWaitKey(0);
317
318     cvDestroyWindow("Object");
319     cvDestroyWindow("Object Correspond");
320
321     return 0;
322 }