converted some more samples to C++
[profile/ivi/opencv.git] / samples / c / find_obj.cpp
1 /*
2  * A Demo to OpenCV Implementation of SURF
3  * Further Information Refer to "SURF: Speed-Up Robust Feature"
4  * Author: Liu Liu
5  * liuliu.1987+opencv@gmail.com
6  */
7 #include <opencv2/objdetect/objdetect.hpp>
8 #include <opencv2/features2d/features2d.hpp>
9 #include <opencv2/highgui/highgui.hpp>
10 #include <opencv2/calib3d/calib3d.hpp>
11 #include <opencv2/imgproc/imgproc_c.h>
12
13 #include <iostream>
14 #include <vector>
15
16 using namespace std;
17
18
19 // define whether to use approximate nearest-neighbor search
20 #define USE_FLANN
21
22
23 IplImage *image = 0;
24
25 double
26 compareSURFDescriptors( const float* d1, const float* d2, double best, int length )
27 {
28     double total_cost = 0;
29     assert( length % 4 == 0 );
30     for( int i = 0; i < length; i += 4 )
31     {
32         double t0 = d1[i] - d2[i];
33         double t1 = d1[i+1] - d2[i+1];
34         double t2 = d1[i+2] - d2[i+2];
35         double t3 = d1[i+3] - d2[i+3];
36         total_cost += t0*t0 + t1*t1 + t2*t2 + t3*t3;
37         if( total_cost > best )
38             break;
39     }
40     return total_cost;
41 }
42
43
44 int
45 naiveNearestNeighbor( const float* vec, int laplacian,
46                       const CvSeq* model_keypoints,
47                       const CvSeq* model_descriptors )
48 {
49     int length = (int)(model_descriptors->elem_size/sizeof(float));
50     int i, neighbor = -1;
51     double d, dist1 = 1e6, dist2 = 1e6;
52     CvSeqReader reader, kreader;
53     cvStartReadSeq( model_keypoints, &kreader, 0 );
54     cvStartReadSeq( model_descriptors, &reader, 0 );
55
56     for( i = 0; i < model_descriptors->total; i++ )
57     {
58         const CvSURFPoint* kp = (const CvSURFPoint*)kreader.ptr;
59         const float* mvec = (const float*)reader.ptr;
60         CV_NEXT_SEQ_ELEM( kreader.seq->elem_size, kreader );
61         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
62         if( laplacian != kp->laplacian )
63             continue;
64         d = compareSURFDescriptors( vec, mvec, dist2, length );
65         if( d < dist1 )
66         {
67             dist2 = dist1;
68             dist1 = d;
69             neighbor = i;
70         }
71         else if ( d < dist2 )
72             dist2 = d;
73     }
74     if ( dist1 < 0.6*dist2 )
75         return neighbor;
76     return -1;
77 }
78
79 void
80 findPairs( const CvSeq* objectKeypoints, const CvSeq* objectDescriptors,
81            const CvSeq* imageKeypoints, const CvSeq* imageDescriptors, vector<int>& ptpairs )
82 {
83     int i;
84     CvSeqReader reader, kreader;
85     cvStartReadSeq( objectKeypoints, &kreader );
86     cvStartReadSeq( objectDescriptors, &reader );
87     ptpairs.clear();
88
89     for( i = 0; i < objectDescriptors->total; i++ )
90     {
91         const CvSURFPoint* kp = (const CvSURFPoint*)kreader.ptr;
92         const float* descriptor = (const float*)reader.ptr;
93         CV_NEXT_SEQ_ELEM( kreader.seq->elem_size, kreader );
94         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
95         int nearest_neighbor = naiveNearestNeighbor( descriptor, kp->laplacian, imageKeypoints, imageDescriptors );
96         if( nearest_neighbor >= 0 )
97         {
98             ptpairs.push_back(i);
99             ptpairs.push_back(nearest_neighbor);
100         }
101     }
102 }
103
104
105 void
106 flannFindPairs( const CvSeq*, const CvSeq* objectDescriptors,
107            const CvSeq*, const CvSeq* imageDescriptors, vector<int>& ptpairs )
108 {
109         int length = (int)(objectDescriptors->elem_size/sizeof(float));
110
111     cv::Mat m_object(objectDescriptors->total, length, CV_32F);
112         cv::Mat m_image(imageDescriptors->total, length, CV_32F);
113
114
115         // copy descriptors
116     CvSeqReader obj_reader;
117         float* obj_ptr = m_object.ptr<float>(0);
118     cvStartReadSeq( objectDescriptors, &obj_reader );
119     for(int i = 0; i < objectDescriptors->total; i++ )
120     {
121         const float* descriptor = (const float*)obj_reader.ptr;
122         CV_NEXT_SEQ_ELEM( obj_reader.seq->elem_size, obj_reader );
123         memcpy(obj_ptr, descriptor, length*sizeof(float));
124         obj_ptr += length;
125     }
126     CvSeqReader img_reader;
127         float* img_ptr = m_image.ptr<float>(0);
128     cvStartReadSeq( imageDescriptors, &img_reader );
129     for(int i = 0; i < imageDescriptors->total; i++ )
130     {
131         const float* descriptor = (const float*)img_reader.ptr;
132         CV_NEXT_SEQ_ELEM( img_reader.seq->elem_size, img_reader );
133         memcpy(img_ptr, descriptor, length*sizeof(float));
134         img_ptr += length;
135     }
136
137     // find nearest neighbors using FLANN
138     cv::Mat m_indices(objectDescriptors->total, 2, CV_32S);
139     cv::Mat m_dists(objectDescriptors->total, 2, CV_32F);
140     cv::flann::Index flann_index(m_image, cv::flann::KDTreeIndexParams(4));  // using 4 randomized kdtrees
141     flann_index.knnSearch(m_object, m_indices, m_dists, 2, cv::flann::SearchParams(64) ); // maximum number of leafs checked
142
143     int* indices_ptr = m_indices.ptr<int>(0);
144     float* dists_ptr = m_dists.ptr<float>(0);
145     for (int i=0;i<m_indices.rows;++i) {
146         if (dists_ptr[2*i]<0.6*dists_ptr[2*i+1]) {
147                 ptpairs.push_back(i);
148                 ptpairs.push_back(indices_ptr[2*i]);
149         }
150     }
151 }
152
153
154 /* a rough implementation for object location */
155 int
156 locatePlanarObject( const CvSeq* objectKeypoints, const CvSeq* objectDescriptors,
157                     const CvSeq* imageKeypoints, const CvSeq* imageDescriptors,
158                     const CvPoint src_corners[4], CvPoint dst_corners[4] )
159 {
160     double h[9];
161     CvMat _h = cvMat(3, 3, CV_64F, h);
162     vector<int> ptpairs;
163     vector<CvPoint2D32f> pt1, pt2;
164     CvMat _pt1, _pt2;
165     int i, n;
166
167 #ifdef USE_FLANN
168     flannFindPairs( objectKeypoints, objectDescriptors, imageKeypoints, imageDescriptors, ptpairs );
169 #else
170     findPairs( objectKeypoints, objectDescriptors, imageKeypoints, imageDescriptors, ptpairs );
171 #endif
172
173     n = (int)(ptpairs.size()/2);
174     if( n < 4 )
175         return 0;
176
177     pt1.resize(n);
178     pt2.resize(n);
179     for( i = 0; i < n; i++ )
180     {
181         pt1[i] = ((CvSURFPoint*)cvGetSeqElem(objectKeypoints,ptpairs[i*2]))->pt;
182         pt2[i] = ((CvSURFPoint*)cvGetSeqElem(imageKeypoints,ptpairs[i*2+1]))->pt;
183     }
184
185     _pt1 = cvMat(1, n, CV_32FC2, &pt1[0] );
186     _pt2 = cvMat(1, n, CV_32FC2, &pt2[0] );
187     if( !cvFindHomography( &_pt1, &_pt2, &_h, CV_RANSAC, 5 ))
188         return 0;
189
190     for( i = 0; i < 4; i++ )
191     {
192         double x = src_corners[i].x, y = src_corners[i].y;
193         double Z = 1./(h[6]*x + h[7]*y + h[8]);
194         double X = (h[0]*x + h[1]*y + h[2])*Z;
195         double Y = (h[3]*x + h[4]*y + h[5])*Z;
196         dst_corners[i] = cvPoint(cvRound(X), cvRound(Y));
197     }
198
199     return 1;
200 }
201
202 int main(int argc, char** argv)
203 {
204     const char* object_filename = argc == 3 ? argv[1] : "box.png";
205     const char* scene_filename = argc == 3 ? argv[2] : "box_in_scene.png";
206
207     CvMemStorage* storage = cvCreateMemStorage(0);
208
209     cvNamedWindow("Object", 1);
210     cvNamedWindow("Object Correspond", 1);
211
212     static CvScalar colors[] = 
213     {
214         {{0,0,255}},
215         {{0,128,255}},
216         {{0,255,255}},
217         {{0,255,0}},
218         {{255,128,0}},
219         {{255,255,0}},
220         {{255,0,0}},
221         {{255,0,255}},
222         {{255,255,255}}
223     };
224
225     IplImage* object = cvLoadImage( object_filename, CV_LOAD_IMAGE_GRAYSCALE );
226     IplImage* image = cvLoadImage( scene_filename, CV_LOAD_IMAGE_GRAYSCALE );
227     if( !object || !image )
228     {
229         fprintf( stderr, "Can not load %s and/or %s\n"
230             "Usage: find_obj [<object_filename> <scene_filename>]\n",
231             object_filename, scene_filename );
232         exit(-1);
233     }
234     IplImage* object_color = cvCreateImage(cvGetSize(object), 8, 3);
235     cvCvtColor( object, object_color, CV_GRAY2BGR );
236     
237     CvSeq *objectKeypoints = 0, *objectDescriptors = 0;
238     CvSeq *imageKeypoints = 0, *imageDescriptors = 0;
239     int i;
240     CvSURFParams params = cvSURFParams(500, 1);
241
242     double tt = (double)cvGetTickCount();
243     cvExtractSURF( object, 0, &objectKeypoints, &objectDescriptors, storage, params );
244     printf("Object Descriptors: %d\n", objectDescriptors->total);
245     cvExtractSURF( image, 0, &imageKeypoints, &imageDescriptors, storage, params );
246     printf("Image Descriptors: %d\n", imageDescriptors->total);
247     tt = (double)cvGetTickCount() - tt;
248     printf( "Extraction time = %gms\n", tt/(cvGetTickFrequency()*1000.));
249     CvPoint src_corners[4] = {{0,0}, {object->width,0}, {object->width, object->height}, {0, object->height}};
250     CvPoint dst_corners[4];
251     IplImage* correspond = cvCreateImage( cvSize(image->width, object->height+image->height), 8, 1 );
252     cvSetImageROI( correspond, cvRect( 0, 0, object->width, object->height ) );
253     cvCopy( object, correspond );
254     cvSetImageROI( correspond, cvRect( 0, object->height, correspond->width, correspond->height ) );
255     cvCopy( image, correspond );
256     cvResetImageROI( correspond );
257
258 #ifdef USE_FLANN
259     printf("Using approximate nearest neighbor search\n");
260 #endif
261
262     if( locatePlanarObject( objectKeypoints, objectDescriptors, imageKeypoints,
263         imageDescriptors, src_corners, dst_corners ))
264     {
265         for( i = 0; i < 4; i++ )
266         {
267             CvPoint r1 = dst_corners[i%4];
268             CvPoint r2 = dst_corners[(i+1)%4];
269             cvLine( correspond, cvPoint(r1.x, r1.y+object->height ),
270                 cvPoint(r2.x, r2.y+object->height ), colors[8] );
271         }
272     }
273     vector<int> ptpairs;
274 #ifdef USE_FLANN
275     flannFindPairs( objectKeypoints, objectDescriptors, imageKeypoints, imageDescriptors, ptpairs );
276 #else
277     findPairs( objectKeypoints, objectDescriptors, imageKeypoints, imageDescriptors, ptpairs );
278 #endif
279     for( i = 0; i < (int)ptpairs.size(); i += 2 )
280     {
281         CvSURFPoint* r1 = (CvSURFPoint*)cvGetSeqElem( objectKeypoints, ptpairs[i] );
282         CvSURFPoint* r2 = (CvSURFPoint*)cvGetSeqElem( imageKeypoints, ptpairs[i+1] );
283         cvLine( correspond, cvPointFrom32f(r1->pt),
284             cvPoint(cvRound(r2->pt.x), cvRound(r2->pt.y+object->height)), colors[8] );
285     }
286
287     cvShowImage( "Object Correspond", correspond );
288     for( i = 0; i < objectKeypoints->total; i++ )
289     {
290         CvSURFPoint* r = (CvSURFPoint*)cvGetSeqElem( objectKeypoints, i );
291         CvPoint center;
292         int radius;
293         center.x = cvRound(r->pt.x);
294         center.y = cvRound(r->pt.y);
295         radius = cvRound(r->size*1.2/9.*2);
296         cvCircle( object_color, center, radius, colors[0], 1, 8, 0 );
297     }
298     cvShowImage( "Object", object_color );
299
300     cvWaitKey(0);
301
302     cvDestroyWindow("Object");
303     cvDestroyWindow("Object SURF");
304     cvDestroyWindow("Object Correspond");
305
306     return 0;
307 }