1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
62 #include "_cvcommon.h"
63 #include "cvclassifier.h"
71 typedef struct CvValArray
77 template<typename T, typename Idx>
78 class LessThanValArray
81 LessThanValArray( const T* _aux ) : aux(_aux) {}
82 bool operator()(Idx a, Idx b) const
84 return *( (float*) (aux->data + ((int) (a)) * aux->step ) ) <
85 *( (float*) (aux->data + ((int) (b)) * aux->step ) );
91 void cvGetSortedIndices( CvMat* val, CvMat* idx, int sortcols )
102 CV_Assert( idx != NULL );
103 CV_Assert( val != NULL );
105 idxtype = CV_MAT_TYPE( idx->type );
106 CV_Assert( idxtype == CV_16SC1 || idxtype == CV_32SC1 || idxtype == CV_32FC1 );
107 CV_Assert( CV_MAT_TYPE( val->type ) == CV_32FC1 );
110 CV_Assert( idx->rows == val->cols );
111 CV_Assert( idx->cols == val->rows );
112 istep = CV_ELEM_SIZE( val->type );
117 CV_Assert( idx->rows == val->rows );
118 CV_Assert( idx->cols == val->cols );
120 jstep = CV_ELEM_SIZE( val->type );
123 va.data = val->data.ptr;
128 for( i = 0; i < idx->rows; i++ )
130 for( j = 0; j < idx->cols; j++ )
132 CV_MAT_ELEM( *idx, short, i, j ) = (short) j;
134 std::sort((short*) (idx->data.ptr + (size_t)i * idx->step),
135 (short*) (idx->data.ptr + (size_t)i * idx->step) + idx->cols,
136 LessThanValArray<CvValArray, short>(&va));
142 for( i = 0; i < idx->rows; i++ )
144 for( j = 0; j < idx->cols; j++ )
146 CV_MAT_ELEM( *idx, int, i, j ) = j;
148 std::sort((int*) (idx->data.ptr + (size_t)i * idx->step),
149 (int*) (idx->data.ptr + (size_t)i * idx->step) + idx->cols,
150 LessThanValArray<CvValArray, int>(&va));
156 for( i = 0; i < idx->rows; i++ )
158 for( j = 0; j < idx->cols; j++ )
160 CV_MAT_ELEM( *idx, float, i, j ) = (float) j;
162 std::sort((float*) (idx->data.ptr + (size_t)i * idx->step),
163 (float*) (idx->data.ptr + (size_t)i * idx->step) + idx->cols,
164 LessThanValArray<CvValArray, float>(&va));
176 void cvReleaseStumpClassifier( CvClassifier** classifier )
178 cvFree( classifier );
183 float cvEvalStumpClassifier( CvClassifier* classifier, CvMat* sample )
185 assert( classifier != NULL );
186 assert( sample != NULL );
187 assert( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
189 if( (CV_MAT_ELEM( (*sample), float, 0,
190 ((CvStumpClassifier*) classifier)->compidx )) <
191 ((CvStumpClassifier*) classifier)->threshold )
192 return ((CvStumpClassifier*) classifier)->left;
193 return ((CvStumpClassifier*) classifier)->right;
196 #define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error ) \
197 static int icvFindStumpThreshold_##suffix( \
198 uchar* data, size_t datastep, \
199 uchar* wdata, size_t wstep, \
200 uchar* ydata, size_t ystep, \
201 uchar* idxdata, size_t idxstep, int num, \
204 float* threshold, float* left, float* right, \
205 float* sumw, float* sumwy, float* sumwyy ) \
214 float curleft = 0.0F; \
215 float curright = 0.0F; \
216 float* prevval = NULL; \
217 float* curval = NULL; \
218 float curlerror = 0.0F; \
219 float currerror = 0.0F; \
224 if( *sumw == FLT_MAX ) \
226 /* calculate sums */ \
234 for( i = 0; i < num; i++ ) \
236 idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
237 w = (float*) (wdata + idx * wstep); \
239 y = (float*) (ydata + idx * ystep); \
242 *sumwyy += wy * (*y); \
246 for( i = 0; i < num; i++ ) \
248 idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
249 curval = (float*) (data + idx * datastep); \
250 /* for debug purpose */ \
251 if( i > 0 ) assert( (*prevval) <= (*curval) ); \
253 wyr = *sumwy - wyl; \
256 if( wl > 0.0 ) curleft = wyl / wl; \
257 else curleft = 0.0F; \
259 if( wr > 0.0 ) curright = wyr / wr; \
260 else curright = 0.0F; \
264 if( curlerror + currerror < (*lerror) + (*rerror) ) \
266 (*lerror) = curlerror; \
267 (*rerror) = currerror; \
268 *threshold = *curval; \
270 *threshold = 0.5F * (*threshold + *prevval); \
279 wl += *((float*) (wdata + idx * wstep)); \
280 wyl += (*((float*) (wdata + idx * wstep))) \
281 * (*((float*) (ydata + idx * ystep))); \
282 wyyl += *((float*) (wdata + idx * wstep)) \
283 * (*((float*) (ydata + idx * ystep))) \
284 * (*((float*) (ydata + idx * ystep))); \
286 while( (++i) < num && \
287 ( *((float*) (data + (idx = \
288 (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep)) \
292 } /* for each value */ \
297 /* misclassification error
298 * err = MIN( wpos, wneg );
300 #define ICV_DEF_FIND_STUMP_THRESHOLD_MISC( suffix, type ) \
301 ICV_DEF_FIND_STUMP_THRESHOLD( misc_##suffix, type, \
302 float wposl = 0.5F * ( wl + wyl ); \
303 float wposr = 0.5F * ( wr + wyr ); \
304 curleft = 0.5F * ( 1.0F + curleft ); \
305 curright = 0.5F * ( 1.0F + curright ); \
306 curlerror = MIN( wposl, wl - wposl ); \
307 currerror = MIN( wposr, wr - wposr ); \
311 * err = 2 * wpos * wneg /(wpos + wneg)
313 #define ICV_DEF_FIND_STUMP_THRESHOLD_GINI( suffix, type ) \
314 ICV_DEF_FIND_STUMP_THRESHOLD( gini_##suffix, type, \
315 float wposl = 0.5F * ( wl + wyl ); \
316 float wposr = 0.5F * ( wr + wyr ); \
317 curleft = 0.5F * ( 1.0F + curleft ); \
318 curright = 0.5F * ( 1.0F + curright ); \
319 curlerror = 2.0F * wposl * ( 1.0F - curleft ); \
320 currerror = 2.0F * wposr * ( 1.0F - curright ); \
323 #define CV_ENTROPY_THRESHOLD FLT_MIN
326 * err = - wpos * log(wpos / (wpos + wneg)) - wneg * log(wneg / (wpos + wneg))
328 #define ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( suffix, type ) \
329 ICV_DEF_FIND_STUMP_THRESHOLD( entropy_##suffix, type, \
330 float wposl = 0.5F * ( wl + wyl ); \
331 float wposr = 0.5F * ( wr + wyr ); \
332 curleft = 0.5F * ( 1.0F + curleft ); \
333 curright = 0.5F * ( 1.0F + curright ); \
334 curlerror = currerror = 0.0F; \
335 if( curleft > CV_ENTROPY_THRESHOLD ) \
336 curlerror -= wposl * logf( curleft ); \
337 if( curleft < 1.0F - CV_ENTROPY_THRESHOLD ) \
338 curlerror -= (wl - wposl) * logf( 1.0F - curleft ); \
340 if( curright > CV_ENTROPY_THRESHOLD ) \
341 currerror -= wposr * logf( curright ); \
342 if( curright < 1.0F - CV_ENTROPY_THRESHOLD ) \
343 currerror -= (wr - wposr) * logf( 1.0F - curright ); \
346 /* least sum of squares error */
347 #define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type ) \
348 ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type, \
349 /* calculate error (sum of squares) */ \
350 /* err = sum( w * (y - left(rigt)Val)^2 ) */ \
351 curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl; \
352 currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
355 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 16s, short )
357 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32s, int )
359 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32f, float )
362 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 16s, short )
364 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32s, int )
366 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32f, float )
369 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 16s, short )
371 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32s, int )
373 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32f, float )
376 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 16s, short )
378 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32s, int )
380 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32f, float )
382 typedef int (*CvFindThresholdFunc)( uchar* data, size_t datastep,
383 uchar* wdata, size_t wstep,
384 uchar* ydata, size_t ystep,
385 uchar* idxdata, size_t idxstep, int num,
388 float* threshold, float* left, float* right,
389 float* sumw, float* sumwy, float* sumwyy );
391 CvFindThresholdFunc findStumpThreshold_16s[4] = {
392 icvFindStumpThreshold_misc_16s,
393 icvFindStumpThreshold_gini_16s,
394 icvFindStumpThreshold_entropy_16s,
395 icvFindStumpThreshold_sq_16s
398 CvFindThresholdFunc findStumpThreshold_32s[4] = {
399 icvFindStumpThreshold_misc_32s,
400 icvFindStumpThreshold_gini_32s,
401 icvFindStumpThreshold_entropy_32s,
402 icvFindStumpThreshold_sq_32s
405 CvFindThresholdFunc findStumpThreshold_32f[4] = {
406 icvFindStumpThreshold_misc_32f,
407 icvFindStumpThreshold_gini_32f,
408 icvFindStumpThreshold_entropy_32f,
409 icvFindStumpThreshold_sq_32f
413 CvClassifier* cvCreateStumpClassifier( CvMat* trainData,
417 CvMat* missedMeasurementsMask,
421 CvClassifierTrainParams* trainParams
424 CvStumpClassifier* stump = NULL;
425 int m = 0; /* number of samples */
426 int n = 0; /* number of components */
432 uchar* idxdata = NULL;
434 int l = 0; /* number of indices */
441 float sumw = FLT_MAX;
442 float sumwy = FLT_MAX;
443 float sumwyy = FLT_MAX;
445 CV_Assert( trainData != NULL );
446 CV_Assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
447 CV_Assert( trainClasses != NULL );
448 CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
449 CV_Assert( missedMeasurementsMask == NULL );
450 CV_Assert( compIdx == NULL );
451 CV_Assert( weights != NULL );
452 CV_Assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
453 CV_Assert( trainParams != NULL );
455 data = trainData->data.ptr;
456 if( CV_IS_ROW_SAMPLE( flags ) )
458 cstep = CV_ELEM_SIZE( trainData->type );
459 sstep = trainData->step;
465 sstep = CV_ELEM_SIZE( trainData->type );
466 cstep = trainData->step;
471 ydata = trainClasses->data.ptr;
472 if( trainClasses->rows == 1 )
474 assert( trainClasses->cols == m );
475 ystep = CV_ELEM_SIZE( trainClasses->type );
479 assert( trainClasses->rows == m );
480 ystep = trainClasses->step;
483 wdata = weights->data.ptr;
484 if( weights->rows == 1 )
486 assert( weights->cols == m );
487 wstep = CV_ELEM_SIZE( weights->type );
491 assert( weights->rows == m );
492 wstep = weights->step;
496 if( sampleIdx != NULL )
498 assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
500 idxdata = sampleIdx->data.ptr;
501 if( sampleIdx->rows == 1 )
504 idxstep = CV_ELEM_SIZE( sampleIdx->type );
509 idxstep = sampleIdx->step;
514 idx = (int*) cvAlloc( l * sizeof( int ) );
515 stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
518 memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
520 stump->eval = cvEvalStumpClassifier;
523 stump->release = cvReleaseStumpClassifier;
525 stump->lerror = FLT_MAX;
526 stump->rerror = FLT_MAX;
531 if( sampleIdx != NULL )
533 for( i = 0; i < l; i++ )
535 idx[i] = (int) *((float*) (idxdata + i*idxstep));
540 for( i = 0; i < l; i++ )
546 for( i = 0; i < n; i++ )
550 va.data = data + i * ((size_t) cstep);
552 std::sort(idx, idx + l, LessThanValArray<CvValArray, int>(&va));
553 if( findStumpThreshold_32s[(int) ((CvStumpTrainParams*) trainParams)->error]
554 ( data + i * ((size_t) cstep), sstep,
555 wdata, wstep, ydata, ystep, (uchar*) idx, sizeof( int ), l,
556 &(stump->lerror), &(stump->rerror),
557 &(stump->threshold), &(stump->left), &(stump->right),
558 &sumw, &sumwy, &sumwyy ) )
562 } /* for each component */
568 if( ((CvStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
570 stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
571 stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
574 return (CvClassifier*) stump;
578 * cvCreateMTStumpClassifier
580 * Multithreaded stump classifier constructor
581 * Includes huge train data support through callback function
584 CvClassifier* cvCreateMTStumpClassifier( CvMat* trainData,
588 CvMat* missedMeasurementsMask,
592 CvClassifierTrainParams* trainParams )
594 CvStumpClassifier* stump = NULL;
595 int m = 0; /* number of samples */
596 int n = 0; /* number of components */
600 int datan = 0; /* num components */
603 uchar* idxdata = NULL;
605 int l = 0; /* number of indices */
609 uchar* sorteddata = NULL;
611 size_t sortedcstep = 0; /* component step */
612 size_t sortedsstep = 0; /* sample step */
613 int sortedn = 0; /* num components */
614 int sortedm = 0; /* num samples */
623 /* private variables */
652 /* end private variables */
654 CV_Assert( trainParams != NULL );
655 CV_Assert( trainClasses != NULL );
656 CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
657 CV_Assert( missedMeasurementsMask == NULL );
658 CV_Assert( compIdx == NULL );
660 stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;
662 ydata = trainClasses->data.ptr;
663 if( trainClasses->rows == 1 )
665 m = trainClasses->cols;
666 ystep = CV_ELEM_SIZE( trainClasses->type );
670 m = trainClasses->rows;
671 ystep = trainClasses->step;
674 wdata = weights->data.ptr;
675 if( weights->rows == 1 )
677 CV_Assert( weights->cols == m );
678 wstep = CV_ELEM_SIZE( weights->type );
682 CV_Assert( weights->rows == m );
683 wstep = weights->step;
686 if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
689 CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
690 assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
691 || sortedtype == CV_32FC1 );
692 sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
693 sortedsstep = CV_ELEM_SIZE( sortedtype );
694 sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
695 sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
696 sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
699 if( trainData == NULL )
701 assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
702 n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
707 assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
708 data = trainData->data.ptr;
709 if( CV_IS_ROW_SAMPLE( flags ) )
711 cstep = CV_ELEM_SIZE( trainData->type );
712 sstep = trainData->step;
713 assert( m == trainData->rows );
714 datan = n = trainData->cols;
718 sstep = CV_ELEM_SIZE( trainData->type );
719 cstep = trainData->step;
720 assert( m == trainData->cols );
721 datan = n = trainData->rows;
723 if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
725 n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
728 assert( datan <= n );
730 if( sampleIdx != NULL )
732 assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
733 idxdata = sampleIdx->data.ptr;
734 idxstep = ( sampleIdx->rows == 1 )
735 ? CV_ELEM_SIZE( sampleIdx->type ) : sampleIdx->step;
736 l = ( sampleIdx->rows == 1 ) ? sampleIdx->cols : sampleIdx->rows;
738 if( sorteddata != NULL )
740 filter = (char*) cvAlloc( sizeof( char ) * m );
741 memset( (void*) filter, 0, sizeof( char ) * m );
742 for( i = 0; i < l; i++ )
744 filter[(int) *((float*) (idxdata + i * idxstep))] = (char) 1;
753 stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
756 memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
758 portion = ((CvMTStumpTrainParams*)trainParams)->portion;
765 portion /= omp_get_max_threads();
769 stump->eval = cvEvalStumpClassifier;
772 stump->release = cvReleaseStumpClassifier;
774 stump->lerror = FLT_MAX;
775 stump->rerror = FLT_MAX;
781 #pragma omp parallel private(mat, va, lerror, rerror, left, right, threshold, \
782 optcompidx, sumw, sumwy, sumwyy, t_compidx, t_n, \
783 ti, tj, tk, t_data, t_cstep, t_sstep, matcstep, \
818 /* prepare matrix for callback */
819 if( CV_IS_ROW_SAMPLE( flags ) )
821 mat = cvMat( m, portion, CV_32FC1, 0 );
822 matcstep = CV_ELEM_SIZE( mat.type );
827 mat = cvMat( portion, m, CV_32FC1, 0 );
829 matsstep = CV_ELEM_SIZE( mat.type );
831 mat.data.ptr = (uchar*) cvAlloc( sizeof( float ) * mat.rows * mat.cols );
834 if( filter != NULL || sortedn < n )
836 t_idx = (int*) cvAlloc( sizeof( int ) * m );
837 if( sortedn == 0 || filter == NULL )
839 if( idxdata != NULL )
841 for( ti = 0; ti < l; ti++ )
843 t_idx[ti] = (int) *((float*) (idxdata + ti * idxstep));
848 for( ti = 0; ti < l; ti++ )
857 #pragma omp critical(c_compidx)
863 while( t_compidx < n )
866 if( t_compidx < datan )
868 t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
875 t_n = ( t_n < (n - t_compidx) ) ? t_n : (n - t_compidx);
878 t_data = mat.data.ptr - t_compidx * ((size_t) t_cstep );
880 /* calculate components */
881 ((CvMTStumpTrainParams*)trainParams)->getTrainData( &mat,
882 sampleIdx, compIdx, t_compidx, t_n,
883 ((CvMTStumpTrainParams*)trainParams)->userdata );
886 if( sorteddata != NULL )
890 /* have sorted indices and filter */
894 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
897 for( tj = 0; tj < sortedm; tj++ )
899 int curidx = (int) ( *((short*) (sorteddata
900 + ti * sortedcstep + tj * sortedsstep)) );
901 if( filter[curidx] != 0 )
903 t_idx[tk++] = curidx;
906 if( findStumpThreshold_32s[stumperror](
907 t_data + ti * t_cstep, t_sstep,
908 wdata, wstep, ydata, ystep,
909 (uchar*) t_idx, sizeof( int ), tk,
911 &threshold, &left, &right,
912 &sumw, &sumwy, &sumwyy ) )
919 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
922 for( tj = 0; tj < sortedm; tj++ )
924 int curidx = (int) ( *((int*) (sorteddata
925 + ti * sortedcstep + tj * sortedsstep)) );
926 if( filter[curidx] != 0 )
928 t_idx[tk++] = curidx;
931 if( findStumpThreshold_32s[stumperror](
932 t_data + ti * t_cstep, t_sstep,
933 wdata, wstep, ydata, ystep,
934 (uchar*) t_idx, sizeof( int ), tk,
936 &threshold, &left, &right,
937 &sumw, &sumwy, &sumwyy ) )
944 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
947 for( tj = 0; tj < sortedm; tj++ )
949 int curidx = (int) ( *((float*) (sorteddata
950 + ti * sortedcstep + tj * sortedsstep)) );
951 if( filter[curidx] != 0 )
953 t_idx[tk++] = curidx;
956 if( findStumpThreshold_32s[stumperror](
957 t_data + ti * t_cstep, t_sstep,
958 wdata, wstep, ydata, ystep,
959 (uchar*) t_idx, sizeof( int ), tk,
961 &threshold, &left, &right,
962 &sumw, &sumwy, &sumwyy ) )
975 /* have sorted indices */
979 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
981 if( findStumpThreshold_16s[stumperror](
982 t_data + ti * t_cstep, t_sstep,
983 wdata, wstep, ydata, ystep,
984 sorteddata + ti * sortedcstep, sortedsstep, sortedm,
986 &threshold, &left, &right,
987 &sumw, &sumwy, &sumwyy ) )
994 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
996 if( findStumpThreshold_32s[stumperror](
997 t_data + ti * t_cstep, t_sstep,
998 wdata, wstep, ydata, ystep,
999 sorteddata + ti * sortedcstep, sortedsstep, sortedm,
1001 &threshold, &left, &right,
1002 &sumw, &sumwy, &sumwyy ) )
1009 for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
1011 if( findStumpThreshold_32f[stumperror](
1012 t_data + ti * t_cstep, t_sstep,
1013 wdata, wstep, ydata, ystep,
1014 sorteddata + ti * sortedcstep, sortedsstep, sortedm,
1016 &threshold, &left, &right,
1017 &sumw, &sumwy, &sumwyy ) )
1030 ti = MAX( t_compidx, MIN( sortedn, t_compidx + t_n ) );
1031 for( ; ti < t_compidx + t_n; ti++ )
1033 va.data = t_data + ti * t_cstep;
1035 std::sort(t_idx, t_idx + l, LessThanValArray<CvValArray, int>(&va));
1036 if( findStumpThreshold_32s[stumperror](
1037 t_data + ti * t_cstep, t_sstep,
1038 wdata, wstep, ydata, ystep,
1039 (uchar*)t_idx, sizeof( int ), l,
1041 &threshold, &left, &right,
1042 &sumw, &sumwy, &sumwyy ) )
1048 #pragma omp critical(c_compidx)
1049 #endif /* _OPENMP */
1051 t_compidx = compidx;
1054 } /* while have training data */
1056 /* get the best classifier */
1058 #pragma omp critical(c_beststump)
1059 #endif /* _OPENMP */
1061 if( lerror + rerror < stump->lerror + stump->rerror )
1063 stump->lerror = lerror;
1064 stump->rerror = rerror;
1065 stump->compidx = optcompidx;
1066 stump->threshold = threshold;
1068 stump->right = right;
1072 /* free allocated memory */
1073 if( mat.data.ptr != NULL )
1075 cvFree( &(mat.data.ptr) );
1081 } /* end of parallel region */
1085 /* free allocated memory */
1086 if( filter != NULL )
1091 if( ((CvMTStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
1093 stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
1094 stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
1097 return (CvClassifier*) stump;
1101 float cvEvalCARTClassifier( CvClassifier* classifier, CvMat* sample )
1103 CV_FUNCNAME( "cvEvalCARTClassifier" );
1110 CV_ASSERT( classifier != NULL );
1111 CV_ASSERT( sample != NULL );
1112 CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1113 CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1115 if( sample->rows == 1 )
1119 if( (CV_MAT_ELEM( (*sample), float, 0,
1120 ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1121 ((CvCARTClassifier*) classifier)->threshold[idx] )
1123 idx = ((CvCARTClassifier*) classifier)->left[idx];
1127 idx = ((CvCARTClassifier*) classifier)->right[idx];
1135 if( (CV_MAT_ELEM( (*sample), float,
1136 ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1137 ((CvCARTClassifier*) classifier)->threshold[idx] )
1139 idx = ((CvCARTClassifier*) classifier)->left[idx];
1143 idx = ((CvCARTClassifier*) classifier)->right[idx];
1150 return ((CvCARTClassifier*) classifier)->val[-idx];
1154 float cvEvalCARTClassifierIdx( CvClassifier* classifier, CvMat* sample )
1156 CV_FUNCNAME( "cvEvalCARTClassifierIdx" );
1163 CV_ASSERT( classifier != NULL );
1164 CV_ASSERT( sample != NULL );
1165 CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1166 CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1168 if( sample->rows == 1 )
1172 if( (CV_MAT_ELEM( (*sample), float, 0,
1173 ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1174 ((CvCARTClassifier*) classifier)->threshold[idx] )
1176 idx = ((CvCARTClassifier*) classifier)->left[idx];
1180 idx = ((CvCARTClassifier*) classifier)->right[idx];
1188 if( (CV_MAT_ELEM( (*sample), float,
1189 ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1190 ((CvCARTClassifier*) classifier)->threshold[idx] )
1192 idx = ((CvCARTClassifier*) classifier)->left[idx];
1196 idx = ((CvCARTClassifier*) classifier)->right[idx];
1203 return (float) (-idx);
1207 void cvReleaseCARTClassifier( CvClassifier** classifier )
1209 cvFree( classifier );
1213 static void CV_CDECL icvDefaultSplitIdx_R( int compidx, float threshold,
1214 CvMat* idx, CvMat** left, CvMat** right,
1217 CvMat* trainData = (CvMat*) userdata;
1220 *left = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1221 *right = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1222 (*left)->cols = (*right)->cols = 0;
1225 for( i = 0; i < trainData->rows; i++ )
1227 if( CV_MAT_ELEM( *trainData, float, i, compidx ) < threshold )
1229 (*left)->data.fl[(*left)->cols++] = (float) i;
1233 (*right)->data.fl[(*right)->cols++] = (float) i;
1244 idxdata = idx->data.ptr;
1245 idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1246 idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1247 for( i = 0; i < idxnum; i++ )
1249 index = (int) *((float*) (idxdata + i * idxstep));
1250 if( CV_MAT_ELEM( *trainData, float, index, compidx ) < threshold )
1252 (*left)->data.fl[(*left)->cols++] = (float) index;
1256 (*right)->data.fl[(*right)->cols++] = (float) index;
1262 static void CV_CDECL icvDefaultSplitIdx_C( int compidx, float threshold,
1263 CvMat* idx, CvMat** left, CvMat** right,
1266 CvMat* trainData = (CvMat*) userdata;
1269 *left = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1270 *right = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1271 (*left)->cols = (*right)->cols = 0;
1274 for( i = 0; i < trainData->cols; i++ )
1276 if( CV_MAT_ELEM( *trainData, float, compidx, i ) < threshold )
1278 (*left)->data.fl[(*left)->cols++] = (float) i;
1282 (*right)->data.fl[(*right)->cols++] = (float) i;
1293 idxdata = idx->data.ptr;
1294 idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1295 idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1296 for( i = 0; i < idxnum; i++ )
1298 index = (int) *((float*) (idxdata + i * idxstep));
1299 if( CV_MAT_ELEM( *trainData, float, compidx, index ) < threshold )
1301 (*left)->data.fl[(*left)->cols++] = (float) index;
1305 (*right)->data.fl[(*right)->cols++] = (float) index;
1311 /* internal structure used in CART creation */
1312 typedef struct CvCARTNode
1315 CvStumpClassifier* stump;
1322 CvClassifier* cvCreateCARTClassifier( CvMat* trainData,
1324 CvMat* trainClasses,
1326 CvMat* missedMeasurementsMask,
1330 CvClassifierTrainParams* trainParams )
1332 CvCARTClassifier* cart = NULL;
1333 size_t datasize = 0;
1338 CvCARTNode* intnode = NULL;
1339 CvCARTNode* list = NULL;
1344 float maxerrdrop = 0.0F;
1347 void (*splitIdxCallback)( int compidx, float threshold,
1348 CvMat* idx, CvMat** left, CvMat** right,
1352 count = ((CvCARTTrainParams*) trainParams)->count;
1354 assert( count > 0 );
1356 datasize = sizeof( *cart ) + (sizeof( float ) + 3 * sizeof( int )) * count +
1357 sizeof( float ) * (count + 1);
1359 cart = (CvCARTClassifier*) cvAlloc( datasize );
1360 memset( cart, 0, datasize );
1362 cart->count = count;
1364 cart->eval = cvEvalCARTClassifier;
1366 cart->release = cvReleaseCARTClassifier;
1368 cart->compidx = (int*) (cart + 1);
1369 cart->threshold = (float*) (cart->compidx + count);
1370 cart->left = (int*) (cart->threshold + count);
1371 cart->right = (int*) (cart->left + count);
1372 cart->val = (float*) (cart->right + count);
1374 datasize = sizeof( CvCARTNode ) * (count + count);
1375 intnode = (CvCARTNode*) cvAlloc( datasize );
1376 memset( intnode, 0, datasize );
1377 list = (CvCARTNode*) (intnode + count);
1379 splitIdxCallback = ((CvCARTTrainParams*) trainParams)->splitIdx;
1380 userdata = ((CvCARTTrainParams*) trainParams)->userdata;
1381 if( splitIdxCallback == NULL )
1383 splitIdxCallback = ( CV_IS_ROW_SAMPLE( flags ) )
1384 ? icvDefaultSplitIdx_R : icvDefaultSplitIdx_C;
1385 userdata = trainData;
1388 /* create root of the tree */
1389 intnode[0].sampleIdx = sampleIdx;
1390 intnode[0].stump = (CvStumpClassifier*)
1391 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1392 trainClasses, typeMask, missedMeasurementsMask, compIdx, sampleIdx, weights,
1393 ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1394 cart->left[0] = cart->right[0] = 0;
1398 for( i = 1; i < count; i++ )
1400 /* split last added node */
1401 splitIdxCallback( intnode[i-1].stump->compidx, intnode[i-1].stump->threshold,
1402 intnode[i-1].sampleIdx, &lidx, &ridx, userdata );
1404 if( intnode[i-1].stump->lerror != 0.0F )
1406 list[listcount].sampleIdx = lidx;
1407 list[listcount].stump = (CvStumpClassifier*)
1408 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1409 trainClasses, typeMask, missedMeasurementsMask, compIdx,
1410 list[listcount].sampleIdx,
1411 weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1412 list[listcount].errdrop = intnode[i-1].stump->lerror
1413 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1414 list[listcount].leftflag = 1;
1415 list[listcount].parent = i-1;
1420 cvReleaseMat( &lidx );
1422 if( intnode[i-1].stump->rerror != 0.0F )
1424 list[listcount].sampleIdx = ridx;
1425 list[listcount].stump = (CvStumpClassifier*)
1426 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1427 trainClasses, typeMask, missedMeasurementsMask, compIdx,
1428 list[listcount].sampleIdx,
1429 weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1430 list[listcount].errdrop = intnode[i-1].stump->rerror
1431 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1432 list[listcount].leftflag = 0;
1433 list[listcount].parent = i-1;
1438 cvReleaseMat( &ridx );
1441 if( listcount == 0 ) break;
1443 /* find the best node to be added to the tree */
1445 maxerrdrop = list[idx].errdrop;
1446 for( j = 1; j < listcount; j++ )
1448 if( list[j].errdrop > maxerrdrop )
1451 maxerrdrop = list[j].errdrop;
1454 intnode[i] = list[idx];
1455 if( list[idx].leftflag )
1457 cart->left[list[idx].parent] = i;
1461 cart->right[list[idx].parent] = i;
1463 if( idx != (listcount - 1) )
1465 list[idx] = list[listcount - 1];
1470 /* fill <cart> fields */
1473 for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1476 cart->compidx[i] = intnode[i].stump->compidx;
1477 cart->threshold[i] = intnode[i].stump->threshold;
1480 if( cart->left[i] <= 0 )
1483 cart->val[j] = intnode[i].stump->left;
1486 if( cart->right[i] <= 0 )
1488 cart->right[i] = -j;
1489 cart->val[j] = intnode[i].stump->right;
1495 for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1497 intnode[i].stump->release( (CvClassifier**) &(intnode[i].stump) );
1500 cvReleaseMat( &(intnode[i].sampleIdx) );
1503 for( i = 0; i < listcount; i++ )
1505 list[i].stump->release( (CvClassifier**) &(list[i].stump) );
1506 cvReleaseMat( &(list[i].sampleIdx) );
1511 return (CvClassifier*) cart;
1514 /****************************************************************************************\
1516 \****************************************************************************************/
1518 typedef struct CvBoostTrainer
1521 int count; /* (idx) ? number_of_indices : number_of_samples */
1527 * cvBoostStartTraining, cvBoostNextWeakClassifier, cvBoostEndTraining
1529 * These functions perform training of 2-class boosting classifier
1530 * using ANY appropriate weak classifier
1534 CvBoostTrainer* icvBoostStartTraining( CvMat* trainClasses,
1535 CvMat* weakTrainVals,
1550 CvBoostTrainer* ptr;
1556 assert( trainClasses != NULL );
1557 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1558 assert( weakTrainVals != NULL );
1559 assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1561 CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1562 CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1564 CV_Assert( m == trainnum );
1571 CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1574 datasize = sizeof( *ptr ) + sizeof( *ptr->idx ) * idxnum;
1575 ptr = (CvBoostTrainer*) cvAlloc( datasize );
1576 memset( ptr, 0, datasize );
1587 ptr->idx = (int*) (ptr + 1);
1588 ptr->count = idxnum;
1589 for( i = 0; i < ptr->count; i++ )
1591 cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1592 ptr->idx[i] = (int) s.val[0];
1595 for( i = 0; i < ptr->count; i++ )
1597 idx = (ptr->idx) ? ptr->idx[i] : i;
1599 *((float*) (traindata + idx * trainstep)) =
1600 2.0F * (*((float*) (ydata + idx * ystep))) - 1.0F;
1608 * Discrete AdaBoost functions
1612 float icvBoostNextWeakClassifierDAB( CvMat* weakEvalVals,
1613 CvMat* trainClasses,
1614 CvMat* /*weakTrainVals*/,
1616 CvBoostTrainer* trainer )
1633 CV_Assert( weakEvalVals != NULL );
1634 CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1635 CV_Assert( trainClasses != NULL );
1636 CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1637 CV_Assert( weights != NULL );
1638 CV_Assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1640 CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1641 CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1642 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1644 CV_Assert( m == ynum );
1645 CV_Assert( m == wnum );
1649 for( i = 0; i < trainer->count; i++ )
1651 idx = (trainer->idx) ? trainer->idx[i] : i;
1653 sumw += *((float*) (wdata + idx*wstep));
1654 err += (*((float*) (wdata + idx*wstep))) *
1655 ( (*((float*) (evaldata + idx*evalstep))) !=
1656 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F );
1659 err = -cvLogRatio( err );
1661 for( i = 0; i < trainer->count; i++ )
1663 idx = (trainer->idx) ? trainer->idx[i] : i;
1665 *((float*) (wdata + idx*wstep)) *= expf( err *
1666 ((*((float*) (evaldata + idx*evalstep))) !=
1667 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F) );
1668 sumw += *((float*) (wdata + idx*wstep));
1670 for( i = 0; i < trainer->count; i++ )
1672 idx = (trainer->idx) ? trainer->idx[i] : i;
1674 *((float*) (wdata + idx * wstep)) /= sumw;
1682 * Real AdaBoost functions
1686 float icvBoostNextWeakClassifierRAB( CvMat* weakEvalVals,
1687 CvMat* trainClasses,
1688 CvMat* /*weakTrainVals*/,
1690 CvBoostTrainer* trainer )
1705 CV_Assert( weakEvalVals != NULL );
1706 CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1707 CV_Assert( trainClasses != NULL );
1708 CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1709 CV_Assert( weights != NULL );
1710 CV_Assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1712 CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1713 CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1714 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1716 CV_Assert( m == ynum );
1717 CV_Assert( m == wnum );
1721 for( i = 0; i < trainer->count; i++ )
1723 idx = (trainer->idx) ? trainer->idx[i] : i;
1725 *((float*) (wdata + idx*wstep)) *= expf( (-(*((float*) (ydata + idx*ystep))) + 0.5F)
1726 * cvLogRatio( *((float*) (evaldata + idx*evalstep)) ) );
1727 sumw += *((float*) (wdata + idx*wstep));
1729 for( i = 0; i < trainer->count; i++ )
1731 idx = (trainer->idx) ? trainer->idx[i] : i;
1733 *((float*) (wdata + idx*wstep)) /= sumw;
1741 * LogitBoost functions
1744 #define CV_LB_PROB_THRESH 0.01F
1745 #define CV_LB_WEIGHT_THRESHOLD 0.0001F
1748 void icvResponsesAndWeightsLB( int num, uchar* wdata, int wstep,
1749 uchar* ydata, int ystep,
1750 uchar* fdata, int fstep,
1751 uchar* traindata, int trainstep,
1757 for( i = 0; i < num; i++ )
1759 idx = (indices) ? indices[i] : i;
1761 p = 1.0F / (1.0F + expf( -(*((float*) (fdata + idx*fstep)))) );
1762 *((float*) (wdata + idx*wstep)) = MAX( p * (1.0F - p), CV_LB_WEIGHT_THRESHOLD );
1763 if( *((float*) (ydata + idx*ystep)) == 1.0F )
1765 *((float*) (traindata + idx*trainstep)) =
1766 1.0F / (MAX( p, CV_LB_PROB_THRESH ));
1770 *((float*) (traindata + idx*trainstep)) =
1771 -1.0F / (MAX( 1.0F - p, CV_LB_PROB_THRESH ));
1777 CvBoostTrainer* icvBoostStartTrainingLB( CvMat* trainClasses,
1778 CvMat* weakTrainVals,
1784 CvBoostTrainer* ptr;
1801 assert( trainClasses != NULL );
1802 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1803 assert( weakTrainVals != NULL );
1804 assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1805 assert( weights != NULL );
1806 assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1808 CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1809 CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1810 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1812 CV_Assert( m == trainnum );
1813 CV_Assert( m == wnum );
1821 CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1824 datasize = sizeof( *ptr ) + sizeof( *ptr->F ) * m + sizeof( *ptr->idx ) * idxnum;
1825 ptr = (CvBoostTrainer*) cvAlloc( datasize );
1826 memset( ptr, 0, datasize );
1827 ptr->F = (float*) (ptr + 1);
1837 ptr->idx = (int*) (ptr->F + m);
1838 ptr->count = idxnum;
1839 for( i = 0; i < ptr->count; i++ )
1841 cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1842 ptr->idx[i] = (int) s.val[0];
1846 for( i = 0; i < m; i++ )
1851 icvResponsesAndWeightsLB( ptr->count, wdata, wstep, ydata, ystep,
1852 (uchar*) ptr->F, sizeof( *ptr->F ),
1853 traindata, trainstep, ptr->idx );
1859 float icvBoostNextWeakClassifierLB( CvMat* weakEvalVals,
1860 CvMat* trainClasses,
1861 CvMat* weakTrainVals,
1863 CvBoostTrainer* trainer )
1879 assert( weakEvalVals != NULL );
1880 assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1881 assert( trainClasses != NULL );
1882 assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1883 assert( weakTrainVals != NULL );
1884 assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1885 assert( weights != NULL );
1886 assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1888 CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1889 CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1890 CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1891 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1893 CV_Assert( m == ynum );
1894 CV_Assert( m == wnum );
1895 CV_Assert( m == trainnum );
1896 //assert( m == trainer->count );
1898 for( i = 0; i < trainer->count; i++ )
1900 idx = (trainer->idx) ? trainer->idx[i] : i;
1902 trainer->F[idx] += *((float*) (evaldata + idx * evalstep));
1905 icvResponsesAndWeightsLB( trainer->count, wdata, wstep, ydata, ystep,
1906 (uchar*) trainer->F, sizeof( *trainer->F ),
1907 traindata, trainstep, trainer->idx );
1918 float icvBoostNextWeakClassifierGAB( CvMat* weakEvalVals,
1919 CvMat* trainClasses,
1920 CvMat* /*weakTrainVals*/,
1922 CvBoostTrainer* trainer )
1937 CV_Assert( weakEvalVals != NULL );
1938 CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1939 CV_Assert( trainClasses != NULL );
1940 CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1941 CV_Assert( weights != NULL );
1942 CV_Assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1944 CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1945 CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1946 CV_MAT2VEC( *weights, wdata, wstep, wnum );
1948 CV_Assert( m == ynum );
1949 CV_Assert( m == wnum );
1952 for( i = 0; i < trainer->count; i++ )
1954 idx = (trainer->idx) ? trainer->idx[i] : i;
1956 *((float*) (wdata + idx*wstep)) *=
1957 expf( -(*((float*) (evaldata + idx*evalstep)))
1958 * ( 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F ) );
1959 sumw += *((float*) (wdata + idx*wstep));
1962 for( i = 0; i < trainer->count; i++ )
1964 idx = (trainer->idx) ? trainer->idx[i] : i;
1966 *((float*) (wdata + idx*wstep)) /= sumw;
1972 typedef CvBoostTrainer* (*CvBoostStartTraining)( CvMat* trainClasses,
1973 CvMat* weakTrainVals,
1978 typedef float (*CvBoostNextWeakClassifier)( CvMat* weakEvalVals,
1979 CvMat* trainClasses,
1980 CvMat* weakTrainVals,
1982 CvBoostTrainer* data );
1984 CvBoostStartTraining startTraining[4] = {
1985 icvBoostStartTraining,
1986 icvBoostStartTraining,
1987 icvBoostStartTrainingLB,
1988 icvBoostStartTraining
1991 CvBoostNextWeakClassifier nextWeakClassifier[4] = {
1992 icvBoostNextWeakClassifierDAB,
1993 icvBoostNextWeakClassifierRAB,
1994 icvBoostNextWeakClassifierLB,
1995 icvBoostNextWeakClassifierGAB
2004 CvBoostTrainer* cvBoostStartTraining( CvMat* trainClasses,
2005 CvMat* weakTrainVals,
2010 return startTraining[type]( trainClasses, weakTrainVals, weights, sampleIdx, type );
2014 void cvBoostEndTraining( CvBoostTrainer** trainer )
2021 float cvBoostNextWeakClassifier( CvMat* weakEvalVals,
2022 CvMat* trainClasses,
2023 CvMat* weakTrainVals,
2025 CvBoostTrainer* trainer )
2027 return nextWeakClassifier[trainer->type]( weakEvalVals, trainClasses,
2028 weakTrainVals, weights, trainer );
2031 /****************************************************************************************\
2032 * Boosted tree models *
2033 \****************************************************************************************/
2035 typedef struct CvBtTrainer
2041 CvMat* trainClasses;
2054 CvMTStumpTrainParams stumpParams;
2055 CvCARTTrainParams cartParams;
2057 float* f; /* F_(m-1) */
2058 CvMat* y; /* yhat */
2060 CvBoostTrainer* boosttrainer;
2064 * cvBtStart, cvBtNext, cvBtEnd
2066 * These functions perform iterative training of
2067 * 2-class (CV_DABCLASS - CV_GABCLASS, CV_L2CLASS), K-class (CV_LKCLASS) classifier
2068 * or fit regression model (CV_LSREG, CV_LADREG, CV_MREG)
2069 * using decision tree as a weak classifier.
2072 typedef void (*CvZeroApproxFunc)( float* approx, CvBtTrainer* trainer );
2074 /* Mean zero approximation */
2075 static void icvZeroApproxMean( float* approx, CvBtTrainer* trainer )
2081 for( i = 0; i < trainer->numsamples; i++ )
2083 idx = icvGetIdxAt( trainer->sampleIdx, i );
2084 approx[0] += *((float*) (trainer->ydata + idx * trainer->ystep));
2086 approx[0] /= (float) trainer->numsamples;
2090 * Median zero approximation
2092 static void icvZeroApproxMed( float* approx, CvBtTrainer* trainer )
2097 for( i = 0; i < trainer->numsamples; i++ )
2099 idx = icvGetIdxAt( trainer->sampleIdx, i );
2100 trainer->f[i] = *((float*) (trainer->ydata + idx * trainer->ystep));
2103 std::sort(trainer->f, trainer->f + trainer->numsamples);
2104 approx[0] = trainer->f[trainer->numsamples / 2];
2108 * 0.5 * log( mean(y) / (1 - mean(y)) ) where y in {0, 1}
2110 static void icvZeroApproxLog( float* approx, CvBtTrainer* trainer )
2114 icvZeroApproxMean( &y_mean, trainer );
2115 approx[0] = 0.5F * cvLogRatio( y_mean );
2119 * 0 zero approximation
2121 static void icvZeroApprox0( float* approx, CvBtTrainer* trainer )
2125 for( i = 0; i < trainer->numclasses; i++ )
2131 static CvZeroApproxFunc icvZeroApproxFunc[] =
2133 icvZeroApprox0, /* CV_DABCLASS */
2134 icvZeroApprox0, /* CV_RABCLASS */
2135 icvZeroApprox0, /* CV_LBCLASS */
2136 icvZeroApprox0, /* CV_GABCLASS */
2137 icvZeroApproxLog, /* CV_L2CLASS */
2138 icvZeroApprox0, /* CV_LKCLASS */
2139 icvZeroApproxMean, /* CV_LSREG */
2140 icvZeroApproxMed, /* CV_LADREG */
2141 icvZeroApproxMed, /* CV_MREG */
2145 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer );
2148 CvBtTrainer* cvBtStart( CvCARTClassifier** trees,
2151 CvMat* trainClasses,
2158 CvBtTrainer* ptr = 0;
2160 CV_FUNCNAME( "cvBtStart" );
2171 CV_ERROR( CV_StsNullPtr, "Invalid trees parameter" );
2174 if( type < CV_DABCLASS || type > CV_MREG )
2176 CV_ERROR( CV_StsUnsupportedFormat, "Unsupported type parameter" );
2178 if( type == CV_LKCLASS )
2180 CV_ASSERT( numclasses >= 2 );
2187 m = MAX( trainClasses->rows, trainClasses->cols );
2189 data_size = sizeof( *ptr );
2190 if( type > CV_GABCLASS )
2192 data_size += m * numclasses * sizeof( *(ptr->f) );
2194 CV_CALL( ptr = (CvBtTrainer*) cvAlloc( data_size ) );
2195 memset( ptr, 0, data_size );
2196 ptr->f = (float*) (ptr + 1);
2198 ptr->trainData = trainData;
2200 ptr->trainClasses = trainClasses;
2201 CV_MAT2VEC( *trainClasses, ptr->ydata, ptr->ystep, ptr->m );
2203 memset( &(ptr->cartParams), 0, sizeof( ptr->cartParams ) );
2204 memset( &(ptr->stumpParams), 0, sizeof( ptr->stumpParams ) );
2209 ptr->stumpParams.error = CV_MISCLASSIFICATION;
2210 ptr->stumpParams.type = CV_CLASSIFICATION_CLASS;
2213 ptr->stumpParams.error = CV_GINI;
2214 ptr->stumpParams.type = CV_CLASSIFICATION;
2217 ptr->stumpParams.error = CV_SQUARE;
2218 ptr->stumpParams.type = CV_REGRESSION;
2220 ptr->cartParams.count = numsplits;
2221 ptr->cartParams.stumpTrainParams = (CvClassifierTrainParams*) &(ptr->stumpParams);
2222 ptr->cartParams.stumpConstructor = cvCreateMTStumpClassifier;
2224 ptr->param[0] = param[0];
2225 ptr->param[1] = param[1];
2227 ptr->numclasses = numclasses;
2229 CV_CALL( ptr->y = cvCreateMat( 1, m, CV_32FC1 ) );
2230 ptr->sampleIdx = sampleIdx;
2231 ptr->numsamples = ( sampleIdx == NULL ) ? ptr->m
2232 : MAX( sampleIdx->rows, sampleIdx->cols );
2234 ptr->weights = cvCreateMat( 1, m, CV_32FC1 );
2235 cvSet( ptr->weights, cvScalar( 1.0 ) );
2237 if( type <= CV_GABCLASS )
2239 ptr->boosttrainer = cvBoostStartTraining( ptr->trainClasses, ptr->y,
2240 ptr->weights, NULL, type );
2242 CV_CALL( cvBtNext( trees, ptr ) );
2246 data_size = sizeof( *zero_approx ) * numclasses;
2247 CV_CALL( zero_approx = (float*) cvAlloc( data_size ) );
2248 icvZeroApproxFunc[type]( zero_approx, ptr );
2249 for( i = 0; i < m; i++ )
2251 for( j = 0; j < numclasses; j++ )
2253 ptr->f[i * numclasses + j] = zero_approx[j];
2257 CV_CALL( cvBtNext( trees, ptr ) );
2259 for( i = 0; i < numclasses; i++ )
2261 for( j = 0; j <= trees[i]->count; j++ )
2263 trees[i]->val[j] += zero_approx[i];
2266 CV_CALL( cvFree( &zero_approx ) );
2274 static void icvBtNext_LSREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2278 /* yhat_i = y_i - F_(m-1)(x_i) */
2279 for( i = 0; i < trainer->m; i++ )
2281 trainer->y->data.fl[i] =
2282 *((float*) (trainer->ydata + i * trainer->ystep)) - trainer->f[i];
2285 trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2287 trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2288 (CvClassifierTrainParams*) &trainer->cartParams );
2292 static void icvBtNext_LADREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2294 CvCARTClassifier* ptr;
2307 data_size = trainer->m * sizeof( *idx );
2308 idx = (int*) cvAlloc( data_size );
2309 data_size = trainer->m * sizeof( *resp );
2310 resp = (float*) cvAlloc( data_size );
2312 /* yhat_i = sign(y_i - F_(m-1)(x_i)) */
2313 for( i = 0; i < trainer->numsamples; i++ )
2315 index = icvGetIdxAt( trainer->sampleIdx, i );
2316 trainer->y->data.fl[index] = (float)
2317 CV_SIGN( *((float*) (trainer->ydata + index * trainer->ystep))
2318 - trainer->f[index] );
2321 ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2322 trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2323 (CvClassifierTrainParams*) &trainer->cartParams );
2325 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2326 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2327 sample_data = sample.data.ptr;
2328 for( i = 0; i < trainer->numsamples; i++ )
2330 index = icvGetIdxAt( trainer->sampleIdx, i );
2331 sample.data.ptr = sample_data + index * sample_step;
2332 idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2334 for( j = 0; j <= ptr->count; j++ )
2337 for( i = 0; i < trainer->numsamples; i++ )
2339 index = icvGetIdxAt( trainer->sampleIdx, i );
2340 if( idx[index] == j )
2342 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2343 - trainer->f[index];
2348 std::sort(resp, resp + respnum);
2349 val = resp[respnum / 2];
2365 static void icvBtNext_MREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2367 CvCARTClassifier* ptr;
2383 data_size = trainer->m * sizeof( *idx );
2384 idx = (int*) cvAlloc( data_size );
2385 data_size = trainer->m * sizeof( *resp );
2386 resp = (float*) cvAlloc( data_size );
2387 data_size = trainer->m * sizeof( *resid );
2388 resid = (float*) cvAlloc( data_size );
2390 /* resid_i = (y_i - F_(m-1)(x_i)) */
2391 for( i = 0; i < trainer->numsamples; i++ )
2393 index = icvGetIdxAt( trainer->sampleIdx, i );
2394 resid[index] = *((float*) (trainer->ydata + index * trainer->ystep))
2395 - trainer->f[index];
2397 resp[i] = (float) fabs( resid[index] );
2400 /* delta = quantile_alpha{abs(resid_i)} */
2401 std::sort(resp, resp + trainer->numsamples);
2402 delta = resp[(int)(trainer->param[1] * (trainer->numsamples - 1))];
2405 for( i = 0; i < trainer->numsamples; i++ )
2407 index = icvGetIdxAt( trainer->sampleIdx, i );
2408 trainer->y->data.fl[index] = MIN( delta, ((float) fabs( resid[index] )) ) *
2409 CV_SIGN( resid[index] );
2412 ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2413 trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2414 (CvClassifierTrainParams*) &trainer->cartParams );
2416 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2417 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2418 sample_data = sample.data.ptr;
2419 for( i = 0; i < trainer->numsamples; i++ )
2421 index = icvGetIdxAt( trainer->sampleIdx, i );
2422 sample.data.ptr = sample_data + index * sample_step;
2423 idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2425 for( j = 0; j <= ptr->count; j++ )
2429 for( i = 0; i < trainer->numsamples; i++ )
2431 index = icvGetIdxAt( trainer->sampleIdx, i );
2432 if( idx[index] == j )
2434 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2435 - trainer->f[index];
2440 /* rhat = median(y_i - F_(m-1)(x_i)) */
2441 std::sort(resp, resp + respnum);
2442 rhat = resp[respnum / 2];
2444 /* val = sum{sign(r_i - rhat_i) * min(delta, abs(r_i - rhat_i)}
2445 * r_i = y_i - F_(m-1)(x_i)
2448 for( i = 0; i < respnum; i++ )
2450 val += CV_SIGN( resp[i] - rhat )
2451 * MIN( delta, (float) fabs( resp[i] - rhat ) );
2454 val = rhat + val / (float) respnum;
2472 //#define CV_VAL_MAX 1e304
2474 //#define CV_LOG_VAL_MAX 700.0
2476 #define CV_VAL_MAX 1e+8
2478 #define CV_LOG_VAL_MAX 18.0
2480 static void icvBtNext_L2CLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2482 CvCARTClassifier* ptr;
2496 float* sorted_weights;
2502 data_size = trainer->m * sizeof( *idx );
2503 idx = (int*) cvAlloc( data_size );
2505 data_size = trainer->m * sizeof( *weights );
2506 weights = (float*) cvAlloc( data_size );
2507 data_size = trainer->m * sizeof( *sorted_weights );
2508 sorted_weights = (float*) cvAlloc( data_size );
2510 /* yhat_i = (4 * y_i - 2) / ( 1 + exp( (4 * y_i - 2) * F_(m-1)(x_i) ) ).
2514 for( i = 0; i < trainer->numsamples; i++ )
2516 index = icvGetIdxAt( trainer->sampleIdx, i );
2517 val = 4.0F * (*((float*) (trainer->ydata + index * trainer->ystep))) - 2.0F;
2518 val_f = val * trainer->f[index];
2519 val_f = ( val_f < CV_LOG_VAL_MAX ) ? exp( val_f ) : CV_LOG_VAL_MAX;
2520 val = (float) ( (double) val / ( 1.0 + val_f ) );
2521 trainer->y->data.fl[index] = val;
2522 val = (float) fabs( val );
2523 weights[index] = val * (2.0F - val);
2524 sorted_weights[i] = weights[index];
2525 sum_weights += sorted_weights[i];
2529 sample_idx = trainer->sampleIdx;
2530 trimmed_num = trainer->numsamples;
2531 if( trainer->param[1] < 1.0F )
2533 /* perform weight trimming */
2538 std::sort(sorted_weights, sorted_weights + trainer->numsamples);
2540 sum_weights *= (1.0F - trainer->param[1]);
2543 do { sum_weights -= sorted_weights[++i]; }
2544 while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2546 threshold = sorted_weights[i];
2548 while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2552 trimmed_num = trainer->numsamples - i;
2553 trimmed_idx = cvCreateMat( 1, trimmed_num, CV_32FC1 );
2555 for( i = 0; i < trainer->numsamples; i++ )
2557 index = icvGetIdxAt( trainer->sampleIdx, i );
2558 if( weights[index] >= threshold )
2560 CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2565 assert( count == trimmed_num );
2567 sample_idx = trimmed_idx;
2569 printf( "Used samples %%: %g\n",
2570 (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2574 ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2575 trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2576 (CvClassifierTrainParams*) &trainer->cartParams );
2578 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2579 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2580 sample_data = sample.data.ptr;
2581 for( i = 0; i < trimmed_num; i++ )
2583 index = icvGetIdxAt( sample_idx, i );
2584 sample.data.ptr = sample_data + index * sample_step;
2585 idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2587 for( j = 0; j <= ptr->count; j++ )
2592 for( i = 0; i < trimmed_num; i++ )
2594 index = icvGetIdxAt( sample_idx, i );
2595 if( idx[index] == j )
2597 val += trainer->y->data.fl[index];
2598 sum_weights += weights[index];
2602 if( sum_weights > 0.0F )
2613 if( trimmed_idx != NULL ) cvReleaseMat( &trimmed_idx );
2614 cvFree( &sorted_weights );
2621 static void icvBtNext_LKCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2623 int i, j, k, kk, num;
2635 float* sorted_weights;
2644 data_size = trainer->m * sizeof( *idx );
2645 idx = (int*) cvAlloc( data_size );
2646 data_size = trainer->m * sizeof( *weights );
2647 weights = (float*) cvAlloc( data_size );
2648 data_size = trainer->m * sizeof( *sorted_weights );
2649 sorted_weights = (float*) cvAlloc( data_size );
2650 trimmed_idx = cvCreateMat( 1, trainer->numsamples, CV_32FC1 );
2652 for( k = 0; k < trainer->numclasses; k++ )
2654 /* yhat_i = y_i - p_k(x_i), y_i in {0, 1} */
2655 /* p_k(x_i) = exp(f_k(x_i)) / (sum_exp_f(x_i)) */
2657 for( i = 0; i < trainer->numsamples; i++ )
2659 index = icvGetIdxAt( trainer->sampleIdx, i );
2660 /* p_k(x_i) = 1 / (1 + sum(exp(f_kk(x_i) - f_k(x_i)))), kk != k */
2661 num = index * trainer->numclasses;
2662 f_k = (double) trainer->f[num + k];
2664 for( kk = 0; kk < trainer->numclasses; kk++ )
2666 if( kk == k ) continue;
2667 exp_f = (double) trainer->f[num + kk] - f_k;
2668 exp_f = (exp_f < CV_LOG_VAL_MAX) ? exp( exp_f ) : CV_VAL_MAX;
2669 if( exp_f == CV_VAL_MAX || exp_f >= (CV_VAL_MAX - sum_exp_f) )
2671 sum_exp_f = CV_VAL_MAX;
2677 val = (float) ( (*((float*) (trainer->ydata + index * trainer->ystep)))
2679 val -= (float) ( (sum_exp_f == CV_VAL_MAX) ? 0.0 : ( 1.0 / sum_exp_f ) );
2681 assert( val >= -1.0F );
2682 assert( val <= 1.0F );
2684 trainer->y->data.fl[index] = val;
2685 val = (float) fabs( val );
2686 weights[index] = val * (1.0F - val);
2687 sorted_weights[i] = weights[index];
2688 sum_weights += sorted_weights[i];
2691 sample_idx = trainer->sampleIdx;
2692 trimmed_num = trainer->numsamples;
2693 if( trainer->param[1] < 1.0F )
2695 /* perform weight trimming */
2700 std::sort(sorted_weights, sorted_weights + trainer->numsamples);
2702 sum_weights *= (1.0F - trainer->param[1]);
2705 do { sum_weights -= sorted_weights[++i]; }
2706 while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2708 threshold = sorted_weights[i];
2710 while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2714 trimmed_num = trainer->numsamples - i;
2715 trimmed_idx->cols = trimmed_num;
2717 for( i = 0; i < trainer->numsamples; i++ )
2719 index = icvGetIdxAt( trainer->sampleIdx, i );
2720 if( weights[index] >= threshold )
2722 CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2727 assert( count == trimmed_num );
2729 sample_idx = trimmed_idx;
2731 printf( "k: %d Used samples %%: %g\n", k,
2732 (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2734 } /* weight trimming */
2736 trees[k] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2737 trainer->flags, trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2738 (CvClassifierTrainParams*) &trainer->cartParams );
2740 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2741 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2742 sample_data = sample.data.ptr;
2743 for( i = 0; i < trimmed_num; i++ )
2745 index = icvGetIdxAt( sample_idx, i );
2746 sample.data.ptr = sample_data + index * sample_step;
2747 idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) trees[k],
2750 for( j = 0; j <= trees[k]->count; j++ )
2755 for( i = 0; i < trimmed_num; i++ )
2757 index = icvGetIdxAt( sample_idx, i );
2758 if( idx[index] == j )
2760 val += trainer->y->data.fl[index];
2761 sum_weights += weights[index];
2765 if( sum_weights > 0.0F )
2767 val = ((float) (trainer->numclasses - 1)) * val /
2768 ((float) (trainer->numclasses)) / sum_weights;
2774 trees[k]->val[j] = val;
2776 } /* for each class */
2778 cvReleaseMat( &trimmed_idx );
2779 cvFree( &sorted_weights );
2785 static void icvBtNext_XXBCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2789 CvMat* weak_eval_vals;
2796 weak_eval_vals = cvCreateMat( 1, trainer->m, CV_32FC1 );
2798 sample_idx = cvTrimWeights( trainer->weights, trainer->sampleIdx,
2799 trainer->param[1] );
2800 num_samples = ( sample_idx == NULL )
2801 ? trainer->m : MAX( sample_idx->rows, sample_idx->cols );
2803 printf( "Used samples %%: %g\n",
2804 (float) num_samples / (float) trainer->numsamples * 100.0F );
2806 trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2807 trainer->flags, trainer->y, NULL, NULL, NULL,
2808 sample_idx, trainer->weights,
2809 (CvClassifierTrainParams*) &trainer->cartParams );
2811 /* evaluate samples */
2812 CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2813 CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2814 sample_data = sample.data.ptr;
2816 for( i = 0; i < trainer->m; i++ )
2818 sample.data.ptr = sample_data + i * sample_step;
2819 weak_eval_vals->data.fl[i] = trees[0]->eval( (CvClassifier*) trees[0], &sample );
2822 alpha = cvBoostNextWeakClassifier( weak_eval_vals, trainer->trainClasses,
2823 trainer->y, trainer->weights, trainer->boosttrainer );
2825 /* multiply tree by alpha */
2826 for( i = 0; i <= trees[0]->count; i++ )
2828 trees[0]->val[i] *= alpha;
2830 if( trainer->type == CV_RABCLASS )
2832 for( i = 0; i <= trees[0]->count; i++ )
2834 trees[0]->val[i] = cvLogRatio( trees[0]->val[i] );
2838 if( sample_idx != NULL && sample_idx != trainer->sampleIdx )
2840 cvReleaseMat( &sample_idx );
2842 cvReleaseMat( &weak_eval_vals );
2845 typedef void (*CvBtNextFunc)( CvCARTClassifier** trees, CvBtTrainer* trainer );
2847 static CvBtNextFunc icvBtNextFunc[] =
2861 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer )
2869 icvBtNextFunc[trainer->type]( trees, trainer );
2872 if( trainer->param[0] != 1.0F )
2874 for( j = 0; j < trainer->numclasses; j++ )
2876 for( i = 0; i <= trees[j]->count; i++ )
2878 trees[j]->val[i] *= trainer->param[0];
2883 if( trainer->type > CV_GABCLASS )
2885 /* update F_(m-1) */
2886 CV_GET_SAMPLE( *(trainer->trainData), trainer->flags, 0, sample );
2887 CV_GET_SAMPLE_STEP( *(trainer->trainData), trainer->flags, sample_step );
2888 sample_data = sample.data.ptr;
2889 for( i = 0; i < trainer->numsamples; i++ )
2891 index = icvGetIdxAt( trainer->sampleIdx, i );
2892 sample.data.ptr = sample_data + index * sample_step;
2893 for( j = 0; j < trainer->numclasses; j++ )
2895 trainer->f[index * trainer->numclasses + j] +=
2896 trees[j]->eval( (CvClassifier*) (trees[j]), &sample );
2903 void cvBtEnd( CvBtTrainer** trainer )
2905 CV_FUNCNAME( "cvBtEnd" );
2909 if( trainer == NULL || (*trainer) == NULL )
2911 CV_ERROR( CV_StsNullPtr, "Invalid trainer parameter" );
2914 if( (*trainer)->y != NULL )
2916 CV_CALL( cvReleaseMat( &((*trainer)->y) ) );
2918 if( (*trainer)->weights != NULL )
2920 CV_CALL( cvReleaseMat( &((*trainer)->weights) ) );
2922 if( (*trainer)->boosttrainer != NULL )
2924 CV_CALL( cvBoostEndTraining( &((*trainer)->boosttrainer) ) );
2926 CV_CALL( cvFree( trainer ) );
2931 /****************************************************************************************\
2932 * Boosted tree model as a classifier *
2933 \****************************************************************************************/
2936 float cvEvalBtClassifier( CvClassifier* classifier, CvMat* sample )
2940 CV_FUNCNAME( "cvEvalBtClassifier" );
2947 if( CV_IS_TUNABLE( classifier->flags ) )
2950 CvCARTClassifier* tree;
2952 CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
2953 for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2955 CV_READ_SEQ_ELEM( tree, reader );
2956 val += tree->eval( (CvClassifier*) tree, sample );
2961 CvCARTClassifier** ptree;
2963 ptree = ((CvBtClassifier*) classifier)->trees;
2964 for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2966 val += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
2977 float cvEvalBtClassifier2( CvClassifier* classifier, CvMat* sample )
2981 CV_FUNCNAME( "cvEvalBtClassifier2" );
2985 CV_CALL( val = cvEvalBtClassifier( classifier, sample ) );
2989 return (float) (val >= 0.0F);
2993 float cvEvalBtClassifierK( CvClassifier* classifier, CvMat* sample )
2997 CV_FUNCNAME( "cvEvalBtClassifierK" );
3008 numclasses = ((CvBtClassifier*) classifier)->numclasses;
3009 data_size = sizeof( *vals ) * numclasses;
3010 CV_CALL( vals = (float*) cvAlloc( data_size ) );
3011 memset( vals, 0, data_size );
3013 if( CV_IS_TUNABLE( classifier->flags ) )
3016 CvCARTClassifier* tree;
3018 CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3019 for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3021 for( k = 0; k < numclasses; k++ )
3023 CV_READ_SEQ_ELEM( tree, reader );
3024 vals[k] += tree->eval( (CvClassifier*) tree, sample );
3031 CvCARTClassifier** ptree;
3033 ptree = ((CvBtClassifier*) classifier)->trees;
3034 for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3036 for( k = 0; k < numclasses; k++ )
3038 vals[k] += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
3044 max_val = vals[cls];
3045 for( k = 1; k < numclasses; k++ )
3047 if( vals[k] > max_val )
3054 CV_CALL( cvFree( &vals ) );
3061 typedef float (*CvEvalBtClassifier)( CvClassifier* classifier, CvMat* sample );
3063 static CvEvalBtClassifier icvEvalBtClassifier[] =
3065 cvEvalBtClassifier2,
3066 cvEvalBtClassifier2,
3067 cvEvalBtClassifier2,
3068 cvEvalBtClassifier2,
3069 cvEvalBtClassifier2,
3070 cvEvalBtClassifierK,
3077 int cvSaveBtClassifier( CvClassifier* classifier, const char* filename )
3079 CV_FUNCNAME( "cvSaveBtClassifier" );
3086 memset(&reader, 0, sizeof(reader));
3087 CvCARTClassifier* tree;
3089 CV_ASSERT( classifier );
3090 CV_ASSERT( filename );
3092 if( !icvMkDir( filename ) || (file = fopen( filename, "w" )) == 0 )
3094 CV_ERROR( CV_StsError, "Unable to create file" );
3097 if( CV_IS_TUNABLE( classifier->flags ) )
3099 CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3101 fprintf( file, "%d %d\n%d\n%d\n", (int) ((CvBtClassifier*) classifier)->type,
3102 ((CvBtClassifier*) classifier)->numclasses,
3103 ((CvBtClassifier*) classifier)->numfeatures,
3104 ((CvBtClassifier*) classifier)->numiter );
3106 for( i = 0; i < ((CvBtClassifier*) classifier)->numclasses *
3107 ((CvBtClassifier*) classifier)->numiter; i++ )
3109 if( CV_IS_TUNABLE( classifier->flags ) )
3111 CV_READ_SEQ_ELEM( tree, reader );
3115 tree = ((CvBtClassifier*) classifier)->trees[i];
3118 fprintf( file, "%d\n", tree->count );
3119 for( j = 0; j < tree->count; j++ )
3121 fprintf( file, "%d %g %d %d\n", tree->compidx[j],
3126 for( j = 0; j <= tree->count; j++ )
3128 fprintf( file, "%g ", tree->val[j] );
3130 fprintf( file, "\n" );
3142 void cvReleaseBtClassifier( CvClassifier** ptr )
3144 CV_FUNCNAME( "cvReleaseBtClassifier" );
3150 if( ptr == NULL || *ptr == NULL )
3152 CV_ERROR( CV_StsNullPtr, "" );
3154 if( CV_IS_TUNABLE( (*ptr)->flags ) )
3157 CvCARTClassifier* tree;
3159 CV_CALL( cvStartReadSeq( ((CvBtClassifier*) *ptr)->seq, &reader ) );
3160 for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3161 ((CvBtClassifier*) *ptr)->numiter; i++ )
3163 CV_READ_SEQ_ELEM( tree, reader );
3164 tree->release( (CvClassifier**) (&tree) );
3166 CV_CALL( cvReleaseMemStorage( &(((CvBtClassifier*) *ptr)->seq->storage) ) );
3170 CvCARTClassifier** ptree;
3172 ptree = ((CvBtClassifier*) *ptr)->trees;
3173 for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3174 ((CvBtClassifier*) *ptr)->numiter; i++ )
3176 (*ptree)->release( (CvClassifier**) ptree );
3181 CV_CALL( cvFree( ptr ) );
3187 static void cvTuneBtClassifier( CvClassifier* classifier, CvMat*, int flags,
3188 CvMat*, CvMat* , CvMat*, CvMat*, CvMat* )
3190 CV_FUNCNAME( "cvTuneBtClassifier" );
3196 if( CV_IS_TUNABLE( flags ) )
3198 if( !CV_IS_TUNABLE( classifier->flags ) )
3200 CV_ERROR( CV_StsUnsupportedFormat,
3201 "Classifier does not support tune function" );
3205 /* tune classifier */
3206 CvCARTClassifier** trees;
3208 printf( "Iteration %d\n", ((CvBtClassifier*) classifier)->numiter + 1 );
3210 data_size = sizeof( *trees ) * ((CvBtClassifier*) classifier)->numclasses;
3211 CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3212 CV_CALL( cvBtNext( trees,
3213 (CvBtTrainer*) ((CvBtClassifier*) classifier)->trainer ) );
3214 CV_CALL( cvSeqPushMulti( ((CvBtClassifier*) classifier)->seq,
3215 trees, ((CvBtClassifier*) classifier)->numclasses ) );
3216 CV_CALL( cvFree( &trees ) );
3217 ((CvBtClassifier*) classifier)->numiter++;
3222 if( CV_IS_TUNABLE( classifier->flags ) )
3227 assert( ((CvBtClassifier*) classifier)->seq->total ==
3228 ((CvBtClassifier*) classifier)->numiter *
3229 ((CvBtClassifier*) classifier)->numclasses );
3231 data_size = sizeof( ((CvBtClassifier*) classifier)->trees[0] ) *
3232 ((CvBtClassifier*) classifier)->seq->total;
3233 CV_CALL( ptr = cvAlloc( data_size ) );
3234 CV_CALL( cvCvtSeqToArray( ((CvBtClassifier*) classifier)->seq, ptr ) );
3235 CV_CALL( cvReleaseMemStorage(
3236 &(((CvBtClassifier*) classifier)->seq->storage) ) );
3237 ((CvBtClassifier*) classifier)->trees = (CvCARTClassifier**) ptr;
3238 classifier->flags &= ~CV_TUNABLE;
3239 CV_CALL( cvBtEnd( (CvBtTrainer**)
3240 &(((CvBtClassifier*) classifier)->trainer )) );
3241 ((CvBtClassifier*) classifier)->trainer = NULL;
3248 static CvBtClassifier* icvAllocBtClassifier( CvBoostType type, int flags, int numclasses,
3251 CvBtClassifier* ptr;
3254 assert( numclasses >= 1 );
3255 assert( numiter >= 0 );
3256 assert( ( numclasses == 1 ) || (type == CV_LKCLASS) );
3258 data_size = sizeof( *ptr );
3259 ptr = (CvBtClassifier*) cvAlloc( data_size );
3260 memset( ptr, 0, data_size );
3262 if( CV_IS_TUNABLE( flags ) )
3264 ptr->seq = cvCreateSeq( 0, sizeof( *(ptr->seq) ), sizeof( *(ptr->trees) ),
3265 cvCreateMemStorage() );
3270 data_size = numclasses * numiter * sizeof( *(ptr->trees) );
3271 ptr->trees = (CvCARTClassifier**) cvAlloc( data_size );
3272 memset( ptr->trees, 0, data_size );
3274 ptr->numiter = numiter;
3278 ptr->numclasses = numclasses;
3281 ptr->eval = icvEvalBtClassifier[(int) type];
3282 ptr->tune = cvTuneBtClassifier;
3283 ptr->save = cvSaveBtClassifier;
3284 ptr->release = cvReleaseBtClassifier;
3290 CvClassifier* cvCreateBtClassifier( CvMat* trainData,
3292 CvMat* trainClasses,
3294 CvMat* missedMeasurementsMask,
3298 CvClassifierTrainParams* trainParams )
3300 CvBtClassifier* ptr = 0;
3302 CV_FUNCNAME( "cvCreateBtClassifier" );
3309 CvCARTClassifier** trees;
3312 CV_ASSERT( trainData != NULL );
3313 CV_ASSERT( trainClasses != NULL );
3314 CV_ASSERT( typeMask == NULL );
3315 CV_ASSERT( missedMeasurementsMask == NULL );
3316 CV_ASSERT( compIdx == NULL );
3317 CV_ASSERT( weights == NULL );
3318 CV_ASSERT( trainParams != NULL );
3320 type = ((CvBtClassifierTrainParams*) trainParams)->type;
3322 if( type >= CV_DABCLASS && type <= CV_GABCLASS && sampleIdx )
3324 CV_ERROR( CV_StsBadArg, "Sample indices are not supported for this type" );
3327 if( type == CV_LKCLASS )
3332 cvMinMaxLoc( trainClasses, &min_val, &max_val );
3333 num_classes = (int) (max_val + 1.0);
3335 CV_ASSERT( num_classes >= 2 );
3341 num_iter = ((CvBtClassifierTrainParams*) trainParams)->numiter;
3343 CV_ASSERT( num_iter > 0 );
3345 ptr = icvAllocBtClassifier( type, CV_TUNABLE | flags, num_classes, num_iter );
3346 ptr->numfeatures = (CV_IS_ROW_SAMPLE( flags )) ? trainData->cols : trainData->rows;
3350 printf( "Iteration %d\n", 1 );
3352 data_size = sizeof( *trees ) * ptr->numclasses;
3353 CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3355 CV_CALL( ptr->trainer = cvBtStart( trees, trainData, flags, trainClasses, sampleIdx,
3356 ((CvBtClassifierTrainParams*) trainParams)->numsplits, type, num_classes,
3357 &(((CvBtClassifierTrainParams*) trainParams)->param[0]) ) );
3359 CV_CALL( cvSeqPushMulti( ptr->seq, trees, ptr->numclasses ) );
3360 CV_CALL( cvFree( &trees ) );
3363 for( i = 1; i < num_iter; i++ )
3365 ptr->tune( (CvClassifier*) ptr, NULL, CV_TUNABLE, NULL, NULL, NULL, NULL, NULL );
3367 if( !CV_IS_TUNABLE( flags ) )
3370 ptr->tune( (CvClassifier*) ptr, NULL, 0, NULL, NULL, NULL, NULL, NULL );
3375 return (CvClassifier*) ptr;
3379 CvClassifier* cvCreateBtClassifierFromFile( const char* filename )
3381 CvBtClassifier* ptr = 0;
3383 CV_FUNCNAME( "cvCreateBtClassifierFromFile" );
3390 int num_classifiers;
3394 int values_read = -1;
3396 CV_ASSERT( filename != NULL );
3399 file = fopen( filename, "r" );
3402 CV_ERROR( CV_StsError, "Unable to open file" );
3405 values_read = fscanf( file, "%d %d %d %d", &type, &num_classes, &num_features, &num_classifiers );
3406 CV_Assert(values_read == 4);
3408 CV_ASSERT( type >= (int) CV_DABCLASS && type <= (int) CV_MREG );
3409 CV_ASSERT( num_features > 0 );
3410 CV_ASSERT( num_classifiers > 0 );
3412 if( (CvBoostType) type != CV_LKCLASS )
3416 ptr = icvAllocBtClassifier( (CvBoostType) type, 0, num_classes, num_classifiers );
3417 ptr->numfeatures = num_features;
3419 for( i = 0; i < num_classes * num_classifiers; i++ )
3422 CvCARTClassifier* tree;
3424 values_read = fscanf( file, "%d", &count );
3425 CV_Assert(values_read == 1);
3427 data_size = sizeof( *tree )
3428 + count * ( sizeof( *(tree->compidx) ) + sizeof( *(tree->threshold) ) +
3429 sizeof( *(tree->right) ) + sizeof( *(tree->left) ) )
3430 + (count + 1) * ( sizeof( *(tree->val) ) );
3431 CV_CALL( tree = (CvCARTClassifier*) cvAlloc( data_size ) );
3432 memset( tree, 0, data_size );
3433 tree->eval = cvEvalCARTClassifier;
3436 tree->release = cvReleaseCARTClassifier;
3437 tree->compidx = (int*) ( tree + 1 );
3438 tree->threshold = (float*) ( tree->compidx + count );
3439 tree->left = (int*) ( tree->threshold + count );
3440 tree->right = (int*) ( tree->left + count );
3441 tree->val = (float*) ( tree->right + count );
3443 tree->count = count;
3444 for( j = 0; j < tree->count; j++ )
3446 values_read = fscanf( file, "%d %g %d %d", &(tree->compidx[j]),
3447 &(tree->threshold[j]),
3449 &(tree->right[j]) );
3450 CV_Assert(values_read == 4);
3452 for( j = 0; j <= tree->count; j++ )
3454 values_read = fscanf( file, "%g", &(tree->val[j]) );
3455 CV_Assert(values_read == 1);
3457 ptr->trees[i] = tree;
3464 return (CvClassifier*) ptr;
3467 /****************************************************************************************\
3468 * Utility functions *
3469 \****************************************************************************************/
3472 CvMat* cvTrimWeights( CvMat* weights, CvMat* idx, float factor )
3476 CV_FUNCNAME( "cvTrimWeights" );
3485 float* sorted_weights;
3487 CV_ASSERT( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
3490 sorted_weights = NULL;
3492 if( factor > 0.0F && factor < 1.0F )
3496 CV_MAT2VEC( *weights, wdata, wstep, wnum );
3497 num = ( idx == NULL ) ? wnum : MAX( idx->rows, idx->cols );
3499 data_size = num * sizeof( *sorted_weights );
3500 sorted_weights = (float*) cvAlloc( data_size );
3501 memset( sorted_weights, 0, data_size );
3504 for( i = 0; i < num; i++ )
3506 index = icvGetIdxAt( idx, i );
3507 sorted_weights[i] = *((float*) (wdata + index * wstep));
3508 sum_weights += sorted_weights[i];
3511 std::sort(sorted_weights, sorted_weights + num);
3513 sum_weights *= (1.0F - factor);
3516 do { sum_weights -= sorted_weights[++i]; }
3517 while( sum_weights > 0.0F && i < (num - 1) );
3519 threshold = sorted_weights[i];
3521 while( i > 0 && sorted_weights[i-1] == threshold ) i--;
3523 if( i > 0 || ( idx != NULL && CV_MAT_TYPE( idx->type ) != CV_32FC1 ) )
3525 CV_CALL( ptr = cvCreateMat( 1, num - i, CV_32FC1 ) );
3527 for( i = 0; i < num; i++ )
3529 index = icvGetIdxAt( idx, i );
3530 if( *((float*) (wdata + index * wstep)) >= threshold )
3532 CV_MAT_ELEM( *ptr, float, 0, count ) = (float) index;
3537 assert( count == ptr->cols );
3539 cvFree( &sorted_weights );
3549 void cvReadTrainData( const char* filename, int flags,
3551 CvMat** trainClasses )
3554 CV_FUNCNAME( "cvReadTrainData" );
3562 int values_read = -1;
3564 if( filename == NULL )
3566 CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3568 if( trainData == NULL )
3570 CV_ERROR( CV_StsNullPtr, "trainData must be not NULL" );
3572 if( trainClasses == NULL )
3574 CV_ERROR( CV_StsNullPtr, "trainClasses must be not NULL" );
3578 *trainClasses = NULL;
3579 file = fopen( filename, "r" );
3582 CV_ERROR( CV_StsError, "Unable to open file" );
3585 values_read = fscanf( file, "%d %d", &m, &n );
3586 CV_Assert(values_read == 2);
3588 if( CV_IS_ROW_SAMPLE( flags ) )
3590 CV_CALL( *trainData = cvCreateMat( m, n, CV_32FC1 ) );
3594 CV_CALL( *trainData = cvCreateMat( n, m, CV_32FC1 ) );
3597 CV_CALL( *trainClasses = cvCreateMat( 1, m, CV_32FC1 ) );
3599 for( i = 0; i < m; i++ )
3601 for( j = 0; j < n; j++ )
3603 values_read = fscanf( file, "%f", &val );
3604 CV_Assert(values_read == 1);
3605 if( CV_IS_ROW_SAMPLE( flags ) )
3607 CV_MAT_ELEM( **trainData, float, i, j ) = val;
3611 CV_MAT_ELEM( **trainData, float, j, i ) = val;
3614 values_read = fscanf( file, "%f", &val );
3615 CV_Assert(values_read == 2);
3616 CV_MAT_ELEM( **trainClasses, float, 0, i ) = val;
3626 void cvWriteTrainData( const char* filename, int flags,
3627 CvMat* trainData, CvMat* trainClasses, CvMat* sampleIdx )
3629 CV_FUNCNAME( "cvWriteTrainData" );
3641 if( filename == NULL )
3643 CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3645 if( trainData == NULL || CV_MAT_TYPE( trainData->type ) != CV_32FC1 )
3647 CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainData" );
3649 if( CV_IS_ROW_SAMPLE( flags ) )
3651 m = trainData->rows;
3652 n = trainData->cols;
3656 n = trainData->rows;
3657 m = trainData->cols;
3659 if( trainClasses == NULL || CV_MAT_TYPE( trainClasses->type ) != CV_32FC1 ||
3660 MIN( trainClasses->rows, trainClasses->cols ) != 1 )
3662 CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainClasses" );
3664 clsrow = (trainClasses->rows == 1);
3665 if( m != ( (clsrow) ? trainClasses->cols : trainClasses->rows ) )
3667 CV_ERROR( CV_StsUnmatchedSizes, "Incorrect trainData and trainClasses sizes" );
3670 if( sampleIdx != NULL )
3672 count = (sampleIdx->rows == 1) ? sampleIdx->cols : sampleIdx->rows;
3680 file = fopen( filename, "w" );
3683 CV_ERROR( CV_StsError, "Unable to create file" );
3686 fprintf( file, "%d %d\n", count, n );
3688 for( i = 0; i < count; i++ )
3692 if( sampleIdx->rows == 1 )
3694 sc = cvGet2D( sampleIdx, 0, i );
3698 sc = cvGet2D( sampleIdx, i, 0 );
3700 idx = (int) sc.val[0];
3706 for( j = 0; j < n; j++ )
3708 fprintf( file, "%g ", ( (CV_IS_ROW_SAMPLE( flags ))
3709 ? CV_MAT_ELEM( *trainData, float, idx, j )
3710 : CV_MAT_ELEM( *trainData, float, j, idx ) ) );
3712 fprintf( file, "%g\n", ( (clsrow)
3713 ? CV_MAT_ELEM( *trainClasses, float, 0, idx )
3714 : CV_MAT_ELEM( *trainClasses, float, idx, 0 ) ) );
3723 #define ICV_RAND_SHUFFLE( suffix, type ) \
3724 static void icvRandShuffle_##suffix( uchar* data, size_t step, int num ) \
3732 CvRNG state = cvRNG((int)seed); \
3734 for( i = 0; i < (num-1); i++ ) \
3736 rn = ((float) cvRandInt( &state )) / (1.0F + UINT_MAX); \
3737 CV_SWAP( *((type*)(data + i * step)), \
3738 *((type*)(data + ( i + (int)( rn * (num - i ) ) )* step)), \
3743 ICV_RAND_SHUFFLE( 8U, uchar )
3745 ICV_RAND_SHUFFLE( 16S, short )
3747 ICV_RAND_SHUFFLE( 32S, int )
3749 ICV_RAND_SHUFFLE( 32F, float )
3752 void cvRandShuffleVec( CvMat* mat )
3754 CV_FUNCNAME( "cvRandShuffle" );
3762 if( (mat == NULL) || !CV_IS_MAT( mat ) || MIN( mat->rows, mat->cols ) != 1 )
3764 CV_ERROR( CV_StsUnsupportedFormat, "" );
3767 CV_MAT2VEC( *mat, data, step, num );
3768 switch( CV_MAT_TYPE( mat->type ) )
3771 icvRandShuffle_8U( data, step, num);
3774 icvRandShuffle_16S( data, step, num);
3777 icvRandShuffle_32S( data, step, num);
3780 icvRandShuffle_32F( data, step, num);
3783 CV_ERROR( CV_StsUnsupportedFormat, "" );