Merge pull request #18061 from danielenricocahall:fix-kd-tree
authorDanny <33044223+danielenricocahall@users.noreply.github.com>
Fri, 4 Sep 2020 17:01:05 +0000 (13:01 -0400)
committerGitHub <noreply@github.com>
Fri, 4 Sep 2020 17:01:05 +0000 (17:01 +0000)
Fix KD Tree kNN Implementation

* Make KDTree mode in kNN functional

remove docs and revert change

Make KDTree mode in kNN functional

spacing

Make KDTree mode in kNN functional

fix window compilations warnings

Make KDTree mode in kNN functional

fix window compilations warnings

Make KDTree mode in kNN functional

casting

Make KDTree mode in kNN functional

formatting

Make KDTree mode in kNN functional

* test coding style

modules/ml/src/kdtree.cpp
modules/ml/src/knearest.cpp
modules/ml/test/test_knearest.cpp

index 1ab840093697cd287d7762ac8a58d27f31c6a1ab..a80e12964a606dc1ff104f8b233d26f7e8b58db6 100644 (file)
@@ -101,7 +101,7 @@ medianPartition( size_t* ofs, int a, int b, const float* vals )
         int i0 = a, i1 = (a+b)/2, i2 = b;
         float v0 = vals[ofs[i0]], v1 = vals[ofs[i1]], v2 = vals[ofs[i2]];
         int ip = v0 < v1 ? (v1 < v2 ? i1 : v0 < v2 ? i2 : i0) :
-            v0 < v2 ? i0 : (v1 < v2 ? i2 : i1);
+                 v0 < v2 ? (v1 == v0 ? i2 : i0): (v1 < v2 ? i2 : i1);
         float pivot = vals[ofs[ip]];
         std::swap(ofs[ip], ofs[i2]);
 
@@ -131,7 +131,6 @@ medianPartition( size_t* ofs, int a, int b, const float* vals )
         CV_Assert(vals[ofs[k]] >= pivot);
         more += vals[ofs[k]] > pivot;
     }
-    CV_Assert(std::abs(more - less) <= 1);
 
     return vals[ofs[middle]];
 }
index ca23d0f4d6ac04611c6f4e4d91e49183d2bbdada..3d8f9b5d2ed03495ffc3c47251dbf728f6985a05 100644 (file)
@@ -381,36 +381,23 @@ public:
         Mat res, nr, d;
         if( _results.needed() )
         {
-            _results.create(testcount, 1, CV_32F);
             res = _results.getMat();
         }
         if( _neighborResponses.needed() )
         {
-            _neighborResponses.create(testcount, k, CV_32F);
             nr = _neighborResponses.getMat();
         }
         if( _dists.needed() )
         {
-            _dists.create(testcount, k, CV_32F);
             d = _dists.getMat();
         }
 
         for (int i=0; i<test_samples.rows; ++i)
         {
             Mat _res, _nr, _d;
-            if (res.rows>i)
-            {
-                _res = res.row(i);
-            }
-            if (nr.rows>i)
-            {
-                _nr = nr.row(i);
-            }
-            if (d.rows>i)
-            {
-                _d = d.row(i);
-            }
             tr.findNearest(test_samples.row(i), k, Emax, _res, _nr, _d, noArray());
+            res.push_back(_res.t());
+            _results.assign(res);
         }
 
         return result; // currently always 0
index 49e6b0d12aeb2dfb79225ac96514d70d960a2f1b..80baed96266e32e6f3e65ae8ee2f9a1116ff6fc4 100644 (file)
@@ -37,18 +37,31 @@ TEST(ML_KNearest, accuracy)
         EXPECT_LE(err, 0.01f);
     }
     {
-        // TODO: broken
-#if 0
         SCOPED_TRACE("KDTree");
-        Mat bestLabels;
+        Mat neighborIndexes;
         float err = 1000;
         Ptr<KNearest> knn = KNearest::create();
         knn->setAlgorithmType(KNearest::KDTREE);
         knn->train(trainData, ml::ROW_SAMPLE, trainLabels);
-        knn->findNearest(testData, 4, bestLabels);
+        knn->findNearest(testData, 4, neighborIndexes);
+        Mat bestLabels;
+        // The output of the KDTree are the neighbor indexes, not actual class labels
+        // so we need to do some extra work to get actual predictions
+        for(int row_num = 0; row_num < neighborIndexes.rows; ++row_num){
+            vector<float> labels;
+            for(int index = 0; index < neighborIndexes.row(row_num).cols; ++index) {
+                labels.push_back(trainLabels.at<float>(neighborIndexes.row(row_num).at<int>(0, index) , 0));
+            }
+            // computing the mode of the output class predictions to determine overall prediction
+            std::vector<int> histogram(3,0);
+            for( int i=0; i<3; ++i )
+                ++histogram[ static_cast<int>(labels[i]) ];
+            int bestLabel = static_cast<int>(std::max_element( histogram.begin(), histogram.end() ) - histogram.begin());
+            bestLabels.push_back(bestLabel);
+        }
+        bestLabels.convertTo(bestLabels, testLabels.type());
         EXPECT_TRUE(calcErr( bestLabels, testLabels, sizes, err, true ));
         EXPECT_LE(err, 0.01f);
-#endif
     }
 }
 
@@ -74,4 +87,26 @@ TEST(ML_KNearest, regression_12347)
     EXPECT_EQ(2, zBestLabels.at<float>(1,0));
 }
 
+TEST(ML_KNearest, bug_11877)
+{
+    Mat trainData = (Mat_<float>(5,2) << 3, 3, 3, 3, 4, 4, 4, 4, 4, 4);
+    Mat trainLabels = (Mat_<float>(5,1) << 0, 0, 1, 1, 1);
+
+    Ptr<KNearest> knnKdt = KNearest::create();
+    knnKdt->setAlgorithmType(KNearest::KDTREE);
+    knnKdt->setIsClassifier(true);
+
+    knnKdt->train(trainData, ml::ROW_SAMPLE, trainLabels);
+
+    Mat testData = (Mat_<float>(2,2) << 3.1, 3.1, 4, 4.1);
+    Mat testLabels = (Mat_<int>(2,1) << 0, 1);
+    Mat result;
+
+    knnKdt->findNearest(testData, 1, result);
+
+    EXPECT_EQ(1, int(result.at<int>(0, 0)));
+    EXPECT_EQ(2, int(result.at<int>(1, 0)));
+    EXPECT_EQ(0, trainLabels.at<int>(result.at<int>(0, 0), 0));
+}
+
 }} // namespace