first version of moving KDTree from core to ml
[profile/ivi/opencv.git] / modules / features2d / test / test_nearestneighbors.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
14 // Copyright (C) 2009, Willow Garage Inc., all rights reserved.
15 // Copyright (C) 2014, Itseez Inc, all rights reserved.
16 // Third party copyrights are property of their respective owners.
17 //
18 // Redistribution and use in source and binary forms, with or without modification,
19 // are permitted provided that the following conditions are met:
20 //
21 //   * Redistribution's of source code must retain the above copyright notice,
22 //     this list of conditions and the following disclaimer.
23 //
24 //   * Redistribution's in binary form must reproduce the above copyright notice,
25 //     this list of conditions and the following disclaimer in the documentation
26 //     and/or other materials provided with the distribution.
27 //
28 //   * The name of the copyright holders may not be used to endorse or promote products
29 //     derived from this software without specific prior written permission.
30 //
31 // This software is provided by the copyright holders and contributors "as is" and
32 // any express or implied warranties, including, but not limited to, the implied
33 // warranties of merchantability and fitness for a particular purpose are disclaimed.
34 // In no event shall the Intel Corporation or contributors be liable for any direct,
35 // indirect, incidental, special, exemplary, or consequential damages
36 // (including, but not limited to, procurement of substitute goods or services;
37 // loss of use, data, or profits; or business interruption) however caused
38 // and on any theory of liability, whether in contract, strict liability,
39 // or tort (including negligence or otherwise) arising in any way out of
40 // the use of this software, even if advised of the possibility of such damage.
41 //
42 //M*/
43
44 #include "test_precomp.hpp"
45
46 #include <algorithm>
47 #include <vector>
48 #include <iostream>
49
50 using namespace std;
51 using namespace cv;
52 using namespace cv::flann;
53
54 //--------------------------------------------------------------------------------
55 class NearestNeighborTest : public cvtest::BaseTest
56 {
57 public:
58     NearestNeighborTest() {}
59 protected:
60     static const int minValue = 0;
61     static const int maxValue = 1;
62     static const int dims = 30;
63     static const int featuresCount = 2000;
64     static const int K = 1; // * should also test 2nd nn etc.?
65
66
67     virtual void run( int start_from );
68     virtual void createModel( const Mat& data ) = 0;
69     virtual int findNeighbors( Mat& points, Mat& neighbors ) = 0;
70     virtual int checkGetPoins( const Mat& data );
71     virtual int checkFindBoxed();
72     virtual int checkFind( const Mat& data );
73     virtual void releaseModel() = 0;
74 };
75
76 int NearestNeighborTest::checkGetPoins( const Mat& )
77 {
78    return cvtest::TS::OK;
79 }
80
81 int NearestNeighborTest::checkFindBoxed()
82 {
83     return cvtest::TS::OK;
84 }
85
86 int NearestNeighborTest::checkFind( const Mat& data )
87 {
88     int code = cvtest::TS::OK;
89     int pointsCount = 1000;
90     float noise = 0.2f;
91
92     RNG rng;
93     Mat points( pointsCount, dims, CV_32FC1 );
94     Mat results( pointsCount, K, CV_32SC1 );
95
96     std::vector<int> fmap( pointsCount );
97     for( int pi = 0; pi < pointsCount; pi++ )
98     {
99         int fi = rng.next() % featuresCount;
100         fmap[pi] = fi;
101         for( int d = 0; d < dims; d++ )
102             points.at<float>(pi, d) = data.at<float>(fi, d) + rng.uniform(0.0f, 1.0f) * noise;
103     }
104
105     code = findNeighbors( points, results );
106
107     if( code == cvtest::TS::OK )
108     {
109         int correctMatches = 0;
110         for( int pi = 0; pi < pointsCount; pi++ )
111         {
112             if( fmap[pi] == results.at<int>(pi, 0) )
113                 correctMatches++;
114         }
115
116         double correctPerc = correctMatches / (double)pointsCount;
117         if (correctPerc < .75)
118         {
119             ts->printf( cvtest::TS::LOG, "correct_perc = %d\n", correctPerc );
120             code = cvtest::TS::FAIL_BAD_ACCURACY;
121         }
122     }
123
124     return code;
125 }
126
127 void NearestNeighborTest::run( int /*start_from*/ ) {
128     int code = cvtest::TS::OK, tempCode;
129     Mat desc( featuresCount, dims, CV_32FC1 );
130     randu( desc, Scalar(minValue), Scalar(maxValue) );
131
132     createModel( desc );
133
134     tempCode = checkGetPoins( desc );
135     if( tempCode != cvtest::TS::OK )
136     {
137         ts->printf( cvtest::TS::LOG, "bad accuracy of GetPoints \n" );
138         code = tempCode;
139     }
140
141     tempCode = checkFindBoxed();
142     if( tempCode != cvtest::TS::OK )
143     {
144         ts->printf( cvtest::TS::LOG, "bad accuracy of FindBoxed \n" );
145         code = tempCode;
146     }
147
148     tempCode = checkFind( desc );
149     if( tempCode != cvtest::TS::OK )
150     {
151         ts->printf( cvtest::TS::LOG, "bad accuracy of Find \n" );
152         code = tempCode;
153     }
154
155     releaseModel();
156
157     ts->set_failed_test_info( code );
158 }
159
160 //--------------------------------------------------------------------------------
161 class CV_KDTreeTest_CPP : public NearestNeighborTest
162 {
163 public:
164     CV_KDTreeTest_CPP() {}
165 protected:
166     virtual void createModel( const Mat& data );
167     virtual int checkGetPoins( const Mat& data );
168     virtual int findNeighbors( Mat& points, Mat& neighbors );
169     virtual int checkFindBoxed();
170     virtual void releaseModel();
171     ml::KDTree* tr;
172 };
173
174
175 void CV_KDTreeTest_CPP::createModel( const Mat& data )
176 {
177     tr = new ml::KDTree( data, false );
178 }
179
180 int CV_KDTreeTest_CPP::checkGetPoins( const Mat& data )
181 {
182     Mat res1( data.size(), data.type() ),
183         res3( data.size(), data.type() );
184     Mat idxs( 1, data.rows, CV_32SC1 );
185     for( int pi = 0; pi < data.rows; pi++ )
186     {
187         idxs.at<int>(0, pi) = pi;
188         // 1st way
189         const float* point = tr->getPoint(pi);
190         for( int di = 0; di < data.cols; di++ )
191             res1.at<float>(pi, di) = point[di];
192     }
193
194     // 3d way
195     tr->getPoints( idxs, res3 );
196
197     if( cvtest::norm( res1, data, NORM_L1) != 0 ||
198         cvtest::norm( res3, data, NORM_L1) != 0)
199         return cvtest::TS::FAIL_BAD_ACCURACY;
200     return cvtest::TS::OK;
201 }
202
203 int CV_KDTreeTest_CPP::checkFindBoxed()
204 {
205     vector<float> min( dims, static_cast<float>(minValue)), max(dims, static_cast<float>(maxValue));
206     vector<int> indices;
207     tr->findOrthoRange( min, max, indices );
208     // TODO check indices
209     if( (int)indices.size() != featuresCount)
210         return cvtest::TS::FAIL_BAD_ACCURACY;
211     return cvtest::TS::OK;
212 }
213
214 int CV_KDTreeTest_CPP::findNeighbors( Mat& points, Mat& neighbors )
215 {
216     const int emax = 20;
217     Mat neighbors2( neighbors.size(), CV_32SC1 );
218     int j;
219     for( int pi = 0; pi < points.rows; pi++ )
220     {
221         // 1st way
222         Mat nrow = neighbors.row(pi);
223         tr->findNearest( points.row(pi), neighbors.cols, emax, nrow );
224
225         // 2nd way
226         vector<int> neighborsIdx2( neighbors2.cols, 0 );
227         tr->findNearest( points.row(pi), neighbors2.cols, emax, neighborsIdx2 );
228         vector<int>::const_iterator it2 = neighborsIdx2.begin();
229         for( j = 0; it2 != neighborsIdx2.end(); ++it2, j++ )
230             neighbors2.at<int>(pi,j) = *it2;
231     }
232
233     // compare results
234     if( cvtest::norm( neighbors, neighbors2, NORM_L1 ) != 0 )
235         return cvtest::TS::FAIL_BAD_ACCURACY;
236
237     return cvtest::TS::OK;
238 }
239
240 void CV_KDTreeTest_CPP::releaseModel()
241 {
242     delete tr;
243 }
244
245 //--------------------------------------------------------------------------------
246 class CV_FlannTest : public NearestNeighborTest
247 {
248 public:
249     CV_FlannTest() {}
250 protected:
251     void createIndex( const Mat& data, const IndexParams& params );
252     int knnSearch( Mat& points, Mat& neighbors );
253     int radiusSearch( Mat& points, Mat& neighbors );
254     virtual void releaseModel();
255     Index* index;
256 };
257
258 void CV_FlannTest::createIndex( const Mat& data, const IndexParams& params )
259 {
260     index = new Index( data, params );
261 }
262
263 int CV_FlannTest::knnSearch( Mat& points, Mat& neighbors )
264 {
265     Mat dist( points.rows, neighbors.cols, CV_32FC1);
266     int knn = 1, j;
267
268     // 1st way
269     index->knnSearch( points, neighbors, dist, knn, SearchParams() );
270
271     // 2nd way
272     Mat neighbors1( neighbors.size(), CV_32SC1 );
273     for( int i = 0; i < points.rows; i++ )
274     {
275         float* fltPtr = points.ptr<float>(i);
276         vector<float> query( fltPtr, fltPtr + points.cols );
277         vector<int> indices( neighbors1.cols, 0 );
278         vector<float> dists( dist.cols, 0 );
279         index->knnSearch( query, indices, dists, knn, SearchParams() );
280         vector<int>::const_iterator it = indices.begin();
281         for( j = 0; it != indices.end(); ++it, j++ )
282             neighbors1.at<int>(i,j) = *it;
283     }
284
285     // compare results
286     if( cvtest::norm( neighbors, neighbors1, NORM_L1 ) != 0 )
287         return cvtest::TS::FAIL_BAD_ACCURACY;
288
289     return cvtest::TS::OK;
290 }
291
292 int CV_FlannTest::radiusSearch( Mat& points, Mat& neighbors )
293 {
294     Mat dist( 1, neighbors.cols, CV_32FC1);
295     Mat neighbors1( neighbors.size(), CV_32SC1 );
296     float radius = 10.0f;
297     int j;
298
299     // radiusSearch can only search one feature at a time for range search
300     for( int i = 0; i < points.rows; i++ )
301     {
302         // 1st way
303         Mat p( 1, points.cols, CV_32FC1, points.ptr<float>(i) ),
304             n( 1, neighbors.cols, CV_32SC1, neighbors.ptr<int>(i) );
305         index->radiusSearch( p, n, dist, radius, neighbors.cols, SearchParams() );
306
307         // 2nd way
308         float* fltPtr = points.ptr<float>(i);
309         vector<float> query( fltPtr, fltPtr + points.cols );
310         vector<int> indices( neighbors1.cols, 0 );
311         vector<float> dists( dist.cols, 0 );
312         index->radiusSearch( query, indices, dists, radius, neighbors.cols, SearchParams() );
313         vector<int>::const_iterator it = indices.begin();
314         for( j = 0; it != indices.end(); ++it, j++ )
315             neighbors1.at<int>(i,j) = *it;
316     }
317     // compare results
318     if( cvtest::norm( neighbors, neighbors1, NORM_L1 ) != 0 )
319         return cvtest::TS::FAIL_BAD_ACCURACY;
320
321     return cvtest::TS::OK;
322 }
323
324 void CV_FlannTest::releaseModel()
325 {
326     delete index;
327 }
328
329 //---------------------------------------
330 class CV_FlannLinearIndexTest : public CV_FlannTest
331 {
332 public:
333     CV_FlannLinearIndexTest() {}
334 protected:
335     virtual void createModel( const Mat& data ) { createIndex( data, LinearIndexParams() ); }
336     virtual int findNeighbors( Mat& points, Mat& neighbors ) { return knnSearch( points, neighbors ); }
337 };
338
339 //---------------------------------------
340 class CV_FlannKMeansIndexTest : public CV_FlannTest
341 {
342 public:
343     CV_FlannKMeansIndexTest() {}
344 protected:
345     virtual void createModel( const Mat& data ) { createIndex( data, KMeansIndexParams() ); }
346     virtual int findNeighbors( Mat& points, Mat& neighbors ) { return radiusSearch( points, neighbors ); }
347 };
348
349 //---------------------------------------
350 class CV_FlannKDTreeIndexTest : public CV_FlannTest
351 {
352 public:
353     CV_FlannKDTreeIndexTest() {}
354 protected:
355     virtual void createModel( const Mat& data ) { createIndex( data, KDTreeIndexParams() ); }
356     virtual int findNeighbors( Mat& points, Mat& neighbors ) { return radiusSearch( points, neighbors ); }
357 };
358
359 //----------------------------------------
360 class CV_FlannCompositeIndexTest : public CV_FlannTest
361 {
362 public:
363     CV_FlannCompositeIndexTest() {}
364 protected:
365     virtual void createModel( const Mat& data ) { createIndex( data, CompositeIndexParams() ); }
366     virtual int findNeighbors( Mat& points, Mat& neighbors ) { return knnSearch( points, neighbors ); }
367 };
368
369 //----------------------------------------
370 class CV_FlannAutotunedIndexTest : public CV_FlannTest
371 {
372 public:
373     CV_FlannAutotunedIndexTest() {}
374 protected:
375     virtual void createModel( const Mat& data ) { createIndex( data, AutotunedIndexParams() ); }
376     virtual int findNeighbors( Mat& points, Mat& neighbors ) { return knnSearch( points, neighbors ); }
377 };
378 //----------------------------------------
379 class CV_FlannSavedIndexTest : public CV_FlannTest
380 {
381 public:
382     CV_FlannSavedIndexTest() {}
383 protected:
384     virtual void createModel( const Mat& data );
385     virtual int findNeighbors( Mat& points, Mat& neighbors ) { return knnSearch( points, neighbors ); }
386 };
387
388 void CV_FlannSavedIndexTest::createModel(const cv::Mat &data)
389 {
390     switch ( cvtest::randInt(ts->get_rng()) % 2 )
391     {
392         //case 0: createIndex( data, LinearIndexParams() ); break; // nothing to save for linear search
393         case 0: createIndex( data, KMeansIndexParams() ); break;
394         case 1: createIndex( data, KDTreeIndexParams() ); break;
395         //case 2: createIndex( data, CompositeIndexParams() ); break; // nothing to save for linear search
396         //case 2: createIndex( data, AutotunedIndexParams() ); break; // possible linear index !
397         default: assert(0);
398     }
399     string filename = tempfile();
400     index->save( filename );
401
402     createIndex( data, SavedIndexParams(filename.c_str()));
403     remove( filename.c_str() );
404 }
405
406 TEST(Features2d_KDTree_CPP, regression) { CV_KDTreeTest_CPP test; test.safe_run(); }
407 TEST(Features2d_FLANN_Linear, regression) { CV_FlannLinearIndexTest test; test.safe_run(); }
408 TEST(Features2d_FLANN_KMeans, regression) { CV_FlannKMeansIndexTest test; test.safe_run(); }
409 TEST(Features2d_FLANN_KDTree, regression) { CV_FlannKDTreeIndexTest test; test.safe_run(); }
410 TEST(Features2d_FLANN_Composite, regression) { CV_FlannCompositeIndexTest test; test.safe_run(); }
411 TEST(Features2d_FLANN_Auto, regression) { CV_FlannAutotunedIndexTest test; test.safe_run(); }
412 TEST(Features2d_FLANN_Saved, regression) { CV_FlannSavedIndexTest test; test.safe_run(); }