c2526493af318a45ec7438168803c35e47192d49
[platform/upstream/opencv.git] / modules / legacy / src / spilltree.cpp
1 /* Original code has been submitted by Liu Liu.
2    ----------------------------------------------------------------------------------
3    * Spill-Tree for Approximate KNN Search
4    * Author: Liu Liu
5    * mailto: liuliu.1987+opencv@gmail.com
6    * Refer to Paper:
7    * An Investigation of Practical Approximate Nearest Neighbor Algorithms
8    * cvMergeSpillTree TBD
9    *
10    * Redistribution and use in source and binary forms, with or
11    * without modification, are permitted provided that the following
12    * conditions are met:
13    *    Redistributions of source code must retain the above
14    *    copyright notice, this list of conditions and the following
15    *    disclaimer.
16    *    Redistributions in binary form must reproduce the above
17    *    copyright notice, this list of conditions and the following
18    *    disclaimer in the documentation and/or other materials
19    *    provided with the distribution.
20    *    The name of Contributor may not be used to endorse or
21    *    promote products derived from this software without
22    *    specific prior written permission.
23    *
24    * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
25    * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
26    * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
27    * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28    * DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY
29    * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30    * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31    * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
32    * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
33    * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
34    * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
35    * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
36    * OF SUCH DAMAGE.
37    */
38
39 #include "precomp.hpp"
40 #include "_featuretree.h"
41
42 struct CvSpillTreeNode
43 {
44   bool leaf; // is leaf or not (leaf is the point that have no more child)
45   bool spill; // is not a non-overlapping point (defeatist search)
46   CvSpillTreeNode* lc; // left child (<)
47   CvSpillTreeNode* rc; // right child (>)
48   int cc; // child count
49   CvMat* u; // projection vector
50   CvMat* center; // center
51   int i; // original index
52   double r; // radius of remaining feature point
53   double ub; // upper bound
54   double lb; // lower bound
55   double mp; // mean point
56   double p; // projection value
57 };
58
59 struct CvSpillTree
60 {
61   CvSpillTreeNode* root;
62   CvMat** refmat; // leaf ref matrix
63   int total; // total leaves
64   int naive; // under this value, we perform naive search
65   int type; // mat type
66   double rho; // under this value, it is a spill tree
67   double tau; // the overlapping buffer ratio
68 };
69
70 struct CvResult
71 {
72   int index;
73   double distance;
74 };
75
76 // find the farthest node in the "list" from "node"
77 static inline CvSpillTreeNode*
78 icvFarthestNode( CvSpillTreeNode* node,
79          CvSpillTreeNode* list,
80          int total )
81 {
82   double farthest = -1.;
83   CvSpillTreeNode* result = NULL;
84   for ( int i = 0; i < total; i++ )
85     {
86       double norm = cvNorm( node->center, list->center );
87       if ( norm > farthest )
88     {
89       farthest = norm;
90       result = list;
91     }
92       list = list->rc;
93     }
94   return result;
95 }
96
97 // clone a new tree node
98 static inline CvSpillTreeNode*
99 icvCloneSpillTreeNode( CvSpillTreeNode* node )
100 {
101   CvSpillTreeNode* result = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
102   memcpy( result, node, sizeof(CvSpillTreeNode) );
103   return result;
104 }
105
106 // append the link-list of a tree node
107 static inline void
108 icvAppendSpillTreeNode( CvSpillTreeNode* node,
109             CvSpillTreeNode* append )
110 {
111   if ( node->lc == NULL )
112     {
113       node->lc = node->rc = append;
114       node->lc->lc = node->rc->rc = NULL;
115     } else {
116       append->lc = node->rc;
117       append->rc = NULL;
118       node->rc->rc = append;
119       node->rc = append;
120     }
121   node->cc++;
122 }
123
124 #define _dispatch_mat_ptr(x, step) (CV_MAT_DEPTH((x)->type) == CV_32F ? (void*)((x)->data.fl+(step)) : (CV_MAT_DEPTH((x)->type) == CV_64F ? (void*)((x)->data.db+(step)) : (void*)(0)))
125
126 static void
127 icvDFSInitSpillTreeNode( const CvSpillTree* tr,
128              const int d,
129              CvSpillTreeNode* node )
130 {
131   if ( node->cc <= tr->naive )
132     {
133       // already get to a leaf, terminate the recursion.
134       node->leaf = true;
135       node->spill = false;
136       return;
137     }
138
139   // random select a node, then find a farthest node from this one, then find a farthest from that one...
140   // to approximate the farthest node-pair
141   static CvRNG rng_state = cvRNG(0xdeadbeef);
142   int rn = cvRandInt( &rng_state ) % node->cc;
143   CvSpillTreeNode* lnode = NULL;
144   CvSpillTreeNode* rnode = node->lc;
145   for ( int i = 0; i < rn; i++ )
146     rnode = rnode->rc;
147   lnode = icvFarthestNode( rnode, node->lc, node->cc );
148   rnode = icvFarthestNode( lnode, node->lc, node->cc );
149
150   // u is the projection vector
151   node->u = cvCreateMat( 1, d, tr->type );
152   cvSub( lnode->center, rnode->center, node->u );
153   cvNormalize( node->u, node->u );
154
155   // find the center of node in hyperspace
156   node->center = cvCreateMat( 1, d, tr->type );
157   cvZero( node->center );
158   CvSpillTreeNode* it = node->lc;
159   for ( int i = 0; i < node->cc; i++ )
160     {
161       cvAdd( it->center, node->center, node->center );
162       it = it->rc;
163     }
164   cvConvertScale( node->center, node->center, 1./node->cc );
165
166   // project every node to "u", and find the mean point "mp"
167   it = node->lc;
168   node->r = -1.;
169   node->mp = 0;
170   for ( int i = 0; i < node->cc; i++ )
171     {
172       node->mp += ( it->p = cvDotProduct( it->center, node->u ) );
173       double norm = cvNorm( node->center, it->center );
174       if ( norm > node->r )
175     node->r = norm;
176       it = it->rc;
177     }
178   node->mp = node->mp / node->cc;
179
180   // overlapping buffer and upper bound, lower bound
181   double ob = (lnode->p-rnode->p)*tr->tau*.5;
182   node->ub = node->mp+ob;
183   node->lb = node->mp-ob;
184   int sl = 0, l = 0;
185   int sr = 0, r = 0;
186   it = node->lc;
187   for ( int i = 0; i < node->cc; i++ )
188     {
189       if ( it->p <= node->ub )
190     sl++;
191       if ( it->p >= node->lb )
192     sr++;
193       if ( it->p < node->mp )
194     l++;
195       else
196     r++;
197       it = it->rc;
198     }
199   // precision problem, return the node as it is.
200   if (( l == 0 )||( r == 0 ))
201     {
202       cvReleaseMat( &(node->u) );
203       cvReleaseMat( &(node->center) );
204       node->leaf = true;
205       node->spill = false;
206       return;
207     }
208   CvSpillTreeNode* lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
209   memset(lc, 0, sizeof(CvSpillTreeNode));
210   CvSpillTreeNode* rc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
211   memset(rc, 0, sizeof(CvSpillTreeNode));
212   lc->lc = lc->rc = rc->lc = rc->rc = NULL;
213   lc->cc = rc->cc = 0;
214   int undo = cvRound(node->cc*tr->rho);
215   if (( sl >= undo )||( sr >= undo ))
216     {
217       // it is not a spill point (defeatist search disabled)
218       it = node->lc;
219       for ( int i = 0; i < node->cc; i++ )
220     {
221       CvSpillTreeNode* next = it->rc;
222       if ( it->p < node->mp )
223         icvAppendSpillTreeNode( lc, it );
224       else
225         icvAppendSpillTreeNode( rc, it );
226       it = next;
227     }
228       node->spill = false;
229     } else {
230       // a spill point
231       it = node->lc;
232       for ( int i = 0; i < node->cc; i++ )
233     {
234       CvSpillTreeNode* next = it->rc;
235       if ( it->p < node->lb )
236         icvAppendSpillTreeNode( lc, it );
237       else if ( it->p > node->ub )
238         icvAppendSpillTreeNode( rc, it );
239       else {
240         CvSpillTreeNode* cit = icvCloneSpillTreeNode( it );
241         icvAppendSpillTreeNode( lc, it );
242         icvAppendSpillTreeNode( rc, cit );
243       }
244       it = next;
245     }
246       node->spill = true;
247     }
248   node->lc = lc;
249   node->rc = rc;
250
251   // recursion process
252   icvDFSInitSpillTreeNode( tr, d, node->lc );
253   icvDFSInitSpillTreeNode( tr, d, node->rc );
254 }
255
256 static CvSpillTree*
257 icvCreateSpillTree( const CvMat* raw_data,
258             const int naive,
259             const double rho,
260             const double tau )
261 {
262   int n = raw_data->rows;
263   int d = raw_data->cols;
264
265   CvSpillTree* tr = (CvSpillTree*)cvAlloc( sizeof(CvSpillTree) );
266   tr->root = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
267   memset(tr->root, 0, sizeof(CvSpillTreeNode));
268   tr->refmat = (CvMat**)cvAlloc( sizeof(CvMat*)*n );
269   tr->total = n;
270   tr->naive = naive;
271   tr->rho = rho;
272   tr->tau = tau;
273   tr->type = raw_data->type;
274
275   // tie a link-list to the root node
276   tr->root->lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
277   memset(tr->root->lc, 0, sizeof(CvSpillTreeNode));
278   tr->root->lc->center = cvCreateMatHeader( 1, d, tr->type );
279   cvSetData( tr->root->lc->center, _dispatch_mat_ptr(raw_data, 0), raw_data->step );
280   tr->refmat[0] = tr->root->lc->center;
281   tr->root->lc->lc = NULL;
282   tr->root->lc->leaf = true;
283   tr->root->lc->i = 0;
284   CvSpillTreeNode* node = tr->root->lc;
285   for ( int i = 1; i < n; i++ )
286     {
287       CvSpillTreeNode* newnode = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
288       memset(newnode, 0, sizeof(CvSpillTreeNode));
289       newnode->center = cvCreateMatHeader( 1, d, tr->type );
290       cvSetData( newnode->center, _dispatch_mat_ptr(raw_data, i*d), raw_data->step );
291       tr->refmat[i] = newnode->center;
292       newnode->lc = node;
293       newnode->i = i;
294       newnode->leaf = true;
295       newnode->rc = NULL;
296       node->rc = newnode;
297       node = newnode;
298     }
299   tr->root->rc = node;
300   tr->root->cc = n;
301   icvDFSInitSpillTreeNode( tr, d, tr->root );
302   return tr;
303 }
304
305 static void
306 icvSpillTreeNodeHeapify( CvResult * heap,
307              int i,
308              const int k )
309 {
310   if ( heap[i].index == -1 )
311     return;
312   int l, r, largest = i;
313   CvResult inp;
314   do {
315     i = largest;
316     r = (i+1)<<1;
317     l = r-1;
318     if (( l < k )&&( heap[l].index == -1 ))
319       largest = l;
320     else if (( r < k )&&( heap[r].index == -1 ))
321       largest = r;
322     else {
323       if (( l < k )&&( heap[l].distance > heap[i].distance ))
324         largest = l;
325       if (( r < k )&&( heap[r].distance > heap[largest].distance ))
326         largest = r;
327     }
328     if ( largest != i )
329       CV_SWAP( heap[largest], heap[i], inp );
330   } while ( largest != i );
331 }
332
333 static void
334 icvSpillTreeDFSearch( CvSpillTree* tr,
335               CvSpillTreeNode* node,
336               CvResult* heap,
337               int* es,
338               const CvMat* desc,
339               const int k,
340               const int emax,
341                       bool * cache)
342 {
343   if ((emax > 0)&&( *es >= emax ))
344     return;
345   double dist, p=0;
346   double distance;
347   while ( node->spill )
348     {
349       // defeatist search
350       if ( !node->leaf )
351     p = cvDotProduct( node->u, desc );
352       if ( p < node->lb && node->lc->cc >= k ) // check the number of children larger than k otherwise you'll skip over better neighbor
353     node = node->lc;
354       else if ( p > node->ub && node->rc->cc >= k )
355     node = node->rc;
356       else
357     break;
358       if ( NULL == node )
359     return;
360     }
361   if ( node->leaf )
362     {
363       // a leaf, naive search
364       CvSpillTreeNode* it = node->lc;
365       for ( int i = 0; i < node->cc; i++ )
366         {
367           if ( !cache[it->i] )
368           {
369         distance = cvNorm( it->center, desc );
370             cache[it->i] = true;
371         if (( heap[0].index == -1)||( distance < heap[0].distance ))
372           {
373                 CvResult  current_result;
374                 current_result.index = it->i;
375                 current_result.distance = distance;
376                 heap[0] = current_result;
377             icvSpillTreeNodeHeapify( heap, 0, k );
378         (*es)++;
379           }
380           }
381           it = it->rc;
382     }
383       return;
384     }
385   dist = cvNorm( node->center, desc );
386   // impossible case, skip
387   if (( heap[0].index != -1 )&&( dist-node->r > heap[0].distance ))
388     return;
389   p = cvDotProduct( node->u, desc );
390   // guided dfs
391   if ( p < node->mp )
392     {
393       icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache );
394       icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache );
395     } else {
396     icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache );
397     icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache );
398     }
399 }
400
401 static void
402 icvFindSpillTreeFeatures( CvSpillTree* tr,
403               const CvMat* desc,
404               CvMat* results,
405               CvMat* dist,
406               const int k,
407               const int emax )
408 {
409   assert( desc->type == tr->type );
410   CvResult* heap = (CvResult*)cvAlloc( k*sizeof(heap[0]) );
411   bool* cache = (bool*)cvAlloc( sizeof(bool)*tr->total );
412   for ( int j = 0; j < desc->rows; j++ )
413     {
414       CvMat _desc = cvMat( 1, desc->cols, desc->type, _dispatch_mat_ptr(desc, j*desc->cols) );
415       for ( int i = 0; i < k; i++ ) {
416         CvResult current;
417         current.index=-1;
418         current.distance=-1;
419     heap[i] = current;
420       }
421       memset( cache, 0, sizeof(bool)*tr->total );
422       int es = 0;
423       icvSpillTreeDFSearch( tr, tr->root, heap, &es, &_desc, k, emax, cache );
424       CvResult inp;
425       for ( int i = k-1; i > 0; i-- )
426     {
427       CV_SWAP( heap[i], heap[0], inp );
428       icvSpillTreeNodeHeapify( heap, 0, i );
429     }
430       int* rs = results->data.i+j*results->cols;
431       double* dt = dist->data.db+j*dist->cols;
432       for ( int i = 0; i < k; i++, rs++, dt++ )
433     if ( heap[i].index != -1 )
434       {
435         *rs = heap[i].index;
436         *dt = heap[i].distance;
437       } else
438         *rs = -1;
439     }
440   cvFree( &heap );
441   cvFree( &cache );
442 }
443
444 static void
445 icvDFSReleaseSpillTreeNode( CvSpillTreeNode* node )
446 {
447   if ( node->leaf )
448     {
449       CvSpillTreeNode* it = node->lc;
450       for ( int i = 0; i < node->cc; i++ )
451         {
452           CvSpillTreeNode* s = it;
453           it = it->rc;
454           cvFree( &s );
455         }
456     } else {
457       cvReleaseMat( &node->u );
458       cvReleaseMat( &node->center );
459       icvDFSReleaseSpillTreeNode( node->lc );
460       icvDFSReleaseSpillTreeNode( node->rc );
461     }
462   cvFree( &node );
463 }
464
465 static void
466 icvReleaseSpillTree( CvSpillTree** tr )
467 {
468   for ( int i = 0; i < (*tr)->total; i++ )
469     cvReleaseMat( &((*tr)->refmat[i]) );
470   cvFree( &((*tr)->refmat) );
471   icvDFSReleaseSpillTreeNode( (*tr)->root );
472   cvFree( tr );
473 }
474
475 class CvSpillTreeWrap : public CvFeatureTree {
476   CvSpillTree* tr;
477 public:
478   CvSpillTreeWrap(const CvMat* raw_data,
479           const int naive,
480           const double rho,
481           const double tau) {
482     tr = icvCreateSpillTree(raw_data, naive, rho, tau);
483   }
484   ~CvSpillTreeWrap() {
485     icvReleaseSpillTree(&tr);
486   }
487
488   void FindFeatures(const CvMat* desc, int k, int emax, CvMat* results, CvMat* dist) {
489     icvFindSpillTreeFeatures(tr, desc, results, dist, k, emax);
490   }
491 };
492
493 CvFeatureTree* cvCreateSpillTree( const CvMat* raw_data,
494                   const int naive,
495                   const double rho,
496                   const double tau ) {
497   return new CvSpillTreeWrap(raw_data, naive, rho, tau);
498 }