Merge pull request #2887 from ilya-lavrenov:ipp_morph_fix
[platform/upstream/opencv.git] / samples / c / find_obj_ferns.cpp
1 #include "opencv2/core/utility.hpp"
2 #include "opencv2/highgui.hpp"
3 #include "opencv2/imgproc.hpp"
4 #include "opencv2/features2d.hpp"
5 #include "opencv2/objdetect.hpp"
6 #include "opencv2/legacy.hpp"
7
8 #include <algorithm>
9 #include <iostream>
10 #include <vector>
11 #include <stdio.h>
12
13 using namespace std;
14 using namespace cv;
15
16 static void help()
17 {
18     printf( "This program shows the use of the \"fern\" plannar PlanarObjectDetector point\n"
19             "descriptor classifier\n"
20             "Usage:\n"
21             "./find_obj_ferns <object_filename> <scene_filename>, default: box.png and box_in_scene.png\n\n");
22     return;
23 }
24
25
26 int main(int argc, char** argv)
27 {
28     int i;
29
30     const char* object_filename = argc > 1 ? argv[1] : "box.png";
31     const char* scene_filename = argc > 2 ? argv[2] : "box_in_scene.png";
32
33     help();
34
35     Mat object = imread( object_filename, IMREAD_GRAYSCALE );
36     Mat scene = imread( scene_filename, IMREAD_GRAYSCALE );
37
38     if( !object.data || !scene.data )
39     {
40         fprintf( stderr, "Can not load %s and/or %s\n",
41                 object_filename, scene_filename );
42         exit(-1);
43     }
44
45     double imgscale = 1;
46     Mat image;
47
48     resize(scene, image, Size(), 1./imgscale, 1./imgscale, INTER_CUBIC);
49
50     namedWindow("Object", 1);
51     namedWindow("Image", 1);
52     namedWindow("Object Correspondence", 1);
53
54     Size patchSize(32, 32);
55     LDetector ldetector(7, 20, 2, 2000, patchSize.width, 2);
56     ldetector.setVerbose(true);
57     PlanarObjectDetector detector;
58
59     vector<Mat> objpyr, imgpyr;
60     int blurKSize = 3;
61     double sigma = 0;
62     GaussianBlur(object, object, Size(blurKSize, blurKSize), sigma, sigma);
63     GaussianBlur(image, image, Size(blurKSize, blurKSize), sigma, sigma);
64     buildPyramid(object, objpyr, ldetector.nOctaves-1);
65     buildPyramid(image, imgpyr, ldetector.nOctaves-1);
66
67     vector<KeyPoint> objKeypoints, imgKeypoints;
68     PatchGenerator gen(0,256,5,true,0.8,1.2,-CV_PI/2,CV_PI/2,-CV_PI/2,CV_PI/2);
69
70     string model_filename = format("%s_model.xml.gz", object_filename);
71     printf("Trying to load %s ...\n", model_filename.c_str());
72     FileStorage fs(model_filename, FileStorage::READ);
73     if( fs.isOpened() )
74     {
75         detector.read(fs.getFirstTopLevelNode());
76         printf("Successfully loaded %s.\n", model_filename.c_str());
77     }
78     else
79     {
80         printf("The file not found and can not be read. Let's train the model.\n");
81         printf("Step 1. Finding the robust keypoints ...\n");
82         ldetector.setVerbose(true);
83         ldetector.getMostStable2D(object, objKeypoints, 100, gen);
84         printf("Done.\nStep 2. Training ferns-based planar object detector ...\n");
85         detector.setVerbose(true);
86
87         detector.train(objpyr, objKeypoints, patchSize.width, 100, 11, 10000, ldetector, gen);
88         printf("Done.\nStep 3. Saving the model to %s ...\n", model_filename.c_str());
89         if( fs.open(model_filename, FileStorage::WRITE) )
90             detector.write(fs, "ferns_model");
91     }
92     printf("Now find the keypoints in the image, try recognize them and compute the homography matrix\n");
93     fs.release();
94
95     vector<Point2f> dst_corners;
96     Mat correspond( object.rows + image.rows, std::max(object.cols, image.cols), CV_8UC3);
97     correspond = Scalar(0.);
98     Mat part(correspond, Rect(0, 0, object.cols, object.rows));
99     cvtColor(object, part, CV_GRAY2BGR);
100     part = Mat(correspond, Rect(0, object.rows, image.cols, image.rows));
101     cvtColor(image, part, CV_GRAY2BGR);
102
103     vector<int> pairs;
104     Mat H;
105
106     double t = (double)getTickCount();
107     objKeypoints = detector.getModelPoints();
108     ldetector(imgpyr, imgKeypoints, 300);
109
110     std::cout << "Object keypoints: " << objKeypoints.size() << "\n";
111     std::cout << "Image keypoints: " << imgKeypoints.size() << "\n";
112     bool found = detector(imgpyr, imgKeypoints, H, dst_corners, &pairs);
113     t = (double)getTickCount() - t;
114     printf("%gms\n", t*1000/getTickFrequency());
115
116     if( found )
117     {
118         for( i = 0; i < 4; i++ )
119         {
120             Point r1 = dst_corners[i%4];
121             Point r2 = dst_corners[(i+1)%4];
122             line( correspond, Point(r1.x, r1.y+object.rows),
123                  Point(r2.x, r2.y+object.rows), Scalar(0,0,255) );
124         }
125     }
126
127     for( i = 0; i < (int)pairs.size(); i += 2 )
128     {
129         line( correspond, objKeypoints[pairs[i]].pt,
130              imgKeypoints[pairs[i+1]].pt + Point2f(0,(float)object.rows),
131              Scalar(0,255,0) );
132     }
133
134     imshow( "Object Correspondence", correspond );
135     Mat objectColor;
136     cvtColor(object, objectColor, CV_GRAY2BGR);
137     for( i = 0; i < (int)objKeypoints.size(); i++ )
138     {
139         circle( objectColor, objKeypoints[i].pt, 2, Scalar(0,0,255), -1 );
140         circle( objectColor, objKeypoints[i].pt, (1 << objKeypoints[i].octave)*15, Scalar(0,255,0), 1 );
141     }
142     Mat imageColor;
143     cvtColor(image, imageColor, CV_GRAY2BGR);
144     for( i = 0; i < (int)imgKeypoints.size(); i++ )
145     {
146         circle( imageColor, imgKeypoints[i].pt, 2, Scalar(0,0,255), -1 );
147         circle( imageColor, imgKeypoints[i].pt, (1 << imgKeypoints[i].octave)*15, Scalar(0,255,0), 1 );
148     }
149
150     imwrite("correspond.png", correspond );
151     imshow( "Object", objectColor );
152     imshow( "Image", imageColor );
153
154     waitKey(0);
155
156     return 0;
157 }