fixed many warnings from GCC 4.6.1
[profile/ivi/opencv.git] / modules / ml / src / rtrees.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "precomp.hpp"
42
43 CvForestTree::CvForestTree()
44 {
45     forest = NULL;
46 }
47
48
49 CvForestTree::~CvForestTree()
50 {
51     clear();
52 }
53
54
55 bool CvForestTree::train( CvDTreeTrainData* _data,
56                           const CvMat* _subsample_idx,
57                           CvRTrees* _forest )
58 {
59     clear();
60     forest = _forest;
61
62     data = _data;
63     data->shared = true;
64     return do_train(_subsample_idx);
65 }
66
67
68 bool
69 CvForestTree::train( const CvMat*, int, const CvMat*, const CvMat*,
70                     const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
71 {
72     assert(0);
73     return false;
74 }
75
76
77 bool
78 CvForestTree::train( CvDTreeTrainData*, const CvMat* )
79 {
80     assert(0);
81     return false;
82 }
83
84
85
86 namespace cv
87 {
88
89 ForestTreeBestSplitFinder::ForestTreeBestSplitFinder( CvForestTree* _tree, CvDTreeNode* _node ) :
90     DTreeBestSplitFinder(_tree, _node) {}
91
92 ForestTreeBestSplitFinder::ForestTreeBestSplitFinder( const ForestTreeBestSplitFinder& finder, Split spl ) :
93     DTreeBestSplitFinder( finder, spl ) {}
94
95 void ForestTreeBestSplitFinder::operator()(const BlockedRange& range)
96 {
97     int vi, vi1 = range.begin(), vi2 = range.end();
98     int n = node->sample_count;
99     CvDTreeTrainData* data = tree->get_data();
100     AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
101
102     CvForestTree* ftree = (CvForestTree*)tree;
103     const CvMat* active_var_mask = ftree->forest->get_active_var_mask();
104
105     for( vi = vi1; vi < vi2; vi++ )
106     {
107         CvDTreeSplit *res;
108         int ci = data->var_type->data.i[vi];
109         if( node->num_valid[vi] <= 1
110             || (active_var_mask && !active_var_mask->data.ptr[vi]) )
111             continue;
112
113         if( data->is_classifier )
114         {
115             if( ci >= 0 )
116                 res = ftree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
117             else
118                 res = ftree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
119         }
120         else
121         {
122             if( ci >= 0 )
123                 res = ftree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
124             else
125                 res = ftree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
126         }
127
128         if( res && bestSplit->quality < split->quality )
129                 memcpy( (CvDTreeSplit*)bestSplit, (CvDTreeSplit*)split, splitSize );
130     }
131 }
132 }
133
134 CvDTreeSplit* CvForestTree::find_best_split( CvDTreeNode* node )
135 {
136     CvMat* active_var_mask = 0;
137     if( forest )
138     {
139         int var_count;
140         CvRNG* rng = forest->get_rng();
141
142         active_var_mask = forest->get_active_var_mask();
143         var_count = active_var_mask->cols;
144
145         CV_Assert( var_count == data->var_count );
146
147         for( int vi = 0; vi < var_count; vi++ )
148         {
149             uchar temp;
150             int i1 = cvRandInt(rng) % var_count;
151             int i2 = cvRandInt(rng) % var_count;
152             CV_SWAP( active_var_mask->data.ptr[i1],
153                 active_var_mask->data.ptr[i2], temp );
154         }
155     }
156
157     cv::ForestTreeBestSplitFinder finder( this, node );
158
159     cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
160
161     CvDTreeSplit *bestSplit = 0;
162     if( finder.bestSplit->quality > 0 )
163     {
164         bestSplit = data->new_split_cat( 0, -1.0f );
165         memcpy( bestSplit, finder.bestSplit, finder.splitSize );
166     }
167
168     return bestSplit;
169 }
170
171 void CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData* _data )
172 {
173     CvDTree::read( fs, fnode, _data );
174     forest = _forest;
175 }
176
177
178 void CvForestTree::read( CvFileStorage*, CvFileNode* )
179 {
180     assert(0);
181 }
182
183 void CvForestTree::read( CvFileStorage* _fs, CvFileNode* _node,
184                          CvDTreeTrainData* _data )
185 {
186     CvDTree::read( _fs, _node, _data );
187 }
188
189
190 //////////////////////////////////////////////////////////////////////////////////////////
191 //                                  Random trees                                        //
192 //////////////////////////////////////////////////////////////////////////////////////////
193 CvRTParams::CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
194     calc_var_importance(false), nactive_vars(0)
195 {
196     term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
197 }
198
199 CvRTParams::CvRTParams( int _max_depth, int _min_sample_count,
200                         float _regression_accuracy, bool _use_surrogates,
201                         int _max_categories, const float* _priors, bool _calc_var_importance,
202                         int _nactive_vars, int max_num_of_trees_in_the_forest,
203                         float forest_accuracy, int termcrit_type ) :
204     CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
205                    _use_surrogates, _max_categories, 0,
206                    false, false, _priors ),
207     calc_var_importance(_calc_var_importance),
208     nactive_vars(_nactive_vars)
209 {
210     term_crit = cvTermCriteria(termcrit_type,
211         max_num_of_trees_in_the_forest, forest_accuracy);
212 }
213
214 CvRTrees::CvRTrees()
215 {
216     nclasses         = 0;
217     oob_error        = 0;
218     ntrees           = 0;
219     trees            = NULL;
220     data             = NULL;
221     active_var_mask  = NULL;
222     var_importance   = NULL;
223     rng = &cv::theRNG();
224     default_model_name = "my_random_trees";
225 }
226
227
228 void CvRTrees::clear()
229 {
230     int k;
231     for( k = 0; k < ntrees; k++ )
232         delete trees[k];
233     cvFree( &trees );
234
235     delete data;
236     data = 0;
237
238     cvReleaseMat( &active_var_mask );
239     cvReleaseMat( &var_importance );
240     ntrees = 0;
241 }
242
243
244 CvRTrees::~CvRTrees()
245 {
246     clear();
247 }
248
249
250 CvMat* CvRTrees::get_active_var_mask()
251 {
252     return active_var_mask;
253 }
254
255
256 CvRNG* CvRTrees::get_rng()
257 {
258     return &rng->state;
259 }
260
261 bool CvRTrees::train( const CvMat* _train_data, int _tflag,
262                         const CvMat* _responses, const CvMat* _var_idx,
263                         const CvMat* _sample_idx, const CvMat* _var_type,
264                         const CvMat* _missing_mask, CvRTParams params )
265 {
266     clear();
267
268     CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
269         params.regression_accuracy, params.use_surrogates, params.max_categories,
270         params.cv_folds, params.use_1se_rule, false, params.priors );
271
272     data = new CvDTreeTrainData();
273     data->set_data( _train_data, _tflag, _responses, _var_idx,
274         _sample_idx, _var_type, _missing_mask, tree_params, true);
275
276     int var_count = data->var_count;
277     if( params.nactive_vars > var_count )
278         params.nactive_vars = var_count;
279     else if( params.nactive_vars == 0 )
280         params.nactive_vars = (int)sqrt((double)var_count);
281     else if( params.nactive_vars < 0 )
282         CV_Error( CV_StsBadArg, "<nactive_vars> must be non-negative" );
283
284     // Create mask of active variables at the tree nodes
285     active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 );
286     if( params.calc_var_importance )
287     {
288         var_importance  = cvCreateMat( 1, var_count, CV_32FC1 );
289         cvZero(var_importance);
290     }
291     { // initialize active variables mask
292         CvMat submask1, submask2;
293         CV_Assert( (active_var_mask->cols >= 1) && (params.nactive_vars > 0) && (params.nactive_vars <= active_var_mask->cols) );
294         cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
295         cvSet( &submask1, cvScalar(1) );
296         if( params.nactive_vars < active_var_mask->cols )
297         {
298             cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
299             cvZero( &submask2 );
300         }
301     }
302
303     return grow_forest( params.term_crit );
304 }
305
306 bool CvRTrees::train( CvMLData* data, CvRTParams params )
307 {
308     const CvMat* values = data->get_values();
309     const CvMat* response = data->get_responses();
310     const CvMat* missing = data->get_missing();
311     const CvMat* var_types = data->get_var_types();
312     const CvMat* train_sidx = data->get_train_sample_idx();
313     const CvMat* var_idx = data->get_var_idx();
314
315     return train( values, CV_ROW_SAMPLE, response, var_idx,
316                   train_sidx, var_types, missing, params );
317 }
318
319 bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
320 {
321     CvMat* sample_idx_mask_for_tree = 0;
322     CvMat* sample_idx_for_tree      = 0;
323
324     const int max_ntrees = term_crit.max_iter;
325     const double max_oob_err = term_crit.epsilon;
326
327     const int dims = data->var_count;
328     float maximal_response = 0;
329
330     CvMat* oob_sample_votes        = 0;
331     CvMat* oob_responses       = 0;
332
333     float* oob_samples_perm_ptr= 0;
334
335     float* samples_ptr     = 0;
336     uchar* missing_ptr     = 0;
337     float* true_resp_ptr   = 0;
338     bool is_oob_or_vimportance = (max_oob_err > 0 && term_crit.type != CV_TERMCRIT_ITER) || var_importance;
339
340     // oob_predictions_sum[i] = sum of predicted values for the i-th sample
341     // oob_num_of_predictions[i] = number of summands
342     //                            (number of predictions for the i-th sample)
343     // initialize these variable to avoid warning C4701
344     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
345     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
346      
347     nsamples = data->sample_count;
348     nclasses = data->get_num_classes();
349
350     if ( is_oob_or_vimportance )
351     {
352         if( data->is_classifier )
353         {
354             oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 );
355             cvZero(oob_sample_votes);
356         }
357         else
358         {
359             // oob_responses[0,i] = oob_predictions_sum[i]
360             //    = sum of predicted values for the i-th sample
361             // oob_responses[1,i] = oob_num_of_predictions[i]
362             //    = number of summands (number of predictions for the i-th sample)
363             oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 );
364             cvZero(oob_responses);
365             cvGetRow( oob_responses, &oob_predictions_sum, 0 );
366             cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
367         }
368         
369         oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims );
370         samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims );
371         missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims );
372         true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples );            
373
374         data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr );
375         
376         double minval, maxval;
377         CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
378         cvMinMaxLoc( &responses, &minval, &maxval );
379         maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
380     }
381
382     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
383     memset( trees, 0, sizeof(trees[0])*max_ntrees );
384
385     sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 );
386     sample_idx_for_tree      = cvCreateMat( 1, nsamples, CV_32SC1 );
387
388     ntrees = 0;
389     while( ntrees < max_ntrees )
390     {
391         int i, oob_samples_count = 0;
392         double ncorrect_responses = 0; // used for estimation of variable importance
393         CvForestTree* tree = 0;
394
395         cvZero( sample_idx_mask_for_tree );
396         for(i = 0; i < nsamples; i++ ) //form sample for creation one tree
397         {
398             int idx = (*rng)(nsamples);
399             sample_idx_for_tree->data.i[i] = idx;
400             sample_idx_mask_for_tree->data.ptr[idx] = 0xFF;
401         }
402
403         trees[ntrees] = new CvForestTree();
404         tree = trees[ntrees];
405         tree->train( data, sample_idx_for_tree, this );
406
407         if ( is_oob_or_vimportance )
408         {
409             CvMat sample, missing;
410             // form array of OOB samples indices and get these samples
411             sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
412             missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
413
414             oob_error = 0;
415             for( i = 0; i < nsamples; i++,
416                 sample.data.fl += dims, missing.data.ptr += dims )
417             {
418                 CvDTreeNode* predicted_node = 0;
419                 // check if the sample is OOB
420                 if( sample_idx_mask_for_tree->data.ptr[i] )
421                     continue;
422
423                 // predict oob samples
424                 if( !predicted_node )
425                     predicted_node = tree->predict(&sample, &missing, true);
426
427                 if( !data->is_classifier ) //regression
428                 {
429                     double avg_resp, resp = predicted_node->value;
430                     oob_predictions_sum.data.fl[i] += (float)resp;
431                     oob_num_of_predictions.data.fl[i] += 1;
432
433                     // compute oob error
434                     avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
435                     avg_resp -= true_resp_ptr[i];
436                     oob_error += avg_resp*avg_resp;
437                     resp = (resp - true_resp_ptr[i])/maximal_response;
438                     ncorrect_responses += exp( -resp*resp );
439                 }
440                 else //classification
441                 {
442                     double prdct_resp;
443                     CvPoint max_loc;
444                     CvMat votes;
445
446                     cvGetRow(oob_sample_votes, &votes, i);
447                     votes.data.i[predicted_node->class_idx]++;
448
449                     // compute oob error
450                     cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
451
452                     prdct_resp = data->cat_map->data.i[max_loc.x];
453                     oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
454
455                     ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
456                 }
457                 oob_samples_count++;
458             }
459             if( oob_samples_count > 0 )
460                 oob_error /= (double)oob_samples_count;
461
462             // estimate variable importance
463             if( var_importance && oob_samples_count > 0 )
464             {
465                 int m;
466
467                 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
468                 for( m = 0; m < dims; m++ )
469                 {
470                     double ncorrect_responses_permuted = 0;
471                     // randomly permute values of the m-th variable in the oob samples
472                     float* mth_var_ptr = oob_samples_perm_ptr + m;
473
474                     for( i = 0; i < nsamples; i++ )
475                     {
476                         int i1, i2;
477                         float temp;
478
479                         if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
480                             continue;
481                         i1 = (*rng)(nsamples);
482                         i2 = (*rng)(nsamples);
483                         CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
484
485                         // turn values of (m-1)-th variable, that were permuted
486                         // at the previous iteration, untouched
487                         if( m > 1 )
488                             oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
489                     }
490
491                     // predict "permuted" cases and calculate the number of votes for the
492                     // correct class in the variable-m-permuted oob data
493                     sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
494                     missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
495                     for( i = 0; i < nsamples; i++,
496                         sample.data.fl += dims, missing.data.ptr += dims )
497                     {
498                         double predct_resp, true_resp;
499
500                         if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
501                             continue;
502
503                         predct_resp = tree->predict(&sample, &missing, true)->value;
504                         true_resp   = true_resp_ptr[i];
505                         if( data->is_classifier )
506                             ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
507                         else
508                         {
509                             true_resp = (true_resp - predct_resp)/maximal_response;
510                             ncorrect_responses_permuted += exp( -true_resp*true_resp );
511                         }
512                     }
513                     var_importance->data.fl[m] += (float)(ncorrect_responses
514                         - ncorrect_responses_permuted);
515                 }
516             }
517         }
518         ntrees++;
519         if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
520             break;
521     }
522
523     if( var_importance )
524     {
525         for ( int vi = 0; vi < var_importance->cols; vi++ )
526                 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
527                     var_importance->data.fl[vi] : 0;
528         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
529     }
530
531     cvFree( &oob_samples_perm_ptr );
532     cvFree( &samples_ptr );
533     cvFree( &missing_ptr );
534     cvFree( &true_resp_ptr );
535     
536     cvReleaseMat( &sample_idx_mask_for_tree );
537     cvReleaseMat( &sample_idx_for_tree );
538
539     cvReleaseMat( &oob_sample_votes );
540     cvReleaseMat( &oob_responses );
541
542     return true;
543 }
544
545
546 const CvMat* CvRTrees::get_var_importance()
547 {
548     return var_importance;
549 }
550
551
552 float CvRTrees::get_proximity( const CvMat* sample1, const CvMat* sample2,
553                               const CvMat* missing1, const CvMat* missing2 ) const
554 {
555     float result = 0;
556
557     for( int i = 0; i < ntrees; i++ )
558         result += trees[i]->predict( sample1, missing1 ) ==
559         trees[i]->predict( sample2, missing2 ) ?  1 : 0;
560     result = result/(float)ntrees;
561
562     return result;
563 }
564
565 float CvRTrees::calc_error( CvMLData* _data, int type , std::vector<float> *resp )
566 {
567     float err = 0;
568     const CvMat* values = _data->get_values();
569     const CvMat* response = _data->get_responses();
570     const CvMat* missing = _data->get_missing();
571     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
572     const CvMat* var_types = _data->get_var_types();
573     int* sidx = sample_idx ? sample_idx->data.i : 0;
574     int r_step = CV_IS_MAT_CONT(response->type) ?
575                 1 : response->step / CV_ELEM_SIZE(response->type);
576     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
577     int sample_count = sample_idx ? sample_idx->cols : 0;
578     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
579     float* pred_resp = 0;
580     if( resp && (sample_count > 0) )
581     {
582         resp->resize( sample_count );
583         pred_resp = &((*resp)[0]);
584     }
585     if ( is_classifier )
586     {
587         for( int i = 0; i < sample_count; i++ )
588         {
589             CvMat sample, miss;
590             int si = sidx ? sidx[i] : i;
591             cvGetRow( values, &sample, si ); 
592             if( missing ) 
593                 cvGetRow( missing, &miss, si );             
594             float r = (float)predict( &sample, missing ? &miss : 0 );
595             if( pred_resp )
596                 pred_resp[i] = r;
597             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
598             err += d;
599         }
600         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
601     }
602     else
603     {
604         for( int i = 0; i < sample_count; i++ )
605         {
606             CvMat sample, miss;
607             int si = sidx ? sidx[i] : i;
608             cvGetRow( values, &sample, si );
609             if( missing ) 
610                 cvGetRow( missing, &miss, si );             
611             float r = (float)predict( &sample, missing ? &miss : 0 );
612             if( pred_resp )
613                 pred_resp[i] = r;
614             float d = r - response->data.fl[si*r_step];
615             err += d*d;
616         }
617         err = sample_count ? err / (float)sample_count : -FLT_MAX;    
618     }
619     return err;
620 }
621
622 float CvRTrees::get_train_error()
623 {
624     float err = -1;
625
626     int sample_count = data->sample_count;
627     int var_count = data->var_count;
628
629     float *values_ptr = (float*)cvAlloc( sizeof(float)*sample_count*var_count );
630     uchar *missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*sample_count*var_count );
631     float *responses_ptr = (float*)cvAlloc( sizeof(float)*sample_count );
632
633     data->get_vectors( 0, values_ptr, missing_ptr, responses_ptr);
634     
635     if (data->is_classifier)
636     {
637         int err_count = 0;
638         float *vp = values_ptr;
639         uchar *mp = missing_ptr;    
640         for (int si = 0; si < sample_count; si++, vp += var_count, mp += var_count)
641         {
642             CvMat sample = cvMat( 1, var_count, CV_32FC1, vp );
643             CvMat missing = cvMat( 1, var_count, CV_8UC1,  mp );
644             float r = predict( &sample, &missing );
645             if (fabs(r - responses_ptr[si]) >= FLT_EPSILON)
646                 err_count++;
647         }
648         err = (float)err_count / (float)sample_count;
649     }
650     else
651         CV_Error( CV_StsBadArg, "This method is not supported for regression problems" );
652     
653     cvFree( &values_ptr );
654     cvFree( &missing_ptr );
655     cvFree( &responses_ptr ); 
656
657     return err;
658 }
659
660
661 float CvRTrees::predict( const CvMat* sample, const CvMat* missing ) const
662 {
663     double result = -1;
664     int k;
665
666     if( nclasses > 0 ) //classification
667     {
668         int max_nvotes = 0;
669         cv::AutoBuffer<int> _votes(nclasses);
670         int* votes = _votes;
671         memset( votes, 0, sizeof(*votes)*nclasses );
672         for( k = 0; k < ntrees; k++ )
673         {
674             CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
675             int nvotes;
676             int class_idx = predicted_node->class_idx;
677             CV_Assert( 0 <= class_idx && class_idx < nclasses );
678
679             nvotes = ++votes[class_idx];
680             if( nvotes > max_nvotes )
681             {
682                 max_nvotes = nvotes;
683                 result = predicted_node->value;
684             }
685         }
686     }
687     else // regression
688     {
689         result = 0;
690         for( k = 0; k < ntrees; k++ )
691             result += trees[k]->predict( sample, missing )->value;
692         result /= (double)ntrees;
693     }
694
695     return (float)result;
696 }
697
698 float CvRTrees::predict_prob( const CvMat* sample, const CvMat* missing) const
699 {
700         if( nclasses == 2 ) //classification
701     {
702         cv::AutoBuffer<int> _votes(nclasses);
703         int* votes = _votes;
704         memset( votes, 0, sizeof(*votes)*nclasses );
705         for( int k = 0; k < ntrees; k++ )
706         {
707             CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
708             int class_idx = predicted_node->class_idx;
709             CV_Assert( 0 <= class_idx && class_idx < nclasses );
710                         
711             ++votes[class_idx];
712         }
713                 
714                 return float(votes[1])/ntrees;
715     }
716     else // regression
717                 CV_Error(CV_StsBadArg, "This function works for binary classification problems only...");
718         
719     return -1;
720 }
721
722 void CvRTrees::write( CvFileStorage* fs, const char* name ) const
723 {
724     int k;
725
726     if( ntrees < 1 || !trees || nsamples < 1 )
727         CV_Error( CV_StsBadArg, "Invalid CvRTrees object" );
728
729     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES );
730
731     cvWriteInt( fs, "nclasses", nclasses );
732     cvWriteInt( fs, "nsamples", nsamples );
733     cvWriteInt( fs, "nactive_vars", (int)cvSum(active_var_mask).val[0] );
734     cvWriteReal( fs, "oob_error", oob_error );
735
736     if( var_importance )
737         cvWrite( fs, "var_importance", var_importance );
738
739     cvWriteInt( fs, "ntrees", ntrees );
740
741     data->write_params( fs );
742
743     cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
744
745     for( k = 0; k < ntrees; k++ )
746     {
747         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
748         trees[k]->write( fs );
749         cvEndWriteStruct( fs );
750     }
751
752     cvEndWriteStruct( fs ); //trees
753     cvEndWriteStruct( fs ); //CV_TYPE_NAME_ML_RTREES
754 }
755
756
757 void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
758 {
759     int nactive_vars, var_count, k;
760     CvSeqReader reader;
761     CvFileNode* trees_fnode = 0;
762
763     clear();
764
765     nclasses     = cvReadIntByName( fs, fnode, "nclasses", -1 );
766     nsamples     = cvReadIntByName( fs, fnode, "nsamples" );
767     nactive_vars = cvReadIntByName( fs, fnode, "nactive_vars", -1 );
768     oob_error    = cvReadRealByName(fs, fnode, "oob_error", -1 );
769     ntrees       = cvReadIntByName( fs, fnode, "ntrees", -1 );
770
771     var_importance = (CvMat*)cvReadByName( fs, fnode, "var_importance" );
772
773     if( nclasses < 0 || nsamples <= 0 || nactive_vars < 0 || oob_error < 0 || ntrees <= 0)
774         CV_Error( CV_StsParseError, "Some <nclasses>, <nsamples>, <var_count>, "
775         "<nactive_vars>, <oob_error>, <ntrees> of tags are missing" );
776
777     rng = &cv::theRNG();
778
779     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*ntrees );
780     memset( trees, 0, sizeof(trees[0])*ntrees );
781
782     data = new CvDTreeTrainData();
783     data->read_params( fs, fnode );
784     data->shared = true;
785
786     trees_fnode = cvGetFileNodeByName( fs, fnode, "trees" );
787     if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
788         CV_Error( CV_StsParseError, "<trees> tag is missing" );
789
790     cvStartReadSeq( trees_fnode->data.seq, &reader );
791     if( reader.seq->total != ntrees )
792         CV_Error( CV_StsParseError,
793         "<ntrees> is not equal to the number of trees saved in file" );
794
795     for( k = 0; k < ntrees; k++ )
796     {
797         trees[k] = new CvForestTree();
798         trees[k]->read( fs, (CvFileNode*)reader.ptr, this, data );
799         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
800     }
801
802     var_count = data->var_count;
803     active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 );
804     {
805         // initialize active variables mask
806         CvMat submask1;
807                 cvGetCols( active_var_mask, &submask1, 0, nactive_vars );
808         cvSet( &submask1, cvScalar(1) );
809
810                 if( nactive_vars < var_count )
811                 {
812                         CvMat submask2;
813                         cvGetCols( active_var_mask, &submask2, nactive_vars, var_count );
814                         cvZero( &submask2 );
815                 }
816     }
817 }
818
819
820 int CvRTrees::get_tree_count() const
821 {
822     return ntrees;
823 }
824
825 CvForestTree* CvRTrees::get_tree(int i) const
826 {
827     return (unsigned)i < (unsigned)ntrees ? trees[i] : 0;
828 }
829
830 using namespace cv;
831
832 bool CvRTrees::train( const Mat& _train_data, int _tflag,
833                      const Mat& _responses, const Mat& _var_idx,
834                      const Mat& _sample_idx, const Mat& _var_type,
835                      const Mat& _missing_mask, CvRTParams _params )
836 {
837     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
838     sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
839     return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
840                  sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
841                  mmask.data.ptr ? &mmask : 0, _params);
842 }
843
844
845 float CvRTrees::predict( const Mat& _sample, const Mat& _missing ) const
846 {
847     CvMat sample = _sample, mmask = _missing;
848     return predict(&sample, mmask.data.ptr ? &mmask : 0);
849 }
850
851 float CvRTrees::predict_prob( const Mat& _sample, const Mat& _missing) const
852 {
853     CvMat sample = _sample, mmask = _missing;
854     return predict_prob(&sample, mmask.data.ptr ? &mmask : 0);
855 }
856
857 Mat CvRTrees::getVarImportance()
858 {
859     return Mat(get_var_importance());
860 }
861
862 // End of file.