String getModelName() const CV_OVERRIDE { return NAME_BRUTE_FORCE; }
int getType() const CV_OVERRIDE { return ml::KNearest::BRUTE_FORCE; }
- void findNearestCore( const Mat& _samples, int k0, const Range& range,
+ void findNearestCore( const Mat& _samples, int k, const Range& range,
Mat* results, Mat* neighbor_responses,
Mat* dists, float* presult ) const
{
int testidx, baseidx, i, j, d = samples.cols, nsamples = samples.rows;
int testcount = range.end - range.start;
- int k = std::min(k0, nsamples);
AutoBuffer<float> buf(testcount*k*2);
float* dbuf = buf.data();
float* nr = neighbor_responses->ptr<float>(testidx + range.start);
for( j = 0; j < k; j++ )
nr[j] = rbuf[testidx*k + j];
- for( ; j < k0; j++ )
+ for( ; j < k; j++ )
nr[j] = 0.f;
}
float* dptr = dists->ptr<float>(testidx + range.start);
for( j = 0; j < k; j++ )
dptr[j] = dbuf[testidx*k + j];
- for( ; j < k0; j++ )
+ for( ; j < k; j++ )
dptr[j] = 0.f;
}
{
float result = 0.f;
CV_Assert( 0 < k );
+ k = std::min(k, samples.rows);
Mat test_samples = _samples.getMat();
CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
{
float result = 0.f;
CV_Assert( 0 < k );
+ k = std::min(k, samples.rows);
Mat test_samples = _samples.getMat();
CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }
TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); }
+TEST(ML_KNearest, regression_12347)
+{
+ Mat xTrainData = (Mat_<float>(5,2) << 1, 1.1, 1.1, 1, 2, 2, 2.1, 2, 2.1, 2.1);
+ Mat yTrainLabels = (Mat_<float>(5,1) << 1, 1, 2, 2, 2);
+ Ptr<KNearest> knn = KNearest::create();
+ knn->train(xTrainData, ml::ROW_SAMPLE, yTrainLabels);
+
+ Mat xTestData = (Mat_<float>(2,2) << 1.1, 1.1, 2, 2.2);
+ Mat zBestLabels, neighbours, dist;
+ // check output shapes:
+ int K = 16, Kexp = std::min(K, xTrainData.rows);
+ knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
+ EXPECT_EQ(xTestData.rows, zBestLabels.rows);
+ EXPECT_EQ(neighbours.cols, Kexp);
+ EXPECT_EQ(dist.cols, Kexp);
+ // see if the result is still correct:
+ K = 2;
+ knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
+ EXPECT_EQ(1, zBestLabels.at<float>(0,0));
+ EXPECT_EQ(2, zBestLabels.at<float>(1,0));
+}
+
}} // namespace