removed contrib, legacy and softcsscade modules; removed latentsvm and datamatrix...
[platform/upstream/opencv.git] / samples / cpp / bagofwords_classification.cpp
1 #include "opencv2/opencv_modules.hpp"
2 #include "opencv2/highgui/highgui.hpp"
3 #include "opencv2/imgproc/imgproc.hpp"
4 #include "opencv2/features2d/features2d.hpp"
5 #include "opencv2/nonfree/nonfree.hpp"
6 #include "opencv2/ml/ml.hpp"
7
8 #include <fstream>
9 #include <iostream>
10 #include <memory>
11 #include <functional>
12
13 #if defined WIN32 || defined _WIN32
14 #define WIN32_LEAN_AND_MEAN
15 #include <windows.h>
16 #undef min
17 #undef max
18 #include "sys/types.h"
19 #endif
20 #include <sys/stat.h>
21
22 #define DEBUG_DESC_PROGRESS
23
24 using namespace cv;
25 using namespace std;
26
27 const string paramsFile = "params.xml";
28 const string vocabularyFile = "vocabulary.xml.gz";
29 const string bowImageDescriptorsDir = "/bowImageDescriptors";
30 const string svmsDir = "/svms";
31 const string plotsDir = "/plots";
32
33 static void help(char** argv)
34 {
35     cout << "\nThis program shows how to read in, train on and produce test results for the PASCAL VOC (Visual Object Challenge) data. \n"
36      << "It shows how to use detectors, descriptors and recognition methods \n"
37         "Using OpenCV version %s\n" << CV_VERSION << "\n"
38      << "Call: \n"
39     << "Format:\n ./" << argv[0] << " [VOC path] [result directory]  \n"
40     << "       or:  \n"
41     << " ./" << argv[0] << " [VOC path] [result directory] [feature detector] [descriptor extractor] [descriptor matcher] \n"
42     << "\n"
43     << "Input parameters: \n"
44     << "[VOC path]             Path to Pascal VOC data (e.g. /home/my/VOCdevkit/VOC2010). Note: VOC2007-VOC2010 are supported. \n"
45     << "[result directory]     Path to result diractory. Following folders will be created in [result directory]: \n"
46     << "                         bowImageDescriptors - to store image descriptors, \n"
47     << "                         svms - to store trained svms, \n"
48     << "                         plots - to store files for plots creating. \n"
49     << "[feature detector]     Feature detector name (e.g. SURF, FAST...) - see createFeatureDetector() function in detectors.cpp \n"
50     << "                         Currently 12/2010, this is FAST, STAR, SIFT, SURF, MSER, GFTT, HARRIS \n"
51     << "[descriptor extractor] Descriptor extractor name (e.g. SURF, SIFT) - see createDescriptorExtractor() function in descriptors.cpp \n"
52     << "                         Currently 12/2010, this is SURF, OpponentSIFT, SIFT, OpponentSURF, BRIEF \n"
53     << "[descriptor matcher]   Descriptor matcher name (e.g. BruteForce) - see createDescriptorMatcher() function in matchers.cpp \n"
54     << "                         Currently 12/2010, this is BruteForce, BruteForce-L1, FlannBased, BruteForce-Hamming, BruteForce-HammingLUT \n"
55     << "\n";
56 }
57
58 static void makeDir( const string& dir )
59 {
60 #if defined WIN32 || defined _WIN32
61     CreateDirectory( dir.c_str(), 0 );
62 #else
63     mkdir( dir.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH );
64 #endif
65 }
66
67 static void makeUsedDirs( const string& rootPath )
68 {
69     makeDir(rootPath + bowImageDescriptorsDir);
70     makeDir(rootPath + svmsDir);
71     makeDir(rootPath + plotsDir);
72 }
73
74 /****************************************************************************************\
75 *                    Classes to work with PASCAL VOC dataset                             *
76 \****************************************************************************************/
77 //
78 // TODO: refactor this part of the code
79 //
80
81
82 //used to specify the (sub-)dataset over which operations are performed
83 enum ObdDatasetType {CV_OBD_TRAIN, CV_OBD_TEST};
84
85 class ObdObject
86 {
87 public:
88     string object_class;
89     Rect boundingBox;
90 };
91
92 //extended object data specific to VOC
93 enum VocPose {CV_VOC_POSE_UNSPECIFIED, CV_VOC_POSE_FRONTAL, CV_VOC_POSE_REAR, CV_VOC_POSE_LEFT, CV_VOC_POSE_RIGHT};
94 class VocObjectData
95 {
96 public:
97     bool difficult;
98     bool occluded;
99     bool truncated;
100     VocPose pose;
101 };
102 //enum VocDataset {CV_VOC2007, CV_VOC2008, CV_VOC2009, CV_VOC2010};
103 enum VocPlotType {CV_VOC_PLOT_SCREEN, CV_VOC_PLOT_PNG};
104 enum VocGT {CV_VOC_GT_NONE, CV_VOC_GT_DIFFICULT, CV_VOC_GT_PRESENT};
105 enum VocConfCond {CV_VOC_CCOND_RECALL, CV_VOC_CCOND_SCORETHRESH};
106 enum VocTask {CV_VOC_TASK_CLASSIFICATION, CV_VOC_TASK_DETECTION};
107
108 class ObdImage
109 {
110 public:
111     ObdImage(string p_id, string p_path) : id(p_id), path(p_path) {}
112     string id;
113     string path;
114 };
115
116 //used by getDetectorGroundTruth to sort a two dimensional list of floats in descending order
117 class ObdScoreIndexSorter
118 {
119 public:
120     float score;
121     int image_idx;
122     int obj_idx;
123     bool operator < (const ObdScoreIndexSorter& compare) const {return (score < compare.score);}
124 };
125
126 class VocData
127 {
128 public:
129     VocData( const string& vocPath, bool useTestDataset )
130         { initVoc( vocPath, useTestDataset ); }
131     ~VocData(){}
132     /* functions for returning classification/object data for multiple images given an object class */
133     void getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
134     void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects);
135     void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth);
136     /* functions for returning object data for a single image given an image id */
137     ObdImage getObjects(const string& id, vector<ObdObject>& objects);
138     ObdImage getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
139     ObdImage getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth);
140     /* functions for returning the ground truth (present/absent) for groups of images */
141     void getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth);
142     void getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth);
143     int getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult = true);
144     /* functions for writing VOC-compatible results files */
145     void writeClassifierResultsFile(const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition = 1, const bool overwrite_ifexists = false);
146     /* functions for calculating metrics from a set of classification/detection results */
147     string getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition = -1, const int number = -1);
148     void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking);
149     void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap);
150     void calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile = false);
151     /* functions for calculating confusion matrices */
152     void calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values);
153     void calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult = true);
154     /* functions for outputting gnuplot output files */
155     void savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title = string(), const VocPlotType plot_type = CV_VOC_PLOT_SCREEN);
156     /* functions for reading in result/ground truth files */
157     void readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
158     void readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores);
159     void readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
160     /* functions for getting dataset info */
161     const vector<string>& getObjectClasses();
162     string getResultsDirectory();
163 protected:
164     void initVoc( const string& vocPath, const bool useTestDataset );
165     void initVoc2007to2010( const string& vocPath, const bool useTestDataset);
166     void readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present);
167     void readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores);
168     void readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
169     void extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
170     string getImagePath(const string& input_str);
171
172     void getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present);
173     void calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization = -1);
174
175     //test two bounding boxes to see if they meet the overlap criteria defined in the VOC documentation
176     float testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth);
177     //extract class and dataset name from a VOC-standard classification/detection results filename
178     void extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name);
179     //get classifier ground truth for a single image
180     bool getClassifierGroundTruthImage(const string& obj_class, const string& id);
181
182     //utility functions
183     void getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending = true);
184     int stringToInteger(const string input_str);
185     void readFileToString(const string filename, string& file_contents);
186     string integerToString(const int input_int);
187     string checkFilenamePathsep(const string filename, bool add_trailing_slash = false);
188     void convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images);
189     int extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents);
190     //utility sorter
191     struct orderingSorter
192     {
193         bool operator ()(std::pair<size_t, vector<float>::const_iterator> const& a, std::pair<size_t, vector<float>::const_iterator> const& b)
194         {
195             return (*a.second) > (*b.second);
196         }
197     };
198     //data members
199     string m_vocPath;
200     string m_vocName;
201     //string m_resPath;
202
203     string m_annotation_path;
204     string m_image_path;
205     string m_imageset_path;
206     string m_class_imageset_path;
207
208     vector<string> m_classifier_gt_all_ids;
209     vector<char> m_classifier_gt_all_present;
210     string m_classifier_gt_class;
211
212     //data members
213     string m_train_set;
214     string m_test_set;
215
216     vector<string> m_object_classes;
217
218
219     float m_min_overlap;
220     bool m_sampled_ap;
221 };
222
223
224 //Return the classification ground truth data for all images of a given VOC object class
225 //--------------------------------------------------------------------------------------
226 //INPUTS:
227 // - obj_class          The VOC object class identifier string
228 // - dataset            Specifies whether to extract images from the training or test set
229 //OUTPUTS:
230 // - images             An array of ObdImage containing info of all images extracted from the ground truth file
231 // - object_present     An array of bools specifying whether the object defined by 'obj_class' is present in each image or not
232 //NOTES:
233 // This function is primarily useful for the classification task, where only
234 // whether a given object is present or not in an image is required, and not each object instance's
235 // position etc.
236 void VocData::getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
237 {
238     string dataset_str;
239     //generate the filename of the classification ground-truth textfile for the object class
240     if (dataset == CV_OBD_TRAIN)
241     {
242         dataset_str = m_train_set;
243     } else {
244         dataset_str = m_test_set;
245     }
246
247     getClassImages_impl(obj_class, dataset_str, images, object_present);
248 }
249
250 void VocData::getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present)
251 {
252     //generate the filename of the classification ground-truth textfile for the object class
253     string gtFilename = m_class_imageset_path;
254     gtFilename.replace(gtFilename.find("%s"),2,obj_class);
255     gtFilename.replace(gtFilename.find("%s"),2,dataset_str);
256
257     //parse the ground truth file, storing in two separate vectors
258     //for the image code and the ground truth value
259     vector<string> image_codes;
260     readClassifierGroundTruth(gtFilename, image_codes, object_present);
261
262     //prepare output arrays
263     images.clear();
264
265     convertImageCodesToObdImages(image_codes, images);
266 }
267
268 //Return the object data for all images of a given VOC object class
269 //-----------------------------------------------------------------
270 //INPUTS:
271 // - obj_class          The VOC object class identifier string
272 // - dataset            Specifies whether to extract images from the training or test set
273 //OUTPUTS:
274 // - images             An array of ObdImage containing info of all images in chosen dataset (tag, path etc.)
275 // - objects            Contains the extended object info (bounding box etc.) for each object instance in each image
276 // - object_data        Contains VOC-specific extended object info (marked difficult etc.)
277 // - ground_truth       Specifies whether there are any difficult/non-difficult instances of the current
278 //                          object class within each image
279 //NOTES:
280 // This function returns extended object information in addition to the absent/present
281 // classification data returned by getClassImages. The objects returned for each image in the 'objects'
282 // array are of all object classes present in the image, and not just the class defined by 'obj_class'.
283 // 'ground_truth' can be used to determine quickly whether an object instance of the given class is present
284 // in an image or not.
285 void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects)
286 {
287     vector<vector<VocObjectData> > object_data;
288     vector<VocGT> ground_truth;
289
290     getClassObjects(obj_class,dataset,images,objects,object_data,ground_truth);
291 }
292
293 void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth)
294 {
295     //generate the filename of the classification ground-truth textfile for the object class
296     string gtFilename = m_class_imageset_path;
297     gtFilename.replace(gtFilename.find("%s"),2,obj_class);
298     if (dataset == CV_OBD_TRAIN)
299     {
300         gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
301     } else {
302         gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
303     }
304
305     //parse the ground truth file, storing in two separate vectors
306     //for the image code and the ground truth value
307     vector<string> image_codes;
308     vector<char> object_present;
309     readClassifierGroundTruth(gtFilename, image_codes, object_present);
310
311     //prepare output arrays
312     images.clear();
313     objects.clear();
314     object_data.clear();
315     ground_truth.clear();
316
317     string annotationFilename;
318     vector<ObdObject> image_objects;
319     vector<VocObjectData> image_object_data;
320     VocGT image_gt;
321
322     //transfer to output arrays and read in object data for each image
323     for (size_t i = 0; i < image_codes.size(); ++i)
324     {
325         ObdImage image = getObjects(obj_class, image_codes[i], image_objects, image_object_data, image_gt);
326
327         images.push_back(image);
328         objects.push_back(image_objects);
329         object_data.push_back(image_object_data);
330         ground_truth.push_back(image_gt);
331     }
332 }
333
334 //Return ground truth data for the objects present in an image with a given UID
335 //-----------------------------------------------------------------------------
336 //INPUTS:
337 // - id                 VOC Dataset unique identifier (string code in form YYYY_XXXXXX where YYYY is the year)
338 //OUTPUTS:
339 // - obj_class (*3)     Specifies the object class to use to resolve 'ground_truth'
340 // - objects            Contains the extended object info (bounding box etc.) for each object in the image
341 // - object_data (*2,3) Contains VOC-specific extended object info (marked difficult etc.)
342 // - ground_truth (*3)  Specifies whether there are any difficult/non-difficult instances of the current
343 //                          object class within the image
344 //RETURN VALUE:
345 // ObdImage containing path and other details of image file with given code
346 //NOTES:
347 // There are three versions of this function
348 //  * One returns a simple array of objects given an id [1]
349 //  * One returns the same as (1) plus VOC specific object data [2]
350 //  * One returns the same as (2) plus the ground_truth flag. This also requires an extra input obj_class [3]
351 ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects)
352 {
353     vector<VocObjectData> object_data;
354     ObdImage image = getObjects(id, objects, object_data);
355
356     return image;
357 }
358
359 ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
360 {
361     //first generate the filename of the annotation file
362     string annotationFilename = m_annotation_path;
363
364     annotationFilename.replace(annotationFilename.find("%s"),2,id);
365
366     //extract objects contained in the current image from the xml
367     extractVocObjects(annotationFilename,objects,object_data);
368
369     //generate image path from extracted string code
370     string path = getImagePath(id);
371
372     ObdImage image(id, path);
373     return image;
374 }
375
376 ObdImage VocData::getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth)
377 {
378
379     //extract object data (except for ground truth flag)
380     ObdImage image = getObjects(id,objects,object_data);
381
382     //pregenerate a flag to indicate whether the current class is present or not in the image
383     ground_truth = CV_VOC_GT_NONE;
384     //iterate through all objects in current image
385     for (size_t j = 0; j < objects.size(); ++j)
386     {
387         if (objects[j].object_class == obj_class)
388         {
389             if (object_data[j].difficult == false)
390             {
391                 //if at least one non-difficult example is present, this flag is always set to CV_VOC_GT_PRESENT
392                 ground_truth = CV_VOC_GT_PRESENT;
393                 break;
394             } else {
395                 //set if at least one object instance is present, but it is marked difficult
396                 ground_truth = CV_VOC_GT_DIFFICULT;
397             }
398         }
399     }
400
401     return image;
402 }
403
404 //Return ground truth data for the presence/absence of a given object class in an arbitrary array of images
405 //---------------------------------------------------------------------------------------------------------
406 //INPUTS:
407 // - obj_class          The VOC object class identifier string
408 // - images             An array of ObdImage OR strings containing the images for which ground truth
409 //                          will be computed
410 //OUTPUTS:
411 // - ground_truth       An output array indicating the presence/absence of obj_class within each image
412 void VocData::getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth)
413 {
414     vector<char>(images.size()).swap(ground_truth);
415
416     vector<ObdObject> objects;
417     vector<VocObjectData> object_data;
418     vector<char>::iterator gt_it = ground_truth.begin();
419     for (vector<ObdImage>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
420     {
421         //getObjects(obj_class, it->id, objects, object_data, voc_ground_truth);
422         (*gt_it) = (getClassifierGroundTruthImage(obj_class, it->id));
423     }
424 }
425
426 void VocData::getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth)
427 {
428     vector<char>(images.size()).swap(ground_truth);
429
430     vector<ObdObject> objects;
431     vector<VocObjectData> object_data;
432     vector<char>::iterator gt_it = ground_truth.begin();
433     for (vector<string>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
434     {
435         //getObjects(obj_class, (*it), objects, object_data, voc_ground_truth);
436         (*gt_it) = (getClassifierGroundTruthImage(obj_class, (*it)));
437     }
438 }
439
440 //Return ground truth data for the accuracy of detection results
441 //--------------------------------------------------------------
442 //INPUTS:
443 // - obj_class          The VOC object class identifier string
444 // - images             An array of ObdImage containing the images for which ground truth
445 //                          will be computed
446 // - bounding_boxes     A 2D input array containing the bounding box rects of the objects of
447 //                          obj_class which were detected in each image
448 //OUTPUTS:
449 // - ground_truth       A 2D output array indicating whether each object detection was accurate
450 //                          or not
451 // - detection_difficult A 2D output array indicating whether the detection fired on an object
452 //                          marked as 'difficult'. This allows it to be ignored if necessary
453 //                          (the voc documentation specifies objects marked as difficult
454 //                          have no effects on the results and are effectively ignored)
455 // - (ignore_difficult) If set to true, objects marked as difficult will be ignored when returning
456 //                          the number of hits for p-r normalization (default = true)
457 //RETURN VALUE:
458 //                      Returns the number of object hits in total in the gt to allow proper normalization
459 //                          of a p-r curve
460 //NOTES:
461 // As stated in the VOC documentation, multiple detections of the same object in an image are
462 // considered FALSE detections e.g. 5 detections of a single object is counted as 1 correct
463 // detection and 4 false detections - it is the responsibility of the participant's system
464 // to filter multiple detections from its output
465 int VocData::getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult)
466 {
467     int recall_normalization = 0;
468
469     /* first create a list of indices referring to the elements of bounding_boxes and scores in
470      * descending order of scores */
471     vector<ObdScoreIndexSorter> sorted_ids;
472     {
473         /* first count how many objects to allow preallocation */
474         size_t obj_count = 0;
475         CV_Assert(images.size() == bounding_boxes.size());
476         CV_Assert(scores.size() == bounding_boxes.size());
477         for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
478         {
479             CV_Assert(scores[im_idx].size() == bounding_boxes[im_idx].size());
480             obj_count += scores[im_idx].size();
481         }
482         /* preallocate id vector */
483         sorted_ids.resize(obj_count);
484         /* now copy across scores and indexes to preallocated vector */
485         int flat_pos = 0;
486         for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
487         {
488             for (size_t ob_idx = 0; ob_idx < scores[im_idx].size(); ++ob_idx)
489             {
490                 sorted_ids[flat_pos].score = scores[im_idx][ob_idx];
491                 sorted_ids[flat_pos].image_idx = (int)im_idx;
492                 sorted_ids[flat_pos].obj_idx = (int)ob_idx;
493                 ++flat_pos;
494             }
495         }
496         /* and sort the vector in descending order of score */
497         std::sort(sorted_ids.begin(),sorted_ids.end());
498         std::reverse(sorted_ids.begin(),sorted_ids.end());
499     }
500
501     /* prepare ground truth + difficult vector (1st dimension) */
502     vector<vector<char> >(images.size()).swap(ground_truth);
503     vector<vector<char> >(images.size()).swap(detection_difficult);
504     vector<vector<char> > detected(images.size());
505
506     vector<vector<ObdObject> > img_objects(images.size());
507     vector<vector<VocObjectData> > img_object_data(images.size());
508     /* preload object ground truth bounding box data */
509     {
510         vector<vector<ObdObject> > img_objects_all(images.size());
511         vector<vector<VocObjectData> > img_object_data_all(images.size());
512         for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
513         {
514             /* prepopulate ground truth bounding boxes */
515             getObjects(images[image_idx].id, img_objects_all[image_idx], img_object_data_all[image_idx]);
516             /* meanwhile, also set length of target ground truth + difficult vector to same as number of object detections (2nd dimension) */
517             ground_truth[image_idx].resize(bounding_boxes[image_idx].size());
518             detection_difficult[image_idx].resize(bounding_boxes[image_idx].size());
519         }
520
521         /* save only instances of the object class concerned */
522         for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
523         {
524             for (size_t obj_idx = 0; obj_idx < img_objects_all[image_idx].size(); ++obj_idx)
525             {
526                 if (img_objects_all[image_idx][obj_idx].object_class == obj_class)
527                 {
528                     img_objects[image_idx].push_back(img_objects_all[image_idx][obj_idx]);
529                     img_object_data[image_idx].push_back(img_object_data_all[image_idx][obj_idx]);
530                 }
531             }
532             detected[image_idx].resize(img_objects[image_idx].size(), false);
533         }
534     }
535
536     /* calculate the total number of objects in the ground truth for the current dataset */
537     {
538         vector<ObdImage> gt_images;
539         vector<char> gt_object_present;
540         getClassImages(obj_class, dataset, gt_images, gt_object_present);
541
542         for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
543         {
544             vector<ObdObject> gt_img_objects;
545             vector<VocObjectData> gt_img_object_data;
546             getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
547             for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
548             {
549                 if (gt_img_objects[obj_idx].object_class == obj_class)
550                 {
551                     if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
552                         ++recall_normalization;
553                 }
554             }
555         }
556     }
557
558 #ifdef PR_DEBUG
559     int printed_count = 0;
560 #endif
561     /* now iterate through detections in descending order of score, assigning to ground truth bounding boxes if possible */
562     for (size_t detect_idx = 0; detect_idx < sorted_ids.size(); ++detect_idx)
563     {
564         //read in indexes to make following code easier to read
565         int im_idx = sorted_ids[detect_idx].image_idx;
566         int ob_idx = sorted_ids[detect_idx].obj_idx;
567         //set ground truth for the current object to false by default
568         ground_truth[im_idx][ob_idx] = false;
569         detection_difficult[im_idx][ob_idx] = false;
570         float maxov = -1.0;
571         bool max_is_difficult = false;
572         int max_gt_obj_idx = -1;
573         //-- for each detected object iterate through objects present in the bounding box ground truth --
574         for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
575         {
576             if (detected[im_idx][gt_obj_idx] == false)
577             {
578                 //check if the detected object and ground truth object overlap by a sufficient margin
579                 float ov = testBoundingBoxesForOverlap(bounding_boxes[im_idx][ob_idx], img_objects[im_idx][gt_obj_idx].boundingBox);
580                 if (ov != -1.0)
581                 {
582                     //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
583                     if (ov > maxov)
584                     {
585                         maxov = ov;
586                         max_gt_obj_idx = (int)gt_obj_idx;
587                         //store whether the maximum detection is marked as difficult or not
588                         max_is_difficult = (img_object_data[im_idx][gt_obj_idx].difficult);
589                     }
590                 }
591             }
592         }
593         //-- if a match was found, set the ground truth of the current object to true --
594         if (maxov != -1.0)
595         {
596             CV_Assert(max_gt_obj_idx != -1);
597             ground_truth[im_idx][ob_idx] = true;
598             //store whether the maximum detection was marked as 'difficult' or not
599             detection_difficult[im_idx][ob_idx] = max_is_difficult;
600             //remove the ground truth object so it doesn't match with subsequent detected objects
601             //** this is the behaviour defined by the voc documentation **
602             detected[im_idx][max_gt_obj_idx] = true;
603         }
604 #ifdef PR_DEBUG
605         if (printed_count < 10)
606         {
607             cout << printed_count << ": id=" << images[im_idx].id << ", score=" << scores[im_idx][ob_idx] << " (" << ob_idx << ") [" << bounding_boxes[im_idx][ob_idx].x << "," <<
608                     bounding_boxes[im_idx][ob_idx].y << "," << bounding_boxes[im_idx][ob_idx].width + bounding_boxes[im_idx][ob_idx].x <<
609                     "," << bounding_boxes[im_idx][ob_idx].height + bounding_boxes[im_idx][ob_idx].y << "] detected=" << ground_truth[im_idx][ob_idx] <<
610                     ", difficult=" << detection_difficult[im_idx][ob_idx] << endl;
611             ++printed_count;
612             /* print ground truth */
613             for (int gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
614             {
615                 cout << "    GT: [" << img_objects[im_idx][gt_obj_idx].boundingBox.x << "," <<
616                         img_objects[im_idx][gt_obj_idx].boundingBox.y << "," << img_objects[im_idx][gt_obj_idx].boundingBox.width + img_objects[im_idx][gt_obj_idx].boundingBox.x <<
617                         "," << img_objects[im_idx][gt_obj_idx].boundingBox.height + img_objects[im_idx][gt_obj_idx].boundingBox.y << "]";
618                 if (gt_obj_idx == max_gt_obj_idx) cout << " <--- (" << maxov << " overlap)";
619                 cout << endl;
620             }
621         }
622 #endif
623     }
624
625     return recall_normalization;
626 }
627
628 //Write VOC-compliant classifier results file
629 //-------------------------------------------
630 //INPUTS:
631 // - obj_class          The VOC object class identifier string
632 // - dataset            Specifies whether working with the training or test set
633 // - images             An array of ObdImage containing the images for which data will be saved to the result file
634 // - scores             A corresponding array of confidence scores given a query
635 // - (competition)      If specified, defines which competition the results are for (see VOC documentation - default 1)
636 //NOTES:
637 // The result file path and filename are determined automatically using m_results_directory as a base
638 void VocData::writeClassifierResultsFile( const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition, const bool overwrite_ifexists)
639 {
640     CV_Assert(images.size() == scores.size());
641
642     string output_file_base, output_file;
643     if (dataset == CV_OBD_TRAIN)
644     {
645         output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_train_set + "_" + obj_class;
646     } else {
647         output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_test_set + "_" + obj_class;
648     }
649     output_file = output_file_base + ".txt";
650
651     //check if file exists, and if so create a numbered new file instead
652     if (overwrite_ifexists == false)
653     {
654         struct stat stFileInfo;
655         if (stat(output_file.c_str(),&stFileInfo) == 0)
656         {
657             string output_file_new;
658             int filenum = 0;
659             do
660             {
661                 ++filenum;
662                 output_file_new = output_file_base + "_" + integerToString(filenum);
663                 output_file = output_file_new + ".txt";
664             } while (stat(output_file.c_str(),&stFileInfo) == 0);
665         }
666     }
667
668     //output data to file
669     std::ofstream result_file(output_file.c_str());
670     if (result_file.is_open())
671     {
672         for (size_t i = 0; i < images.size(); ++i)
673         {
674             result_file << images[i].id << " " << scores[i] << endl;
675         }
676         result_file.close();
677     } else {
678         string err_msg = "could not open classifier results file '" + output_file + "' for writing. Before running for the first time, a 'results' subdirectory should be created within the VOC dataset base directory. e.g. if the VOC data is stored in /VOC/VOC2010 then the path /VOC/results must be created.";
679         CV_Error(CV_StsError,err_msg.c_str());
680     }
681 }
682
683 //---------------------------------------
684 //CALCULATE METRICS FROM VOC RESULTS DATA
685 //---------------------------------------
686
687 //Utility function to construct a VOC-standard classification results filename
688 //----------------------------------------------------------------------------
689 //INPUTS:
690 // - obj_class          The VOC object class identifier string
691 // - task               Specifies whether to generate a filename for the classification or detection task
692 // - dataset            Specifies whether working with the training or test set
693 // - (competition)      If specified, defines which competition the results are for (see VOC documentation
694 //                      default of -1 means this is set to 1 for the classification task and 3 for the detection task)
695 // - (number)           If specified and above 0, defines which of a number of duplicate results file produced for a given set of
696 //                      of settings should be used (this number will be added as a postfix to the filename)
697 //NOTES:
698 // This is primarily useful for returning the filename of a classification file previously computed using writeClassifierResultsFile
699 // for example when calling calcClassifierPrecRecall
700 string VocData::getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition, const int number)
701 {
702     if ((competition < 1) && (competition != -1))
703         CV_Error(CV_StsBadArg,"competition argument should be a positive non-zero number or -1 to accept the default");
704     if ((number < 1) && (number != -1))
705         CV_Error(CV_StsBadArg,"number argument should be a positive non-zero number or -1 to accept the default");
706
707     string dset, task_type;
708
709     if (dataset == CV_OBD_TRAIN)
710     {
711         dset = m_train_set;
712     } else {
713         dset = m_test_set;
714     }
715
716     int comp = competition;
717     if (task == CV_VOC_TASK_CLASSIFICATION)
718     {
719         task_type = "cls";
720         if (comp == -1) comp = 1;
721     } else {
722         task_type = "det";
723         if (comp == -1) comp = 3;
724     }
725
726     stringstream ss;
727     if (number < 1)
728     {
729         ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << ".txt";
730     } else {
731         ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << "_" << number << ".txt";
732     }
733
734     string filename = ss.str();
735     return filename;
736 }
737
738 //Calculate metrics for classification results
739 //--------------------------------------------
740 //INPUTS:
741 // - ground_truth       A vector of booleans determining whether the currently tested class is present in each input image
742 // - scores             A vector containing the similarity score for each input image (higher is more similar)
743 //OUTPUTS:
744 // - precision          A vector containing the precision calculated at each datapoint of a p-r curve generated from the result set
745 // - recall             A vector containing the recall calculated at each datapoint of a p-r curve generated from the result set
746 // - ap                The ap metric calculated from the result set
747 // - (ranking)          A vector of the same length as 'ground_truth' and 'scores' containing the order of the indices in both of
748 //                      these arrays when sorting by the ranking score in descending order
749 //NOTES:
750 // The result file path and filename are determined automatically using m_results_directory as a base
751 void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking)
752 {
753     vector<char> res_ground_truth;
754     getClassifierGroundTruth(obj_class, images, res_ground_truth);
755
756     calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
757 }
758
759 void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap)
760 {
761     vector<char> res_ground_truth;
762     getClassifierGroundTruth(obj_class, images, res_ground_truth);
763
764     vector<size_t> ranking;
765     calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
766 }
767
768 //< Overloaded version which accepts VOC classification result file input instead of array of scores/ground truth >
769 //INPUTS:
770 // - input_file         The path to the VOC standard results file to use for calculating precision/recall
771 //                      If a full path is not specified, it is assumed this file is in the VOC standard results directory
772 //                      A VOC standard filename can be retrieved (as used by writeClassifierResultsFile) by calling  getClassifierResultsFilename
773
774 void VocData::calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile)
775 {
776     //read in classification results file
777     vector<string> res_image_codes;
778     vector<float> res_scores;
779
780     string input_file_std = checkFilenamePathsep(input_file);
781     readClassifierResultsFile(input_file_std, res_image_codes, res_scores);
782
783     //extract the object class and dataset from the results file filename
784     string class_name, dataset_name;
785     extractDataFromResultsFilename(input_file_std, class_name, dataset_name);
786
787     //generate the ground truth for the images extracted from the results file
788     vector<char> res_ground_truth;
789
790     getClassifierGroundTruth(class_name, res_image_codes, res_ground_truth);
791
792     if (outputRankingFile)
793     {
794         /* 1. store sorting order by score (descending) in 'order' */
795         vector<std::pair<size_t, vector<float>::const_iterator> > order(res_scores.size());
796
797         size_t n = 0;
798         for (vector<float>::const_iterator it = res_scores.begin(); it != res_scores.end(); ++it, ++n)
799             order[n] = make_pair(n, it);
800
801         std::sort(order.begin(),order.end(),orderingSorter());
802
803         /* 2. save ranking results to text file */
804         string input_file_std1 = checkFilenamePathsep(input_file);
805         size_t fnamestart = input_file_std1.rfind("/");
806         string scoregt_file_str = input_file_std1.substr(0,fnamestart+1) + "scoregt_" + class_name + ".txt";
807         std::ofstream scoregt_file(scoregt_file_str.c_str());
808         if (scoregt_file.is_open())
809         {
810             for (size_t i = 0; i < res_scores.size(); ++i)
811             {
812                 scoregt_file << res_image_codes[order[i].first] << " " << res_scores[order[i].first] << " " << res_ground_truth[order[i].first] << endl;
813             }
814             scoregt_file.close();
815         } else {
816             string err_msg = "could not open scoregt file '" + scoregt_file_str + "' for writing.";
817             CV_Error(CV_StsError,err_msg.c_str());
818         }
819     }
820
821     //finally, calculate precision+recall+ap
822     vector<size_t> ranking;
823     calcPrecRecall_impl(res_ground_truth,res_scores,precision,recall,ap,ranking);
824 }
825
826 //< Protected implementation of Precision-Recall calculation used by both calcClassifierPrecRecall and calcDetectorPrecRecall >
827
828 void VocData::calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization)
829 {
830     CV_Assert(ground_truth.size() == scores.size());
831
832     //add extra element for p-r at 0 recall (in case that first retrieved is positive)
833     vector<float>(scores.size()+1).swap(precision);
834     vector<float>(scores.size()+1).swap(recall);
835
836     // SORT RESULTS BY THEIR SCORE
837     /* 1. store sorting order in 'order' */
838     VocData::getSortOrder(scores, ranking);
839
840 #ifdef PR_DEBUG
841     std::ofstream scoregt_file("D:/pr.txt");
842     if (scoregt_file.is_open())
843     {
844        for (int i = 0; i < scores.size(); ++i)
845        {
846            scoregt_file << scores[ranking[i]] << " " << ground_truth[ranking[i]] << endl;
847        }
848        scoregt_file.close();
849     }
850 #endif
851
852     // CALCULATE PRECISION+RECALL
853
854     int retrieved_hits = 0;
855
856     int recall_norm;
857     if (recall_normalization != -1)
858     {
859         recall_norm = recall_normalization;
860     } else {
861         recall_norm = (int)std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<char>(),(char)1));
862     }
863
864     ap = 0;
865     recall[0] = 0;
866     for (size_t idx = 0; idx < ground_truth.size(); ++idx)
867     {
868         if (ground_truth[ranking[idx]] != 0) ++retrieved_hits;
869
870         precision[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(idx+1);
871         recall[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(recall_norm);
872
873         if (idx == 0)
874         {
875             //add further point at 0 recall with the same precision value as the first computed point
876             precision[idx] = precision[idx+1];
877         }
878         if (recall[idx+1] == 1.0)
879         {
880             //if recall = 1, then end early as all positive images have been found
881             recall.resize(idx+2);
882             precision.resize(idx+2);
883             break;
884         }
885     }
886
887     /* ap calculation */
888     if (m_sampled_ap == false)
889     {
890         // FOR VOC2010+ AP IS CALCULATED FROM ALL DATAPOINTS
891         /* make precision monotonically decreasing for purposes of calculating ap */
892         vector<float> precision_monot(precision.size());
893         vector<float>::iterator prec_m_it = precision_monot.begin();
894         for (vector<float>::iterator prec_it = precision.begin(); prec_it != precision.end(); ++prec_it, ++prec_m_it)
895         {
896             vector<float>::iterator max_elem;
897             max_elem = std::max_element(prec_it,precision.end());
898             (*prec_m_it) = (*max_elem);
899         }
900         /* calculate ap */
901         for (size_t idx = 0; idx < (recall.size()-1); ++idx)
902         {
903             ap += (recall[idx+1] - recall[idx])*precision_monot[idx+1] +   //no need to take min of prec - is monotonically decreasing
904                     0.5f*(recall[idx+1] - recall[idx])*std::abs(precision_monot[idx+1] - precision_monot[idx]);
905         }
906     } else {
907         // FOR BEFORE VOC2010 AP IS CALCULATED BY SAMPLING PRECISION AT RECALL 0.0,0.1,..,1.0
908
909         for (float recall_pos = 0.f; recall_pos <= 1.f; recall_pos += 0.1f)
910         {
911             //find iterator of the precision corresponding to the first recall >= recall_pos
912             vector<float>::iterator recall_it = recall.begin();
913             vector<float>::iterator prec_it = precision.begin();
914
915             while ((*recall_it) < recall_pos)
916             {
917                 ++recall_it;
918                 ++prec_it;
919                 if (recall_it == recall.end()) break;
920             }
921
922             /* if no recall >= recall_pos found, this level of recall is never reached so stop adding to ap */
923             if (recall_it == recall.end()) break;
924
925             /* if the prec_it is valid, compute the max precision at this level of recall or higher */
926             vector<float>::iterator max_prec = std::max_element(prec_it,precision.end());
927
928             ap += (*max_prec)/11;
929         }
930     }
931 }
932
933 /* functions for calculating confusion matrix rows */
934
935 //Calculate rows of a confusion matrix
936 //------------------------------------
937 //INPUTS:
938 // - obj_class          The VOC object class identifier string for the confusion matrix row to compute
939 // - images             An array of ObdImage containing the images to use for the computation
940 // - scores             A corresponding array of confidence scores for the presence of obj_class in each image
941 // - cond               Defines whether to use a cut off point based on recall (CV_VOC_CCOND_RECALL) or score
942 //                      (CV_VOC_CCOND_SCORETHRESH) the latter is useful for classifier detections where positive
943 //                      values are positive detections and negative values are negative detections
944 // - threshold          Threshold value for cond. In case of CV_VOC_CCOND_RECALL, is proportion recall (e.g. 0.5).
945 //                      In the case of CV_VOC_CCOND_SCORETHRESH is the value above which to count results.
946 //OUTPUTS:
947 // - output_headers     An output vector of object class headers for the confusion matrix row
948 // - output_values      An output vector of values for the confusion matrix row corresponding to the classes
949 //                      defined in output_headers
950 //NOTES:
951 // The methodology used by the classifier version of this function is that true positives have a single unit
952 // added to the obj_class column in the confusion matrix row, whereas false positives have a single unit
953 // distributed in proportion between all the columns in the confusion matrix row corresponding to the objects
954 // present in the image.
955 void VocData::calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values)
956 {
957     CV_Assert(images.size() == scores.size());
958
959     // SORT RESULTS BY THEIR SCORE
960     /* 1. store sorting order in 'ranking' */
961     vector<size_t> ranking;
962     VocData::getSortOrder(scores, ranking);
963
964     // CALCULATE CONFUSION MATRIX ENTRIES
965     /* prepare object category headers */
966     output_headers = m_object_classes;
967     vector<float>(output_headers.size(),0.0).swap(output_values);
968     /* find the index of the target object class in the headers for later use */
969     int target_idx;
970     {
971         vector<string>::iterator target_idx_it = std::find(output_headers.begin(),output_headers.end(),obj_class);
972         /* if the target class can not be found, raise an exception */
973         if (target_idx_it == output_headers.end())
974         {
975             string err_msg = "could not find the target object class '" + obj_class + "' in list of valid classes.";
976             CV_Error(CV_StsError,err_msg.c_str());
977         }
978         /* convert iterator to index */
979         target_idx = (int)std::distance(output_headers.begin(),target_idx_it);
980     }
981
982     /* prepare variables related to calculating recall if using the recall threshold */
983     int retrieved_hits = 0;
984     int total_relevant = 0;
985     if (cond == CV_VOC_CCOND_RECALL)
986     {
987         vector<char> ground_truth;
988         /* in order to calculate the total number of relevant images for normalization of recall
989             it's necessary to extract the ground truth for the images under consideration */
990         getClassifierGroundTruth(obj_class, images, ground_truth);
991         total_relevant = (int)std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<char>(),(char)1));
992     }
993
994     /* iterate through images */
995     vector<ObdObject> img_objects;
996     vector<VocObjectData> img_object_data;
997     int total_images = 0;
998     for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
999     {
1000         /* if using the score as the break condition, check for it now */
1001         if (cond == CV_VOC_CCOND_SCORETHRESH)
1002         {
1003             if (scores[ranking[image_idx]] <= threshold) break;
1004         }
1005         /* if continuing for this iteration, increment the image counter for later normalization */
1006         ++total_images;
1007         /* for each image retrieve the objects contained */
1008         getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
1009         //check if the tested for object class is present
1010         if (getClassifierGroundTruthImage(obj_class, images[ranking[image_idx]].id))
1011         {
1012             //if the target class is present, assign fully to the target class element in the confusion matrix row
1013             output_values[target_idx] += 1.0;
1014             if (cond == CV_VOC_CCOND_RECALL) ++retrieved_hits;
1015         } else {
1016             //first delete all objects marked as difficult
1017             for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
1018             {
1019                 if (img_object_data[obj_idx].difficult == true)
1020                 {
1021                     vector<ObdObject>::iterator it1 = img_objects.begin();
1022                     std::advance(it1,obj_idx);
1023                     img_objects.erase(it1);
1024                     vector<VocObjectData>::iterator it2 = img_object_data.begin();
1025                     std::advance(it2,obj_idx);
1026                     img_object_data.erase(it2);
1027                     --obj_idx;
1028                 }
1029             }
1030             //if the target class is not present, add values to the confusion matrix row in equal proportions to all objects present in the image
1031             for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
1032             {
1033                 //find the index of the currently considered object
1034                 vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[obj_idx].object_class);
1035                 //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
1036                 if (class_idx_it == output_headers.end())
1037                 {
1038                     string err_msg = "could not find object class '" + img_objects[obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
1039                     CV_Error(CV_StsError,err_msg.c_str());
1040                 }
1041                 /* convert iterator to index */
1042                 int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
1043                 //add to confusion matrix row in proportion
1044                 output_values[class_idx] += 1.f/static_cast<float>(img_objects.size());
1045             }
1046         }
1047         //check break conditions if breaking on certain level of recall
1048         if (cond == CV_VOC_CCOND_RECALL)
1049         {
1050             if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
1051         }
1052     }
1053     /* finally, normalize confusion matrix row */
1054     for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
1055     {
1056         (*it) /= static_cast<float>(total_images);
1057     }
1058 }
1059
1060 // NOTE: doesn't ignore repeated detections
1061 void VocData::calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult)
1062 {
1063     CV_Assert(images.size() == scores.size());
1064     CV_Assert(images.size() == bounding_boxes.size());
1065
1066     //collapse scores and ground_truth vectors into 1D vectors to allow ranking
1067     /* define final flat vectors */
1068     vector<string> images_flat;
1069     vector<float> scores_flat;
1070     vector<Rect> bounding_boxes_flat;
1071     {
1072         /* first count how many objects to allow preallocation */
1073         int obj_count = 0;
1074         CV_Assert(scores.size() == bounding_boxes.size());
1075         for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
1076         {
1077             CV_Assert(scores[img_idx].size() == bounding_boxes[img_idx].size());
1078             for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
1079             {
1080                 ++obj_count;
1081             }
1082         }
1083         /* preallocate vectors */
1084         images_flat.resize(obj_count);
1085         scores_flat.resize(obj_count);
1086         bounding_boxes_flat.resize(obj_count);
1087         /* now copy across to preallocated vectors */
1088         int flat_pos = 0;
1089         for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
1090         {
1091             for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
1092             {
1093                 images_flat[flat_pos] = images[img_idx].id;
1094                 scores_flat[flat_pos] = scores[img_idx][obj_idx];
1095                 bounding_boxes_flat[flat_pos] = bounding_boxes[img_idx][obj_idx];
1096                 ++flat_pos;
1097             }
1098         }
1099     }
1100
1101     // SORT RESULTS BY THEIR SCORE
1102     /* 1. store sorting order in 'ranking' */
1103     vector<size_t> ranking;
1104     VocData::getSortOrder(scores_flat, ranking);
1105
1106     // CALCULATE CONFUSION MATRIX ENTRIES
1107     /* prepare object category headers */
1108     output_headers = m_object_classes;
1109     output_headers.push_back("background");
1110     vector<float>(output_headers.size(),0.0).swap(output_values);
1111
1112     /* prepare variables related to calculating recall if using the recall threshold */
1113     int retrieved_hits = 0;
1114     int total_relevant = 0;
1115     if (cond == CV_VOC_CCOND_RECALL)
1116     {
1117 //        vector<char> ground_truth;
1118 //        /* in order to calculate the total number of relevant images for normalization of recall
1119 //            it's necessary to extract the ground truth for the images under consideration */
1120 //        getClassifierGroundTruth(obj_class, images, ground_truth);
1121 //        total_relevant = std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<bool>(),true));
1122         /* calculate the total number of objects in the ground truth for the current dataset */
1123         vector<ObdImage> gt_images;
1124         vector<char> gt_object_present;
1125         getClassImages(obj_class, dataset, gt_images, gt_object_present);
1126
1127         for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
1128         {
1129             vector<ObdObject> gt_img_objects;
1130             vector<VocObjectData> gt_img_object_data;
1131             getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
1132             for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
1133             {
1134                 if (gt_img_objects[obj_idx].object_class == obj_class)
1135                 {
1136                     if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
1137                         ++total_relevant;
1138                 }
1139             }
1140         }
1141     }
1142
1143     /* iterate through objects */
1144     vector<ObdObject> img_objects;
1145     vector<VocObjectData> img_object_data;
1146     int total_objects = 0;
1147     for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
1148     {
1149         /* if using the score as the break condition, check for it now */
1150         if (cond == CV_VOC_CCOND_SCORETHRESH)
1151         {
1152             if (scores_flat[ranking[image_idx]] <= threshold) break;
1153         }
1154         /* increment the image counter for later normalization */
1155         ++total_objects;
1156         /* for each image retrieve the objects contained */
1157         getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
1158
1159         //find the ground truth object which has the highest overlap score with the detected object
1160         float maxov = -1.0;
1161         int max_gt_obj_idx = -1;
1162         //-- for each detected object iterate through objects present in ground truth --
1163         for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects.size(); ++gt_obj_idx)
1164         {
1165             //check difficulty flag
1166             if (ignore_difficult || (img_object_data[gt_obj_idx].difficult == false))
1167             {
1168                 //if the class matches, then check if the detected object and ground truth object overlap by a sufficient margin
1169                 float ov = testBoundingBoxesForOverlap(bounding_boxes_flat[ranking[image_idx]], img_objects[gt_obj_idx].boundingBox);
1170                 if (ov != -1.f)
1171                 {
1172                     //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
1173                     if (ov > maxov)
1174                     {
1175                         maxov = ov;
1176                         max_gt_obj_idx = (int)gt_obj_idx;
1177                     }
1178                 }
1179             }
1180         }
1181
1182         //assign to appropriate object class if an object was detected
1183         if (maxov != -1.0)
1184         {
1185             //find the index of the currently considered object
1186             vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[max_gt_obj_idx].object_class);
1187             //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
1188             if (class_idx_it == output_headers.end())
1189             {
1190                 string err_msg = "could not find object class '" + img_objects[max_gt_obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
1191                 CV_Error(CV_StsError,err_msg.c_str());
1192             }
1193             /* convert iterator to index */
1194             int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
1195             //add to confusion matrix row in proportion
1196             output_values[class_idx] += 1.0;
1197         } else {
1198             //otherwise assign to background class
1199             output_values[output_values.size()-1] += 1.0;
1200         }
1201
1202         //check break conditions if breaking on certain level of recall
1203         if (cond == CV_VOC_CCOND_RECALL)
1204         {
1205             if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
1206         }
1207     }
1208
1209     /* finally, normalize confusion matrix row */
1210     for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
1211     {
1212         (*it) /= static_cast<float>(total_objects);
1213     }
1214 }
1215
1216 //Save Precision-Recall results to a p-r curve in GNUPlot format
1217 //--------------------------------------------------------------
1218 //INPUTS:
1219 // - output_file        The file to which to save the GNUPlot data file. If only a filename is specified, the data
1220 //                      file is saved to the standard VOC results directory.
1221 // - precision          Vector of precisions as returned from calcClassifier/DetectorPrecRecall
1222 // - recall             Vector of recalls as returned from calcClassifier/DetectorPrecRecall
1223 // - ap                ap as returned from calcClassifier/DetectorPrecRecall
1224 // - (title)            Title to use for the plot (if not specified, just the ap is printed as the title)
1225 //                      This also specifies the filename of the output file if printing to pdf
1226 // - (plot_type)        Specifies whether to instruct GNUPlot to save to a PDF file (CV_VOC_PLOT_PDF) or directly
1227 //                      to screen (CV_VOC_PLOT_SCREEN) in the datafile
1228 //NOTES:
1229 // The GNUPlot data file can be executed using GNUPlot from the commandline in the following way:
1230 //      >> GNUPlot <output_file>
1231 // This will then display the p-r curve on the screen or save it to a pdf file depending on plot_type
1232
1233 void VocData::savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title, const VocPlotType plot_type)
1234 {
1235     string output_file_std = checkFilenamePathsep(output_file);
1236
1237     //if no directory is specified, by default save the output file in the results directory
1238 //    if (output_file_std.find("/") == output_file_std.npos)
1239 //    {
1240 //        output_file_std = m_results_directory + output_file_std;
1241 //    }
1242
1243     std::ofstream plot_file(output_file_std.c_str());
1244
1245     if (plot_file.is_open())
1246     {
1247         plot_file << "set xrange [0:1]" << endl;
1248         plot_file << "set yrange [0:1]" << endl;
1249         plot_file << "set size square" << endl;
1250         string title_text = title;
1251         if (title_text.size() == 0) title_text = "Precision-Recall Curve";
1252         plot_file << "set title \"" << title_text << " (ap: " << ap << ")\"" << endl;
1253         plot_file << "set xlabel \"Recall\"" << endl;
1254         plot_file << "set ylabel \"Precision\"" << endl;
1255         plot_file << "set style data lines" << endl;
1256         plot_file << "set nokey" << endl;
1257         if (plot_type == CV_VOC_PLOT_PNG)
1258         {
1259             plot_file << "set terminal png" << endl;
1260             string pdf_filename;
1261             if (title.size() != 0)
1262             {
1263                 pdf_filename = title;
1264             } else {
1265                 pdf_filename = "prcurve";
1266             }
1267             plot_file << "set out \"" << title << ".png\"" << endl;
1268         }
1269         plot_file << "plot \"-\" using 1:2" << endl;
1270         plot_file << "# X Y" << endl;
1271         CV_Assert(precision.size() == recall.size());
1272         for (size_t i = 0; i < precision.size(); ++i)
1273         {
1274             plot_file << "  " << recall[i] << " " << precision[i] << endl;
1275         }
1276         plot_file << "end" << endl;
1277         if (plot_type == CV_VOC_PLOT_SCREEN)
1278         {
1279             plot_file << "pause -1" << endl;
1280         }
1281         plot_file.close();
1282     } else {
1283         string err_msg = "could not open plot file '" + output_file_std + "' for writing.";
1284         CV_Error(CV_StsError,err_msg.c_str());
1285     }
1286 }
1287
1288 void VocData::readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
1289 {
1290     images.clear();
1291
1292     string gtFilename = m_class_imageset_path;
1293     gtFilename.replace(gtFilename.find("%s"),2,obj_class);
1294     if (dataset == CV_OBD_TRAIN)
1295     {
1296         gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
1297     } else {
1298         gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
1299     }
1300
1301     vector<string> image_codes;
1302     readClassifierGroundTruth(gtFilename, image_codes, object_present);
1303
1304     convertImageCodesToObdImages(image_codes, images);
1305 }
1306
1307 void VocData::readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores)
1308 {
1309     images.clear();
1310
1311     string input_file_std = checkFilenamePathsep(input_file);
1312
1313     //if no directory is specified, by default search for the input file in the results directory
1314 //    if (input_file_std.find("/") == input_file_std.npos)
1315 //    {
1316 //        input_file_std = m_results_directory + input_file_std;
1317 //    }
1318
1319     vector<string> image_codes;
1320     readClassifierResultsFile(input_file_std, image_codes, scores);
1321
1322     convertImageCodesToObdImages(image_codes, images);
1323 }
1324
1325 void VocData::readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
1326 {
1327     images.clear();
1328
1329     string input_file_std = checkFilenamePathsep(input_file);
1330
1331     //if no directory is specified, by default search for the input file in the results directory
1332 //    if (input_file_std.find("/") == input_file_std.npos)
1333 //    {
1334 //        input_file_std = m_results_directory + input_file_std;
1335 //    }
1336
1337     vector<string> image_codes;
1338     readDetectorResultsFile(input_file_std, image_codes, scores, bounding_boxes);
1339
1340     convertImageCodesToObdImages(image_codes, images);
1341 }
1342
1343 const vector<string>& VocData::getObjectClasses()
1344 {
1345     return m_object_classes;
1346 }
1347
1348 //string VocData::getResultsDirectory()
1349 //{
1350 //    return m_results_directory;
1351 //}
1352
1353 //---------------------------------------------------------
1354 // Protected Functions ------------------------------------
1355 //---------------------------------------------------------
1356
1357 static string getVocName( const string& vocPath )
1358 {
1359     size_t found = vocPath.rfind( '/' );
1360     if( found == string::npos )
1361     {
1362         found = vocPath.rfind( '\\' );
1363         if( found == string::npos )
1364             return vocPath;
1365     }
1366     return vocPath.substr(found + 1, vocPath.size() - found);
1367 }
1368
1369 void VocData::initVoc( const string& vocPath, const bool useTestDataset )
1370 {
1371     initVoc2007to2010( vocPath, useTestDataset );
1372 }
1373
1374 //Initialize file paths and settings for the VOC 2010 dataset
1375 //-----------------------------------------------------------
1376 void VocData::initVoc2007to2010( const string& vocPath, const bool useTestDataset )
1377 {
1378     //check format of root directory and modify if necessary
1379
1380     m_vocName = getVocName( vocPath );
1381
1382     CV_Assert( !m_vocName.compare("VOC2007") || !m_vocName.compare("VOC2008") ||
1383                !m_vocName.compare("VOC2009") || !m_vocName.compare("VOC2010") );
1384
1385     m_vocPath = checkFilenamePathsep( vocPath, true );
1386
1387     if (useTestDataset)
1388     {
1389         m_train_set = "trainval";
1390         m_test_set = "test";
1391     } else {
1392         m_train_set = "train";
1393         m_test_set = "val";
1394     }
1395
1396     // initialize main classification/detection challenge paths
1397     m_annotation_path = m_vocPath + "/Annotations/%s.xml";
1398     m_image_path = m_vocPath + "/JPEGImages/%s.jpg";
1399     m_imageset_path = m_vocPath + "/ImageSets/Main/%s.txt";
1400     m_class_imageset_path = m_vocPath + "/ImageSets/Main/%s_%s.txt";
1401
1402     //define available object_classes for VOC2010 dataset
1403     m_object_classes.push_back("aeroplane");
1404     m_object_classes.push_back("bicycle");
1405     m_object_classes.push_back("bird");
1406     m_object_classes.push_back("boat");
1407     m_object_classes.push_back("bottle");
1408     m_object_classes.push_back("bus");
1409     m_object_classes.push_back("car");
1410     m_object_classes.push_back("cat");
1411     m_object_classes.push_back("chair");
1412     m_object_classes.push_back("cow");
1413     m_object_classes.push_back("diningtable");
1414     m_object_classes.push_back("dog");
1415     m_object_classes.push_back("horse");
1416     m_object_classes.push_back("motorbike");
1417     m_object_classes.push_back("person");
1418     m_object_classes.push_back("pottedplant");
1419     m_object_classes.push_back("sheep");
1420     m_object_classes.push_back("sofa");
1421     m_object_classes.push_back("train");
1422     m_object_classes.push_back("tvmonitor");
1423
1424     m_min_overlap = 0.5;
1425
1426     //up until VOC 2010, ap was calculated by sampling p-r curve, not taking complete curve
1427     m_sampled_ap = ((m_vocName == "VOC2007") || (m_vocName == "VOC2008") || (m_vocName == "VOC2009"));
1428 }
1429
1430 //Read a VOC classification ground truth text file for a given object class and dataset
1431 //-------------------------------------------------------------------------------------
1432 //INPUTS:
1433 // - filename           The path of the text file to read
1434 //OUTPUTS:
1435 // - image_codes        VOC image codes extracted from the GT file in the form 20XX_XXXXXX where the first four
1436 //                          digits specify the year of the dataset, and the last group specifies a unique ID
1437 // - object_present     For each image in the 'image_codes' array, specifies whether the object class described
1438 //                          in the loaded GT file is present or not
1439 void VocData::readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present)
1440 {
1441     image_codes.clear();
1442     object_present.clear();
1443
1444     std::ifstream gtfile(filename.c_str());
1445     if (!gtfile.is_open())
1446     {
1447         string err_msg = "could not open VOC ground truth textfile '" + filename + "'.";
1448         CV_Error(CV_StsError,err_msg.c_str());
1449     }
1450
1451     string line;
1452     string image;
1453     int obj_present = 0;
1454     while (!gtfile.eof())
1455     {
1456         std::getline(gtfile,line);
1457         std::istringstream iss(line);
1458         iss >> image >> obj_present;
1459         if (!iss.fail())
1460         {
1461             image_codes.push_back(image);
1462             object_present.push_back(obj_present == 1);
1463         } else {
1464             if (!gtfile.eof()) CV_Error(CV_StsParseError,"error parsing VOC ground truth textfile.");
1465         }
1466     }
1467     gtfile.close();
1468 }
1469
1470 void VocData::readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores)
1471 {
1472     //check if results file exists
1473     std::ifstream result_file(input_file.c_str());
1474     if (result_file.is_open())
1475     {
1476         string line;
1477         string image;
1478         float score;
1479         //read in the results file
1480         while (!result_file.eof())
1481         {
1482             std::getline(result_file,line);
1483             std::istringstream iss(line);
1484             iss >> image >> score;
1485             if (!iss.fail())
1486             {
1487                 image_codes.push_back(image);
1488                 scores.push_back(score);
1489             } else {
1490                 if(!result_file.eof()) CV_Error(CV_StsParseError,"error parsing VOC classifier results file.");
1491             }
1492         }
1493         result_file.close();
1494     } else {
1495         string err_msg = "could not open classifier results file '" + input_file + "' for reading.";
1496         CV_Error(CV_StsError,err_msg.c_str());
1497     }
1498 }
1499
1500 void VocData::readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
1501 {
1502     image_codes.clear();
1503     scores.clear();
1504     bounding_boxes.clear();
1505
1506     //check if results file exists
1507     std::ifstream result_file(input_file.c_str());
1508     if (result_file.is_open())
1509     {
1510         string line;
1511         string image;
1512         Rect bounding_box;
1513         float score;
1514         //read in the results file
1515         while (!result_file.eof())
1516         {
1517             std::getline(result_file,line);
1518             std::istringstream iss(line);
1519             iss >> image >> score >> bounding_box.x >> bounding_box.y >> bounding_box.width >> bounding_box.height;
1520             if (!iss.fail())
1521             {
1522                 //convert right and bottom positions to width and height
1523                 bounding_box.width -= bounding_box.x;
1524                 bounding_box.height -= bounding_box.y;
1525                 //convert to 0-indexing
1526                 bounding_box.x -= 1;
1527                 bounding_box.y -= 1;
1528                 //store in output vectors
1529                 /* first check if the current image code has been seen before */
1530                 vector<string>::iterator image_codes_it = std::find(image_codes.begin(),image_codes.end(),image);
1531                 if (image_codes_it == image_codes.end())
1532                 {
1533                     image_codes.push_back(image);
1534                     vector<float> score_vect(1);
1535                     score_vect[0] = score;
1536                     scores.push_back(score_vect);
1537                     vector<Rect> bounding_box_vect(1);
1538                     bounding_box_vect[0] = bounding_box;
1539                     bounding_boxes.push_back(bounding_box_vect);
1540                 } else {
1541                     /* if the image index has been seen before, add the current object below it in the 2D arrays */
1542                     int image_idx = (int)std::distance(image_codes.begin(),image_codes_it);
1543                     scores[image_idx].push_back(score);
1544                     bounding_boxes[image_idx].push_back(bounding_box);
1545                 }
1546             } else {
1547                 if(!result_file.eof()) CV_Error(CV_StsParseError,"error parsing VOC detector results file.");
1548             }
1549         }
1550         result_file.close();
1551     } else {
1552         string err_msg = "could not open detector results file '" + input_file + "' for reading.";
1553         CV_Error(CV_StsError,err_msg.c_str());
1554     }
1555 }
1556
1557
1558 //Read a VOC annotation xml file for a given image
1559 //------------------------------------------------
1560 //INPUTS:
1561 // - filename           The path of the xml file to read
1562 //OUTPUTS:
1563 // - objects            Array of VocObject describing all object instances present in the given image
1564 void VocData::extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
1565 {
1566 #ifdef PR_DEBUG
1567     int block = 1;
1568     cout << "SAMPLE VOC OBJECT EXTRACTION for " << filename << ":" << endl;
1569 #endif
1570     objects.clear();
1571     object_data.clear();
1572
1573     string contents, object_contents, tag_contents;
1574
1575     readFileToString(filename, contents);
1576
1577     //keep on extracting 'object' blocks until no more can be found
1578     if (extractXMLBlock(contents, "annotation", 0, contents) != -1)
1579     {
1580         int searchpos = 0;
1581         searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
1582         while (searchpos != -1)
1583         {
1584 #ifdef PR_DEBUG
1585             cout << "SEARCHPOS:" << searchpos << endl;
1586             cout << "start block " << block << " ---------" << endl;
1587             cout << object_contents << endl;
1588             cout << "end block " << block << " -----------" << endl;
1589             ++block;
1590 #endif
1591
1592             ObdObject object;
1593             VocObjectData object_d;
1594
1595             //object class -------------
1596
1597             if (extractXMLBlock(object_contents, "name", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <name> tag in object definition of '" + filename + "'");
1598             object.object_class.swap(tag_contents);
1599
1600             //object bounding box -------------
1601
1602             int xmax, xmin, ymax, ymin;
1603
1604             if (extractXMLBlock(object_contents, "xmax", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <xmax> tag in object definition of '" + filename + "'");
1605             xmax = stringToInteger(tag_contents);
1606
1607             if (extractXMLBlock(object_contents, "xmin", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <xmin> tag in object definition of '" + filename + "'");
1608             xmin = stringToInteger(tag_contents);
1609
1610             if (extractXMLBlock(object_contents, "ymax", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <ymax> tag in object definition of '" + filename + "'");
1611             ymax = stringToInteger(tag_contents);
1612
1613             if (extractXMLBlock(object_contents, "ymin", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <ymin> tag in object definition of '" + filename + "'");
1614             ymin = stringToInteger(tag_contents);
1615
1616             object.boundingBox.x = xmin-1;      //convert to 0-based indexing
1617             object.boundingBox.width = xmax - xmin;
1618             object.boundingBox.y = ymin-1;
1619             object.boundingBox.height = ymax - ymin;
1620
1621             CV_Assert(xmin != 0);
1622             CV_Assert(xmax > xmin);
1623             CV_Assert(ymin != 0);
1624             CV_Assert(ymax > ymin);
1625
1626
1627             //object tags -------------
1628
1629             if (extractXMLBlock(object_contents, "difficult", 0, tag_contents) != -1)
1630             {
1631                 object_d.difficult = (tag_contents == "1");
1632             } else object_d.difficult = false;
1633             if (extractXMLBlock(object_contents, "occluded", 0, tag_contents) != -1)
1634             {
1635                 object_d.occluded = (tag_contents == "1");
1636             } else object_d.occluded = false;
1637             if (extractXMLBlock(object_contents, "truncated", 0, tag_contents) != -1)
1638             {
1639                 object_d.truncated = (tag_contents == "1");
1640             } else object_d.truncated = false;
1641             if (extractXMLBlock(object_contents, "pose", 0, tag_contents) != -1)
1642             {
1643                 if (tag_contents == "Frontal") object_d.pose = CV_VOC_POSE_FRONTAL;
1644                 if (tag_contents == "Rear") object_d.pose = CV_VOC_POSE_REAR;
1645                 if (tag_contents == "Left") object_d.pose = CV_VOC_POSE_LEFT;
1646                 if (tag_contents == "Right") object_d.pose = CV_VOC_POSE_RIGHT;
1647             }
1648
1649             //add to array of objects
1650             objects.push_back(object);
1651             object_data.push_back(object_d);
1652
1653             //extract next 'object' block from file if it exists
1654             searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
1655         }
1656     }
1657 }
1658
1659 //Converts an image identifier string in the format YYYY_XXXXXX to a single index integer of form XXXXXXYYYY
1660 //where Y represents a year and returns the image path
1661 //----------------------------------------------------------------------------------------------------------
1662 string VocData::getImagePath(const string& input_str)
1663 {
1664     string path = m_image_path;
1665     path.replace(path.find("%s"),2,input_str);
1666     return path;
1667 }
1668
1669 //Tests two boundary boxes for overlap (using the intersection over union metric) and returns the overlap if the objects
1670 //defined by the two bounding boxes are considered to be matched according to the criterion outlined in
1671 //the VOC documentation [namely intersection/union > some threshold] otherwise returns -1.0 (no match)
1672 //----------------------------------------------------------------------------------------------------------
1673 float VocData::testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth)
1674 {
1675     int detection_x2 = detection.x + detection.width;
1676     int detection_y2 = detection.y + detection.height;
1677     int ground_truth_x2 = ground_truth.x + ground_truth.width;
1678     int ground_truth_y2 = ground_truth.y + ground_truth.height;
1679     //first calculate the boundaries of the intersection of the rectangles
1680     int intersection_x = std::max(detection.x, ground_truth.x); //rightmost left
1681     int intersection_y = std::max(detection.y, ground_truth.y); //bottommost top
1682     int intersection_x2 = std::min(detection_x2, ground_truth_x2); //leftmost right
1683     int intersection_y2 = std::min(detection_y2, ground_truth_y2); //topmost bottom
1684     //then calculate the width and height of the intersection rect
1685     int intersection_width = intersection_x2 - intersection_x + 1;
1686     int intersection_height = intersection_y2 - intersection_y + 1;
1687     //if there is no overlap then return false straight away
1688     if ((intersection_width <= 0) || (intersection_height <= 0)) return -1.0;
1689     //otherwise calculate the intersection
1690     int intersection_area = intersection_width*intersection_height;
1691
1692     //now calculate the union
1693     int union_area = (detection.width+1)*(detection.height+1) + (ground_truth.width+1)*(ground_truth.height+1) - intersection_area;
1694
1695     //calculate the intersection over union and use as threshold as per VOC documentation
1696     float overlap = static_cast<float>(intersection_area)/static_cast<float>(union_area);
1697     if (overlap > m_min_overlap)
1698     {
1699         return overlap;
1700     } else {
1701         return -1.0;
1702     }
1703 }
1704
1705 //Extracts the object class and dataset from the filename of a VOC standard results text file, which takes
1706 //the format 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'
1707 //----------------------------------------------------------------------------------------------------------
1708 void VocData::extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name)
1709 {
1710     string input_file_std = checkFilenamePathsep(input_file);
1711
1712     size_t fnamestart = input_file_std.rfind("/");
1713     size_t fnameend = input_file_std.rfind(".txt");
1714
1715     if ((fnamestart == input_file_std.npos) || (fnameend == input_file_std.npos))
1716         CV_Error(CV_StsError,"Could not extract filename of results file.");
1717
1718     ++fnamestart;
1719     if (fnamestart >= fnameend)
1720         CV_Error(CV_StsError,"Could not extract filename of results file.");
1721
1722     //extract dataset and class names, triggering exception if the filename format is not correct
1723     string filename = input_file_std.substr(fnamestart, fnameend-fnamestart);
1724     size_t datasetstart = filename.find("_");
1725     datasetstart = filename.find("_",datasetstart+1);
1726     size_t classstart = filename.find("_",datasetstart+1);
1727     //allow for appended index after a further '_' by discarding this part if it exists
1728     size_t classend = filename.find("_",classstart+1);
1729     if (classend == filename.npos) classend = filename.size();
1730     if ((datasetstart == filename.npos) || (classstart == filename.npos))
1731         CV_Error(CV_StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
1732     ++datasetstart;
1733     ++classstart;
1734     if (((datasetstart-classstart) < 1) || ((classend-datasetstart) < 1))
1735         CV_Error(CV_StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
1736
1737     dataset_name = filename.substr(datasetstart,classstart-datasetstart-1);
1738     class_name = filename.substr(classstart,classend-classstart);
1739 }
1740
1741 bool VocData::getClassifierGroundTruthImage(const string& obj_class, const string& id)
1742 {
1743     /* if the classifier ground truth data for all images of the current class has not been loaded yet, load it now */
1744     if (m_classifier_gt_all_ids.empty() || (m_classifier_gt_class != obj_class))
1745     {
1746         m_classifier_gt_all_ids.clear();
1747         m_classifier_gt_all_present.clear();
1748         m_classifier_gt_class = obj_class;
1749         for (int i=0; i<2; ++i) //run twice (once over test set and once over training set)
1750         {
1751             //generate the filename of the classification ground-truth textfile for the object class
1752             string gtFilename = m_class_imageset_path;
1753             gtFilename.replace(gtFilename.find("%s"),2,obj_class);
1754             if (i == 0)
1755             {
1756                 gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
1757             } else {
1758                 gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
1759             }
1760
1761             //parse the ground truth file, storing in two separate vectors
1762             //for the image code and the ground truth value
1763             vector<string> image_codes;
1764             vector<char> object_present;
1765             readClassifierGroundTruth(gtFilename, image_codes, object_present);
1766
1767             m_classifier_gt_all_ids.insert(m_classifier_gt_all_ids.end(),image_codes.begin(),image_codes.end());
1768             m_classifier_gt_all_present.insert(m_classifier_gt_all_present.end(),object_present.begin(),object_present.end());
1769
1770             CV_Assert(m_classifier_gt_all_ids.size() == m_classifier_gt_all_present.size());
1771         }
1772     }
1773
1774
1775     //search for the image code
1776     vector<string>::iterator it = find (m_classifier_gt_all_ids.begin(), m_classifier_gt_all_ids.end(), id);
1777     if (it != m_classifier_gt_all_ids.end())
1778     {
1779         //image found, so return corresponding ground truth
1780         return m_classifier_gt_all_present[std::distance(m_classifier_gt_all_ids.begin(),it)] != 0;
1781     } else {
1782         string err_msg = "could not find classifier ground truth for image '" + id + "' and class '" + obj_class + "'";
1783         CV_Error(CV_StsError,err_msg.c_str());
1784     }
1785
1786     return true;
1787 }
1788
1789 //-------------------------------------------------------------------
1790 // Protected Functions (utility) ------------------------------------
1791 //-------------------------------------------------------------------
1792
1793 //returns a vector containing indexes of the input vector in sorted ascending/descending order
1794 void VocData::getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending)
1795 {
1796     /* 1. store sorting order in 'order_pair' */
1797     vector<std::pair<size_t, vector<float>::const_iterator> > order_pair(values.size());
1798
1799     size_t n = 0;
1800     for (vector<float>::const_iterator it = values.begin(); it != values.end(); ++it, ++n)
1801         order_pair[n] = make_pair(n, it);
1802
1803     std::sort(order_pair.begin(),order_pair.end(),orderingSorter());
1804     if (descending == false) std::reverse(order_pair.begin(),order_pair.end());
1805
1806     vector<size_t>(order_pair.size()).swap(order);
1807     for (size_t i = 0; i < order_pair.size(); ++i)
1808     {
1809         order[i] = order_pair[i].first;
1810     }
1811 }
1812
1813 void VocData::readFileToString(const string filename, string& file_contents)
1814 {
1815     std::ifstream ifs(filename.c_str());
1816     if (!ifs.is_open()) CV_Error(CV_StsError,"could not open text file");
1817
1818     stringstream oss;
1819     oss << ifs.rdbuf();
1820
1821     file_contents = oss.str();
1822 }
1823
1824 int VocData::stringToInteger(const string input_str)
1825 {
1826     int result = 0;
1827
1828     stringstream ss(input_str);
1829     if ((ss >> result).fail())
1830     {
1831         CV_Error(CV_StsBadArg,"could not perform string to integer conversion");
1832     }
1833     return result;
1834 }
1835
1836 string VocData::integerToString(const int input_int)
1837 {
1838     string result;
1839
1840     stringstream ss;
1841     if ((ss << input_int).fail())
1842     {
1843         CV_Error(CV_StsBadArg,"could not perform integer to string conversion");
1844     }
1845     result = ss.str();
1846     return result;
1847 }
1848
1849 string VocData::checkFilenamePathsep( const string filename, bool add_trailing_slash )
1850 {
1851     string filename_new = filename;
1852
1853     size_t pos = filename_new.find("\\\\");
1854     while (pos != filename_new.npos)
1855     {
1856         filename_new.replace(pos,2,"/");
1857         pos = filename_new.find("\\\\", pos);
1858     }
1859     pos = filename_new.find("\\");
1860     while (pos != filename_new.npos)
1861     {
1862         filename_new.replace(pos,1,"/");
1863         pos = filename_new.find("\\", pos);
1864     }
1865     if (add_trailing_slash)
1866     {
1867         //add training slash if this is missing
1868         if (filename_new.rfind("/") != filename_new.length()-1) filename_new += "/";
1869     }
1870
1871     return filename_new;
1872 }
1873
1874 void VocData::convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images)
1875 {
1876     images.clear();
1877     images.reserve(image_codes.size());
1878
1879     string path;
1880     //transfer to output arrays
1881     for (size_t i = 0; i < image_codes.size(); ++i)
1882     {
1883         //generate image path and indices from extracted string code
1884         path = getImagePath(image_codes[i]);
1885         images.push_back(ObdImage(image_codes[i], path));
1886     }
1887 }
1888
1889 //Extract text from within a given tag from an XML file
1890 //-----------------------------------------------------
1891 //INPUTS:
1892 // - src            XML source file
1893 // - tag            XML tag delimiting block to extract
1894 // - searchpos      position within src at which to start search
1895 //OUTPUTS:
1896 // - tag_contents   text extracted between <tag> and </tag> tags
1897 //RETURN VALUE:
1898 // - the position of the final character extracted in tag_contents within src
1899 //      (can be used to call extractXMLBlock recursively to extract multiple blocks)
1900 //      returns -1 if the tag could not be found
1901 int VocData::extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents)
1902 {
1903     size_t startpos, next_startpos, endpos;
1904     int embed_count = 1;
1905
1906     //find position of opening tag
1907     startpos = src.find("<" + tag + ">", searchpos);
1908     if (startpos == string::npos) return -1;
1909
1910     //initialize endpos -
1911     // start searching for end tag anywhere after opening tag
1912     endpos = startpos;
1913
1914     //find position of next opening tag
1915     next_startpos = src.find("<" + tag + ">", startpos+1);
1916
1917     //match opening tags with closing tags, and only
1918     //accept final closing tag of same level as original
1919     //opening tag
1920     while (embed_count > 0)
1921     {
1922         endpos = src.find("</" + tag + ">", endpos+1);
1923         if (endpos == string::npos) return -1;
1924
1925         //the next code is only executed if there are embedded tags with the same name
1926         if (next_startpos != string::npos)
1927         {
1928             while (next_startpos<endpos)
1929             {
1930                 //counting embedded start tags
1931                 ++embed_count;
1932                 next_startpos = src.find("<" + tag + ">", next_startpos+1);
1933                 if (next_startpos == string::npos) break;
1934             }
1935         }
1936         //passing end tag so decrement nesting level
1937         --embed_count;
1938     }
1939
1940     //finally, extract the tag region
1941     startpos += tag.length() + 2;
1942     if (startpos > src.length()) return -1;
1943     if (endpos > src.length()) return -1;
1944     tag_contents = src.substr(startpos,endpos-startpos);
1945     return static_cast<int>(endpos);
1946 }
1947
1948 /****************************************************************************************\
1949 *                            Sample on image classification                             *
1950 \****************************************************************************************/
1951 //
1952 // This part of the code was a little refactor
1953 //
1954 struct DDMParams
1955 {
1956     DDMParams() : detectorType("SURF"), descriptorType("SURF"), matcherType("BruteForce") {}
1957     DDMParams( const string _detectorType, const string _descriptorType, const string& _matcherType ) :
1958         detectorType(_detectorType), descriptorType(_descriptorType), matcherType(_matcherType){}
1959     void read( const FileNode& fn )
1960     {
1961         fn["detectorType"] >> detectorType;
1962         fn["descriptorType"] >> descriptorType;
1963         fn["matcherType"] >> matcherType;
1964     }
1965     void write( FileStorage& fs ) const
1966     {
1967         fs << "detectorType" << detectorType;
1968         fs << "descriptorType" << descriptorType;
1969         fs << "matcherType" << matcherType;
1970     }
1971     void print() const
1972     {
1973         cout << "detectorType: " << detectorType << endl;
1974         cout << "descriptorType: " << descriptorType << endl;
1975         cout << "matcherType: " << matcherType << endl;
1976     }
1977
1978     string detectorType;
1979     string descriptorType;
1980     string matcherType;
1981 };
1982
1983 struct VocabTrainParams
1984 {
1985     VocabTrainParams() : trainObjClass("chair"), vocabSize(1000), memoryUse(200), descProportion(0.3f) {}
1986     VocabTrainParams( const string _trainObjClass, size_t _vocabSize, size_t _memoryUse, float _descProportion ) :
1987             trainObjClass(_trainObjClass), vocabSize((int)_vocabSize), memoryUse((int)_memoryUse), descProportion(_descProportion) {}
1988     void read( const FileNode& fn )
1989     {
1990         fn["trainObjClass"] >> trainObjClass;
1991         fn["vocabSize"] >> vocabSize;
1992         fn["memoryUse"] >> memoryUse;
1993         fn["descProportion"] >> descProportion;
1994     }
1995     void write( FileStorage& fs ) const
1996     {
1997         fs << "trainObjClass" << trainObjClass;
1998         fs << "vocabSize" << vocabSize;
1999         fs << "memoryUse" << memoryUse;
2000         fs << "descProportion" << descProportion;
2001     }
2002     void print() const
2003     {
2004         cout << "trainObjClass: " << trainObjClass << endl;
2005         cout << "vocabSize: " << vocabSize << endl;
2006         cout << "memoryUse: " << memoryUse << endl;
2007         cout << "descProportion: " << descProportion << endl;
2008     }
2009
2010
2011     string trainObjClass; // Object class used for training visual vocabulary.
2012                           // It shouldn't matter which object class is specified here - visual vocab will still be the same.
2013     int vocabSize; //number of visual words in vocabulary to train
2014     int memoryUse; // Memory to preallocate (in MB) when training vocab.
2015                    // Change this depending on the size of the dataset/available memory.
2016     float descProportion; // Specifies the number of descriptors to use from each image as a proportion of the total num descs.
2017 };
2018
2019 struct SVMTrainParamsExt
2020 {
2021     SVMTrainParamsExt() : descPercent(0.5f), targetRatio(0.4f), balanceClasses(true) {}
2022     SVMTrainParamsExt( float _descPercent, float _targetRatio, bool _balanceClasses ) :
2023             descPercent(_descPercent), targetRatio(_targetRatio), balanceClasses(_balanceClasses) {}
2024     void read( const FileNode& fn )
2025     {
2026         fn["descPercent"] >> descPercent;
2027         fn["targetRatio"] >> targetRatio;
2028         fn["balanceClasses"] >> balanceClasses;
2029     }
2030     void write( FileStorage& fs ) const
2031     {
2032         fs << "descPercent" << descPercent;
2033         fs << "targetRatio" << targetRatio;
2034         fs << "balanceClasses" << balanceClasses;
2035     }
2036     void print() const
2037     {
2038         cout << "descPercent: " << descPercent << endl;
2039         cout << "targetRatio: " << targetRatio << endl;
2040         cout << "balanceClasses: " << balanceClasses << endl;
2041     }
2042
2043     float descPercent; // Percentage of extracted descriptors to use for training.
2044     float targetRatio; // Try to get this ratio of positive to negative samples (minimum).
2045     bool balanceClasses;    // Balance class weights by number of samples in each (if true cSvmTrainTargetRatio is ignored).
2046 };
2047
2048 static void readUsedParams( const FileNode& fn, string& vocName, DDMParams& ddmParams, VocabTrainParams& vocabTrainParams, SVMTrainParamsExt& svmTrainParamsExt )
2049 {
2050     fn["vocName"] >> vocName;
2051
2052     FileNode currFn = fn;
2053
2054     currFn = fn["ddmParams"];
2055     ddmParams.read( currFn );
2056
2057     currFn = fn["vocabTrainParams"];
2058     vocabTrainParams.read( currFn );
2059
2060     currFn = fn["svmTrainParamsExt"];
2061     svmTrainParamsExt.read( currFn );
2062 }
2063
2064 static void writeUsedParams( FileStorage& fs, const string& vocName, const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams, const SVMTrainParamsExt& svmTrainParamsExt )
2065 {
2066     fs << "vocName" << vocName;
2067
2068     fs << "ddmParams" << "{";
2069     ddmParams.write(fs);
2070     fs << "}";
2071
2072     fs << "vocabTrainParams" << "{";
2073     vocabTrainParams.write(fs);
2074     fs << "}";
2075
2076     fs << "svmTrainParamsExt" << "{";
2077     svmTrainParamsExt.write(fs);
2078     fs << "}";
2079 }
2080
2081 static void printUsedParams( const string& vocPath, const string& resDir,
2082                       const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams,
2083                       const SVMTrainParamsExt& svmTrainParamsExt )
2084 {
2085     cout << "CURRENT CONFIGURATION" << endl;
2086     cout << "----------------------------------------------------------------" << endl;
2087     cout << "vocPath: " << vocPath << endl;
2088     cout << "resDir: " << resDir << endl;
2089     cout << endl; ddmParams.print();
2090     cout << endl; vocabTrainParams.print();
2091     cout << endl; svmTrainParamsExt.print();
2092     cout << "----------------------------------------------------------------" << endl << endl;
2093 }
2094
2095 static bool readVocabulary( const string& filename, Mat& vocabulary )
2096 {
2097     cout << "Reading vocabulary...";
2098     FileStorage fs( filename, FileStorage::READ );
2099     if( fs.isOpened() )
2100     {
2101         fs["vocabulary"] >> vocabulary;
2102         cout << "done" << endl;
2103         return true;
2104     }
2105     return false;
2106 }
2107
2108 static bool writeVocabulary( const string& filename, const Mat& vocabulary )
2109 {
2110     cout << "Saving vocabulary..." << endl;
2111     FileStorage fs( filename, FileStorage::WRITE );
2112     if( fs.isOpened() )
2113     {
2114         fs << "vocabulary" << vocabulary;
2115         return true;
2116     }
2117     return false;
2118 }
2119
2120 static Mat trainVocabulary( const string& filename, VocData& vocData, const VocabTrainParams& trainParams,
2121                      const Ptr<FeatureDetector>& fdetector, const Ptr<DescriptorExtractor>& dextractor )
2122 {
2123     Mat vocabulary;
2124     if( !readVocabulary( filename, vocabulary) )
2125     {
2126         CV_Assert( dextractor->descriptorType() == CV_32FC1 );
2127         const int elemSize = CV_ELEM_SIZE(dextractor->descriptorType());
2128         const int descByteSize = dextractor->descriptorSize() * elemSize;
2129         const int bytesInMB = 1048576;
2130         const int maxDescCount = (trainParams.memoryUse * bytesInMB) / descByteSize; // Total number of descs to use for training.
2131
2132         cout << "Extracting VOC data..." << endl;
2133         vector<ObdImage> images;
2134         vector<char> objectPresent;
2135         vocData.getClassImages( trainParams.trainObjClass, CV_OBD_TRAIN, images, objectPresent );
2136
2137         cout << "Computing descriptors..." << endl;
2138         RNG& rng = theRNG();
2139         TermCriteria terminate_criterion;
2140         terminate_criterion.epsilon = FLT_EPSILON;
2141         BOWKMeansTrainer bowTrainer( trainParams.vocabSize, terminate_criterion, 3, KMEANS_PP_CENTERS );
2142
2143         while( images.size() > 0 )
2144         {
2145             if( bowTrainer.descriptorsCount() > maxDescCount )
2146             {
2147 #ifdef DEBUG_DESC_PROGRESS
2148                 cout << "Breaking due to full memory ( descriptors count = " << bowTrainer.descriptorsCount()
2149                         << "; descriptor size in bytes = " << descByteSize << "; all used memory = "
2150                         << bowTrainer.descriptorsCount()*descByteSize << endl;
2151 #endif
2152                 break;
2153             }
2154
2155             // Randomly pick an image from the dataset which hasn't yet been seen
2156             // and compute the descriptors from that image.
2157             int randImgIdx = rng( (unsigned)images.size() );
2158             Mat colorImage = imread( images[randImgIdx].path );
2159             vector<KeyPoint> imageKeypoints;
2160             fdetector->detect( colorImage, imageKeypoints );
2161             Mat imageDescriptors;
2162             dextractor->compute( colorImage, imageKeypoints, imageDescriptors );
2163
2164             //check that there were descriptors calculated for the current image
2165             if( !imageDescriptors.empty() )
2166             {
2167                 int descCount = imageDescriptors.rows;
2168                 // Extract trainParams.descProportion descriptors from the image, breaking if the 'allDescriptors' matrix becomes full
2169                 int descsToExtract = static_cast<int>(trainParams.descProportion * static_cast<float>(descCount));
2170                 // Fill mask of used descriptors
2171                 vector<char> usedMask( descCount, false );
2172                 fill( usedMask.begin(), usedMask.begin() + descsToExtract, true );
2173                 for( int i = 0; i < descCount; i++ )
2174                 {
2175                     int i1 = rng(descCount), i2 = rng(descCount);
2176                     char tmp = usedMask[i1]; usedMask[i1] = usedMask[i2]; usedMask[i2] = tmp;
2177                 }
2178
2179                 for( int i = 0; i < descCount; i++ )
2180                 {
2181                     if( usedMask[i] && bowTrainer.descriptorsCount() < maxDescCount )
2182                         bowTrainer.add( imageDescriptors.row(i) );
2183                 }
2184             }
2185
2186 #ifdef DEBUG_DESC_PROGRESS
2187             cout << images.size() << " images left, " << images[randImgIdx].id << " processed - "
2188                     <</* descs_extracted << "/" << image_descriptors.rows << " extracted - " << */
2189                     cvRound((static_cast<double>(bowTrainer.descriptorsCount())/static_cast<double>(maxDescCount))*100.0)
2190                     << " % memory used" << ( imageDescriptors.empty() ? " -> no descriptors extracted, skipping" : "") << endl;
2191 #endif
2192
2193             // Delete the current element from images so it is not added again
2194             images.erase( images.begin() + randImgIdx );
2195         }
2196
2197         cout << "Maximum allowed descriptor count: " << maxDescCount << ", Actual descriptor count: " << bowTrainer.descriptorsCount() << endl;
2198
2199         cout << "Training vocabulary..." << endl;
2200         vocabulary = bowTrainer.cluster();
2201
2202         if( !writeVocabulary(filename, vocabulary) )
2203         {
2204             cout << "Error: file " << filename << " can not be opened to write" << endl;
2205             exit(-1);
2206         }
2207     }
2208     return vocabulary;
2209 }
2210
2211 static bool readBowImageDescriptor( const string& file, Mat& bowImageDescriptor )
2212 {
2213     FileStorage fs( file, FileStorage::READ );
2214     if( fs.isOpened() )
2215     {
2216         fs["imageDescriptor"] >> bowImageDescriptor;
2217         return true;
2218     }
2219     return false;
2220 }
2221
2222 static bool writeBowImageDescriptor( const string& file, const Mat& bowImageDescriptor )
2223 {
2224     FileStorage fs( file, FileStorage::WRITE );
2225     if( fs.isOpened() )
2226     {
2227         fs << "imageDescriptor" << bowImageDescriptor;
2228         return true;
2229     }
2230     return false;
2231 }
2232
2233 // Load in the bag of words vectors for a set of images, from file if possible
2234 static void calculateImageDescriptors( const vector<ObdImage>& images, vector<Mat>& imageDescriptors,
2235                                 Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2236                                 const string& resPath )
2237 {
2238     CV_Assert( !bowExtractor->getVocabulary().empty() );
2239     imageDescriptors.resize( images.size() );
2240
2241     for( size_t i = 0; i < images.size(); i++ )
2242     {
2243         string filename = resPath + bowImageDescriptorsDir + "/" + images[i].id + ".xml.gz";
2244         if( readBowImageDescriptor( filename, imageDescriptors[i] ) )
2245         {
2246 #ifdef DEBUG_DESC_PROGRESS
2247             cout << "Loaded bag of word vector for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << endl;
2248 #endif
2249         }
2250         else
2251         {
2252             Mat colorImage = imread( images[i].path );
2253 #ifdef DEBUG_DESC_PROGRESS
2254             cout << "Computing descriptors for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << flush;
2255 #endif
2256             vector<KeyPoint> keypoints;
2257             fdetector->detect( colorImage, keypoints );
2258 #ifdef DEBUG_DESC_PROGRESS
2259                 cout << " + generating BoW vector" << std::flush;
2260 #endif
2261             bowExtractor->compute( colorImage, keypoints, imageDescriptors[i] );
2262 #ifdef DEBUG_DESC_PROGRESS
2263             cout << " ...DONE " << static_cast<int>(static_cast<float>(i+1)/static_cast<float>(images.size())*100.0)
2264                  << " % complete" << endl;
2265 #endif
2266             if( !imageDescriptors[i].empty() )
2267             {
2268                 if( !writeBowImageDescriptor( filename, imageDescriptors[i] ) )
2269                 {
2270                     cout << "Error: file " << filename << "can not be opened to write bow image descriptor" << endl;
2271                     exit(-1);
2272                 }
2273             }
2274         }
2275     }
2276 }
2277
2278 static void removeEmptyBowImageDescriptors( vector<ObdImage>& images, vector<Mat>& bowImageDescriptors,
2279                                      vector<char>& objectPresent )
2280 {
2281     CV_Assert( !images.empty() );
2282     for( int i = (int)images.size() - 1; i >= 0; i-- )
2283     {
2284         bool res = bowImageDescriptors[i].empty();
2285         if( res )
2286         {
2287             cout << "Removing image " << images[i].id << " due to no descriptors..." << endl;
2288             images.erase( images.begin() + i );
2289             bowImageDescriptors.erase( bowImageDescriptors.begin() + i );
2290             objectPresent.erase( objectPresent.begin() + i );
2291         }
2292     }
2293 }
2294
2295 static void removeBowImageDescriptorsByCount( vector<ObdImage>& images, vector<Mat> bowImageDescriptors, vector<char> objectPresent,
2296                                        const SVMTrainParamsExt& svmParamsExt, int descsToDelete )
2297 {
2298     RNG& rng = theRNG();
2299     int pos_ex = (int)std::count( objectPresent.begin(), objectPresent.end(), (char)1 );
2300     int neg_ex = (int)std::count( objectPresent.begin(), objectPresent.end(), (char)0 );
2301
2302     while( descsToDelete != 0 )
2303     {
2304         int randIdx = rng((unsigned)images.size());
2305
2306         // Prefer positive training examples according to svmParamsExt.targetRatio if required
2307         if( objectPresent[randIdx] )
2308         {
2309             if( (static_cast<float>(pos_ex)/static_cast<float>(neg_ex+pos_ex)  < svmParamsExt.targetRatio) &&
2310                 (neg_ex > 0) && (svmParamsExt.balanceClasses == false) )
2311             { continue; }
2312             else
2313             { pos_ex--; }
2314         }
2315         else
2316         { neg_ex--; }
2317
2318         images.erase( images.begin() + randIdx );
2319         bowImageDescriptors.erase( bowImageDescriptors.begin() + randIdx );
2320         objectPresent.erase( objectPresent.begin() + randIdx );
2321
2322         descsToDelete--;
2323     }
2324     CV_Assert( bowImageDescriptors.size() == objectPresent.size() );
2325 }
2326
2327 static void setSVMParams( CvSVMParams& svmParams, CvMat& class_wts_cv, const Mat& responses, bool balanceClasses )
2328 {
2329     int pos_ex = countNonZero(responses == 1);
2330     int neg_ex = countNonZero(responses == -1);
2331     cout << pos_ex << " positive training samples; " << neg_ex << " negative training samples" << endl;
2332
2333     svmParams.svm_type = CvSVM::C_SVC;
2334     svmParams.kernel_type = CvSVM::RBF;
2335     if( balanceClasses )
2336     {
2337         Mat class_wts( 2, 1, CV_32FC1 );
2338         // The first training sample determines the '+1' class internally, even if it is negative,
2339         // so store whether this is the case so that the class weights can be reversed accordingly.
2340         bool reversed_classes = (responses.at<float>(0) < 0.f);
2341         if( reversed_classes == false )
2342         {
2343             class_wts.at<float>(0) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of positive class + 1 (i.e. cost of false positive - larger gives greater cost)
2344             class_wts.at<float>(1) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of negative class - 1 (i.e. cost of false negative)
2345         }
2346         else
2347         {
2348             class_wts.at<float>(0) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex);
2349             class_wts.at<float>(1) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex);
2350         }
2351         class_wts_cv = class_wts;
2352         svmParams.class_weights = &class_wts_cv;
2353     }
2354 }
2355
2356 static void setSVMTrainAutoParams( CvParamGrid& c_grid, CvParamGrid& gamma_grid,
2357                             CvParamGrid& p_grid, CvParamGrid& nu_grid,
2358                             CvParamGrid& coef_grid, CvParamGrid& degree_grid )
2359 {
2360     c_grid = CvSVM::get_default_grid(CvSVM::C);
2361
2362     gamma_grid = CvSVM::get_default_grid(CvSVM::GAMMA);
2363
2364     p_grid = CvSVM::get_default_grid(CvSVM::P);
2365     p_grid.step = 0;
2366
2367     nu_grid = CvSVM::get_default_grid(CvSVM::NU);
2368     nu_grid.step = 0;
2369
2370     coef_grid = CvSVM::get_default_grid(CvSVM::COEF);
2371     coef_grid.step = 0;
2372
2373     degree_grid = CvSVM::get_default_grid(CvSVM::DEGREE);
2374     degree_grid.step = 0;
2375 }
2376
2377 static void trainSVMClassifier( CvSVM& svm, const SVMTrainParamsExt& svmParamsExt, const string& objClassName, VocData& vocData,
2378                          Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2379                          const string& resPath )
2380 {
2381     /* first check if a previously trained svm for the current class has been saved to file */
2382     string svmFilename = resPath + svmsDir + "/" + objClassName + ".xml.gz";
2383
2384     FileStorage fs( svmFilename, FileStorage::READ);
2385     if( fs.isOpened() )
2386     {
2387         cout << "*** LOADING SVM CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
2388         svm.load( svmFilename.c_str() );
2389     }
2390     else
2391     {
2392         cout << "*** TRAINING CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
2393         cout << "CALCULATING BOW VECTORS FOR TRAINING SET OF " << objClassName << "..." << endl;
2394
2395         // Get classification ground truth for images in the training set
2396         vector<ObdImage> images;
2397         vector<Mat> bowImageDescriptors;
2398         vector<char> objectPresent;
2399         vocData.getClassImages( objClassName, CV_OBD_TRAIN, images, objectPresent );
2400
2401         // Compute the bag of words vector for each image in the training set.
2402         calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
2403
2404         // Remove any images for which descriptors could not be calculated
2405         removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent );
2406
2407         CV_Assert( svmParamsExt.descPercent > 0.f && svmParamsExt.descPercent <= 1.f );
2408         if( svmParamsExt.descPercent < 1.f )
2409         {
2410             int descsToDelete = static_cast<int>(static_cast<float>(images.size())*(1.0-svmParamsExt.descPercent));
2411
2412             cout << "Using " << (images.size() - descsToDelete) << " of " << images.size() <<
2413                     " descriptors for training (" << svmParamsExt.descPercent*100.0 << " %)" << endl;
2414             removeBowImageDescriptorsByCount( images, bowImageDescriptors, objectPresent, svmParamsExt, descsToDelete );
2415         }
2416
2417         // Prepare the input matrices for SVM training.
2418         Mat trainData( (int)images.size(), bowExtractor->getVocabulary().rows, CV_32FC1 );
2419         Mat responses( (int)images.size(), 1, CV_32SC1 );
2420
2421         // Transfer bag of words vectors and responses across to the training data matrices
2422         for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
2423         {
2424             // Transfer image descriptor (bag of words vector) to training data matrix
2425             Mat submat = trainData.row((int)imageIdx);
2426             if( bowImageDescriptors[imageIdx].cols != bowExtractor->descriptorSize() )
2427             {
2428                 cout << "Error: computed bow image descriptor size " << bowImageDescriptors[imageIdx].cols
2429                      << " differs from vocabulary size" << bowExtractor->getVocabulary().cols << endl;
2430                 exit(-1);
2431             }
2432             bowImageDescriptors[imageIdx].copyTo( submat );
2433
2434             // Set response value
2435             responses.at<int>((int)imageIdx) = objectPresent[imageIdx] ? 1 : -1;
2436         }
2437
2438         cout << "TRAINING SVM FOR CLASS ..." << objClassName << "..." << endl;
2439         CvSVMParams svmParams;
2440         CvMat class_wts_cv;
2441         setSVMParams( svmParams, class_wts_cv, responses, svmParamsExt.balanceClasses );
2442         CvParamGrid c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid;
2443         setSVMTrainAutoParams( c_grid, gamma_grid,  p_grid, nu_grid, coef_grid, degree_grid );
2444         svm.train_auto( trainData, responses, Mat(), Mat(), svmParams, 10, c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
2445         cout << "SVM TRAINING FOR CLASS " << objClassName << " COMPLETED" << endl;
2446
2447         svm.save( svmFilename.c_str() );
2448         cout << "SAVED CLASSIFIER TO FILE" << endl;
2449     }
2450 }
2451
2452 static void computeConfidences( CvSVM& svm, const string& objClassName, VocData& vocData,
2453                          Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2454                          const string& resPath )
2455 {
2456     cout << "*** CALCULATING CONFIDENCES FOR CLASS " << objClassName << " ***" << endl;
2457     cout << "CALCULATING BOW VECTORS FOR TEST SET OF " << objClassName << "..." << endl;
2458     // Get classification ground truth for images in the test set
2459     vector<ObdImage> images;
2460     vector<Mat> bowImageDescriptors;
2461     vector<char> objectPresent;
2462     vocData.getClassImages( objClassName, CV_OBD_TEST, images, objectPresent );
2463
2464     // Compute the bag of words vector for each image in the test set
2465     calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
2466     // Remove any images for which descriptors could not be calculated
2467     removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent);
2468
2469     // Use the bag of words vectors to calculate classifier output for each image in test set
2470     cout << "CALCULATING CONFIDENCE SCORES FOR CLASS " << objClassName << "..." << endl;
2471     vector<float> confidences( images.size() );
2472     float signMul = 1.f;
2473     for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
2474     {
2475         if( imageIdx == 0 )
2476         {
2477             // In the first iteration, determine the sign of the positive class
2478             float classVal = confidences[imageIdx] = svm.predict( bowImageDescriptors[imageIdx], false );
2479             float scoreVal = confidences[imageIdx] = svm.predict( bowImageDescriptors[imageIdx], true );
2480             signMul = (classVal < 0) == (scoreVal < 0) ? 1.f : -1.f;
2481         }
2482         // svm output of decision function
2483         confidences[imageIdx] = signMul * svm.predict( bowImageDescriptors[imageIdx], true );
2484     }
2485
2486     cout << "WRITING QUERY RESULTS TO VOC RESULTS FILE FOR CLASS " << objClassName << "..." << endl;
2487     vocData.writeClassifierResultsFile( resPath + plotsDir, objClassName, CV_OBD_TEST, images, confidences, 1, true );
2488
2489     cout << "DONE - " << objClassName << endl;
2490     cout << "---------------------------------------------------------------" << endl;
2491 }
2492
2493 static void computeGnuPlotOutput( const string& resPath, const string& objClassName, VocData& vocData )
2494 {
2495     vector<float> precision, recall;
2496     float ap;
2497
2498     const string resultFile = vocData.getResultsFilename( objClassName, CV_VOC_TASK_CLASSIFICATION, CV_OBD_TEST);
2499     const string plotFile = resultFile.substr(0, resultFile.size()-4) + ".plt";
2500
2501     cout << "Calculating precision recall curve for class '" <<objClassName << "'" << endl;
2502     vocData.calcClassifierPrecRecall( resPath + plotsDir + "/" + resultFile, precision, recall, ap, true );
2503     cout << "Outputting to GNUPlot file..." << endl;
2504     vocData.savePrecRecallToGnuplot( resPath + plotsDir + "/" + plotFile, precision, recall, ap, objClassName, CV_VOC_PLOT_PNG );
2505 }
2506
2507
2508
2509
2510 int main(int argc, char** argv)
2511 {
2512     if( argc != 3 && argc != 6 )
2513     {
2514         help(argv);
2515         return -1;
2516     }
2517
2518     cv::initModule_nonfree();
2519
2520     const string vocPath = argv[1], resPath = argv[2];
2521
2522     // Read or set default parameters
2523     string vocName;
2524     DDMParams ddmParams;
2525     VocabTrainParams vocabTrainParams;
2526     SVMTrainParamsExt svmTrainParamsExt;
2527
2528     makeUsedDirs( resPath );
2529
2530     FileStorage paramsFS( resPath + "/" + paramsFile, FileStorage::READ );
2531     if( paramsFS.isOpened() )
2532     {
2533        readUsedParams( paramsFS.root(), vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
2534        CV_Assert( vocName == getVocName(vocPath) );
2535     }
2536     else
2537     {
2538         vocName = getVocName(vocPath);
2539         if( argc!= 6 )
2540         {
2541             cout << "Feature detector, descriptor extractor, descriptor matcher must be set" << endl;
2542             return -1;
2543         }
2544         ddmParams = DDMParams( argv[3], argv[4], argv[5] ); // from command line
2545         // vocabTrainParams and svmTrainParamsExt is set by defaults
2546         paramsFS.open( resPath + "/" + paramsFile, FileStorage::WRITE );
2547         if( paramsFS.isOpened() )
2548         {
2549             writeUsedParams( paramsFS, vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
2550             paramsFS.release();
2551         }
2552         else
2553         {
2554             cout << "File " << (resPath + "/" + paramsFile) << "can not be opened to write" << endl;
2555             return -1;
2556         }
2557     }
2558
2559     // Create detector, descriptor, matcher.
2560     Ptr<FeatureDetector> featureDetector = FeatureDetector::create( ddmParams.detectorType );
2561     Ptr<DescriptorExtractor> descExtractor = DescriptorExtractor::create( ddmParams.descriptorType );
2562     Ptr<BOWImgDescriptorExtractor> bowExtractor;
2563     if( !featureDetector || !descExtractor )
2564     {
2565         cout << "featureDetector or descExtractor was not created" << endl;
2566         return -1;
2567     }
2568     {
2569         Ptr<DescriptorMatcher> descMatcher = DescriptorMatcher::create( ddmParams.matcherType );
2570         if( !featureDetector || !descExtractor || !descMatcher )
2571         {
2572             cout << "descMatcher was not created" << endl;
2573             return -1;
2574         }
2575         bowExtractor = makePtr<BOWImgDescriptorExtractor>( descExtractor, descMatcher );
2576     }
2577
2578     // Print configuration to screen
2579     printUsedParams( vocPath, resPath, ddmParams, vocabTrainParams, svmTrainParamsExt );
2580     // Create object to work with VOC
2581     VocData vocData( vocPath, false );
2582
2583     // 1. Train visual word vocabulary if a pre-calculated vocabulary file doesn't already exist from previous run
2584     Mat vocabulary = trainVocabulary( resPath + "/" + vocabularyFile, vocData, vocabTrainParams,
2585                                       featureDetector, descExtractor );
2586     bowExtractor->setVocabulary( vocabulary );
2587
2588     // 2. Train a classifier and run a sample query for each object class
2589     const vector<string>& objClasses = vocData.getObjectClasses(); // object class list
2590     for( size_t classIdx = 0; classIdx < objClasses.size(); ++classIdx )
2591     {
2592         // Train a classifier on train dataset
2593         CvSVM svm;
2594         trainSVMClassifier( svm, svmTrainParamsExt, objClasses[classIdx], vocData,
2595                             bowExtractor, featureDetector, resPath );
2596
2597         // Now use the classifier over all images on the test dataset and rank according to score order
2598         // also calculating precision-recall etc.
2599         computeConfidences( svm, objClasses[classIdx], vocData,
2600                             bowExtractor, featureDetector, resPath );
2601         // Calculate precision/recall/ap and use GNUPlot to output to a pdf file
2602         computeGnuPlotOutput( resPath, objClasses[classIdx], vocData );
2603     }
2604     return 0;
2605 }