Wraps cv::EMD for Python and Java
[platform/upstream/opencv.git] / modules / imgproc / src / emd.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 /*
43     Partially based on Yossi Rubner code:
44     =========================================================================
45     emd.c
46
47     Last update: 3/14/98
48
49     An implementation of the Earth Movers Distance.
50     Based of the solution for the Transportation problem as described in
51     "Introduction to Mathematical Programming" by F. S. Hillier and
52     G. J. Lieberman, McGraw-Hill, 1990.
53
54     Copyright (C) 1998 Yossi Rubner
55     Computer Science Department, Stanford University
56     E-Mail: rubner@cs.stanford.edu   URL: http://vision.stanford.edu/~rubner
57     ==========================================================================
58 */
59 #include "precomp.hpp"
60
61 #define MAX_ITERATIONS 500
62 #define CV_EMD_INF   ((float)1e20)
63 #define CV_EMD_EPS   ((float)1e-5)
64
65 /* CvNode1D is used for lists, representing 1D sparse array */
66 typedef struct CvNode1D
67 {
68     float val;
69     struct CvNode1D *next;
70 }
71 CvNode1D;
72
73 /* CvNode2D is used for lists, representing 2D sparse matrix */
74 typedef struct CvNode2D
75 {
76     float val;
77     struct CvNode2D *next[2];  /* next row & next column */
78     int i, j;
79 }
80 CvNode2D;
81
82
83 typedef struct CvEMDState
84 {
85     int ssize, dsize;
86
87     float **cost;
88     CvNode2D *_x;
89     CvNode2D *end_x;
90     CvNode2D *enter_x;
91     char **is_x;
92
93     CvNode2D **rows_x;
94     CvNode2D **cols_x;
95
96     CvNode1D *u;
97     CvNode1D *v;
98
99     int* idx1;
100     int* idx2;
101
102     /* find_loop buffers */
103     CvNode2D **loop;
104     char *is_used;
105
106     /* russel buffers */
107     float *s;
108     float *d;
109     float **delta;
110
111     float weight, max_cost;
112     char *buffer;
113 }
114 CvEMDState;
115
116 /* static function declaration */
117 static int icvInitEMD( const float *signature1, int size1,
118                        const float *signature2, int size2,
119                        int dims, CvDistanceFunction dist_func, void *user_param,
120                        const float* cost, int cost_step,
121                        CvEMDState * state, float *lower_bound,
122                        cv::AutoBuffer<char>& _buffer );
123
124 static int icvFindBasicVariables( float **cost, char **is_x,
125                                   CvNode1D * u, CvNode1D * v, int ssize, int dsize );
126
127 static float icvIsOptimal( float **cost, char **is_x,
128                            CvNode1D * u, CvNode1D * v,
129                            int ssize, int dsize, CvNode2D * enter_x );
130
131 static void icvRussel( CvEMDState * state );
132
133
134 static bool icvNewSolution( CvEMDState * state );
135 static int icvFindLoop( CvEMDState * state );
136
137 static void icvAddBasicVariable( CvEMDState * state,
138                                  int min_i, int min_j,
139                                  CvNode1D * prev_u_min_i,
140                                  CvNode1D * prev_v_min_j,
141                                  CvNode1D * u_head );
142
143 static float icvDistL2( const float *x, const float *y, void *user_param );
144 static float icvDistL1( const float *x, const float *y, void *user_param );
145 static float icvDistC( const float *x, const float *y, void *user_param );
146
147 /* The main function */
148 CV_IMPL float cvCalcEMD2( const CvArr* signature_arr1,
149             const CvArr* signature_arr2,
150             int dist_type,
151             CvDistanceFunction dist_func,
152             const CvArr* cost_matrix,
153             CvArr* flow_matrix,
154             float *lower_bound,
155             void *user_param )
156 {
157     cv::AutoBuffer<char> local_buf;
158     CvEMDState state;
159     float emd = 0;
160
161     memset( &state, 0, sizeof(state));
162
163     double total_cost = 0;
164     int result = 0;
165     float eps, min_delta;
166     CvNode2D *xp = 0;
167     CvMat sign_stub1, *signature1 = (CvMat*)signature_arr1;
168     CvMat sign_stub2, *signature2 = (CvMat*)signature_arr2;
169     CvMat cost_stub, *cost = &cost_stub;
170     CvMat flow_stub, *flow = (CvMat*)flow_matrix;
171     int dims, size1, size2;
172
173     signature1 = cvGetMat( signature1, &sign_stub1 );
174     signature2 = cvGetMat( signature2, &sign_stub2 );
175
176     if( signature1->cols != signature2->cols )
177         CV_Error( CV_StsUnmatchedSizes, "The arrays must have equal number of columns (which is number of dimensions but 1)" );
178
179     dims = signature1->cols - 1;
180     size1 = signature1->rows;
181     size2 = signature2->rows;
182
183     if( !CV_ARE_TYPES_EQ( signature1, signature2 ))
184         CV_Error( CV_StsUnmatchedFormats, "The array must have equal types" );
185
186     if( CV_MAT_TYPE( signature1->type ) != CV_32FC1 )
187         CV_Error( CV_StsUnsupportedFormat, "The signatures must be 32fC1" );
188
189     if( flow )
190     {
191         flow = cvGetMat( flow, &flow_stub );
192
193         if( flow->rows != size1 || flow->cols != size2 )
194             CV_Error( CV_StsUnmatchedSizes,
195             "The flow matrix size does not match to the signatures' sizes" );
196
197         if( CV_MAT_TYPE( flow->type ) != CV_32FC1 )
198             CV_Error( CV_StsUnsupportedFormat, "The flow matrix must be 32fC1" );
199     }
200
201     cost->data.fl = 0;
202     cost->step = 0;
203
204     if( dist_type < 0 )
205     {
206         if( cost_matrix )
207         {
208             if( dist_func )
209                 CV_Error( CV_StsBadArg,
210                 "Only one of cost matrix or distance function should be non-NULL in case of user-defined distance" );
211
212             if( lower_bound )
213                 CV_Error( CV_StsBadArg,
214                 "The lower boundary can not be calculated if the cost matrix is used" );
215
216             cost = cvGetMat( cost_matrix, &cost_stub );
217             if( cost->rows != size1 || cost->cols != size2 )
218                 CV_Error( CV_StsUnmatchedSizes,
219                 "The cost matrix size does not match to the signatures' sizes" );
220
221             if( CV_MAT_TYPE( cost->type ) != CV_32FC1 )
222                 CV_Error( CV_StsUnsupportedFormat, "The cost matrix must be 32fC1" );
223         }
224         else if( !dist_func )
225             CV_Error( CV_StsNullPtr, "In case of user-defined distance Distance function is undefined" );
226     }
227     else
228     {
229         if( dims == 0 )
230             CV_Error( CV_StsBadSize,
231             "Number of dimensions can be 0 only if a user-defined metric is used" );
232         user_param = (void *) (size_t)dims;
233         switch (dist_type)
234         {
235         case CV_DIST_L1:
236             dist_func = icvDistL1;
237             break;
238         case CV_DIST_L2:
239             dist_func = icvDistL2;
240             break;
241         case CV_DIST_C:
242             dist_func = icvDistC;
243             break;
244         default:
245             CV_Error( CV_StsBadFlag, "Bad or unsupported metric type" );
246         }
247     }
248
249     result = icvInitEMD( signature1->data.fl, size1,
250                         signature2->data.fl, size2,
251                         dims, dist_func, user_param,
252                         cost->data.fl, cost->step,
253                         &state, lower_bound, local_buf );
254
255     if( result > 0 && lower_bound )
256     {
257         emd = *lower_bound;
258         return emd;
259     }
260
261     eps = CV_EMD_EPS * state.max_cost;
262
263     /* if ssize = 1 or dsize = 1 then we are done, else ... */
264     if( state.ssize > 1 && state.dsize > 1 )
265     {
266         int itr;
267
268         for( itr = 1; itr < MAX_ITERATIONS; itr++ )
269         {
270             /* find basic variables */
271             result = icvFindBasicVariables( state.cost, state.is_x,
272                                             state.u, state.v, state.ssize, state.dsize );
273             if( result < 0 )
274                 break;
275
276             /* check for optimality */
277             min_delta = icvIsOptimal( state.cost, state.is_x,
278                                       state.u, state.v,
279                                       state.ssize, state.dsize, state.enter_x );
280
281             if( min_delta == CV_EMD_INF )
282                 CV_Error( CV_StsNoConv, "" );
283
284             /* if no negative deltamin, we found the optimal solution */
285             if( min_delta >= -eps )
286                 break;
287
288             /* improve solution */
289             if(!icvNewSolution( &state ))
290                 CV_Error( CV_StsNoConv, "" );
291         }
292     }
293
294     /* compute the total flow */
295     for( xp = state._x; xp < state.end_x; xp++ )
296     {
297         float val = xp->val;
298         int i = xp->i;
299         int j = xp->j;
300
301         if( xp == state.enter_x )
302           continue;
303
304         int ci = state.idx1[i];
305         int cj = state.idx2[j];
306
307         if( ci >= 0 && cj >= 0 )
308         {
309             total_cost += (double)val * state.cost[i][j];
310             if( flow )
311                 ((float*)(flow->data.ptr + flow->step*ci))[cj] = val;
312         }
313     }
314
315     emd = (float) (total_cost / state.weight);
316     return emd;
317 }
318
319
320 /************************************************************************************\
321 *          initialize structure, allocate buffers and generate initial golution      *
322 \************************************************************************************/
323 static int icvInitEMD( const float* signature1, int size1,
324             const float* signature2, int size2,
325             int dims, CvDistanceFunction dist_func, void* user_param,
326             const float* cost, int cost_step,
327             CvEMDState* state, float* lower_bound,
328             cv::AutoBuffer<char>& _buffer )
329 {
330     float s_sum = 0, d_sum = 0, diff;
331     int i, j;
332     int ssize = 0, dsize = 0;
333     int equal_sums = 1;
334     int buffer_size;
335     float max_cost = 0;
336     char *buffer, *buffer_end;
337
338     memset( state, 0, sizeof( *state ));
339     assert( cost_step % sizeof(float) == 0 );
340     cost_step /= sizeof(float);
341
342     /* calculate buffer size */
343     buffer_size = (size1+1) * (size2+1) * (sizeof( float ) +    /* cost */
344                                    sizeof( char ) +     /* is_x */
345                                    sizeof( float )) +   /* delta matrix */
346         (size1 + size2 + 2) * (sizeof( CvNode2D ) + /* _x */
347                            sizeof( CvNode2D * ) +  /* cols_x & rows_x */
348                            sizeof( CvNode1D ) + /* u & v */
349                            sizeof( float ) + /* s & d */
350                            sizeof( int ) + sizeof(CvNode2D*)) +  /* idx1 & idx2 */
351         (size1+1) * (sizeof( float * ) + sizeof( char * ) + /* rows pointers for */
352                  sizeof( float * )) + 256;      /*  cost, is_x and delta */
353
354     if( buffer_size < (int) (dims * 2 * sizeof( float )))
355     {
356         buffer_size = dims * 2 * sizeof( float );
357     }
358
359     /* allocate buffers */
360     _buffer.allocate(buffer_size);
361
362     state->buffer = buffer = _buffer;
363     buffer_end = buffer + buffer_size;
364
365     state->idx1 = (int*) buffer;
366     buffer += (size1 + 1) * sizeof( int );
367
368     state->idx2 = (int*) buffer;
369     buffer += (size2 + 1) * sizeof( int );
370
371     state->s = (float *) buffer;
372     buffer += (size1 + 1) * sizeof( float );
373
374     state->d = (float *) buffer;
375     buffer += (size2 + 1) * sizeof( float );
376
377     /* sum up the supply and demand */
378     for( i = 0; i < size1; i++ )
379     {
380         float weight = signature1[i * (dims + 1)];
381
382         if( weight > 0 )
383         {
384             s_sum += weight;
385             state->s[ssize] = weight;
386             state->idx1[ssize++] = i;
387
388         }
389         else if( weight < 0 )
390             CV_Error(CV_StsBadArg, "signature1 must not contain negative weights");
391     }
392
393     for( i = 0; i < size2; i++ )
394     {
395         float weight = signature2[i * (dims + 1)];
396
397         if( weight > 0 )
398         {
399             d_sum += weight;
400             state->d[dsize] = weight;
401             state->idx2[dsize++] = i;
402         }
403         else if( weight < 0 )
404             CV_Error(CV_StsBadArg, "signature2 must not contain negative weights");
405     }
406
407     if( ssize == 0 )
408         CV_Error(CV_StsBadArg, "signature1 must contain at least one non-zero value");
409     if( dsize == 0 )
410         CV_Error(CV_StsBadArg, "signature2 must contain at least one non-zero value");
411
412     /* if supply different than the demand, add a zero-cost dummy cluster */
413     diff = s_sum - d_sum;
414     if( fabs( diff ) >= CV_EMD_EPS * s_sum )
415     {
416         equal_sums = 0;
417         if( diff < 0 )
418         {
419             state->s[ssize] = -diff;
420             state->idx1[ssize++] = -1;
421         }
422         else
423         {
424             state->d[dsize] = diff;
425             state->idx2[dsize++] = -1;
426         }
427     }
428
429     state->ssize = ssize;
430     state->dsize = dsize;
431     state->weight = s_sum > d_sum ? s_sum : d_sum;
432
433     if( lower_bound && equal_sums )     /* check lower bound */
434     {
435         int sz1 = size1 * (dims + 1), sz2 = size2 * (dims + 1);
436         float lb = 0;
437
438         float* xs = (float *) buffer;
439         float* xd = xs + dims;
440
441         memset( xs, 0, dims*sizeof(xs[0]));
442         memset( xd, 0, dims*sizeof(xd[0]));
443
444         for( j = 0; j < sz1; j += dims + 1 )
445         {
446             float weight = signature1[j];
447             for( i = 0; i < dims; i++ )
448                 xs[i] += signature1[j + i + 1] * weight;
449         }
450
451         for( j = 0; j < sz2; j += dims + 1 )
452         {
453             float weight = signature2[j];
454             for( i = 0; i < dims; i++ )
455                 xd[i] += signature2[j + i + 1] * weight;
456         }
457
458         lb = dist_func( xs, xd, user_param ) / state->weight;
459         i = *lower_bound <= lb;
460         *lower_bound = lb;
461         if( i )
462             return 1;
463     }
464
465     /* assign pointers */
466     state->is_used = (char *) buffer;
467     /* init delta matrix */
468     state->delta = (float **) buffer;
469     buffer += ssize * sizeof( float * );
470
471     for( i = 0; i < ssize; i++ )
472     {
473         state->delta[i] = (float *) buffer;
474         buffer += dsize * sizeof( float );
475     }
476
477     state->loop = (CvNode2D **) buffer;
478     buffer += (ssize + dsize + 1) * sizeof(CvNode2D*);
479
480     state->_x = state->end_x = (CvNode2D *) buffer;
481     buffer += (ssize + dsize) * sizeof( CvNode2D );
482
483     /* init cost matrix */
484     state->cost = (float **) buffer;
485     buffer += ssize * sizeof( float * );
486
487     /* compute the distance matrix */
488     for( i = 0; i < ssize; i++ )
489     {
490         int ci = state->idx1[i];
491
492         state->cost[i] = (float *) buffer;
493         buffer += dsize * sizeof( float );
494
495         if( ci >= 0 )
496         {
497             for( j = 0; j < dsize; j++ )
498             {
499                 int cj = state->idx2[j];
500                 if( cj < 0 )
501                     state->cost[i][j] = 0;
502                 else
503                 {
504                     float val;
505                     if( dist_func )
506                     {
507                         val = dist_func( signature1 + ci * (dims + 1) + 1,
508                                          signature2 + cj * (dims + 1) + 1,
509                                          user_param );
510                     }
511                     else
512                     {
513                         assert( cost );
514                         val = cost[cost_step*ci + cj];
515                     }
516                     state->cost[i][j] = val;
517                     if( max_cost < val )
518                         max_cost = val;
519                 }
520             }
521         }
522         else
523         {
524             for( j = 0; j < dsize; j++ )
525                 state->cost[i][j] = 0;
526         }
527     }
528
529     state->max_cost = max_cost;
530
531     memset( buffer, 0, buffer_end - buffer );
532
533     state->rows_x = (CvNode2D **) buffer;
534     buffer += ssize * sizeof( CvNode2D * );
535
536     state->cols_x = (CvNode2D **) buffer;
537     buffer += dsize * sizeof( CvNode2D * );
538
539     state->u = (CvNode1D *) buffer;
540     buffer += ssize * sizeof( CvNode1D );
541
542     state->v = (CvNode1D *) buffer;
543     buffer += dsize * sizeof( CvNode1D );
544
545     /* init is_x matrix */
546     state->is_x = (char **) buffer;
547     buffer += ssize * sizeof( char * );
548
549     for( i = 0; i < ssize; i++ )
550     {
551         state->is_x[i] = buffer;
552         buffer += dsize;
553     }
554
555     assert( buffer <= buffer_end );
556
557     icvRussel( state );
558
559     state->enter_x = (state->end_x)++;
560     return 0;
561 }
562
563
564 /****************************************************************************************\
565 *                              icvFindBasicVariables                                   *
566 \****************************************************************************************/
567 static int icvFindBasicVariables( float **cost, char **is_x,
568                        CvNode1D * u, CvNode1D * v, int ssize, int dsize )
569 {
570     int i, j, found;
571     int u_cfound, v_cfound;
572     CvNode1D u0_head, u1_head, *cur_u, *prev_u;
573     CvNode1D v0_head, v1_head, *cur_v, *prev_v;
574
575     /* initialize the rows list (u) and the columns list (v) */
576     u0_head.next = u;
577     for( i = 0; i < ssize; i++ )
578     {
579         u[i].next = u + i + 1;
580     }
581     u[ssize - 1].next = 0;
582     u1_head.next = 0;
583
584     v0_head.next = ssize > 1 ? v + 1 : 0;
585     for( i = 1; i < dsize; i++ )
586     {
587         v[i].next = v + i + 1;
588     }
589     v[dsize - 1].next = 0;
590     v1_head.next = 0;
591
592     /* there are ssize+dsize variables but only ssize+dsize-1 independent equations,
593        so set v[0]=0 */
594     v[0].val = 0;
595     v1_head.next = v;
596     v1_head.next->next = 0;
597
598     /* loop until all variables are found */
599     u_cfound = v_cfound = 0;
600     while( u_cfound < ssize || v_cfound < dsize )
601     {
602         found = 0;
603         if( v_cfound < dsize )
604         {
605             /* loop over all marked columns */
606             prev_v = &v1_head;
607
608             for( found |= (cur_v = v1_head.next) != 0; cur_v != 0; cur_v = cur_v->next )
609             {
610                 float cur_v_val = cur_v->val;
611
612                 j = (int)(cur_v - v);
613                 /* find the variables in column j */
614                 prev_u = &u0_head;
615                 for( cur_u = u0_head.next; cur_u != 0; )
616                 {
617                     i = (int)(cur_u - u);
618                     if( is_x[i][j] )
619                     {
620                         /* compute u[i] */
621                         cur_u->val = cost[i][j] - cur_v_val;
622                         /* ...and add it to the marked list */
623                         prev_u->next = cur_u->next;
624                         cur_u->next = u1_head.next;
625                         u1_head.next = cur_u;
626                         cur_u = prev_u->next;
627                     }
628                     else
629                     {
630                         prev_u = cur_u;
631                         cur_u = cur_u->next;
632                     }
633                 }
634                 prev_v->next = cur_v->next;
635                 v_cfound++;
636             }
637         }
638
639         if( u_cfound < ssize )
640         {
641             /* loop over all marked rows */
642             prev_u = &u1_head;
643             for( found |= (cur_u = u1_head.next) != 0; cur_u != 0; cur_u = cur_u->next )
644             {
645                 float cur_u_val = cur_u->val;
646                 float *_cost;
647                 char *_is_x;
648
649                 i = (int)(cur_u - u);
650                 _cost = cost[i];
651                 _is_x = is_x[i];
652                 /* find the variables in rows i */
653                 prev_v = &v0_head;
654                 for( cur_v = v0_head.next; cur_v != 0; )
655                 {
656                     j = (int)(cur_v - v);
657                     if( _is_x[j] )
658                     {
659                         /* compute v[j] */
660                         cur_v->val = _cost[j] - cur_u_val;
661                         /* ...and add it to the marked list */
662                         prev_v->next = cur_v->next;
663                         cur_v->next = v1_head.next;
664                         v1_head.next = cur_v;
665                         cur_v = prev_v->next;
666                     }
667                     else
668                     {
669                         prev_v = cur_v;
670                         cur_v = cur_v->next;
671                     }
672                 }
673                 prev_u->next = cur_u->next;
674                 u_cfound++;
675             }
676         }
677
678         if( !found )
679             return -1;
680     }
681
682     return 0;
683 }
684
685
686 /****************************************************************************************\
687 *                                   icvIsOptimal                                       *
688 \****************************************************************************************/
689 static float
690 icvIsOptimal( float **cost, char **is_x,
691               CvNode1D * u, CvNode1D * v, int ssize, int dsize, CvNode2D * enter_x )
692 {
693     float delta, min_delta = CV_EMD_INF;
694     int i, j, min_i = 0, min_j = 0;
695
696     /* find the minimal cij-ui-vj over all i,j */
697     for( i = 0; i < ssize; i++ )
698     {
699         float u_val = u[i].val;
700         float *_cost = cost[i];
701         char *_is_x = is_x[i];
702
703         for( j = 0; j < dsize; j++ )
704         {
705             if( !_is_x[j] )
706             {
707                 delta = _cost[j] - u_val - v[j].val;
708                 if( min_delta > delta )
709                 {
710                     min_delta = delta;
711                     min_i = i;
712                     min_j = j;
713                 }
714             }
715         }
716     }
717
718     enter_x->i = min_i;
719     enter_x->j = min_j;
720
721     return min_delta;
722 }
723
724 /****************************************************************************************\
725 *                                   icvNewSolution                                     *
726 \****************************************************************************************/
727 static bool
728 icvNewSolution( CvEMDState * state )
729 {
730     int i, j;
731     float min_val = CV_EMD_INF;
732     int steps;
733     CvNode2D head, *cur_x, *next_x, *leave_x = 0;
734     CvNode2D *enter_x = state->enter_x;
735     CvNode2D **loop = state->loop;
736
737     /* enter the new basic variable */
738     i = enter_x->i;
739     j = enter_x->j;
740     state->is_x[i][j] = 1;
741     enter_x->next[0] = state->rows_x[i];
742     enter_x->next[1] = state->cols_x[j];
743     enter_x->val = 0;
744     state->rows_x[i] = enter_x;
745     state->cols_x[j] = enter_x;
746
747     /* find a chain reaction */
748     steps = icvFindLoop( state );
749
750     if( steps == 0 )
751         return false;
752
753     /* find the largest value in the loop */
754     for( i = 1; i < steps; i += 2 )
755     {
756         float temp = loop[i]->val;
757
758         if( min_val > temp )
759         {
760             leave_x = loop[i];
761             min_val = temp;
762         }
763     }
764
765     /* update the loop */
766     for( i = 0; i < steps; i += 2 )
767     {
768         float temp0 = loop[i]->val + min_val;
769         float temp1 = loop[i + 1]->val - min_val;
770
771         loop[i]->val = temp0;
772         loop[i + 1]->val = temp1;
773     }
774
775     /* remove the leaving basic variable */
776     i = leave_x->i;
777     j = leave_x->j;
778     state->is_x[i][j] = 0;
779
780     head.next[0] = state->rows_x[i];
781     cur_x = &head;
782     while( (next_x = cur_x->next[0]) != leave_x )
783     {
784         cur_x = next_x;
785         assert( cur_x );
786     }
787     cur_x->next[0] = next_x->next[0];
788     state->rows_x[i] = head.next[0];
789
790     head.next[1] = state->cols_x[j];
791     cur_x = &head;
792     while( (next_x = cur_x->next[1]) != leave_x )
793     {
794         cur_x = next_x;
795         assert( cur_x );
796     }
797     cur_x->next[1] = next_x->next[1];
798     state->cols_x[j] = head.next[1];
799
800     /* set enter_x to be the new empty slot */
801     state->enter_x = leave_x;
802
803     return true;
804 }
805
806
807
808 /****************************************************************************************\
809 *                                    icvFindLoop                                       *
810 \****************************************************************************************/
811 static int
812 icvFindLoop( CvEMDState * state )
813 {
814     int i, steps = 1;
815     CvNode2D *new_x;
816     CvNode2D **loop = state->loop;
817     CvNode2D *enter_x = state->enter_x, *_x = state->_x;
818     char *is_used = state->is_used;
819
820     memset( is_used, 0, state->ssize + state->dsize );
821
822     new_x = loop[0] = enter_x;
823     is_used[enter_x - _x] = 1;
824     steps = 1;
825
826     do
827     {
828         if( (steps & 1) == 1 )
829         {
830             /* find an unused x in the row */
831             new_x = state->rows_x[new_x->i];
832             while( new_x != 0 && is_used[new_x - _x] )
833                 new_x = new_x->next[0];
834         }
835         else
836         {
837             /* find an unused x in the column, or the entering x */
838             new_x = state->cols_x[new_x->j];
839             while( new_x != 0 && is_used[new_x - _x] && new_x != enter_x )
840                 new_x = new_x->next[1];
841             if( new_x == enter_x )
842                 break;
843         }
844
845         if( new_x != 0 )        /* found the next x */
846         {
847             /* add x to the loop */
848             loop[steps++] = new_x;
849             is_used[new_x - _x] = 1;
850         }
851         else                    /* didn't find the next x */
852         {
853             /* backtrack */
854             do
855             {
856                 i = steps & 1;
857                 new_x = loop[steps - 1];
858                 do
859                 {
860                     new_x = new_x->next[i];
861                 }
862                 while( new_x != 0 && is_used[new_x - _x] );
863
864                 if( new_x == 0 )
865                 {
866                     is_used[loop[--steps] - _x] = 0;
867                 }
868             }
869             while( new_x == 0 && steps > 0 );
870
871             is_used[loop[steps - 1] - _x] = 0;
872             loop[steps - 1] = new_x;
873             is_used[new_x - _x] = 1;
874         }
875     }
876     while( steps > 0 );
877
878     return steps;
879 }
880
881
882
883 /****************************************************************************************\
884 *                                        icvRussel                                     *
885 \****************************************************************************************/
886 static void
887 icvRussel( CvEMDState * state )
888 {
889     int i, j, min_i = -1, min_j = -1;
890     float min_delta, diff;
891     CvNode1D u_head, *cur_u, *prev_u;
892     CvNode1D v_head, *cur_v, *prev_v;
893     CvNode1D *prev_u_min_i = 0, *prev_v_min_j = 0, *remember;
894     CvNode1D *u = state->u, *v = state->v;
895     int ssize = state->ssize, dsize = state->dsize;
896     float eps = CV_EMD_EPS * state->max_cost;
897     float **cost = state->cost;
898     float **delta = state->delta;
899
900     /* initialize the rows list (ur), and the columns list (vr) */
901     u_head.next = u;
902     for( i = 0; i < ssize; i++ )
903     {
904         u[i].next = u + i + 1;
905     }
906     u[ssize - 1].next = 0;
907
908     v_head.next = v;
909     for( i = 0; i < dsize; i++ )
910     {
911         v[i].val = -CV_EMD_INF;
912         v[i].next = v + i + 1;
913     }
914     v[dsize - 1].next = 0;
915
916     /* find the maximum row and column values (ur[i] and vr[j]) */
917     for( i = 0; i < ssize; i++ )
918     {
919         float u_val = -CV_EMD_INF;
920         float *cost_row = cost[i];
921
922         for( j = 0; j < dsize; j++ )
923         {
924             float temp = cost_row[j];
925
926             if( u_val < temp )
927                 u_val = temp;
928             if( v[j].val < temp )
929                 v[j].val = temp;
930         }
931         u[i].val = u_val;
932     }
933
934     /* compute the delta matrix */
935     for( i = 0; i < ssize; i++ )
936     {
937         float u_val = u[i].val;
938         float *delta_row = delta[i];
939         float *cost_row = cost[i];
940
941         for( j = 0; j < dsize; j++ )
942         {
943             delta_row[j] = cost_row[j] - u_val - v[j].val;
944         }
945     }
946
947     /* find the basic variables */
948     do
949     {
950         /* find the smallest delta[i][j] */
951         min_i = -1;
952         min_delta = CV_EMD_INF;
953         prev_u = &u_head;
954         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
955         {
956             i = (int)(cur_u - u);
957             float *delta_row = delta[i];
958
959             prev_v = &v_head;
960             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
961             {
962                 j = (int)(cur_v - v);
963                 if( min_delta > delta_row[j] )
964                 {
965                     min_delta = delta_row[j];
966                     min_i = i;
967                     min_j = j;
968                     prev_u_min_i = prev_u;
969                     prev_v_min_j = prev_v;
970                 }
971                 prev_v = cur_v;
972             }
973             prev_u = cur_u;
974         }
975
976         if( min_i < 0 )
977             break;
978
979         /* add x[min_i][min_j] to the basis, and adjust supplies and cost */
980         remember = prev_u_min_i->next;
981         icvAddBasicVariable( state, min_i, min_j, prev_u_min_i, prev_v_min_j, &u_head );
982
983         /* update the necessary delta[][] */
984         if( remember == prev_u_min_i->next )    /* line min_i was deleted */
985         {
986             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
987             {
988                 j = (int)(cur_v - v);
989                 if( cur_v->val == cost[min_i][j] )      /* column j needs updating */
990                 {
991                     float max_val = -CV_EMD_INF;
992
993                     /* find the new maximum value in the column */
994                     for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
995                     {
996                         float temp = cost[cur_u - u][j];
997
998                         if( max_val < temp )
999                             max_val = temp;
1000                     }
1001
1002                     /* if needed, adjust the relevant delta[*][j] */
1003                     diff = max_val - cur_v->val;
1004                     cur_v->val = max_val;
1005                     if( fabs( diff ) < eps )
1006                     {
1007                         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1008                             delta[cur_u - u][j] += diff;
1009                     }
1010                 }
1011             }
1012         }
1013         else                    /* column min_j was deleted */
1014         {
1015             for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1016             {
1017                 i = (int)(cur_u - u);
1018                 if( cur_u->val == cost[i][min_j] )      /* row i needs updating */
1019                 {
1020                     float max_val = -CV_EMD_INF;
1021
1022                     /* find the new maximum value in the row */
1023                     for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1024                     {
1025                         float temp = cost[i][cur_v - v];
1026
1027                         if( max_val < temp )
1028                             max_val = temp;
1029                     }
1030
1031                     /* if needed, adjust the relevant delta[i][*] */
1032                     diff = max_val - cur_u->val;
1033                     cur_u->val = max_val;
1034
1035                     if( fabs( diff ) < eps )
1036                     {
1037                         for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1038                             delta[i][cur_v - v] += diff;
1039                     }
1040                 }
1041             }
1042         }
1043     }
1044     while( u_head.next != 0 || v_head.next != 0 );
1045 }
1046
1047
1048
1049 /****************************************************************************************\
1050 *                                   icvAddBasicVariable                                *
1051 \****************************************************************************************/
1052 static void
1053 icvAddBasicVariable( CvEMDState * state,
1054                      int min_i, int min_j,
1055                      CvNode1D * prev_u_min_i, CvNode1D * prev_v_min_j, CvNode1D * u_head )
1056 {
1057     float temp;
1058     CvNode2D *end_x = state->end_x;
1059
1060     if( state->s[min_i] < state->d[min_j] + state->weight * CV_EMD_EPS )
1061     {                           /* supply exhausted */
1062         temp = state->s[min_i];
1063         state->s[min_i] = 0;
1064         state->d[min_j] -= temp;
1065     }
1066     else                        /* demand exhausted */
1067     {
1068         temp = state->d[min_j];
1069         state->d[min_j] = 0;
1070         state->s[min_i] -= temp;
1071     }
1072
1073     /* x(min_i,min_j) is a basic variable */
1074     state->is_x[min_i][min_j] = 1;
1075
1076     end_x->val = temp;
1077     end_x->i = min_i;
1078     end_x->j = min_j;
1079     end_x->next[0] = state->rows_x[min_i];
1080     end_x->next[1] = state->cols_x[min_j];
1081     state->rows_x[min_i] = end_x;
1082     state->cols_x[min_j] = end_x;
1083     state->end_x = end_x + 1;
1084
1085     /* delete supply row only if the empty, and if not last row */
1086     if( state->s[min_i] == 0 && u_head->next->next != 0 )
1087         prev_u_min_i->next = prev_u_min_i->next->next;  /* remove row from list */
1088     else
1089         prev_v_min_j->next = prev_v_min_j->next->next;  /* remove column from list */
1090 }
1091
1092
1093 /****************************************************************************************\
1094 *                                  standard  metrics                                     *
1095 \****************************************************************************************/
1096 static float
1097 icvDistL1( const float *x, const float *y, void *user_param )
1098 {
1099     int i, dims = (int)(size_t)user_param;
1100     double s = 0;
1101
1102     for( i = 0; i < dims; i++ )
1103     {
1104         double t = x[i] - y[i];
1105
1106         s += fabs( t );
1107     }
1108     return (float)s;
1109 }
1110
1111 static float
1112 icvDistL2( const float *x, const float *y, void *user_param )
1113 {
1114     int i, dims = (int)(size_t)user_param;
1115     double s = 0;
1116
1117     for( i = 0; i < dims; i++ )
1118     {
1119         double t = x[i] - y[i];
1120
1121         s += t * t;
1122     }
1123     return cvSqrt( (float)s );
1124 }
1125
1126 static float
1127 icvDistC( const float *x, const float *y, void *user_param )
1128 {
1129     int i, dims = (int)(size_t)user_param;
1130     double s = 0;
1131
1132     for( i = 0; i < dims; i++ )
1133     {
1134         double t = fabs( x[i] - y[i] );
1135
1136         if( s < t )
1137             s = t;
1138     }
1139     return (float)s;
1140 }
1141
1142
1143 float cv::EMD( InputArray _signature1, InputArray _signature2,
1144                int distType, InputArray _cost,
1145                float* lowerBound, OutputArray _flow )
1146 {
1147     CV_INSTRUMENT_REGION()
1148
1149     Mat signature1 = _signature1.getMat(), signature2 = _signature2.getMat();
1150     Mat cost = _cost.getMat(), flow;
1151
1152     CvMat _csignature1 = signature1;
1153     CvMat _csignature2 = signature2;
1154     CvMat _ccost = cost, _cflow;
1155     if( _flow.needed() )
1156     {
1157         _flow.create(signature1.rows, signature2.rows, CV_32F);
1158         flow = _flow.getMat();
1159         flow = Scalar::all(0);
1160         _cflow = flow;
1161     }
1162
1163     return cvCalcEMD2( &_csignature1, &_csignature2, distType, 0, cost.empty() ? 0 : &_ccost,
1164                        _flow.needed() ? &_cflow : 0, lowerBound, 0 );
1165 }
1166
1167 float cv::wrapperEMD(InputArray _signature1, InputArray _signature2,
1168                int distType, InputArray _cost,
1169                Ptr<float> lowerBound, OutputArray _flow)
1170 {
1171     return EMD(_signature1, _signature2, distType, _cost, lowerBound.get(), _flow);
1172 }
1173
1174 /* End of file. */