4b9f24f1bc1abd8fbdae24fd6890cd72a920df9e
[platform/upstream/opencv.git] / apps / haartraining / cvboost.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "cvconfig.h"
43
44 #ifdef HAVE_MALLOC_H
45   #include <malloc.h>
46 #endif
47
48 #ifdef HAVE_MEMORY_H
49   #include <memory.h>
50 #endif
51
52 #ifdef _OPENMP
53   #include <omp.h>
54 #endif /* _OPENMP */
55
56 #include <cstdio>
57 #include <cfloat>
58 #include <cmath>
59 #include <ctime>
60 #include <climits>
61
62 #include "_cvcommon.h"
63 #include "cvclassifier.h"
64
65 #ifdef _OPENMP
66 #include "omp.h"
67 #endif
68
69 #define CV_BOOST_IMPL
70
71 typedef struct CvValArray
72 {
73     uchar* data;
74     size_t step;
75 } CvValArray;
76
77 template<typename T, typename Idx>
78 class LessThanValArray
79 {
80 public:
81     LessThanValArray( const T* _aux ) : aux(_aux) {}
82     bool operator()(Idx a, Idx b) const
83     {
84         return *( (float*) (aux->data + ((int) (a)) * aux->step ) ) <
85                *( (float*) (aux->data + ((int) (b)) * aux->step ) );
86     }
87     const T* aux;
88 };
89
90 CV_BOOST_IMPL
91 void cvGetSortedIndices( CvMat* val, CvMat* idx, int sortcols )
92 {
93     int idxtype = 0;
94     size_t istep = 0;
95     size_t jstep = 0;
96
97     int i = 0;
98     int j = 0;
99
100     CvValArray va;
101
102     CV_Assert( idx != NULL );
103     CV_Assert( val != NULL );
104
105     idxtype = CV_MAT_TYPE( idx->type );
106     CV_Assert( idxtype == CV_16SC1 || idxtype == CV_32SC1 || idxtype == CV_32FC1 );
107     CV_Assert( CV_MAT_TYPE( val->type ) == CV_32FC1 );
108     if( sortcols )
109     {
110         CV_Assert( idx->rows == val->cols );
111         CV_Assert( idx->cols == val->rows );
112         istep = CV_ELEM_SIZE( val->type );
113         jstep = val->step;
114     }
115     else
116     {
117         CV_Assert( idx->rows == val->rows );
118         CV_Assert( idx->cols == val->cols );
119         istep = val->step;
120         jstep = CV_ELEM_SIZE( val->type );
121     }
122
123     va.data = val->data.ptr;
124     va.step = jstep;
125     switch( idxtype )
126     {
127         case CV_16SC1:
128             for( i = 0; i < idx->rows; i++ )
129             {
130                 for( j = 0; j < idx->cols; j++ )
131                 {
132                     CV_MAT_ELEM( *idx, short, i, j ) = (short) j;
133                 }
134                 std::sort((short*) (idx->data.ptr + (size_t)i * idx->step),
135                           (short*) (idx->data.ptr + (size_t)i * idx->step) + idx->cols,
136                           LessThanValArray<CvValArray, short>(&va));
137                 va.data += istep;
138             }
139             break;
140
141         case CV_32SC1:
142             for( i = 0; i < idx->rows; i++ )
143             {
144                 for( j = 0; j < idx->cols; j++ )
145                 {
146                     CV_MAT_ELEM( *idx, int, i, j ) = j;
147                 }
148                 std::sort((int*) (idx->data.ptr + (size_t)i * idx->step),
149                           (int*) (idx->data.ptr + (size_t)i * idx->step) + idx->cols,
150                           LessThanValArray<CvValArray, int>(&va));
151                 va.data += istep;
152             }
153             break;
154
155         case CV_32FC1:
156             for( i = 0; i < idx->rows; i++ )
157             {
158                 for( j = 0; j < idx->cols; j++ )
159                 {
160                     CV_MAT_ELEM( *idx, float, i, j ) = (float) j;
161                 }
162                 std::sort((float*) (idx->data.ptr + (size_t)i * idx->step),
163                           (float*) (idx->data.ptr + (size_t)i * idx->step) + idx->cols,
164                           LessThanValArray<CvValArray, float>(&va));
165                 va.data += istep;
166             }
167             break;
168
169         default:
170             assert( 0 );
171             break;
172     }
173 }
174
175 CV_BOOST_IMPL
176 void cvReleaseStumpClassifier( CvClassifier** classifier )
177 {
178     cvFree( classifier );
179     *classifier = 0;
180 }
181
182 CV_BOOST_IMPL
183 float cvEvalStumpClassifier( CvClassifier* classifier, CvMat* sample )
184 {
185     assert( classifier != NULL );
186     assert( sample != NULL );
187     assert( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
188
189     if( (CV_MAT_ELEM( (*sample), float, 0,
190             ((CvStumpClassifier*) classifier)->compidx )) <
191         ((CvStumpClassifier*) classifier)->threshold )
192         return ((CvStumpClassifier*) classifier)->left;
193     return ((CvStumpClassifier*) classifier)->right;
194 }
195
196 #define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \
197 static int icvFindStumpThreshold_##suffix(                                               \
198         uchar* data, size_t datastep,                                                    \
199         uchar* wdata, size_t wstep,                                                      \
200         uchar* ydata, size_t ystep,                                                      \
201         uchar* idxdata, size_t idxstep, int num,                                         \
202         float* lerror,                                                                   \
203         float* rerror,                                                                   \
204         float* threshold, float* left, float* right,                                     \
205         float* sumw, float* sumwy, float* sumwyy )                                       \
206 {                                                                                        \
207     int found = 0;                                                                       \
208     float wyl  = 0.0F;                                                                   \
209     float wl   = 0.0F;                                                                   \
210     float wyyl = 0.0F;                                                                   \
211     float wyr  = 0.0F;                                                                   \
212     float wr   = 0.0F;                                                                   \
213                                                                                          \
214     float curleft  = 0.0F;                                                               \
215     float curright = 0.0F;                                                               \
216     float* prevval = NULL;                                                               \
217     float* curval  = NULL;                                                               \
218     float curlerror = 0.0F;                                                              \
219     float currerror = 0.0F;                                                              \
220                                                                                          \
221     int i = 0;                                                                           \
222     int idx = 0;                                                                         \
223                                                                                          \
224     if( *sumw == FLT_MAX )                                                               \
225     {                                                                                    \
226         /* calculate sums */                                                             \
227         float *y = NULL;                                                                 \
228         float *w = NULL;                                                                 \
229         float wy = 0.0F;                                                                 \
230                                                                                          \
231         *sumw   = 0.0F;                                                                  \
232         *sumwy  = 0.0F;                                                                  \
233         *sumwyy = 0.0F;                                                                  \
234         for( i = 0; i < num; i++ )                                                       \
235         {                                                                                \
236             idx = (int) ( *((type*) (idxdata + i*idxstep)) );                            \
237             w = (float*) (wdata + idx * wstep);                                          \
238             *sumw += *w;                                                                 \
239             y = (float*) (ydata + idx * ystep);                                          \
240             wy = (*w) * (*y);                                                            \
241             *sumwy += wy;                                                                \
242             *sumwyy += wy * (*y);                                                        \
243         }                                                                                \
244     }                                                                                    \
245                                                                                          \
246     for( i = 0; i < num; i++ )                                                           \
247     {                                                                                    \
248         idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                \
249         curval = (float*) (data + idx * datastep);                                       \
250          /* for debug purpose */                                                         \
251         if( i > 0 ) assert( (*prevval) <= (*curval) );                                   \
252                                                                                          \
253         wyr  = *sumwy - wyl;                                                             \
254         wr   = *sumw  - wl;                                                              \
255                                                                                          \
256         if( wl > 0.0 ) curleft = wyl / wl;                                               \
257         else curleft = 0.0F;                                                             \
258                                                                                          \
259         if( wr > 0.0 ) curright = wyr / wr;                                              \
260         else curright = 0.0F;                                                            \
261                                                                                          \
262         error                                                                            \
263                                                                                          \
264         if( curlerror + currerror < (*lerror) + (*rerror) )                              \
265         {                                                                                \
266             (*lerror) = curlerror;                                                       \
267             (*rerror) = currerror;                                                       \
268             *threshold = *curval;                                                        \
269             if( i > 0 ) {                                                                \
270                 *threshold = 0.5F * (*threshold + *prevval);                             \
271             }                                                                            \
272             *left  = curleft;                                                            \
273             *right = curright;                                                           \
274             found = 1;                                                                   \
275         }                                                                                \
276                                                                                          \
277         do                                                                               \
278         {                                                                                \
279             wl  += *((float*) (wdata + idx * wstep));                                    \
280             wyl += (*((float*) (wdata + idx * wstep)))                                   \
281                 * (*((float*) (ydata + idx * ystep)));                                   \
282             wyyl += *((float*) (wdata + idx * wstep))                                    \
283                 * (*((float*) (ydata + idx * ystep)))                                    \
284                 * (*((float*) (ydata + idx * ystep)));                                   \
285         }                                                                                \
286         while( (++i) < num &&                                                            \
287             ( *((float*) (data + (idx =                                                  \
288                 (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep))                 \
289                 == *curval ) );                                                          \
290         --i;                                                                             \
291         prevval = curval;                                                                \
292     } /* for each value */                                                               \
293                                                                                          \
294     return found;                                                                        \
295 }
296
297 /* misclassification error
298  * err = MIN( wpos, wneg );
299  */
300 #define ICV_DEF_FIND_STUMP_THRESHOLD_MISC( suffix, type )                                \
301     ICV_DEF_FIND_STUMP_THRESHOLD( misc_##suffix, type,                                   \
302         float wposl = 0.5F * ( wl + wyl );                                               \
303         float wposr = 0.5F * ( wr + wyr );                                               \
304         curleft = 0.5F * ( 1.0F + curleft );                                             \
305         curright = 0.5F * ( 1.0F + curright );                                           \
306         curlerror = MIN( wposl, wl - wposl );                                            \
307         currerror = MIN( wposr, wr - wposr );                                            \
308     )
309
310 /* gini error
311  * err = 2 * wpos * wneg /(wpos + wneg)
312  */
313 #define ICV_DEF_FIND_STUMP_THRESHOLD_GINI( suffix, type )                                \
314     ICV_DEF_FIND_STUMP_THRESHOLD( gini_##suffix, type,                                   \
315         float wposl = 0.5F * ( wl + wyl );                                               \
316         float wposr = 0.5F * ( wr + wyr );                                               \
317         curleft = 0.5F * ( 1.0F + curleft );                                             \
318         curright = 0.5F * ( 1.0F + curright );                                           \
319         curlerror = 2.0F * wposl * ( 1.0F - curleft );                                   \
320         currerror = 2.0F * wposr * ( 1.0F - curright );                                  \
321     )
322
323 #define CV_ENTROPY_THRESHOLD FLT_MIN
324
325 /* entropy error
326  * err = - wpos * log(wpos / (wpos + wneg)) - wneg * log(wneg / (wpos + wneg))
327  */
328 #define ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( suffix, type )                             \
329     ICV_DEF_FIND_STUMP_THRESHOLD( entropy_##suffix, type,                                \
330         float wposl = 0.5F * ( wl + wyl );                                               \
331         float wposr = 0.5F * ( wr + wyr );                                               \
332         curleft = 0.5F * ( 1.0F + curleft );                                             \
333         curright = 0.5F * ( 1.0F + curright );                                           \
334         curlerror = currerror = 0.0F;                                                    \
335         if( curleft > CV_ENTROPY_THRESHOLD )                                             \
336             curlerror -= wposl * logf( curleft );                                        \
337         if( curleft < 1.0F - CV_ENTROPY_THRESHOLD )                                      \
338             curlerror -= (wl - wposl) * logf( 1.0F - curleft );                          \
339                                                                                          \
340         if( curright > CV_ENTROPY_THRESHOLD )                                            \
341             currerror -= wposr * logf( curright );                                       \
342         if( curright < 1.0F - CV_ENTROPY_THRESHOLD )                                     \
343             currerror -= (wr - wposr) * logf( 1.0F - curright );                         \
344     )
345
346 /* least sum of squares error */
347 #define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type )                                  \
348     ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type,                                     \
349         /* calculate error (sum of squares)          */                                  \
350         /* err = sum( w * (y - left(rigt)Val)^2 )    */                                  \
351         curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl;                \
352         currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
353     )
354
355 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 16s, short )
356
357 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32s, int )
358
359 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32f, float )
360
361
362 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 16s, short )
363
364 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32s, int )
365
366 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32f, float )
367
368
369 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 16s, short )
370
371 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32s, int )
372
373 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32f, float )
374
375
376 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 16s, short )
377
378 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32s, int )
379
380 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32f, float )
381
382 typedef int (*CvFindThresholdFunc)( uchar* data, size_t datastep,
383                                     uchar* wdata, size_t wstep,
384                                     uchar* ydata, size_t ystep,
385                                     uchar* idxdata, size_t idxstep, int num,
386                                     float* lerror,
387                                     float* rerror,
388                                     float* threshold, float* left, float* right,
389                                     float* sumw, float* sumwy, float* sumwyy );
390
391 CvFindThresholdFunc findStumpThreshold_16s[4] = {
392         icvFindStumpThreshold_misc_16s,
393         icvFindStumpThreshold_gini_16s,
394         icvFindStumpThreshold_entropy_16s,
395         icvFindStumpThreshold_sq_16s
396     };
397
398 CvFindThresholdFunc findStumpThreshold_32s[4] = {
399         icvFindStumpThreshold_misc_32s,
400         icvFindStumpThreshold_gini_32s,
401         icvFindStumpThreshold_entropy_32s,
402         icvFindStumpThreshold_sq_32s
403     };
404
405 CvFindThresholdFunc findStumpThreshold_32f[4] = {
406         icvFindStumpThreshold_misc_32f,
407         icvFindStumpThreshold_gini_32f,
408         icvFindStumpThreshold_entropy_32f,
409         icvFindStumpThreshold_sq_32f
410     };
411
412 CV_BOOST_IMPL
413 CvClassifier* cvCreateStumpClassifier( CvMat* trainData,
414                       int flags,
415                       CvMat* trainClasses,
416                       CvMat* /*typeMask*/,
417                       CvMat* missedMeasurementsMask,
418                       CvMat* compIdx,
419                       CvMat* sampleIdx,
420                       CvMat* weights,
421                       CvClassifierTrainParams* trainParams
422                     )
423 {
424     CvStumpClassifier* stump = NULL;
425     int m = 0; /* number of samples */
426     int n = 0; /* number of components */
427     uchar* data = NULL;
428     int cstep   = 0;
429     int sstep   = 0;
430     uchar* ydata = NULL;
431     int ystep    = 0;
432     uchar* idxdata = NULL;
433     int idxstep    = 0;
434     int l = 0; /* number of indices */
435     uchar* wdata = NULL;
436     int wstep    = 0;
437
438     int* idx = NULL;
439     int i = 0;
440
441     float sumw   = FLT_MAX;
442     float sumwy  = FLT_MAX;
443     float sumwyy = FLT_MAX;
444
445     CV_Assert( trainData != NULL );
446     CV_Assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
447     CV_Assert( trainClasses != NULL );
448     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
449     CV_Assert( missedMeasurementsMask == NULL );
450     CV_Assert( compIdx == NULL );
451     CV_Assert( weights != NULL );
452     CV_Assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
453     CV_Assert( trainParams != NULL );
454
455     data = trainData->data.ptr;
456     if( CV_IS_ROW_SAMPLE( flags ) )
457     {
458         cstep = CV_ELEM_SIZE( trainData->type );
459         sstep = trainData->step;
460         m = trainData->rows;
461         n = trainData->cols;
462     }
463     else
464     {
465         sstep = CV_ELEM_SIZE( trainData->type );
466         cstep = trainData->step;
467         m = trainData->cols;
468         n = trainData->rows;
469     }
470
471     ydata = trainClasses->data.ptr;
472     if( trainClasses->rows == 1 )
473     {
474         assert( trainClasses->cols == m );
475         ystep = CV_ELEM_SIZE( trainClasses->type );
476     }
477     else
478     {
479         assert( trainClasses->rows == m );
480         ystep = trainClasses->step;
481     }
482
483     wdata = weights->data.ptr;
484     if( weights->rows == 1 )
485     {
486         assert( weights->cols == m );
487         wstep = CV_ELEM_SIZE( weights->type );
488     }
489     else
490     {
491         assert( weights->rows == m );
492         wstep = weights->step;
493     }
494
495     l = m;
496     if( sampleIdx != NULL )
497     {
498         assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
499
500         idxdata = sampleIdx->data.ptr;
501         if( sampleIdx->rows == 1 )
502         {
503             l = sampleIdx->cols;
504             idxstep = CV_ELEM_SIZE( sampleIdx->type );
505         }
506         else
507         {
508             l = sampleIdx->rows;
509             idxstep = sampleIdx->step;
510         }
511         assert( l <= m );
512     }
513
514     idx = (int*) cvAlloc( l * sizeof( int ) );
515     stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
516
517     /* START */
518     memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
519
520     stump->eval = cvEvalStumpClassifier;
521     stump->tune = NULL;
522     stump->save = NULL;
523     stump->release = cvReleaseStumpClassifier;
524
525     stump->lerror = FLT_MAX;
526     stump->rerror = FLT_MAX;
527     stump->left  = 0.0F;
528     stump->right = 0.0F;
529
530     /* copy indices */
531     if( sampleIdx != NULL )
532     {
533         for( i = 0; i < l; i++ )
534         {
535             idx[i] = (int) *((float*) (idxdata + i*idxstep));
536         }
537     }
538     else
539     {
540         for( i = 0; i < l; i++ )
541         {
542             idx[i] = i;
543         }
544     }
545
546     for( i = 0; i < n; i++ )
547     {
548         CvValArray va;
549
550         va.data = data + i * ((size_t) cstep);
551         va.step = sstep;
552         std::sort(idx, idx + l, LessThanValArray<CvValArray, int>(&va));
553         if( findStumpThreshold_32s[(int) ((CvStumpTrainParams*) trainParams)->error]
554               ( data + i * ((size_t) cstep), sstep,
555                 wdata, wstep, ydata, ystep, (uchar*) idx, sizeof( int ), l,
556                 &(stump->lerror), &(stump->rerror),
557                 &(stump->threshold), &(stump->left), &(stump->right),
558                 &sumw, &sumwy, &sumwyy ) )
559         {
560             stump->compidx = i;
561         }
562     } /* for each component */
563
564     /* END */
565
566     cvFree( &idx );
567
568     if( ((CvStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
569     {
570         stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
571         stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
572     }
573
574     return (CvClassifier*) stump;
575 }
576
577 /*
578  * cvCreateMTStumpClassifier
579  *
580  * Multithreaded stump classifier constructor
581  * Includes huge train data support through callback function
582  */
583 CV_BOOST_IMPL
584 CvClassifier* cvCreateMTStumpClassifier( CvMat* trainData,
585                       int flags,
586                       CvMat* trainClasses,
587                       CvMat* /*typeMask*/,
588                       CvMat* missedMeasurementsMask,
589                       CvMat* compIdx,
590                       CvMat* sampleIdx,
591                       CvMat* weights,
592                       CvClassifierTrainParams* trainParams )
593 {
594     CvStumpClassifier* stump = NULL;
595     int m = 0; /* number of samples */
596     int n = 0; /* number of components */
597     uchar* data = NULL;
598     size_t cstep   = 0;
599     size_t sstep   = 0;
600     int    datan   = 0; /* num components */
601     uchar* ydata = NULL;
602     size_t ystep = 0;
603     uchar* idxdata = NULL;
604     size_t idxstep = 0;
605     int    l = 0; /* number of indices */
606     uchar* wdata = NULL;
607     size_t wstep = 0;
608
609     uchar* sorteddata = NULL;
610     int    sortedtype    = 0;
611     size_t sortedcstep   = 0; /* component step */
612     size_t sortedsstep   = 0; /* sample step */
613     int    sortedn       = 0; /* num components */
614     int    sortedm       = 0; /* num samples */
615
616     char* filter = NULL;
617     int i = 0;
618
619     int compidx = 0;
620     int stumperror;
621     int portion;
622
623     /* private variables */
624     CvMat mat;
625     CvValArray va;
626     float lerror;
627     float rerror;
628     float left;
629     float right;
630     float threshold;
631     int optcompidx;
632
633     float sumw;
634     float sumwy;
635     float sumwyy;
636
637     int t_compidx;
638     int t_n;
639
640     int ti;
641     int tj;
642     int tk;
643
644     uchar* t_data;
645     size_t t_cstep;
646     size_t t_sstep;
647
648     size_t matcstep;
649     size_t matsstep;
650
651     int* t_idx;
652     /* end private variables */
653
654     CV_Assert( trainParams != NULL );
655     CV_Assert( trainClasses != NULL );
656     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
657     CV_Assert( missedMeasurementsMask == NULL );
658     CV_Assert( compIdx == NULL );
659
660     stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;
661
662     ydata = trainClasses->data.ptr;
663     if( trainClasses->rows == 1 )
664     {
665         m = trainClasses->cols;
666         ystep = CV_ELEM_SIZE( trainClasses->type );
667     }
668     else
669     {
670         m = trainClasses->rows;
671         ystep = trainClasses->step;
672     }
673
674     wdata = weights->data.ptr;
675     if( weights->rows == 1 )
676     {
677         CV_Assert( weights->cols == m );
678         wstep = CV_ELEM_SIZE( weights->type );
679     }
680     else
681     {
682         CV_Assert( weights->rows == m );
683         wstep = weights->step;
684     }
685
686     if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
687     {
688         sortedtype =
689             CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
690         assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
691                 || sortedtype == CV_32FC1 );
692         sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
693         sortedsstep = CV_ELEM_SIZE( sortedtype );
694         sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
695         sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
696         sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
697     }
698
699     if( trainData == NULL )
700     {
701         assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
702         n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
703         assert( n > 0 );
704     }
705     else
706     {
707         assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
708         data = trainData->data.ptr;
709         if( CV_IS_ROW_SAMPLE( flags ) )
710         {
711             cstep = CV_ELEM_SIZE( trainData->type );
712             sstep = trainData->step;
713             assert( m == trainData->rows );
714             datan = n = trainData->cols;
715         }
716         else
717         {
718             sstep = CV_ELEM_SIZE( trainData->type );
719             cstep = trainData->step;
720             assert( m == trainData->cols );
721             datan = n = trainData->rows;
722         }
723         if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
724         {
725             n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
726         }
727     }
728     assert( datan <= n );
729
730     if( sampleIdx != NULL )
731     {
732         assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
733         idxdata = sampleIdx->data.ptr;
734         idxstep = ( sampleIdx->rows == 1 )
735             ? CV_ELEM_SIZE( sampleIdx->type ) : sampleIdx->step;
736         l = ( sampleIdx->rows == 1 ) ? sampleIdx->cols : sampleIdx->rows;
737
738         if( sorteddata != NULL )
739         {
740             filter = (char*) cvAlloc( sizeof( char ) * m );
741             memset( (void*) filter, 0, sizeof( char ) * m );
742             for( i = 0; i < l; i++ )
743             {
744                 filter[(int) *((float*) (idxdata + i * idxstep))] = (char) 1;
745             }
746         }
747     }
748     else
749     {
750         l = m;
751     }
752
753     stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
754
755     /* START */
756     memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
757
758     portion = ((CvMTStumpTrainParams*)trainParams)->portion;
759
760     if( portion < 1 )
761     {
762         /* auto portion */
763         portion = n;
764         #ifdef _OPENMP
765         portion /= omp_get_max_threads();
766         #endif /* _OPENMP */
767     }
768
769     stump->eval = cvEvalStumpClassifier;
770     stump->tune = NULL;
771     stump->save = NULL;
772     stump->release = cvReleaseStumpClassifier;
773
774     stump->lerror = FLT_MAX;
775     stump->rerror = FLT_MAX;
776     stump->left  = 0.0F;
777     stump->right = 0.0F;
778
779     compidx = 0;
780     #ifdef _OPENMP
781     #pragma omp parallel private(mat, va, lerror, rerror, left, right, threshold, \
782                                  optcompidx, sumw, sumwy, sumwyy, t_compidx, t_n, \
783                                  ti, tj, tk, t_data, t_cstep, t_sstep, matcstep,  \
784                                  matsstep, t_idx)
785     #endif /* _OPENMP */
786     {
787         lerror = FLT_MAX;
788         rerror = FLT_MAX;
789         left  = 0.0F;
790         right = 0.0F;
791         threshold = 0.0F;
792         optcompidx = 0;
793
794         sumw   = FLT_MAX;
795         sumwy  = FLT_MAX;
796         sumwyy = FLT_MAX;
797
798         t_compidx = 0;
799         t_n = 0;
800
801         ti = 0;
802         tj = 0;
803         tk = 0;
804
805         t_data = NULL;
806         t_cstep = 0;
807         t_sstep = 0;
808
809         matcstep = 0;
810         matsstep = 0;
811
812         t_idx = NULL;
813
814         mat.data.ptr = NULL;
815
816         if( datan < n )
817         {
818             /* prepare matrix for callback */
819             if( CV_IS_ROW_SAMPLE( flags ) )
820             {
821                 mat = cvMat( m, portion, CV_32FC1, 0 );
822                 matcstep = CV_ELEM_SIZE( mat.type );
823                 matsstep = mat.step;
824             }
825             else
826             {
827                 mat = cvMat( portion, m, CV_32FC1, 0 );
828                 matcstep = mat.step;
829                 matsstep = CV_ELEM_SIZE( mat.type );
830             }
831             mat.data.ptr = (uchar*) cvAlloc( sizeof( float ) * mat.rows * mat.cols );
832         }
833
834         if( filter != NULL || sortedn < n )
835         {
836             t_idx = (int*) cvAlloc( sizeof( int ) * m );
837             if( sortedn == 0 || filter == NULL )
838             {
839                 if( idxdata != NULL )
840                 {
841                     for( ti = 0; ti < l; ti++ )
842                     {
843                         t_idx[ti] = (int) *((float*) (idxdata + ti * idxstep));
844                     }
845                 }
846                 else
847                 {
848                     for( ti = 0; ti < l; ti++ )
849                     {
850                         t_idx[ti] = ti;
851                     }
852                 }
853             }
854         }
855
856         #ifdef _OPENMP
857         #pragma omp critical(c_compidx)
858         #endif /* _OPENMP */
859         {
860             t_compidx = compidx;
861             compidx += portion;
862         }
863         while( t_compidx < n )
864         {
865             t_n = portion;
866             if( t_compidx < datan )
867             {
868                 t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
869                 t_data = data;
870                 t_cstep = cstep;
871                 t_sstep = sstep;
872             }
873             else
874             {
875                 t_n = ( t_n < (n - t_compidx) ) ? t_n : (n - t_compidx);
876                 t_cstep = matcstep;
877                 t_sstep = matsstep;
878                 t_data = mat.data.ptr - t_compidx * ((size_t) t_cstep );
879
880                 /* calculate components */
881                 ((CvMTStumpTrainParams*)trainParams)->getTrainData( &mat,
882                         sampleIdx, compIdx, t_compidx, t_n,
883                         ((CvMTStumpTrainParams*)trainParams)->userdata );
884             }
885
886             if( sorteddata != NULL )
887             {
888                 if( filter != NULL )
889                 {
890                     /* have sorted indices and filter */
891                     switch( sortedtype )
892                     {
893                         case CV_16SC1:
894                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
895                             {
896                                 tk = 0;
897                                 for( tj = 0; tj < sortedm; tj++ )
898                                 {
899                                     int curidx = (int) ( *((short*) (sorteddata
900                                             + ti * sortedcstep + tj * sortedsstep)) );
901                                     if( filter[curidx] != 0 )
902                                     {
903                                         t_idx[tk++] = curidx;
904                                     }
905                                 }
906                                 if( findStumpThreshold_32s[stumperror](
907                                         t_data + ti * t_cstep, t_sstep,
908                                         wdata, wstep, ydata, ystep,
909                                         (uchar*) t_idx, sizeof( int ), tk,
910                                         &lerror, &rerror,
911                                         &threshold, &left, &right,
912                                         &sumw, &sumwy, &sumwyy ) )
913                                 {
914                                     optcompidx = ti;
915                                 }
916                             }
917                             break;
918                         case CV_32SC1:
919                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
920                             {
921                                 tk = 0;
922                                 for( tj = 0; tj < sortedm; tj++ )
923                                 {
924                                     int curidx = (int) ( *((int*) (sorteddata
925                                             + ti * sortedcstep + tj * sortedsstep)) );
926                                     if( filter[curidx] != 0 )
927                                     {
928                                         t_idx[tk++] = curidx;
929                                     }
930                                 }
931                                 if( findStumpThreshold_32s[stumperror](
932                                         t_data + ti * t_cstep, t_sstep,
933                                         wdata, wstep, ydata, ystep,
934                                         (uchar*) t_idx, sizeof( int ), tk,
935                                         &lerror, &rerror,
936                                         &threshold, &left, &right,
937                                         &sumw, &sumwy, &sumwyy ) )
938                                 {
939                                     optcompidx = ti;
940                                 }
941                             }
942                             break;
943                         case CV_32FC1:
944                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
945                             {
946                                 tk = 0;
947                                 for( tj = 0; tj < sortedm; tj++ )
948                                 {
949                                     int curidx = (int) ( *((float*) (sorteddata
950                                             + ti * sortedcstep + tj * sortedsstep)) );
951                                     if( filter[curidx] != 0 )
952                                     {
953                                         t_idx[tk++] = curidx;
954                                     }
955                                 }
956                                 if( findStumpThreshold_32s[stumperror](
957                                         t_data + ti * t_cstep, t_sstep,
958                                         wdata, wstep, ydata, ystep,
959                                         (uchar*) t_idx, sizeof( int ), tk,
960                                         &lerror, &rerror,
961                                         &threshold, &left, &right,
962                                         &sumw, &sumwy, &sumwyy ) )
963                                 {
964                                     optcompidx = ti;
965                                 }
966                             }
967                             break;
968                         default:
969                             assert( 0 );
970                             break;
971                     }
972                 }
973                 else
974                 {
975                     /* have sorted indices */
976                     switch( sortedtype )
977                     {
978                         case CV_16SC1:
979                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
980                             {
981                                 if( findStumpThreshold_16s[stumperror](
982                                         t_data + ti * t_cstep, t_sstep,
983                                         wdata, wstep, ydata, ystep,
984                                         sorteddata + ti * sortedcstep, sortedsstep, sortedm,
985                                         &lerror, &rerror,
986                                         &threshold, &left, &right,
987                                         &sumw, &sumwy, &sumwyy ) )
988                                 {
989                                     optcompidx = ti;
990                                 }
991                             }
992                             break;
993                         case CV_32SC1:
994                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
995                             {
996                                 if( findStumpThreshold_32s[stumperror](
997                                         t_data + ti * t_cstep, t_sstep,
998                                         wdata, wstep, ydata, ystep,
999                                         sorteddata + ti * sortedcstep, sortedsstep, sortedm,
1000                                         &lerror, &rerror,
1001                                         &threshold, &left, &right,
1002                                         &sumw, &sumwy, &sumwyy ) )
1003                                 {
1004                                     optcompidx = ti;
1005                                 }
1006                             }
1007                             break;
1008                         case CV_32FC1:
1009                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
1010                             {
1011                                 if( findStumpThreshold_32f[stumperror](
1012                                         t_data + ti * t_cstep, t_sstep,
1013                                         wdata, wstep, ydata, ystep,
1014                                         sorteddata + ti * sortedcstep, sortedsstep, sortedm,
1015                                         &lerror, &rerror,
1016                                         &threshold, &left, &right,
1017                                         &sumw, &sumwy, &sumwyy ) )
1018                                 {
1019                                     optcompidx = ti;
1020                                 }
1021                             }
1022                             break;
1023                         default:
1024                             assert( 0 );
1025                             break;
1026                     }
1027                 }
1028             }
1029
1030             ti = MAX( t_compidx, MIN( sortedn, t_compidx + t_n ) );
1031             for( ; ti < t_compidx + t_n; ti++ )
1032             {
1033                 va.data = t_data + ti * t_cstep;
1034                 va.step = t_sstep;
1035                 std::sort(t_idx, t_idx + l, LessThanValArray<CvValArray, int>(&va));
1036                 if( findStumpThreshold_32s[stumperror](
1037                         t_data + ti * t_cstep, t_sstep,
1038                         wdata, wstep, ydata, ystep,
1039                         (uchar*)t_idx, sizeof( int ), l,
1040                         &lerror, &rerror,
1041                         &threshold, &left, &right,
1042                         &sumw, &sumwy, &sumwyy ) )
1043                 {
1044                     optcompidx = ti;
1045                 }
1046             }
1047             #ifdef _OPENMP
1048             #pragma omp critical(c_compidx)
1049             #endif /* _OPENMP */
1050             {
1051                 t_compidx = compidx;
1052                 compidx += portion;
1053             }
1054         } /* while have training data */
1055
1056         /* get the best classifier */
1057         #ifdef _OPENMP
1058         #pragma omp critical(c_beststump)
1059         #endif /* _OPENMP */
1060         {
1061             if( lerror + rerror < stump->lerror + stump->rerror )
1062             {
1063                 stump->lerror    = lerror;
1064                 stump->rerror    = rerror;
1065                 stump->compidx   = optcompidx;
1066                 stump->threshold = threshold;
1067                 stump->left      = left;
1068                 stump->right     = right;
1069             }
1070         }
1071
1072         /* free allocated memory */
1073         if( mat.data.ptr != NULL )
1074         {
1075             cvFree( &(mat.data.ptr) );
1076         }
1077         if( t_idx != NULL )
1078         {
1079             cvFree( &t_idx );
1080         }
1081     } /* end of parallel region */
1082
1083     /* END */
1084
1085     /* free allocated memory */
1086     if( filter != NULL )
1087     {
1088         cvFree( &filter );
1089     }
1090
1091     if( ((CvMTStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
1092     {
1093         stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
1094         stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
1095     }
1096
1097     return (CvClassifier*) stump;
1098 }
1099
1100 CV_BOOST_IMPL
1101 float cvEvalCARTClassifier( CvClassifier* classifier, CvMat* sample )
1102 {
1103     CV_FUNCNAME( "cvEvalCARTClassifier" );
1104
1105     int idx = 0;
1106
1107     __BEGIN__;
1108
1109
1110     CV_ASSERT( classifier != NULL );
1111     CV_ASSERT( sample != NULL );
1112     CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1113     CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1114
1115     if( sample->rows == 1 )
1116     {
1117         do
1118         {
1119             if( (CV_MAT_ELEM( (*sample), float, 0,
1120                     ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1121                 ((CvCARTClassifier*) classifier)->threshold[idx] )
1122             {
1123                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1124             }
1125             else
1126             {
1127                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1128             }
1129         } while( idx > 0 );
1130     }
1131     else
1132     {
1133         do
1134         {
1135             if( (CV_MAT_ELEM( (*sample), float,
1136                     ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1137                 ((CvCARTClassifier*) classifier)->threshold[idx] )
1138             {
1139                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1140             }
1141             else
1142             {
1143                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1144             }
1145         } while( idx > 0 );
1146     }
1147
1148     __END__;
1149
1150     return ((CvCARTClassifier*) classifier)->val[-idx];
1151 }
1152
1153 static
1154 float cvEvalCARTClassifierIdx( CvClassifier* classifier, CvMat* sample )
1155 {
1156     CV_FUNCNAME( "cvEvalCARTClassifierIdx" );
1157
1158     int idx = 0;
1159
1160     __BEGIN__;
1161
1162
1163     CV_ASSERT( classifier != NULL );
1164     CV_ASSERT( sample != NULL );
1165     CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1166     CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1167
1168     if( sample->rows == 1 )
1169     {
1170         do
1171         {
1172             if( (CV_MAT_ELEM( (*sample), float, 0,
1173                     ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1174                 ((CvCARTClassifier*) classifier)->threshold[idx] )
1175             {
1176                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1177             }
1178             else
1179             {
1180                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1181             }
1182         } while( idx > 0 );
1183     }
1184     else
1185     {
1186         do
1187         {
1188             if( (CV_MAT_ELEM( (*sample), float,
1189                     ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1190                 ((CvCARTClassifier*) classifier)->threshold[idx] )
1191             {
1192                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1193             }
1194             else
1195             {
1196                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1197             }
1198         } while( idx > 0 );
1199     }
1200
1201     __END__;
1202
1203     return (float) (-idx);
1204 }
1205
1206 CV_BOOST_IMPL
1207 void cvReleaseCARTClassifier( CvClassifier** classifier )
1208 {
1209     cvFree( classifier );
1210     *classifier = NULL;
1211 }
1212
1213 static void CV_CDECL icvDefaultSplitIdx_R( int compidx, float threshold,
1214                                     CvMat* idx, CvMat** left, CvMat** right,
1215                                     void* userdata )
1216 {
1217     CvMat* trainData = (CvMat*) userdata;
1218     int i = 0;
1219
1220     *left = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1221     *right = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1222     (*left)->cols = (*right)->cols = 0;
1223     if( idx == NULL )
1224     {
1225         for( i = 0; i < trainData->rows; i++ )
1226         {
1227             if( CV_MAT_ELEM( *trainData, float, i, compidx ) < threshold )
1228             {
1229                 (*left)->data.fl[(*left)->cols++] = (float) i;
1230             }
1231             else
1232             {
1233                 (*right)->data.fl[(*right)->cols++] = (float) i;
1234             }
1235         }
1236     }
1237     else
1238     {
1239         uchar* idxdata;
1240         int idxnum;
1241         int idxstep;
1242         int index;
1243
1244         idxdata = idx->data.ptr;
1245         idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1246         idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1247         for( i = 0; i < idxnum; i++ )
1248         {
1249             index = (int) *((float*) (idxdata + i * idxstep));
1250             if( CV_MAT_ELEM( *trainData, float, index, compidx ) < threshold )
1251             {
1252                 (*left)->data.fl[(*left)->cols++] = (float) index;
1253             }
1254             else
1255             {
1256                 (*right)->data.fl[(*right)->cols++] = (float) index;
1257             }
1258         }
1259     }
1260 }
1261
1262 static void CV_CDECL icvDefaultSplitIdx_C( int compidx, float threshold,
1263                                     CvMat* idx, CvMat** left, CvMat** right,
1264                                     void* userdata )
1265 {
1266     CvMat* trainData = (CvMat*) userdata;
1267     int i = 0;
1268
1269     *left = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1270     *right = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1271     (*left)->cols = (*right)->cols = 0;
1272     if( idx == NULL )
1273     {
1274         for( i = 0; i < trainData->cols; i++ )
1275         {
1276             if( CV_MAT_ELEM( *trainData, float, compidx, i ) < threshold )
1277             {
1278                 (*left)->data.fl[(*left)->cols++] = (float) i;
1279             }
1280             else
1281             {
1282                 (*right)->data.fl[(*right)->cols++] = (float) i;
1283             }
1284         }
1285     }
1286     else
1287     {
1288         uchar* idxdata;
1289         int idxnum;
1290         int idxstep;
1291         int index;
1292
1293         idxdata = idx->data.ptr;
1294         idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1295         idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1296         for( i = 0; i < idxnum; i++ )
1297         {
1298             index = (int) *((float*) (idxdata + i * idxstep));
1299             if( CV_MAT_ELEM( *trainData, float, compidx, index ) < threshold )
1300             {
1301                 (*left)->data.fl[(*left)->cols++] = (float) index;
1302             }
1303             else
1304             {
1305                 (*right)->data.fl[(*right)->cols++] = (float) index;
1306             }
1307         }
1308     }
1309 }
1310
1311 /* internal structure used in CART creation */
1312 typedef struct CvCARTNode
1313 {
1314     CvMat* sampleIdx;
1315     CvStumpClassifier* stump;
1316     int parent;
1317     int leftflag;
1318     float errdrop;
1319 } CvCARTNode;
1320
1321 CV_BOOST_IMPL
1322 CvClassifier* cvCreateCARTClassifier( CvMat* trainData,
1323                      int flags,
1324                      CvMat* trainClasses,
1325                      CvMat* typeMask,
1326                      CvMat* missedMeasurementsMask,
1327                      CvMat* compIdx,
1328                      CvMat* sampleIdx,
1329                      CvMat* weights,
1330                      CvClassifierTrainParams* trainParams )
1331 {
1332     CvCARTClassifier* cart = NULL;
1333     size_t datasize = 0;
1334     int count = 0;
1335     int i = 0;
1336     int j = 0;
1337
1338     CvCARTNode* intnode = NULL;
1339     CvCARTNode* list = NULL;
1340     int listcount = 0;
1341     CvMat* lidx = NULL;
1342     CvMat* ridx = NULL;
1343
1344     float maxerrdrop = 0.0F;
1345     int idx = 0;
1346
1347     void (*splitIdxCallback)( int compidx, float threshold,
1348                               CvMat* idx, CvMat** left, CvMat** right,
1349                               void* userdata );
1350     void* userdata;
1351
1352     count = ((CvCARTTrainParams*) trainParams)->count;
1353
1354     assert( count > 0 );
1355
1356     datasize = sizeof( *cart ) + (sizeof( float ) + 3 * sizeof( int )) * count +
1357         sizeof( float ) * (count + 1);
1358
1359     cart = (CvCARTClassifier*) cvAlloc( datasize );
1360     memset( cart, 0, datasize );
1361
1362     cart->count = count;
1363
1364     cart->eval = cvEvalCARTClassifier;
1365     cart->save = NULL;
1366     cart->release = cvReleaseCARTClassifier;
1367
1368     cart->compidx = (int*) (cart + 1);
1369     cart->threshold = (float*) (cart->compidx + count);
1370     cart->left  = (int*) (cart->threshold + count);
1371     cart->right = (int*) (cart->left + count);
1372     cart->val = (float*) (cart->right + count);
1373
1374     datasize = sizeof( CvCARTNode ) * (count + count);
1375     intnode = (CvCARTNode*) cvAlloc( datasize );
1376     memset( intnode, 0, datasize );
1377     list = (CvCARTNode*) (intnode + count);
1378
1379     splitIdxCallback = ((CvCARTTrainParams*) trainParams)->splitIdx;
1380     userdata = ((CvCARTTrainParams*) trainParams)->userdata;
1381     if( splitIdxCallback == NULL )
1382     {
1383         splitIdxCallback = ( CV_IS_ROW_SAMPLE( flags ) )
1384             ? icvDefaultSplitIdx_R : icvDefaultSplitIdx_C;
1385         userdata = trainData;
1386     }
1387
1388     /* create root of the tree */
1389     intnode[0].sampleIdx = sampleIdx;
1390     intnode[0].stump = (CvStumpClassifier*)
1391         ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1392             trainClasses, typeMask, missedMeasurementsMask, compIdx, sampleIdx, weights,
1393             ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1394     cart->left[0] = cart->right[0] = 0;
1395
1396     /* build tree */
1397     listcount = 0;
1398     for( i = 1; i < count; i++ )
1399     {
1400         /* split last added node */
1401         splitIdxCallback( intnode[i-1].stump->compidx, intnode[i-1].stump->threshold,
1402             intnode[i-1].sampleIdx, &lidx, &ridx, userdata );
1403
1404         if( intnode[i-1].stump->lerror != 0.0F )
1405         {
1406             list[listcount].sampleIdx = lidx;
1407             list[listcount].stump = (CvStumpClassifier*)
1408                 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1409                     trainClasses, typeMask, missedMeasurementsMask, compIdx,
1410                     list[listcount].sampleIdx,
1411                     weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1412             list[listcount].errdrop = intnode[i-1].stump->lerror
1413                 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1414             list[listcount].leftflag = 1;
1415             list[listcount].parent = i-1;
1416             listcount++;
1417         }
1418         else
1419         {
1420             cvReleaseMat( &lidx );
1421         }
1422         if( intnode[i-1].stump->rerror != 0.0F )
1423         {
1424             list[listcount].sampleIdx = ridx;
1425             list[listcount].stump = (CvStumpClassifier*)
1426                 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1427                     trainClasses, typeMask, missedMeasurementsMask, compIdx,
1428                     list[listcount].sampleIdx,
1429                     weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1430             list[listcount].errdrop = intnode[i-1].stump->rerror
1431                 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1432             list[listcount].leftflag = 0;
1433             list[listcount].parent = i-1;
1434             listcount++;
1435         }
1436         else
1437         {
1438             cvReleaseMat( &ridx );
1439         }
1440
1441         if( listcount == 0 ) break;
1442
1443         /* find the best node to be added to the tree */
1444         idx = 0;
1445         maxerrdrop = list[idx].errdrop;
1446         for( j = 1; j < listcount; j++ )
1447         {
1448             if( list[j].errdrop > maxerrdrop )
1449             {
1450                 idx = j;
1451                 maxerrdrop = list[j].errdrop;
1452             }
1453         }
1454         intnode[i] = list[idx];
1455         if( list[idx].leftflag )
1456         {
1457             cart->left[list[idx].parent] = i;
1458         }
1459         else
1460         {
1461             cart->right[list[idx].parent] = i;
1462         }
1463         if( idx != (listcount - 1) )
1464         {
1465             list[idx] = list[listcount - 1];
1466         }
1467         listcount--;
1468     }
1469
1470     /* fill <cart> fields */
1471     j = 0;
1472     cart->count = 0;
1473     for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1474     {
1475         cart->count++;
1476         cart->compidx[i] = intnode[i].stump->compidx;
1477         cart->threshold[i] = intnode[i].stump->threshold;
1478
1479         /* leaves */
1480         if( cart->left[i] <= 0 )
1481         {
1482             cart->left[i] = -j;
1483             cart->val[j] = intnode[i].stump->left;
1484             j++;
1485         }
1486         if( cart->right[i] <= 0 )
1487         {
1488             cart->right[i] = -j;
1489             cart->val[j] = intnode[i].stump->right;
1490             j++;
1491         }
1492     }
1493
1494     /* CLEAN UP */
1495     for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1496     {
1497         intnode[i].stump->release( (CvClassifier**) &(intnode[i].stump) );
1498         if( i != 0 )
1499         {
1500             cvReleaseMat( &(intnode[i].sampleIdx) );
1501         }
1502     }
1503     for( i = 0; i < listcount; i++ )
1504     {
1505         list[i].stump->release( (CvClassifier**) &(list[i].stump) );
1506         cvReleaseMat( &(list[i].sampleIdx) );
1507     }
1508
1509     cvFree( &intnode );
1510
1511     return (CvClassifier*) cart;
1512 }
1513
1514 /****************************************************************************************\
1515 *                                        Boosting                                        *
1516 \****************************************************************************************/
1517
1518 typedef struct CvBoostTrainer
1519 {
1520     CvBoostType type;
1521     int count;             /* (idx) ? number_of_indices : number_of_samples */
1522     int* idx;
1523     float* F;
1524 } CvBoostTrainer;
1525
1526 /*
1527  * cvBoostStartTraining, cvBoostNextWeakClassifier, cvBoostEndTraining
1528  *
1529  * These functions perform training of 2-class boosting classifier
1530  * using ANY appropriate weak classifier
1531  */
1532
1533 static
1534 CvBoostTrainer* icvBoostStartTraining( CvMat* trainClasses,
1535                                        CvMat* weakTrainVals,
1536                                        CvMat* /*weights*/,
1537                                        CvMat* sampleIdx,
1538                                        CvBoostType type )
1539 {
1540     uchar* ydata;
1541     int ystep;
1542     int m;
1543     uchar* traindata;
1544     int trainstep;
1545     int trainnum;
1546     int i;
1547     int idx;
1548
1549     size_t datasize;
1550     CvBoostTrainer* ptr;
1551
1552     int idxnum;
1553     int idxstep;
1554     uchar* idxdata;
1555
1556     assert( trainClasses != NULL );
1557     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1558     assert( weakTrainVals != NULL );
1559     assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1560
1561     CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1562     CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1563
1564     CV_Assert( m == trainnum );
1565
1566     idxnum = 0;
1567     idxstep = 0;
1568     idxdata = NULL;
1569     if( sampleIdx )
1570     {
1571         CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1572     }
1573
1574     datasize = sizeof( *ptr ) + sizeof( *ptr->idx ) * idxnum;
1575     ptr = (CvBoostTrainer*) cvAlloc( datasize );
1576     memset( ptr, 0, datasize );
1577     ptr->F = NULL;
1578     ptr->idx = NULL;
1579
1580     ptr->count = m;
1581     ptr->type = type;
1582
1583     if( idxnum > 0 )
1584     {
1585         CvScalar s;
1586
1587         ptr->idx = (int*) (ptr + 1);
1588         ptr->count = idxnum;
1589         for( i = 0; i < ptr->count; i++ )
1590         {
1591             cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1592             ptr->idx[i] = (int) s.val[0];
1593         }
1594     }
1595     for( i = 0; i < ptr->count; i++ )
1596     {
1597         idx = (ptr->idx) ? ptr->idx[i] : i;
1598
1599         *((float*) (traindata + idx * trainstep)) =
1600             2.0F * (*((float*) (ydata + idx * ystep))) - 1.0F;
1601     }
1602
1603     return ptr;
1604 }
1605
1606 /*
1607  *
1608  * Discrete AdaBoost functions
1609  *
1610  */
1611 static
1612 float icvBoostNextWeakClassifierDAB( CvMat* weakEvalVals,
1613                                      CvMat* trainClasses,
1614                                      CvMat* /*weakTrainVals*/,
1615                                      CvMat* weights,
1616                                      CvBoostTrainer* trainer )
1617 {
1618     uchar* evaldata;
1619     int evalstep;
1620     int m;
1621     uchar* ydata;
1622     int ystep;
1623     int ynum;
1624     uchar* wdata;
1625     int wstep;
1626     int wnum;
1627
1628     float sumw;
1629     float err;
1630     int i;
1631     int idx;
1632
1633     CV_Assert( weakEvalVals != NULL );
1634     CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1635     CV_Assert( trainClasses != NULL );
1636     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1637     CV_Assert( weights != NULL );
1638     CV_Assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1639
1640     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1641     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1642     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1643
1644     CV_Assert( m == ynum );
1645     CV_Assert( m == wnum );
1646
1647     sumw = 0.0F;
1648     err = 0.0F;
1649     for( i = 0; i < trainer->count; i++ )
1650     {
1651         idx = (trainer->idx) ? trainer->idx[i] : i;
1652
1653         sumw += *((float*) (wdata + idx*wstep));
1654         err += (*((float*) (wdata + idx*wstep))) *
1655             ( (*((float*) (evaldata + idx*evalstep))) !=
1656                 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F );
1657     }
1658     err /= sumw;
1659     err = -cvLogRatio( err );
1660
1661     for( i = 0; i < trainer->count; i++ )
1662     {
1663         idx = (trainer->idx) ? trainer->idx[i] : i;
1664
1665         *((float*) (wdata + idx*wstep)) *= expf( err *
1666             ((*((float*) (evaldata + idx*evalstep))) !=
1667                 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F) );
1668         sumw += *((float*) (wdata + idx*wstep));
1669     }
1670     for( i = 0; i < trainer->count; i++ )
1671     {
1672         idx = (trainer->idx) ? trainer->idx[i] : i;
1673
1674         *((float*) (wdata + idx * wstep)) /= sumw;
1675     }
1676
1677     return err;
1678 }
1679
1680 /*
1681  *
1682  * Real AdaBoost functions
1683  *
1684  */
1685 static
1686 float icvBoostNextWeakClassifierRAB( CvMat* weakEvalVals,
1687                                      CvMat* trainClasses,
1688                                      CvMat* /*weakTrainVals*/,
1689                                      CvMat* weights,
1690                                      CvBoostTrainer* trainer )
1691 {
1692     uchar* evaldata;
1693     int evalstep;
1694     int m;
1695     uchar* ydata;
1696     int ystep;
1697     int ynum;
1698     uchar* wdata;
1699     int wstep;
1700     int wnum;
1701
1702     float sumw;
1703     int i, idx;
1704
1705     CV_Assert( weakEvalVals != NULL );
1706     CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1707     CV_Assert( trainClasses != NULL );
1708     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1709     CV_Assert( weights != NULL );
1710     CV_Assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1711
1712     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1713     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1714     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1715
1716     CV_Assert( m == ynum );
1717     CV_Assert( m == wnum );
1718
1719
1720     sumw = 0.0F;
1721     for( i = 0; i < trainer->count; i++ )
1722     {
1723         idx = (trainer->idx) ? trainer->idx[i] : i;
1724
1725         *((float*) (wdata + idx*wstep)) *= expf( (-(*((float*) (ydata + idx*ystep))) + 0.5F)
1726             * cvLogRatio( *((float*) (evaldata + idx*evalstep)) ) );
1727         sumw += *((float*) (wdata + idx*wstep));
1728     }
1729     for( i = 0; i < trainer->count; i++ )
1730     {
1731         idx = (trainer->idx) ? trainer->idx[i] : i;
1732
1733         *((float*) (wdata + idx*wstep)) /= sumw;
1734     }
1735
1736     return 1.0F;
1737 }
1738
1739 /*
1740  *
1741  * LogitBoost functions
1742  *
1743  */
1744 #define CV_LB_PROB_THRESH      0.01F
1745 #define CV_LB_WEIGHT_THRESHOLD 0.0001F
1746
1747 static
1748 void icvResponsesAndWeightsLB( int num, uchar* wdata, int wstep,
1749                                uchar* ydata, int ystep,
1750                                uchar* fdata, int fstep,
1751                                uchar* traindata, int trainstep,
1752                                int* indices )
1753 {
1754     int i, idx;
1755     float p;
1756
1757     for( i = 0; i < num; i++ )
1758     {
1759         idx = (indices) ? indices[i] : i;
1760
1761         p = 1.0F / (1.0F + expf( -(*((float*) (fdata + idx*fstep)))) );
1762         *((float*) (wdata + idx*wstep)) = MAX( p * (1.0F - p), CV_LB_WEIGHT_THRESHOLD );
1763         if( *((float*) (ydata + idx*ystep)) == 1.0F )
1764         {
1765             *((float*) (traindata + idx*trainstep)) =
1766                 1.0F / (MAX( p, CV_LB_PROB_THRESH ));
1767         }
1768         else
1769         {
1770             *((float*) (traindata + idx*trainstep)) =
1771                 -1.0F / (MAX( 1.0F - p, CV_LB_PROB_THRESH ));
1772         }
1773     }
1774 }
1775
1776 static
1777 CvBoostTrainer* icvBoostStartTrainingLB( CvMat* trainClasses,
1778                                          CvMat* weakTrainVals,
1779                                          CvMat* weights,
1780                                          CvMat* sampleIdx,
1781                                          CvBoostType type )
1782 {
1783     size_t datasize;
1784     CvBoostTrainer* ptr;
1785
1786     uchar* ydata;
1787     int ystep;
1788     int m;
1789     uchar* traindata;
1790     int trainstep;
1791     int trainnum;
1792     uchar* wdata;
1793     int wstep;
1794     int wnum;
1795     int i;
1796
1797     int idxnum;
1798     int idxstep;
1799     uchar* idxdata;
1800
1801     assert( trainClasses != NULL );
1802     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1803     assert( weakTrainVals != NULL );
1804     assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1805     assert( weights != NULL );
1806     assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1807
1808     CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1809     CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1810     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1811
1812     CV_Assert( m == trainnum );
1813     CV_Assert( m == wnum );
1814
1815
1816     idxnum = 0;
1817     idxstep = 0;
1818     idxdata = NULL;
1819     if( sampleIdx )
1820     {
1821         CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1822     }
1823
1824     datasize = sizeof( *ptr ) + sizeof( *ptr->F ) * m + sizeof( *ptr->idx ) * idxnum;
1825     ptr = (CvBoostTrainer*) cvAlloc( datasize );
1826     memset( ptr, 0, datasize );
1827     ptr->F = (float*) (ptr + 1);
1828     ptr->idx = NULL;
1829
1830     ptr->count = m;
1831     ptr->type = type;
1832
1833     if( idxnum > 0 )
1834     {
1835         CvScalar s;
1836
1837         ptr->idx = (int*) (ptr->F + m);
1838         ptr->count = idxnum;
1839         for( i = 0; i < ptr->count; i++ )
1840         {
1841             cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1842             ptr->idx[i] = (int) s.val[0];
1843         }
1844     }
1845
1846     for( i = 0; i < m; i++ )
1847     {
1848         ptr->F[i] = 0.0F;
1849     }
1850
1851     icvResponsesAndWeightsLB( ptr->count, wdata, wstep, ydata, ystep,
1852                               (uchar*) ptr->F, sizeof( *ptr->F ),
1853                               traindata, trainstep, ptr->idx );
1854
1855     return ptr;
1856 }
1857
1858 static
1859 float icvBoostNextWeakClassifierLB( CvMat* weakEvalVals,
1860                                     CvMat* trainClasses,
1861                                     CvMat* weakTrainVals,
1862                                     CvMat* weights,
1863                                     CvBoostTrainer* trainer )
1864 {
1865     uchar* evaldata;
1866     int evalstep;
1867     int m;
1868     uchar* ydata;
1869     int ystep;
1870     int ynum;
1871     uchar* traindata;
1872     int trainstep;
1873     int trainnum;
1874     uchar* wdata;
1875     int wstep;
1876     int wnum;
1877     int i, idx;
1878
1879     assert( weakEvalVals != NULL );
1880     assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1881     assert( trainClasses != NULL );
1882     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1883     assert( weakTrainVals != NULL );
1884     assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1885     assert( weights != NULL );
1886     assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1887
1888     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1889     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1890     CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1891     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1892
1893     CV_Assert( m == ynum );
1894     CV_Assert( m == wnum );
1895     CV_Assert( m == trainnum );
1896     //assert( m == trainer->count );
1897
1898     for( i = 0; i < trainer->count; i++ )
1899     {
1900         idx = (trainer->idx) ? trainer->idx[i] : i;
1901
1902         trainer->F[idx] += *((float*) (evaldata + idx * evalstep));
1903     }
1904
1905     icvResponsesAndWeightsLB( trainer->count, wdata, wstep, ydata, ystep,
1906                               (uchar*) trainer->F, sizeof( *trainer->F ),
1907                               traindata, trainstep, trainer->idx );
1908
1909     return 1.0F;
1910 }
1911
1912 /*
1913  *
1914  * Gentle AdaBoost
1915  *
1916  */
1917 static
1918 float icvBoostNextWeakClassifierGAB( CvMat* weakEvalVals,
1919                                      CvMat* trainClasses,
1920                                      CvMat* /*weakTrainVals*/,
1921                                      CvMat* weights,
1922                                      CvBoostTrainer* trainer )
1923 {
1924     uchar* evaldata;
1925     int evalstep;
1926     int m;
1927     uchar* ydata;
1928     int ystep;
1929     int ynum;
1930     uchar* wdata;
1931     int wstep;
1932     int wnum;
1933
1934     int i, idx;
1935     float sumw;
1936
1937     CV_Assert( weakEvalVals != NULL );
1938     CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1939     CV_Assert( trainClasses != NULL );
1940     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1941     CV_Assert( weights != NULL );
1942     CV_Assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1943
1944     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1945     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1946     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1947
1948     CV_Assert( m == ynum );
1949     CV_Assert( m == wnum );
1950
1951     sumw = 0.0F;
1952     for( i = 0; i < trainer->count; i++ )
1953     {
1954         idx = (trainer->idx) ? trainer->idx[i] : i;
1955
1956         *((float*) (wdata + idx*wstep)) *=
1957             expf( -(*((float*) (evaldata + idx*evalstep)))
1958                   * ( 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F ) );
1959         sumw += *((float*) (wdata + idx*wstep));
1960     }
1961
1962     for( i = 0; i < trainer->count; i++ )
1963     {
1964         idx = (trainer->idx) ? trainer->idx[i] : i;
1965
1966         *((float*) (wdata + idx*wstep)) /= sumw;
1967     }
1968
1969     return 1.0F;
1970 }
1971
1972 typedef CvBoostTrainer* (*CvBoostStartTraining)( CvMat* trainClasses,
1973                                                  CvMat* weakTrainVals,
1974                                                  CvMat* weights,
1975                                                  CvMat* sampleIdx,
1976                                                  CvBoostType type );
1977
1978 typedef float (*CvBoostNextWeakClassifier)( CvMat* weakEvalVals,
1979                                             CvMat* trainClasses,
1980                                             CvMat* weakTrainVals,
1981                                             CvMat* weights,
1982                                             CvBoostTrainer* data );
1983
1984 CvBoostStartTraining startTraining[4] = {
1985         icvBoostStartTraining,
1986         icvBoostStartTraining,
1987         icvBoostStartTrainingLB,
1988         icvBoostStartTraining
1989     };
1990
1991 CvBoostNextWeakClassifier nextWeakClassifier[4] = {
1992         icvBoostNextWeakClassifierDAB,
1993         icvBoostNextWeakClassifierRAB,
1994         icvBoostNextWeakClassifierLB,
1995         icvBoostNextWeakClassifierGAB
1996     };
1997
1998 /*
1999  *
2000  * Dispatchers
2001  *
2002  */
2003 CV_BOOST_IMPL
2004 CvBoostTrainer* cvBoostStartTraining( CvMat* trainClasses,
2005                                       CvMat* weakTrainVals,
2006                                       CvMat* weights,
2007                                       CvMat* sampleIdx,
2008                                       CvBoostType type )
2009 {
2010     return startTraining[type]( trainClasses, weakTrainVals, weights, sampleIdx, type );
2011 }
2012
2013 CV_BOOST_IMPL
2014 void cvBoostEndTraining( CvBoostTrainer** trainer )
2015 {
2016     cvFree( trainer );
2017     *trainer = NULL;
2018 }
2019
2020 CV_BOOST_IMPL
2021 float cvBoostNextWeakClassifier( CvMat* weakEvalVals,
2022                                  CvMat* trainClasses,
2023                                  CvMat* weakTrainVals,
2024                                  CvMat* weights,
2025                                  CvBoostTrainer* trainer )
2026 {
2027     return nextWeakClassifier[trainer->type]( weakEvalVals, trainClasses,
2028         weakTrainVals, weights, trainer    );
2029 }
2030
2031 /****************************************************************************************\
2032 *                                    Boosted tree models                                 *
2033 \****************************************************************************************/
2034
2035 typedef struct CvBtTrainer
2036 {
2037     /* {{ external */
2038     CvMat* trainData;
2039     int flags;
2040
2041     CvMat* trainClasses;
2042     int m;
2043     uchar* ydata;
2044     int ystep;
2045
2046     CvMat* sampleIdx;
2047     int numsamples;
2048
2049     float param[2];
2050     CvBoostType type;
2051     int numclasses;
2052     /* }} external */
2053
2054     CvMTStumpTrainParams stumpParams;
2055     CvCARTTrainParams  cartParams;
2056
2057     float* f;          /* F_(m-1) */
2058     CvMat* y;          /* yhat    */
2059     CvMat* weights;
2060     CvBoostTrainer* boosttrainer;
2061 } CvBtTrainer;
2062
2063 /*
2064  * cvBtStart, cvBtNext, cvBtEnd
2065  *
2066  * These functions perform iterative training of
2067  * 2-class (CV_DABCLASS - CV_GABCLASS, CV_L2CLASS), K-class (CV_LKCLASS) classifier
2068  * or fit regression model (CV_LSREG, CV_LADREG, CV_MREG)
2069  * using decision tree as a weak classifier.
2070  */
2071
2072 typedef void (*CvZeroApproxFunc)( float* approx, CvBtTrainer* trainer );
2073
2074 /* Mean zero approximation */
2075 static void icvZeroApproxMean( float* approx, CvBtTrainer* trainer )
2076 {
2077     int i;
2078     int idx;
2079
2080     approx[0] = 0.0F;
2081     for( i = 0; i < trainer->numsamples; i++ )
2082     {
2083         idx = icvGetIdxAt( trainer->sampleIdx, i );
2084         approx[0] += *((float*) (trainer->ydata + idx * trainer->ystep));
2085     }
2086     approx[0] /= (float) trainer->numsamples;
2087 }
2088
2089 /*
2090  * Median zero approximation
2091  */
2092 static void icvZeroApproxMed( float* approx, CvBtTrainer* trainer )
2093 {
2094     int i;
2095     int idx;
2096
2097     for( i = 0; i < trainer->numsamples; i++ )
2098     {
2099         idx = icvGetIdxAt( trainer->sampleIdx, i );
2100         trainer->f[i] = *((float*) (trainer->ydata + idx * trainer->ystep));
2101     }
2102
2103     std::sort(trainer->f, trainer->f + trainer->numsamples);
2104     approx[0] = trainer->f[trainer->numsamples / 2];
2105 }
2106
2107 /*
2108  * 0.5 * log( mean(y) / (1 - mean(y)) ) where y in {0, 1}
2109  */
2110 static void icvZeroApproxLog( float* approx, CvBtTrainer* trainer )
2111 {
2112     float y_mean;
2113
2114     icvZeroApproxMean( &y_mean, trainer );
2115     approx[0] = 0.5F * cvLogRatio( y_mean );
2116 }
2117
2118 /*
2119  * 0 zero approximation
2120  */
2121 static void icvZeroApprox0( float* approx, CvBtTrainer* trainer )
2122 {
2123     int i;
2124
2125     for( i = 0; i < trainer->numclasses; i++ )
2126     {
2127         approx[i] = 0.0F;
2128     }
2129 }
2130
2131 static CvZeroApproxFunc icvZeroApproxFunc[] =
2132 {
2133     icvZeroApprox0,    /* CV_DABCLASS */
2134     icvZeroApprox0,    /* CV_RABCLASS */
2135     icvZeroApprox0,    /* CV_LBCLASS  */
2136     icvZeroApprox0,    /* CV_GABCLASS */
2137     icvZeroApproxLog,  /* CV_L2CLASS  */
2138     icvZeroApprox0,    /* CV_LKCLASS  */
2139     icvZeroApproxMean, /* CV_LSREG    */
2140     icvZeroApproxMed,  /* CV_LADREG   */
2141     icvZeroApproxMed,  /* CV_MREG     */
2142 };
2143
2144 CV_BOOST_IMPL
2145 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer );
2146
2147 static
2148 CvBtTrainer* cvBtStart( CvCARTClassifier** trees,
2149                         CvMat* trainData,
2150                         int flags,
2151                         CvMat* trainClasses,
2152                         CvMat* sampleIdx,
2153                         int numsplits,
2154                         CvBoostType type,
2155                         int numclasses,
2156                         float* param )
2157 {
2158     CvBtTrainer* ptr = 0;
2159
2160     CV_FUNCNAME( "cvBtStart" );
2161
2162     __BEGIN__;
2163
2164     size_t data_size;
2165     float* zero_approx;
2166     int m;
2167     int i, j;
2168
2169     if( trees == NULL )
2170     {
2171         CV_ERROR( CV_StsNullPtr, "Invalid trees parameter" );
2172     }
2173
2174     if( type < CV_DABCLASS || type > CV_MREG )
2175     {
2176         CV_ERROR( CV_StsUnsupportedFormat, "Unsupported type parameter" );
2177     }
2178     if( type == CV_LKCLASS )
2179     {
2180         CV_ASSERT( numclasses >= 2 );
2181     }
2182     else
2183     {
2184         numclasses = 1;
2185     }
2186
2187     m = MAX( trainClasses->rows, trainClasses->cols );
2188     ptr = NULL;
2189     data_size = sizeof( *ptr );
2190     if( type > CV_GABCLASS )
2191     {
2192         data_size += m * numclasses * sizeof( *(ptr->f) );
2193     }
2194     CV_CALL( ptr = (CvBtTrainer*) cvAlloc( data_size ) );
2195     memset( ptr, 0, data_size );
2196     ptr->f = (float*) (ptr + 1);
2197
2198     ptr->trainData = trainData;
2199     ptr->flags = flags;
2200     ptr->trainClasses = trainClasses;
2201     CV_MAT2VEC( *trainClasses, ptr->ydata, ptr->ystep, ptr->m );
2202
2203     memset( &(ptr->cartParams), 0, sizeof( ptr->cartParams ) );
2204     memset( &(ptr->stumpParams), 0, sizeof( ptr->stumpParams ) );
2205
2206     switch( type )
2207     {
2208         case CV_DABCLASS:
2209             ptr->stumpParams.error = CV_MISCLASSIFICATION;
2210             ptr->stumpParams.type  = CV_CLASSIFICATION_CLASS;
2211             break;
2212         case CV_RABCLASS:
2213             ptr->stumpParams.error = CV_GINI;
2214             ptr->stumpParams.type  = CV_CLASSIFICATION;
2215             break;
2216         default:
2217             ptr->stumpParams.error = CV_SQUARE;
2218             ptr->stumpParams.type  = CV_REGRESSION;
2219     }
2220     ptr->cartParams.count = numsplits;
2221     ptr->cartParams.stumpTrainParams = (CvClassifierTrainParams*) &(ptr->stumpParams);
2222     ptr->cartParams.stumpConstructor = cvCreateMTStumpClassifier;
2223
2224     ptr->param[0] = param[0];
2225     ptr->param[1] = param[1];
2226     ptr->type = type;
2227     ptr->numclasses = numclasses;
2228
2229     CV_CALL( ptr->y = cvCreateMat( 1, m, CV_32FC1 ) );
2230     ptr->sampleIdx = sampleIdx;
2231     ptr->numsamples = ( sampleIdx == NULL ) ? ptr->m
2232                              : MAX( sampleIdx->rows, sampleIdx->cols );
2233
2234     ptr->weights = cvCreateMat( 1, m, CV_32FC1 );
2235     cvSet( ptr->weights, cvScalar( 1.0 ) );
2236
2237     if( type <= CV_GABCLASS )
2238     {
2239         ptr->boosttrainer = cvBoostStartTraining( ptr->trainClasses, ptr->y,
2240             ptr->weights, NULL, type );
2241
2242         CV_CALL( cvBtNext( trees, ptr ) );
2243     }
2244     else
2245     {
2246         data_size = sizeof( *zero_approx ) * numclasses;
2247         CV_CALL( zero_approx = (float*) cvAlloc( data_size ) );
2248         icvZeroApproxFunc[type]( zero_approx, ptr );
2249         for( i = 0; i < m; i++ )
2250         {
2251             for( j = 0; j < numclasses; j++ )
2252             {
2253                 ptr->f[i * numclasses + j] = zero_approx[j];
2254             }
2255         }
2256
2257         CV_CALL( cvBtNext( trees, ptr ) );
2258
2259         for( i = 0; i < numclasses; i++ )
2260         {
2261             for( j = 0; j <= trees[i]->count; j++ )
2262             {
2263                 trees[i]->val[j] += zero_approx[i];
2264             }
2265         }
2266         CV_CALL( cvFree( &zero_approx ) );
2267     }
2268
2269     __END__;
2270
2271     return ptr;
2272 }
2273
2274 static void icvBtNext_LSREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2275 {
2276     int i;
2277
2278     /* yhat_i = y_i - F_(m-1)(x_i) */
2279     for( i = 0; i < trainer->m; i++ )
2280     {
2281         trainer->y->data.fl[i] =
2282             *((float*) (trainer->ydata + i * trainer->ystep)) - trainer->f[i];
2283     }
2284
2285     trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2286         trainer->flags,
2287         trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2288         (CvClassifierTrainParams*) &trainer->cartParams );
2289 }
2290
2291
2292 static void icvBtNext_LADREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2293 {
2294     CvCARTClassifier* ptr;
2295     int i, j;
2296     CvMat sample;
2297     int sample_step;
2298     uchar* sample_data;
2299     int index;
2300
2301     int data_size;
2302     int* idx;
2303     float* resp;
2304     int respnum;
2305     float val;
2306
2307     data_size = trainer->m * sizeof( *idx );
2308     idx = (int*) cvAlloc( data_size );
2309     data_size = trainer->m * sizeof( *resp );
2310     resp = (float*) cvAlloc( data_size );
2311
2312     /* yhat_i = sign(y_i - F_(m-1)(x_i)) */
2313     for( i = 0; i < trainer->numsamples; i++ )
2314     {
2315         index = icvGetIdxAt( trainer->sampleIdx, i );
2316         trainer->y->data.fl[index] = (float)
2317              CV_SIGN( *((float*) (trainer->ydata + index * trainer->ystep))
2318                      - trainer->f[index] );
2319     }
2320
2321     ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2322         trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2323         (CvClassifierTrainParams*) &trainer->cartParams );
2324
2325     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2326     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2327     sample_data = sample.data.ptr;
2328     for( i = 0; i < trainer->numsamples; i++ )
2329     {
2330         index = icvGetIdxAt( trainer->sampleIdx, i );
2331         sample.data.ptr = sample_data + index * sample_step;
2332         idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2333     }
2334     for( j = 0; j <= ptr->count; j++ )
2335     {
2336         respnum = 0;
2337         for( i = 0; i < trainer->numsamples; i++ )
2338         {
2339             index = icvGetIdxAt( trainer->sampleIdx, i );
2340             if( idx[index] == j )
2341             {
2342                 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2343                                   - trainer->f[index];
2344             }
2345         }
2346         if( respnum > 0 )
2347         {
2348             std::sort(resp, resp + respnum);
2349             val = resp[respnum / 2];
2350         }
2351         else
2352         {
2353             val = 0.0F;
2354         }
2355         ptr->val[j] = val;
2356     }
2357
2358     cvFree( &idx );
2359     cvFree( &resp );
2360
2361     trees[0] = ptr;
2362 }
2363
2364
2365 static void icvBtNext_MREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2366 {
2367     CvCARTClassifier* ptr;
2368     int i, j;
2369     CvMat sample;
2370     int sample_step;
2371     uchar* sample_data;
2372
2373     int data_size;
2374     int* idx;
2375     float* resid;
2376     float* resp;
2377     int respnum;
2378     float rhat;
2379     float val;
2380     float delta;
2381     int index;
2382
2383     data_size = trainer->m * sizeof( *idx );
2384     idx = (int*) cvAlloc( data_size );
2385     data_size = trainer->m * sizeof( *resp );
2386     resp = (float*) cvAlloc( data_size );
2387     data_size = trainer->m * sizeof( *resid );
2388     resid = (float*) cvAlloc( data_size );
2389
2390     /* resid_i = (y_i - F_(m-1)(x_i)) */
2391     for( i = 0; i < trainer->numsamples; i++ )
2392     {
2393         index = icvGetIdxAt( trainer->sampleIdx, i );
2394         resid[index] = *((float*) (trainer->ydata + index * trainer->ystep))
2395                        - trainer->f[index];
2396         /* for delta */
2397         resp[i] = (float) fabs( resid[index] );
2398     }
2399
2400     /* delta = quantile_alpha{abs(resid_i)} */
2401     std::sort(resp, resp + trainer->numsamples);
2402     delta = resp[(int)(trainer->param[1] * (trainer->numsamples - 1))];
2403
2404     /* yhat_i */
2405     for( i = 0; i < trainer->numsamples; i++ )
2406     {
2407         index = icvGetIdxAt( trainer->sampleIdx, i );
2408         trainer->y->data.fl[index] = MIN( delta, ((float) fabs( resid[index] )) ) *
2409                                  CV_SIGN( resid[index] );
2410     }
2411
2412     ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2413         trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2414         (CvClassifierTrainParams*) &trainer->cartParams );
2415
2416     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2417     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2418     sample_data = sample.data.ptr;
2419     for( i = 0; i < trainer->numsamples; i++ )
2420     {
2421         index = icvGetIdxAt( trainer->sampleIdx, i );
2422         sample.data.ptr = sample_data + index * sample_step;
2423         idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2424     }
2425     for( j = 0; j <= ptr->count; j++ )
2426     {
2427         respnum = 0;
2428
2429         for( i = 0; i < trainer->numsamples; i++ )
2430         {
2431             index = icvGetIdxAt( trainer->sampleIdx, i );
2432             if( idx[index] == j )
2433             {
2434                 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2435                                   - trainer->f[index];
2436             }
2437         }
2438         if( respnum > 0 )
2439         {
2440             /* rhat = median(y_i - F_(m-1)(x_i)) */
2441             std::sort(resp, resp + respnum);
2442             rhat = resp[respnum / 2];
2443
2444             /* val = sum{sign(r_i - rhat_i) * min(delta, abs(r_i - rhat_i)}
2445              * r_i = y_i - F_(m-1)(x_i)
2446              */
2447             val = 0.0F;
2448             for( i = 0; i < respnum; i++ )
2449             {
2450                 val += CV_SIGN( resp[i] - rhat )
2451                        * MIN( delta, (float) fabs( resp[i] - rhat ) );
2452             }
2453
2454             val = rhat + val / (float) respnum;
2455         }
2456         else
2457         {
2458             val = 0.0F;
2459         }
2460
2461         ptr->val[j] = val;
2462
2463     }
2464
2465     cvFree( &resid );
2466     cvFree( &resp );
2467     cvFree( &idx );
2468
2469     trees[0] = ptr;
2470 }
2471
2472 //#define CV_VAL_MAX 1e304
2473
2474 //#define CV_LOG_VAL_MAX 700.0
2475
2476 #define CV_VAL_MAX 1e+8
2477
2478 #define CV_LOG_VAL_MAX 18.0
2479
2480 static void icvBtNext_L2CLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2481 {
2482     CvCARTClassifier* ptr;
2483     int i, j;
2484     CvMat sample;
2485     int sample_step;
2486     uchar* sample_data;
2487
2488     int data_size;
2489     int* idx;
2490     int respnum;
2491     float val;
2492     double val_f;
2493
2494     float sum_weights;
2495     float* weights;
2496     float* sorted_weights;
2497     CvMat* trimmed_idx;
2498     CvMat* sample_idx;
2499     int index;
2500     int trimmed_num;
2501
2502     data_size = trainer->m * sizeof( *idx );
2503     idx = (int*) cvAlloc( data_size );
2504
2505     data_size = trainer->m * sizeof( *weights );
2506     weights = (float*) cvAlloc( data_size );
2507     data_size = trainer->m * sizeof( *sorted_weights );
2508     sorted_weights = (float*) cvAlloc( data_size );
2509
2510     /* yhat_i = (4 * y_i - 2) / ( 1 + exp( (4 * y_i - 2) * F_(m-1)(x_i) ) ).
2511      *   y_i in {0, 1}
2512      */
2513     sum_weights = 0.0F;
2514     for( i = 0; i < trainer->numsamples; i++ )
2515     {
2516         index = icvGetIdxAt( trainer->sampleIdx, i );
2517         val = 4.0F * (*((float*) (trainer->ydata + index * trainer->ystep))) - 2.0F;
2518         val_f = val * trainer->f[index];
2519         val_f = ( val_f < CV_LOG_VAL_MAX ) ? exp( val_f ) : CV_LOG_VAL_MAX;
2520         val = (float) ( (double) val / ( 1.0 + val_f ) );
2521         trainer->y->data.fl[index] = val;
2522         val = (float) fabs( val );
2523         weights[index] = val * (2.0F - val);
2524         sorted_weights[i] = weights[index];
2525         sum_weights += sorted_weights[i];
2526     }
2527
2528     trimmed_idx = NULL;
2529     sample_idx = trainer->sampleIdx;
2530     trimmed_num = trainer->numsamples;
2531     if( trainer->param[1] < 1.0F )
2532     {
2533         /* perform weight trimming */
2534
2535         float threshold;
2536         int count;
2537
2538         std::sort(sorted_weights, sorted_weights + trainer->numsamples);
2539
2540         sum_weights *= (1.0F - trainer->param[1]);
2541
2542         i = -1;
2543         do { sum_weights -= sorted_weights[++i]; }
2544         while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2545
2546         threshold = sorted_weights[i];
2547
2548         while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2549
2550         if( i > 0 )
2551         {
2552             trimmed_num = trainer->numsamples - i;
2553             trimmed_idx = cvCreateMat( 1, trimmed_num, CV_32FC1 );
2554             count = 0;
2555             for( i = 0; i < trainer->numsamples; i++ )
2556             {
2557                 index = icvGetIdxAt( trainer->sampleIdx, i );
2558                 if( weights[index] >= threshold )
2559                 {
2560                     CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2561                     count++;
2562                 }
2563             }
2564
2565             assert( count == trimmed_num );
2566
2567             sample_idx = trimmed_idx;
2568
2569             printf( "Used samples %%: %g\n",
2570                 (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2571         }
2572     }
2573
2574     ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2575         trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2576         (CvClassifierTrainParams*) &trainer->cartParams );
2577
2578     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2579     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2580     sample_data = sample.data.ptr;
2581     for( i = 0; i < trimmed_num; i++ )
2582     {
2583         index = icvGetIdxAt( sample_idx, i );
2584         sample.data.ptr = sample_data + index * sample_step;
2585         idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2586     }
2587     for( j = 0; j <= ptr->count; j++ )
2588     {
2589         respnum = 0;
2590         val = 0.0F;
2591         sum_weights = 0.0F;
2592         for( i = 0; i < trimmed_num; i++ )
2593         {
2594             index = icvGetIdxAt( sample_idx, i );
2595             if( idx[index] == j )
2596             {
2597                 val += trainer->y->data.fl[index];
2598                 sum_weights += weights[index];
2599                 respnum++;
2600             }
2601         }
2602         if( sum_weights > 0.0F )
2603         {
2604             val /= sum_weights;
2605         }
2606         else
2607         {
2608             val = 0.0F;
2609         }
2610         ptr->val[j] = val;
2611     }
2612
2613     if( trimmed_idx != NULL ) cvReleaseMat( &trimmed_idx );
2614     cvFree( &sorted_weights );
2615     cvFree( &weights );
2616     cvFree( &idx );
2617
2618     trees[0] = ptr;
2619 }
2620
2621 static void icvBtNext_LKCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2622 {
2623     int i, j, k, kk, num;
2624     CvMat sample;
2625     int sample_step;
2626     uchar* sample_data;
2627
2628     int data_size;
2629     int* idx;
2630     int respnum;
2631     float val;
2632
2633     float sum_weights;
2634     float* weights;
2635     float* sorted_weights;
2636     CvMat* trimmed_idx;
2637     CvMat* sample_idx;
2638     int index;
2639     int trimmed_num;
2640     double sum_exp_f;
2641     double exp_f;
2642     double f_k;
2643
2644     data_size = trainer->m * sizeof( *idx );
2645     idx = (int*) cvAlloc( data_size );
2646     data_size = trainer->m * sizeof( *weights );
2647     weights = (float*) cvAlloc( data_size );
2648     data_size = trainer->m * sizeof( *sorted_weights );
2649     sorted_weights = (float*) cvAlloc( data_size );
2650     trimmed_idx = cvCreateMat( 1, trainer->numsamples, CV_32FC1 );
2651
2652     for( k = 0; k < trainer->numclasses; k++ )
2653     {
2654         /* yhat_i = y_i - p_k(x_i), y_i in {0, 1}      */
2655         /* p_k(x_i) = exp(f_k(x_i)) / (sum_exp_f(x_i)) */
2656         sum_weights = 0.0F;
2657         for( i = 0; i < trainer->numsamples; i++ )
2658         {
2659             index = icvGetIdxAt( trainer->sampleIdx, i );
2660             /* p_k(x_i) = 1 / (1 + sum(exp(f_kk(x_i) - f_k(x_i)))), kk != k */
2661             num = index * trainer->numclasses;
2662             f_k = (double) trainer->f[num + k];
2663             sum_exp_f = 1.0;
2664             for( kk = 0; kk < trainer->numclasses; kk++ )
2665             {
2666                 if( kk == k ) continue;
2667                 exp_f = (double) trainer->f[num + kk] - f_k;
2668                 exp_f = (exp_f < CV_LOG_VAL_MAX) ? exp( exp_f ) : CV_VAL_MAX;
2669                 if( exp_f == CV_VAL_MAX || exp_f >= (CV_VAL_MAX - sum_exp_f) )
2670                 {
2671                     sum_exp_f = CV_VAL_MAX;
2672                     break;
2673                 }
2674                 sum_exp_f += exp_f;
2675             }
2676
2677             val = (float) ( (*((float*) (trainer->ydata + index * trainer->ystep)))
2678                             == (float) k );
2679             val -= (float) ( (sum_exp_f == CV_VAL_MAX) ? 0.0 : ( 1.0 / sum_exp_f ) );
2680
2681             assert( val >= -1.0F );
2682             assert( val <= 1.0F );
2683
2684             trainer->y->data.fl[index] = val;
2685             val = (float) fabs( val );
2686             weights[index] = val * (1.0F - val);
2687             sorted_weights[i] = weights[index];
2688             sum_weights += sorted_weights[i];
2689         }
2690
2691         sample_idx = trainer->sampleIdx;
2692         trimmed_num = trainer->numsamples;
2693         if( trainer->param[1] < 1.0F )
2694         {
2695             /* perform weight trimming */
2696
2697             float threshold;
2698             int count;
2699
2700             std::sort(sorted_weights, sorted_weights + trainer->numsamples);
2701
2702             sum_weights *= (1.0F - trainer->param[1]);
2703
2704             i = -1;
2705             do { sum_weights -= sorted_weights[++i]; }
2706             while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2707
2708             threshold = sorted_weights[i];
2709
2710             while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2711
2712             if( i > 0 )
2713             {
2714                 trimmed_num = trainer->numsamples - i;
2715                 trimmed_idx->cols = trimmed_num;
2716                 count = 0;
2717                 for( i = 0; i < trainer->numsamples; i++ )
2718                 {
2719                     index = icvGetIdxAt( trainer->sampleIdx, i );
2720                     if( weights[index] >= threshold )
2721                     {
2722                         CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2723                         count++;
2724                     }
2725                 }
2726
2727                 assert( count == trimmed_num );
2728
2729                 sample_idx = trimmed_idx;
2730
2731                 printf( "k: %d Used samples %%: %g\n", k,
2732                     (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2733             }
2734         } /* weight trimming */
2735
2736         trees[k] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2737             trainer->flags, trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2738             (CvClassifierTrainParams*) &trainer->cartParams );
2739
2740         CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2741         CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2742         sample_data = sample.data.ptr;
2743         for( i = 0; i < trimmed_num; i++ )
2744         {
2745             index = icvGetIdxAt( sample_idx, i );
2746             sample.data.ptr = sample_data + index * sample_step;
2747             idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) trees[k],
2748                                                         &sample );
2749         }
2750         for( j = 0; j <= trees[k]->count; j++ )
2751         {
2752             respnum = 0;
2753             val = 0.0F;
2754             sum_weights = 0.0F;
2755             for( i = 0; i < trimmed_num; i++ )
2756             {
2757                 index = icvGetIdxAt( sample_idx, i );
2758                 if( idx[index] == j )
2759                 {
2760                     val += trainer->y->data.fl[index];
2761                     sum_weights += weights[index];
2762                     respnum++;
2763                 }
2764             }
2765             if( sum_weights > 0.0F )
2766             {
2767                 val = ((float) (trainer->numclasses - 1)) * val /
2768                       ((float) (trainer->numclasses)) / sum_weights;
2769             }
2770             else
2771             {
2772                 val = 0.0F;
2773             }
2774             trees[k]->val[j] = val;
2775         }
2776     } /* for each class */
2777
2778     cvReleaseMat( &trimmed_idx );
2779     cvFree( &sorted_weights );
2780     cvFree( &weights );
2781     cvFree( &idx );
2782 }
2783
2784
2785 static void icvBtNext_XXBCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2786 {
2787     float alpha;
2788     int i;
2789     CvMat* weak_eval_vals;
2790     CvMat* sample_idx;
2791     int num_samples;
2792     CvMat sample;
2793     uchar* sample_data;
2794     int sample_step;
2795
2796     weak_eval_vals = cvCreateMat( 1, trainer->m, CV_32FC1 );
2797
2798     sample_idx = cvTrimWeights( trainer->weights, trainer->sampleIdx,
2799                                 trainer->param[1] );
2800     num_samples = ( sample_idx == NULL )
2801         ? trainer->m : MAX( sample_idx->rows, sample_idx->cols );
2802
2803     printf( "Used samples %%: %g\n",
2804         (float) num_samples / (float) trainer->numsamples * 100.0F );
2805
2806     trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2807         trainer->flags, trainer->y, NULL, NULL, NULL,
2808         sample_idx, trainer->weights,
2809         (CvClassifierTrainParams*) &trainer->cartParams );
2810
2811     /* evaluate samples */
2812     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2813     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2814     sample_data = sample.data.ptr;
2815
2816     for( i = 0; i < trainer->m; i++ )
2817     {
2818         sample.data.ptr = sample_data + i * sample_step;
2819         weak_eval_vals->data.fl[i] = trees[0]->eval( (CvClassifier*) trees[0], &sample );
2820     }
2821
2822     alpha = cvBoostNextWeakClassifier( weak_eval_vals, trainer->trainClasses,
2823         trainer->y, trainer->weights, trainer->boosttrainer );
2824
2825     /* multiply tree by alpha */
2826     for( i = 0; i <= trees[0]->count; i++ )
2827     {
2828         trees[0]->val[i] *= alpha;
2829     }
2830     if( trainer->type == CV_RABCLASS )
2831     {
2832         for( i = 0; i <= trees[0]->count; i++ )
2833         {
2834             trees[0]->val[i] = cvLogRatio( trees[0]->val[i] );
2835         }
2836     }
2837
2838     if( sample_idx != NULL && sample_idx != trainer->sampleIdx )
2839     {
2840         cvReleaseMat( &sample_idx );
2841     }
2842     cvReleaseMat( &weak_eval_vals );
2843 }
2844
2845 typedef void (*CvBtNextFunc)( CvCARTClassifier** trees, CvBtTrainer* trainer );
2846
2847 static CvBtNextFunc icvBtNextFunc[] =
2848 {
2849     icvBtNext_XXBCLASS,
2850     icvBtNext_XXBCLASS,
2851     icvBtNext_XXBCLASS,
2852     icvBtNext_XXBCLASS,
2853     icvBtNext_L2CLASS,
2854     icvBtNext_LKCLASS,
2855     icvBtNext_LSREG,
2856     icvBtNext_LADREG,
2857     icvBtNext_MREG
2858 };
2859
2860 CV_BOOST_IMPL
2861 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer )
2862 {
2863     int i, j;
2864     int index;
2865     CvMat sample;
2866     int sample_step;
2867     uchar* sample_data;
2868
2869     icvBtNextFunc[trainer->type]( trees, trainer );
2870
2871     /* shrinkage */
2872     if( trainer->param[0] != 1.0F )
2873     {
2874         for( j = 0; j < trainer->numclasses; j++ )
2875         {
2876             for( i = 0; i <= trees[j]->count; i++ )
2877             {
2878                 trees[j]->val[i] *= trainer->param[0];
2879             }
2880         }
2881     }
2882
2883     if( trainer->type > CV_GABCLASS )
2884     {
2885         /* update F_(m-1) */
2886         CV_GET_SAMPLE( *(trainer->trainData), trainer->flags, 0, sample );
2887         CV_GET_SAMPLE_STEP( *(trainer->trainData), trainer->flags, sample_step );
2888         sample_data = sample.data.ptr;
2889         for( i = 0; i < trainer->numsamples; i++ )
2890         {
2891             index = icvGetIdxAt( trainer->sampleIdx, i );
2892             sample.data.ptr = sample_data + index * sample_step;
2893             for( j = 0; j < trainer->numclasses; j++ )
2894             {
2895                 trainer->f[index * trainer->numclasses + j] +=
2896                     trees[j]->eval( (CvClassifier*) (trees[j]), &sample );
2897             }
2898         }
2899     }
2900 }
2901
2902 static
2903 void cvBtEnd( CvBtTrainer** trainer )
2904 {
2905     CV_FUNCNAME( "cvBtEnd" );
2906
2907     __BEGIN__;
2908
2909     if( trainer == NULL || (*trainer) == NULL )
2910     {
2911         CV_ERROR( CV_StsNullPtr, "Invalid trainer parameter" );
2912     }
2913
2914     if( (*trainer)->y != NULL )
2915     {
2916         CV_CALL( cvReleaseMat( &((*trainer)->y) ) );
2917     }
2918     if( (*trainer)->weights != NULL )
2919     {
2920         CV_CALL( cvReleaseMat( &((*trainer)->weights) ) );
2921     }
2922     if( (*trainer)->boosttrainer != NULL )
2923     {
2924         CV_CALL( cvBoostEndTraining( &((*trainer)->boosttrainer) ) );
2925     }
2926     CV_CALL( cvFree( trainer ) );
2927
2928     __END__;
2929 }
2930
2931 /****************************************************************************************\
2932 *                         Boosted tree model as a classifier                             *
2933 \****************************************************************************************/
2934
2935 static
2936 float cvEvalBtClassifier( CvClassifier* classifier, CvMat* sample )
2937 {
2938     float val;
2939
2940     CV_FUNCNAME( "cvEvalBtClassifier" );
2941
2942     __BEGIN__;
2943
2944     int i;
2945
2946     val = 0.0F;
2947     if( CV_IS_TUNABLE( classifier->flags ) )
2948     {
2949         CvSeqReader reader;
2950         CvCARTClassifier* tree;
2951
2952         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
2953         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2954         {
2955             CV_READ_SEQ_ELEM( tree, reader );
2956             val += tree->eval( (CvClassifier*) tree, sample );
2957         }
2958     }
2959     else
2960     {
2961         CvCARTClassifier** ptree;
2962
2963         ptree = ((CvBtClassifier*) classifier)->trees;
2964         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2965         {
2966             val += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
2967             ptree++;
2968         }
2969     }
2970
2971     __END__;
2972
2973     return val;
2974 }
2975
2976 static
2977 float cvEvalBtClassifier2( CvClassifier* classifier, CvMat* sample )
2978 {
2979     float val;
2980
2981     CV_FUNCNAME( "cvEvalBtClassifier2" );
2982
2983     __BEGIN__;
2984
2985     CV_CALL( val = cvEvalBtClassifier( classifier, sample ) );
2986
2987     __END__;
2988
2989     return (float) (val >= 0.0F);
2990 }
2991
2992 static
2993 float cvEvalBtClassifierK( CvClassifier* classifier, CvMat* sample )
2994 {
2995     int cls = 0;
2996
2997     CV_FUNCNAME( "cvEvalBtClassifierK" );
2998
2999     __BEGIN__;
3000
3001     int i, k;
3002     float max_val;
3003     int numclasses;
3004
3005     float* vals;
3006     size_t data_size;
3007
3008     numclasses = ((CvBtClassifier*) classifier)->numclasses;
3009     data_size = sizeof( *vals ) * numclasses;
3010     CV_CALL( vals = (float*) cvAlloc( data_size ) );
3011     memset( vals, 0, data_size );
3012
3013     if( CV_IS_TUNABLE( classifier->flags ) )
3014     {
3015         CvSeqReader reader;
3016         CvCARTClassifier* tree;
3017
3018         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3019         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3020         {
3021             for( k = 0; k < numclasses; k++ )
3022             {
3023                 CV_READ_SEQ_ELEM( tree, reader );
3024                 vals[k] += tree->eval( (CvClassifier*) tree, sample );
3025             }
3026         }
3027
3028     }
3029     else
3030     {
3031         CvCARTClassifier** ptree;
3032
3033         ptree = ((CvBtClassifier*) classifier)->trees;
3034         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3035         {
3036             for( k = 0; k < numclasses; k++ )
3037             {
3038                 vals[k] += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
3039                 ptree++;
3040             }
3041         }
3042     }
3043
3044     max_val = vals[cls];
3045     for( k = 1; k < numclasses; k++ )
3046     {
3047         if( vals[k] > max_val )
3048         {
3049             max_val = vals[k];
3050             cls = k;
3051         }
3052     }
3053
3054     CV_CALL( cvFree( &vals ) );
3055
3056     __END__;
3057
3058     return (float) cls;
3059 }
3060
3061 typedef float (*CvEvalBtClassifier)( CvClassifier* classifier, CvMat* sample );
3062
3063 static CvEvalBtClassifier icvEvalBtClassifier[] =
3064 {
3065     cvEvalBtClassifier2,
3066     cvEvalBtClassifier2,
3067     cvEvalBtClassifier2,
3068     cvEvalBtClassifier2,
3069     cvEvalBtClassifier2,
3070     cvEvalBtClassifierK,
3071     cvEvalBtClassifier,
3072     cvEvalBtClassifier,
3073     cvEvalBtClassifier
3074 };
3075
3076 static
3077 int cvSaveBtClassifier( CvClassifier* classifier, const char* filename )
3078 {
3079     CV_FUNCNAME( "cvSaveBtClassifier" );
3080
3081     __BEGIN__;
3082
3083     FILE* file;
3084     int i, j;
3085     CvSeqReader reader;
3086     memset(&reader, 0, sizeof(reader));
3087     CvCARTClassifier* tree;
3088
3089     CV_ASSERT( classifier );
3090     CV_ASSERT( filename );
3091
3092     if( !icvMkDir( filename ) || (file = fopen( filename, "w" )) == 0 )
3093     {
3094         CV_ERROR( CV_StsError, "Unable to create file" );
3095     }
3096
3097     if( CV_IS_TUNABLE( classifier->flags ) )
3098     {
3099         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3100     }
3101     fprintf( file, "%d %d\n%d\n%d\n", (int) ((CvBtClassifier*) classifier)->type,
3102                                       ((CvBtClassifier*) classifier)->numclasses,
3103                                       ((CvBtClassifier*) classifier)->numfeatures,
3104                                       ((CvBtClassifier*) classifier)->numiter );
3105
3106     for( i = 0; i < ((CvBtClassifier*) classifier)->numclasses *
3107                     ((CvBtClassifier*) classifier)->numiter; i++ )
3108     {
3109         if( CV_IS_TUNABLE( classifier->flags ) )
3110         {
3111             CV_READ_SEQ_ELEM( tree, reader );
3112         }
3113         else
3114         {
3115             tree = ((CvBtClassifier*) classifier)->trees[i];
3116         }
3117
3118         fprintf( file, "%d\n", tree->count );
3119         for( j = 0; j < tree->count; j++ )
3120         {
3121             fprintf( file, "%d %g %d %d\n", tree->compidx[j],
3122                                             tree->threshold[j],
3123                                             tree->left[j],
3124                                             tree->right[j] );
3125         }
3126         for( j = 0; j <= tree->count; j++ )
3127         {
3128             fprintf( file, "%g ", tree->val[j] );
3129         }
3130         fprintf( file, "\n" );
3131     }
3132
3133     fclose( file );
3134
3135     __END__;
3136
3137     return 1;
3138 }
3139
3140
3141 static
3142 void cvReleaseBtClassifier( CvClassifier** ptr )
3143 {
3144     CV_FUNCNAME( "cvReleaseBtClassifier" );
3145
3146     __BEGIN__;
3147
3148     int i;
3149
3150     if( ptr == NULL || *ptr == NULL )
3151     {
3152         CV_ERROR( CV_StsNullPtr, "" );
3153     }
3154     if( CV_IS_TUNABLE( (*ptr)->flags ) )
3155     {
3156         CvSeqReader reader;
3157         CvCARTClassifier* tree;
3158
3159         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) *ptr)->seq, &reader ) );
3160         for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3161                         ((CvBtClassifier*) *ptr)->numiter; i++ )
3162         {
3163             CV_READ_SEQ_ELEM( tree, reader );
3164             tree->release( (CvClassifier**) (&tree) );
3165         }
3166         CV_CALL( cvReleaseMemStorage( &(((CvBtClassifier*) *ptr)->seq->storage) ) );
3167     }
3168     else
3169     {
3170         CvCARTClassifier** ptree;
3171
3172         ptree = ((CvBtClassifier*) *ptr)->trees;
3173         for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3174                         ((CvBtClassifier*) *ptr)->numiter; i++ )
3175         {
3176             (*ptree)->release( (CvClassifier**) ptree );
3177             ptree++;
3178         }
3179     }
3180
3181     CV_CALL( cvFree( ptr ) );
3182     *ptr = NULL;
3183
3184     __END__;
3185 }
3186
3187 static void cvTuneBtClassifier( CvClassifier* classifier, CvMat*, int flags,
3188                          CvMat*, CvMat* , CvMat*, CvMat*, CvMat* )
3189 {
3190     CV_FUNCNAME( "cvTuneBtClassifier" );
3191
3192     __BEGIN__;
3193
3194     size_t data_size;
3195
3196     if( CV_IS_TUNABLE( flags ) )
3197     {
3198         if( !CV_IS_TUNABLE( classifier->flags ) )
3199         {
3200             CV_ERROR( CV_StsUnsupportedFormat,
3201                       "Classifier does not support tune function" );
3202         }
3203         else
3204         {
3205             /* tune classifier */
3206             CvCARTClassifier** trees;
3207
3208             printf( "Iteration %d\n", ((CvBtClassifier*) classifier)->numiter + 1 );
3209
3210             data_size = sizeof( *trees ) * ((CvBtClassifier*) classifier)->numclasses;
3211             CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3212             CV_CALL( cvBtNext( trees,
3213                 (CvBtTrainer*) ((CvBtClassifier*) classifier)->trainer ) );
3214             CV_CALL( cvSeqPushMulti( ((CvBtClassifier*) classifier)->seq,
3215                 trees, ((CvBtClassifier*) classifier)->numclasses ) );
3216             CV_CALL( cvFree( &trees ) );
3217             ((CvBtClassifier*) classifier)->numiter++;
3218         }
3219     }
3220     else
3221     {
3222         if( CV_IS_TUNABLE( classifier->flags ) )
3223         {
3224             /* convert */
3225             void* ptr;
3226
3227             assert( ((CvBtClassifier*) classifier)->seq->total ==
3228                         ((CvBtClassifier*) classifier)->numiter *
3229                         ((CvBtClassifier*) classifier)->numclasses );
3230
3231             data_size = sizeof( ((CvBtClassifier*) classifier)->trees[0] ) *
3232                 ((CvBtClassifier*) classifier)->seq->total;
3233             CV_CALL( ptr = cvAlloc( data_size ) );
3234             CV_CALL( cvCvtSeqToArray( ((CvBtClassifier*) classifier)->seq, ptr ) );
3235             CV_CALL( cvReleaseMemStorage(
3236                     &(((CvBtClassifier*) classifier)->seq->storage) ) );
3237             ((CvBtClassifier*) classifier)->trees = (CvCARTClassifier**) ptr;
3238             classifier->flags &= ~CV_TUNABLE;
3239             CV_CALL( cvBtEnd( (CvBtTrainer**)
3240                 &(((CvBtClassifier*) classifier)->trainer )) );
3241             ((CvBtClassifier*) classifier)->trainer = NULL;
3242         }
3243     }
3244
3245     __END__;
3246 }
3247
3248 static CvBtClassifier* icvAllocBtClassifier( CvBoostType type, int flags, int numclasses,
3249                                       int numiter )
3250 {
3251     CvBtClassifier* ptr;
3252     size_t data_size;
3253
3254     assert( numclasses >= 1 );
3255     assert( numiter >= 0 );
3256     assert( ( numclasses == 1 ) || (type == CV_LKCLASS) );
3257
3258     data_size = sizeof( *ptr );
3259     ptr = (CvBtClassifier*) cvAlloc( data_size );
3260     memset( ptr, 0, data_size );
3261
3262     if( CV_IS_TUNABLE( flags ) )
3263     {
3264         ptr->seq = cvCreateSeq( 0, sizeof( *(ptr->seq) ), sizeof( *(ptr->trees) ),
3265                                 cvCreateMemStorage() );
3266         ptr->numiter = 0;
3267     }
3268     else
3269     {
3270         data_size = numclasses * numiter * sizeof( *(ptr->trees) );
3271         ptr->trees = (CvCARTClassifier**) cvAlloc( data_size );
3272         memset( ptr->trees, 0, data_size );
3273
3274         ptr->numiter = numiter;
3275     }
3276
3277     ptr->flags = flags;
3278     ptr->numclasses = numclasses;
3279     ptr->type = type;
3280
3281     ptr->eval = icvEvalBtClassifier[(int) type];
3282     ptr->tune = cvTuneBtClassifier;
3283     ptr->save = cvSaveBtClassifier;
3284     ptr->release = cvReleaseBtClassifier;
3285
3286     return ptr;
3287 }
3288
3289 CV_BOOST_IMPL
3290 CvClassifier* cvCreateBtClassifier( CvMat* trainData,
3291                                     int flags,
3292                                     CvMat* trainClasses,
3293                                     CvMat* typeMask,
3294                                     CvMat* missedMeasurementsMask,
3295                                     CvMat* compIdx,
3296                                     CvMat* sampleIdx,
3297                                     CvMat* weights,
3298                                     CvClassifierTrainParams* trainParams )
3299 {
3300     CvBtClassifier* ptr = 0;
3301
3302     CV_FUNCNAME( "cvCreateBtClassifier" );
3303
3304     __BEGIN__;
3305     CvBoostType type;
3306     int num_classes;
3307     int num_iter;
3308     int i;
3309     CvCARTClassifier** trees;
3310     size_t data_size;
3311
3312     CV_ASSERT( trainData != NULL );
3313     CV_ASSERT( trainClasses != NULL );
3314     CV_ASSERT( typeMask == NULL );
3315     CV_ASSERT( missedMeasurementsMask == NULL );
3316     CV_ASSERT( compIdx == NULL );
3317     CV_ASSERT( weights == NULL );
3318     CV_ASSERT( trainParams != NULL );
3319
3320     type = ((CvBtClassifierTrainParams*) trainParams)->type;
3321
3322     if( type >= CV_DABCLASS && type <= CV_GABCLASS && sampleIdx )
3323     {
3324         CV_ERROR( CV_StsBadArg, "Sample indices are not supported for this type" );
3325     }
3326
3327     if( type == CV_LKCLASS )
3328     {
3329         double min_val;
3330         double max_val;
3331
3332         cvMinMaxLoc( trainClasses, &min_val, &max_val );
3333         num_classes = (int) (max_val + 1.0);
3334
3335         CV_ASSERT( num_classes >= 2 );
3336     }
3337     else
3338     {
3339         num_classes = 1;
3340     }
3341     num_iter = ((CvBtClassifierTrainParams*) trainParams)->numiter;
3342
3343     CV_ASSERT( num_iter > 0 );
3344
3345     ptr = icvAllocBtClassifier( type, CV_TUNABLE | flags, num_classes, num_iter );
3346     ptr->numfeatures = (CV_IS_ROW_SAMPLE( flags )) ? trainData->cols : trainData->rows;
3347
3348     i = 0;
3349
3350     printf( "Iteration %d\n", 1 );
3351
3352     data_size = sizeof( *trees ) * ptr->numclasses;
3353     CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3354
3355     CV_CALL( ptr->trainer = cvBtStart( trees, trainData, flags, trainClasses, sampleIdx,
3356         ((CvBtClassifierTrainParams*) trainParams)->numsplits, type, num_classes,
3357         &(((CvBtClassifierTrainParams*) trainParams)->param[0]) ) );
3358
3359     CV_CALL( cvSeqPushMulti( ptr->seq, trees, ptr->numclasses ) );
3360     CV_CALL( cvFree( &trees ) );
3361     ptr->numiter++;
3362
3363     for( i = 1; i < num_iter; i++ )
3364     {
3365         ptr->tune( (CvClassifier*) ptr, NULL, CV_TUNABLE, NULL, NULL, NULL, NULL, NULL );
3366     }
3367     if( !CV_IS_TUNABLE( flags ) )
3368     {
3369         /* convert */
3370         ptr->tune( (CvClassifier*) ptr, NULL, 0, NULL, NULL, NULL, NULL, NULL );
3371     }
3372
3373     __END__;
3374
3375     return (CvClassifier*) ptr;
3376 }
3377
3378 CV_BOOST_IMPL
3379 CvClassifier* cvCreateBtClassifierFromFile( const char* filename )
3380 {
3381     CvBtClassifier* ptr = 0;
3382
3383     CV_FUNCNAME( "cvCreateBtClassifierFromFile" );
3384
3385     __BEGIN__;
3386
3387     FILE* file;
3388     int i, j;
3389     int data_size;
3390     int num_classifiers;
3391     int num_features;
3392     int num_classes;
3393     int type;
3394     int values_read = -1;
3395
3396     CV_ASSERT( filename != NULL );
3397
3398     ptr = NULL;
3399     file = fopen( filename, "r" );
3400     if( !file )
3401     {
3402         CV_ERROR( CV_StsError, "Unable to open file" );
3403     }
3404
3405     values_read = fscanf( file, "%d %d %d %d", &type, &num_classes, &num_features, &num_classifiers );
3406     CV_Assert(values_read == 4);
3407
3408     CV_ASSERT( type >= (int) CV_DABCLASS && type <= (int) CV_MREG );
3409     CV_ASSERT( num_features > 0 );
3410     CV_ASSERT( num_classifiers > 0 );
3411
3412     if( (CvBoostType) type != CV_LKCLASS )
3413     {
3414         num_classes = 1;
3415     }
3416     ptr = icvAllocBtClassifier( (CvBoostType) type, 0, num_classes, num_classifiers );
3417     ptr->numfeatures = num_features;
3418
3419     for( i = 0; i < num_classes * num_classifiers; i++ )
3420     {
3421         int count;
3422         CvCARTClassifier* tree;
3423
3424         values_read = fscanf( file, "%d", &count );
3425         CV_Assert(values_read == 1);
3426
3427         data_size = sizeof( *tree )
3428             + count * ( sizeof( *(tree->compidx) ) + sizeof( *(tree->threshold) ) +
3429                         sizeof( *(tree->right) ) + sizeof( *(tree->left) ) )
3430             + (count + 1) * ( sizeof( *(tree->val) ) );
3431         CV_CALL( tree = (CvCARTClassifier*) cvAlloc( data_size ) );
3432         memset( tree, 0, data_size );
3433         tree->eval = cvEvalCARTClassifier;
3434         tree->tune = NULL;
3435         tree->save = NULL;
3436         tree->release = cvReleaseCARTClassifier;
3437         tree->compidx = (int*) ( tree + 1 );
3438         tree->threshold = (float*) ( tree->compidx + count );
3439         tree->left = (int*) ( tree->threshold + count );
3440         tree->right = (int*) ( tree->left + count );
3441         tree->val = (float*) ( tree->right + count );
3442
3443         tree->count = count;
3444         for( j = 0; j < tree->count; j++ )
3445         {
3446             values_read = fscanf( file, "%d %g %d %d", &(tree->compidx[j]),
3447                                          &(tree->threshold[j]),
3448                                          &(tree->left[j]),
3449                                          &(tree->right[j]) );
3450             CV_Assert(values_read == 4);
3451         }
3452         for( j = 0; j <= tree->count; j++ )
3453         {
3454             values_read = fscanf( file, "%g", &(tree->val[j]) );
3455             CV_Assert(values_read == 1);
3456         }
3457         ptr->trees[i] = tree;
3458     }
3459
3460     fclose( file );
3461
3462     __END__;
3463
3464     return (CvClassifier*) ptr;
3465 }
3466
3467 /****************************************************************************************\
3468 *                                    Utility functions                                   *
3469 \****************************************************************************************/
3470
3471 CV_BOOST_IMPL
3472 CvMat* cvTrimWeights( CvMat* weights, CvMat* idx, float factor )
3473 {
3474     CvMat* ptr = 0;
3475
3476     CV_FUNCNAME( "cvTrimWeights" );
3477     __BEGIN__;
3478     int i, index, num;
3479     float sum_weights;
3480     uchar* wdata;
3481     size_t wstep;
3482     int wnum;
3483     float threshold;
3484     int count;
3485     float* sorted_weights;
3486
3487     CV_ASSERT( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
3488
3489     ptr = idx;
3490     sorted_weights = NULL;
3491
3492     if( factor > 0.0F && factor < 1.0F )
3493     {
3494         size_t data_size;
3495
3496         CV_MAT2VEC( *weights, wdata, wstep, wnum );
3497         num = ( idx == NULL ) ? wnum : MAX( idx->rows, idx->cols );
3498
3499         data_size = num * sizeof( *sorted_weights );
3500         sorted_weights = (float*) cvAlloc( data_size );
3501         memset( sorted_weights, 0, data_size );
3502
3503         sum_weights = 0.0F;
3504         for( i = 0; i < num; i++ )
3505         {
3506             index = icvGetIdxAt( idx, i );
3507             sorted_weights[i] = *((float*) (wdata + index * wstep));
3508             sum_weights += sorted_weights[i];
3509         }
3510
3511         std::sort(sorted_weights, sorted_weights + num);
3512
3513         sum_weights *= (1.0F - factor);
3514
3515         i = -1;
3516         do { sum_weights -= sorted_weights[++i]; }
3517         while( sum_weights > 0.0F && i < (num - 1) );
3518
3519         threshold = sorted_weights[i];
3520
3521         while( i > 0 && sorted_weights[i-1] == threshold ) i--;
3522
3523         if( i > 0 || ( idx != NULL && CV_MAT_TYPE( idx->type ) != CV_32FC1 ) )
3524         {
3525             CV_CALL( ptr = cvCreateMat( 1, num - i, CV_32FC1 ) );
3526             count = 0;
3527             for( i = 0; i < num; i++ )
3528             {
3529                 index = icvGetIdxAt( idx, i );
3530                 if( *((float*) (wdata + index * wstep)) >= threshold )
3531                 {
3532                     CV_MAT_ELEM( *ptr, float, 0, count ) = (float) index;
3533                     count++;
3534                 }
3535             }
3536
3537             assert( count == ptr->cols );
3538         }
3539         cvFree( &sorted_weights );
3540     }
3541
3542     __END__;
3543
3544     return ptr;
3545 }
3546
3547
3548 CV_BOOST_IMPL
3549 void cvReadTrainData( const char* filename, int flags,
3550                       CvMat** trainData,
3551                       CvMat** trainClasses )
3552 {
3553
3554     CV_FUNCNAME( "cvReadTrainData" );
3555
3556     __BEGIN__;
3557
3558     FILE* file;
3559     int m, n;
3560     int i, j;
3561     float val;
3562     int values_read = -1;
3563
3564     if( filename == NULL )
3565     {
3566         CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3567     }
3568     if( trainData == NULL )
3569     {
3570         CV_ERROR( CV_StsNullPtr, "trainData must be not NULL" );
3571     }
3572     if( trainClasses == NULL )
3573     {
3574         CV_ERROR( CV_StsNullPtr, "trainClasses must be not NULL" );
3575     }
3576
3577     *trainData = NULL;
3578     *trainClasses = NULL;
3579     file = fopen( filename, "r" );
3580     if( !file )
3581     {
3582         CV_ERROR( CV_StsError, "Unable to open file" );
3583     }
3584
3585     values_read = fscanf( file, "%d %d", &m, &n );
3586     CV_Assert(values_read == 2);
3587
3588     if( CV_IS_ROW_SAMPLE( flags ) )
3589     {
3590         CV_CALL( *trainData = cvCreateMat( m, n, CV_32FC1 ) );
3591     }
3592     else
3593     {
3594         CV_CALL( *trainData = cvCreateMat( n, m, CV_32FC1 ) );
3595     }
3596
3597     CV_CALL( *trainClasses = cvCreateMat( 1, m, CV_32FC1 ) );
3598
3599     for( i = 0; i < m; i++ )
3600     {
3601         for( j = 0; j < n; j++ )
3602         {
3603             values_read = fscanf( file, "%f", &val );
3604             CV_Assert(values_read == 1);
3605             if( CV_IS_ROW_SAMPLE( flags ) )
3606             {
3607                 CV_MAT_ELEM( **trainData, float, i, j ) = val;
3608             }
3609             else
3610             {
3611                 CV_MAT_ELEM( **trainData, float, j, i ) = val;
3612             }
3613         }
3614         values_read = fscanf( file, "%f", &val );
3615         CV_Assert(values_read == 2);
3616         CV_MAT_ELEM( **trainClasses, float, 0, i ) = val;
3617     }
3618
3619     fclose( file );
3620
3621     __END__;
3622
3623 }
3624
3625 CV_BOOST_IMPL
3626 void cvWriteTrainData( const char* filename, int flags,
3627                        CvMat* trainData, CvMat* trainClasses, CvMat* sampleIdx )
3628 {
3629     CV_FUNCNAME( "cvWriteTrainData" );
3630
3631     __BEGIN__;
3632
3633     FILE* file;
3634     int m, n;
3635     int i, j;
3636     int clsrow;
3637     int count;
3638     int idx;
3639     CvScalar sc;
3640
3641     if( filename == NULL )
3642     {
3643         CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3644     }
3645     if( trainData == NULL || CV_MAT_TYPE( trainData->type ) != CV_32FC1 )
3646     {
3647         CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainData" );
3648     }
3649     if( CV_IS_ROW_SAMPLE( flags ) )
3650     {
3651         m = trainData->rows;
3652         n = trainData->cols;
3653     }
3654     else
3655     {
3656         n = trainData->rows;
3657         m = trainData->cols;
3658     }
3659     if( trainClasses == NULL || CV_MAT_TYPE( trainClasses->type ) != CV_32FC1 ||
3660         MIN( trainClasses->rows, trainClasses->cols ) != 1 )
3661     {
3662         CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainClasses" );
3663     }
3664     clsrow = (trainClasses->rows == 1);
3665     if( m != ( (clsrow) ? trainClasses->cols : trainClasses->rows ) )
3666     {
3667         CV_ERROR( CV_StsUnmatchedSizes, "Incorrect trainData and trainClasses sizes" );
3668     }
3669
3670     if( sampleIdx != NULL )
3671     {
3672         count = (sampleIdx->rows == 1) ? sampleIdx->cols : sampleIdx->rows;
3673     }
3674     else
3675     {
3676         count = m;
3677     }
3678
3679
3680     file = fopen( filename, "w" );
3681     if( !file )
3682     {
3683         CV_ERROR( CV_StsError, "Unable to create file" );
3684     }
3685
3686     fprintf( file, "%d %d\n", count, n );
3687
3688     for( i = 0; i < count; i++ )
3689     {
3690         if( sampleIdx )
3691         {
3692             if( sampleIdx->rows == 1 )
3693             {
3694                 sc = cvGet2D( sampleIdx, 0, i );
3695             }
3696             else
3697             {
3698                 sc = cvGet2D( sampleIdx, i, 0 );
3699             }
3700             idx = (int) sc.val[0];
3701         }
3702         else
3703         {
3704             idx = i;
3705         }
3706         for( j = 0; j < n; j++ )
3707         {
3708             fprintf( file, "%g ", ( (CV_IS_ROW_SAMPLE( flags ))
3709                                     ? CV_MAT_ELEM( *trainData, float, idx, j )
3710                                     : CV_MAT_ELEM( *trainData, float, j, idx ) ) );
3711         }
3712         fprintf( file, "%g\n", ( (clsrow)
3713                                 ? CV_MAT_ELEM( *trainClasses, float, 0, idx )
3714                                 : CV_MAT_ELEM( *trainClasses, float, idx, 0 ) ) );
3715     }
3716
3717     fclose( file );
3718
3719     __END__;
3720 }
3721
3722
3723 #define ICV_RAND_SHUFFLE( suffix, type )                                                 \
3724 static void icvRandShuffle_##suffix( uchar* data, size_t step, int num )                 \
3725 {                                                                                        \
3726     time_t seed;                                                                         \
3727     type tmp;                                                                            \
3728     int i;                                                                               \
3729     float rn;                                                                            \
3730                                                                                          \
3731     time( &seed );                                                                       \
3732     CvRNG state = cvRNG((int)seed);                                                      \
3733                                                                                          \
3734     for( i = 0; i < (num-1); i++ )                                                       \
3735     {                                                                                    \
3736         rn = ((float) cvRandInt( &state )) / (1.0F + UINT_MAX);                          \
3737         CV_SWAP( *((type*)(data + i * step)),                                            \
3738                  *((type*)(data + ( i + (int)( rn * (num - i ) ) )* step)),              \
3739                  tmp );                                                                  \
3740     }                                                                                    \
3741 }
3742
3743 ICV_RAND_SHUFFLE( 8U, uchar )
3744
3745 ICV_RAND_SHUFFLE( 16S, short )
3746
3747 ICV_RAND_SHUFFLE( 32S, int )
3748
3749 ICV_RAND_SHUFFLE( 32F, float )
3750
3751 CV_BOOST_IMPL
3752 void cvRandShuffleVec( CvMat* mat )
3753 {
3754     CV_FUNCNAME( "cvRandShuffle" );
3755
3756     __BEGIN__;
3757
3758     uchar* data;
3759     size_t step;
3760     int num;
3761
3762     if( (mat == NULL) || !CV_IS_MAT( mat ) || MIN( mat->rows, mat->cols ) != 1 )
3763     {
3764         CV_ERROR( CV_StsUnsupportedFormat, "" );
3765     }
3766
3767     CV_MAT2VEC( *mat, data, step, num );
3768     switch( CV_MAT_TYPE( mat->type ) )
3769     {
3770         case CV_8UC1:
3771             icvRandShuffle_8U( data, step, num);
3772             break;
3773         case CV_16SC1:
3774             icvRandShuffle_16S( data, step, num);
3775             break;
3776         case CV_32SC1:
3777             icvRandShuffle_32S( data, step, num);
3778             break;
3779         case CV_32FC1:
3780             icvRandShuffle_32F( data, step, num);
3781             break;
3782         default:
3783             CV_ERROR( CV_StsUnsupportedFormat, "" );
3784     }
3785
3786     __END__;
3787 }
3788
3789 /* End of file. */