some samples updated according to new CommandLineParser class
[profile/ivi/opencv.git] / samples / c / find_obj_ferns.cpp
1 #include "opencv2/highgui/highgui.hpp"
2 #include "opencv2/core/core.hpp"
3 #include "opencv2/imgproc/imgproc.hpp"
4 #include "opencv2/features2d/features2d.hpp"
5 #include "opencv2/objdetect/objdetect.hpp"
6
7 #include <algorithm>
8 #include <iostream>
9 #include <vector>
10
11 using namespace cv;
12
13 void help()
14 {
15     printf( "This program shows the use of the \"fern\" plannar PlanarObjectDetector point\n"
16             "descriptor classifier"
17             "Usage:\n"
18             "./find_obj_ferns [--object_filename]=<object_filename, box.png as default> \n"
19             "[--scene_filename]=<scene_filename box_in_scene.png as default>]\n\n");
20 }
21
22 int main(int argc, const char** argv)
23 {
24     help();
25
26     CommandLineParser parser(argc, argv);
27
28     string objectFileName = parser.get<string>("object_filename", "box.png");
29     string sceneFileName = parser.get<string>("scene_filename", "box_in_scene.png");
30
31     cvNamedWindow("Object", 1);
32     cvNamedWindow("Image", 1);
33     cvNamedWindow("Object Correspondence", 1);
34     
35     Mat object = imread( objectFileName.c_str(), CV_LOAD_IMAGE_GRAYSCALE );
36     Mat image;
37     
38     double imgscale = 1;
39
40     Mat _image = imread( sceneFileName.c_str(), CV_LOAD_IMAGE_GRAYSCALE );
41     resize(_image, image, Size(), 1./imgscale, 1./imgscale, INTER_CUBIC);
42
43
44     if( !object.data || !image.data )
45     {
46         fprintf( stderr, "Can not load %s and/or %s\n"
47                 "Usage: find_obj_ferns [<object_filename> <scene_filename>]\n",
48                 objectFileName.c_str(), sceneFileName.c_str() );
49         exit(-1);
50     }
51
52     Size patchSize(32, 32);
53     LDetector ldetector(7, 20, 2, 2000, patchSize.width, 2);
54     ldetector.setVerbose(true);
55     PlanarObjectDetector detector;
56     
57     vector<Mat> objpyr, imgpyr;
58     int blurKSize = 3;
59     double sigma = 0;
60     GaussianBlur(object, object, Size(blurKSize, blurKSize), sigma, sigma);
61     GaussianBlur(image, image, Size(blurKSize, blurKSize), sigma, sigma);
62     buildPyramid(object, objpyr, ldetector.nOctaves-1);
63     buildPyramid(image, imgpyr, ldetector.nOctaves-1);
64     
65     vector<KeyPoint> objKeypoints, imgKeypoints;
66         PatchGenerator gen(0,256,5,true,0.8,1.2,-CV_PI/2,CV_PI/2,-CV_PI/2,CV_PI/2);
67     
68     string model_filename = format("%s_model.xml.gz", objectFileName.c_str());
69     printf("Trying to load %s ...\n", model_filename.c_str());
70     FileStorage fs(model_filename, FileStorage::READ);
71     if( fs.isOpened() )
72     {
73         detector.read(fs.getFirstTopLevelNode());
74         printf("Successfully loaded %s.\n", model_filename.c_str());
75     }
76     else
77     {
78         printf("The file not found and can not be read. Let's train the model.\n");
79         printf("Step 1. Finding the robust keypoints ...\n");
80         ldetector.setVerbose(true);
81         ldetector.getMostStable2D(object, objKeypoints, 100, gen);
82         printf("Done.\nStep 2. Training ferns-based planar object detector ...\n");
83         detector.setVerbose(true);
84     
85         detector.train(objpyr, objKeypoints, patchSize.width, 100, 11, 10000, ldetector, gen);
86         printf("Done.\nStep 3. Saving the model to %s ...\n", model_filename.c_str());
87         if( fs.open(model_filename, FileStorage::WRITE) )
88             detector.write(fs, "ferns_model");
89     }
90     printf("Now find the keypoints in the image, try recognize them and compute the homography matrix\n");
91     fs.release();
92         
93     vector<Point2f> dst_corners;
94     Mat correspond( object.rows + image.rows, std::max(object.cols, image.cols), CV_8UC3);
95     correspond = Scalar(0.);
96     Mat part(correspond, Rect(0, 0, object.cols, object.rows));
97     cvtColor(object, part, CV_GRAY2BGR);
98     part = Mat(correspond, Rect(0, object.rows, image.cols, image.rows));
99     cvtColor(image, part, CV_GRAY2BGR);
100  
101     vector<int> pairs;
102     Mat H;
103     
104     double t = (double)getTickCount();
105     objKeypoints = detector.getModelPoints();
106     ldetector(imgpyr, imgKeypoints, 300);
107     
108     std::cout << "Object keypoints: " << objKeypoints.size() << "\n";
109     std::cout << "Image keypoints: " << imgKeypoints.size() << "\n";
110     bool found = detector(imgpyr, imgKeypoints, H, dst_corners, &pairs);
111     t = (double)getTickCount() - t;
112     printf("%gms\n", t*1000/getTickFrequency());
113     
114     int i = 0;
115     if( found )
116     {
117         for( i = 0; i < 4; i++ )
118         {
119             Point r1 = dst_corners[i%4];
120             Point r2 = dst_corners[(i+1)%4];
121             line( correspond, Point(r1.x, r1.y+object.rows),
122                  Point(r2.x, r2.y+object.rows), Scalar(0,0,255) );
123         }
124     }
125     
126     for( i = 0; i < (int)pairs.size(); i += 2 )
127     {
128         line( correspond, objKeypoints[pairs[i]].pt,
129              imgKeypoints[pairs[i+1]].pt + Point2f(0,(float)object.rows),
130              Scalar(0,255,0) );
131     }
132     
133     imshow( "Object Correspondence", correspond );
134     Mat objectColor;
135     cvtColor(object, objectColor, CV_GRAY2BGR);
136     for( i = 0; i < (int)objKeypoints.size(); i++ )
137     {
138         circle( objectColor, objKeypoints[i].pt, 2, Scalar(0,0,255), -1 );
139         circle( objectColor, objKeypoints[i].pt, (1 << objKeypoints[i].octave)*15, Scalar(0,255,0), 1 );
140     }
141     Mat imageColor;
142     cvtColor(image, imageColor, CV_GRAY2BGR);
143     for( i = 0; i < (int)imgKeypoints.size(); i++ )
144     {
145         circle( imageColor, imgKeypoints[i].pt, 2, Scalar(0,0,255), -1 );
146         circle( imageColor, imgKeypoints[i].pt, (1 << imgKeypoints[i].octave)*15, Scalar(0,255,0), 1 );
147     }
148     imwrite("correspond.png", correspond );
149     imshow( "Object", objectColor );
150     imshow( "Image", imageColor );
151     
152     waitKey(0);
153     return 0;
154 }