made random generators of MLL classes depended on default rng (theRNG) (#205).
[profile/ivi/opencv.git] / modules / ml / src / rtrees.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "precomp.hpp"
42
43 CvForestTree::CvForestTree()
44 {
45     forest = NULL;
46 }
47
48
49 CvForestTree::~CvForestTree()
50 {
51     clear();
52 }
53
54
55 bool CvForestTree::train( CvDTreeTrainData* _data,
56                           const CvMat* _subsample_idx,
57                           CvRTrees* _forest )
58 {
59     clear();
60     forest = _forest;
61
62     data = _data;
63     data->shared = true;
64     return do_train(_subsample_idx);
65 }
66
67
68 bool
69 CvForestTree::train( const CvMat*, int, const CvMat*, const CvMat*,
70                     const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
71 {
72     assert(0);
73     return false;
74 }
75
76
77 bool
78 CvForestTree::train( CvDTreeTrainData*, const CvMat* )
79 {
80     assert(0);
81     return false;
82 }
83
84
85
86 namespace cv
87 {
88
89 ForestTreeBestSplitFinder::ForestTreeBestSplitFinder( CvForestTree* _tree, CvDTreeNode* _node ) :
90     DTreeBestSplitFinder(_tree, _node) {}
91
92 ForestTreeBestSplitFinder::ForestTreeBestSplitFinder( const ForestTreeBestSplitFinder& finder, Split spl ) :
93     DTreeBestSplitFinder( finder, spl ) {}
94
95 void ForestTreeBestSplitFinder::operator()(const BlockedRange& range)
96 {
97     int vi, vi1 = range.begin(), vi2 = range.end();
98     int n = node->sample_count;
99     CvDTreeTrainData* data = tree->get_data();
100     AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
101
102     CvForestTree* ftree = (CvForestTree*)tree;
103     const CvMat* active_var_mask = ftree->forest->get_active_var_mask();
104
105     for( vi = vi1; vi < vi2; vi++ )
106     {
107         CvDTreeSplit *res;
108         int ci = data->var_type->data.i[vi];
109         if( node->num_valid[vi] <= 1
110             || (active_var_mask && !active_var_mask->data.ptr[vi]) )
111             continue;
112
113         if( data->is_classifier )
114         {
115             if( ci >= 0 )
116                 res = ftree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
117             else
118                 res = ftree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
119         }
120         else
121         {
122             if( ci >= 0 )
123                 res = ftree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
124             else
125                 res = ftree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
126         }
127
128         if( res && bestSplit->quality < split->quality )
129                 memcpy( (CvDTreeSplit*)bestSplit, (CvDTreeSplit*)split, splitSize );
130     }
131 }
132 }
133
134 CvDTreeSplit* CvForestTree::find_best_split( CvDTreeNode* node )
135 {
136     CvMat* active_var_mask = 0;
137     if( forest )
138     {
139         int var_count;
140         CvRNG* rng = forest->get_rng();
141
142         active_var_mask = forest->get_active_var_mask();
143         var_count = active_var_mask->cols;
144
145         CV_Assert( var_count == data->var_count );
146
147         for( int vi = 0; vi < var_count; vi++ )
148         {
149             uchar temp;
150             int i1 = cvRandInt(rng) % var_count;
151             int i2 = cvRandInt(rng) % var_count;
152             CV_SWAP( active_var_mask->data.ptr[i1],
153                 active_var_mask->data.ptr[i2], temp );
154         }
155     }
156
157     cv::ForestTreeBestSplitFinder finder( this, node );
158
159     cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
160
161     CvDTreeSplit *bestSplit = 0;
162     if( finder.bestSplit->quality > 0 )
163     {
164         bestSplit = data->new_split_cat( 0, -1.0f );
165         memcpy( bestSplit, finder.bestSplit, finder.splitSize );
166     }
167
168     return bestSplit;
169 }
170
171 void CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData* _data )
172 {
173     CvDTree::read( fs, fnode, _data );
174     forest = _forest;
175 }
176
177
178 void CvForestTree::read( CvFileStorage*, CvFileNode* )
179 {
180     assert(0);
181 }
182
183 void CvForestTree::read( CvFileStorage* _fs, CvFileNode* _node,
184                          CvDTreeTrainData* _data )
185 {
186     CvDTree::read( _fs, _node, _data );
187 }
188
189
190 //////////////////////////////////////////////////////////////////////////////////////////
191 //                                  Random trees                                        //
192 //////////////////////////////////////////////////////////////////////////////////////////
193
194 CvRTrees::CvRTrees()
195 {
196     nclasses         = 0;
197     oob_error        = 0;
198     ntrees           = 0;
199     trees            = NULL;
200     data             = NULL;
201     active_var_mask  = NULL;
202     var_importance   = NULL;
203     rng = &cv::theRNG();
204     default_model_name = "my_random_trees";
205 }
206
207
208 void CvRTrees::clear()
209 {
210     int k;
211     for( k = 0; k < ntrees; k++ )
212         delete trees[k];
213     cvFree( &trees );
214
215     delete data;
216     data = 0;
217
218     cvReleaseMat( &active_var_mask );
219     cvReleaseMat( &var_importance );
220     ntrees = 0;
221 }
222
223
224 CvRTrees::~CvRTrees()
225 {
226     clear();
227 }
228
229
230 CvMat* CvRTrees::get_active_var_mask()
231 {
232     return active_var_mask;
233 }
234
235
236 CvRNG* CvRTrees::get_rng()
237 {
238     return &rng->state;
239 }
240
241 bool CvRTrees::train( const CvMat* _train_data, int _tflag,
242                         const CvMat* _responses, const CvMat* _var_idx,
243                         const CvMat* _sample_idx, const CvMat* _var_type,
244                         const CvMat* _missing_mask, CvRTParams params )
245 {
246     clear();
247
248     CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
249         params.regression_accuracy, params.use_surrogates, params.max_categories,
250         params.cv_folds, params.use_1se_rule, false, params.priors );
251
252     data = new CvDTreeTrainData();
253     data->set_data( _train_data, _tflag, _responses, _var_idx,
254         _sample_idx, _var_type, _missing_mask, tree_params, true);
255
256     int var_count = data->var_count;
257     if( params.nactive_vars > var_count )
258         params.nactive_vars = var_count;
259     else if( params.nactive_vars == 0 )
260         params.nactive_vars = (int)sqrt((double)var_count);
261     else if( params.nactive_vars < 0 )
262         CV_Error( CV_StsBadArg, "<nactive_vars> must be non-negative" );
263
264     // Create mask of active variables at the tree nodes
265     active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 );
266     if( params.calc_var_importance )
267     {
268         var_importance  = cvCreateMat( 1, var_count, CV_32FC1 );
269         cvZero(var_importance);
270     }
271     { // initialize active variables mask
272         CvMat submask1, submask2;
273         CV_Assert( (active_var_mask->cols >= 1) && (params.nactive_vars > 0) && (params.nactive_vars <= active_var_mask->cols) );
274         cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
275         cvSet( &submask1, cvScalar(1) );
276         if( params.nactive_vars < active_var_mask->cols )
277         {
278             cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
279             cvZero( &submask2 );
280         }
281     }
282
283     return grow_forest( params.term_crit );
284 }
285
286 bool CvRTrees::train( CvMLData* data, CvRTParams params )
287 {
288     const CvMat* values = data->get_values();
289     const CvMat* response = data->get_responses();
290     const CvMat* missing = data->get_missing();
291     const CvMat* var_types = data->get_var_types();
292     const CvMat* train_sidx = data->get_train_sample_idx();
293     const CvMat* var_idx = data->get_var_idx();
294
295     return train( values, CV_ROW_SAMPLE, response, var_idx,
296                   train_sidx, var_types, missing, params );
297 }
298
299 bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
300 {
301     CvMat* sample_idx_mask_for_tree = 0;
302     CvMat* sample_idx_for_tree      = 0;
303
304     const int max_ntrees = term_crit.max_iter;
305     const double max_oob_err = term_crit.epsilon;
306
307     const int dims = data->var_count;
308     float maximal_response = 0;
309
310     CvMat* oob_sample_votes        = 0;
311     CvMat* oob_responses       = 0;
312
313     float* oob_samples_perm_ptr= 0;
314
315     float* samples_ptr     = 0;
316     uchar* missing_ptr     = 0;
317     float* true_resp_ptr   = 0;
318     bool is_oob_or_vimportance = (max_oob_err > 0 && term_crit.type != CV_TERMCRIT_ITER) || var_importance;
319
320     // oob_predictions_sum[i] = sum of predicted values for the i-th sample
321     // oob_num_of_predictions[i] = number of summands
322     //                            (number of predictions for the i-th sample)
323     // initialize these variable to avoid warning C4701
324     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
325     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
326      
327     nsamples = data->sample_count;
328     nclasses = data->get_num_classes();
329
330     if ( is_oob_or_vimportance )
331     {
332         if( data->is_classifier )
333         {
334             oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 );
335             cvZero(oob_sample_votes);
336         }
337         else
338         {
339             // oob_responses[0,i] = oob_predictions_sum[i]
340             //    = sum of predicted values for the i-th sample
341             // oob_responses[1,i] = oob_num_of_predictions[i]
342             //    = number of summands (number of predictions for the i-th sample)
343             oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 );
344             cvZero(oob_responses);
345             cvGetRow( oob_responses, &oob_predictions_sum, 0 );
346             cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
347         }
348         
349         oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims );
350         samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims );
351         missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims );
352         true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples );            
353
354         data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr );
355         
356         double minval, maxval;
357         CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
358         cvMinMaxLoc( &responses, &minval, &maxval );
359         maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
360     }
361
362     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
363     memset( trees, 0, sizeof(trees[0])*max_ntrees );
364
365     sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 );
366     sample_idx_for_tree      = cvCreateMat( 1, nsamples, CV_32SC1 );
367
368     ntrees = 0;
369     while( ntrees < max_ntrees )
370     {
371         int i, oob_samples_count = 0;
372         double ncorrect_responses = 0; // used for estimation of variable importance
373         CvForestTree* tree = 0;
374
375         cvZero( sample_idx_mask_for_tree );
376         for(i = 0; i < nsamples; i++ ) //form sample for creation one tree
377         {
378             int idx = (*rng)(nsamples);
379             sample_idx_for_tree->data.i[i] = idx;
380             sample_idx_mask_for_tree->data.ptr[idx] = 0xFF;
381         }
382
383         trees[ntrees] = new CvForestTree();
384         tree = trees[ntrees];
385         tree->train( data, sample_idx_for_tree, this );
386
387         if ( is_oob_or_vimportance )
388         {
389             CvMat sample, missing;
390             // form array of OOB samples indices and get these samples
391             sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
392             missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
393
394             oob_error = 0;
395             for( i = 0; i < nsamples; i++,
396                 sample.data.fl += dims, missing.data.ptr += dims )
397             {
398                 CvDTreeNode* predicted_node = 0;
399                 // check if the sample is OOB
400                 if( sample_idx_mask_for_tree->data.ptr[i] )
401                     continue;
402
403                 // predict oob samples
404                 if( !predicted_node )
405                     predicted_node = tree->predict(&sample, &missing, true);
406
407                 if( !data->is_classifier ) //regression
408                 {
409                     double avg_resp, resp = predicted_node->value;
410                     oob_predictions_sum.data.fl[i] += (float)resp;
411                     oob_num_of_predictions.data.fl[i] += 1;
412
413                     // compute oob error
414                     avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
415                     avg_resp -= true_resp_ptr[i];
416                     oob_error += avg_resp*avg_resp;
417                     resp = (resp - true_resp_ptr[i])/maximal_response;
418                     ncorrect_responses += exp( -resp*resp );
419                 }
420                 else //classification
421                 {
422                     double prdct_resp;
423                     CvPoint max_loc;
424                     CvMat votes;
425
426                     cvGetRow(oob_sample_votes, &votes, i);
427                     votes.data.i[predicted_node->class_idx]++;
428
429                     // compute oob error
430                     cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
431
432                     prdct_resp = data->cat_map->data.i[max_loc.x];
433                     oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
434
435                     ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
436                 }
437                 oob_samples_count++;
438             }
439             if( oob_samples_count > 0 )
440                 oob_error /= (double)oob_samples_count;
441
442             // estimate variable importance
443             if( var_importance && oob_samples_count > 0 )
444             {
445                 int m;
446
447                 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
448                 for( m = 0; m < dims; m++ )
449                 {
450                     double ncorrect_responses_permuted = 0;
451                     // randomly permute values of the m-th variable in the oob samples
452                     float* mth_var_ptr = oob_samples_perm_ptr + m;
453
454                     for( i = 0; i < nsamples; i++ )
455                     {
456                         int i1, i2;
457                         float temp;
458
459                         if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
460                             continue;
461                         i1 = (*rng)(nsamples);
462                         i2 = (*rng)(nsamples);
463                         CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
464
465                         // turn values of (m-1)-th variable, that were permuted
466                         // at the previous iteration, untouched
467                         if( m > 1 )
468                             oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
469                     }
470
471                     // predict "permuted" cases and calculate the number of votes for the
472                     // correct class in the variable-m-permuted oob data
473                     sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
474                     missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
475                     for( i = 0; i < nsamples; i++,
476                         sample.data.fl += dims, missing.data.ptr += dims )
477                     {
478                         double predct_resp, true_resp;
479
480                         if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
481                             continue;
482
483                         predct_resp = tree->predict(&sample, &missing, true)->value;
484                         true_resp   = true_resp_ptr[i];
485                         if( data->is_classifier )
486                             ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
487                         else
488                         {
489                             true_resp = (true_resp - predct_resp)/maximal_response;
490                             ncorrect_responses_permuted += exp( -true_resp*true_resp );
491                         }
492                     }
493                     var_importance->data.fl[m] += (float)(ncorrect_responses
494                         - ncorrect_responses_permuted);
495                 }
496             }
497         }
498         ntrees++;
499         if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
500             break;
501     }
502
503     if( var_importance )
504     {
505         for ( int vi = 0; vi < var_importance->cols; vi++ )
506                 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
507                     var_importance->data.fl[vi] : 0;
508         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
509     }
510
511     cvFree( &oob_samples_perm_ptr );
512     cvFree( &samples_ptr );
513     cvFree( &missing_ptr );
514     cvFree( &true_resp_ptr );
515     
516     cvReleaseMat( &sample_idx_mask_for_tree );
517     cvReleaseMat( &sample_idx_for_tree );
518
519     cvReleaseMat( &oob_sample_votes );
520     cvReleaseMat( &oob_responses );
521
522     return true;
523 }
524
525
526 const CvMat* CvRTrees::get_var_importance()
527 {
528     return var_importance;
529 }
530
531
532 float CvRTrees::get_proximity( const CvMat* sample1, const CvMat* sample2,
533                               const CvMat* missing1, const CvMat* missing2 ) const
534 {
535     float result = 0;
536
537     for( int i = 0; i < ntrees; i++ )
538         result += trees[i]->predict( sample1, missing1 ) ==
539         trees[i]->predict( sample2, missing2 ) ?  1 : 0;
540     result = result/(float)ntrees;
541
542     return result;
543 }
544
545 float CvRTrees::calc_error( CvMLData* _data, int type , std::vector<float> *resp )
546 {
547     float err = 0;
548     const CvMat* values = _data->get_values();
549     const CvMat* response = _data->get_responses();
550     const CvMat* missing = _data->get_missing();
551     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
552     const CvMat* var_types = _data->get_var_types();
553     int* sidx = sample_idx ? sample_idx->data.i : 0;
554     int r_step = CV_IS_MAT_CONT(response->type) ?
555                 1 : response->step / CV_ELEM_SIZE(response->type);
556     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
557     int sample_count = sample_idx ? sample_idx->cols : 0;
558     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
559     float* pred_resp = 0;
560     if( resp && (sample_count > 0) )
561     {
562         resp->resize( sample_count );
563         pred_resp = &((*resp)[0]);
564     }
565     if ( is_classifier )
566     {
567         for( int i = 0; i < sample_count; i++ )
568         {
569             CvMat sample, miss;
570             int si = sidx ? sidx[i] : i;
571             cvGetRow( values, &sample, si ); 
572             if( missing ) 
573                 cvGetRow( missing, &miss, si );             
574             float r = (float)predict( &sample, missing ? &miss : 0 );
575             if( pred_resp )
576                 pred_resp[i] = r;
577             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
578             err += d;
579         }
580         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
581     }
582     else
583     {
584         for( int i = 0; i < sample_count; i++ )
585         {
586             CvMat sample, miss;
587             int si = sidx ? sidx[i] : i;
588             cvGetRow( values, &sample, si );
589             if( missing ) 
590                 cvGetRow( missing, &miss, si );             
591             float r = (float)predict( &sample, missing ? &miss : 0 );
592             if( pred_resp )
593                 pred_resp[i] = r;
594             float d = r - response->data.fl[si*r_step];
595             err += d*d;
596         }
597         err = sample_count ? err / (float)sample_count : -FLT_MAX;    
598     }
599     return err;
600 }
601
602 float CvRTrees::get_train_error()
603 {
604     float err = -1;
605
606     int sample_count = data->sample_count;
607     int var_count = data->var_count;
608
609     float *values_ptr = (float*)cvAlloc( sizeof(float)*sample_count*var_count );
610     uchar *missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*sample_count*var_count );
611     float *responses_ptr = (float*)cvAlloc( sizeof(float)*sample_count );
612
613     data->get_vectors( 0, values_ptr, missing_ptr, responses_ptr);
614     
615     if (data->is_classifier)
616     {
617         int err_count = 0;
618         float *vp = values_ptr;
619         uchar *mp = missing_ptr;    
620         for (int si = 0; si < sample_count; si++, vp += var_count, mp += var_count)
621         {
622             CvMat sample = cvMat( 1, var_count, CV_32FC1, vp );
623             CvMat missing = cvMat( 1, var_count, CV_8UC1,  mp );
624             float r = predict( &sample, &missing );
625             if (fabs(r - responses_ptr[si]) >= FLT_EPSILON)
626                 err_count++;
627         }
628         err = (float)err_count / (float)sample_count;
629     }
630     else
631         CV_Error( CV_StsBadArg, "This method is not supported for regression problems" );
632     
633     cvFree( &values_ptr );
634     cvFree( &missing_ptr );
635     cvFree( &responses_ptr ); 
636
637     return err;
638 }
639
640
641 float CvRTrees::predict( const CvMat* sample, const CvMat* missing ) const
642 {
643     double result = -1;
644     int k;
645
646     if( nclasses > 0 ) //classification
647     {
648         int max_nvotes = 0;
649         int* votes = (int*)alloca( sizeof(int)*nclasses );
650         memset( votes, 0, sizeof(*votes)*nclasses );
651         for( k = 0; k < ntrees; k++ )
652         {
653             CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
654             int nvotes;
655             int class_idx = predicted_node->class_idx;
656             CV_Assert( 0 <= class_idx && class_idx < nclasses );
657
658             nvotes = ++votes[class_idx];
659             if( nvotes > max_nvotes )
660             {
661                 max_nvotes = nvotes;
662                 result = predicted_node->value;
663             }
664         }
665     }
666     else // regression
667     {
668         result = 0;
669         for( k = 0; k < ntrees; k++ )
670             result += trees[k]->predict( sample, missing )->value;
671         result /= (double)ntrees;
672     }
673
674     return (float)result;
675 }
676
677 float CvRTrees::predict_prob( const CvMat* sample, const CvMat* missing) const
678 {
679     double result = -1;
680     int k;
681         
682         if( nclasses == 2 ) //classification
683     {
684         int max_nvotes = 0;
685         int* votes = (int*)alloca( sizeof(int)*nclasses );
686         memset( votes, 0, sizeof(*votes)*nclasses );
687         for( k = 0; k < ntrees; k++ )
688         {
689             CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
690             int nvotes;
691             int class_idx = predicted_node->class_idx;
692             CV_Assert( 0 <= class_idx && class_idx < nclasses );
693                         
694             nvotes = ++votes[class_idx];
695             if( nvotes > max_nvotes )
696             {
697                 max_nvotes = nvotes;
698                 result = predicted_node->value;
699             }
700         }
701                 
702                 return float(votes[1])/ntrees;
703     }
704     else // regression
705                 CV_Error(CV_StsBadArg, "This function works for binary classification problems only...");
706         
707     return -1;
708 }
709
710 void CvRTrees::write( CvFileStorage* fs, const char* name ) const
711 {
712     int k;
713
714     if( ntrees < 1 || !trees || nsamples < 1 )
715         CV_Error( CV_StsBadArg, "Invalid CvRTrees object" );
716
717     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES );
718
719     cvWriteInt( fs, "nclasses", nclasses );
720     cvWriteInt( fs, "nsamples", nsamples );
721     cvWriteInt( fs, "nactive_vars", (int)cvSum(active_var_mask).val[0] );
722     cvWriteReal( fs, "oob_error", oob_error );
723
724     if( var_importance )
725         cvWrite( fs, "var_importance", var_importance );
726
727     cvWriteInt( fs, "ntrees", ntrees );
728
729     data->write_params( fs );
730
731     cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
732
733     for( k = 0; k < ntrees; k++ )
734     {
735         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
736         trees[k]->write( fs );
737         cvEndWriteStruct( fs );
738     }
739
740     cvEndWriteStruct( fs ); //trees
741     cvEndWriteStruct( fs ); //CV_TYPE_NAME_ML_RTREES
742 }
743
744
745 void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
746 {
747     int nactive_vars, var_count, k;
748     CvSeqReader reader;
749     CvFileNode* trees_fnode = 0;
750
751     clear();
752
753     nclasses     = cvReadIntByName( fs, fnode, "nclasses", -1 );
754     nsamples     = cvReadIntByName( fs, fnode, "nsamples" );
755     nactive_vars = cvReadIntByName( fs, fnode, "nactive_vars", -1 );
756     oob_error    = cvReadRealByName(fs, fnode, "oob_error", -1 );
757     ntrees       = cvReadIntByName( fs, fnode, "ntrees", -1 );
758
759     var_importance = (CvMat*)cvReadByName( fs, fnode, "var_importance" );
760
761     if( nclasses < 0 || nsamples <= 0 || nactive_vars < 0 || oob_error < 0 || ntrees <= 0)
762         CV_Error( CV_StsParseError, "Some <nclasses>, <nsamples>, <var_count>, "
763         "<nactive_vars>, <oob_error>, <ntrees> of tags are missing" );
764
765     rng = &cv::theRNG();
766
767     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*ntrees );
768     memset( trees, 0, sizeof(trees[0])*ntrees );
769
770     data = new CvDTreeTrainData();
771     data->read_params( fs, fnode );
772     data->shared = true;
773
774     trees_fnode = cvGetFileNodeByName( fs, fnode, "trees" );
775     if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
776         CV_Error( CV_StsParseError, "<trees> tag is missing" );
777
778     cvStartReadSeq( trees_fnode->data.seq, &reader );
779     if( reader.seq->total != ntrees )
780         CV_Error( CV_StsParseError,
781         "<ntrees> is not equal to the number of trees saved in file" );
782
783     for( k = 0; k < ntrees; k++ )
784     {
785         trees[k] = new CvForestTree();
786         trees[k]->read( fs, (CvFileNode*)reader.ptr, this, data );
787         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
788     }
789
790     var_count = data->var_count;
791     active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 );
792     {
793         // initialize active variables mask
794         CvMat submask1, submask2;
795         cvGetCols( active_var_mask, &submask1, 0, nactive_vars );
796         cvGetCols( active_var_mask, &submask2, nactive_vars, var_count );
797         cvSet( &submask1, cvScalar(1) );
798         cvZero( &submask2 );
799     }
800 }
801
802
803 int CvRTrees::get_tree_count() const
804 {
805     return ntrees;
806 }
807
808 CvForestTree* CvRTrees::get_tree(int i) const
809 {
810     return (unsigned)i < (unsigned)ntrees ? trees[i] : 0;
811 }
812
813 using namespace cv;
814
815 bool CvRTrees::train( const Mat& _train_data, int _tflag,
816                      const Mat& _responses, const Mat& _var_idx,
817                      const Mat& _sample_idx, const Mat& _var_type,
818                      const Mat& _missing_mask, CvRTParams _params )
819 {
820     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
821     sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
822     return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
823                  sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
824                  mmask.data.ptr ? &mmask : 0, _params);
825 }
826
827
828 float CvRTrees::predict( const Mat& _sample, const Mat& _missing ) const
829 {
830     CvMat sample = _sample, mmask = _missing;
831     return predict(&sample, mmask.data.ptr ? &mmask : 0);
832 }
833
834 float CvRTrees::predict_prob( const Mat& _sample, const Mat& _missing) const
835 {
836     CvMat sample = _sample, mmask = _missing;
837     return predict_prob(&sample, mmask.data.ptr ? &mmask : 0);
838 }
839
840 Mat CvRTrees::getVarImportance()
841 {
842     return Mat(get_var_importance());
843 }
844
845 // End of file.