converted some more samples to C++
[profile/ivi/opencv.git] / samples / c / find_obj_ferns.cpp
1 #include "opencv2/highgui/highgui.hpp"
2 #include "opencv2/core/core.hpp"
3 #include "opencv2/imgproc/imgproc.hpp"
4 #include "opencv2/features2d/features2d.hpp"
5
6 #include <algorithm>
7 #include <iostream>
8 #include <vector>
9
10 using namespace cv;
11
12 int main(int argc, char** argv)
13 {
14     const char* object_filename = argc > 1 ? argv[1] : "box.png";
15     const char* scene_filename = argc > 2 ? argv[2] : "box_in_scene.png";
16     int i;
17     
18     cvNamedWindow("Object", 1);
19     cvNamedWindow("Image", 1);
20     cvNamedWindow("Object Correspondence", 1);
21     
22     Mat object = imread( object_filename, CV_LOAD_IMAGE_GRAYSCALE );
23     Mat image;
24     
25     double imgscale = 1;
26
27     Mat _image = imread( scene_filename, CV_LOAD_IMAGE_GRAYSCALE );
28     resize(_image, image, Size(), 1./imgscale, 1./imgscale, INTER_CUBIC);
29
30
31     if( !object.data || !image.data )
32     {
33         fprintf( stderr, "Can not load %s and/or %s\n"
34                 "Usage: find_obj [<object_filename> <scene_filename>]\n",
35                 object_filename, scene_filename );
36         exit(-1);
37     }
38
39     Size patchSize(32, 32);
40     LDetector ldetector(7, 20, 2, 2000, patchSize.width, 2);
41     ldetector.setVerbose(true);
42     PlanarObjectDetector detector;
43     
44     vector<Mat> objpyr, imgpyr;
45     int blurKSize = 3;
46     double sigma = 0;
47     GaussianBlur(object, object, Size(blurKSize, blurKSize), sigma, sigma);
48     GaussianBlur(image, image, Size(blurKSize, blurKSize), sigma, sigma);
49     buildPyramid(object, objpyr, ldetector.nOctaves-1);
50     buildPyramid(image, imgpyr, ldetector.nOctaves-1);
51     
52     vector<KeyPoint> objKeypoints, imgKeypoints;
53         PatchGenerator gen(0,256,5,true,0.8,1.2,-CV_PI/2,CV_PI/2,-CV_PI/2,CV_PI/2);
54     
55     string model_filename = format("%s_model.xml.gz", object_filename);
56     printf("Trying to load %s ...\n", model_filename.c_str());
57     FileStorage fs(model_filename, FileStorage::READ);
58     if( fs.isOpened() )
59     {
60         detector.read(fs.getFirstTopLevelNode());
61         printf("Successfully loaded %s.\n", model_filename.c_str());
62     }
63     else
64     {
65         printf("The file not found and can not be read. Let's train the model.\n");
66         printf("Step 1. Finding the robust keypoints ...\n");
67         ldetector.setVerbose(true);
68         ldetector.getMostStable2D(object, objKeypoints, 100, gen);
69         printf("Done.\nStep 2. Training ferns-based planar object detector ...\n");
70         detector.setVerbose(true);
71     
72         detector.train(objpyr, objKeypoints, patchSize.width, 100, 11, 10000, ldetector, gen);
73         printf("Done.\nStep 3. Saving the model to %s ...\n", model_filename.c_str());
74         if( fs.open(model_filename, FileStorage::WRITE) )
75             detector.write(fs, "ferns_model");
76     }
77     printf("Now find the keypoints in the image, try recognize them and compute the homography matrix\n");
78     fs.release();
79         
80     vector<Point2f> dst_corners;
81     Mat correspond( object.rows + image.rows, std::max(object.cols, image.cols), CV_8UC3);
82     correspond = Scalar(0.);
83     Mat part(correspond, Rect(0, 0, object.cols, object.rows));
84     cvtColor(object, part, CV_GRAY2BGR);
85     part = Mat(correspond, Rect(0, object.rows, image.cols, image.rows));
86     cvtColor(image, part, CV_GRAY2BGR);
87  
88     vector<int> pairs;
89     Mat H;
90     
91     double t = (double)getTickCount();
92     objKeypoints = detector.getModelPoints();
93     ldetector(imgpyr, imgKeypoints, 300);
94     
95     std::cout << "Object keypoints: " << objKeypoints.size() << "\n";
96     std::cout << "Image keypoints: " << imgKeypoints.size() << "\n";
97     bool found = detector(imgpyr, imgKeypoints, H, dst_corners, &pairs);
98     t = (double)getTickCount() - t;
99     printf("%gms\n", t*1000/getTickFrequency());
100     
101     if( found )
102     {
103         for( i = 0; i < 4; i++ )
104         {
105             Point r1 = dst_corners[i%4];
106             Point r2 = dst_corners[(i+1)%4];
107             line( correspond, Point(r1.x, r1.y+object.rows),
108                  Point(r2.x, r2.y+object.rows), Scalar(0,0,255) );
109         }
110     }
111     
112     for( i = 0; i < (int)pairs.size(); i += 2 )
113     {
114         line( correspond, objKeypoints[pairs[i]].pt,
115              imgKeypoints[pairs[i+1]].pt + Point2f(0,object.rows),
116              Scalar(0,255,0) );
117     }
118     
119     imshow( "Object Correspondence", correspond );
120     Mat objectColor;
121     cvtColor(object, objectColor, CV_GRAY2BGR);
122     for( i = 0; i < (int)objKeypoints.size(); i++ )
123     {
124         circle( objectColor, objKeypoints[i].pt, 2, Scalar(0,0,255), -1 );
125         circle( objectColor, objKeypoints[i].pt, (1 << objKeypoints[i].octave)*15, Scalar(0,255,0), 1 );
126     }
127     Mat imageColor;
128     cvtColor(image, imageColor, CV_GRAY2BGR);
129     for( i = 0; i < (int)imgKeypoints.size(); i++ )
130     {
131         circle( imageColor, imgKeypoints[i].pt, 2, Scalar(0,0,255), -1 );
132         circle( imageColor, imgKeypoints[i].pt, (1 << imgKeypoints[i].octave)*15, Scalar(0,255,0), 1 );
133     }
134     imwrite("correspond.png", correspond );
135     imshow( "Object", objectColor );
136     imshow( "Image", imageColor );
137     
138     waitKey(0);
139     return 0;
140 }