Add Euclidean distance to be stable for tracking points
authorTae-Young Chung <ty83.chung@samsung.com>
Wed, 3 Jul 2024 04:23:02 +0000 (13:23 +0900)
committerTae-Young Chung <ty83.chung@samsung.com>
Wed, 3 Jul 2024 04:23:02 +0000 (13:23 +0900)
Change-Id: I11452a96c639aafb6d25fde1405f3b7c1f3b1802
Signed-off-by: Tae-Young Chung <ty83.chung@samsung.com>
inference/backends/private/include/PointsTracker.h
inference/backends/private/src/PointsTracker.cpp

index f44f20e648ec141ee58dd5a61069b751b1022efc..b561f254777ff25c72842701df275a7beecebd90 100644 (file)
@@ -30,7 +30,9 @@ class PointsTracker : protected AbstractTracker {
 private:
     int _numberOfPoints;
     std::vector<singleo::Point> _points;
+    float _distanceThresValue;
 
+    float distanceValue(std::tuple<float, float> pt1, std::tuple<float, float> pt2);
 public:
     PointsTracker(const int numbers);
     virtual ~PointsTracker();
index 6fdd1828e81947394045fbdd7d1e213109bb5873..ace9fa917e7007163e959e3d5c89bf341f04fb85 100644 (file)
@@ -27,6 +27,7 @@ namespace inference
 PointsTracker::PointsTracker(const int numbers) : AbstractTracker(numbers)
 {
     _numberOfPoints = numbers;
+    _distanceThresValue = sqrt(2.f);
 }
 
 PointsTracker::~PointsTracker()
@@ -34,6 +35,14 @@ PointsTracker::~PointsTracker()
 
 }
 
+float PointsTracker::distanceValue(std::tuple<float, float> pt1, std::tuple<float, float> pt2)
+{
+    float diffX = get<0>(pt1) - get<0>(pt2);
+    float diffY = get<1>(pt1) - get<1>(pt2);
+
+    return sqrt(diffX * diffX + diffY * diffY);
+}
+
 void PointsTracker::init(vector<Point> &points)
 {
     if (points.size() != _filters.size()) {
@@ -56,6 +65,11 @@ vector<Point> &PointsTracker::predict()
     for (auto &point : _points) {
         auto newPoint = _filters[index]->predict();
 
+        if (distanceValue({static_cast<float>(point.x), static_cast<float>(point.y)}, newPoint) <= _distanceThresValue) {
+            index++;
+            continue;
+        }
+
         point.x = get<0>(newPoint);
         point.y = get<1>(newPoint);
         index++;
@@ -72,8 +86,12 @@ vector<Point> &PointsTracker::update(vector<Point> &points)
         _filters[index]->update(point.x, point.y);
         auto newPoint = _filters[index]->get();
 
-        _points[index].x = static_cast<int>(get<0>(newPoint));
-        _points[index].y = static_cast<int>(get<1>(newPoint));
+        if (distanceValue({static_cast<float>(point.x), static_cast<float>(point.y)}, newPoint) <= _distanceThresValue) {
+            _filters[index]->update(_points[index].x, _points[index].y);
+        } else {
+            _points[index].x = static_cast<int>(get<0>(newPoint));
+            _points[index].y = static_cast<int>(get<1>(newPoint));
+        }
         index++;
     }