Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / validation_app / YOLOObjectDetectionProcessor.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <iostream>
8 #include <limits>
9 #include <map>
10 #include <memory>
11 #include <string>
12 #include <list>
13 #include <vector>
14 #include <algorithm>
15
16 using namespace std;
17
18 class YOLOObjectDetectionProcessor : public ObjectDetectionProcessor {
19 private:
20     /**
21      * \brief This function analyses the YOLO net output for a single class
22      * @param net_out - The output data
23      * @param class_num - The class number
24      * @return a list of found boxes
25      */
26     std::vector<DetectedObject> yoloNetParseOutput(const float *net_out, int class_num) {
27         float threshold = 0.2f;         // The confidence threshold
28         int C = 20;                     // classes
29         int B = 2;                      // bounding boxes
30         int S = 7;                      // cell size
31
32         std::vector<DetectedObject> boxes;
33         std::vector<DetectedObject> boxes_result;
34         int SS = S * S;                 // number of grid cells 7*7 = 49
35         // First 980 values corresponds to probabilities for each of the 20 classes for each grid cell.
36         // These probabilities are conditioned on objects being present in each grid cell.
37         int prob_size = SS * C;         // class probabilities 49 * 20 = 980
38         // The next 98 values are confidence scores for 2 bounding boxes predicted by each grid cells.
39         int conf_size = SS * B;         // 49*2 = 98 confidences for each grid cell
40
41         const float *probs = &net_out[0];
42         const float *confs = &net_out[prob_size];
43         const float *cords = &net_out[prob_size + conf_size];     // 98*4 = 392 coords x, y, w, h
44
45         for (int grid = 0; grid < SS; grid++) {
46             int row = grid / S;
47             int col = grid % S;
48             for (int b = 0; b < B; b++) {
49                 int objectType = class_num;
50
51                 float conf = confs[(grid * B + b)];
52                 float xc = (cords[(grid * B + b) * 4 + 0] + col) / S;
53                 float yc = (cords[(grid * B + b) * 4 + 1] + row) / S;
54                 float w = pow(cords[(grid * B + b) * 4 + 2], 2);
55                 float h = pow(cords[(grid * B + b) * 4 + 3], 2);
56                 float prob = probs[grid * C + class_num] * conf;
57
58                 DetectedObject bx(objectType, xc - w / 2, yc - h / 2, xc + w / 2,
59                         yc + h / 2, prob);
60
61                 if (prob >= threshold) {
62                     boxes.push_back(bx);
63                 }
64             }
65         }
66
67         // Sorting the higher probabilities to the top
68         sort(boxes.begin(), boxes.end(),
69                 [](const DetectedObject & a, const DetectedObject & b) -> bool {
70                     return a.prob > b.prob;
71                 });
72
73         // Filtering out overlapping boxes
74         std::vector<bool> overlapped(boxes.size(), false);
75         for (size_t i = 0; i < boxes.size(); i++) {
76             if (overlapped[i])
77                 continue;
78
79             DetectedObject box_i = boxes[i];
80             for (size_t j = i + 1; j < boxes.size(); j++) {
81                 DetectedObject box_j = boxes[j];
82                 if (DetectedObject::ioU(box_i, box_j) >= 0.4) {
83                     overlapped[j] = true;
84                 }
85             }
86         }
87
88         for (size_t i = 0; i < boxes.size(); i++) {
89             if (boxes[i].prob > 0.0f) {
90                 boxes_result.push_back(boxes[i]);
91             }
92         }
93         return boxes_result;
94     }
95
96 protected:
97     std::map<std::string, std::list<DetectedObject>> processResult(std::vector<std::string> files) {
98         std::map<std::string, std::list<DetectedObject>> detectedObjects;
99
100         std::string firstOutputName = this->outInfo.begin()->first;
101         const auto detectionOutArray = inferRequest.GetBlob(firstOutputName);
102         const float *box = detectionOutArray->buffer().as<float*>();
103
104         std::string file = *files.begin();
105         for (int c = 0; c < 20; c++) {
106             std::vector<DetectedObject> result = yoloNetParseOutput(box, c);
107             detectedObjects[file].insert(detectedObjects[file].end(), result.begin(), result.end());
108         }
109
110         return detectedObjects;
111     }
112
113 public:
114     YOLOObjectDetectionProcessor(const std::string& flags_m, const std::string& flags_d, const std::string& flags_i, const std::string& subdir, int flags_b,
115             double threshold,
116             InferencePlugin plugin, CsvDumper& dumper,
117             const std::string& flags_a, const std::string& classes_list_file) :
118
119                 ObjectDetectionProcessor(flags_m, flags_d, flags_i, subdir, flags_b, threshold,
120                         plugin, dumper, flags_a, classes_list_file, PreprocessingOptions(true, ResizeCropPolicy::Resize), false) { }
121 };