1 /* Original code has been submitted by Liu Liu.
2 ----------------------------------------------------------------------------------
3 * Spill-Tree for Approximate KNN Search
5 * mailto: liuliu.1987+opencv@gmail.com
7 * An Investigation of Practical Approximate Nearest Neighbor Algorithms
10 * Redistribution and use in source and binary forms, with or
11 * without modification, are permitted provided that the following
13 * Redistributions of source code must retain the above
14 * copyright notice, this list of conditions and the following
16 * Redistributions in binary form must reproduce the above
17 * copyright notice, this list of conditions and the following
18 * disclaimer in the documentation and/or other materials
19 * provided with the distribution.
20 * The name of Contributor may not be used to endorse or
21 * promote products derived from this software without
22 * specific prior written permission.
24 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
25 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
26 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
27 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28 * DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY
29 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
32 * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
33 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
34 * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
35 * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
39 #include "precomp.hpp"
40 #include "_featuretree.h"
42 struct CvSpillTreeNode
44 bool leaf; // is leaf or not (leaf is the point that have no more child)
45 bool spill; // is not a non-overlapping point (defeatist search)
46 CvSpillTreeNode* lc; // left child (<)
47 CvSpillTreeNode* rc; // right child (>)
48 int cc; // child count
49 CvMat* u; // projection vector
50 CvMat* center; // center
51 int i; // original index
52 double r; // radius of remaining feature point
53 double ub; // upper bound
54 double lb; // lower bound
55 double mp; // mean point
56 double p; // projection value
61 CvSpillTreeNode* root;
62 CvMat** refmat; // leaf ref matrix
63 int total; // total leaves
64 int naive; // under this value, we perform naive search
66 double rho; // under this value, it is a spill tree
67 double tau; // the overlapping buffer ratio
76 // find the farthest node in the "list" from "node"
77 static inline CvSpillTreeNode*
78 icvFarthestNode( CvSpillTreeNode* node,
79 CvSpillTreeNode* list,
82 double farthest = -1.;
83 CvSpillTreeNode* result = NULL;
84 for ( int i = 0; i < total; i++ )
86 double norm = cvNorm( node->center, list->center );
87 if ( norm > farthest )
97 // clone a new tree node
98 static inline CvSpillTreeNode*
99 icvCloneSpillTreeNode( CvSpillTreeNode* node )
101 CvSpillTreeNode* result = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
102 memcpy( result, node, sizeof(CvSpillTreeNode) );
106 // append the link-list of a tree node
108 icvAppendSpillTreeNode( CvSpillTreeNode* node,
109 CvSpillTreeNode* append )
111 if ( node->lc == NULL )
113 node->lc = node->rc = append;
114 node->lc->lc = node->rc->rc = NULL;
116 append->lc = node->rc;
118 node->rc->rc = append;
124 #define _dispatch_mat_ptr(x, step) (CV_MAT_DEPTH((x)->type) == CV_32F ? (void*)((x)->data.fl+(step)) : (CV_MAT_DEPTH((x)->type) == CV_64F ? (void*)((x)->data.db+(step)) : (void*)(0)))
127 icvDFSInitSpillTreeNode( const CvSpillTree* tr,
129 CvSpillTreeNode* node )
131 if ( node->cc <= tr->naive )
133 // already get to a leaf, terminate the recursion.
139 // random select a node, then find a farthest node from this one, then find a farthest from that one...
140 // to approximate the farthest node-pair
141 static CvRNG rng_state = cvRNG(0xdeadbeef);
142 int rn = cvRandInt( &rng_state ) % node->cc;
143 CvSpillTreeNode* lnode = NULL;
144 CvSpillTreeNode* rnode = node->lc;
145 for ( int i = 0; i < rn; i++ )
147 lnode = icvFarthestNode( rnode, node->lc, node->cc );
148 rnode = icvFarthestNode( lnode, node->lc, node->cc );
150 // u is the projection vector
151 node->u = cvCreateMat( 1, d, tr->type );
152 cvSub( lnode->center, rnode->center, node->u );
153 cvNormalize( node->u, node->u );
155 // find the center of node in hyperspace
156 node->center = cvCreateMat( 1, d, tr->type );
157 cvZero( node->center );
158 CvSpillTreeNode* it = node->lc;
159 for ( int i = 0; i < node->cc; i++ )
161 cvAdd( it->center, node->center, node->center );
164 cvConvertScale( node->center, node->center, 1./node->cc );
166 // project every node to "u", and find the mean point "mp"
170 for ( int i = 0; i < node->cc; i++ )
172 node->mp += ( it->p = cvDotProduct( it->center, node->u ) );
173 double norm = cvNorm( node->center, it->center );
174 if ( norm > node->r )
178 node->mp = node->mp / node->cc;
180 // overlapping buffer and upper bound, lower bound
181 double ob = (lnode->p-rnode->p)*tr->tau*.5;
182 node->ub = node->mp+ob;
183 node->lb = node->mp-ob;
187 for ( int i = 0; i < node->cc; i++ )
189 if ( it->p <= node->ub )
191 if ( it->p >= node->lb )
193 if ( it->p < node->mp )
199 // precision problem, return the node as it is.
200 if (( l == 0 )||( r == 0 ))
202 cvReleaseMat( &(node->u) );
203 cvReleaseMat( &(node->center) );
208 CvSpillTreeNode* lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
209 memset(lc, 0, sizeof(CvSpillTreeNode));
210 CvSpillTreeNode* rc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
211 memset(rc, 0, sizeof(CvSpillTreeNode));
212 lc->lc = lc->rc = rc->lc = rc->rc = NULL;
214 int undo = cvRound(node->cc*tr->rho);
215 if (( sl >= undo )||( sr >= undo ))
217 // it is not a spill point (defeatist search disabled)
219 for ( int i = 0; i < node->cc; i++ )
221 CvSpillTreeNode* next = it->rc;
222 if ( it->p < node->mp )
223 icvAppendSpillTreeNode( lc, it );
225 icvAppendSpillTreeNode( rc, it );
232 for ( int i = 0; i < node->cc; i++ )
234 CvSpillTreeNode* next = it->rc;
235 if ( it->p < node->lb )
236 icvAppendSpillTreeNode( lc, it );
237 else if ( it->p > node->ub )
238 icvAppendSpillTreeNode( rc, it );
240 CvSpillTreeNode* cit = icvCloneSpillTreeNode( it );
241 icvAppendSpillTreeNode( lc, it );
242 icvAppendSpillTreeNode( rc, cit );
252 icvDFSInitSpillTreeNode( tr, d, node->lc );
253 icvDFSInitSpillTreeNode( tr, d, node->rc );
257 icvCreateSpillTree( const CvMat* raw_data,
262 int n = raw_data->rows;
263 int d = raw_data->cols;
265 CvSpillTree* tr = (CvSpillTree*)cvAlloc( sizeof(CvSpillTree) );
266 tr->root = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
267 memset(tr->root, 0, sizeof(CvSpillTreeNode));
268 tr->refmat = (CvMat**)cvAlloc( sizeof(CvMat*)*n );
273 tr->type = raw_data->type;
275 // tie a link-list to the root node
276 tr->root->lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
277 memset(tr->root->lc, 0, sizeof(CvSpillTreeNode));
278 tr->root->lc->center = cvCreateMatHeader( 1, d, tr->type );
279 cvSetData( tr->root->lc->center, _dispatch_mat_ptr(raw_data, 0), raw_data->step );
280 tr->refmat[0] = tr->root->lc->center;
281 tr->root->lc->lc = NULL;
282 tr->root->lc->leaf = true;
284 CvSpillTreeNode* node = tr->root->lc;
285 for ( int i = 1; i < n; i++ )
287 CvSpillTreeNode* newnode = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
288 memset(newnode, 0, sizeof(CvSpillTreeNode));
289 newnode->center = cvCreateMatHeader( 1, d, tr->type );
290 cvSetData( newnode->center, _dispatch_mat_ptr(raw_data, i*d), raw_data->step );
291 tr->refmat[i] = newnode->center;
294 newnode->leaf = true;
301 icvDFSInitSpillTreeNode( tr, d, tr->root );
306 icvSpillTreeNodeHeapify( CvResult * heap,
310 if ( heap[i].index == -1 )
312 int l, r, largest = i;
318 if (( l < k )&&( heap[l].index == -1 ))
320 else if (( r < k )&&( heap[r].index == -1 ))
323 if (( l < k )&&( heap[l].distance > heap[i].distance ))
325 if (( r < k )&&( heap[r].distance > heap[largest].distance ))
329 CV_SWAP( heap[largest], heap[i], inp );
330 } while ( largest != i );
334 icvSpillTreeDFSearch( CvSpillTree* tr,
335 CvSpillTreeNode* node,
343 if ((emax > 0)&&( *es >= emax ))
347 while ( node->spill )
351 p = cvDotProduct( node->u, desc );
352 if ( p < node->lb && node->lc->cc >= k ) // check the number of children larger than k otherwise you'll skip over better neighbor
354 else if ( p > node->ub && node->rc->cc >= k )
363 // a leaf, naive search
364 CvSpillTreeNode* it = node->lc;
365 for ( int i = 0; i < node->cc; i++ )
369 distance = cvNorm( it->center, desc );
371 if (( heap[0].index == -1)||( distance < heap[0].distance ))
373 CvResult current_result;
374 current_result.index = it->i;
375 current_result.distance = distance;
376 heap[0] = current_result;
377 icvSpillTreeNodeHeapify( heap, 0, k );
385 dist = cvNorm( node->center, desc );
386 // impossible case, skip
387 if (( heap[0].index != -1 )&&( dist-node->r > heap[0].distance ))
389 p = cvDotProduct( node->u, desc );
393 icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache );
394 icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache );
396 icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache );
397 icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache );
402 icvFindSpillTreeFeatures( CvSpillTree* tr,
409 assert( desc->type == tr->type );
410 CvResult* heap = (CvResult*)cvAlloc( k*sizeof(heap[0]) );
411 bool* cache = (bool*)cvAlloc( sizeof(bool)*tr->total );
412 for ( int j = 0; j < desc->rows; j++ )
414 CvMat _desc = cvMat( 1, desc->cols, desc->type, _dispatch_mat_ptr(desc, j*desc->cols) );
415 for ( int i = 0; i < k; i++ ) {
421 memset( cache, 0, sizeof(bool)*tr->total );
423 icvSpillTreeDFSearch( tr, tr->root, heap, &es, &_desc, k, emax, cache );
425 for ( int i = k-1; i > 0; i-- )
427 CV_SWAP( heap[i], heap[0], inp );
428 icvSpillTreeNodeHeapify( heap, 0, i );
430 int* rs = results->data.i+j*results->cols;
431 double* dt = dist->data.db+j*dist->cols;
432 for ( int i = 0; i < k; i++, rs++, dt++ )
433 if ( heap[i].index != -1 )
436 *dt = heap[i].distance;
445 icvDFSReleaseSpillTreeNode( CvSpillTreeNode* node )
449 CvSpillTreeNode* it = node->lc;
450 for ( int i = 0; i < node->cc; i++ )
452 CvSpillTreeNode* s = it;
457 cvReleaseMat( &node->u );
458 cvReleaseMat( &node->center );
459 icvDFSReleaseSpillTreeNode( node->lc );
460 icvDFSReleaseSpillTreeNode( node->rc );
466 icvReleaseSpillTree( CvSpillTree** tr )
468 for ( int i = 0; i < (*tr)->total; i++ )
469 cvReleaseMat( &((*tr)->refmat[i]) );
470 cvFree( &((*tr)->refmat) );
471 icvDFSReleaseSpillTreeNode( (*tr)->root );
475 class CvSpillTreeWrap : public CvFeatureTree {
478 CvSpillTreeWrap(const CvMat* raw_data,
482 tr = icvCreateSpillTree(raw_data, naive, rho, tau);
485 icvReleaseSpillTree(&tr);
488 void FindFeatures(const CvMat* desc, int k, int emax, CvMat* results, CvMat* dist) {
489 icvFindSpillTreeFeatures(tr, desc, results, dist, k, emax);
493 CvFeatureTree* cvCreateSpillTree( const CvMat* raw_data,
497 return new CvSpillTreeWrap(raw_data, naive, rho, tau);