ml: fix adjusting K in KNearest (#12358)
authorberak <px1704@web.de>
Fri, 31 Aug 2018 13:07:53 +0000 (15:07 +0200)
committerVadim Pisarevsky <vadim.pisarevsky@gmail.com>
Fri, 31 Aug 2018 13:07:53 +0000 (16:07 +0300)
modules/ml/src/knearest.cpp
modules/ml/test/test_emknearestkmeans.cpp

index d608012..df48b00 100644 (file)
@@ -140,13 +140,12 @@ public:
     String getModelName() const CV_OVERRIDE { return NAME_BRUTE_FORCE; }
     int getType() const CV_OVERRIDE { return ml::KNearest::BRUTE_FORCE; }
 
-    void findNearestCore( const Mat& _samples, int k0, const Range& range,
+    void findNearestCore( const Mat& _samples, int k, const Range& range,
                           Mat* results, Mat* neighbor_responses,
                           Mat* dists, float* presult ) const
     {
         int testidx, baseidx, i, j, d = samples.cols, nsamples = samples.rows;
         int testcount = range.end - range.start;
-        int k = std::min(k0, nsamples);
 
         AutoBuffer<float> buf(testcount*k*2);
         float* dbuf = buf.data();
@@ -215,7 +214,7 @@ public:
                 float* nr = neighbor_responses->ptr<float>(testidx + range.start);
                 for( j = 0; j < k; j++ )
                     nr[j] = rbuf[testidx*k + j];
-                for( ; j < k0; j++ )
+                for( ; j < k; j++ )
                     nr[j] = 0.f;
             }
 
@@ -224,7 +223,7 @@ public:
                 float* dptr = dists->ptr<float>(testidx + range.start);
                 for( j = 0; j < k; j++ )
                     dptr[j] = dbuf[testidx*k + j];
-                for( ; j < k0; j++ )
+                for( ; j < k; j++ )
                     dptr[j] = 0.f;
             }
 
@@ -307,6 +306,7 @@ public:
     {
         float result = 0.f;
         CV_Assert( 0 < k );
+        k = std::min(k, samples.rows);
 
         Mat test_samples = _samples.getMat();
         CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
@@ -363,6 +363,7 @@ public:
     {
         float result = 0.f;
         CV_Assert( 0 < k );
+        k = std::min(k, samples.rows);
 
         Mat test_samples = _samples.getMat();
         CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
index 6755c2e..691815c 100644 (file)
@@ -702,4 +702,26 @@ TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
 TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }
 TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); }
 
+TEST(ML_KNearest, regression_12347)
+{
+    Mat xTrainData = (Mat_<float>(5,2) << 1, 1.1, 1.1, 1, 2, 2, 2.1, 2, 2.1, 2.1);
+    Mat yTrainLabels = (Mat_<float>(5,1) << 1, 1, 2, 2, 2);
+    Ptr<KNearest> knn = KNearest::create();
+    knn->train(xTrainData, ml::ROW_SAMPLE, yTrainLabels);
+
+    Mat xTestData = (Mat_<float>(2,2) << 1.1, 1.1, 2, 2.2);
+    Mat zBestLabels, neighbours, dist;
+    // check output shapes:
+    int K = 16, Kexp = std::min(K, xTrainData.rows);
+    knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
+    EXPECT_EQ(xTestData.rows, zBestLabels.rows);
+    EXPECT_EQ(neighbours.cols, Kexp);
+    EXPECT_EQ(dist.cols, Kexp);
+    // see if the result is still correct:
+    K = 2;
+    knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
+    EXPECT_EQ(1, zBestLabels.at<float>(0,0));
+    EXPECT_EQ(2, zBestLabels.at<float>(1,0));
+}
+
 }} // namespace