c9d289e8167f8d473e23b104f63bb619ad962b1d
[platform/upstream/opencv.git] / samples / cpp / bagofwords_classification.cpp
1 #include "opencv2/opencv_modules.hpp"
2 #include "opencv2/highgui/highgui.hpp"
3 #include "opencv2/imgproc/imgproc.hpp"
4 #include "opencv2/features2d/features2d.hpp"
5 #include "opencv2/nonfree/nonfree.hpp"
6 #include "opencv2/ml/ml.hpp"
7 #ifdef HAVE_OPENCV_OCL
8 #define _OCL_SVM_ 1 //select whether using ocl::svm method or not, default is using
9 #include "opencv2/ocl/ocl.hpp"
10 #endif
11
12 #include <fstream>
13 #include <iostream>
14 #include <memory>
15 #include <functional>
16
17 #if defined WIN32 || defined _WIN32
18 #define WIN32_LEAN_AND_MEAN
19 #include <windows.h>
20 #undef min
21 #undef max
22 #include "sys/types.h"
23 #endif
24 #include <sys/stat.h>
25
26 #define DEBUG_DESC_PROGRESS
27
28 using namespace cv;
29 using namespace std;
30
31 const string paramsFile = "params.xml";
32 const string vocabularyFile = "vocabulary.xml.gz";
33 const string bowImageDescriptorsDir = "/bowImageDescriptors";
34 const string svmsDir = "/svms";
35 const string plotsDir = "/plots";
36
37 static void help(char** argv)
38 {
39     cout << "\nThis program shows how to read in, train on and produce test results for the PASCAL VOC (Visual Object Challenge) data. \n"
40      << "It shows how to use detectors, descriptors and recognition methods \n"
41         "Using OpenCV version %s\n" << CV_VERSION << "\n"
42      << "Call: \n"
43     << "Format:\n ./" << argv[0] << " [VOC path] [result directory]  \n"
44     << "       or:  \n"
45     << " ./" << argv[0] << " [VOC path] [result directory] [feature detector] [descriptor extractor] [descriptor matcher] \n"
46     << "\n"
47     << "Input parameters: \n"
48     << "[VOC path]             Path to Pascal VOC data (e.g. /home/my/VOCdevkit/VOC2010). Note: VOC2007-VOC2010 are supported. \n"
49     << "[result directory]     Path to result diractory. Following folders will be created in [result directory]: \n"
50     << "                         bowImageDescriptors - to store image descriptors, \n"
51     << "                         svms - to store trained svms, \n"
52     << "                         plots - to store files for plots creating. \n"
53     << "[feature detector]     Feature detector name (e.g. SURF, FAST...) - see createFeatureDetector() function in detectors.cpp \n"
54     << "                         Currently 12/2010, this is FAST, STAR, SIFT, SURF, MSER, GFTT, HARRIS \n"
55     << "[descriptor extractor] Descriptor extractor name (e.g. SURF, SIFT) - see createDescriptorExtractor() function in descriptors.cpp \n"
56     << "                         Currently 12/2010, this is SURF, OpponentSIFT, SIFT, OpponentSURF, BRIEF \n"
57     << "[descriptor matcher]   Descriptor matcher name (e.g. BruteForce) - see createDescriptorMatcher() function in matchers.cpp \n"
58     << "                         Currently 12/2010, this is BruteForce, BruteForce-L1, FlannBased, BruteForce-Hamming, BruteForce-HammingLUT \n"
59     << "\n";
60 }
61
62 static void makeDir( const string& dir )
63 {
64 #if defined WIN32 || defined _WIN32
65     CreateDirectory( dir.c_str(), 0 );
66 #else
67     mkdir( dir.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH );
68 #endif
69 }
70
71 static void makeUsedDirs( const string& rootPath )
72 {
73     makeDir(rootPath + bowImageDescriptorsDir);
74     makeDir(rootPath + svmsDir);
75     makeDir(rootPath + plotsDir);
76 }
77
78 /****************************************************************************************\
79 *                    Classes to work with PASCAL VOC dataset                             *
80 \****************************************************************************************/
81 //
82 // TODO: refactor this part of the code
83 //
84
85
86 //used to specify the (sub-)dataset over which operations are performed
87 enum ObdDatasetType {CV_OBD_TRAIN, CV_OBD_TEST};
88
89 class ObdObject
90 {
91 public:
92     string object_class;
93     Rect boundingBox;
94 };
95
96 //extended object data specific to VOC
97 enum VocPose {CV_VOC_POSE_UNSPECIFIED, CV_VOC_POSE_FRONTAL, CV_VOC_POSE_REAR, CV_VOC_POSE_LEFT, CV_VOC_POSE_RIGHT};
98 class VocObjectData
99 {
100 public:
101     bool difficult;
102     bool occluded;
103     bool truncated;
104     VocPose pose;
105 };
106 //enum VocDataset {CV_VOC2007, CV_VOC2008, CV_VOC2009, CV_VOC2010};
107 enum VocPlotType {CV_VOC_PLOT_SCREEN, CV_VOC_PLOT_PNG};
108 enum VocGT {CV_VOC_GT_NONE, CV_VOC_GT_DIFFICULT, CV_VOC_GT_PRESENT};
109 enum VocConfCond {CV_VOC_CCOND_RECALL, CV_VOC_CCOND_SCORETHRESH};
110 enum VocTask {CV_VOC_TASK_CLASSIFICATION, CV_VOC_TASK_DETECTION};
111
112 class ObdImage
113 {
114 public:
115     ObdImage(string p_id, string p_path) : id(p_id), path(p_path) {}
116     string id;
117     string path;
118 };
119
120 //used by getDetectorGroundTruth to sort a two dimensional list of floats in descending order
121 class ObdScoreIndexSorter
122 {
123 public:
124     float score;
125     int image_idx;
126     int obj_idx;
127     bool operator < (const ObdScoreIndexSorter& compare) const {return (score < compare.score);}
128 };
129
130 class VocData
131 {
132 public:
133     VocData( const string& vocPath, bool useTestDataset )
134         { initVoc( vocPath, useTestDataset ); }
135     ~VocData(){}
136     /* functions for returning classification/object data for multiple images given an object class */
137     void getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
138     void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects);
139     void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth);
140     /* functions for returning object data for a single image given an image id */
141     ObdImage getObjects(const string& id, vector<ObdObject>& objects);
142     ObdImage getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
143     ObdImage getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth);
144     /* functions for returning the ground truth (present/absent) for groups of images */
145     void getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth);
146     void getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth);
147     int getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult = true);
148     /* functions for writing VOC-compatible results files */
149     void writeClassifierResultsFile(const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition = 1, const bool overwrite_ifexists = false);
150     /* functions for calculating metrics from a set of classification/detection results */
151     string getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition = -1, const int number = -1);
152     void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking);
153     void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap);
154     void calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile = false);
155     /* functions for calculating confusion matrices */
156     void calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values);
157     void calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult = true);
158     /* functions for outputting gnuplot output files */
159     void savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title = string(), const VocPlotType plot_type = CV_VOC_PLOT_SCREEN);
160     /* functions for reading in result/ground truth files */
161     void readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
162     void readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores);
163     void readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
164     /* functions for getting dataset info */
165     const vector<string>& getObjectClasses();
166     string getResultsDirectory();
167 protected:
168     void initVoc( const string& vocPath, const bool useTestDataset );
169     void initVoc2007to2010( const string& vocPath, const bool useTestDataset);
170     void readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present);
171     void readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores);
172     void readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
173     void extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
174     string getImagePath(const string& input_str);
175
176     void getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present);
177     void calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization = -1);
178
179     //test two bounding boxes to see if they meet the overlap criteria defined in the VOC documentation
180     float testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth);
181     //extract class and dataset name from a VOC-standard classification/detection results filename
182     void extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name);
183     //get classifier ground truth for a single image
184     bool getClassifierGroundTruthImage(const string& obj_class, const string& id);
185
186     //utility functions
187     void getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending = true);
188     int stringToInteger(const string input_str);
189     void readFileToString(const string filename, string& file_contents);
190     string integerToString(const int input_int);
191     string checkFilenamePathsep(const string filename, bool add_trailing_slash = false);
192     void convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images);
193     int extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents);
194     //utility sorter
195     struct orderingSorter
196     {
197         bool operator ()(std::pair<size_t, vector<float>::const_iterator> const& a, std::pair<size_t, vector<float>::const_iterator> const& b)
198         {
199             return (*a.second) > (*b.second);
200         }
201     };
202     //data members
203     string m_vocPath;
204     string m_vocName;
205     //string m_resPath;
206
207     string m_annotation_path;
208     string m_image_path;
209     string m_imageset_path;
210     string m_class_imageset_path;
211
212     vector<string> m_classifier_gt_all_ids;
213     vector<char> m_classifier_gt_all_present;
214     string m_classifier_gt_class;
215
216     //data members
217     string m_train_set;
218     string m_test_set;
219
220     vector<string> m_object_classes;
221
222
223     float m_min_overlap;
224     bool m_sampled_ap;
225 };
226
227
228 //Return the classification ground truth data for all images of a given VOC object class
229 //--------------------------------------------------------------------------------------
230 //INPUTS:
231 // - obj_class          The VOC object class identifier string
232 // - dataset            Specifies whether to extract images from the training or test set
233 //OUTPUTS:
234 // - images             An array of ObdImage containing info of all images extracted from the ground truth file
235 // - object_present     An array of bools specifying whether the object defined by 'obj_class' is present in each image or not
236 //NOTES:
237 // This function is primarily useful for the classification task, where only
238 // whether a given object is present or not in an image is required, and not each object instance's
239 // position etc.
240 void VocData::getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
241 {
242     string dataset_str;
243     //generate the filename of the classification ground-truth textfile for the object class
244     if (dataset == CV_OBD_TRAIN)
245     {
246         dataset_str = m_train_set;
247     } else {
248         dataset_str = m_test_set;
249     }
250
251     getClassImages_impl(obj_class, dataset_str, images, object_present);
252 }
253
254 void VocData::getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present)
255 {
256     //generate the filename of the classification ground-truth textfile for the object class
257     string gtFilename = m_class_imageset_path;
258     gtFilename.replace(gtFilename.find("%s"),2,obj_class);
259     gtFilename.replace(gtFilename.find("%s"),2,dataset_str);
260
261     //parse the ground truth file, storing in two separate vectors
262     //for the image code and the ground truth value
263     vector<string> image_codes;
264     readClassifierGroundTruth(gtFilename, image_codes, object_present);
265
266     //prepare output arrays
267     images.clear();
268
269     convertImageCodesToObdImages(image_codes, images);
270 }
271
272 //Return the object data for all images of a given VOC object class
273 //-----------------------------------------------------------------
274 //INPUTS:
275 // - obj_class          The VOC object class identifier string
276 // - dataset            Specifies whether to extract images from the training or test set
277 //OUTPUTS:
278 // - images             An array of ObdImage containing info of all images in chosen dataset (tag, path etc.)
279 // - objects            Contains the extended object info (bounding box etc.) for each object instance in each image
280 // - object_data        Contains VOC-specific extended object info (marked difficult etc.)
281 // - ground_truth       Specifies whether there are any difficult/non-difficult instances of the current
282 //                          object class within each image
283 //NOTES:
284 // This function returns extended object information in addition to the absent/present
285 // classification data returned by getClassImages. The objects returned for each image in the 'objects'
286 // array are of all object classes present in the image, and not just the class defined by 'obj_class'.
287 // 'ground_truth' can be used to determine quickly whether an object instance of the given class is present
288 // in an image or not.
289 void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects)
290 {
291     vector<vector<VocObjectData> > object_data;
292     vector<VocGT> ground_truth;
293
294     getClassObjects(obj_class,dataset,images,objects,object_data,ground_truth);
295 }
296
297 void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth)
298 {
299     //generate the filename of the classification ground-truth textfile for the object class
300     string gtFilename = m_class_imageset_path;
301     gtFilename.replace(gtFilename.find("%s"),2,obj_class);
302     if (dataset == CV_OBD_TRAIN)
303     {
304         gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
305     } else {
306         gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
307     }
308
309     //parse the ground truth file, storing in two separate vectors
310     //for the image code and the ground truth value
311     vector<string> image_codes;
312     vector<char> object_present;
313     readClassifierGroundTruth(gtFilename, image_codes, object_present);
314
315     //prepare output arrays
316     images.clear();
317     objects.clear();
318     object_data.clear();
319     ground_truth.clear();
320
321     string annotationFilename;
322     vector<ObdObject> image_objects;
323     vector<VocObjectData> image_object_data;
324     VocGT image_gt;
325
326     //transfer to output arrays and read in object data for each image
327     for (size_t i = 0; i < image_codes.size(); ++i)
328     {
329         ObdImage image = getObjects(obj_class, image_codes[i], image_objects, image_object_data, image_gt);
330
331         images.push_back(image);
332         objects.push_back(image_objects);
333         object_data.push_back(image_object_data);
334         ground_truth.push_back(image_gt);
335     }
336 }
337
338 //Return ground truth data for the objects present in an image with a given UID
339 //-----------------------------------------------------------------------------
340 //INPUTS:
341 // - id                 VOC Dataset unique identifier (string code in form YYYY_XXXXXX where YYYY is the year)
342 //OUTPUTS:
343 // - obj_class (*3)     Specifies the object class to use to resolve 'ground_truth'
344 // - objects            Contains the extended object info (bounding box etc.) for each object in the image
345 // - object_data (*2,3) Contains VOC-specific extended object info (marked difficult etc.)
346 // - ground_truth (*3)  Specifies whether there are any difficult/non-difficult instances of the current
347 //                          object class within the image
348 //RETURN VALUE:
349 // ObdImage containing path and other details of image file with given code
350 //NOTES:
351 // There are three versions of this function
352 //  * One returns a simple array of objects given an id [1]
353 //  * One returns the same as (1) plus VOC specific object data [2]
354 //  * One returns the same as (2) plus the ground_truth flag. This also requires an extra input obj_class [3]
355 ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects)
356 {
357     vector<VocObjectData> object_data;
358     ObdImage image = getObjects(id, objects, object_data);
359
360     return image;
361 }
362
363 ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
364 {
365     //first generate the filename of the annotation file
366     string annotationFilename = m_annotation_path;
367
368     annotationFilename.replace(annotationFilename.find("%s"),2,id);
369
370     //extract objects contained in the current image from the xml
371     extractVocObjects(annotationFilename,objects,object_data);
372
373     //generate image path from extracted string code
374     string path = getImagePath(id);
375
376     ObdImage image(id, path);
377     return image;
378 }
379
380 ObdImage VocData::getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth)
381 {
382
383     //extract object data (except for ground truth flag)
384     ObdImage image = getObjects(id,objects,object_data);
385
386     //pregenerate a flag to indicate whether the current class is present or not in the image
387     ground_truth = CV_VOC_GT_NONE;
388     //iterate through all objects in current image
389     for (size_t j = 0; j < objects.size(); ++j)
390     {
391         if (objects[j].object_class == obj_class)
392         {
393             if (object_data[j].difficult == false)
394             {
395                 //if at least one non-difficult example is present, this flag is always set to CV_VOC_GT_PRESENT
396                 ground_truth = CV_VOC_GT_PRESENT;
397                 break;
398             } else {
399                 //set if at least one object instance is present, but it is marked difficult
400                 ground_truth = CV_VOC_GT_DIFFICULT;
401             }
402         }
403     }
404
405     return image;
406 }
407
408 //Return ground truth data for the presence/absence of a given object class in an arbitrary array of images
409 //---------------------------------------------------------------------------------------------------------
410 //INPUTS:
411 // - obj_class          The VOC object class identifier string
412 // - images             An array of ObdImage OR strings containing the images for which ground truth
413 //                          will be computed
414 //OUTPUTS:
415 // - ground_truth       An output array indicating the presence/absence of obj_class within each image
416 void VocData::getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth)
417 {
418     vector<char>(images.size()).swap(ground_truth);
419
420     vector<ObdObject> objects;
421     vector<VocObjectData> object_data;
422     vector<char>::iterator gt_it = ground_truth.begin();
423     for (vector<ObdImage>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
424     {
425         //getObjects(obj_class, it->id, objects, object_data, voc_ground_truth);
426         (*gt_it) = (getClassifierGroundTruthImage(obj_class, it->id));
427     }
428 }
429
430 void VocData::getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth)
431 {
432     vector<char>(images.size()).swap(ground_truth);
433
434     vector<ObdObject> objects;
435     vector<VocObjectData> object_data;
436     vector<char>::iterator gt_it = ground_truth.begin();
437     for (vector<string>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
438     {
439         //getObjects(obj_class, (*it), objects, object_data, voc_ground_truth);
440         (*gt_it) = (getClassifierGroundTruthImage(obj_class, (*it)));
441     }
442 }
443
444 //Return ground truth data for the accuracy of detection results
445 //--------------------------------------------------------------
446 //INPUTS:
447 // - obj_class          The VOC object class identifier string
448 // - images             An array of ObdImage containing the images for which ground truth
449 //                          will be computed
450 // - bounding_boxes     A 2D input array containing the bounding box rects of the objects of
451 //                          obj_class which were detected in each image
452 //OUTPUTS:
453 // - ground_truth       A 2D output array indicating whether each object detection was accurate
454 //                          or not
455 // - detection_difficult A 2D output array indicating whether the detection fired on an object
456 //                          marked as 'difficult'. This allows it to be ignored if necessary
457 //                          (the voc documentation specifies objects marked as difficult
458 //                          have no effects on the results and are effectively ignored)
459 // - (ignore_difficult) If set to true, objects marked as difficult will be ignored when returning
460 //                          the number of hits for p-r normalization (default = true)
461 //RETURN VALUE:
462 //                      Returns the number of object hits in total in the gt to allow proper normalization
463 //                          of a p-r curve
464 //NOTES:
465 // As stated in the VOC documentation, multiple detections of the same object in an image are
466 // considered FALSE detections e.g. 5 detections of a single object is counted as 1 correct
467 // detection and 4 false detections - it is the responsibility of the participant's system
468 // to filter multiple detections from its output
469 int VocData::getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult)
470 {
471     int recall_normalization = 0;
472
473     /* first create a list of indices referring to the elements of bounding_boxes and scores in
474      * descending order of scores */
475     vector<ObdScoreIndexSorter> sorted_ids;
476     {
477         /* first count how many objects to allow preallocation */
478         size_t obj_count = 0;
479         CV_Assert(images.size() == bounding_boxes.size());
480         CV_Assert(scores.size() == bounding_boxes.size());
481         for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
482         {
483             CV_Assert(scores[im_idx].size() == bounding_boxes[im_idx].size());
484             obj_count += scores[im_idx].size();
485         }
486         /* preallocate id vector */
487         sorted_ids.resize(obj_count);
488         /* now copy across scores and indexes to preallocated vector */
489         int flat_pos = 0;
490         for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
491         {
492             for (size_t ob_idx = 0; ob_idx < scores[im_idx].size(); ++ob_idx)
493             {
494                 sorted_ids[flat_pos].score = scores[im_idx][ob_idx];
495                 sorted_ids[flat_pos].image_idx = (int)im_idx;
496                 sorted_ids[flat_pos].obj_idx = (int)ob_idx;
497                 ++flat_pos;
498             }
499         }
500         /* and sort the vector in descending order of score */
501         std::sort(sorted_ids.begin(),sorted_ids.end());
502         std::reverse(sorted_ids.begin(),sorted_ids.end());
503     }
504
505     /* prepare ground truth + difficult vector (1st dimension) */
506     vector<vector<char> >(images.size()).swap(ground_truth);
507     vector<vector<char> >(images.size()).swap(detection_difficult);
508     vector<vector<char> > detected(images.size());
509
510     vector<vector<ObdObject> > img_objects(images.size());
511     vector<vector<VocObjectData> > img_object_data(images.size());
512     /* preload object ground truth bounding box data */
513     {
514         vector<vector<ObdObject> > img_objects_all(images.size());
515         vector<vector<VocObjectData> > img_object_data_all(images.size());
516         for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
517         {
518             /* prepopulate ground truth bounding boxes */
519             getObjects(images[image_idx].id, img_objects_all[image_idx], img_object_data_all[image_idx]);
520             /* meanwhile, also set length of target ground truth + difficult vector to same as number of object detections (2nd dimension) */
521             ground_truth[image_idx].resize(bounding_boxes[image_idx].size());
522             detection_difficult[image_idx].resize(bounding_boxes[image_idx].size());
523         }
524
525         /* save only instances of the object class concerned */
526         for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
527         {
528             for (size_t obj_idx = 0; obj_idx < img_objects_all[image_idx].size(); ++obj_idx)
529             {
530                 if (img_objects_all[image_idx][obj_idx].object_class == obj_class)
531                 {
532                     img_objects[image_idx].push_back(img_objects_all[image_idx][obj_idx]);
533                     img_object_data[image_idx].push_back(img_object_data_all[image_idx][obj_idx]);
534                 }
535             }
536             detected[image_idx].resize(img_objects[image_idx].size(), false);
537         }
538     }
539
540     /* calculate the total number of objects in the ground truth for the current dataset */
541     {
542         vector<ObdImage> gt_images;
543         vector<char> gt_object_present;
544         getClassImages(obj_class, dataset, gt_images, gt_object_present);
545
546         for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
547         {
548             vector<ObdObject> gt_img_objects;
549             vector<VocObjectData> gt_img_object_data;
550             getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
551             for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
552             {
553                 if (gt_img_objects[obj_idx].object_class == obj_class)
554                 {
555                     if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
556                         ++recall_normalization;
557                 }
558             }
559         }
560     }
561
562 #ifdef PR_DEBUG
563     int printed_count = 0;
564 #endif
565     /* now iterate through detections in descending order of score, assigning to ground truth bounding boxes if possible */
566     for (size_t detect_idx = 0; detect_idx < sorted_ids.size(); ++detect_idx)
567     {
568         //read in indexes to make following code easier to read
569         int im_idx = sorted_ids[detect_idx].image_idx;
570         int ob_idx = sorted_ids[detect_idx].obj_idx;
571         //set ground truth for the current object to false by default
572         ground_truth[im_idx][ob_idx] = false;
573         detection_difficult[im_idx][ob_idx] = false;
574         float maxov = -1.0;
575         bool max_is_difficult = false;
576         int max_gt_obj_idx = -1;
577         //-- for each detected object iterate through objects present in the bounding box ground truth --
578         for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
579         {
580             if (detected[im_idx][gt_obj_idx] == false)
581             {
582                 //check if the detected object and ground truth object overlap by a sufficient margin
583                 float ov = testBoundingBoxesForOverlap(bounding_boxes[im_idx][ob_idx], img_objects[im_idx][gt_obj_idx].boundingBox);
584                 if (ov != -1.0)
585                 {
586                     //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
587                     if (ov > maxov)
588                     {
589                         maxov = ov;
590                         max_gt_obj_idx = (int)gt_obj_idx;
591                         //store whether the maximum detection is marked as difficult or not
592                         max_is_difficult = (img_object_data[im_idx][gt_obj_idx].difficult);
593                     }
594                 }
595             }
596         }
597         //-- if a match was found, set the ground truth of the current object to true --
598         if (maxov != -1.0)
599         {
600             CV_Assert(max_gt_obj_idx != -1);
601             ground_truth[im_idx][ob_idx] = true;
602             //store whether the maximum detection was marked as 'difficult' or not
603             detection_difficult[im_idx][ob_idx] = max_is_difficult;
604             //remove the ground truth object so it doesn't match with subsequent detected objects
605             //** this is the behaviour defined by the voc documentation **
606             detected[im_idx][max_gt_obj_idx] = true;
607         }
608 #ifdef PR_DEBUG
609         if (printed_count < 10)
610         {
611             cout << printed_count << ": id=" << images[im_idx].id << ", score=" << scores[im_idx][ob_idx] << " (" << ob_idx << ") [" << bounding_boxes[im_idx][ob_idx].x << "," <<
612                     bounding_boxes[im_idx][ob_idx].y << "," << bounding_boxes[im_idx][ob_idx].width + bounding_boxes[im_idx][ob_idx].x <<
613                     "," << bounding_boxes[im_idx][ob_idx].height + bounding_boxes[im_idx][ob_idx].y << "] detected=" << ground_truth[im_idx][ob_idx] <<
614                     ", difficult=" << detection_difficult[im_idx][ob_idx] << endl;
615             ++printed_count;
616             /* print ground truth */
617             for (int gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
618             {
619                 cout << "    GT: [" << img_objects[im_idx][gt_obj_idx].boundingBox.x << "," <<
620                         img_objects[im_idx][gt_obj_idx].boundingBox.y << "," << img_objects[im_idx][gt_obj_idx].boundingBox.width + img_objects[im_idx][gt_obj_idx].boundingBox.x <<
621                         "," << img_objects[im_idx][gt_obj_idx].boundingBox.height + img_objects[im_idx][gt_obj_idx].boundingBox.y << "]";
622                 if (gt_obj_idx == max_gt_obj_idx) cout << " <--- (" << maxov << " overlap)";
623                 cout << endl;
624             }
625         }
626 #endif
627     }
628
629     return recall_normalization;
630 }
631
632 //Write VOC-compliant classifier results file
633 //-------------------------------------------
634 //INPUTS:
635 // - obj_class          The VOC object class identifier string
636 // - dataset            Specifies whether working with the training or test set
637 // - images             An array of ObdImage containing the images for which data will be saved to the result file
638 // - scores             A corresponding array of confidence scores given a query
639 // - (competition)      If specified, defines which competition the results are for (see VOC documentation - default 1)
640 //NOTES:
641 // The result file path and filename are determined automatically using m_results_directory as a base
642 void VocData::writeClassifierResultsFile( const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition, const bool overwrite_ifexists)
643 {
644     CV_Assert(images.size() == scores.size());
645
646     string output_file_base, output_file;
647     if (dataset == CV_OBD_TRAIN)
648     {
649         output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_train_set + "_" + obj_class;
650     } else {
651         output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_test_set + "_" + obj_class;
652     }
653     output_file = output_file_base + ".txt";
654
655     //check if file exists, and if so create a numbered new file instead
656     if (overwrite_ifexists == false)
657     {
658         struct stat stFileInfo;
659         if (stat(output_file.c_str(),&stFileInfo) == 0)
660         {
661             string output_file_new;
662             int filenum = 0;
663             do
664             {
665                 ++filenum;
666                 output_file_new = output_file_base + "_" + integerToString(filenum);
667                 output_file = output_file_new + ".txt";
668             } while (stat(output_file.c_str(),&stFileInfo) == 0);
669         }
670     }
671
672     //output data to file
673     std::ofstream result_file(output_file.c_str());
674     if (result_file.is_open())
675     {
676         for (size_t i = 0; i < images.size(); ++i)
677         {
678             result_file << images[i].id << " " << scores[i] << endl;
679         }
680         result_file.close();
681     } else {
682         string err_msg = "could not open classifier results file '" + output_file + "' for writing. Before running for the first time, a 'results' subdirectory should be created within the VOC dataset base directory. e.g. if the VOC data is stored in /VOC/VOC2010 then the path /VOC/results must be created.";
683         CV_Error(CV_StsError,err_msg.c_str());
684     }
685 }
686
687 //---------------------------------------
688 //CALCULATE METRICS FROM VOC RESULTS DATA
689 //---------------------------------------
690
691 //Utility function to construct a VOC-standard classification results filename
692 //----------------------------------------------------------------------------
693 //INPUTS:
694 // - obj_class          The VOC object class identifier string
695 // - task               Specifies whether to generate a filename for the classification or detection task
696 // - dataset            Specifies whether working with the training or test set
697 // - (competition)      If specified, defines which competition the results are for (see VOC documentation
698 //                      default of -1 means this is set to 1 for the classification task and 3 for the detection task)
699 // - (number)           If specified and above 0, defines which of a number of duplicate results file produced for a given set of
700 //                      of settings should be used (this number will be added as a postfix to the filename)
701 //NOTES:
702 // This is primarily useful for returning the filename of a classification file previously computed using writeClassifierResultsFile
703 // for example when calling calcClassifierPrecRecall
704 string VocData::getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition, const int number)
705 {
706     if ((competition < 1) && (competition != -1))
707         CV_Error(CV_StsBadArg,"competition argument should be a positive non-zero number or -1 to accept the default");
708     if ((number < 1) && (number != -1))
709         CV_Error(CV_StsBadArg,"number argument should be a positive non-zero number or -1 to accept the default");
710
711     string dset, task_type;
712
713     if (dataset == CV_OBD_TRAIN)
714     {
715         dset = m_train_set;
716     } else {
717         dset = m_test_set;
718     }
719
720     int comp = competition;
721     if (task == CV_VOC_TASK_CLASSIFICATION)
722     {
723         task_type = "cls";
724         if (comp == -1) comp = 1;
725     } else {
726         task_type = "det";
727         if (comp == -1) comp = 3;
728     }
729
730     stringstream ss;
731     if (number < 1)
732     {
733         ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << ".txt";
734     } else {
735         ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << "_" << number << ".txt";
736     }
737
738     string filename = ss.str();
739     return filename;
740 }
741
742 //Calculate metrics for classification results
743 //--------------------------------------------
744 //INPUTS:
745 // - ground_truth       A vector of booleans determining whether the currently tested class is present in each input image
746 // - scores             A vector containing the similarity score for each input image (higher is more similar)
747 //OUTPUTS:
748 // - precision          A vector containing the precision calculated at each datapoint of a p-r curve generated from the result set
749 // - recall             A vector containing the recall calculated at each datapoint of a p-r curve generated from the result set
750 // - ap                The ap metric calculated from the result set
751 // - (ranking)          A vector of the same length as 'ground_truth' and 'scores' containing the order of the indices in both of
752 //                      these arrays when sorting by the ranking score in descending order
753 //NOTES:
754 // The result file path and filename are determined automatically using m_results_directory as a base
755 void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking)
756 {
757     vector<char> res_ground_truth;
758     getClassifierGroundTruth(obj_class, images, res_ground_truth);
759
760     calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
761 }
762
763 void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap)
764 {
765     vector<char> res_ground_truth;
766     getClassifierGroundTruth(obj_class, images, res_ground_truth);
767
768     vector<size_t> ranking;
769     calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
770 }
771
772 //< Overloaded version which accepts VOC classification result file input instead of array of scores/ground truth >
773 //INPUTS:
774 // - input_file         The path to the VOC standard results file to use for calculating precision/recall
775 //                      If a full path is not specified, it is assumed this file is in the VOC standard results directory
776 //                      A VOC standard filename can be retrieved (as used by writeClassifierResultsFile) by calling  getClassifierResultsFilename
777
778 void VocData::calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile)
779 {
780     //read in classification results file
781     vector<string> res_image_codes;
782     vector<float> res_scores;
783
784     string input_file_std = checkFilenamePathsep(input_file);
785     readClassifierResultsFile(input_file_std, res_image_codes, res_scores);
786
787     //extract the object class and dataset from the results file filename
788     string class_name, dataset_name;
789     extractDataFromResultsFilename(input_file_std, class_name, dataset_name);
790
791     //generate the ground truth for the images extracted from the results file
792     vector<char> res_ground_truth;
793
794     getClassifierGroundTruth(class_name, res_image_codes, res_ground_truth);
795
796     if (outputRankingFile)
797     {
798         /* 1. store sorting order by score (descending) in 'order' */
799         vector<std::pair<size_t, vector<float>::const_iterator> > order(res_scores.size());
800
801         size_t n = 0;
802         for (vector<float>::const_iterator it = res_scores.begin(); it != res_scores.end(); ++it, ++n)
803             order[n] = make_pair(n, it);
804
805         std::sort(order.begin(),order.end(),orderingSorter());
806
807         /* 2. save ranking results to text file */
808         string input_file_std1 = checkFilenamePathsep(input_file);
809         size_t fnamestart = input_file_std1.rfind("/");
810         string scoregt_file_str = input_file_std1.substr(0,fnamestart+1) + "scoregt_" + class_name + ".txt";
811         std::ofstream scoregt_file(scoregt_file_str.c_str());
812         if (scoregt_file.is_open())
813         {
814             for (size_t i = 0; i < res_scores.size(); ++i)
815             {
816                 scoregt_file << res_image_codes[order[i].first] << " " << res_scores[order[i].first] << " " << res_ground_truth[order[i].first] << endl;
817             }
818             scoregt_file.close();
819         } else {
820             string err_msg = "could not open scoregt file '" + scoregt_file_str + "' for writing.";
821             CV_Error(CV_StsError,err_msg.c_str());
822         }
823     }
824
825     //finally, calculate precision+recall+ap
826     vector<size_t> ranking;
827     calcPrecRecall_impl(res_ground_truth,res_scores,precision,recall,ap,ranking);
828 }
829
830 //< Protected implementation of Precision-Recall calculation used by both calcClassifierPrecRecall and calcDetectorPrecRecall >
831
832 void VocData::calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization)
833 {
834     CV_Assert(ground_truth.size() == scores.size());
835
836     //add extra element for p-r at 0 recall (in case that first retrieved is positive)
837     vector<float>(scores.size()+1).swap(precision);
838     vector<float>(scores.size()+1).swap(recall);
839
840     // SORT RESULTS BY THEIR SCORE
841     /* 1. store sorting order in 'order' */
842     VocData::getSortOrder(scores, ranking);
843
844 #ifdef PR_DEBUG
845     std::ofstream scoregt_file("D:/pr.txt");
846     if (scoregt_file.is_open())
847     {
848        for (int i = 0; i < scores.size(); ++i)
849        {
850            scoregt_file << scores[ranking[i]] << " " << ground_truth[ranking[i]] << endl;
851        }
852        scoregt_file.close();
853     }
854 #endif
855
856     // CALCULATE PRECISION+RECALL
857
858     int retrieved_hits = 0;
859
860     int recall_norm;
861     if (recall_normalization != -1)
862     {
863         recall_norm = recall_normalization;
864     } else {
865         recall_norm = (int)std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<char>(),(char)1));
866     }
867
868     ap = 0;
869     recall[0] = 0;
870     for (size_t idx = 0; idx < ground_truth.size(); ++idx)
871     {
872         if (ground_truth[ranking[idx]] != 0) ++retrieved_hits;
873
874         precision[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(idx+1);
875         recall[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(recall_norm);
876
877         if (idx == 0)
878         {
879             //add further point at 0 recall with the same precision value as the first computed point
880             precision[idx] = precision[idx+1];
881         }
882         if (recall[idx+1] == 1.0)
883         {
884             //if recall = 1, then end early as all positive images have been found
885             recall.resize(idx+2);
886             precision.resize(idx+2);
887             break;
888         }
889     }
890
891     /* ap calculation */
892     if (m_sampled_ap == false)
893     {
894         // FOR VOC2010+ AP IS CALCULATED FROM ALL DATAPOINTS
895         /* make precision monotonically decreasing for purposes of calculating ap */
896         vector<float> precision_monot(precision.size());
897         vector<float>::iterator prec_m_it = precision_monot.begin();
898         for (vector<float>::iterator prec_it = precision.begin(); prec_it != precision.end(); ++prec_it, ++prec_m_it)
899         {
900             vector<float>::iterator max_elem;
901             max_elem = std::max_element(prec_it,precision.end());
902             (*prec_m_it) = (*max_elem);
903         }
904         /* calculate ap */
905         for (size_t idx = 0; idx < (recall.size()-1); ++idx)
906         {
907             ap += (recall[idx+1] - recall[idx])*precision_monot[idx+1] +   //no need to take min of prec - is monotonically decreasing
908                     0.5f*(recall[idx+1] - recall[idx])*std::abs(precision_monot[idx+1] - precision_monot[idx]);
909         }
910     } else {
911         // FOR BEFORE VOC2010 AP IS CALCULATED BY SAMPLING PRECISION AT RECALL 0.0,0.1,..,1.0
912
913         for (float recall_pos = 0.f; recall_pos <= 1.f; recall_pos += 0.1f)
914         {
915             //find iterator of the precision corresponding to the first recall >= recall_pos
916             vector<float>::iterator recall_it = recall.begin();
917             vector<float>::iterator prec_it = precision.begin();
918
919             while ((*recall_it) < recall_pos)
920             {
921                 ++recall_it;
922                 ++prec_it;
923                 if (recall_it == recall.end()) break;
924             }
925
926             /* if no recall >= recall_pos found, this level of recall is never reached so stop adding to ap */
927             if (recall_it == recall.end()) break;
928
929             /* if the prec_it is valid, compute the max precision at this level of recall or higher */
930             vector<float>::iterator max_prec = std::max_element(prec_it,precision.end());
931
932             ap += (*max_prec)/11;
933         }
934     }
935 }
936
937 /* functions for calculating confusion matrix rows */
938
939 //Calculate rows of a confusion matrix
940 //------------------------------------
941 //INPUTS:
942 // - obj_class          The VOC object class identifier string for the confusion matrix row to compute
943 // - images             An array of ObdImage containing the images to use for the computation
944 // - scores             A corresponding array of confidence scores for the presence of obj_class in each image
945 // - cond               Defines whether to use a cut off point based on recall (CV_VOC_CCOND_RECALL) or score
946 //                      (CV_VOC_CCOND_SCORETHRESH) the latter is useful for classifier detections where positive
947 //                      values are positive detections and negative values are negative detections
948 // - threshold          Threshold value for cond. In case of CV_VOC_CCOND_RECALL, is proportion recall (e.g. 0.5).
949 //                      In the case of CV_VOC_CCOND_SCORETHRESH is the value above which to count results.
950 //OUTPUTS:
951 // - output_headers     An output vector of object class headers for the confusion matrix row
952 // - output_values      An output vector of values for the confusion matrix row corresponding to the classes
953 //                      defined in output_headers
954 //NOTES:
955 // The methodology used by the classifier version of this function is that true positives have a single unit
956 // added to the obj_class column in the confusion matrix row, whereas false positives have a single unit
957 // distributed in proportion between all the columns in the confusion matrix row corresponding to the objects
958 // present in the image.
959 void VocData::calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values)
960 {
961     CV_Assert(images.size() == scores.size());
962
963     // SORT RESULTS BY THEIR SCORE
964     /* 1. store sorting order in 'ranking' */
965     vector<size_t> ranking;
966     VocData::getSortOrder(scores, ranking);
967
968     // CALCULATE CONFUSION MATRIX ENTRIES
969     /* prepare object category headers */
970     output_headers = m_object_classes;
971     vector<float>(output_headers.size(),0.0).swap(output_values);
972     /* find the index of the target object class in the headers for later use */
973     int target_idx;
974     {
975         vector<string>::iterator target_idx_it = std::find(output_headers.begin(),output_headers.end(),obj_class);
976         /* if the target class can not be found, raise an exception */
977         if (target_idx_it == output_headers.end())
978         {
979             string err_msg = "could not find the target object class '" + obj_class + "' in list of valid classes.";
980             CV_Error(CV_StsError,err_msg.c_str());
981         }
982         /* convert iterator to index */
983         target_idx = (int)std::distance(output_headers.begin(),target_idx_it);
984     }
985
986     /* prepare variables related to calculating recall if using the recall threshold */
987     int retrieved_hits = 0;
988     int total_relevant = 0;
989     if (cond == CV_VOC_CCOND_RECALL)
990     {
991         vector<char> ground_truth;
992         /* in order to calculate the total number of relevant images for normalization of recall
993             it's necessary to extract the ground truth for the images under consideration */
994         getClassifierGroundTruth(obj_class, images, ground_truth);
995         total_relevant = (int)std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<char>(),(char)1));
996     }
997
998     /* iterate through images */
999     vector<ObdObject> img_objects;
1000     vector<VocObjectData> img_object_data;
1001     int total_images = 0;
1002     for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
1003     {
1004         /* if using the score as the break condition, check for it now */
1005         if (cond == CV_VOC_CCOND_SCORETHRESH)
1006         {
1007             if (scores[ranking[image_idx]] <= threshold) break;
1008         }
1009         /* if continuing for this iteration, increment the image counter for later normalization */
1010         ++total_images;
1011         /* for each image retrieve the objects contained */
1012         getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
1013         //check if the tested for object class is present
1014         if (getClassifierGroundTruthImage(obj_class, images[ranking[image_idx]].id))
1015         {
1016             //if the target class is present, assign fully to the target class element in the confusion matrix row
1017             output_values[target_idx] += 1.0;
1018             if (cond == CV_VOC_CCOND_RECALL) ++retrieved_hits;
1019         } else {
1020             //first delete all objects marked as difficult
1021             for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
1022             {
1023                 if (img_object_data[obj_idx].difficult == true)
1024                 {
1025                     vector<ObdObject>::iterator it1 = img_objects.begin();
1026                     std::advance(it1,obj_idx);
1027                     img_objects.erase(it1);
1028                     vector<VocObjectData>::iterator it2 = img_object_data.begin();
1029                     std::advance(it2,obj_idx);
1030                     img_object_data.erase(it2);
1031                     --obj_idx;
1032                 }
1033             }
1034             //if the target class is not present, add values to the confusion matrix row in equal proportions to all objects present in the image
1035             for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
1036             {
1037                 //find the index of the currently considered object
1038                 vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[obj_idx].object_class);
1039                 //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
1040                 if (class_idx_it == output_headers.end())
1041                 {
1042                     string err_msg = "could not find object class '" + img_objects[obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
1043                     CV_Error(CV_StsError,err_msg.c_str());
1044                 }
1045                 /* convert iterator to index */
1046                 int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
1047                 //add to confusion matrix row in proportion
1048                 output_values[class_idx] += 1.f/static_cast<float>(img_objects.size());
1049             }
1050         }
1051         //check break conditions if breaking on certain level of recall
1052         if (cond == CV_VOC_CCOND_RECALL)
1053         {
1054             if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
1055         }
1056     }
1057     /* finally, normalize confusion matrix row */
1058     for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
1059     {
1060         (*it) /= static_cast<float>(total_images);
1061     }
1062 }
1063
1064 // NOTE: doesn't ignore repeated detections
1065 void VocData::calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult)
1066 {
1067     CV_Assert(images.size() == scores.size());
1068     CV_Assert(images.size() == bounding_boxes.size());
1069
1070     //collapse scores and ground_truth vectors into 1D vectors to allow ranking
1071     /* define final flat vectors */
1072     vector<string> images_flat;
1073     vector<float> scores_flat;
1074     vector<Rect> bounding_boxes_flat;
1075     {
1076         /* first count how many objects to allow preallocation */
1077         int obj_count = 0;
1078         CV_Assert(scores.size() == bounding_boxes.size());
1079         for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
1080         {
1081             CV_Assert(scores[img_idx].size() == bounding_boxes[img_idx].size());
1082             for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
1083             {
1084                 ++obj_count;
1085             }
1086         }
1087         /* preallocate vectors */
1088         images_flat.resize(obj_count);
1089         scores_flat.resize(obj_count);
1090         bounding_boxes_flat.resize(obj_count);
1091         /* now copy across to preallocated vectors */
1092         int flat_pos = 0;
1093         for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
1094         {
1095             for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
1096             {
1097                 images_flat[flat_pos] = images[img_idx].id;
1098                 scores_flat[flat_pos] = scores[img_idx][obj_idx];
1099                 bounding_boxes_flat[flat_pos] = bounding_boxes[img_idx][obj_idx];
1100                 ++flat_pos;
1101             }
1102         }
1103     }
1104
1105     // SORT RESULTS BY THEIR SCORE
1106     /* 1. store sorting order in 'ranking' */
1107     vector<size_t> ranking;
1108     VocData::getSortOrder(scores_flat, ranking);
1109
1110     // CALCULATE CONFUSION MATRIX ENTRIES
1111     /* prepare object category headers */
1112     output_headers = m_object_classes;
1113     output_headers.push_back("background");
1114     vector<float>(output_headers.size(),0.0).swap(output_values);
1115
1116     /* prepare variables related to calculating recall if using the recall threshold */
1117     int retrieved_hits = 0;
1118     int total_relevant = 0;
1119     if (cond == CV_VOC_CCOND_RECALL)
1120     {
1121 //        vector<char> ground_truth;
1122 //        /* in order to calculate the total number of relevant images for normalization of recall
1123 //            it's necessary to extract the ground truth for the images under consideration */
1124 //        getClassifierGroundTruth(obj_class, images, ground_truth);
1125 //        total_relevant = std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<bool>(),true));
1126         /* calculate the total number of objects in the ground truth for the current dataset */
1127         vector<ObdImage> gt_images;
1128         vector<char> gt_object_present;
1129         getClassImages(obj_class, dataset, gt_images, gt_object_present);
1130
1131         for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
1132         {
1133             vector<ObdObject> gt_img_objects;
1134             vector<VocObjectData> gt_img_object_data;
1135             getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
1136             for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
1137             {
1138                 if (gt_img_objects[obj_idx].object_class == obj_class)
1139                 {
1140                     if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
1141                         ++total_relevant;
1142                 }
1143             }
1144         }
1145     }
1146
1147     /* iterate through objects */
1148     vector<ObdObject> img_objects;
1149     vector<VocObjectData> img_object_data;
1150     int total_objects = 0;
1151     for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
1152     {
1153         /* if using the score as the break condition, check for it now */
1154         if (cond == CV_VOC_CCOND_SCORETHRESH)
1155         {
1156             if (scores_flat[ranking[image_idx]] <= threshold) break;
1157         }
1158         /* increment the image counter for later normalization */
1159         ++total_objects;
1160         /* for each image retrieve the objects contained */
1161         getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
1162
1163         //find the ground truth object which has the highest overlap score with the detected object
1164         float maxov = -1.0;
1165         int max_gt_obj_idx = -1;
1166         //-- for each detected object iterate through objects present in ground truth --
1167         for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects.size(); ++gt_obj_idx)
1168         {
1169             //check difficulty flag
1170             if (ignore_difficult || (img_object_data[gt_obj_idx].difficult == false))
1171             {
1172                 //if the class matches, then check if the detected object and ground truth object overlap by a sufficient margin
1173                 float ov = testBoundingBoxesForOverlap(bounding_boxes_flat[ranking[image_idx]], img_objects[gt_obj_idx].boundingBox);
1174                 if (ov != -1.f)
1175                 {
1176                     //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
1177                     if (ov > maxov)
1178                     {
1179                         maxov = ov;
1180                         max_gt_obj_idx = (int)gt_obj_idx;
1181                     }
1182                 }
1183             }
1184         }
1185
1186         //assign to appropriate object class if an object was detected
1187         if (maxov != -1.0)
1188         {
1189             //find the index of the currently considered object
1190             vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[max_gt_obj_idx].object_class);
1191             //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
1192             if (class_idx_it == output_headers.end())
1193             {
1194                 string err_msg = "could not find object class '" + img_objects[max_gt_obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
1195                 CV_Error(CV_StsError,err_msg.c_str());
1196             }
1197             /* convert iterator to index */
1198             int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
1199             //add to confusion matrix row in proportion
1200             output_values[class_idx] += 1.0;
1201         } else {
1202             //otherwise assign to background class
1203             output_values[output_values.size()-1] += 1.0;
1204         }
1205
1206         //check break conditions if breaking on certain level of recall
1207         if (cond == CV_VOC_CCOND_RECALL)
1208         {
1209             if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
1210         }
1211     }
1212
1213     /* finally, normalize confusion matrix row */
1214     for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
1215     {
1216         (*it) /= static_cast<float>(total_objects);
1217     }
1218 }
1219
1220 //Save Precision-Recall results to a p-r curve in GNUPlot format
1221 //--------------------------------------------------------------
1222 //INPUTS:
1223 // - output_file        The file to which to save the GNUPlot data file. If only a filename is specified, the data
1224 //                      file is saved to the standard VOC results directory.
1225 // - precision          Vector of precisions as returned from calcClassifier/DetectorPrecRecall
1226 // - recall             Vector of recalls as returned from calcClassifier/DetectorPrecRecall
1227 // - ap                ap as returned from calcClassifier/DetectorPrecRecall
1228 // - (title)            Title to use for the plot (if not specified, just the ap is printed as the title)
1229 //                      This also specifies the filename of the output file if printing to pdf
1230 // - (plot_type)        Specifies whether to instruct GNUPlot to save to a PDF file (CV_VOC_PLOT_PDF) or directly
1231 //                      to screen (CV_VOC_PLOT_SCREEN) in the datafile
1232 //NOTES:
1233 // The GNUPlot data file can be executed using GNUPlot from the commandline in the following way:
1234 //      >> GNUPlot <output_file>
1235 // This will then display the p-r curve on the screen or save it to a pdf file depending on plot_type
1236
1237 void VocData::savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title, const VocPlotType plot_type)
1238 {
1239     string output_file_std = checkFilenamePathsep(output_file);
1240
1241     //if no directory is specified, by default save the output file in the results directory
1242 //    if (output_file_std.find("/") == output_file_std.npos)
1243 //    {
1244 //        output_file_std = m_results_directory + output_file_std;
1245 //    }
1246
1247     std::ofstream plot_file(output_file_std.c_str());
1248
1249     if (plot_file.is_open())
1250     {
1251         plot_file << "set xrange [0:1]" << endl;
1252         plot_file << "set yrange [0:1]" << endl;
1253         plot_file << "set size square" << endl;
1254         string title_text = title;
1255         if (title_text.size() == 0) title_text = "Precision-Recall Curve";
1256         plot_file << "set title \"" << title_text << " (ap: " << ap << ")\"" << endl;
1257         plot_file << "set xlabel \"Recall\"" << endl;
1258         plot_file << "set ylabel \"Precision\"" << endl;
1259         plot_file << "set style data lines" << endl;
1260         plot_file << "set nokey" << endl;
1261         if (plot_type == CV_VOC_PLOT_PNG)
1262         {
1263             plot_file << "set terminal png" << endl;
1264             string pdf_filename;
1265             if (title.size() != 0)
1266             {
1267                 pdf_filename = title;
1268             } else {
1269                 pdf_filename = "prcurve";
1270             }
1271             plot_file << "set out \"" << title << ".png\"" << endl;
1272         }
1273         plot_file << "plot \"-\" using 1:2" << endl;
1274         plot_file << "# X Y" << endl;
1275         CV_Assert(precision.size() == recall.size());
1276         for (size_t i = 0; i < precision.size(); ++i)
1277         {
1278             plot_file << "  " << recall[i] << " " << precision[i] << endl;
1279         }
1280         plot_file << "end" << endl;
1281         if (plot_type == CV_VOC_PLOT_SCREEN)
1282         {
1283             plot_file << "pause -1" << endl;
1284         }
1285         plot_file.close();
1286     } else {
1287         string err_msg = "could not open plot file '" + output_file_std + "' for writing.";
1288         CV_Error(CV_StsError,err_msg.c_str());
1289     }
1290 }
1291
1292 void VocData::readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
1293 {
1294     images.clear();
1295
1296     string gtFilename = m_class_imageset_path;
1297     gtFilename.replace(gtFilename.find("%s"),2,obj_class);
1298     if (dataset == CV_OBD_TRAIN)
1299     {
1300         gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
1301     } else {
1302         gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
1303     }
1304
1305     vector<string> image_codes;
1306     readClassifierGroundTruth(gtFilename, image_codes, object_present);
1307
1308     convertImageCodesToObdImages(image_codes, images);
1309 }
1310
1311 void VocData::readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores)
1312 {
1313     images.clear();
1314
1315     string input_file_std = checkFilenamePathsep(input_file);
1316
1317     //if no directory is specified, by default search for the input file in the results directory
1318 //    if (input_file_std.find("/") == input_file_std.npos)
1319 //    {
1320 //        input_file_std = m_results_directory + input_file_std;
1321 //    }
1322
1323     vector<string> image_codes;
1324     readClassifierResultsFile(input_file_std, image_codes, scores);
1325
1326     convertImageCodesToObdImages(image_codes, images);
1327 }
1328
1329 void VocData::readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
1330 {
1331     images.clear();
1332
1333     string input_file_std = checkFilenamePathsep(input_file);
1334
1335     //if no directory is specified, by default search for the input file in the results directory
1336 //    if (input_file_std.find("/") == input_file_std.npos)
1337 //    {
1338 //        input_file_std = m_results_directory + input_file_std;
1339 //    }
1340
1341     vector<string> image_codes;
1342     readDetectorResultsFile(input_file_std, image_codes, scores, bounding_boxes);
1343
1344     convertImageCodesToObdImages(image_codes, images);
1345 }
1346
1347 const vector<string>& VocData::getObjectClasses()
1348 {
1349     return m_object_classes;
1350 }
1351
1352 //string VocData::getResultsDirectory()
1353 //{
1354 //    return m_results_directory;
1355 //}
1356
1357 //---------------------------------------------------------
1358 // Protected Functions ------------------------------------
1359 //---------------------------------------------------------
1360
1361 static string getVocName( const string& vocPath )
1362 {
1363     size_t found = vocPath.rfind( '/' );
1364     if( found == string::npos )
1365     {
1366         found = vocPath.rfind( '\\' );
1367         if( found == string::npos )
1368             return vocPath;
1369     }
1370     return vocPath.substr(found + 1, vocPath.size() - found);
1371 }
1372
1373 void VocData::initVoc( const string& vocPath, const bool useTestDataset )
1374 {
1375     initVoc2007to2010( vocPath, useTestDataset );
1376 }
1377
1378 //Initialize file paths and settings for the VOC 2010 dataset
1379 //-----------------------------------------------------------
1380 void VocData::initVoc2007to2010( const string& vocPath, const bool useTestDataset )
1381 {
1382     //check format of root directory and modify if necessary
1383
1384     m_vocName = getVocName( vocPath );
1385
1386     CV_Assert( !m_vocName.compare("VOC2007") || !m_vocName.compare("VOC2008") ||
1387                !m_vocName.compare("VOC2009") || !m_vocName.compare("VOC2010") );
1388
1389     m_vocPath = checkFilenamePathsep( vocPath, true );
1390
1391     if (useTestDataset)
1392     {
1393         m_train_set = "trainval";
1394         m_test_set = "test";
1395     } else {
1396         m_train_set = "train";
1397         m_test_set = "val";
1398     }
1399
1400     // initialize main classification/detection challenge paths
1401     m_annotation_path = m_vocPath + "/Annotations/%s.xml";
1402     m_image_path = m_vocPath + "/JPEGImages/%s.jpg";
1403     m_imageset_path = m_vocPath + "/ImageSets/Main/%s.txt";
1404     m_class_imageset_path = m_vocPath + "/ImageSets/Main/%s_%s.txt";
1405
1406     //define available object_classes for VOC2010 dataset
1407     m_object_classes.push_back("aeroplane");
1408     m_object_classes.push_back("bicycle");
1409     m_object_classes.push_back("bird");
1410     m_object_classes.push_back("boat");
1411     m_object_classes.push_back("bottle");
1412     m_object_classes.push_back("bus");
1413     m_object_classes.push_back("car");
1414     m_object_classes.push_back("cat");
1415     m_object_classes.push_back("chair");
1416     m_object_classes.push_back("cow");
1417     m_object_classes.push_back("diningtable");
1418     m_object_classes.push_back("dog");
1419     m_object_classes.push_back("horse");
1420     m_object_classes.push_back("motorbike");
1421     m_object_classes.push_back("person");
1422     m_object_classes.push_back("pottedplant");
1423     m_object_classes.push_back("sheep");
1424     m_object_classes.push_back("sofa");
1425     m_object_classes.push_back("train");
1426     m_object_classes.push_back("tvmonitor");
1427
1428     m_min_overlap = 0.5;
1429
1430     //up until VOC 2010, ap was calculated by sampling p-r curve, not taking complete curve
1431     m_sampled_ap = ((m_vocName == "VOC2007") || (m_vocName == "VOC2008") || (m_vocName == "VOC2009"));
1432 }
1433
1434 //Read a VOC classification ground truth text file for a given object class and dataset
1435 //-------------------------------------------------------------------------------------
1436 //INPUTS:
1437 // - filename           The path of the text file to read
1438 //OUTPUTS:
1439 // - image_codes        VOC image codes extracted from the GT file in the form 20XX_XXXXXX where the first four
1440 //                          digits specify the year of the dataset, and the last group specifies a unique ID
1441 // - object_present     For each image in the 'image_codes' array, specifies whether the object class described
1442 //                          in the loaded GT file is present or not
1443 void VocData::readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present)
1444 {
1445     image_codes.clear();
1446     object_present.clear();
1447
1448     std::ifstream gtfile(filename.c_str());
1449     if (!gtfile.is_open())
1450     {
1451         string err_msg = "could not open VOC ground truth textfile '" + filename + "'.";
1452         CV_Error(CV_StsError,err_msg.c_str());
1453     }
1454
1455     string line;
1456     string image;
1457     int obj_present = 0;
1458     while (!gtfile.eof())
1459     {
1460         std::getline(gtfile,line);
1461         std::istringstream iss(line);
1462         iss >> image >> obj_present;
1463         if (!iss.fail())
1464         {
1465             image_codes.push_back(image);
1466             object_present.push_back(obj_present == 1);
1467         } else {
1468             if (!gtfile.eof()) CV_Error(CV_StsParseError,"error parsing VOC ground truth textfile.");
1469         }
1470     }
1471     gtfile.close();
1472 }
1473
1474 void VocData::readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores)
1475 {
1476     //check if results file exists
1477     std::ifstream result_file(input_file.c_str());
1478     if (result_file.is_open())
1479     {
1480         string line;
1481         string image;
1482         float score;
1483         //read in the results file
1484         while (!result_file.eof())
1485         {
1486             std::getline(result_file,line);
1487             std::istringstream iss(line);
1488             iss >> image >> score;
1489             if (!iss.fail())
1490             {
1491                 image_codes.push_back(image);
1492                 scores.push_back(score);
1493             } else {
1494                 if(!result_file.eof()) CV_Error(CV_StsParseError,"error parsing VOC classifier results file.");
1495             }
1496         }
1497         result_file.close();
1498     } else {
1499         string err_msg = "could not open classifier results file '" + input_file + "' for reading.";
1500         CV_Error(CV_StsError,err_msg.c_str());
1501     }
1502 }
1503
1504 void VocData::readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
1505 {
1506     image_codes.clear();
1507     scores.clear();
1508     bounding_boxes.clear();
1509
1510     //check if results file exists
1511     std::ifstream result_file(input_file.c_str());
1512     if (result_file.is_open())
1513     {
1514         string line;
1515         string image;
1516         Rect bounding_box;
1517         float score;
1518         //read in the results file
1519         while (!result_file.eof())
1520         {
1521             std::getline(result_file,line);
1522             std::istringstream iss(line);
1523             iss >> image >> score >> bounding_box.x >> bounding_box.y >> bounding_box.width >> bounding_box.height;
1524             if (!iss.fail())
1525             {
1526                 //convert right and bottom positions to width and height
1527                 bounding_box.width -= bounding_box.x;
1528                 bounding_box.height -= bounding_box.y;
1529                 //convert to 0-indexing
1530                 bounding_box.x -= 1;
1531                 bounding_box.y -= 1;
1532                 //store in output vectors
1533                 /* first check if the current image code has been seen before */
1534                 vector<string>::iterator image_codes_it = std::find(image_codes.begin(),image_codes.end(),image);
1535                 if (image_codes_it == image_codes.end())
1536                 {
1537                     image_codes.push_back(image);
1538                     vector<float> score_vect(1);
1539                     score_vect[0] = score;
1540                     scores.push_back(score_vect);
1541                     vector<Rect> bounding_box_vect(1);
1542                     bounding_box_vect[0] = bounding_box;
1543                     bounding_boxes.push_back(bounding_box_vect);
1544                 } else {
1545                     /* if the image index has been seen before, add the current object below it in the 2D arrays */
1546                     int image_idx = (int)std::distance(image_codes.begin(),image_codes_it);
1547                     scores[image_idx].push_back(score);
1548                     bounding_boxes[image_idx].push_back(bounding_box);
1549                 }
1550             } else {
1551                 if(!result_file.eof()) CV_Error(CV_StsParseError,"error parsing VOC detector results file.");
1552             }
1553         }
1554         result_file.close();
1555     } else {
1556         string err_msg = "could not open detector results file '" + input_file + "' for reading.";
1557         CV_Error(CV_StsError,err_msg.c_str());
1558     }
1559 }
1560
1561
1562 //Read a VOC annotation xml file for a given image
1563 //------------------------------------------------
1564 //INPUTS:
1565 // - filename           The path of the xml file to read
1566 //OUTPUTS:
1567 // - objects            Array of VocObject describing all object instances present in the given image
1568 void VocData::extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
1569 {
1570 #ifdef PR_DEBUG
1571     int block = 1;
1572     cout << "SAMPLE VOC OBJECT EXTRACTION for " << filename << ":" << endl;
1573 #endif
1574     objects.clear();
1575     object_data.clear();
1576
1577     string contents, object_contents, tag_contents;
1578
1579     readFileToString(filename, contents);
1580
1581     //keep on extracting 'object' blocks until no more can be found
1582     if (extractXMLBlock(contents, "annotation", 0, contents) != -1)
1583     {
1584         int searchpos = 0;
1585         searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
1586         while (searchpos != -1)
1587         {
1588 #ifdef PR_DEBUG
1589             cout << "SEARCHPOS:" << searchpos << endl;
1590             cout << "start block " << block << " ---------" << endl;
1591             cout << object_contents << endl;
1592             cout << "end block " << block << " -----------" << endl;
1593             ++block;
1594 #endif
1595
1596             ObdObject object;
1597             VocObjectData object_d;
1598
1599             //object class -------------
1600
1601             if (extractXMLBlock(object_contents, "name", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <name> tag in object definition of '" + filename + "'");
1602             object.object_class.swap(tag_contents);
1603
1604             //object bounding box -------------
1605
1606             int xmax, xmin, ymax, ymin;
1607
1608             if (extractXMLBlock(object_contents, "xmax", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <xmax> tag in object definition of '" + filename + "'");
1609             xmax = stringToInteger(tag_contents);
1610
1611             if (extractXMLBlock(object_contents, "xmin", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <xmin> tag in object definition of '" + filename + "'");
1612             xmin = stringToInteger(tag_contents);
1613
1614             if (extractXMLBlock(object_contents, "ymax", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <ymax> tag in object definition of '" + filename + "'");
1615             ymax = stringToInteger(tag_contents);
1616
1617             if (extractXMLBlock(object_contents, "ymin", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <ymin> tag in object definition of '" + filename + "'");
1618             ymin = stringToInteger(tag_contents);
1619
1620             object.boundingBox.x = xmin-1;      //convert to 0-based indexing
1621             object.boundingBox.width = xmax - xmin;
1622             object.boundingBox.y = ymin-1;
1623             object.boundingBox.height = ymax - ymin;
1624
1625             CV_Assert(xmin != 0);
1626             CV_Assert(xmax > xmin);
1627             CV_Assert(ymin != 0);
1628             CV_Assert(ymax > ymin);
1629
1630
1631             //object tags -------------
1632
1633             if (extractXMLBlock(object_contents, "difficult", 0, tag_contents) != -1)
1634             {
1635                 object_d.difficult = (tag_contents == "1");
1636             } else object_d.difficult = false;
1637             if (extractXMLBlock(object_contents, "occluded", 0, tag_contents) != -1)
1638             {
1639                 object_d.occluded = (tag_contents == "1");
1640             } else object_d.occluded = false;
1641             if (extractXMLBlock(object_contents, "truncated", 0, tag_contents) != -1)
1642             {
1643                 object_d.truncated = (tag_contents == "1");
1644             } else object_d.truncated = false;
1645             if (extractXMLBlock(object_contents, "pose", 0, tag_contents) != -1)
1646             {
1647                 if (tag_contents == "Frontal") object_d.pose = CV_VOC_POSE_FRONTAL;
1648                 if (tag_contents == "Rear") object_d.pose = CV_VOC_POSE_REAR;
1649                 if (tag_contents == "Left") object_d.pose = CV_VOC_POSE_LEFT;
1650                 if (tag_contents == "Right") object_d.pose = CV_VOC_POSE_RIGHT;
1651             }
1652
1653             //add to array of objects
1654             objects.push_back(object);
1655             object_data.push_back(object_d);
1656
1657             //extract next 'object' block from file if it exists
1658             searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
1659         }
1660     }
1661 }
1662
1663 //Converts an image identifier string in the format YYYY_XXXXXX to a single index integer of form XXXXXXYYYY
1664 //where Y represents a year and returns the image path
1665 //----------------------------------------------------------------------------------------------------------
1666 string VocData::getImagePath(const string& input_str)
1667 {
1668     string path = m_image_path;
1669     path.replace(path.find("%s"),2,input_str);
1670     return path;
1671 }
1672
1673 //Tests two boundary boxes for overlap (using the intersection over union metric) and returns the overlap if the objects
1674 //defined by the two bounding boxes are considered to be matched according to the criterion outlined in
1675 //the VOC documentation [namely intersection/union > some threshold] otherwise returns -1.0 (no match)
1676 //----------------------------------------------------------------------------------------------------------
1677 float VocData::testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth)
1678 {
1679     int detection_x2 = detection.x + detection.width;
1680     int detection_y2 = detection.y + detection.height;
1681     int ground_truth_x2 = ground_truth.x + ground_truth.width;
1682     int ground_truth_y2 = ground_truth.y + ground_truth.height;
1683     //first calculate the boundaries of the intersection of the rectangles
1684     int intersection_x = std::max(detection.x, ground_truth.x); //rightmost left
1685     int intersection_y = std::max(detection.y, ground_truth.y); //bottommost top
1686     int intersection_x2 = std::min(detection_x2, ground_truth_x2); //leftmost right
1687     int intersection_y2 = std::min(detection_y2, ground_truth_y2); //topmost bottom
1688     //then calculate the width and height of the intersection rect
1689     int intersection_width = intersection_x2 - intersection_x + 1;
1690     int intersection_height = intersection_y2 - intersection_y + 1;
1691     //if there is no overlap then return false straight away
1692     if ((intersection_width <= 0) || (intersection_height <= 0)) return -1.0;
1693     //otherwise calculate the intersection
1694     int intersection_area = intersection_width*intersection_height;
1695
1696     //now calculate the union
1697     int union_area = (detection.width+1)*(detection.height+1) + (ground_truth.width+1)*(ground_truth.height+1) - intersection_area;
1698
1699     //calculate the intersection over union and use as threshold as per VOC documentation
1700     float overlap = static_cast<float>(intersection_area)/static_cast<float>(union_area);
1701     if (overlap > m_min_overlap)
1702     {
1703         return overlap;
1704     } else {
1705         return -1.0;
1706     }
1707 }
1708
1709 //Extracts the object class and dataset from the filename of a VOC standard results text file, which takes
1710 //the format 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'
1711 //----------------------------------------------------------------------------------------------------------
1712 void VocData::extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name)
1713 {
1714     string input_file_std = checkFilenamePathsep(input_file);
1715
1716     size_t fnamestart = input_file_std.rfind("/");
1717     size_t fnameend = input_file_std.rfind(".txt");
1718
1719     if ((fnamestart == input_file_std.npos) || (fnameend == input_file_std.npos))
1720         CV_Error(CV_StsError,"Could not extract filename of results file.");
1721
1722     ++fnamestart;
1723     if (fnamestart >= fnameend)
1724         CV_Error(CV_StsError,"Could not extract filename of results file.");
1725
1726     //extract dataset and class names, triggering exception if the filename format is not correct
1727     string filename = input_file_std.substr(fnamestart, fnameend-fnamestart);
1728     size_t datasetstart = filename.find("_");
1729     datasetstart = filename.find("_",datasetstart+1);
1730     size_t classstart = filename.find("_",datasetstart+1);
1731     //allow for appended index after a further '_' by discarding this part if it exists
1732     size_t classend = filename.find("_",classstart+1);
1733     if (classend == filename.npos) classend = filename.size();
1734     if ((datasetstart == filename.npos) || (classstart == filename.npos))
1735         CV_Error(CV_StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
1736     ++datasetstart;
1737     ++classstart;
1738     if (((datasetstart-classstart) < 1) || ((classend-datasetstart) < 1))
1739         CV_Error(CV_StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
1740
1741     dataset_name = filename.substr(datasetstart,classstart-datasetstart-1);
1742     class_name = filename.substr(classstart,classend-classstart);
1743 }
1744
1745 bool VocData::getClassifierGroundTruthImage(const string& obj_class, const string& id)
1746 {
1747     /* if the classifier ground truth data for all images of the current class has not been loaded yet, load it now */
1748     if (m_classifier_gt_all_ids.empty() || (m_classifier_gt_class != obj_class))
1749     {
1750         m_classifier_gt_all_ids.clear();
1751         m_classifier_gt_all_present.clear();
1752         m_classifier_gt_class = obj_class;
1753         for (int i=0; i<2; ++i) //run twice (once over test set and once over training set)
1754         {
1755             //generate the filename of the classification ground-truth textfile for the object class
1756             string gtFilename = m_class_imageset_path;
1757             gtFilename.replace(gtFilename.find("%s"),2,obj_class);
1758             if (i == 0)
1759             {
1760                 gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
1761             } else {
1762                 gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
1763             }
1764
1765             //parse the ground truth file, storing in two separate vectors
1766             //for the image code and the ground truth value
1767             vector<string> image_codes;
1768             vector<char> object_present;
1769             readClassifierGroundTruth(gtFilename, image_codes, object_present);
1770
1771             m_classifier_gt_all_ids.insert(m_classifier_gt_all_ids.end(),image_codes.begin(),image_codes.end());
1772             m_classifier_gt_all_present.insert(m_classifier_gt_all_present.end(),object_present.begin(),object_present.end());
1773
1774             CV_Assert(m_classifier_gt_all_ids.size() == m_classifier_gt_all_present.size());
1775         }
1776     }
1777
1778
1779     //search for the image code
1780     vector<string>::iterator it = find (m_classifier_gt_all_ids.begin(), m_classifier_gt_all_ids.end(), id);
1781     if (it != m_classifier_gt_all_ids.end())
1782     {
1783         //image found, so return corresponding ground truth
1784         return m_classifier_gt_all_present[std::distance(m_classifier_gt_all_ids.begin(),it)] != 0;
1785     } else {
1786         string err_msg = "could not find classifier ground truth for image '" + id + "' and class '" + obj_class + "'";
1787         CV_Error(CV_StsError,err_msg.c_str());
1788     }
1789
1790     return true;
1791 }
1792
1793 //-------------------------------------------------------------------
1794 // Protected Functions (utility) ------------------------------------
1795 //-------------------------------------------------------------------
1796
1797 //returns a vector containing indexes of the input vector in sorted ascending/descending order
1798 void VocData::getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending)
1799 {
1800     /* 1. store sorting order in 'order_pair' */
1801     vector<std::pair<size_t, vector<float>::const_iterator> > order_pair(values.size());
1802
1803     size_t n = 0;
1804     for (vector<float>::const_iterator it = values.begin(); it != values.end(); ++it, ++n)
1805         order_pair[n] = make_pair(n, it);
1806
1807     std::sort(order_pair.begin(),order_pair.end(),orderingSorter());
1808     if (descending == false) std::reverse(order_pair.begin(),order_pair.end());
1809
1810     vector<size_t>(order_pair.size()).swap(order);
1811     for (size_t i = 0; i < order_pair.size(); ++i)
1812     {
1813         order[i] = order_pair[i].first;
1814     }
1815 }
1816
1817 void VocData::readFileToString(const string filename, string& file_contents)
1818 {
1819     std::ifstream ifs(filename.c_str());
1820     if (!ifs.is_open()) CV_Error(CV_StsError,"could not open text file");
1821
1822     stringstream oss;
1823     oss << ifs.rdbuf();
1824
1825     file_contents = oss.str();
1826 }
1827
1828 int VocData::stringToInteger(const string input_str)
1829 {
1830     int result = 0;
1831
1832     stringstream ss(input_str);
1833     if ((ss >> result).fail())
1834     {
1835         CV_Error(CV_StsBadArg,"could not perform string to integer conversion");
1836     }
1837     return result;
1838 }
1839
1840 string VocData::integerToString(const int input_int)
1841 {
1842     string result;
1843
1844     stringstream ss;
1845     if ((ss << input_int).fail())
1846     {
1847         CV_Error(CV_StsBadArg,"could not perform integer to string conversion");
1848     }
1849     result = ss.str();
1850     return result;
1851 }
1852
1853 string VocData::checkFilenamePathsep( const string filename, bool add_trailing_slash )
1854 {
1855     string filename_new = filename;
1856
1857     size_t pos = filename_new.find("\\\\");
1858     while (pos != filename_new.npos)
1859     {
1860         filename_new.replace(pos,2,"/");
1861         pos = filename_new.find("\\\\", pos);
1862     }
1863     pos = filename_new.find("\\");
1864     while (pos != filename_new.npos)
1865     {
1866         filename_new.replace(pos,1,"/");
1867         pos = filename_new.find("\\", pos);
1868     }
1869     if (add_trailing_slash)
1870     {
1871         //add training slash if this is missing
1872         if (filename_new.rfind("/") != filename_new.length()-1) filename_new += "/";
1873     }
1874
1875     return filename_new;
1876 }
1877
1878 void VocData::convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images)
1879 {
1880     images.clear();
1881     images.reserve(image_codes.size());
1882
1883     string path;
1884     //transfer to output arrays
1885     for (size_t i = 0; i < image_codes.size(); ++i)
1886     {
1887         //generate image path and indices from extracted string code
1888         path = getImagePath(image_codes[i]);
1889         images.push_back(ObdImage(image_codes[i], path));
1890     }
1891 }
1892
1893 //Extract text from within a given tag from an XML file
1894 //-----------------------------------------------------
1895 //INPUTS:
1896 // - src            XML source file
1897 // - tag            XML tag delimiting block to extract
1898 // - searchpos      position within src at which to start search
1899 //OUTPUTS:
1900 // - tag_contents   text extracted between <tag> and </tag> tags
1901 //RETURN VALUE:
1902 // - the position of the final character extracted in tag_contents within src
1903 //      (can be used to call extractXMLBlock recursively to extract multiple blocks)
1904 //      returns -1 if the tag could not be found
1905 int VocData::extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents)
1906 {
1907     size_t startpos, next_startpos, endpos;
1908     int embed_count = 1;
1909
1910     //find position of opening tag
1911     startpos = src.find("<" + tag + ">", searchpos);
1912     if (startpos == string::npos) return -1;
1913
1914     //initialize endpos -
1915     // start searching for end tag anywhere after opening tag
1916     endpos = startpos;
1917
1918     //find position of next opening tag
1919     next_startpos = src.find("<" + tag + ">", startpos+1);
1920
1921     //match opening tags with closing tags, and only
1922     //accept final closing tag of same level as original
1923     //opening tag
1924     while (embed_count > 0)
1925     {
1926         endpos = src.find("</" + tag + ">", endpos+1);
1927         if (endpos == string::npos) return -1;
1928
1929         //the next code is only executed if there are embedded tags with the same name
1930         if (next_startpos != string::npos)
1931         {
1932             while (next_startpos<endpos)
1933             {
1934                 //counting embedded start tags
1935                 ++embed_count;
1936                 next_startpos = src.find("<" + tag + ">", next_startpos+1);
1937                 if (next_startpos == string::npos) break;
1938             }
1939         }
1940         //passing end tag so decrement nesting level
1941         --embed_count;
1942     }
1943
1944     //finally, extract the tag region
1945     startpos += tag.length() + 2;
1946     if (startpos > src.length()) return -1;
1947     if (endpos > src.length()) return -1;
1948     tag_contents = src.substr(startpos,endpos-startpos);
1949     return static_cast<int>(endpos);
1950 }
1951
1952 /****************************************************************************************\
1953 *                            Sample on image classification                             *
1954 \****************************************************************************************/
1955 //
1956 // This part of the code was a little refactor
1957 //
1958 struct DDMParams
1959 {
1960     DDMParams() : detectorType("SURF"), descriptorType("SURF"), matcherType("BruteForce") {}
1961     DDMParams( const string _detectorType, const string _descriptorType, const string& _matcherType ) :
1962         detectorType(_detectorType), descriptorType(_descriptorType), matcherType(_matcherType){}
1963     void read( const FileNode& fn )
1964     {
1965         fn["detectorType"] >> detectorType;
1966         fn["descriptorType"] >> descriptorType;
1967         fn["matcherType"] >> matcherType;
1968     }
1969     void write( FileStorage& fs ) const
1970     {
1971         fs << "detectorType" << detectorType;
1972         fs << "descriptorType" << descriptorType;
1973         fs << "matcherType" << matcherType;
1974     }
1975     void print() const
1976     {
1977         cout << "detectorType: " << detectorType << endl;
1978         cout << "descriptorType: " << descriptorType << endl;
1979         cout << "matcherType: " << matcherType << endl;
1980     }
1981
1982     string detectorType;
1983     string descriptorType;
1984     string matcherType;
1985 };
1986
1987 struct VocabTrainParams
1988 {
1989     VocabTrainParams() : trainObjClass("chair"), vocabSize(1000), memoryUse(200), descProportion(0.3f) {}
1990     VocabTrainParams( const string _trainObjClass, size_t _vocabSize, size_t _memoryUse, float _descProportion ) :
1991             trainObjClass(_trainObjClass), vocabSize((int)_vocabSize), memoryUse((int)_memoryUse), descProportion(_descProportion) {}
1992     void read( const FileNode& fn )
1993     {
1994         fn["trainObjClass"] >> trainObjClass;
1995         fn["vocabSize"] >> vocabSize;
1996         fn["memoryUse"] >> memoryUse;
1997         fn["descProportion"] >> descProportion;
1998     }
1999     void write( FileStorage& fs ) const
2000     {
2001         fs << "trainObjClass" << trainObjClass;
2002         fs << "vocabSize" << vocabSize;
2003         fs << "memoryUse" << memoryUse;
2004         fs << "descProportion" << descProportion;
2005     }
2006     void print() const
2007     {
2008         cout << "trainObjClass: " << trainObjClass << endl;
2009         cout << "vocabSize: " << vocabSize << endl;
2010         cout << "memoryUse: " << memoryUse << endl;
2011         cout << "descProportion: " << descProportion << endl;
2012     }
2013
2014
2015     string trainObjClass; // Object class used for training visual vocabulary.
2016                           // It shouldn't matter which object class is specified here - visual vocab will still be the same.
2017     int vocabSize; //number of visual words in vocabulary to train
2018     int memoryUse; // Memory to preallocate (in MB) when training vocab.
2019                    // Change this depending on the size of the dataset/available memory.
2020     float descProportion; // Specifies the number of descriptors to use from each image as a proportion of the total num descs.
2021 };
2022
2023 struct SVMTrainParamsExt
2024 {
2025     SVMTrainParamsExt() : descPercent(0.5f), targetRatio(0.4f), balanceClasses(true) {}
2026     SVMTrainParamsExt( float _descPercent, float _targetRatio, bool _balanceClasses ) :
2027             descPercent(_descPercent), targetRatio(_targetRatio), balanceClasses(_balanceClasses) {}
2028     void read( const FileNode& fn )
2029     {
2030         fn["descPercent"] >> descPercent;
2031         fn["targetRatio"] >> targetRatio;
2032         fn["balanceClasses"] >> balanceClasses;
2033     }
2034     void write( FileStorage& fs ) const
2035     {
2036         fs << "descPercent" << descPercent;
2037         fs << "targetRatio" << targetRatio;
2038         fs << "balanceClasses" << balanceClasses;
2039     }
2040     void print() const
2041     {
2042         cout << "descPercent: " << descPercent << endl;
2043         cout << "targetRatio: " << targetRatio << endl;
2044         cout << "balanceClasses: " << balanceClasses << endl;
2045     }
2046
2047     float descPercent; // Percentage of extracted descriptors to use for training.
2048     float targetRatio; // Try to get this ratio of positive to negative samples (minimum).
2049     bool balanceClasses;    // Balance class weights by number of samples in each (if true cSvmTrainTargetRatio is ignored).
2050 };
2051
2052 static void readUsedParams( const FileNode& fn, string& vocName, DDMParams& ddmParams, VocabTrainParams& vocabTrainParams, SVMTrainParamsExt& svmTrainParamsExt )
2053 {
2054     fn["vocName"] >> vocName;
2055
2056     FileNode currFn = fn;
2057
2058     currFn = fn["ddmParams"];
2059     ddmParams.read( currFn );
2060
2061     currFn = fn["vocabTrainParams"];
2062     vocabTrainParams.read( currFn );
2063
2064     currFn = fn["svmTrainParamsExt"];
2065     svmTrainParamsExt.read( currFn );
2066 }
2067
2068 static void writeUsedParams( FileStorage& fs, const string& vocName, const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams, const SVMTrainParamsExt& svmTrainParamsExt )
2069 {
2070     fs << "vocName" << vocName;
2071
2072     fs << "ddmParams" << "{";
2073     ddmParams.write(fs);
2074     fs << "}";
2075
2076     fs << "vocabTrainParams" << "{";
2077     vocabTrainParams.write(fs);
2078     fs << "}";
2079
2080     fs << "svmTrainParamsExt" << "{";
2081     svmTrainParamsExt.write(fs);
2082     fs << "}";
2083 }
2084
2085 static void printUsedParams( const string& vocPath, const string& resDir,
2086                       const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams,
2087                       const SVMTrainParamsExt& svmTrainParamsExt )
2088 {
2089     cout << "CURRENT CONFIGURATION" << endl;
2090     cout << "----------------------------------------------------------------" << endl;
2091     cout << "vocPath: " << vocPath << endl;
2092     cout << "resDir: " << resDir << endl;
2093     cout << endl; ddmParams.print();
2094     cout << endl; vocabTrainParams.print();
2095     cout << endl; svmTrainParamsExt.print();
2096     cout << "----------------------------------------------------------------" << endl << endl;
2097 }
2098
2099 static bool readVocabulary( const string& filename, Mat& vocabulary )
2100 {
2101     cout << "Reading vocabulary...";
2102     FileStorage fs( filename, FileStorage::READ );
2103     if( fs.isOpened() )
2104     {
2105         fs["vocabulary"] >> vocabulary;
2106         cout << "done" << endl;
2107         return true;
2108     }
2109     return false;
2110 }
2111
2112 static bool writeVocabulary( const string& filename, const Mat& vocabulary )
2113 {
2114     cout << "Saving vocabulary..." << endl;
2115     FileStorage fs( filename, FileStorage::WRITE );
2116     if( fs.isOpened() )
2117     {
2118         fs << "vocabulary" << vocabulary;
2119         return true;
2120     }
2121     return false;
2122 }
2123
2124 static Mat trainVocabulary( const string& filename, VocData& vocData, const VocabTrainParams& trainParams,
2125                      const Ptr<FeatureDetector>& fdetector, const Ptr<DescriptorExtractor>& dextractor )
2126 {
2127     Mat vocabulary;
2128     if( !readVocabulary( filename, vocabulary) )
2129     {
2130         CV_Assert( dextractor->descriptorType() == CV_32FC1 );
2131         const int elemSize = CV_ELEM_SIZE(dextractor->descriptorType());
2132         const int descByteSize = dextractor->descriptorSize() * elemSize;
2133         const int bytesInMB = 1048576;
2134         const int maxDescCount = (trainParams.memoryUse * bytesInMB) / descByteSize; // Total number of descs to use for training.
2135
2136         cout << "Extracting VOC data..." << endl;
2137         vector<ObdImage> images;
2138         vector<char> objectPresent;
2139         vocData.getClassImages( trainParams.trainObjClass, CV_OBD_TRAIN, images, objectPresent );
2140
2141         cout << "Computing descriptors..." << endl;
2142         RNG& rng = theRNG();
2143         TermCriteria terminate_criterion;
2144         terminate_criterion.epsilon = FLT_EPSILON;
2145         BOWKMeansTrainer bowTrainer( trainParams.vocabSize, terminate_criterion, 3, KMEANS_PP_CENTERS );
2146
2147         while( images.size() > 0 )
2148         {
2149             if( bowTrainer.descriptorsCount() > maxDescCount )
2150             {
2151 #ifdef DEBUG_DESC_PROGRESS
2152                 cout << "Breaking due to full memory ( descriptors count = " << bowTrainer.descriptorsCount()
2153                         << "; descriptor size in bytes = " << descByteSize << "; all used memory = "
2154                         << bowTrainer.descriptorsCount()*descByteSize << endl;
2155 #endif
2156                 break;
2157             }
2158
2159             // Randomly pick an image from the dataset which hasn't yet been seen
2160             // and compute the descriptors from that image.
2161             int randImgIdx = rng( (unsigned)images.size() );
2162             Mat colorImage = imread( images[randImgIdx].path );
2163             vector<KeyPoint> imageKeypoints;
2164             fdetector->detect( colorImage, imageKeypoints );
2165             Mat imageDescriptors;
2166             dextractor->compute( colorImage, imageKeypoints, imageDescriptors );
2167
2168             //check that there were descriptors calculated for the current image
2169             if( !imageDescriptors.empty() )
2170             {
2171                 int descCount = imageDescriptors.rows;
2172                 // Extract trainParams.descProportion descriptors from the image, breaking if the 'allDescriptors' matrix becomes full
2173                 int descsToExtract = static_cast<int>(trainParams.descProportion * static_cast<float>(descCount));
2174                 // Fill mask of used descriptors
2175                 vector<char> usedMask( descCount, false );
2176                 fill( usedMask.begin(), usedMask.begin() + descsToExtract, true );
2177                 for( int i = 0; i < descCount; i++ )
2178                 {
2179                     int i1 = rng(descCount), i2 = rng(descCount);
2180                     char tmp = usedMask[i1]; usedMask[i1] = usedMask[i2]; usedMask[i2] = tmp;
2181                 }
2182
2183                 for( int i = 0; i < descCount; i++ )
2184                 {
2185                     if( usedMask[i] && bowTrainer.descriptorsCount() < maxDescCount )
2186                         bowTrainer.add( imageDescriptors.row(i) );
2187                 }
2188             }
2189
2190 #ifdef DEBUG_DESC_PROGRESS
2191             cout << images.size() << " images left, " << images[randImgIdx].id << " processed - "
2192                     <</* descs_extracted << "/" << image_descriptors.rows << " extracted - " << */
2193                     cvRound((static_cast<double>(bowTrainer.descriptorsCount())/static_cast<double>(maxDescCount))*100.0)
2194                     << " % memory used" << ( imageDescriptors.empty() ? " -> no descriptors extracted, skipping" : "") << endl;
2195 #endif
2196
2197             // Delete the current element from images so it is not added again
2198             images.erase( images.begin() + randImgIdx );
2199         }
2200
2201         cout << "Maximum allowed descriptor count: " << maxDescCount << ", Actual descriptor count: " << bowTrainer.descriptorsCount() << endl;
2202
2203         cout << "Training vocabulary..." << endl;
2204         vocabulary = bowTrainer.cluster();
2205
2206         if( !writeVocabulary(filename, vocabulary) )
2207         {
2208             cout << "Error: file " << filename << " can not be opened to write" << endl;
2209             exit(-1);
2210         }
2211     }
2212     return vocabulary;
2213 }
2214
2215 static bool readBowImageDescriptor( const string& file, Mat& bowImageDescriptor )
2216 {
2217     FileStorage fs( file, FileStorage::READ );
2218     if( fs.isOpened() )
2219     {
2220         fs["imageDescriptor"] >> bowImageDescriptor;
2221         return true;
2222     }
2223     return false;
2224 }
2225
2226 static bool writeBowImageDescriptor( const string& file, const Mat& bowImageDescriptor )
2227 {
2228     FileStorage fs( file, FileStorage::WRITE );
2229     if( fs.isOpened() )
2230     {
2231         fs << "imageDescriptor" << bowImageDescriptor;
2232         return true;
2233     }
2234     return false;
2235 }
2236
2237 // Load in the bag of words vectors for a set of images, from file if possible
2238 static void calculateImageDescriptors( const vector<ObdImage>& images, vector<Mat>& imageDescriptors,
2239                                 Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2240                                 const string& resPath )
2241 {
2242     CV_Assert( !bowExtractor->getVocabulary().empty() );
2243     imageDescriptors.resize( images.size() );
2244
2245     for( size_t i = 0; i < images.size(); i++ )
2246     {
2247         string filename = resPath + bowImageDescriptorsDir + "/" + images[i].id + ".xml.gz";
2248         if( readBowImageDescriptor( filename, imageDescriptors[i] ) )
2249         {
2250 #ifdef DEBUG_DESC_PROGRESS
2251             cout << "Loaded bag of word vector for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << endl;
2252 #endif
2253         }
2254         else
2255         {
2256             Mat colorImage = imread( images[i].path );
2257 #ifdef DEBUG_DESC_PROGRESS
2258             cout << "Computing descriptors for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << flush;
2259 #endif
2260             vector<KeyPoint> keypoints;
2261             fdetector->detect( colorImage, keypoints );
2262 #ifdef DEBUG_DESC_PROGRESS
2263                 cout << " + generating BoW vector" << std::flush;
2264 #endif
2265             bowExtractor->compute( colorImage, keypoints, imageDescriptors[i] );
2266 #ifdef DEBUG_DESC_PROGRESS
2267             cout << " ...DONE " << static_cast<int>(static_cast<float>(i+1)/static_cast<float>(images.size())*100.0)
2268                  << " % complete" << endl;
2269 #endif
2270             if( !imageDescriptors[i].empty() )
2271             {
2272                 if( !writeBowImageDescriptor( filename, imageDescriptors[i] ) )
2273                 {
2274                     cout << "Error: file " << filename << "can not be opened to write bow image descriptor" << endl;
2275                     exit(-1);
2276                 }
2277             }
2278         }
2279     }
2280 }
2281
2282 static void removeEmptyBowImageDescriptors( vector<ObdImage>& images, vector<Mat>& bowImageDescriptors,
2283                                      vector<char>& objectPresent )
2284 {
2285     CV_Assert( !images.empty() );
2286     for( int i = (int)images.size() - 1; i >= 0; i-- )
2287     {
2288         bool res = bowImageDescriptors[i].empty();
2289         if( res )
2290         {
2291             cout << "Removing image " << images[i].id << " due to no descriptors..." << endl;
2292             images.erase( images.begin() + i );
2293             bowImageDescriptors.erase( bowImageDescriptors.begin() + i );
2294             objectPresent.erase( objectPresent.begin() + i );
2295         }
2296     }
2297 }
2298
2299 static void removeBowImageDescriptorsByCount( vector<ObdImage>& images, vector<Mat> bowImageDescriptors, vector<char> objectPresent,
2300                                        const SVMTrainParamsExt& svmParamsExt, int descsToDelete )
2301 {
2302     RNG& rng = theRNG();
2303     int pos_ex = (int)std::count( objectPresent.begin(), objectPresent.end(), (char)1 );
2304     int neg_ex = (int)std::count( objectPresent.begin(), objectPresent.end(), (char)0 );
2305
2306     while( descsToDelete != 0 )
2307     {
2308         int randIdx = rng((unsigned)images.size());
2309
2310         // Prefer positive training examples according to svmParamsExt.targetRatio if required
2311         if( objectPresent[randIdx] )
2312         {
2313             if( (static_cast<float>(pos_ex)/static_cast<float>(neg_ex+pos_ex)  < svmParamsExt.targetRatio) &&
2314                 (neg_ex > 0) && (svmParamsExt.balanceClasses == false) )
2315             { continue; }
2316             else
2317             { pos_ex--; }
2318         }
2319         else
2320         { neg_ex--; }
2321
2322         images.erase( images.begin() + randIdx );
2323         bowImageDescriptors.erase( bowImageDescriptors.begin() + randIdx );
2324         objectPresent.erase( objectPresent.begin() + randIdx );
2325
2326         descsToDelete--;
2327     }
2328     CV_Assert( bowImageDescriptors.size() == objectPresent.size() );
2329 }
2330
2331 static void setSVMParams( CvSVMParams& svmParams, CvMat& class_wts_cv, const Mat& responses, bool balanceClasses )
2332 {
2333     int pos_ex = countNonZero(responses == 1);
2334     int neg_ex = countNonZero(responses == -1);
2335     cout << pos_ex << " positive training samples; " << neg_ex << " negative training samples" << endl;
2336
2337     svmParams.svm_type = CvSVM::C_SVC;
2338     svmParams.kernel_type = CvSVM::RBF;
2339     if( balanceClasses )
2340     {
2341         Mat class_wts( 2, 1, CV_32FC1 );
2342         // The first training sample determines the '+1' class internally, even if it is negative,
2343         // so store whether this is the case so that the class weights can be reversed accordingly.
2344         bool reversed_classes = (responses.at<float>(0) < 0.f);
2345         if( reversed_classes == false )
2346         {
2347             class_wts.at<float>(0) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of positive class + 1 (i.e. cost of false positive - larger gives greater cost)
2348             class_wts.at<float>(1) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of negative class - 1 (i.e. cost of false negative)
2349         }
2350         else
2351         {
2352             class_wts.at<float>(0) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex);
2353             class_wts.at<float>(1) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex);
2354         }
2355         class_wts_cv = class_wts;
2356         svmParams.class_weights = &class_wts_cv;
2357     }
2358 }
2359
2360 static void setSVMTrainAutoParams( CvParamGrid& c_grid, CvParamGrid& gamma_grid,
2361                             CvParamGrid& p_grid, CvParamGrid& nu_grid,
2362                             CvParamGrid& coef_grid, CvParamGrid& degree_grid )
2363 {
2364     c_grid = CvSVM::get_default_grid(CvSVM::C);
2365
2366     gamma_grid = CvSVM::get_default_grid(CvSVM::GAMMA);
2367
2368     p_grid = CvSVM::get_default_grid(CvSVM::P);
2369     p_grid.step = 0;
2370
2371     nu_grid = CvSVM::get_default_grid(CvSVM::NU);
2372     nu_grid.step = 0;
2373
2374     coef_grid = CvSVM::get_default_grid(CvSVM::COEF);
2375     coef_grid.step = 0;
2376
2377     degree_grid = CvSVM::get_default_grid(CvSVM::DEGREE);
2378     degree_grid.step = 0;
2379 }
2380
2381 #if defined HAVE_OPENCV_OCL && _OCL_SVM_
2382 static void trainSVMClassifier( cv::ocl::CvSVM_OCL& svm, const SVMTrainParamsExt& svmParamsExt, const string& objClassName, VocData& vocData,
2383                                Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2384                                const string& resPath )
2385 #else
2386 static void trainSVMClassifier( CvSVM& svm, const SVMTrainParamsExt& svmParamsExt, const string& objClassName, VocData& vocData,
2387                          Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2388                          const string& resPath )
2389 #endif
2390 {
2391     /* first check if a previously trained svm for the current class has been saved to file */
2392     string svmFilename = resPath + svmsDir + "/" + objClassName + ".xml.gz";
2393
2394     FileStorage fs( svmFilename, FileStorage::READ);
2395     if( fs.isOpened() )
2396     {
2397         cout << "*** LOADING SVM CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
2398         svm.load( svmFilename.c_str() );
2399     }
2400     else
2401     {
2402         cout << "*** TRAINING CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
2403         cout << "CALCULATING BOW VECTORS FOR TRAINING SET OF " << objClassName << "..." << endl;
2404
2405         // Get classification ground truth for images in the training set
2406         vector<ObdImage> images;
2407         vector<Mat> bowImageDescriptors;
2408         vector<char> objectPresent;
2409         vocData.getClassImages( objClassName, CV_OBD_TRAIN, images, objectPresent );
2410
2411         // Compute the bag of words vector for each image in the training set.
2412         calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
2413
2414         // Remove any images for which descriptors could not be calculated
2415         removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent );
2416
2417         CV_Assert( svmParamsExt.descPercent > 0.f && svmParamsExt.descPercent <= 1.f );
2418         if( svmParamsExt.descPercent < 1.f )
2419         {
2420             int descsToDelete = static_cast<int>(static_cast<float>(images.size())*(1.0-svmParamsExt.descPercent));
2421
2422             cout << "Using " << (images.size() - descsToDelete) << " of " << images.size() <<
2423                     " descriptors for training (" << svmParamsExt.descPercent*100.0 << " %)" << endl;
2424             removeBowImageDescriptorsByCount( images, bowImageDescriptors, objectPresent, svmParamsExt, descsToDelete );
2425         }
2426
2427         // Prepare the input matrices for SVM training.
2428         Mat trainData( (int)images.size(), bowExtractor->getVocabulary().rows, CV_32FC1 );
2429         Mat responses( (int)images.size(), 1, CV_32SC1 );
2430
2431         // Transfer bag of words vectors and responses across to the training data matrices
2432         for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
2433         {
2434             // Transfer image descriptor (bag of words vector) to training data matrix
2435             Mat submat = trainData.row((int)imageIdx);
2436             if( bowImageDescriptors[imageIdx].cols != bowExtractor->descriptorSize() )
2437             {
2438                 cout << "Error: computed bow image descriptor size " << bowImageDescriptors[imageIdx].cols
2439                      << " differs from vocabulary size" << bowExtractor->getVocabulary().cols << endl;
2440                 exit(-1);
2441             }
2442             bowImageDescriptors[imageIdx].copyTo( submat );
2443
2444             // Set response value
2445             responses.at<int>((int)imageIdx) = objectPresent[imageIdx] ? 1 : -1;
2446         }
2447
2448         cout << "TRAINING SVM FOR CLASS ..." << objClassName << "..." << endl;
2449         CvSVMParams svmParams;
2450         CvMat class_wts_cv;
2451         setSVMParams( svmParams, class_wts_cv, responses, svmParamsExt.balanceClasses );
2452         CvParamGrid c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid;
2453         setSVMTrainAutoParams( c_grid, gamma_grid,  p_grid, nu_grid, coef_grid, degree_grid );
2454         svm.train_auto( trainData, responses, Mat(), Mat(), svmParams, 10, c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
2455         cout << "SVM TRAINING FOR CLASS " << objClassName << " COMPLETED" << endl;
2456
2457         svm.save( svmFilename.c_str() );
2458         cout << "SAVED CLASSIFIER TO FILE" << endl;
2459     }
2460 }
2461
2462 #if defined HAVE_OPENCV_OCL && _OCL_SVM_
2463 static void computeConfidences( cv::ocl::CvSVM_OCL& svm, const string& objClassName, VocData& vocData,
2464                                Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2465                                const string& resPath )
2466 #else
2467 static void computeConfidences( CvSVM& svm, const string& objClassName, VocData& vocData,
2468                          Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2469                          const string& resPath )
2470 #endif
2471 {
2472     cout << "*** CALCULATING CONFIDENCES FOR CLASS " << objClassName << " ***" << endl;
2473     cout << "CALCULATING BOW VECTORS FOR TEST SET OF " << objClassName << "..." << endl;
2474     // Get classification ground truth for images in the test set
2475     vector<ObdImage> images;
2476     vector<Mat> bowImageDescriptors;
2477     vector<char> objectPresent;
2478     vocData.getClassImages( objClassName, CV_OBD_TEST, images, objectPresent );
2479
2480     // Compute the bag of words vector for each image in the test set
2481     calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
2482     // Remove any images for which descriptors could not be calculated
2483     removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent);
2484
2485     // Use the bag of words vectors to calculate classifier output for each image in test set
2486     cout << "CALCULATING CONFIDENCE SCORES FOR CLASS " << objClassName << "..." << endl;
2487     vector<float> confidences( images.size() );
2488     float signMul = 1.f;
2489     for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
2490     {
2491         if( imageIdx == 0 )
2492         {
2493             // In the first iteration, determine the sign of the positive class
2494             float classVal = confidences[imageIdx] = svm.predict( bowImageDescriptors[imageIdx], false );
2495             float scoreVal = confidences[imageIdx] = svm.predict( bowImageDescriptors[imageIdx], true );
2496             signMul = (classVal < 0) == (scoreVal < 0) ? 1.f : -1.f;
2497         }
2498         // svm output of decision function
2499         confidences[imageIdx] = signMul * svm.predict( bowImageDescriptors[imageIdx], true );
2500     }
2501
2502     cout << "WRITING QUERY RESULTS TO VOC RESULTS FILE FOR CLASS " << objClassName << "..." << endl;
2503     vocData.writeClassifierResultsFile( resPath + plotsDir, objClassName, CV_OBD_TEST, images, confidences, 1, true );
2504
2505     cout << "DONE - " << objClassName << endl;
2506     cout << "---------------------------------------------------------------" << endl;
2507 }
2508
2509 static void computeGnuPlotOutput( const string& resPath, const string& objClassName, VocData& vocData )
2510 {
2511     vector<float> precision, recall;
2512     float ap;
2513
2514     const string resultFile = vocData.getResultsFilename( objClassName, CV_VOC_TASK_CLASSIFICATION, CV_OBD_TEST);
2515     const string plotFile = resultFile.substr(0, resultFile.size()-4) + ".plt";
2516
2517     cout << "Calculating precision recall curve for class '" <<objClassName << "'" << endl;
2518     vocData.calcClassifierPrecRecall( resPath + plotsDir + "/" + resultFile, precision, recall, ap, true );
2519     cout << "Outputting to GNUPlot file..." << endl;
2520     vocData.savePrecRecallToGnuplot( resPath + plotsDir + "/" + plotFile, precision, recall, ap, objClassName, CV_VOC_PLOT_PNG );
2521 }
2522
2523
2524
2525
2526 int main(int argc, char** argv)
2527 {
2528     if( argc != 3 && argc != 6 )
2529     {
2530         help(argv);
2531         return -1;
2532     }
2533
2534     cv::initModule_nonfree();
2535
2536     const string vocPath = argv[1], resPath = argv[2];
2537
2538     // Read or set default parameters
2539     string vocName;
2540     DDMParams ddmParams;
2541     VocabTrainParams vocabTrainParams;
2542     SVMTrainParamsExt svmTrainParamsExt;
2543
2544     makeUsedDirs( resPath );
2545
2546     FileStorage paramsFS( resPath + "/" + paramsFile, FileStorage::READ );
2547     if( paramsFS.isOpened() )
2548     {
2549        readUsedParams( paramsFS.root(), vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
2550        CV_Assert( vocName == getVocName(vocPath) );
2551     }
2552     else
2553     {
2554         vocName = getVocName(vocPath);
2555         if( argc!= 6 )
2556         {
2557             cout << "Feature detector, descriptor extractor, descriptor matcher must be set" << endl;
2558             return -1;
2559         }
2560         ddmParams = DDMParams( argv[3], argv[4], argv[5] ); // from command line
2561         // vocabTrainParams and svmTrainParamsExt is set by defaults
2562         paramsFS.open( resPath + "/" + paramsFile, FileStorage::WRITE );
2563         if( paramsFS.isOpened() )
2564         {
2565             writeUsedParams( paramsFS, vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
2566             paramsFS.release();
2567         }
2568         else
2569         {
2570             cout << "File " << (resPath + "/" + paramsFile) << "can not be opened to write" << endl;
2571             return -1;
2572         }
2573     }
2574
2575     // Create detector, descriptor, matcher.
2576     Ptr<FeatureDetector> featureDetector = FeatureDetector::create( ddmParams.detectorType );
2577     Ptr<DescriptorExtractor> descExtractor = DescriptorExtractor::create( ddmParams.descriptorType );
2578     Ptr<BOWImgDescriptorExtractor> bowExtractor;
2579     if( !featureDetector || !descExtractor )
2580     {
2581         cout << "featureDetector or descExtractor was not created" << endl;
2582         return -1;
2583     }
2584     {
2585         Ptr<DescriptorMatcher> descMatcher = DescriptorMatcher::create( ddmParams.matcherType );
2586         if( !featureDetector || !descExtractor || !descMatcher )
2587         {
2588             cout << "descMatcher was not created" << endl;
2589             return -1;
2590         }
2591         bowExtractor = makePtr<BOWImgDescriptorExtractor>( descExtractor, descMatcher );
2592     }
2593
2594     // Print configuration to screen
2595     printUsedParams( vocPath, resPath, ddmParams, vocabTrainParams, svmTrainParamsExt );
2596     // Create object to work with VOC
2597     VocData vocData( vocPath, false );
2598
2599     // 1. Train visual word vocabulary if a pre-calculated vocabulary file doesn't already exist from previous run
2600     Mat vocabulary = trainVocabulary( resPath + "/" + vocabularyFile, vocData, vocabTrainParams,
2601                                       featureDetector, descExtractor );
2602     bowExtractor->setVocabulary( vocabulary );
2603
2604     // 2. Train a classifier and run a sample query for each object class
2605     const vector<string>& objClasses = vocData.getObjectClasses(); // object class list
2606     for( size_t classIdx = 0; classIdx < objClasses.size(); ++classIdx )
2607     {
2608         // Train a classifier on train dataset
2609 #if defined HAVE_OPENCV_OCL && _OCL_SVM_
2610         cv::ocl::CvSVM_OCL svm;
2611 #else
2612         CvSVM svm;
2613 #endif
2614         trainSVMClassifier( svm, svmTrainParamsExt, objClasses[classIdx], vocData,
2615                             bowExtractor, featureDetector, resPath );
2616
2617         // Now use the classifier over all images on the test dataset and rank according to score order
2618         // also calculating precision-recall etc.
2619         computeConfidences( svm, objClasses[classIdx], vocData,
2620                             bowExtractor, featureDetector, resPath );
2621         // Calculate precision/recall/ap and use GNUPlot to output to a pdf file
2622         computeGnuPlotOutput( resPath, objClasses[classIdx], vocData );
2623     }
2624     return 0;
2625 }