Add IoU to be stable for tracking
authorTae-Young Chung <ty83.chung@samsung.com>
Tue, 2 Jul 2024 08:53:01 +0000 (17:53 +0900)
committerTae-Young Chung <ty83.chung@samsung.com>
Tue, 2 Jul 2024 08:53:01 +0000 (17:53 +0900)
Change-Id: I5fa0c8682b1898d00340bfec4112f3be0edb4011
Signed-off-by: Tae-Young Chung <ty83.chung@samsung.com>
inference/backends/private/include/RectsTracker.h
inference/backends/private/src/RectsTracker.cpp

index 25cab07b0d865e7dce8db93faac6a0d1f9fbf72a..8311dde8bbc6a2cc94e5def3e0d3997d7f639568 100644 (file)
@@ -30,7 +30,9 @@ class RectsTracker : public AbstractTracker {
 private:
     int _numberOfRects;
     std::vector<singleo::Rect> _rects;
+    float _iouThresValue;
 
+    float iouValue(std::tuple<float, float> lt1, std::tuple<float, float> rb1, std::tuple<float, float> lt2, std::tuple<float, float> rb2);
 public:
     RectsTracker(const int numbers);
     virtual ~RectsTracker();
index 38d0656f6c79e6671885871f1612834a6da1f935..3797975a81d975d7fd647757b0fae966f22156dc 100644 (file)
@@ -27,6 +27,7 @@ namespace inference
 RectsTracker::RectsTracker(const int numbers) : AbstractTracker(numbers * 2)
 {
     _numberOfRects = numbers;
+    _iouThresValue = 0.95;
 }
 
 RectsTracker::~RectsTracker()
@@ -34,6 +35,29 @@ RectsTracker::~RectsTracker()
 
 }
 
+float RectsTracker::iouValue(std::tuple<float, float> lt1, std::tuple<float, float> rb1, std::tuple<float, float> lt2, std::tuple<float, float> rb2)
+{
+    float overlapX1 = std::max(get<0>(lt1), get<0>(lt2));
+    float overlapY1 = std::max(get<1>(lt1), get<1>(lt2));
+
+    float overlapX2 = std::min(get<0>(rb1), get<0>(rb2));
+    float overlapY2 = std::min(get<1>(rb1), get<1>(rb2));
+
+    float overlapW = overlapX2 - overlapX1;
+    float overlapH = overlapY2 - overlapY1;
+
+    if (overlapW < 0.f || overlapH < 0.f)
+        return 0.0f;
+
+    float areaOverlap = overlapW * overlapH;
+    float area1 = (get<0>(rb1) - get<0>(lt1)) * (get<1>(rb1) - get<1>(lt1));
+    float area2 = (get<0>(rb2) - get<0>(lt2)) * (get<1>(rb2) - get<1>(lt2));
+
+    float areaCombined = area1 + area2 - areaOverlap;
+
+    return areaOverlap / (areaCombined + 1e-5);
+}
+
 void RectsTracker::init(vector<Rect> &rects)
 {
     if (rects.size() * 2 != _filters.size()) {
@@ -61,6 +85,9 @@ vector<Rect> &RectsTracker::predict()
         auto newRB = _filters[filterIndex]->predict();
         filterIndex++;
 
+        if (iouValue({rect.left, rect.top}, {rect.right, rect.bottom}, newLT, newRB) >= _iouThresValue)
+            continue;
+
         rect.left = static_cast<int>(get<0>(newLT));
         rect.top = static_cast<int>(get<1>(newLT));
         rect.right = static_cast<int>(get<0>(newRB));
@@ -85,10 +112,22 @@ vector<Rect> &RectsTracker::update(vector<Rect> &rects)
         auto newRB = _filters[filterIndex]->get();
         filterIndex++;
 
-        _rects[rectIndex].left = static_cast<int>(get<0>(newLT));
-        _rects[rectIndex].top = static_cast<int>(get<1>(newLT));
-        _rects[rectIndex].right = static_cast<int>(get<0>(newRB));
-        _rects[rectIndex].bottom = static_cast<int>(get<1>(newRB));
+        if (iouValue({rect.left, rect.top}, {rect.right, rect.bottom}, newLT, newRB) >= _iouThresValue) {
+            SINGLEO_LOGI("iou: %.4f", iouValue({rect.left, rect.top}, {rect.right, rect.bottom}, newLT, newRB));
+            filterIndex--;
+            filterIndex--;
+
+            _filters[filterIndex]->update(_rects[rectIndex].left, _rects[rectIndex].top);
+            filterIndex++;
+            _filters[filterIndex]->update(_rects[rectIndex].right, _rects[rectIndex].bottom);
+            filterIndex++;
+        } else {
+            _rects[rectIndex].left = static_cast<int>(get<0>(newLT));
+            _rects[rectIndex].top = static_cast<int>(get<1>(newLT));
+            _rects[rectIndex].right = static_cast<int>(get<0>(newRB));
+            _rects[rectIndex].bottom = static_cast<int>(get<1>(newRB));
+        }
+
         rectIndex++;
     }