Refactor NMS procedure at RegionLayer
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 8 Dec 2017 11:14:17 +0000 (14:14 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Thu, 21 Dec 2017 09:21:45 +0000 (12:21 +0300)
modules/dnn/src/darknet/darknet_io.cpp
modules/dnn/src/layers/region_layer.cpp

index bf6a5af..daddcf7 100644 (file)
@@ -482,7 +482,7 @@ namespace cv {
                     }
                     else if (layer_type == "region")
                     {
-                        float thresh = 0.001;    // in the original Darknet is equal to the detection threshold set by the user
+                        float thresh = getParam<float>(layer_params, "thresh", 0.001);
                         int coords = getParam<int>(layer_params, "coords", 4);
                         int classes = getParam<int>(layer_params, "classes", -1);
                         int num_of_anchors = getParam<int>(layer_params, "num", -1);
index 94993fa..688cb90 100644 (file)
@@ -43,7 +43,7 @@
 #include "../precomp.hpp"
 #include <opencv2/dnn/shape_utils.hpp>
 #include <opencv2/dnn/all_layers.hpp>
-#include <iostream>
+#include "nms.inl.hpp"
 #include "opencl_kernels_dnn.hpp"
 
 namespace cv
@@ -173,8 +173,7 @@ public:
             if (nmsThreshold > 0) {
                 Mat mat = outBlob.getMat(ACCESS_WRITE);
                 float *dstData = mat.ptr<float>();
-                do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
-                //do_nms(dstData, rows*cols*anchors, nmsThreshold);
+                do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
             }
 
         }
@@ -263,128 +262,48 @@ public:
             }
 
             if (nmsThreshold > 0) {
-                do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
-                //do_nms(dstData, rows*cols*anchors, nmsThreshold);
+                do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
             }
 
         }
     }
 
-
-    struct box {
-        float x, y, w, h;
-        float *probs;
-    };
-
-    float overlap(float x1, float w1, float x2, float w2)
-    {
-        float l1 = x1 - w1 / 2;
-        float l2 = x2 - w2 / 2;
-        float left = l1 > l2 ? l1 : l2;
-        float r1 = x1 + w1 / 2;
-        float r2 = x2 + w2 / 2;
-        float right = r1 < r2 ? r1 : r2;
-        return right - left;
-    }
-
-    float box_intersection(box a, box b)
-    {
-        float w = overlap(a.x, a.w, b.x, b.w);
-        float h = overlap(a.y, a.h, b.y, b.h);
-        if (w < 0 || h < 0) return 0;
-        float area = w*h;
-        return area;
-    }
-
-    float box_union(box a, box b)
+    static inline float rectOverlap(const Rect2f& a, const Rect2f& b)
     {
-        float i = box_intersection(a, b);
-        float u = a.w*a.h + b.w*b.h - i;
-        return u;
+        return 1.0f - jaccardDistance(a, b);
     }
 
-    float box_iou(box a, box b)
+    void do_nms_sort(float *detections, int total, float score_thresh, float nms_thresh)
     {
-        return box_intersection(a, b) / box_union(a, b);
-    }
-
-    struct sortable_bbox {
-        int index;
-        float *probs;
-    };
-
-    struct nms_comparator {
-        int k;
-        nms_comparator(int _k) : k(_k) {}
-        bool operator ()(sortable_bbox v1, sortable_bbox v2) {
-            return v2.probs[k] < v1.probs[k];
-        }
-    };
-
-    void do_nms_sort(float *detections, int total, float nms_thresh)
-    {
-        std::vector<box> boxes(total);
-        for (int i = 0; i < total; ++i) {
-            box &b = boxes[i];
-            int box_index = i * (classes + coords + 1);
-            b.x = detections[box_index + 0];
-            b.y = detections[box_index + 1];
-            b.w = detections[box_index + 2];
-            b.h = detections[box_index + 3];
-            int class_index = i * (classes + 5) + 5;
-            b.probs = (detections + class_index);
-        }
-
-        std::vector<sortable_bbox> s(total);
-
-        for (int i = 0; i < total; ++i) {
-            s[i].index = i;
-            int class_index = i * (classes + 5) + 5;
-            s[i].probs = (detections + class_index);
-        }
+        std::vector<Rect2f> boxes(total);
+        std::vector<float> scores(total);
 
-        for (int k = 0; k < classes; ++k) {
-            std::stable_sort(s.begin(), s.end(), nms_comparator(k));
-            for (int i = 0; i < total; ++i) {
-                if (boxes[s[i].index].probs[k] == 0) continue;
-                box a = boxes[s[i].index];
-                for (int j = i + 1; j < total; ++j) {
-                    box b = boxes[s[j].index];
-                    if (box_iou(a, b) > nms_thresh) {
-                        boxes[s[j].index].probs[k] = 0;
-                    }
-                }
-            }
-        }
-    }
-
-    void do_nms(float *detections, int total, float nms_thresh)
-    {
-        std::vector<box> boxes(total);
-        for (int i = 0; i < total; ++i) {
-            box &b = boxes[i];
+        for (int i = 0; i < total; ++i)
+        {
+            Rect2f &b = boxes[i];
             int box_index = i * (classes + coords + 1);
-            b.x = detections[box_index + 0];
-            b.y = detections[box_index + 1];
-            b.w = detections[box_index + 2];
-            b.h = detections[box_index + 3];
-            int class_index = i * (classes + 5) + 5;
-            b.probs = (detections + class_index);
+            b.width = detections[box_index + 2];
+            b.height = detections[box_index + 3];
+            b.x = detections[box_index + 0] - b.width / 2;
+            b.y = detections[box_index + 1] - b.height / 2;
         }
 
-        for (int i = 0; i < total; ++i) {
-            bool any = false;
-            for (int k = 0; k < classes; ++k) any = any || (boxes[i].probs[k] > 0);
-            if (!any) {
-                continue;
+        std::vector<int> indices;
+        for (int k = 0; k < classes; ++k)
+        {
+            for (int i = 0; i < total; ++i)
+            {
+                int box_index = i * (classes + coords + 1);
+                int class_index = box_index + 5;
+                scores[i] = detections[class_index + k];
+                detections[class_index + k] = 0;
             }
-            for (int j = i + 1; j < total; ++j) {
-                if (box_iou(boxes[i], boxes[j]) > nms_thresh) {
-                    for (int k = 0; k < classes; ++k) {
-                        if (boxes[i].probs[k] < boxes[j].probs[k]) boxes[i].probs[k] = 0;
-                        else boxes[j].probs[k] = 0;
-                    }
-                }
+            NMSFast_(boxes, scores, score_thresh, nms_thresh, 1, 0, indices, rectOverlap);
+            for (int i = 0, n = indices.size(); i < n; ++i)
+            {
+                int box_index = indices[i] * (classes + coords + 1);
+                int class_index = box_index + 5;
+                detections[class_index + k] = scores[indices[i]];
             }
         }
     }