8f7fcb436c0c8cc148c2bbf06662d32e757220a9
[platform/upstream/opencv.git] / modules / objdetect / src / latentsvmdetector.cpp
1 #include "precomp.hpp"
2 #include "opencv2/imgproc/imgproc_c.h"
3 #include "opencv2/objdetect/objdetect_c.h"
4 #include "_lsvmparser.h"
5 #include "_lsvm_matching.h"
6
7 /*
8 // load trained detector from a file
9 //
10 // API
11 // CvLatentSvmDetector* cvLoadLatentSvmDetector(const char* filename);
12 // INPUT
13 // filename             - path to the file containing the parameters of
14 //                      - trained Latent SVM detector
15 // OUTPUT
16 // trained Latent SVM detector in internal representation
17 */
18 CvLatentSvmDetector* cvLoadLatentSvmDetector(const char* filename)
19 {
20     CvLatentSvmDetector* detector = 0;
21     CvLSVMFilterObject** filters = 0;
22     int kFilters = 0;
23     int kComponents = 0;
24     int* kPartFilters = 0;
25     float* b = 0;
26     float scoreThreshold = 0.f;
27     int err_code = 0;
28
29     err_code = loadModel(filename, &filters, &kFilters, &kComponents, &kPartFilters, &b, &scoreThreshold);
30     if (err_code != LATENT_SVM_OK) return 0;
31
32     detector = (CvLatentSvmDetector*)malloc(sizeof(CvLatentSvmDetector));
33     detector->filters = filters;
34     detector->b = b;
35     detector->num_components = kComponents;
36     detector->num_filters = kFilters;
37     detector->num_part_filters = kPartFilters;
38     detector->score_threshold = scoreThreshold;
39
40     return detector;
41 }
42
43 /*
44 // release memory allocated for CvLatentSvmDetector structure
45 //
46 // API
47 // void cvReleaseLatentSvmDetector(CvLatentSvmDetector** detector);
48 // INPUT
49 // detector             - CvLatentSvmDetector structure to be released
50 // OUTPUT
51 */
52 void cvReleaseLatentSvmDetector(CvLatentSvmDetector** detector)
53 {
54     free((*detector)->b);
55     free((*detector)->num_part_filters);
56     for (int i = 0; i < (*detector)->num_filters; i++)
57     {
58         free((*detector)->filters[i]->H);
59         free((*detector)->filters[i]);
60     }
61     free((*detector)->filters);
62     free((*detector));
63     *detector = 0;
64 }
65
66 /*
67 // find rectangular regions in the given image that are likely
68 // to contain objects and corresponding confidence levels
69 //
70 // API
71 // CvSeq* cvLatentSvmDetectObjects(const IplImage* image,
72 //                                  CvLatentSvmDetector* detector,
73 //                                  CvMemStorage* storage,
74 //                                  float overlap_threshold = 0.5f,
75                                     int numThreads = -1);
76 // INPUT
77 // image                - image to detect objects in
78 // detector             - Latent SVM detector in internal representation
79 // storage              - memory storage to store the resultant sequence
80 //                          of the object candidate rectangles
81 // overlap_threshold    - threshold for the non-maximum suppression algorithm [here will be the reference to original paper]
82 // OUTPUT
83 // sequence of detected objects (bounding boxes and confidence levels stored in CvObjectDetection structures)
84 */
85 CvSeq* cvLatentSvmDetectObjects(IplImage* image,
86                                 CvLatentSvmDetector* detector,
87                                 CvMemStorage* storage,
88                                 float overlap_threshold, int numThreads)
89 {
90     CvLSVMFeaturePyramid *H = 0;
91     CvPoint *points = 0, *oppPoints = 0;
92     int kPoints = 0;
93     float *score = 0;
94     unsigned int maxXBorder = 0, maxYBorder = 0;
95     int numBoxesOut = 0;
96     CvPoint *pointsOut = 0;
97     CvPoint *oppPointsOut = 0;
98     float *scoreOut = 0;
99     CvSeq* result_seq = 0;
100     int error = 0;
101
102     if(image->nChannels == 3)
103         cvCvtColor(image, image, CV_BGR2RGB);
104
105     // Getting maximum filter dimensions
106     getMaxFilterDims((const CvLSVMFilterObject**)(detector->filters), detector->num_components,
107                      detector->num_part_filters, &maxXBorder, &maxYBorder);
108     // Create feature pyramid with nullable border
109     H = createFeaturePyramidWithBorder(image, maxXBorder, maxYBorder);
110     // Search object
111     error = searchObjectThresholdSomeComponents(H, (const CvLSVMFilterObject**)(detector->filters),
112         detector->num_components, detector->num_part_filters, detector->b, detector->score_threshold,
113         &points, &oppPoints, &score, &kPoints, numThreads);
114     if (error != LATENT_SVM_OK)
115     {
116         return NULL;
117     }
118     // Clipping boxes
119     clippingBoxes(image->width, image->height, points, kPoints);
120     clippingBoxes(image->width, image->height, oppPoints, kPoints);
121     // NMS procedure
122     nonMaximumSuppression(kPoints, points, oppPoints, score, overlap_threshold,
123                 &numBoxesOut, &pointsOut, &oppPointsOut, &scoreOut);
124
125     result_seq = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvObjectDetection), storage );
126
127     for (int i = 0; i < numBoxesOut; i++)
128     {
129         CvObjectDetection detection = {CvRect(), 0};
130         detection.score = scoreOut[i];
131         CvRect bounding_box;
132         bounding_box.x = pointsOut[i].x;
133         bounding_box.y = pointsOut[i].y;
134         bounding_box.width = oppPointsOut[i].x - pointsOut[i].x;
135         bounding_box.height = oppPointsOut[i].y - pointsOut[i].y;
136         detection.rect = bounding_box;
137         cvSeqPush(result_seq, &detection);
138     }
139
140     if(image->nChannels == 3)
141         cvCvtColor(image, image, CV_RGB2BGR);
142
143     freeFeaturePyramidObject(&H);
144     free(points);
145     free(oppPoints);
146     free(score);
147     free(scoreOut);
148
149     return result_seq;
150 }
151
152 namespace cv
153 {
154 LatentSvmDetector::ObjectDetection::ObjectDetection() : score(0.f), classID(-1)
155 {}
156
157 LatentSvmDetector::ObjectDetection::ObjectDetection( const Rect& _rect, float _score, int _classID ) :
158     rect(_rect), score(_score), classID(_classID)
159 {}
160
161 LatentSvmDetector::LatentSvmDetector()
162 {}
163
164 LatentSvmDetector::LatentSvmDetector( const std::vector<String>& filenames, const std::vector<String>& _classNames )
165 {
166     load( filenames, _classNames );
167 }
168
169 LatentSvmDetector::~LatentSvmDetector()
170 {
171     clear();
172 }
173
174 void LatentSvmDetector::clear()
175 {
176     for( size_t i = 0; i < detectors.size(); i++ )
177         cvReleaseLatentSvmDetector( &detectors[i] );
178     detectors.clear();
179
180     classNames.clear();
181 }
182
183 bool LatentSvmDetector::empty() const
184 {
185     return detectors.empty();
186 }
187
188 const std::vector<String>& LatentSvmDetector::getClassNames() const
189 {
190     return classNames;
191 }
192
193 size_t LatentSvmDetector::getClassCount() const
194 {
195     return classNames.size();
196 }
197
198 static String extractModelName( const String& filename )
199 {
200     size_t startPos = filename.rfind('/');
201     if( startPos == String::npos )
202         startPos = filename.rfind('\\');
203
204     if( startPos == String::npos )
205         startPos = 0;
206     else
207         startPos++;
208
209     const int extentionSize = 4; //.xml
210
211     int substrLength = (int)(filename.size() - startPos - extentionSize);
212
213     return filename.substr(startPos, substrLength);
214 }
215
216 bool LatentSvmDetector::load( const std::vector<String>& filenames, const std::vector<String>& _classNames )
217 {
218     clear();
219
220     CV_Assert( _classNames.empty() || _classNames.size() == filenames.size() );
221
222     for( size_t i = 0; i < filenames.size(); i++ )
223     {
224         const String filename = filenames[i];
225         if( filename.length() < 5 || filename.substr(filename.length()-4, 4) != ".xml" )
226             continue;
227
228         CvLatentSvmDetector* detector = cvLoadLatentSvmDetector( filename.c_str() );
229         if( detector )
230         {
231             detectors.push_back( detector );
232             if( _classNames.empty() )
233             {
234                 classNames.push_back( extractModelName(filenames[i]) );
235             }
236             else
237                 classNames.push_back( _classNames[i] );
238         }
239     }
240
241     return !empty();
242 }
243
244 void LatentSvmDetector::detect( const Mat& image,
245                                 std::vector<ObjectDetection>& objectDetections,
246                                 float overlapThreshold,
247                                 int numThreads )
248 {
249     objectDetections.clear();
250     if( numThreads <= 0 )
251         numThreads = 1;
252
253     for( size_t classID = 0; classID < detectors.size(); classID++ )
254     {
255         IplImage image_ipl = image;
256         CvMemStorage* storage = cvCreateMemStorage(0);
257         CvSeq* detections = cvLatentSvmDetectObjects( &image_ipl, detectors[classID], storage, overlapThreshold, numThreads );
258
259         // convert results
260         objectDetections.reserve( objectDetections.size() + detections->total );
261         for( int detectionIdx = 0; detectionIdx < detections->total; detectionIdx++ )
262         {
263             CvObjectDetection detection = *(CvObjectDetection*)cvGetSeqElem( detections, detectionIdx );
264             objectDetections.push_back( ObjectDetection(Rect(detection.rect), detection.score, (int)classID) );
265         }
266
267         cvReleaseMemStorage( &storage );
268     }
269 }
270
271 } // namespace cv