1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 By downloading, copying, installing or using the software you agree to this license.
6 If you do not agree to this license, do not download, install,
7 copy or use the software.
10 Intel License Agreement
12 Copyright (C) 2000, Intel Corporation, all rights reserved.
13 Third party copyrights are property of their respective owners.
15 Redistribution and use in source and binary forms, with or without modification,
16 are permitted provided that the following conditions are met:
18 * Redistribution's of source code must retain the above copyright notice,
19 this list of conditions and the following disclaimer.
21 * Redistribution's in binary form must reproduce the above copyright notice,
22 this list of conditions and the following disclaimer in the documentation
23 and/or other materials provided with the distribution.
25 * The name of Intel Corporation may not be used to endorse or promote products
26 derived from this software without specific prior written permission.
28 This software is provided by the copyright holders and contributors "as is" and
29 any express or implied warranties, including, but not limited to, the implied
30 warranties of merchantability and fitness for a particular purpose are disclaimed.
31 In no event shall the Intel Corporation or contributors be liable for any direct,
32 indirect, incidental, special, exemplary, or consequential damages
33 (including, but not limited to, procurement of substitute goods or services;
34 loss of use, data, or profits; or business interruption) however caused
35 and on any theory of liability, whether in contract, strict liability,
36 or tort (including negligence or otherwise) arising in any way out of
37 the use of this software, even if advised of the possibility of such damage.
41 #include "precomp.hpp"
43 static const float ord_nan = FLT_MAX*0.5f;
44 static const int min_block_size = 1 << 16;
45 static const int block_size_delta = 1 << 10;
47 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
48 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
50 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
51 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
55 void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
56 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
57 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
58 bool _shared, bool _add_labels, bool _update_data )
60 CvMat* sample_indices = 0;
64 CvPair16u32s* pair16u32s_ptr = 0;
65 CvDTreeTrainData* data = 0;
68 unsigned short* udst = 0;
71 CV_FUNCNAME( "CvERTreeTrainData::set_data" );
75 int sample_all = 0, r_type, cv_n;
76 int total_c_count = 0;
77 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
78 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
81 const int *sidx = 0, *vidx = 0;
83 uint64 effective_buf_size = 0;
84 int effective_buf_height = 0, effective_buf_width = 0;
86 if ( _params.use_surrogates )
87 CV_ERROR(CV_StsBadArg, "CvERTrees do not support surrogate splits");
89 if( _update_data && data_root )
91 CV_ERROR(CV_StsBadArg, "CvERTrees do not support data update");
99 CV_CALL( set_params( _params ));
101 // check parameter types and sizes
102 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
104 train_data = _train_data;
105 responses = _responses;
106 missing_mask = _missing_mask;
108 if( _tflag == CV_ROW_SAMPLE )
110 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
113 ms_step = _missing_mask->step, mv_step = 1;
117 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
120 mv_step = _missing_mask->step, ms_step = 1;
124 sample_count = sample_all;
129 CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
130 sidx = sample_indices->data.i;
131 sample_count = sample_indices->rows + sample_indices->cols - 1;
136 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
137 vidx = var_idx->data.i;
138 var_count = var_idx->rows + var_idx->cols - 1;
141 if( !CV_IS_MAT(_responses) ||
142 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
143 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
144 (_responses->rows != 1 && _responses->cols != 1) ||
145 _responses->rows + _responses->cols - 1 != sample_all )
146 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
147 "floating-point vector containing as many elements as "
148 "the total number of samples in the training data matrix" );
151 if ( sample_count < 65536 )
154 r_type = CV_VAR_CATEGORICAL;
156 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
158 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
163 is_classifier = r_type == CV_VAR_CATEGORICAL;
165 // step 0. calc the number of categorical vars
166 for( vi = 0; vi < var_count; vi++ )
168 char vt = var_type0 ? var_type0->data.ptr[vi] : CV_VAR_ORDERED;
169 var_type->data.i[vi] = vt == CV_VAR_CATEGORICAL ? cat_var_count++ : ord_var_count--;
172 ord_var_count = ~ord_var_count;
173 cv_n = params.cv_folds;
174 // set the two last elements of var_type array to be able
175 // to locate responses and cross-validation labels using
176 // the corresponding get_* functions.
177 var_type->data.i[var_count] = cat_var_count;
178 var_type->data.i[var_count+1] = cat_var_count+1;
180 // in case of single ordered predictor we need dummy cv_labels
181 // for safe split_node_data() operation
182 have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
184 work_var_count = cat_var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
187 buf_count = shared ? 2 : 1;
189 buf_size = -1; // the member buf_size is obsolete
191 effective_buf_size = (uint64)(work_var_count + 1)*(uint64)sample_count * buf_count; // this is the total size of "CvMat buf" to be allocated
192 effective_buf_width = sample_count;
193 effective_buf_height = work_var_count+1;
195 if (effective_buf_width >= effective_buf_height)
196 effective_buf_height *= buf_count;
198 effective_buf_width *= buf_count;
200 if ((uint64)effective_buf_width * (uint64)effective_buf_height != effective_buf_size)
202 CV_Error(CV_StsBadArg, "The memory buffer cannot be allocated since its size exceeds integer fields limit");
207 CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_16UC1 ));
208 CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
212 CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_32SC1 ));
213 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
216 size = is_classifier ? cat_var_count+1 : cat_var_count;
217 size = !size ? 1 : size;
218 CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
219 CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
221 size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
222 size = !size ? 1 : size;
223 CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
225 // now calculate the maximum size of split,
226 // create memory storage that will keep nodes and splits of the decision tree
227 // allocate root node and the buffer for the whole training data
228 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
229 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
230 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
231 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
232 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
233 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
235 nv_size = var_count*sizeof(int);
236 nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
238 temp_block_size = nv_size;
242 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
243 CV_ERROR( CV_StsOutOfRange,
244 "The many folds in cross-validation for such a small dataset" );
246 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
247 temp_block_size = MAX(temp_block_size, cv_size);
250 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
251 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
252 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
254 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
256 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
263 _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
264 if (is_buf_16u && (cat_var_count || is_classifier))
265 _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
267 // transform the training data to convenient representation
268 for( vi = 0; vi <= var_count; vi++ )
271 const uchar* mask = 0;
272 int m_step = 0, step;
273 const int* idata = 0;
274 const float* fdata = 0;
277 if( vi < var_count ) // analyze i-th input variable
279 int vi0 = vidx ? vidx[vi] : vi;
280 ci = get_var_type(vi);
281 step = ds_step; m_step = ms_step;
282 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
283 idata = _train_data->data.i + vi0*dv_step;
285 fdata = _train_data->data.fl + vi0*dv_step;
287 mask = _missing_mask->data.ptr + vi0*mv_step;
289 else // analyze _responses
292 step = CV_IS_MAT_CONT(_responses->type) ?
293 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
294 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
295 idata = _responses->data.i;
297 fdata = _responses->data.fl;
300 if( (vi < var_count && ci>=0) ||
301 (vi == var_count && is_classifier) ) // process categorical variable or response
303 int c_count, prev_label;
307 udst = (unsigned short*)(buf->data.s + ci*sample_count);
309 idst = buf->data.i + ci*sample_count;
312 for( i = 0; i < sample_count; i++ )
314 int val = INT_MAX, si = sidx ? sidx[i] : i;
315 if( !mask || !mask[(size_t)si*m_step] )
318 val = idata[(size_t)si*step];
321 float t = fdata[(size_t)si*step];
325 sprintf( err, "%d-th value of %d-th (categorical) "
326 "variable is not an integer", i, vi );
327 CV_ERROR( CV_StsBadArg, err );
333 sprintf( err, "%d-th value of %d-th (categorical) "
334 "variable is too large", i, vi );
335 CV_ERROR( CV_StsBadArg, err );
342 pair16u32s_ptr[i].u = udst + i;
343 pair16u32s_ptr[i].i = _idst + i;
348 int_ptr[i] = idst + i;
352 c_count = num_valid > 0;
356 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
357 // count the categories
358 for( i = 1; i < num_valid; i++ )
359 if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
364 icvSortIntPtr( int_ptr, sample_count, 0 );
365 // count the categories
366 for( i = 1; i < num_valid; i++ )
367 c_count += *int_ptr[i] != *int_ptr[i-1];
371 max_c_count = MAX( max_c_count, c_count );
372 cat_count->data.i[ci] = c_count;
373 cat_ofs->data.i[ci] = total_c_count;
375 // resize cat_map, if need
376 if( cat_map->cols < total_c_count + c_count )
379 CV_CALL( cat_map = cvCreateMat( 1,
380 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
381 for( i = 0; i < total_c_count; i++ )
382 cat_map->data.i[i] = tmp_map->data.i[i];
383 cvReleaseMat( &tmp_map );
386 c_map = cat_map->data.i + total_c_count;
387 total_c_count += c_count;
392 // compact the class indices and build the map
393 prev_label = ~*pair16u32s_ptr[0].i;
394 for( i = 0; i < num_valid; i++ )
396 int cur_label = *pair16u32s_ptr[i].i;
397 if( cur_label != prev_label )
398 c_map[++c_count] = prev_label = cur_label;
399 *pair16u32s_ptr[i].u = (unsigned short)c_count;
401 // replace labels for missing values with 65535
402 for( ; i < sample_count; i++ )
403 *pair16u32s_ptr[i].u = 65535;
407 // compact the class indices and build the map
408 prev_label = ~*int_ptr[0];
409 for( i = 0; i < num_valid; i++ )
411 int cur_label = *int_ptr[i];
412 if( cur_label != prev_label )
413 c_map[++c_count] = prev_label = cur_label;
414 *int_ptr[i] = c_count;
416 // replace labels for missing values with -1
417 for( ; i < sample_count; i++ )
421 else if( ci < 0 ) // process ordered variable
423 for( i = 0; i < sample_count; i++ )
426 int si = sidx ? sidx[i] : i;
427 if( !mask || !mask[(size_t)si*m_step] )
430 val = (float)idata[(size_t)si*step];
432 val = fdata[(size_t)si*step];
434 if( fabs(val) >= ord_nan )
436 sprintf( err, "%d-th value of %d-th (ordered) "
437 "variable (=%g) is too large", i, vi, val );
438 CV_ERROR( CV_StsBadArg, err );
445 data_root->set_num_valid(vi, num_valid);
450 udst = (unsigned short*)(buf->data.s + get_work_var_count()*sample_count);
452 idst = buf->data.i + get_work_var_count()*sample_count;
454 for (i = 0; i < sample_count; i++)
457 udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
459 idst[i] = sidx ? sidx[i] : i;
464 unsigned short* usdst = 0;
469 usdst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
470 for( i = vi = 0; i < sample_count; i++ )
472 usdst[i] = (unsigned short)vi++;
473 vi &= vi < cv_n ? -1 : 0;
476 for( i = 0; i < sample_count; i++ )
478 int a = (*rng)(sample_count);
479 int b = (*rng)(sample_count);
480 unsigned short unsh = (unsigned short)vi;
481 CV_SWAP( usdst[a], usdst[b], unsh );
486 idst2 = buf->data.i + (get_work_var_count()-1)*sample_count;
487 for( i = vi = 0; i < sample_count; i++ )
490 vi &= vi < cv_n ? -1 : 0;
493 for( i = 0; i < sample_count; i++ )
495 int a = (*rng)(sample_count);
496 int b = (*rng)(sample_count);
497 CV_SWAP( idst2[a], idst2[b], vi );
503 cat_map->cols = MAX( total_c_count, 1 );
505 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
506 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
507 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
509 have_priors = is_classifier && params.priors;
512 int m = get_num_classes();
514 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
515 for( i = 0; i < m; i++ )
517 double val = have_priors ? params.priors[i] : 1.;
519 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
520 priors->data.db[i] = val;
526 cvScale( priors, priors, 1./sum );
528 CV_CALL( priors_mult = cvCloneMat( priors ));
529 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
532 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
533 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
545 cvReleaseMat( &var_type0 );
546 cvReleaseMat( &sample_indices );
547 cvReleaseMat( &tmp_map );
550 void CvERTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
551 const float** ord_values, const int** missing, int* sample_indices_buf )
553 int vidx = var_idx ? var_idx->data.i[vi] : vi;
554 int node_sample_count = n->sample_count;
555 // may use missing_buf as buffer for sample indices!
556 const int* sample_indices = get_sample_indices(n, sample_indices_buf ? sample_indices_buf : missing_buf);
558 int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
559 int m_step = missing_mask ? missing_mask->step/CV_ELEM_SIZE(missing_mask->type) : 1;
560 if( tflag == CV_ROW_SAMPLE )
562 for( int i = 0; i < node_sample_count; i++ )
564 int idx = sample_indices[i];
565 missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + idx * m_step + vi) : 0;
566 ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
570 for( int i = 0; i < node_sample_count; i++ )
572 int idx = sample_indices[i];
573 missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + vi* m_step + idx) : 0;
574 ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
576 *ord_values = ord_values_buf;
577 *missing = missing_buf;
581 const int* CvERTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
583 return get_cat_var_data( n, var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0), indices_buf );
587 const int* CvERTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
590 return get_cat_var_data( n, var_count + (is_classifier ? 1 : 0), labels_buf );
595 const int* CvERTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf )
597 int ci = get_var_type( vi);
598 const int* cat_values = 0;
600 cat_values = buf->data.i + n->buf_idx*get_length_subbuf() + ci*sample_count + n->offset;
602 const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
603 ci*sample_count + n->offset);
604 for( int i = 0; i < n->sample_count; i++ )
605 cat_values_buf[i] = short_values[i];
606 cat_values = cat_values_buf;
611 void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
612 float* values, uchar* missing,
613 float* _responses, bool get_class_idx )
615 CvMat* subsample_idx = 0;
616 CvMat* subsample_co = 0;
618 cv::AutoBuffer<uchar> inn_buf(sample_count*(sizeof(float) + sizeof(int)));
620 CV_FUNCNAME( "CvERTreeTrainData::get_vectors" );
624 int i, vi, total = sample_count, count = total, cur_ofs = 0;
630 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
631 sidx = subsample_idx->data.i;
632 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
633 co = subsample_co->data.i;
634 cvZero( subsample_co );
635 count = subsample_idx->cols + subsample_idx->rows - 1;
636 for( i = 0; i < count; i++ )
638 for( i = 0; i < total; i++ )
640 int count_i = co[i*2];
643 co[i*2+1] = cur_ofs*var_count;
650 memset( missing, 1, count*var_count );
652 for( vi = 0; vi < var_count; vi++ )
654 int ci = get_var_type(vi);
655 if( ci >= 0 ) // categorical
657 float* dst = values + vi;
658 uchar* m = missing ? missing + vi : 0;
659 int* lbls_buf = (int*)(uchar*)inn_buf;
660 const int* src = get_cat_var_data(data_root, vi, lbls_buf);
662 for( i = 0; i < count; i++, dst += var_count )
664 int idx = sidx ? sidx[i] : i;
669 *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
676 int* mis_buf = (int*)(uchar*)inn_buf;
677 const float *dst = 0;
679 get_ord_var_data(data_root, vi, values + vi, mis_buf, &dst, &mis, 0);
680 for (int si = 0; si < total; si++)
681 *(missing + vi + si) = mis[si] == 0 ? 0 : 1;
690 int* lbls_buf = (int*)(uchar*)inn_buf;
691 const int* src = get_class_labels(data_root, lbls_buf);
692 for( i = 0; i < count; i++ )
694 int idx = sidx ? sidx[i] : i;
695 int val = get_class_idx ? src[idx] :
696 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
697 _responses[i] = (float)val;
702 float* _values_buf = (float*)(uchar*)inn_buf;
703 int* sample_idx_buf = (int*)(_values_buf + sample_count);
704 const float* _values = get_ord_responses(data_root, _values_buf, sample_idx_buf);
705 for( i = 0; i < count; i++ )
707 int idx = sidx ? sidx[i] : i;
708 _responses[i] = _values[idx];
715 cvReleaseMat( &subsample_idx );
716 cvReleaseMat( &subsample_co );
719 CvDTreeNode* CvERTreeTrainData::subsample_data( const CvMat* _subsample_idx )
721 CvDTreeNode* root = 0;
723 CV_FUNCNAME( "CvERTreeTrainData::subsample_data" );
728 CV_ERROR( CV_StsError, "No training data has been set" );
730 if( !_subsample_idx )
732 // make a copy of the root node
735 root = new_node( 0, 1, 0, 0 );
738 root->num_valid = temp.num_valid;
739 if( root->num_valid )
741 for( i = 0; i < var_count; i++ )
742 root->num_valid[i] = data_root->num_valid[i];
744 root->cv_Tn = temp.cv_Tn;
745 root->cv_node_risk = temp.cv_node_risk;
746 root->cv_node_error = temp.cv_node_error;
749 CV_ERROR( CV_StsError, "_subsample_idx must be null for extra-trees" );
755 double CvForestERTree::calc_node_dir( CvDTreeNode* node )
757 char* dir = (char*)data->direction->data.ptr;
758 int i, n = node->sample_count, vi = node->split->var_idx;
761 assert( !node->split->inversed );
763 if( data->get_var_type(vi) >= 0 ) // split on categorical var
765 cv::AutoBuffer<uchar> inn_buf(n*sizeof(int)*(!data->have_priors ? 1 : 2));
766 int* labels_buf = (int*)(uchar*)inn_buf;
767 const int* labels = data->get_cat_var_data( node, vi, labels_buf );
768 const int* subset = node->split->subset;
769 if( !data->have_priors )
771 int sum = 0, sum_abs = 0;
773 for( i = 0; i < n; i++ )
776 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
777 CV_DTREE_CAT_DIR(idx,subset) : 0;
778 sum += d; sum_abs += d & 1;
782 R = (sum_abs + sum) >> 1;
783 L = (sum_abs - sum) >> 1;
787 const double* priors = data->priors_mult->data.db;
788 double sum = 0, sum_abs = 0;
789 int *responses_buf = labels_buf + n;
790 const int* responses = data->get_class_labels(node, responses_buf);
792 for( i = 0; i < n; i++ )
795 double w = priors[responses[i]];
796 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
797 sum += d*w; sum_abs += (d & 1)*w;
801 R = (sum_abs + sum) * 0.5;
802 L = (sum_abs - sum) * 0.5;
805 else // split on ordered var
807 float split_val = node->split->ord.c;
808 cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(!data->have_priors ? 1 : 2) + sizeof(float)));
809 float* val_buf = (float*)(uchar*)inn_buf;
810 int* missing_buf = (int*)(val_buf + n);
811 const float* val = 0;
812 const int* missing = 0;
813 data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing, 0 );
815 if( !data->have_priors )
818 for( i = 0; i < n; i++ )
824 if ( val[i] < split_val)
839 const double* priors = data->priors_mult->data.db;
840 int* responses_buf = missing_buf + n;
841 const int* responses = data->get_class_labels(node, responses_buf);
843 for( i = 0; i < n; i++ )
849 double w = priors[responses[i]];
850 if ( val[i] < split_val)
865 node->maxlr = MAX( L, R );
866 return node->split->quality/(L + R);
869 CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
872 const float epsilon = FLT_EPSILON*2;
873 const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
875 int n = node->sample_count;
876 int m = data->get_num_classes();
878 cv::AutoBuffer<uchar> inn_buf;
880 inn_buf.allocate(n*(2*sizeof(int) + sizeof(float)));
881 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
882 float* values_buf = (float*)ext_buf;
883 int* missing_buf = (int*)(values_buf + n);
884 const float* values = 0;
885 const int* missing = 0;
886 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
887 int* responses_buf = missing_buf + n;
888 const int* responses = data->get_class_labels( node, responses_buf );
890 double lbest_val = 0, rbest_val = 0, best_val = init_quality, split_val = 0;
891 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
892 bool is_find_split = false;
895 while ( missing[smpi] && (smpi < n) )
901 for (; smpi < n; smpi++)
903 float ptemp = values[smpi];
904 int ms = missing[smpi];
911 float fdiff = pmax-pmin;
914 is_find_split = true;
915 cv::RNG* rng = data->rng;
916 split_val = pmin + rng->uniform(0.f, 1.f) * fdiff ;
917 if (split_val - pmin <= FLT_EPSILON)
918 split_val = pmin + split_delta;
919 if (pmax - split_val <= FLT_EPSILON)
920 split_val = pmax - split_delta;
922 // calculate Gini index
925 cv::AutoBuffer<int> lrc(m*2);
926 int *lc = lrc, *rc = lc + m;
929 // init arrays of class instance counters on both sides of the split
930 for(int i = 0; i < m; i++ )
935 for( int si = 0; si < n; si++ )
937 int r = responses[si];
938 float val = values[si];
939 int ms = missing[si];
941 if ( val < split_val )
952 for (int i = 0; i < m; i++)
954 lbest_val += lc[i]*lc[i];
955 rbest_val += rc[i]*rc[i];
957 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
961 cv::AutoBuffer<double> lrc(m*2);
962 double *lc = lrc, *rc = lc + m;
965 // init arrays of class instance counters on both sides of the split
966 for(int i = 0; i < m; i++ )
971 for( int si = 0; si < n; si++ )
973 int r = responses[si];
974 float val = values[si];
975 int ms = missing[si];
976 double p = priors[r];
978 if ( val < split_val )
989 for (int i = 0; i < m; i++)
991 lbest_val += lc[i]*lc[i];
992 rbest_val += rc[i]*rc[i];
994 best_val = (lbest_val*R + rbest_val*L) / (L*R);
999 CvDTreeSplit* split = 0;
1002 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1003 split->var_idx = vi;
1004 split->ord.c = (float)split_val;
1005 split->ord.split_point = -1;
1006 split->inversed = 0;
1007 split->quality = (float)best_val;
1012 CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1015 int ci = data->get_var_type(vi);
1016 int n = node->sample_count;
1017 int cm = data->get_num_classes();
1018 int vm = data->cat_count->data.i[ci];
1019 double best_val = init_quality;
1020 CvDTreeSplit *split = 0;
1024 cv::AutoBuffer<int> inn_buf;
1026 inn_buf.allocate(2*n);
1027 int* ext_buf = _ext_buf ? (int*)_ext_buf : (int*)inn_buf;
1029 const int* labels = data->get_cat_var_data( node, vi, ext_buf );
1030 const int* responses = data->get_class_labels( node, ext_buf + n );
1032 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1034 // create random class mask
1035 cv::AutoBuffer<int> valid_cidx(vm);
1036 for (int i = 0; i < vm; i++)
1040 for (int si = 0; si < n; si++)
1043 if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1048 int valid_ccount = 0;
1049 for (int i = 0; i < vm; i++)
1050 if (valid_cidx[i] >= 0)
1052 valid_cidx[i] = valid_ccount;
1055 if (valid_ccount > 1)
1057 CvRNG* rng = forest->get_rng();
1058 int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1060 CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1062 memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1063 cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1064 cvSet( &submask, cvScalar(1) );
1065 for (int i = 0; i < valid_ccount; i++)
1068 int i1 = cvRandInt( rng ) % valid_ccount;
1069 int i2 = cvRandInt( rng ) % valid_ccount;
1070 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1073 split = _split ? _split : data->new_split_cat( 0, -1.0f );
1074 split->var_idx = vi;
1075 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1077 // calculate Gini index
1078 double lbest_val = 0, rbest_val = 0;
1081 cv::AutoBuffer<int> lrc(cm*2);
1082 int *lc = lrc, *rc = lc + cm;
1084 // init arrays of class instance counters on both sides of the split
1085 for(int i = 0; i < cm; i++ )
1090 for( int si = 0; si < n; si++ )
1092 int r = responses[si];
1093 int var_class_idx = labels[si];
1094 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1096 int mask_class_idx = valid_cidx[var_class_idx];
1097 if (var_class_mask->data.ptr[mask_class_idx])
1101 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1109 for (int i = 0; i < cm; i++)
1111 lbest_val += lc[i]*lc[i];
1112 rbest_val += rc[i]*rc[i];
1114 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
1118 cv::AutoBuffer<int> lrc(cm*2);
1119 int *lc = lrc, *rc = lc + cm;
1120 double L = 0, R = 0;
1121 // init arrays of class instance counters on both sides of the split
1122 for(int i = 0; i < cm; i++ )
1127 for( int si = 0; si < n; si++ )
1129 int r = responses[si];
1130 int var_class_idx = labels[si];
1131 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1133 double p = priors[si];
1134 int mask_class_idx = valid_cidx[var_class_idx];
1136 if (var_class_mask->data.ptr[mask_class_idx])
1140 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1148 for (int i = 0; i < cm; i++)
1150 lbest_val += lc[i]*lc[i];
1151 rbest_val += rc[i]*rc[i];
1153 best_val = (lbest_val*R + rbest_val*L) / (L*R);
1155 split->quality = (float)best_val;
1157 cvReleaseMat(&var_class_mask);
1164 CvDTreeSplit* CvForestERTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1167 const float epsilon = FLT_EPSILON*2;
1168 const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
1169 int n = node->sample_count;
1170 cv::AutoBuffer<uchar> inn_buf;
1172 inn_buf.allocate(n*(2*sizeof(int) + 2*sizeof(float)));
1173 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
1174 float* values_buf = (float*)ext_buf;
1175 int* missing_buf = (int*)(values_buf + n);
1176 const float* values = 0;
1177 const int* missing = 0;
1178 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
1179 float* responses_buf = (float*)(missing_buf + n);
1180 int* sample_indices_buf = (int*)(responses_buf + n);
1181 const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
1183 double best_val = init_quality, split_val = 0, lsum = 0, rsum = 0;
1186 bool is_find_split = false;
1189 while ( missing[smpi] && (smpi < n) )
1194 pmin = values[smpi];
1196 for (; smpi < n; smpi++)
1198 float ptemp = values[smpi];
1199 int m = missing[smpi];
1206 float fdiff = pmax-pmin;
1207 if (fdiff > epsilon)
1209 is_find_split = true;
1210 cv::RNG* rng = data->rng;
1211 split_val = pmin + rng->uniform(0.f, 1.f) * fdiff ;
1212 if (split_val - pmin <= FLT_EPSILON)
1213 split_val = pmin + split_delta;
1214 if (pmax - split_val <= FLT_EPSILON)
1215 split_val = pmax - split_delta;
1217 for (int si = 0; si < n; si++)
1219 float r = responses[si];
1220 float val = values[si];
1221 int m = missing[si];
1223 if (val < split_val)
1234 best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1237 CvDTreeSplit* split = 0;
1240 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1241 split->var_idx = vi;
1242 split->ord.c = (float)split_val;
1243 split->ord.split_point = -1;
1244 split->inversed = 0;
1245 split->quality = (float)best_val;
1250 CvDTreeSplit* CvForestERTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1253 int ci = data->get_var_type(vi);
1254 int n = node->sample_count;
1255 int vm = data->cat_count->data.i[ci];
1256 double best_val = init_quality;
1257 CvDTreeSplit *split = 0;
1258 float lsum = 0, rsum = 0;
1262 int base_size = vm*sizeof(int);
1263 cv::AutoBuffer<uchar> inn_buf(base_size);
1265 inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
1266 uchar* base_buf = (uchar*)inn_buf;
1267 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
1268 int* labels_buf = (int*)ext_buf;
1269 const int* labels = data->get_cat_var_data( node, vi, labels_buf );
1270 float* responses_buf = (float*)(labels_buf + n);
1271 int* sample_indices_buf = (int*)(responses_buf + n);
1272 const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
1274 // create random class mask
1275 int *valid_cidx = (int*)base_buf;
1276 for (int i = 0; i < vm; i++)
1280 for (int si = 0; si < n; si++)
1283 if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1288 int valid_ccount = 0;
1289 for (int i = 0; i < vm; i++)
1290 if (valid_cidx[i] >= 0)
1292 valid_cidx[i] = valid_ccount;
1295 if (valid_ccount > 1)
1297 CvRNG* rng = forest->get_rng();
1298 int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1300 CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1302 memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1303 cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1304 cvSet( &submask, cvScalar(1) );
1305 for (int i = 0; i < valid_ccount; i++)
1308 int i1 = cvRandInt( rng ) % valid_ccount;
1309 int i2 = cvRandInt( rng ) % valid_ccount;
1310 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1313 split = _split ? _split : data->new_split_cat( 0, -1.0f);
1314 split->var_idx = vi;
1315 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1318 for( int si = 0; si < n; si++ )
1320 float r = responses[si];
1321 int var_class_idx = labels[si];
1322 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1324 int mask_class_idx = valid_cidx[var_class_idx];
1325 if (var_class_mask->data.ptr[mask_class_idx])
1329 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1337 best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1339 split->quality = (float)best_val;
1341 cvReleaseMat(&var_class_mask);
1348 void CvForestERTree::split_node_data( CvDTreeNode* node )
1350 int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
1351 char* dir = (char*)data->direction->data.ptr;
1352 CvDTreeNode *left = 0, *right = 0;
1353 int new_buf_idx = data->get_child_buf_idx( node );
1354 CvMat* buf = data->buf;
1355 size_t length_buf_row = data->get_length_subbuf();
1356 cv::AutoBuffer<int> temp_buf(n);
1358 complete_node_dir(node);
1360 for( i = nl = nr = 0; i < n; i++ )
1367 bool split_input_data;
1368 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
1369 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
1371 split_input_data = node->depth + 1 < data->params.max_depth &&
1372 (node->left->sample_count > data->params.min_sample_count ||
1373 node->right->sample_count > data->params.min_sample_count);
1375 cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)+sizeof(float)));
1376 // split ordered vars
1377 for( vi = 0; vi < data->var_count; vi++ )
1379 int ci = data->get_var_type(vi);
1380 if (ci >= 0) continue;
1382 int n1 = node->get_num_valid(vi), nr1 = 0;
1383 float* values_buf = (float*)(uchar*)inn_buf;
1384 int* missing_buf = (int*)(values_buf + n);
1385 const float* values = 0;
1386 const int* missing = 0;
1387 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
1389 for( i = 0; i < n; i++ )
1390 nr1 += ((!missing[i]) & dir[i]);
1391 left->set_num_valid(vi, n1 - nr1);
1392 right->set_num_valid(vi, nr1);
1394 // split categorical vars, responses and cv_labels using new_idx relocation table
1395 for( vi = 0; vi < data->get_work_var_count() + data->ord_var_count; vi++ )
1397 int ci = data->get_var_type(vi);
1398 if (ci < 0) continue;
1400 int n1 = node->get_num_valid(vi), nr1 = 0;
1401 const int* src_lbls = data->get_cat_var_data(node, vi, (int*)(uchar*)inn_buf);
1403 for(i = 0; i < n; i++)
1404 temp_buf[i] = src_lbls[i];
1406 if (data->is_buf_16u)
1408 unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row +
1409 ci*scount + left->offset);
1410 unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row +
1411 ci*scount + right->offset);
1413 for( i = 0; i < n; i++ )
1416 int idx = temp_buf[i];
1419 *rdst = (unsigned short)idx;
1421 nr1 += (idx != 65535);
1425 *ldst = (unsigned short)idx;
1430 if( vi < data->var_count )
1432 left->set_num_valid(vi, n1 - nr1);
1433 right->set_num_valid(vi, nr1);
1438 int *ldst = buf->data.i + left->buf_idx*length_buf_row +
1439 ci*scount + left->offset;
1440 int *rdst = buf->data.i + right->buf_idx*length_buf_row +
1441 ci*scount + right->offset;
1443 for( i = 0; i < n; i++ )
1446 int idx = temp_buf[i];
1461 if( vi < data->var_count )
1463 left->set_num_valid(vi, n1 - nr1);
1464 right->set_num_valid(vi, nr1);
1469 // split sample indices
1470 int *sample_idx_src_buf = (int*)(uchar*)inn_buf;
1471 const int* sample_idx_src = 0;
1472 if (split_input_data)
1474 sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
1476 for(i = 0; i < n; i++)
1477 temp_buf[i] = sample_idx_src[i];
1479 int pos = data->get_work_var_count();
1481 if (data->is_buf_16u)
1483 unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
1484 pos*scount + left->offset);
1485 unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row +
1486 pos*scount + right->offset);
1488 for (i = 0; i < n; i++)
1491 unsigned short idx = (unsigned short)temp_buf[i];
1506 int* ldst = buf->data.i + left->buf_idx*length_buf_row +
1507 pos*scount + left->offset;
1508 int* rdst = buf->data.i + right->buf_idx*length_buf_row +
1509 pos*scount + right->offset;
1510 for (i = 0; i < n; i++)
1513 int idx = temp_buf[i];
1528 // deallocate the parent node data that is not needed anymore
1529 data->free_node_data(node);
1532 CvERTrees::CvERTrees()
1536 CvERTrees::~CvERTrees()
1540 std::string CvERTrees::getName() const
1542 return CV_TYPE_NAME_ML_ERTREES;
1545 bool CvERTrees::train( const CvMat* _train_data, int _tflag,
1546 const CvMat* _responses, const CvMat* _var_idx,
1547 const CvMat* _sample_idx, const CvMat* _var_type,
1548 const CvMat* _missing_mask, CvRTParams params )
1550 bool result = false;
1552 CV_FUNCNAME("CvERTrees::train");
1558 CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
1559 params.regression_accuracy, params.use_surrogates, params.max_categories,
1560 params.cv_folds, params.use_1se_rule, false, params.priors );
1562 data = new CvERTreeTrainData();
1563 CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
1564 _sample_idx, _var_type, _missing_mask, tree_params, true));
1566 var_count = data->var_count;
1567 if( params.nactive_vars > var_count )
1568 params.nactive_vars = var_count;
1569 else if( params.nactive_vars == 0 )
1570 params.nactive_vars = (int)sqrt((double)var_count);
1571 else if( params.nactive_vars < 0 )
1572 CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
1574 // Create mask of active variables at the tree nodes
1575 CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
1576 if( params.calc_var_importance )
1578 CV_CALL(var_importance = cvCreateMat( 1, var_count, CV_32FC1 ));
1579 cvZero(var_importance);
1581 { // initialize active variables mask
1582 CvMat submask1, submask2;
1583 CV_Assert( (active_var_mask->cols >= 1) && (params.nactive_vars > 0) && (params.nactive_vars <= active_var_mask->cols) );
1584 cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
1585 cvSet( &submask1, cvScalar(1) );
1586 if( params.nactive_vars < active_var_mask->cols )
1588 cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
1589 cvZero( &submask2 );
1593 CV_CALL(result = grow_forest( params.term_crit ));
1602 bool CvERTrees::train( CvMLData* _data, CvRTParams params)
1604 bool result = false;
1606 CV_FUNCNAME( "CvERTrees::train" );
1610 CV_CALL( result = CvRTrees::train( _data, params) );
1617 bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
1619 bool result = false;
1621 CvMat* sample_idx_for_tree = 0;
1623 CV_FUNCNAME("CvERTrees::grow_forest");
1626 const int max_ntrees = term_crit.max_iter;
1627 const double max_oob_err = term_crit.epsilon;
1629 const int dims = data->var_count;
1630 float maximal_response = 0;
1632 CvMat* oob_sample_votes = 0;
1633 CvMat* oob_responses = 0;
1635 float* oob_samples_perm_ptr= 0;
1637 float* samples_ptr = 0;
1638 uchar* missing_ptr = 0;
1639 float* true_resp_ptr = 0;
1640 bool is_oob_or_vimportance = ((max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER)) || var_importance;
1642 // oob_predictions_sum[i] = sum of predicted values for the i-th sample
1643 // oob_num_of_predictions[i] = number of summands
1644 // (number of predictions for the i-th sample)
1645 // initialize these variable to avoid warning C4701
1646 CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
1647 CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
1649 nsamples = data->sample_count;
1650 nclasses = data->get_num_classes();
1652 if ( is_oob_or_vimportance )
1654 if( data->is_classifier )
1656 CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
1657 cvZero(oob_sample_votes);
1661 // oob_responses[0,i] = oob_predictions_sum[i]
1662 // = sum of predicted values for the i-th sample
1663 // oob_responses[1,i] = oob_num_of_predictions[i]
1664 // = number of summands (number of predictions for the i-th sample)
1665 CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
1666 cvZero(oob_responses);
1667 cvGetRow( oob_responses, &oob_predictions_sum, 0 );
1668 cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
1671 CV_CALL(oob_samples_perm_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1672 CV_CALL(samples_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1673 CV_CALL(missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
1674 CV_CALL(true_resp_ptr = (float*)cvAlloc( sizeof(float)*nsamples ));
1676 CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
1678 double minval, maxval;
1679 CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
1680 cvMinMaxLoc( &responses, &minval, &maxval );
1681 maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
1685 trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
1686 memset( trees, 0, sizeof(trees[0])*max_ntrees );
1688 CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));
1690 for (int i = 0; i < nsamples; i++)
1691 sample_idx_for_tree->data.i[i] = i;
1693 while( ntrees < max_ntrees )
1695 int i, oob_samples_count = 0;
1696 double ncorrect_responses = 0; // used for estimation of variable importance
1697 CvForestTree* tree = 0;
1699 trees[ntrees] = new CvForestERTree();
1700 tree = (CvForestERTree*)trees[ntrees];
1701 CV_CALL(tree->train( data, 0, this ));
1703 if ( is_oob_or_vimportance )
1705 CvMat sample, missing;
1706 // form array of OOB samples indices and get these samples
1707 sample = cvMat( 1, dims, CV_32FC1, samples_ptr );
1708 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
1711 for( i = 0; i < nsamples; i++,
1712 sample.data.fl += dims, missing.data.ptr += dims )
1714 CvDTreeNode* predicted_node = 0;
1716 // predict oob samples
1717 if( !predicted_node )
1718 CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
1720 if( !data->is_classifier ) //regression
1722 double avg_resp, resp = predicted_node->value;
1723 oob_predictions_sum.data.fl[i] += (float)resp;
1724 oob_num_of_predictions.data.fl[i] += 1;
1726 // compute oob error
1727 avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
1728 avg_resp -= true_resp_ptr[i];
1729 oob_error += avg_resp*avg_resp;
1730 resp = (resp - true_resp_ptr[i])/maximal_response;
1731 ncorrect_responses += exp( -resp*resp );
1733 else //classification
1739 cvGetRow(oob_sample_votes, &votes, i);
1740 votes.data.i[predicted_node->class_idx]++;
1742 // compute oob error
1743 cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
1745 prdct_resp = data->cat_map->data.i[max_loc.x];
1746 oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
1748 ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
1750 oob_samples_count++;
1752 if( oob_samples_count > 0 )
1753 oob_error /= (double)oob_samples_count;
1755 // estimate variable importance
1756 if( var_importance && oob_samples_count > 0 )
1760 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
1761 for( m = 0; m < dims; m++ )
1763 double ncorrect_responses_permuted = 0;
1764 // randomly permute values of the m-th variable in the oob samples
1765 float* mth_var_ptr = oob_samples_perm_ptr + m;
1767 for( i = 0; i < nsamples; i++ )
1772 i1 = (*rng)(nsamples);
1773 i2 = (*rng)(nsamples);
1774 CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
1776 // turn values of (m-1)-th variable, that were permuted
1777 // at the previous iteration, untouched
1779 oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
1782 // predict "permuted" cases and calculate the number of votes for the
1783 // correct class in the variable-m-permuted oob data
1784 sample = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
1785 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
1786 for( i = 0; i < nsamples; i++,
1787 sample.data.fl += dims, missing.data.ptr += dims )
1789 double predct_resp, true_resp;
1791 predct_resp = tree->predict(&sample, &missing, true)->value;
1792 true_resp = true_resp_ptr[i];
1793 if( data->is_classifier )
1794 ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
1797 true_resp = (true_resp - predct_resp)/maximal_response;
1798 ncorrect_responses_permuted += exp( -true_resp*true_resp );
1801 var_importance->data.fl[m] += (float)(ncorrect_responses
1802 - ncorrect_responses_permuted);
1807 if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
1810 if( var_importance )
1812 for ( int vi = 0; vi < var_importance->cols; vi++ )
1813 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
1814 var_importance->data.fl[vi] : 0;
1815 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
1820 cvFree( &oob_samples_perm_ptr );
1821 cvFree( &samples_ptr );
1822 cvFree( &missing_ptr );
1823 cvFree( &true_resp_ptr );
1825 cvReleaseMat( &sample_idx_for_tree );
1827 cvReleaseMat( &oob_sample_votes );
1828 cvReleaseMat( &oob_responses );
1837 bool CvERTrees::train( const Mat& _train_data, int _tflag,
1838 const Mat& _responses, const Mat& _var_idx,
1839 const Mat& _sample_idx, const Mat& _var_type,
1840 const Mat& _missing_mask, CvRTParams params )
1842 CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
1843 sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
1844 return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
1845 sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
1846 mmask.data.ptr ? &mmask : 0, params);