Merge pull request #3075 from akarsakov:ipp_imgproc_fix
[profile/ivi/opencv.git] / samples / cpp / bagofwords_classification.cpp
1 #include "opencv2/opencv_modules.hpp"
2 #include "opencv2/imgcodecs.hpp"
3 #include "opencv2/highgui/highgui.hpp"
4 #include "opencv2/imgproc/imgproc.hpp"
5 #include "opencv2/features2d/features2d.hpp"
6 #include "opencv2/nonfree/nonfree.hpp"
7 #include "opencv2/ml/ml.hpp"
8
9 #include <fstream>
10 #include <iostream>
11 #include <memory>
12 #include <functional>
13
14 #if defined WIN32 || defined _WIN32
15 #define WIN32_LEAN_AND_MEAN
16 #include <windows.h>
17 #undef min
18 #undef max
19 #include "sys/types.h"
20 #endif
21 #include <sys/stat.h>
22
23 #define DEBUG_DESC_PROGRESS
24
25 using namespace cv;
26 using namespace cv::ml;
27 using namespace std;
28
29 const string paramsFile = "params.xml";
30 const string vocabularyFile = "vocabulary.xml.gz";
31 const string bowImageDescriptorsDir = "/bowImageDescriptors";
32 const string svmsDir = "/svms";
33 const string plotsDir = "/plots";
34
35 static void help(char** argv)
36 {
37     cout << "\nThis program shows how to read in, train on and produce test results for the PASCAL VOC (Visual Object Challenge) data. \n"
38      << "It shows how to use detectors, descriptors and recognition methods \n"
39         "Using OpenCV version %s\n" << CV_VERSION << "\n"
40      << "Call: \n"
41     << "Format:\n ./" << argv[0] << " [VOC path] [result directory]  \n"
42     << "       or:  \n"
43     << " ./" << argv[0] << " [VOC path] [result directory] [feature detector] [descriptor extractor] [descriptor matcher] \n"
44     << "\n"
45     << "Input parameters: \n"
46     << "[VOC path]             Path to Pascal VOC data (e.g. /home/my/VOCdevkit/VOC2010). Note: VOC2007-VOC2010 are supported. \n"
47     << "[result directory]     Path to result diractory. Following folders will be created in [result directory]: \n"
48     << "                         bowImageDescriptors - to store image descriptors, \n"
49     << "                         svms - to store trained svms, \n"
50     << "                         plots - to store files for plots creating. \n"
51     << "[feature detector]     Feature detector name (e.g. SURF, FAST...) - see createFeatureDetector() function in detectors.cpp \n"
52     << "                         Currently 12/2010, this is FAST, STAR, SIFT, SURF, MSER, GFTT, HARRIS \n"
53     << "[descriptor extractor] Descriptor extractor name (e.g. SURF, SIFT) - see createDescriptorExtractor() function in descriptors.cpp \n"
54     << "                         Currently 12/2010, this is SURF, OpponentSIFT, SIFT, OpponentSURF, BRIEF \n"
55     << "[descriptor matcher]   Descriptor matcher name (e.g. BruteForce) - see createDescriptorMatcher() function in matchers.cpp \n"
56     << "                         Currently 12/2010, this is BruteForce, BruteForce-L1, FlannBased, BruteForce-Hamming, BruteForce-HammingLUT \n"
57     << "\n";
58 }
59
60 static void makeDir( const string& dir )
61 {
62 #if defined WIN32 || defined _WIN32
63     CreateDirectory( dir.c_str(), 0 );
64 #else
65     mkdir( dir.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH );
66 #endif
67 }
68
69 static void makeUsedDirs( const string& rootPath )
70 {
71     makeDir(rootPath + bowImageDescriptorsDir);
72     makeDir(rootPath + svmsDir);
73     makeDir(rootPath + plotsDir);
74 }
75
76 /****************************************************************************************\
77 *                    Classes to work with PASCAL VOC dataset                             *
78 \****************************************************************************************/
79 //
80 // TODO: refactor this part of the code
81 //
82
83
84 //used to specify the (sub-)dataset over which operations are performed
85 enum ObdDatasetType {CV_OBD_TRAIN, CV_OBD_TEST};
86
87 class ObdObject
88 {
89 public:
90     string object_class;
91     Rect boundingBox;
92 };
93
94 //extended object data specific to VOC
95 enum VocPose {CV_VOC_POSE_UNSPECIFIED, CV_VOC_POSE_FRONTAL, CV_VOC_POSE_REAR, CV_VOC_POSE_LEFT, CV_VOC_POSE_RIGHT};
96 class VocObjectData
97 {
98 public:
99     bool difficult;
100     bool occluded;
101     bool truncated;
102     VocPose pose;
103 };
104 //enum VocDataset {CV_VOC2007, CV_VOC2008, CV_VOC2009, CV_VOC2010};
105 enum VocPlotType {CV_VOC_PLOT_SCREEN, CV_VOC_PLOT_PNG};
106 enum VocGT {CV_VOC_GT_NONE, CV_VOC_GT_DIFFICULT, CV_VOC_GT_PRESENT};
107 enum VocConfCond {CV_VOC_CCOND_RECALL, CV_VOC_CCOND_SCORETHRESH};
108 enum VocTask {CV_VOC_TASK_CLASSIFICATION, CV_VOC_TASK_DETECTION};
109
110 class ObdImage
111 {
112 public:
113     ObdImage(string p_id, string p_path) : id(p_id), path(p_path) {}
114     string id;
115     string path;
116 };
117
118 //used by getDetectorGroundTruth to sort a two dimensional list of floats in descending order
119 class ObdScoreIndexSorter
120 {
121 public:
122     float score;
123     int image_idx;
124     int obj_idx;
125     bool operator < (const ObdScoreIndexSorter& compare) const {return (score < compare.score);}
126 };
127
128 class VocData
129 {
130 public:
131     VocData( const string& vocPath, bool useTestDataset )
132         { initVoc( vocPath, useTestDataset ); }
133     ~VocData(){}
134     /* functions for returning classification/object data for multiple images given an object class */
135     void getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
136     void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects);
137     void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth);
138     /* functions for returning object data for a single image given an image id */
139     ObdImage getObjects(const string& id, vector<ObdObject>& objects);
140     ObdImage getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
141     ObdImage getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth);
142     /* functions for returning the ground truth (present/absent) for groups of images */
143     void getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth);
144     void getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth);
145     int getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult = true);
146     /* functions for writing VOC-compatible results files */
147     void writeClassifierResultsFile(const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition = 1, const bool overwrite_ifexists = false);
148     /* functions for calculating metrics from a set of classification/detection results */
149     string getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition = -1, const int number = -1);
150     void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking);
151     void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap);
152     void calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile = false);
153     /* functions for calculating confusion matrices */
154     void calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values);
155     void calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult = true);
156     /* functions for outputting gnuplot output files */
157     void savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title = string(), const VocPlotType plot_type = CV_VOC_PLOT_SCREEN);
158     /* functions for reading in result/ground truth files */
159     void readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
160     void readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores);
161     void readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
162     /* functions for getting dataset info */
163     const vector<string>& getObjectClasses();
164     string getResultsDirectory();
165 protected:
166     void initVoc( const string& vocPath, const bool useTestDataset );
167     void initVoc2007to2010( const string& vocPath, const bool useTestDataset);
168     void readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present);
169     void readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores);
170     void readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
171     void extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
172     string getImagePath(const string& input_str);
173
174     void getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present);
175     void calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization = -1);
176
177     //test two bounding boxes to see if they meet the overlap criteria defined in the VOC documentation
178     float testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth);
179     //extract class and dataset name from a VOC-standard classification/detection results filename
180     void extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name);
181     //get classifier ground truth for a single image
182     bool getClassifierGroundTruthImage(const string& obj_class, const string& id);
183
184     //utility functions
185     void getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending = true);
186     int stringToInteger(const string input_str);
187     void readFileToString(const string filename, string& file_contents);
188     string integerToString(const int input_int);
189     string checkFilenamePathsep(const string filename, bool add_trailing_slash = false);
190     void convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images);
191     int extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents);
192     //utility sorter
193     struct orderingSorter
194     {
195         bool operator ()(std::pair<size_t, vector<float>::const_iterator> const& a, std::pair<size_t, vector<float>::const_iterator> const& b)
196         {
197             return (*a.second) > (*b.second);
198         }
199     };
200     //data members
201     string m_vocPath;
202     string m_vocName;
203     //string m_resPath;
204
205     string m_annotation_path;
206     string m_image_path;
207     string m_imageset_path;
208     string m_class_imageset_path;
209
210     vector<string> m_classifier_gt_all_ids;
211     vector<char> m_classifier_gt_all_present;
212     string m_classifier_gt_class;
213
214     //data members
215     string m_train_set;
216     string m_test_set;
217
218     vector<string> m_object_classes;
219
220
221     float m_min_overlap;
222     bool m_sampled_ap;
223 };
224
225
226 //Return the classification ground truth data for all images of a given VOC object class
227 //--------------------------------------------------------------------------------------
228 //INPUTS:
229 // - obj_class          The VOC object class identifier string
230 // - dataset            Specifies whether to extract images from the training or test set
231 //OUTPUTS:
232 // - images             An array of ObdImage containing info of all images extracted from the ground truth file
233 // - object_present     An array of bools specifying whether the object defined by 'obj_class' is present in each image or not
234 //NOTES:
235 // This function is primarily useful for the classification task, where only
236 // whether a given object is present or not in an image is required, and not each object instance's
237 // position etc.
238 void VocData::getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
239 {
240     string dataset_str;
241     //generate the filename of the classification ground-truth textfile for the object class
242     if (dataset == CV_OBD_TRAIN)
243     {
244         dataset_str = m_train_set;
245     } else {
246         dataset_str = m_test_set;
247     }
248
249     getClassImages_impl(obj_class, dataset_str, images, object_present);
250 }
251
252 void VocData::getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present)
253 {
254     //generate the filename of the classification ground-truth textfile for the object class
255     string gtFilename = m_class_imageset_path;
256     gtFilename.replace(gtFilename.find("%s"),2,obj_class);
257     gtFilename.replace(gtFilename.find("%s"),2,dataset_str);
258
259     //parse the ground truth file, storing in two separate vectors
260     //for the image code and the ground truth value
261     vector<string> image_codes;
262     readClassifierGroundTruth(gtFilename, image_codes, object_present);
263
264     //prepare output arrays
265     images.clear();
266
267     convertImageCodesToObdImages(image_codes, images);
268 }
269
270 //Return the object data for all images of a given VOC object class
271 //-----------------------------------------------------------------
272 //INPUTS:
273 // - obj_class          The VOC object class identifier string
274 // - dataset            Specifies whether to extract images from the training or test set
275 //OUTPUTS:
276 // - images             An array of ObdImage containing info of all images in chosen dataset (tag, path etc.)
277 // - objects            Contains the extended object info (bounding box etc.) for each object instance in each image
278 // - object_data        Contains VOC-specific extended object info (marked difficult etc.)
279 // - ground_truth       Specifies whether there are any difficult/non-difficult instances of the current
280 //                          object class within each image
281 //NOTES:
282 // This function returns extended object information in addition to the absent/present
283 // classification data returned by getClassImages. The objects returned for each image in the 'objects'
284 // array are of all object classes present in the image, and not just the class defined by 'obj_class'.
285 // 'ground_truth' can be used to determine quickly whether an object instance of the given class is present
286 // in an image or not.
287 void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects)
288 {
289     vector<vector<VocObjectData> > object_data;
290     vector<VocGT> ground_truth;
291
292     getClassObjects(obj_class,dataset,images,objects,object_data,ground_truth);
293 }
294
295 void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth)
296 {
297     //generate the filename of the classification ground-truth textfile for the object class
298     string gtFilename = m_class_imageset_path;
299     gtFilename.replace(gtFilename.find("%s"),2,obj_class);
300     if (dataset == CV_OBD_TRAIN)
301     {
302         gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
303     } else {
304         gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
305     }
306
307     //parse the ground truth file, storing in two separate vectors
308     //for the image code and the ground truth value
309     vector<string> image_codes;
310     vector<char> object_present;
311     readClassifierGroundTruth(gtFilename, image_codes, object_present);
312
313     //prepare output arrays
314     images.clear();
315     objects.clear();
316     object_data.clear();
317     ground_truth.clear();
318
319     string annotationFilename;
320     vector<ObdObject> image_objects;
321     vector<VocObjectData> image_object_data;
322     VocGT image_gt;
323
324     //transfer to output arrays and read in object data for each image
325     for (size_t i = 0; i < image_codes.size(); ++i)
326     {
327         ObdImage image = getObjects(obj_class, image_codes[i], image_objects, image_object_data, image_gt);
328
329         images.push_back(image);
330         objects.push_back(image_objects);
331         object_data.push_back(image_object_data);
332         ground_truth.push_back(image_gt);
333     }
334 }
335
336 //Return ground truth data for the objects present in an image with a given UID
337 //-----------------------------------------------------------------------------
338 //INPUTS:
339 // - id                 VOC Dataset unique identifier (string code in form YYYY_XXXXXX where YYYY is the year)
340 //OUTPUTS:
341 // - obj_class (*3)     Specifies the object class to use to resolve 'ground_truth'
342 // - objects            Contains the extended object info (bounding box etc.) for each object in the image
343 // - object_data (*2,3) Contains VOC-specific extended object info (marked difficult etc.)
344 // - ground_truth (*3)  Specifies whether there are any difficult/non-difficult instances of the current
345 //                          object class within the image
346 //RETURN VALUE:
347 // ObdImage containing path and other details of image file with given code
348 //NOTES:
349 // There are three versions of this function
350 //  * One returns a simple array of objects given an id [1]
351 //  * One returns the same as (1) plus VOC specific object data [2]
352 //  * One returns the same as (2) plus the ground_truth flag. This also requires an extra input obj_class [3]
353 ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects)
354 {
355     vector<VocObjectData> object_data;
356     ObdImage image = getObjects(id, objects, object_data);
357
358     return image;
359 }
360
361 ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
362 {
363     //first generate the filename of the annotation file
364     string annotationFilename = m_annotation_path;
365
366     annotationFilename.replace(annotationFilename.find("%s"),2,id);
367
368     //extract objects contained in the current image from the xml
369     extractVocObjects(annotationFilename,objects,object_data);
370
371     //generate image path from extracted string code
372     string path = getImagePath(id);
373
374     ObdImage image(id, path);
375     return image;
376 }
377
378 ObdImage VocData::getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth)
379 {
380
381     //extract object data (except for ground truth flag)
382     ObdImage image = getObjects(id,objects,object_data);
383
384     //pregenerate a flag to indicate whether the current class is present or not in the image
385     ground_truth = CV_VOC_GT_NONE;
386     //iterate through all objects in current image
387     for (size_t j = 0; j < objects.size(); ++j)
388     {
389         if (objects[j].object_class == obj_class)
390         {
391             if (object_data[j].difficult == false)
392             {
393                 //if at least one non-difficult example is present, this flag is always set to CV_VOC_GT_PRESENT
394                 ground_truth = CV_VOC_GT_PRESENT;
395                 break;
396             } else {
397                 //set if at least one object instance is present, but it is marked difficult
398                 ground_truth = CV_VOC_GT_DIFFICULT;
399             }
400         }
401     }
402
403     return image;
404 }
405
406 //Return ground truth data for the presence/absence of a given object class in an arbitrary array of images
407 //---------------------------------------------------------------------------------------------------------
408 //INPUTS:
409 // - obj_class          The VOC object class identifier string
410 // - images             An array of ObdImage OR strings containing the images for which ground truth
411 //                          will be computed
412 //OUTPUTS:
413 // - ground_truth       An output array indicating the presence/absence of obj_class within each image
414 void VocData::getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth)
415 {
416     vector<char>(images.size()).swap(ground_truth);
417
418     vector<ObdObject> objects;
419     vector<VocObjectData> object_data;
420     vector<char>::iterator gt_it = ground_truth.begin();
421     for (vector<ObdImage>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
422     {
423         //getObjects(obj_class, it->id, objects, object_data, voc_ground_truth);
424         (*gt_it) = (getClassifierGroundTruthImage(obj_class, it->id));
425     }
426 }
427
428 void VocData::getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth)
429 {
430     vector<char>(images.size()).swap(ground_truth);
431
432     vector<ObdObject> objects;
433     vector<VocObjectData> object_data;
434     vector<char>::iterator gt_it = ground_truth.begin();
435     for (vector<string>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
436     {
437         //getObjects(obj_class, (*it), objects, object_data, voc_ground_truth);
438         (*gt_it) = (getClassifierGroundTruthImage(obj_class, (*it)));
439     }
440 }
441
442 //Return ground truth data for the accuracy of detection results
443 //--------------------------------------------------------------
444 //INPUTS:
445 // - obj_class          The VOC object class identifier string
446 // - images             An array of ObdImage containing the images for which ground truth
447 //                          will be computed
448 // - bounding_boxes     A 2D input array containing the bounding box rects of the objects of
449 //                          obj_class which were detected in each image
450 //OUTPUTS:
451 // - ground_truth       A 2D output array indicating whether each object detection was accurate
452 //                          or not
453 // - detection_difficult A 2D output array indicating whether the detection fired on an object
454 //                          marked as 'difficult'. This allows it to be ignored if necessary
455 //                          (the voc documentation specifies objects marked as difficult
456 //                          have no effects on the results and are effectively ignored)
457 // - (ignore_difficult) If set to true, objects marked as difficult will be ignored when returning
458 //                          the number of hits for p-r normalization (default = true)
459 //RETURN VALUE:
460 //                      Returns the number of object hits in total in the gt to allow proper normalization
461 //                          of a p-r curve
462 //NOTES:
463 // As stated in the VOC documentation, multiple detections of the same object in an image are
464 // considered FALSE detections e.g. 5 detections of a single object is counted as 1 correct
465 // detection and 4 false detections - it is the responsibility of the participant's system
466 // to filter multiple detections from its output
467 int VocData::getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult)
468 {
469     int recall_normalization = 0;
470
471     /* first create a list of indices referring to the elements of bounding_boxes and scores in
472      * descending order of scores */
473     vector<ObdScoreIndexSorter> sorted_ids;
474     {
475         /* first count how many objects to allow preallocation */
476         size_t obj_count = 0;
477         CV_Assert(images.size() == bounding_boxes.size());
478         CV_Assert(scores.size() == bounding_boxes.size());
479         for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
480         {
481             CV_Assert(scores[im_idx].size() == bounding_boxes[im_idx].size());
482             obj_count += scores[im_idx].size();
483         }
484         /* preallocate id vector */
485         sorted_ids.resize(obj_count);
486         /* now copy across scores and indexes to preallocated vector */
487         int flat_pos = 0;
488         for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
489         {
490             for (size_t ob_idx = 0; ob_idx < scores[im_idx].size(); ++ob_idx)
491             {
492                 sorted_ids[flat_pos].score = scores[im_idx][ob_idx];
493                 sorted_ids[flat_pos].image_idx = (int)im_idx;
494                 sorted_ids[flat_pos].obj_idx = (int)ob_idx;
495                 ++flat_pos;
496             }
497         }
498         /* and sort the vector in descending order of score */
499         std::sort(sorted_ids.begin(),sorted_ids.end());
500         std::reverse(sorted_ids.begin(),sorted_ids.end());
501     }
502
503     /* prepare ground truth + difficult vector (1st dimension) */
504     vector<vector<char> >(images.size()).swap(ground_truth);
505     vector<vector<char> >(images.size()).swap(detection_difficult);
506     vector<vector<char> > detected(images.size());
507
508     vector<vector<ObdObject> > img_objects(images.size());
509     vector<vector<VocObjectData> > img_object_data(images.size());
510     /* preload object ground truth bounding box data */
511     {
512         vector<vector<ObdObject> > img_objects_all(images.size());
513         vector<vector<VocObjectData> > img_object_data_all(images.size());
514         for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
515         {
516             /* prepopulate ground truth bounding boxes */
517             getObjects(images[image_idx].id, img_objects_all[image_idx], img_object_data_all[image_idx]);
518             /* meanwhile, also set length of target ground truth + difficult vector to same as number of object detections (2nd dimension) */
519             ground_truth[image_idx].resize(bounding_boxes[image_idx].size());
520             detection_difficult[image_idx].resize(bounding_boxes[image_idx].size());
521         }
522
523         /* save only instances of the object class concerned */
524         for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
525         {
526             for (size_t obj_idx = 0; obj_idx < img_objects_all[image_idx].size(); ++obj_idx)
527             {
528                 if (img_objects_all[image_idx][obj_idx].object_class == obj_class)
529                 {
530                     img_objects[image_idx].push_back(img_objects_all[image_idx][obj_idx]);
531                     img_object_data[image_idx].push_back(img_object_data_all[image_idx][obj_idx]);
532                 }
533             }
534             detected[image_idx].resize(img_objects[image_idx].size(), false);
535         }
536     }
537
538     /* calculate the total number of objects in the ground truth for the current dataset */
539     {
540         vector<ObdImage> gt_images;
541         vector<char> gt_object_present;
542         getClassImages(obj_class, dataset, gt_images, gt_object_present);
543
544         for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
545         {
546             vector<ObdObject> gt_img_objects;
547             vector<VocObjectData> gt_img_object_data;
548             getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
549             for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
550             {
551                 if (gt_img_objects[obj_idx].object_class == obj_class)
552                 {
553                     if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
554                         ++recall_normalization;
555                 }
556             }
557         }
558     }
559
560 #ifdef PR_DEBUG
561     int printed_count = 0;
562 #endif
563     /* now iterate through detections in descending order of score, assigning to ground truth bounding boxes if possible */
564     for (size_t detect_idx = 0; detect_idx < sorted_ids.size(); ++detect_idx)
565     {
566         //read in indexes to make following code easier to read
567         int im_idx = sorted_ids[detect_idx].image_idx;
568         int ob_idx = sorted_ids[detect_idx].obj_idx;
569         //set ground truth for the current object to false by default
570         ground_truth[im_idx][ob_idx] = false;
571         detection_difficult[im_idx][ob_idx] = false;
572         float maxov = -1.0;
573         bool max_is_difficult = false;
574         int max_gt_obj_idx = -1;
575         //-- for each detected object iterate through objects present in the bounding box ground truth --
576         for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
577         {
578             if (detected[im_idx][gt_obj_idx] == false)
579             {
580                 //check if the detected object and ground truth object overlap by a sufficient margin
581                 float ov = testBoundingBoxesForOverlap(bounding_boxes[im_idx][ob_idx], img_objects[im_idx][gt_obj_idx].boundingBox);
582                 if (ov != -1.0)
583                 {
584                     //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
585                     if (ov > maxov)
586                     {
587                         maxov = ov;
588                         max_gt_obj_idx = (int)gt_obj_idx;
589                         //store whether the maximum detection is marked as difficult or not
590                         max_is_difficult = (img_object_data[im_idx][gt_obj_idx].difficult);
591                     }
592                 }
593             }
594         }
595         //-- if a match was found, set the ground truth of the current object to true --
596         if (maxov != -1.0)
597         {
598             CV_Assert(max_gt_obj_idx != -1);
599             ground_truth[im_idx][ob_idx] = true;
600             //store whether the maximum detection was marked as 'difficult' or not
601             detection_difficult[im_idx][ob_idx] = max_is_difficult;
602             //remove the ground truth object so it doesn't match with subsequent detected objects
603             //** this is the behaviour defined by the voc documentation **
604             detected[im_idx][max_gt_obj_idx] = true;
605         }
606 #ifdef PR_DEBUG
607         if (printed_count < 10)
608         {
609             cout << printed_count << ": id=" << images[im_idx].id << ", score=" << scores[im_idx][ob_idx] << " (" << ob_idx << ") [" << bounding_boxes[im_idx][ob_idx].x << "," <<
610                     bounding_boxes[im_idx][ob_idx].y << "," << bounding_boxes[im_idx][ob_idx].width + bounding_boxes[im_idx][ob_idx].x <<
611                     "," << bounding_boxes[im_idx][ob_idx].height + bounding_boxes[im_idx][ob_idx].y << "] detected=" << ground_truth[im_idx][ob_idx] <<
612                     ", difficult=" << detection_difficult[im_idx][ob_idx] << endl;
613             ++printed_count;
614             /* print ground truth */
615             for (int gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
616             {
617                 cout << "    GT: [" << img_objects[im_idx][gt_obj_idx].boundingBox.x << "," <<
618                         img_objects[im_idx][gt_obj_idx].boundingBox.y << "," << img_objects[im_idx][gt_obj_idx].boundingBox.width + img_objects[im_idx][gt_obj_idx].boundingBox.x <<
619                         "," << img_objects[im_idx][gt_obj_idx].boundingBox.height + img_objects[im_idx][gt_obj_idx].boundingBox.y << "]";
620                 if (gt_obj_idx == max_gt_obj_idx) cout << " <--- (" << maxov << " overlap)";
621                 cout << endl;
622             }
623         }
624 #endif
625     }
626
627     return recall_normalization;
628 }
629
630 //Write VOC-compliant classifier results file
631 //-------------------------------------------
632 //INPUTS:
633 // - obj_class          The VOC object class identifier string
634 // - dataset            Specifies whether working with the training or test set
635 // - images             An array of ObdImage containing the images for which data will be saved to the result file
636 // - scores             A corresponding array of confidence scores given a query
637 // - (competition)      If specified, defines which competition the results are for (see VOC documentation - default 1)
638 //NOTES:
639 // The result file path and filename are determined automatically using m_results_directory as a base
640 void VocData::writeClassifierResultsFile( const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition, const bool overwrite_ifexists)
641 {
642     CV_Assert(images.size() == scores.size());
643
644     string output_file_base, output_file;
645     if (dataset == CV_OBD_TRAIN)
646     {
647         output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_train_set + "_" + obj_class;
648     } else {
649         output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_test_set + "_" + obj_class;
650     }
651     output_file = output_file_base + ".txt";
652
653     //check if file exists, and if so create a numbered new file instead
654     if (overwrite_ifexists == false)
655     {
656         struct stat stFileInfo;
657         if (stat(output_file.c_str(),&stFileInfo) == 0)
658         {
659             string output_file_new;
660             int filenum = 0;
661             do
662             {
663                 ++filenum;
664                 output_file_new = output_file_base + "_" + integerToString(filenum);
665                 output_file = output_file_new + ".txt";
666             } while (stat(output_file.c_str(),&stFileInfo) == 0);
667         }
668     }
669
670     //output data to file
671     std::ofstream result_file(output_file.c_str());
672     if (result_file.is_open())
673     {
674         for (size_t i = 0; i < images.size(); ++i)
675         {
676             result_file << images[i].id << " " << scores[i] << endl;
677         }
678         result_file.close();
679     } else {
680         string err_msg = "could not open classifier results file '" + output_file + "' for writing. Before running for the first time, a 'results' subdirectory should be created within the VOC dataset base directory. e.g. if the VOC data is stored in /VOC/VOC2010 then the path /VOC/results must be created.";
681         CV_Error(Error::StsError,err_msg.c_str());
682     }
683 }
684
685 //---------------------------------------
686 //CALCULATE METRICS FROM VOC RESULTS DATA
687 //---------------------------------------
688
689 //Utility function to construct a VOC-standard classification results filename
690 //----------------------------------------------------------------------------
691 //INPUTS:
692 // - obj_class          The VOC object class identifier string
693 // - task               Specifies whether to generate a filename for the classification or detection task
694 // - dataset            Specifies whether working with the training or test set
695 // - (competition)      If specified, defines which competition the results are for (see VOC documentation
696 //                      default of -1 means this is set to 1 for the classification task and 3 for the detection task)
697 // - (number)           If specified and above 0, defines which of a number of duplicate results file produced for a given set of
698 //                      of settings should be used (this number will be added as a postfix to the filename)
699 //NOTES:
700 // This is primarily useful for returning the filename of a classification file previously computed using writeClassifierResultsFile
701 // for example when calling calcClassifierPrecRecall
702 string VocData::getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition, const int number)
703 {
704     if ((competition < 1) && (competition != -1))
705         CV_Error(Error::StsBadArg,"competition argument should be a positive non-zero number or -1 to accept the default");
706     if ((number < 1) && (number != -1))
707         CV_Error(Error::StsBadArg,"number argument should be a positive non-zero number or -1 to accept the default");
708
709     string dset, task_type;
710
711     if (dataset == CV_OBD_TRAIN)
712     {
713         dset = m_train_set;
714     } else {
715         dset = m_test_set;
716     }
717
718     int comp = competition;
719     if (task == CV_VOC_TASK_CLASSIFICATION)
720     {
721         task_type = "cls";
722         if (comp == -1) comp = 1;
723     } else {
724         task_type = "det";
725         if (comp == -1) comp = 3;
726     }
727
728     stringstream ss;
729     if (number < 1)
730     {
731         ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << ".txt";
732     } else {
733         ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << "_" << number << ".txt";
734     }
735
736     string filename = ss.str();
737     return filename;
738 }
739
740 //Calculate metrics for classification results
741 //--------------------------------------------
742 //INPUTS:
743 // - ground_truth       A vector of booleans determining whether the currently tested class is present in each input image
744 // - scores             A vector containing the similarity score for each input image (higher is more similar)
745 //OUTPUTS:
746 // - precision          A vector containing the precision calculated at each datapoint of a p-r curve generated from the result set
747 // - recall             A vector containing the recall calculated at each datapoint of a p-r curve generated from the result set
748 // - ap                The ap metric calculated from the result set
749 // - (ranking)          A vector of the same length as 'ground_truth' and 'scores' containing the order of the indices in both of
750 //                      these arrays when sorting by the ranking score in descending order
751 //NOTES:
752 // The result file path and filename are determined automatically using m_results_directory as a base
753 void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking)
754 {
755     vector<char> res_ground_truth;
756     getClassifierGroundTruth(obj_class, images, res_ground_truth);
757
758     calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
759 }
760
761 void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap)
762 {
763     vector<char> res_ground_truth;
764     getClassifierGroundTruth(obj_class, images, res_ground_truth);
765
766     vector<size_t> ranking;
767     calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
768 }
769
770 //< Overloaded version which accepts VOC classification result file input instead of array of scores/ground truth >
771 //INPUTS:
772 // - input_file         The path to the VOC standard results file to use for calculating precision/recall
773 //                      If a full path is not specified, it is assumed this file is in the VOC standard results directory
774 //                      A VOC standard filename can be retrieved (as used by writeClassifierResultsFile) by calling  getClassifierResultsFilename
775
776 void VocData::calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile)
777 {
778     //read in classification results file
779     vector<string> res_image_codes;
780     vector<float> res_scores;
781
782     string input_file_std = checkFilenamePathsep(input_file);
783     readClassifierResultsFile(input_file_std, res_image_codes, res_scores);
784
785     //extract the object class and dataset from the results file filename
786     string class_name, dataset_name;
787     extractDataFromResultsFilename(input_file_std, class_name, dataset_name);
788
789     //generate the ground truth for the images extracted from the results file
790     vector<char> res_ground_truth;
791
792     getClassifierGroundTruth(class_name, res_image_codes, res_ground_truth);
793
794     if (outputRankingFile)
795     {
796         /* 1. store sorting order by score (descending) in 'order' */
797         vector<std::pair<size_t, vector<float>::const_iterator> > order(res_scores.size());
798
799         size_t n = 0;
800         for (vector<float>::const_iterator it = res_scores.begin(); it != res_scores.end(); ++it, ++n)
801             order[n] = make_pair(n, it);
802
803         std::sort(order.begin(),order.end(),orderingSorter());
804
805         /* 2. save ranking results to text file */
806         string input_file_std1 = checkFilenamePathsep(input_file);
807         size_t fnamestart = input_file_std1.rfind("/");
808         string scoregt_file_str = input_file_std1.substr(0,fnamestart+1) + "scoregt_" + class_name + ".txt";
809         std::ofstream scoregt_file(scoregt_file_str.c_str());
810         if (scoregt_file.is_open())
811         {
812             for (size_t i = 0; i < res_scores.size(); ++i)
813             {
814                 scoregt_file << res_image_codes[order[i].first] << " " << res_scores[order[i].first] << " " << res_ground_truth[order[i].first] << endl;
815             }
816             scoregt_file.close();
817         } else {
818             string err_msg = "could not open scoregt file '" + scoregt_file_str + "' for writing.";
819             CV_Error(Error::StsError,err_msg.c_str());
820         }
821     }
822
823     //finally, calculate precision+recall+ap
824     vector<size_t> ranking;
825     calcPrecRecall_impl(res_ground_truth,res_scores,precision,recall,ap,ranking);
826 }
827
828 //< Protected implementation of Precision-Recall calculation used by both calcClassifierPrecRecall and calcDetectorPrecRecall >
829
830 void VocData::calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization)
831 {
832     CV_Assert(ground_truth.size() == scores.size());
833
834     //add extra element for p-r at 0 recall (in case that first retrieved is positive)
835     vector<float>(scores.size()+1).swap(precision);
836     vector<float>(scores.size()+1).swap(recall);
837
838     // SORT RESULTS BY THEIR SCORE
839     /* 1. store sorting order in 'order' */
840     VocData::getSortOrder(scores, ranking);
841
842 #ifdef PR_DEBUG
843     std::ofstream scoregt_file("D:/pr.txt");
844     if (scoregt_file.is_open())
845     {
846        for (int i = 0; i < scores.size(); ++i)
847        {
848            scoregt_file << scores[ranking[i]] << " " << ground_truth[ranking[i]] << endl;
849        }
850        scoregt_file.close();
851     }
852 #endif
853
854     // CALCULATE PRECISION+RECALL
855
856     int retrieved_hits = 0;
857
858     int recall_norm;
859     if (recall_normalization != -1)
860     {
861         recall_norm = recall_normalization;
862     } else {
863         recall_norm = (int)std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<char>(),(char)1));
864     }
865
866     ap = 0;
867     recall[0] = 0;
868     for (size_t idx = 0; idx < ground_truth.size(); ++idx)
869     {
870         if (ground_truth[ranking[idx]] != 0) ++retrieved_hits;
871
872         precision[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(idx+1);
873         recall[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(recall_norm);
874
875         if (idx == 0)
876         {
877             //add further point at 0 recall with the same precision value as the first computed point
878             precision[idx] = precision[idx+1];
879         }
880         if (recall[idx+1] == 1.0)
881         {
882             //if recall = 1, then end early as all positive images have been found
883             recall.resize(idx+2);
884             precision.resize(idx+2);
885             break;
886         }
887     }
888
889     /* ap calculation */
890     if (m_sampled_ap == false)
891     {
892         // FOR VOC2010+ AP IS CALCULATED FROM ALL DATAPOINTS
893         /* make precision monotonically decreasing for purposes of calculating ap */
894         vector<float> precision_monot(precision.size());
895         vector<float>::iterator prec_m_it = precision_monot.begin();
896         for (vector<float>::iterator prec_it = precision.begin(); prec_it != precision.end(); ++prec_it, ++prec_m_it)
897         {
898             vector<float>::iterator max_elem;
899             max_elem = std::max_element(prec_it,precision.end());
900             (*prec_m_it) = (*max_elem);
901         }
902         /* calculate ap */
903         for (size_t idx = 0; idx < (recall.size()-1); ++idx)
904         {
905             ap += (recall[idx+1] - recall[idx])*precision_monot[idx+1] +   //no need to take min of prec - is monotonically decreasing
906                     0.5f*(recall[idx+1] - recall[idx])*std::abs(precision_monot[idx+1] - precision_monot[idx]);
907         }
908     } else {
909         // FOR BEFORE VOC2010 AP IS CALCULATED BY SAMPLING PRECISION AT RECALL 0.0,0.1,..,1.0
910
911         for (float recall_pos = 0.f; recall_pos <= 1.f; recall_pos += 0.1f)
912         {
913             //find iterator of the precision corresponding to the first recall >= recall_pos
914             vector<float>::iterator recall_it = recall.begin();
915             vector<float>::iterator prec_it = precision.begin();
916
917             while ((*recall_it) < recall_pos)
918             {
919                 ++recall_it;
920                 ++prec_it;
921                 if (recall_it == recall.end()) break;
922             }
923
924             /* if no recall >= recall_pos found, this level of recall is never reached so stop adding to ap */
925             if (recall_it == recall.end()) break;
926
927             /* if the prec_it is valid, compute the max precision at this level of recall or higher */
928             vector<float>::iterator max_prec = std::max_element(prec_it,precision.end());
929
930             ap += (*max_prec)/11;
931         }
932     }
933 }
934
935 /* functions for calculating confusion matrix rows */
936
937 //Calculate rows of a confusion matrix
938 //------------------------------------
939 //INPUTS:
940 // - obj_class          The VOC object class identifier string for the confusion matrix row to compute
941 // - images             An array of ObdImage containing the images to use for the computation
942 // - scores             A corresponding array of confidence scores for the presence of obj_class in each image
943 // - cond               Defines whether to use a cut off point based on recall (CV_VOC_CCOND_RECALL) or score
944 //                      (CV_VOC_CCOND_SCORETHRESH) the latter is useful for classifier detections where positive
945 //                      values are positive detections and negative values are negative detections
946 // - threshold          Threshold value for cond. In case of CV_VOC_CCOND_RECALL, is proportion recall (e.g. 0.5).
947 //                      In the case of CV_VOC_CCOND_SCORETHRESH is the value above which to count results.
948 //OUTPUTS:
949 // - output_headers     An output vector of object class headers for the confusion matrix row
950 // - output_values      An output vector of values for the confusion matrix row corresponding to the classes
951 //                      defined in output_headers
952 //NOTES:
953 // The methodology used by the classifier version of this function is that true positives have a single unit
954 // added to the obj_class column in the confusion matrix row, whereas false positives have a single unit
955 // distributed in proportion between all the columns in the confusion matrix row corresponding to the objects
956 // present in the image.
957 void VocData::calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values)
958 {
959     CV_Assert(images.size() == scores.size());
960
961     // SORT RESULTS BY THEIR SCORE
962     /* 1. store sorting order in 'ranking' */
963     vector<size_t> ranking;
964     VocData::getSortOrder(scores, ranking);
965
966     // CALCULATE CONFUSION MATRIX ENTRIES
967     /* prepare object category headers */
968     output_headers = m_object_classes;
969     vector<float>(output_headers.size(),0.0).swap(output_values);
970     /* find the index of the target object class in the headers for later use */
971     int target_idx;
972     {
973         vector<string>::iterator target_idx_it = std::find(output_headers.begin(),output_headers.end(),obj_class);
974         /* if the target class can not be found, raise an exception */
975         if (target_idx_it == output_headers.end())
976         {
977             string err_msg = "could not find the target object class '" + obj_class + "' in list of valid classes.";
978             CV_Error(Error::StsError,err_msg.c_str());
979         }
980         /* convert iterator to index */
981         target_idx = (int)std::distance(output_headers.begin(),target_idx_it);
982     }
983
984     /* prepare variables related to calculating recall if using the recall threshold */
985     int retrieved_hits = 0;
986     int total_relevant = 0;
987     if (cond == CV_VOC_CCOND_RECALL)
988     {
989         vector<char> ground_truth;
990         /* in order to calculate the total number of relevant images for normalization of recall
991             it's necessary to extract the ground truth for the images under consideration */
992         getClassifierGroundTruth(obj_class, images, ground_truth);
993         total_relevant = (int)std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<char>(),(char)1));
994     }
995
996     /* iterate through images */
997     vector<ObdObject> img_objects;
998     vector<VocObjectData> img_object_data;
999     int total_images = 0;
1000     for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
1001     {
1002         /* if using the score as the break condition, check for it now */
1003         if (cond == CV_VOC_CCOND_SCORETHRESH)
1004         {
1005             if (scores[ranking[image_idx]] <= threshold) break;
1006         }
1007         /* if continuing for this iteration, increment the image counter for later normalization */
1008         ++total_images;
1009         /* for each image retrieve the objects contained */
1010         getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
1011         //check if the tested for object class is present
1012         if (getClassifierGroundTruthImage(obj_class, images[ranking[image_idx]].id))
1013         {
1014             //if the target class is present, assign fully to the target class element in the confusion matrix row
1015             output_values[target_idx] += 1.0;
1016             if (cond == CV_VOC_CCOND_RECALL) ++retrieved_hits;
1017         } else {
1018             //first delete all objects marked as difficult
1019             for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
1020             {
1021                 if (img_object_data[obj_idx].difficult == true)
1022                 {
1023                     vector<ObdObject>::iterator it1 = img_objects.begin();
1024                     std::advance(it1,obj_idx);
1025                     img_objects.erase(it1);
1026                     vector<VocObjectData>::iterator it2 = img_object_data.begin();
1027                     std::advance(it2,obj_idx);
1028                     img_object_data.erase(it2);
1029                     --obj_idx;
1030                 }
1031             }
1032             //if the target class is not present, add values to the confusion matrix row in equal proportions to all objects present in the image
1033             for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
1034             {
1035                 //find the index of the currently considered object
1036                 vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[obj_idx].object_class);
1037                 //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
1038                 if (class_idx_it == output_headers.end())
1039                 {
1040                     string err_msg = "could not find object class '" + img_objects[obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
1041                     CV_Error(Error::StsError,err_msg.c_str());
1042                 }
1043                 /* convert iterator to index */
1044                 int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
1045                 //add to confusion matrix row in proportion
1046                 output_values[class_idx] += 1.f/static_cast<float>(img_objects.size());
1047             }
1048         }
1049         //check break conditions if breaking on certain level of recall
1050         if (cond == CV_VOC_CCOND_RECALL)
1051         {
1052             if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
1053         }
1054     }
1055     /* finally, normalize confusion matrix row */
1056     for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
1057     {
1058         (*it) /= static_cast<float>(total_images);
1059     }
1060 }
1061
1062 // NOTE: doesn't ignore repeated detections
1063 void VocData::calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult)
1064 {
1065     CV_Assert(images.size() == scores.size());
1066     CV_Assert(images.size() == bounding_boxes.size());
1067
1068     //collapse scores and ground_truth vectors into 1D vectors to allow ranking
1069     /* define final flat vectors */
1070     vector<string> images_flat;
1071     vector<float> scores_flat;
1072     vector<Rect> bounding_boxes_flat;
1073     {
1074         /* first count how many objects to allow preallocation */
1075         int obj_count = 0;
1076         CV_Assert(scores.size() == bounding_boxes.size());
1077         for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
1078         {
1079             CV_Assert(scores[img_idx].size() == bounding_boxes[img_idx].size());
1080             for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
1081             {
1082                 ++obj_count;
1083             }
1084         }
1085         /* preallocate vectors */
1086         images_flat.resize(obj_count);
1087         scores_flat.resize(obj_count);
1088         bounding_boxes_flat.resize(obj_count);
1089         /* now copy across to preallocated vectors */
1090         int flat_pos = 0;
1091         for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
1092         {
1093             for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
1094             {
1095                 images_flat[flat_pos] = images[img_idx].id;
1096                 scores_flat[flat_pos] = scores[img_idx][obj_idx];
1097                 bounding_boxes_flat[flat_pos] = bounding_boxes[img_idx][obj_idx];
1098                 ++flat_pos;
1099             }
1100         }
1101     }
1102
1103     // SORT RESULTS BY THEIR SCORE
1104     /* 1. store sorting order in 'ranking' */
1105     vector<size_t> ranking;
1106     VocData::getSortOrder(scores_flat, ranking);
1107
1108     // CALCULATE CONFUSION MATRIX ENTRIES
1109     /* prepare object category headers */
1110     output_headers = m_object_classes;
1111     output_headers.push_back("background");
1112     vector<float>(output_headers.size(),0.0).swap(output_values);
1113
1114     /* prepare variables related to calculating recall if using the recall threshold */
1115     int retrieved_hits = 0;
1116     int total_relevant = 0;
1117     if (cond == CV_VOC_CCOND_RECALL)
1118     {
1119 //        vector<char> ground_truth;
1120 //        /* in order to calculate the total number of relevant images for normalization of recall
1121 //            it's necessary to extract the ground truth for the images under consideration */
1122 //        getClassifierGroundTruth(obj_class, images, ground_truth);
1123 //        total_relevant = std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<bool>(),true));
1124         /* calculate the total number of objects in the ground truth for the current dataset */
1125         vector<ObdImage> gt_images;
1126         vector<char> gt_object_present;
1127         getClassImages(obj_class, dataset, gt_images, gt_object_present);
1128
1129         for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
1130         {
1131             vector<ObdObject> gt_img_objects;
1132             vector<VocObjectData> gt_img_object_data;
1133             getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
1134             for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
1135             {
1136                 if (gt_img_objects[obj_idx].object_class == obj_class)
1137                 {
1138                     if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
1139                         ++total_relevant;
1140                 }
1141             }
1142         }
1143     }
1144
1145     /* iterate through objects */
1146     vector<ObdObject> img_objects;
1147     vector<VocObjectData> img_object_data;
1148     int total_objects = 0;
1149     for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
1150     {
1151         /* if using the score as the break condition, check for it now */
1152         if (cond == CV_VOC_CCOND_SCORETHRESH)
1153         {
1154             if (scores_flat[ranking[image_idx]] <= threshold) break;
1155         }
1156         /* increment the image counter for later normalization */
1157         ++total_objects;
1158         /* for each image retrieve the objects contained */
1159         getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
1160
1161         //find the ground truth object which has the highest overlap score with the detected object
1162         float maxov = -1.0;
1163         int max_gt_obj_idx = -1;
1164         //-- for each detected object iterate through objects present in ground truth --
1165         for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects.size(); ++gt_obj_idx)
1166         {
1167             //check difficulty flag
1168             if (ignore_difficult || (img_object_data[gt_obj_idx].difficult == false))
1169             {
1170                 //if the class matches, then check if the detected object and ground truth object overlap by a sufficient margin
1171                 float ov = testBoundingBoxesForOverlap(bounding_boxes_flat[ranking[image_idx]], img_objects[gt_obj_idx].boundingBox);
1172                 if (ov != -1.f)
1173                 {
1174                     //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
1175                     if (ov > maxov)
1176                     {
1177                         maxov = ov;
1178                         max_gt_obj_idx = (int)gt_obj_idx;
1179                     }
1180                 }
1181             }
1182         }
1183
1184         //assign to appropriate object class if an object was detected
1185         if (maxov != -1.0)
1186         {
1187             //find the index of the currently considered object
1188             vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[max_gt_obj_idx].object_class);
1189             //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
1190             if (class_idx_it == output_headers.end())
1191             {
1192                 string err_msg = "could not find object class '" + img_objects[max_gt_obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
1193                 CV_Error(Error::StsError,err_msg.c_str());
1194             }
1195             /* convert iterator to index */
1196             int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
1197             //add to confusion matrix row in proportion
1198             output_values[class_idx] += 1.0;
1199         } else {
1200             //otherwise assign to background class
1201             output_values[output_values.size()-1] += 1.0;
1202         }
1203
1204         //check break conditions if breaking on certain level of recall
1205         if (cond == CV_VOC_CCOND_RECALL)
1206         {
1207             if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
1208         }
1209     }
1210
1211     /* finally, normalize confusion matrix row */
1212     for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
1213     {
1214         (*it) /= static_cast<float>(total_objects);
1215     }
1216 }
1217
1218 //Save Precision-Recall results to a p-r curve in GNUPlot format
1219 //--------------------------------------------------------------
1220 //INPUTS:
1221 // - output_file        The file to which to save the GNUPlot data file. If only a filename is specified, the data
1222 //                      file is saved to the standard VOC results directory.
1223 // - precision          Vector of precisions as returned from calcClassifier/DetectorPrecRecall
1224 // - recall             Vector of recalls as returned from calcClassifier/DetectorPrecRecall
1225 // - ap                ap as returned from calcClassifier/DetectorPrecRecall
1226 // - (title)            Title to use for the plot (if not specified, just the ap is printed as the title)
1227 //                      This also specifies the filename of the output file if printing to pdf
1228 // - (plot_type)        Specifies whether to instruct GNUPlot to save to a PDF file (CV_VOC_PLOT_PDF) or directly
1229 //                      to screen (CV_VOC_PLOT_SCREEN) in the datafile
1230 //NOTES:
1231 // The GNUPlot data file can be executed using GNUPlot from the commandline in the following way:
1232 //      >> GNUPlot <output_file>
1233 // This will then display the p-r curve on the screen or save it to a pdf file depending on plot_type
1234
1235 void VocData::savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title, const VocPlotType plot_type)
1236 {
1237     string output_file_std = checkFilenamePathsep(output_file);
1238
1239     //if no directory is specified, by default save the output file in the results directory
1240 //    if (output_file_std.find("/") == output_file_std.npos)
1241 //    {
1242 //        output_file_std = m_results_directory + output_file_std;
1243 //    }
1244
1245     std::ofstream plot_file(output_file_std.c_str());
1246
1247     if (plot_file.is_open())
1248     {
1249         plot_file << "set xrange [0:1]" << endl;
1250         plot_file << "set yrange [0:1]" << endl;
1251         plot_file << "set size square" << endl;
1252         string title_text = title;
1253         if (title_text.size() == 0) title_text = "Precision-Recall Curve";
1254         plot_file << "set title \"" << title_text << " (ap: " << ap << ")\"" << endl;
1255         plot_file << "set xlabel \"Recall\"" << endl;
1256         plot_file << "set ylabel \"Precision\"" << endl;
1257         plot_file << "set style data lines" << endl;
1258         plot_file << "set nokey" << endl;
1259         if (plot_type == CV_VOC_PLOT_PNG)
1260         {
1261             plot_file << "set terminal png" << endl;
1262             string pdf_filename;
1263             if (title.size() != 0)
1264             {
1265                 pdf_filename = title;
1266             } else {
1267                 pdf_filename = "prcurve";
1268             }
1269             plot_file << "set out \"" << title << ".png\"" << endl;
1270         }
1271         plot_file << "plot \"-\" using 1:2" << endl;
1272         plot_file << "# X Y" << endl;
1273         CV_Assert(precision.size() == recall.size());
1274         for (size_t i = 0; i < precision.size(); ++i)
1275         {
1276             plot_file << "  " << recall[i] << " " << precision[i] << endl;
1277         }
1278         plot_file << "end" << endl;
1279         if (plot_type == CV_VOC_PLOT_SCREEN)
1280         {
1281             plot_file << "pause -1" << endl;
1282         }
1283         plot_file.close();
1284     } else {
1285         string err_msg = "could not open plot file '" + output_file_std + "' for writing.";
1286         CV_Error(Error::StsError,err_msg.c_str());
1287     }
1288 }
1289
1290 void VocData::readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
1291 {
1292     images.clear();
1293
1294     string gtFilename = m_class_imageset_path;
1295     gtFilename.replace(gtFilename.find("%s"),2,obj_class);
1296     if (dataset == CV_OBD_TRAIN)
1297     {
1298         gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
1299     } else {
1300         gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
1301     }
1302
1303     vector<string> image_codes;
1304     readClassifierGroundTruth(gtFilename, image_codes, object_present);
1305
1306     convertImageCodesToObdImages(image_codes, images);
1307 }
1308
1309 void VocData::readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores)
1310 {
1311     images.clear();
1312
1313     string input_file_std = checkFilenamePathsep(input_file);
1314
1315     //if no directory is specified, by default search for the input file in the results directory
1316 //    if (input_file_std.find("/") == input_file_std.npos)
1317 //    {
1318 //        input_file_std = m_results_directory + input_file_std;
1319 //    }
1320
1321     vector<string> image_codes;
1322     readClassifierResultsFile(input_file_std, image_codes, scores);
1323
1324     convertImageCodesToObdImages(image_codes, images);
1325 }
1326
1327 void VocData::readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
1328 {
1329     images.clear();
1330
1331     string input_file_std = checkFilenamePathsep(input_file);
1332
1333     //if no directory is specified, by default search for the input file in the results directory
1334 //    if (input_file_std.find("/") == input_file_std.npos)
1335 //    {
1336 //        input_file_std = m_results_directory + input_file_std;
1337 //    }
1338
1339     vector<string> image_codes;
1340     readDetectorResultsFile(input_file_std, image_codes, scores, bounding_boxes);
1341
1342     convertImageCodesToObdImages(image_codes, images);
1343 }
1344
1345 const vector<string>& VocData::getObjectClasses()
1346 {
1347     return m_object_classes;
1348 }
1349
1350 //string VocData::getResultsDirectory()
1351 //{
1352 //    return m_results_directory;
1353 //}
1354
1355 //---------------------------------------------------------
1356 // Protected Functions ------------------------------------
1357 //---------------------------------------------------------
1358
1359 static string getVocName( const string& vocPath )
1360 {
1361     size_t found = vocPath.rfind( '/' );
1362     if( found == string::npos )
1363     {
1364         found = vocPath.rfind( '\\' );
1365         if( found == string::npos )
1366             return vocPath;
1367     }
1368     return vocPath.substr(found + 1, vocPath.size() - found);
1369 }
1370
1371 void VocData::initVoc( const string& vocPath, const bool useTestDataset )
1372 {
1373     initVoc2007to2010( vocPath, useTestDataset );
1374 }
1375
1376 //Initialize file paths and settings for the VOC 2010 dataset
1377 //-----------------------------------------------------------
1378 void VocData::initVoc2007to2010( const string& vocPath, const bool useTestDataset )
1379 {
1380     //check format of root directory and modify if necessary
1381
1382     m_vocName = getVocName( vocPath );
1383
1384     CV_Assert( !m_vocName.compare("VOC2007") || !m_vocName.compare("VOC2008") ||
1385                !m_vocName.compare("VOC2009") || !m_vocName.compare("VOC2010") );
1386
1387     m_vocPath = checkFilenamePathsep( vocPath, true );
1388
1389     if (useTestDataset)
1390     {
1391         m_train_set = "trainval";
1392         m_test_set = "test";
1393     } else {
1394         m_train_set = "train";
1395         m_test_set = "val";
1396     }
1397
1398     // initialize main classification/detection challenge paths
1399     m_annotation_path = m_vocPath + "/Annotations/%s.xml";
1400     m_image_path = m_vocPath + "/JPEGImages/%s.jpg";
1401     m_imageset_path = m_vocPath + "/ImageSets/Main/%s.txt";
1402     m_class_imageset_path = m_vocPath + "/ImageSets/Main/%s_%s.txt";
1403
1404     //define available object_classes for VOC2010 dataset
1405     m_object_classes.push_back("aeroplane");
1406     m_object_classes.push_back("bicycle");
1407     m_object_classes.push_back("bird");
1408     m_object_classes.push_back("boat");
1409     m_object_classes.push_back("bottle");
1410     m_object_classes.push_back("bus");
1411     m_object_classes.push_back("car");
1412     m_object_classes.push_back("cat");
1413     m_object_classes.push_back("chair");
1414     m_object_classes.push_back("cow");
1415     m_object_classes.push_back("diningtable");
1416     m_object_classes.push_back("dog");
1417     m_object_classes.push_back("horse");
1418     m_object_classes.push_back("motorbike");
1419     m_object_classes.push_back("person");
1420     m_object_classes.push_back("pottedplant");
1421     m_object_classes.push_back("sheep");
1422     m_object_classes.push_back("sofa");
1423     m_object_classes.push_back("train");
1424     m_object_classes.push_back("tvmonitor");
1425
1426     m_min_overlap = 0.5;
1427
1428     //up until VOC 2010, ap was calculated by sampling p-r curve, not taking complete curve
1429     m_sampled_ap = ((m_vocName == "VOC2007") || (m_vocName == "VOC2008") || (m_vocName == "VOC2009"));
1430 }
1431
1432 //Read a VOC classification ground truth text file for a given object class and dataset
1433 //-------------------------------------------------------------------------------------
1434 //INPUTS:
1435 // - filename           The path of the text file to read
1436 //OUTPUTS:
1437 // - image_codes        VOC image codes extracted from the GT file in the form 20XX_XXXXXX where the first four
1438 //                          digits specify the year of the dataset, and the last group specifies a unique ID
1439 // - object_present     For each image in the 'image_codes' array, specifies whether the object class described
1440 //                          in the loaded GT file is present or not
1441 void VocData::readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present)
1442 {
1443     image_codes.clear();
1444     object_present.clear();
1445
1446     std::ifstream gtfile(filename.c_str());
1447     if (!gtfile.is_open())
1448     {
1449         string err_msg = "could not open VOC ground truth textfile '" + filename + "'.";
1450         CV_Error(Error::StsError,err_msg.c_str());
1451     }
1452
1453     string line;
1454     string image;
1455     int obj_present = 0;
1456     while (!gtfile.eof())
1457     {
1458         std::getline(gtfile,line);
1459         std::istringstream iss(line);
1460         iss >> image >> obj_present;
1461         if (!iss.fail())
1462         {
1463             image_codes.push_back(image);
1464             object_present.push_back(obj_present == 1);
1465         } else {
1466             if (!gtfile.eof()) CV_Error(Error::StsParseError,"error parsing VOC ground truth textfile.");
1467         }
1468     }
1469     gtfile.close();
1470 }
1471
1472 void VocData::readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores)
1473 {
1474     //check if results file exists
1475     std::ifstream result_file(input_file.c_str());
1476     if (result_file.is_open())
1477     {
1478         string line;
1479         string image;
1480         float score;
1481         //read in the results file
1482         while (!result_file.eof())
1483         {
1484             std::getline(result_file,line);
1485             std::istringstream iss(line);
1486             iss >> image >> score;
1487             if (!iss.fail())
1488             {
1489                 image_codes.push_back(image);
1490                 scores.push_back(score);
1491             } else {
1492                 if(!result_file.eof()) CV_Error(Error::StsParseError,"error parsing VOC classifier results file.");
1493             }
1494         }
1495         result_file.close();
1496     } else {
1497         string err_msg = "could not open classifier results file '" + input_file + "' for reading.";
1498         CV_Error(Error::StsError,err_msg.c_str());
1499     }
1500 }
1501
1502 void VocData::readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
1503 {
1504     image_codes.clear();
1505     scores.clear();
1506     bounding_boxes.clear();
1507
1508     //check if results file exists
1509     std::ifstream result_file(input_file.c_str());
1510     if (result_file.is_open())
1511     {
1512         string line;
1513         string image;
1514         Rect bounding_box;
1515         float score;
1516         //read in the results file
1517         while (!result_file.eof())
1518         {
1519             std::getline(result_file,line);
1520             std::istringstream iss(line);
1521             iss >> image >> score >> bounding_box.x >> bounding_box.y >> bounding_box.width >> bounding_box.height;
1522             if (!iss.fail())
1523             {
1524                 //convert right and bottom positions to width and height
1525                 bounding_box.width -= bounding_box.x;
1526                 bounding_box.height -= bounding_box.y;
1527                 //convert to 0-indexing
1528                 bounding_box.x -= 1;
1529                 bounding_box.y -= 1;
1530                 //store in output vectors
1531                 /* first check if the current image code has been seen before */
1532                 vector<string>::iterator image_codes_it = std::find(image_codes.begin(),image_codes.end(),image);
1533                 if (image_codes_it == image_codes.end())
1534                 {
1535                     image_codes.push_back(image);
1536                     vector<float> score_vect(1);
1537                     score_vect[0] = score;
1538                     scores.push_back(score_vect);
1539                     vector<Rect> bounding_box_vect(1);
1540                     bounding_box_vect[0] = bounding_box;
1541                     bounding_boxes.push_back(bounding_box_vect);
1542                 } else {
1543                     /* if the image index has been seen before, add the current object below it in the 2D arrays */
1544                     int image_idx = (int)std::distance(image_codes.begin(),image_codes_it);
1545                     scores[image_idx].push_back(score);
1546                     bounding_boxes[image_idx].push_back(bounding_box);
1547                 }
1548             } else {
1549                 if(!result_file.eof()) CV_Error(Error::StsParseError,"error parsing VOC detector results file.");
1550             }
1551         }
1552         result_file.close();
1553     } else {
1554         string err_msg = "could not open detector results file '" + input_file + "' for reading.";
1555         CV_Error(Error::StsError,err_msg.c_str());
1556     }
1557 }
1558
1559
1560 //Read a VOC annotation xml file for a given image
1561 //------------------------------------------------
1562 //INPUTS:
1563 // - filename           The path of the xml file to read
1564 //OUTPUTS:
1565 // - objects            Array of VocObject describing all object instances present in the given image
1566 void VocData::extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
1567 {
1568 #ifdef PR_DEBUG
1569     int block = 1;
1570     cout << "SAMPLE VOC OBJECT EXTRACTION for " << filename << ":" << endl;
1571 #endif
1572     objects.clear();
1573     object_data.clear();
1574
1575     string contents, object_contents, tag_contents;
1576
1577     readFileToString(filename, contents);
1578
1579     //keep on extracting 'object' blocks until no more can be found
1580     if (extractXMLBlock(contents, "annotation", 0, contents) != -1)
1581     {
1582         int searchpos = 0;
1583         searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
1584         while (searchpos != -1)
1585         {
1586 #ifdef PR_DEBUG
1587             cout << "SEARCHPOS:" << searchpos << endl;
1588             cout << "start block " << block << " ---------" << endl;
1589             cout << object_contents << endl;
1590             cout << "end block " << block << " -----------" << endl;
1591             ++block;
1592 #endif
1593
1594             ObdObject object;
1595             VocObjectData object_d;
1596
1597             //object class -------------
1598
1599             if (extractXMLBlock(object_contents, "name", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <name> tag in object definition of '" + filename + "'");
1600             object.object_class.swap(tag_contents);
1601
1602             //object bounding box -------------
1603
1604             int xmax, xmin, ymax, ymin;
1605
1606             if (extractXMLBlock(object_contents, "xmax", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <xmax> tag in object definition of '" + filename + "'");
1607             xmax = stringToInteger(tag_contents);
1608
1609             if (extractXMLBlock(object_contents, "xmin", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <xmin> tag in object definition of '" + filename + "'");
1610             xmin = stringToInteger(tag_contents);
1611
1612             if (extractXMLBlock(object_contents, "ymax", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <ymax> tag in object definition of '" + filename + "'");
1613             ymax = stringToInteger(tag_contents);
1614
1615             if (extractXMLBlock(object_contents, "ymin", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <ymin> tag in object definition of '" + filename + "'");
1616             ymin = stringToInteger(tag_contents);
1617
1618             object.boundingBox.x = xmin-1;      //convert to 0-based indexing
1619             object.boundingBox.width = xmax - xmin;
1620             object.boundingBox.y = ymin-1;
1621             object.boundingBox.height = ymax - ymin;
1622
1623             CV_Assert(xmin != 0);
1624             CV_Assert(xmax > xmin);
1625             CV_Assert(ymin != 0);
1626             CV_Assert(ymax > ymin);
1627
1628
1629             //object tags -------------
1630
1631             if (extractXMLBlock(object_contents, "difficult", 0, tag_contents) != -1)
1632             {
1633                 object_d.difficult = (tag_contents == "1");
1634             } else object_d.difficult = false;
1635             if (extractXMLBlock(object_contents, "occluded", 0, tag_contents) != -1)
1636             {
1637                 object_d.occluded = (tag_contents == "1");
1638             } else object_d.occluded = false;
1639             if (extractXMLBlock(object_contents, "truncated", 0, tag_contents) != -1)
1640             {
1641                 object_d.truncated = (tag_contents == "1");
1642             } else object_d.truncated = false;
1643             if (extractXMLBlock(object_contents, "pose", 0, tag_contents) != -1)
1644             {
1645                 if (tag_contents == "Frontal") object_d.pose = CV_VOC_POSE_FRONTAL;
1646                 if (tag_contents == "Rear") object_d.pose = CV_VOC_POSE_REAR;
1647                 if (tag_contents == "Left") object_d.pose = CV_VOC_POSE_LEFT;
1648                 if (tag_contents == "Right") object_d.pose = CV_VOC_POSE_RIGHT;
1649             }
1650
1651             //add to array of objects
1652             objects.push_back(object);
1653             object_data.push_back(object_d);
1654
1655             //extract next 'object' block from file if it exists
1656             searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
1657         }
1658     }
1659 }
1660
1661 //Converts an image identifier string in the format YYYY_XXXXXX to a single index integer of form XXXXXXYYYY
1662 //where Y represents a year and returns the image path
1663 //----------------------------------------------------------------------------------------------------------
1664 string VocData::getImagePath(const string& input_str)
1665 {
1666     string path = m_image_path;
1667     path.replace(path.find("%s"),2,input_str);
1668     return path;
1669 }
1670
1671 //Tests two boundary boxes for overlap (using the intersection over union metric) and returns the overlap if the objects
1672 //defined by the two bounding boxes are considered to be matched according to the criterion outlined in
1673 //the VOC documentation [namely intersection/union > some threshold] otherwise returns -1.0 (no match)
1674 //----------------------------------------------------------------------------------------------------------
1675 float VocData::testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth)
1676 {
1677     int detection_x2 = detection.x + detection.width;
1678     int detection_y2 = detection.y + detection.height;
1679     int ground_truth_x2 = ground_truth.x + ground_truth.width;
1680     int ground_truth_y2 = ground_truth.y + ground_truth.height;
1681     //first calculate the boundaries of the intersection of the rectangles
1682     int intersection_x = std::max(detection.x, ground_truth.x); //rightmost left
1683     int intersection_y = std::max(detection.y, ground_truth.y); //bottommost top
1684     int intersection_x2 = std::min(detection_x2, ground_truth_x2); //leftmost right
1685     int intersection_y2 = std::min(detection_y2, ground_truth_y2); //topmost bottom
1686     //then calculate the width and height of the intersection rect
1687     int intersection_width = intersection_x2 - intersection_x + 1;
1688     int intersection_height = intersection_y2 - intersection_y + 1;
1689     //if there is no overlap then return false straight away
1690     if ((intersection_width <= 0) || (intersection_height <= 0)) return -1.0;
1691     //otherwise calculate the intersection
1692     int intersection_area = intersection_width*intersection_height;
1693
1694     //now calculate the union
1695     int union_area = (detection.width+1)*(detection.height+1) + (ground_truth.width+1)*(ground_truth.height+1) - intersection_area;
1696
1697     //calculate the intersection over union and use as threshold as per VOC documentation
1698     float overlap = static_cast<float>(intersection_area)/static_cast<float>(union_area);
1699     if (overlap > m_min_overlap)
1700     {
1701         return overlap;
1702     } else {
1703         return -1.0;
1704     }
1705 }
1706
1707 //Extracts the object class and dataset from the filename of a VOC standard results text file, which takes
1708 //the format 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'
1709 //----------------------------------------------------------------------------------------------------------
1710 void VocData::extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name)
1711 {
1712     string input_file_std = checkFilenamePathsep(input_file);
1713
1714     size_t fnamestart = input_file_std.rfind("/");
1715     size_t fnameend = input_file_std.rfind(".txt");
1716
1717     if ((fnamestart == input_file_std.npos) || (fnameend == input_file_std.npos))
1718         CV_Error(Error::StsError,"Could not extract filename of results file.");
1719
1720     ++fnamestart;
1721     if (fnamestart >= fnameend)
1722         CV_Error(Error::StsError,"Could not extract filename of results file.");
1723
1724     //extract dataset and class names, triggering exception if the filename format is not correct
1725     string filename = input_file_std.substr(fnamestart, fnameend-fnamestart);
1726     size_t datasetstart = filename.find("_");
1727     datasetstart = filename.find("_",datasetstart+1);
1728     size_t classstart = filename.find("_",datasetstart+1);
1729     //allow for appended index after a further '_' by discarding this part if it exists
1730     size_t classend = filename.find("_",classstart+1);
1731     if (classend == filename.npos) classend = filename.size();
1732     if ((datasetstart == filename.npos) || (classstart == filename.npos))
1733         CV_Error(Error::StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
1734     ++datasetstart;
1735     ++classstart;
1736     if (((datasetstart-classstart) < 1) || ((classend-datasetstart) < 1))
1737         CV_Error(Error::StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
1738
1739     dataset_name = filename.substr(datasetstart,classstart-datasetstart-1);
1740     class_name = filename.substr(classstart,classend-classstart);
1741 }
1742
1743 bool VocData::getClassifierGroundTruthImage(const string& obj_class, const string& id)
1744 {
1745     /* if the classifier ground truth data for all images of the current class has not been loaded yet, load it now */
1746     if (m_classifier_gt_all_ids.empty() || (m_classifier_gt_class != obj_class))
1747     {
1748         m_classifier_gt_all_ids.clear();
1749         m_classifier_gt_all_present.clear();
1750         m_classifier_gt_class = obj_class;
1751         for (int i=0; i<2; ++i) //run twice (once over test set and once over training set)
1752         {
1753             //generate the filename of the classification ground-truth textfile for the object class
1754             string gtFilename = m_class_imageset_path;
1755             gtFilename.replace(gtFilename.find("%s"),2,obj_class);
1756             if (i == 0)
1757             {
1758                 gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
1759             } else {
1760                 gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
1761             }
1762
1763             //parse the ground truth file, storing in two separate vectors
1764             //for the image code and the ground truth value
1765             vector<string> image_codes;
1766             vector<char> object_present;
1767             readClassifierGroundTruth(gtFilename, image_codes, object_present);
1768
1769             m_classifier_gt_all_ids.insert(m_classifier_gt_all_ids.end(),image_codes.begin(),image_codes.end());
1770             m_classifier_gt_all_present.insert(m_classifier_gt_all_present.end(),object_present.begin(),object_present.end());
1771
1772             CV_Assert(m_classifier_gt_all_ids.size() == m_classifier_gt_all_present.size());
1773         }
1774     }
1775
1776
1777     //search for the image code
1778     vector<string>::iterator it = find (m_classifier_gt_all_ids.begin(), m_classifier_gt_all_ids.end(), id);
1779     if (it != m_classifier_gt_all_ids.end())
1780     {
1781         //image found, so return corresponding ground truth
1782         return m_classifier_gt_all_present[std::distance(m_classifier_gt_all_ids.begin(),it)] != 0;
1783     } else {
1784         string err_msg = "could not find classifier ground truth for image '" + id + "' and class '" + obj_class + "'";
1785         CV_Error(Error::StsError,err_msg.c_str());
1786     }
1787
1788     return true;
1789 }
1790
1791 //-------------------------------------------------------------------
1792 // Protected Functions (utility) ------------------------------------
1793 //-------------------------------------------------------------------
1794
1795 //returns a vector containing indexes of the input vector in sorted ascending/descending order
1796 void VocData::getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending)
1797 {
1798     /* 1. store sorting order in 'order_pair' */
1799     vector<std::pair<size_t, vector<float>::const_iterator> > order_pair(values.size());
1800
1801     size_t n = 0;
1802     for (vector<float>::const_iterator it = values.begin(); it != values.end(); ++it, ++n)
1803         order_pair[n] = make_pair(n, it);
1804
1805     std::sort(order_pair.begin(),order_pair.end(),orderingSorter());
1806     if (descending == false) std::reverse(order_pair.begin(),order_pair.end());
1807
1808     vector<size_t>(order_pair.size()).swap(order);
1809     for (size_t i = 0; i < order_pair.size(); ++i)
1810     {
1811         order[i] = order_pair[i].first;
1812     }
1813 }
1814
1815 void VocData::readFileToString(const string filename, string& file_contents)
1816 {
1817     std::ifstream ifs(filename.c_str());
1818     if (!ifs.is_open()) CV_Error(Error::StsError,"could not open text file");
1819
1820     stringstream oss;
1821     oss << ifs.rdbuf();
1822
1823     file_contents = oss.str();
1824 }
1825
1826 int VocData::stringToInteger(const string input_str)
1827 {
1828     int result = 0;
1829
1830     stringstream ss(input_str);
1831     if ((ss >> result).fail())
1832     {
1833         CV_Error(Error::StsBadArg,"could not perform string to integer conversion");
1834     }
1835     return result;
1836 }
1837
1838 string VocData::integerToString(const int input_int)
1839 {
1840     string result;
1841
1842     stringstream ss;
1843     if ((ss << input_int).fail())
1844     {
1845         CV_Error(Error::StsBadArg,"could not perform integer to string conversion");
1846     }
1847     result = ss.str();
1848     return result;
1849 }
1850
1851 string VocData::checkFilenamePathsep( const string filename, bool add_trailing_slash )
1852 {
1853     string filename_new = filename;
1854
1855     size_t pos = filename_new.find("\\\\");
1856     while (pos != filename_new.npos)
1857     {
1858         filename_new.replace(pos,2,"/");
1859         pos = filename_new.find("\\\\", pos);
1860     }
1861     pos = filename_new.find("\\");
1862     while (pos != filename_new.npos)
1863     {
1864         filename_new.replace(pos,1,"/");
1865         pos = filename_new.find("\\", pos);
1866     }
1867     if (add_trailing_slash)
1868     {
1869         //add training slash if this is missing
1870         if (filename_new.rfind("/") != filename_new.length()-1) filename_new += "/";
1871     }
1872
1873     return filename_new;
1874 }
1875
1876 void VocData::convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images)
1877 {
1878     images.clear();
1879     images.reserve(image_codes.size());
1880
1881     string path;
1882     //transfer to output arrays
1883     for (size_t i = 0; i < image_codes.size(); ++i)
1884     {
1885         //generate image path and indices from extracted string code
1886         path = getImagePath(image_codes[i]);
1887         images.push_back(ObdImage(image_codes[i], path));
1888     }
1889 }
1890
1891 //Extract text from within a given tag from an XML file
1892 //-----------------------------------------------------
1893 //INPUTS:
1894 // - src            XML source file
1895 // - tag            XML tag delimiting block to extract
1896 // - searchpos      position within src at which to start search
1897 //OUTPUTS:
1898 // - tag_contents   text extracted between <tag> and </tag> tags
1899 //RETURN VALUE:
1900 // - the position of the final character extracted in tag_contents within src
1901 //      (can be used to call extractXMLBlock recursively to extract multiple blocks)
1902 //      returns -1 if the tag could not be found
1903 int VocData::extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents)
1904 {
1905     size_t startpos, next_startpos, endpos;
1906     int embed_count = 1;
1907
1908     //find position of opening tag
1909     startpos = src.find("<" + tag + ">", searchpos);
1910     if (startpos == string::npos) return -1;
1911
1912     //initialize endpos -
1913     // start searching for end tag anywhere after opening tag
1914     endpos = startpos;
1915
1916     //find position of next opening tag
1917     next_startpos = src.find("<" + tag + ">", startpos+1);
1918
1919     //match opening tags with closing tags, and only
1920     //accept final closing tag of same level as original
1921     //opening tag
1922     while (embed_count > 0)
1923     {
1924         endpos = src.find("</" + tag + ">", endpos+1);
1925         if (endpos == string::npos) return -1;
1926
1927         //the next code is only executed if there are embedded tags with the same name
1928         if (next_startpos != string::npos)
1929         {
1930             while (next_startpos<endpos)
1931             {
1932                 //counting embedded start tags
1933                 ++embed_count;
1934                 next_startpos = src.find("<" + tag + ">", next_startpos+1);
1935                 if (next_startpos == string::npos) break;
1936             }
1937         }
1938         //passing end tag so decrement nesting level
1939         --embed_count;
1940     }
1941
1942     //finally, extract the tag region
1943     startpos += tag.length() + 2;
1944     if (startpos > src.length()) return -1;
1945     if (endpos > src.length()) return -1;
1946     tag_contents = src.substr(startpos,endpos-startpos);
1947     return static_cast<int>(endpos);
1948 }
1949
1950 /****************************************************************************************\
1951 *                            Sample on image classification                             *
1952 \****************************************************************************************/
1953 //
1954 // This part of the code was a little refactor
1955 //
1956 struct DDMParams
1957 {
1958     DDMParams() : detectorType("SURF"), descriptorType("SURF"), matcherType("BruteForce") {}
1959     DDMParams( const string _detectorType, const string _descriptorType, const string& _matcherType ) :
1960         detectorType(_detectorType), descriptorType(_descriptorType), matcherType(_matcherType){}
1961     void read( const FileNode& fn )
1962     {
1963         fn["detectorType"] >> detectorType;
1964         fn["descriptorType"] >> descriptorType;
1965         fn["matcherType"] >> matcherType;
1966     }
1967     void write( FileStorage& fs ) const
1968     {
1969         fs << "detectorType" << detectorType;
1970         fs << "descriptorType" << descriptorType;
1971         fs << "matcherType" << matcherType;
1972     }
1973     void print() const
1974     {
1975         cout << "detectorType: " << detectorType << endl;
1976         cout << "descriptorType: " << descriptorType << endl;
1977         cout << "matcherType: " << matcherType << endl;
1978     }
1979
1980     string detectorType;
1981     string descriptorType;
1982     string matcherType;
1983 };
1984
1985 struct VocabTrainParams
1986 {
1987     VocabTrainParams() : trainObjClass("chair"), vocabSize(1000), memoryUse(200), descProportion(0.3f) {}
1988     VocabTrainParams( const string _trainObjClass, size_t _vocabSize, size_t _memoryUse, float _descProportion ) :
1989             trainObjClass(_trainObjClass), vocabSize((int)_vocabSize), memoryUse((int)_memoryUse), descProportion(_descProportion) {}
1990     void read( const FileNode& fn )
1991     {
1992         fn["trainObjClass"] >> trainObjClass;
1993         fn["vocabSize"] >> vocabSize;
1994         fn["memoryUse"] >> memoryUse;
1995         fn["descProportion"] >> descProportion;
1996     }
1997     void write( FileStorage& fs ) const
1998     {
1999         fs << "trainObjClass" << trainObjClass;
2000         fs << "vocabSize" << vocabSize;
2001         fs << "memoryUse" << memoryUse;
2002         fs << "descProportion" << descProportion;
2003     }
2004     void print() const
2005     {
2006         cout << "trainObjClass: " << trainObjClass << endl;
2007         cout << "vocabSize: " << vocabSize << endl;
2008         cout << "memoryUse: " << memoryUse << endl;
2009         cout << "descProportion: " << descProportion << endl;
2010     }
2011
2012
2013     string trainObjClass; // Object class used for training visual vocabulary.
2014                           // It shouldn't matter which object class is specified here - visual vocab will still be the same.
2015     int vocabSize; //number of visual words in vocabulary to train
2016     int memoryUse; // Memory to preallocate (in MB) when training vocab.
2017                    // Change this depending on the size of the dataset/available memory.
2018     float descProportion; // Specifies the number of descriptors to use from each image as a proportion of the total num descs.
2019 };
2020
2021 struct SVMTrainParamsExt
2022 {
2023     SVMTrainParamsExt() : descPercent(0.5f), targetRatio(0.4f), balanceClasses(true) {}
2024     SVMTrainParamsExt( float _descPercent, float _targetRatio, bool _balanceClasses ) :
2025             descPercent(_descPercent), targetRatio(_targetRatio), balanceClasses(_balanceClasses) {}
2026     void read( const FileNode& fn )
2027     {
2028         fn["descPercent"] >> descPercent;
2029         fn["targetRatio"] >> targetRatio;
2030         fn["balanceClasses"] >> balanceClasses;
2031     }
2032     void write( FileStorage& fs ) const
2033     {
2034         fs << "descPercent" << descPercent;
2035         fs << "targetRatio" << targetRatio;
2036         fs << "balanceClasses" << balanceClasses;
2037     }
2038     void print() const
2039     {
2040         cout << "descPercent: " << descPercent << endl;
2041         cout << "targetRatio: " << targetRatio << endl;
2042         cout << "balanceClasses: " << balanceClasses << endl;
2043     }
2044
2045     float descPercent; // Percentage of extracted descriptors to use for training.
2046     float targetRatio; // Try to get this ratio of positive to negative samples (minimum).
2047     bool balanceClasses;    // Balance class weights by number of samples in each (if true cSvmTrainTargetRatio is ignored).
2048 };
2049
2050 static void readUsedParams( const FileNode& fn, string& vocName, DDMParams& ddmParams, VocabTrainParams& vocabTrainParams, SVMTrainParamsExt& svmTrainParamsExt )
2051 {
2052     fn["vocName"] >> vocName;
2053
2054     FileNode currFn = fn;
2055
2056     currFn = fn["ddmParams"];
2057     ddmParams.read( currFn );
2058
2059     currFn = fn["vocabTrainParams"];
2060     vocabTrainParams.read( currFn );
2061
2062     currFn = fn["svmTrainParamsExt"];
2063     svmTrainParamsExt.read( currFn );
2064 }
2065
2066 static void writeUsedParams( FileStorage& fs, const string& vocName, const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams, const SVMTrainParamsExt& svmTrainParamsExt )
2067 {
2068     fs << "vocName" << vocName;
2069
2070     fs << "ddmParams" << "{";
2071     ddmParams.write(fs);
2072     fs << "}";
2073
2074     fs << "vocabTrainParams" << "{";
2075     vocabTrainParams.write(fs);
2076     fs << "}";
2077
2078     fs << "svmTrainParamsExt" << "{";
2079     svmTrainParamsExt.write(fs);
2080     fs << "}";
2081 }
2082
2083 static void printUsedParams( const string& vocPath, const string& resDir,
2084                       const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams,
2085                       const SVMTrainParamsExt& svmTrainParamsExt )
2086 {
2087     cout << "CURRENT CONFIGURATION" << endl;
2088     cout << "----------------------------------------------------------------" << endl;
2089     cout << "vocPath: " << vocPath << endl;
2090     cout << "resDir: " << resDir << endl;
2091     cout << endl; ddmParams.print();
2092     cout << endl; vocabTrainParams.print();
2093     cout << endl; svmTrainParamsExt.print();
2094     cout << "----------------------------------------------------------------" << endl << endl;
2095 }
2096
2097 static bool readVocabulary( const string& filename, Mat& vocabulary )
2098 {
2099     cout << "Reading vocabulary...";
2100     FileStorage fs( filename, FileStorage::READ );
2101     if( fs.isOpened() )
2102     {
2103         fs["vocabulary"] >> vocabulary;
2104         cout << "done" << endl;
2105         return true;
2106     }
2107     return false;
2108 }
2109
2110 static bool writeVocabulary( const string& filename, const Mat& vocabulary )
2111 {
2112     cout << "Saving vocabulary..." << endl;
2113     FileStorage fs( filename, FileStorage::WRITE );
2114     if( fs.isOpened() )
2115     {
2116         fs << "vocabulary" << vocabulary;
2117         return true;
2118     }
2119     return false;
2120 }
2121
2122 static Mat trainVocabulary( const string& filename, VocData& vocData, const VocabTrainParams& trainParams,
2123                      const Ptr<FeatureDetector>& fdetector, const Ptr<DescriptorExtractor>& dextractor )
2124 {
2125     Mat vocabulary;
2126     if( !readVocabulary( filename, vocabulary) )
2127     {
2128         CV_Assert( dextractor->descriptorType() == CV_32FC1 );
2129         const int elemSize = CV_ELEM_SIZE(dextractor->descriptorType());
2130         const int descByteSize = dextractor->descriptorSize() * elemSize;
2131         const int bytesInMB = 1048576;
2132         const int maxDescCount = (trainParams.memoryUse * bytesInMB) / descByteSize; // Total number of descs to use for training.
2133
2134         cout << "Extracting VOC data..." << endl;
2135         vector<ObdImage> images;
2136         vector<char> objectPresent;
2137         vocData.getClassImages( trainParams.trainObjClass, CV_OBD_TRAIN, images, objectPresent );
2138
2139         cout << "Computing descriptors..." << endl;
2140         RNG& rng = theRNG();
2141         TermCriteria terminate_criterion;
2142         terminate_criterion.epsilon = FLT_EPSILON;
2143         BOWKMeansTrainer bowTrainer( trainParams.vocabSize, terminate_criterion, 3, KMEANS_PP_CENTERS );
2144
2145         while( images.size() > 0 )
2146         {
2147             if( bowTrainer.descriptorsCount() > maxDescCount )
2148             {
2149 #ifdef DEBUG_DESC_PROGRESS
2150                 cout << "Breaking due to full memory ( descriptors count = " << bowTrainer.descriptorsCount()
2151                         << "; descriptor size in bytes = " << descByteSize << "; all used memory = "
2152                         << bowTrainer.descriptorsCount()*descByteSize << endl;
2153 #endif
2154                 break;
2155             }
2156
2157             // Randomly pick an image from the dataset which hasn't yet been seen
2158             // and compute the descriptors from that image.
2159             int randImgIdx = rng( (unsigned)images.size() );
2160             Mat colorImage = imread( images[randImgIdx].path );
2161             vector<KeyPoint> imageKeypoints;
2162             fdetector->detect( colorImage, imageKeypoints );
2163             Mat imageDescriptors;
2164             dextractor->compute( colorImage, imageKeypoints, imageDescriptors );
2165
2166             //check that there were descriptors calculated for the current image
2167             if( !imageDescriptors.empty() )
2168             {
2169                 int descCount = imageDescriptors.rows;
2170                 // Extract trainParams.descProportion descriptors from the image, breaking if the 'allDescriptors' matrix becomes full
2171                 int descsToExtract = static_cast<int>(trainParams.descProportion * static_cast<float>(descCount));
2172                 // Fill mask of used descriptors
2173                 vector<char> usedMask( descCount, false );
2174                 fill( usedMask.begin(), usedMask.begin() + descsToExtract, true );
2175                 for( int i = 0; i < descCount; i++ )
2176                 {
2177                     int i1 = rng(descCount), i2 = rng(descCount);
2178                     char tmp = usedMask[i1]; usedMask[i1] = usedMask[i2]; usedMask[i2] = tmp;
2179                 }
2180
2181                 for( int i = 0; i < descCount; i++ )
2182                 {
2183                     if( usedMask[i] && bowTrainer.descriptorsCount() < maxDescCount )
2184                         bowTrainer.add( imageDescriptors.row(i) );
2185                 }
2186             }
2187
2188 #ifdef DEBUG_DESC_PROGRESS
2189             cout << images.size() << " images left, " << images[randImgIdx].id << " processed - "
2190                     <</* descs_extracted << "/" << image_descriptors.rows << " extracted - " << */
2191                     cvRound((static_cast<double>(bowTrainer.descriptorsCount())/static_cast<double>(maxDescCount))*100.0)
2192                     << " % memory used" << ( imageDescriptors.empty() ? " -> no descriptors extracted, skipping" : "") << endl;
2193 #endif
2194
2195             // Delete the current element from images so it is not added again
2196             images.erase( images.begin() + randImgIdx );
2197         }
2198
2199         cout << "Maximum allowed descriptor count: " << maxDescCount << ", Actual descriptor count: " << bowTrainer.descriptorsCount() << endl;
2200
2201         cout << "Training vocabulary..." << endl;
2202         vocabulary = bowTrainer.cluster();
2203
2204         if( !writeVocabulary(filename, vocabulary) )
2205         {
2206             cout << "Error: file " << filename << " can not be opened to write" << endl;
2207             exit(-1);
2208         }
2209     }
2210     return vocabulary;
2211 }
2212
2213 static bool readBowImageDescriptor( const string& file, Mat& bowImageDescriptor )
2214 {
2215     FileStorage fs( file, FileStorage::READ );
2216     if( fs.isOpened() )
2217     {
2218         fs["imageDescriptor"] >> bowImageDescriptor;
2219         return true;
2220     }
2221     return false;
2222 }
2223
2224 static bool writeBowImageDescriptor( const string& file, const Mat& bowImageDescriptor )
2225 {
2226     FileStorage fs( file, FileStorage::WRITE );
2227     if( fs.isOpened() )
2228     {
2229         fs << "imageDescriptor" << bowImageDescriptor;
2230         return true;
2231     }
2232     return false;
2233 }
2234
2235 // Load in the bag of words vectors for a set of images, from file if possible
2236 static void calculateImageDescriptors( const vector<ObdImage>& images, vector<Mat>& imageDescriptors,
2237                                 Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2238                                 const string& resPath )
2239 {
2240     CV_Assert( !bowExtractor->getVocabulary().empty() );
2241     imageDescriptors.resize( images.size() );
2242
2243     for( size_t i = 0; i < images.size(); i++ )
2244     {
2245         string filename = resPath + bowImageDescriptorsDir + "/" + images[i].id + ".xml.gz";
2246         if( readBowImageDescriptor( filename, imageDescriptors[i] ) )
2247         {
2248 #ifdef DEBUG_DESC_PROGRESS
2249             cout << "Loaded bag of word vector for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << endl;
2250 #endif
2251         }
2252         else
2253         {
2254             Mat colorImage = imread( images[i].path );
2255 #ifdef DEBUG_DESC_PROGRESS
2256             cout << "Computing descriptors for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << flush;
2257 #endif
2258             vector<KeyPoint> keypoints;
2259             fdetector->detect( colorImage, keypoints );
2260 #ifdef DEBUG_DESC_PROGRESS
2261                 cout << " + generating BoW vector" << std::flush;
2262 #endif
2263             bowExtractor->compute( colorImage, keypoints, imageDescriptors[i] );
2264 #ifdef DEBUG_DESC_PROGRESS
2265             cout << " ...DONE " << static_cast<int>(static_cast<float>(i+1)/static_cast<float>(images.size())*100.0)
2266                  << " % complete" << endl;
2267 #endif
2268             if( !imageDescriptors[i].empty() )
2269             {
2270                 if( !writeBowImageDescriptor( filename, imageDescriptors[i] ) )
2271                 {
2272                     cout << "Error: file " << filename << "can not be opened to write bow image descriptor" << endl;
2273                     exit(-1);
2274                 }
2275             }
2276         }
2277     }
2278 }
2279
2280 static void removeEmptyBowImageDescriptors( vector<ObdImage>& images, vector<Mat>& bowImageDescriptors,
2281                                      vector<char>& objectPresent )
2282 {
2283     CV_Assert( !images.empty() );
2284     for( int i = (int)images.size() - 1; i >= 0; i-- )
2285     {
2286         bool res = bowImageDescriptors[i].empty();
2287         if( res )
2288         {
2289             cout << "Removing image " << images[i].id << " due to no descriptors..." << endl;
2290             images.erase( images.begin() + i );
2291             bowImageDescriptors.erase( bowImageDescriptors.begin() + i );
2292             objectPresent.erase( objectPresent.begin() + i );
2293         }
2294     }
2295 }
2296
2297 static void removeBowImageDescriptorsByCount( vector<ObdImage>& images, vector<Mat> bowImageDescriptors, vector<char> objectPresent,
2298                                        const SVMTrainParamsExt& svmParamsExt, int descsToDelete )
2299 {
2300     RNG& rng = theRNG();
2301     int pos_ex = (int)std::count( objectPresent.begin(), objectPresent.end(), (char)1 );
2302     int neg_ex = (int)std::count( objectPresent.begin(), objectPresent.end(), (char)0 );
2303
2304     while( descsToDelete != 0 )
2305     {
2306         int randIdx = rng((unsigned)images.size());
2307
2308         // Prefer positive training examples according to svmParamsExt.targetRatio if required
2309         if( objectPresent[randIdx] )
2310         {
2311             if( (static_cast<float>(pos_ex)/static_cast<float>(neg_ex+pos_ex)  < svmParamsExt.targetRatio) &&
2312                 (neg_ex > 0) && (svmParamsExt.balanceClasses == false) )
2313             { continue; }
2314             else
2315             { pos_ex--; }
2316         }
2317         else
2318         { neg_ex--; }
2319
2320         images.erase( images.begin() + randIdx );
2321         bowImageDescriptors.erase( bowImageDescriptors.begin() + randIdx );
2322         objectPresent.erase( objectPresent.begin() + randIdx );
2323
2324         descsToDelete--;
2325     }
2326     CV_Assert( bowImageDescriptors.size() == objectPresent.size() );
2327 }
2328
2329 static void setSVMParams( SVM::Params& svmParams, Mat& class_wts_cv, const Mat& responses, bool balanceClasses )
2330 {
2331     int pos_ex = countNonZero(responses == 1);
2332     int neg_ex = countNonZero(responses == -1);
2333     cout << pos_ex << " positive training samples; " << neg_ex << " negative training samples" << endl;
2334
2335     svmParams.svmType = SVM::C_SVC;
2336     svmParams.kernelType = SVM::RBF;
2337     if( balanceClasses )
2338     {
2339         Mat class_wts( 2, 1, CV_32FC1 );
2340         // The first training sample determines the '+1' class internally, even if it is negative,
2341         // so store whether this is the case so that the class weights can be reversed accordingly.
2342         bool reversed_classes = (responses.at<float>(0) < 0.f);
2343         if( reversed_classes == false )
2344         {
2345             class_wts.at<float>(0) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of positive class + 1 (i.e. cost of false positive - larger gives greater cost)
2346             class_wts.at<float>(1) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of negative class - 1 (i.e. cost of false negative)
2347         }
2348         else
2349         {
2350             class_wts.at<float>(0) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex);
2351             class_wts.at<float>(1) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex);
2352         }
2353         class_wts_cv = class_wts;
2354         svmParams.classWeights = class_wts_cv;
2355     }
2356 }
2357
2358 static void setSVMTrainAutoParams( ParamGrid& c_grid, ParamGrid& gamma_grid,
2359                             ParamGrid& p_grid, ParamGrid& nu_grid,
2360                             ParamGrid& coef_grid, ParamGrid& degree_grid )
2361 {
2362     c_grid = SVM::getDefaultGrid(SVM::C);
2363
2364     gamma_grid = SVM::getDefaultGrid(SVM::GAMMA);
2365
2366     p_grid = SVM::getDefaultGrid(SVM::P);
2367     p_grid.logStep = 0;
2368
2369     nu_grid = SVM::getDefaultGrid(SVM::NU);
2370     nu_grid.logStep = 0;
2371
2372     coef_grid = SVM::getDefaultGrid(SVM::COEF);
2373     coef_grid.logStep = 0;
2374
2375     degree_grid = SVM::getDefaultGrid(SVM::DEGREE);
2376     degree_grid.logStep = 0;
2377 }
2378
2379 static Ptr<SVM> trainSVMClassifier( const SVMTrainParamsExt& svmParamsExt, const string& objClassName, VocData& vocData,
2380                          Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2381                          const string& resPath )
2382 {
2383     /* first check if a previously trained svm for the current class has been saved to file */
2384     string svmFilename = resPath + svmsDir + "/" + objClassName + ".xml.gz";
2385     Ptr<SVM> svm;
2386
2387     FileStorage fs( svmFilename, FileStorage::READ);
2388     if( fs.isOpened() )
2389     {
2390         cout << "*** LOADING SVM CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
2391         svm = StatModel::load<SVM>( svmFilename );
2392     }
2393     else
2394     {
2395         cout << "*** TRAINING CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
2396         cout << "CALCULATING BOW VECTORS FOR TRAINING SET OF " << objClassName << "..." << endl;
2397
2398         // Get classification ground truth for images in the training set
2399         vector<ObdImage> images;
2400         vector<Mat> bowImageDescriptors;
2401         vector<char> objectPresent;
2402         vocData.getClassImages( objClassName, CV_OBD_TRAIN, images, objectPresent );
2403
2404         // Compute the bag of words vector for each image in the training set.
2405         calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
2406
2407         // Remove any images for which descriptors could not be calculated
2408         removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent );
2409
2410         CV_Assert( svmParamsExt.descPercent > 0.f && svmParamsExt.descPercent <= 1.f );
2411         if( svmParamsExt.descPercent < 1.f )
2412         {
2413             int descsToDelete = static_cast<int>(static_cast<float>(images.size())*(1.0-svmParamsExt.descPercent));
2414
2415             cout << "Using " << (images.size() - descsToDelete) << " of " << images.size() <<
2416                     " descriptors for training (" << svmParamsExt.descPercent*100.0 << " %)" << endl;
2417             removeBowImageDescriptorsByCount( images, bowImageDescriptors, objectPresent, svmParamsExt, descsToDelete );
2418         }
2419
2420         // Prepare the input matrices for SVM training.
2421         Mat trainData( (int)images.size(), bowExtractor->getVocabulary().rows, CV_32FC1 );
2422         Mat responses( (int)images.size(), 1, CV_32SC1 );
2423
2424         // Transfer bag of words vectors and responses across to the training data matrices
2425         for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
2426         {
2427             // Transfer image descriptor (bag of words vector) to training data matrix
2428             Mat submat = trainData.row((int)imageIdx);
2429             if( bowImageDescriptors[imageIdx].cols != bowExtractor->descriptorSize() )
2430             {
2431                 cout << "Error: computed bow image descriptor size " << bowImageDescriptors[imageIdx].cols
2432                      << " differs from vocabulary size" << bowExtractor->getVocabulary().cols << endl;
2433                 exit(-1);
2434             }
2435             bowImageDescriptors[imageIdx].copyTo( submat );
2436
2437             // Set response value
2438             responses.at<int>((int)imageIdx) = objectPresent[imageIdx] ? 1 : -1;
2439         }
2440
2441         cout << "TRAINING SVM FOR CLASS ..." << objClassName << "..." << endl;
2442         SVM::Params svmParams;
2443         Mat class_wts_cv;
2444         setSVMParams( svmParams, class_wts_cv, responses, svmParamsExt.balanceClasses );
2445         svm = SVM::create(svmParams);
2446         ParamGrid c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid;
2447         setSVMTrainAutoParams( c_grid, gamma_grid,  p_grid, nu_grid, coef_grid, degree_grid );
2448
2449         svm->trainAuto(TrainData::create(trainData, ROW_SAMPLE, responses), 10,
2450                        c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid);
2451         cout << "SVM TRAINING FOR CLASS " << objClassName << " COMPLETED" << endl;
2452
2453         svm->save( svmFilename );
2454         cout << "SAVED CLASSIFIER TO FILE" << endl;
2455     }
2456     return svm;
2457 }
2458
2459 static void computeConfidences( const Ptr<SVM>& svm, const string& objClassName, VocData& vocData,
2460                          Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2461                          const string& resPath )
2462 {
2463     cout << "*** CALCULATING CONFIDENCES FOR CLASS " << objClassName << " ***" << endl;
2464     cout << "CALCULATING BOW VECTORS FOR TEST SET OF " << objClassName << "..." << endl;
2465     // Get classification ground truth for images in the test set
2466     vector<ObdImage> images;
2467     vector<Mat> bowImageDescriptors;
2468     vector<char> objectPresent;
2469     vocData.getClassImages( objClassName, CV_OBD_TEST, images, objectPresent );
2470
2471     // Compute the bag of words vector for each image in the test set
2472     calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
2473     // Remove any images for which descriptors could not be calculated
2474     removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent);
2475
2476     // Use the bag of words vectors to calculate classifier output for each image in test set
2477     cout << "CALCULATING CONFIDENCE SCORES FOR CLASS " << objClassName << "..." << endl;
2478     vector<float> confidences( images.size() );
2479     float signMul = 1.f;
2480     for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
2481     {
2482         if( imageIdx == 0 )
2483         {
2484             // In the first iteration, determine the sign of the positive class
2485             float classVal = confidences[imageIdx] = svm->predict( bowImageDescriptors[imageIdx], noArray(), 0 );
2486             float scoreVal = confidences[imageIdx] = svm->predict( bowImageDescriptors[imageIdx], noArray(), StatModel::RAW_OUTPUT );
2487             signMul = (classVal < 0) == (scoreVal < 0) ? 1.f : -1.f;
2488         }
2489         // svm output of decision function
2490         confidences[imageIdx] = signMul * svm->predict( bowImageDescriptors[imageIdx], noArray(), StatModel::RAW_OUTPUT );
2491     }
2492
2493     cout << "WRITING QUERY RESULTS TO VOC RESULTS FILE FOR CLASS " << objClassName << "..." << endl;
2494     vocData.writeClassifierResultsFile( resPath + plotsDir, objClassName, CV_OBD_TEST, images, confidences, 1, true );
2495
2496     cout << "DONE - " << objClassName << endl;
2497     cout << "---------------------------------------------------------------" << endl;
2498 }
2499
2500 static void computeGnuPlotOutput( const string& resPath, const string& objClassName, VocData& vocData )
2501 {
2502     vector<float> precision, recall;
2503     float ap;
2504
2505     const string resultFile = vocData.getResultsFilename( objClassName, CV_VOC_TASK_CLASSIFICATION, CV_OBD_TEST);
2506     const string plotFile = resultFile.substr(0, resultFile.size()-4) + ".plt";
2507
2508     cout << "Calculating precision recall curve for class '" <<objClassName << "'" << endl;
2509     vocData.calcClassifierPrecRecall( resPath + plotsDir + "/" + resultFile, precision, recall, ap, true );
2510     cout << "Outputting to GNUPlot file..." << endl;
2511     vocData.savePrecRecallToGnuplot( resPath + plotsDir + "/" + plotFile, precision, recall, ap, objClassName, CV_VOC_PLOT_PNG );
2512 }
2513
2514
2515
2516
2517 int main(int argc, char** argv)
2518 {
2519     if( argc != 3 && argc != 6 )
2520     {
2521         help(argv);
2522         return -1;
2523     }
2524
2525     cv::initModule_nonfree();
2526
2527     const string vocPath = argv[1], resPath = argv[2];
2528
2529     // Read or set default parameters
2530     string vocName;
2531     DDMParams ddmParams;
2532     VocabTrainParams vocabTrainParams;
2533     SVMTrainParamsExt svmTrainParamsExt;
2534
2535     makeUsedDirs( resPath );
2536
2537     FileStorage paramsFS( resPath + "/" + paramsFile, FileStorage::READ );
2538     if( paramsFS.isOpened() )
2539     {
2540        readUsedParams( paramsFS.root(), vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
2541        CV_Assert( vocName == getVocName(vocPath) );
2542     }
2543     else
2544     {
2545         vocName = getVocName(vocPath);
2546         if( argc!= 6 )
2547         {
2548             cout << "Feature detector, descriptor extractor, descriptor matcher must be set" << endl;
2549             return -1;
2550         }
2551         ddmParams = DDMParams( argv[3], argv[4], argv[5] ); // from command line
2552         // vocabTrainParams and svmTrainParamsExt is set by defaults
2553         paramsFS.open( resPath + "/" + paramsFile, FileStorage::WRITE );
2554         if( paramsFS.isOpened() )
2555         {
2556             writeUsedParams( paramsFS, vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
2557             paramsFS.release();
2558         }
2559         else
2560         {
2561             cout << "File " << (resPath + "/" + paramsFile) << "can not be opened to write" << endl;
2562             return -1;
2563         }
2564     }
2565
2566     // Create detector, descriptor, matcher.
2567     Ptr<FeatureDetector> featureDetector = FeatureDetector::create( ddmParams.detectorType );
2568     Ptr<DescriptorExtractor> descExtractor = DescriptorExtractor::create( ddmParams.descriptorType );
2569     Ptr<BOWImgDescriptorExtractor> bowExtractor;
2570     if( !featureDetector || !descExtractor )
2571     {
2572         cout << "featureDetector or descExtractor was not created" << endl;
2573         return -1;
2574     }
2575     {
2576         Ptr<DescriptorMatcher> descMatcher = DescriptorMatcher::create( ddmParams.matcherType );
2577         if( !featureDetector || !descExtractor || !descMatcher )
2578         {
2579             cout << "descMatcher was not created" << endl;
2580             return -1;
2581         }
2582         bowExtractor = makePtr<BOWImgDescriptorExtractor>( descExtractor, descMatcher );
2583     }
2584
2585     // Print configuration to screen
2586     printUsedParams( vocPath, resPath, ddmParams, vocabTrainParams, svmTrainParamsExt );
2587     // Create object to work with VOC
2588     VocData vocData( vocPath, false );
2589
2590     // 1. Train visual word vocabulary if a pre-calculated vocabulary file doesn't already exist from previous run
2591     Mat vocabulary = trainVocabulary( resPath + "/" + vocabularyFile, vocData, vocabTrainParams,
2592                                       featureDetector, descExtractor );
2593     bowExtractor->setVocabulary( vocabulary );
2594
2595     // 2. Train a classifier and run a sample query for each object class
2596     const vector<string>& objClasses = vocData.getObjectClasses(); // object class list
2597     for( size_t classIdx = 0; classIdx < objClasses.size(); ++classIdx )
2598     {
2599         // Train a classifier on train dataset
2600         Ptr<SVM> svm = trainSVMClassifier( svmTrainParamsExt, objClasses[classIdx], vocData,
2601                                            bowExtractor, featureDetector, resPath );
2602
2603         // Now use the classifier over all images on the test dataset and rank according to score order
2604         // also calculating precision-recall etc.
2605         computeConfidences( svm, objClasses[classIdx], vocData,
2606                             bowExtractor, featureDetector, resPath );
2607         // Calculate precision/recall/ap and use GNUPlot to output to a pdf file
2608         computeGnuPlotOutput( resPath, objClasses[classIdx], vocData );
2609     }
2610     return 0;
2611 }