Add OpenCV source code
[platform/upstream/opencv.git] / modules / ml / src / ertrees.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2
3   IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
5   By downloading, copying, installing or using the software you agree to this license.
6   If you do not agree to this license, do not download, install,
7   copy or use the software.
8
9
10                         Intel License Agreement
11
12  Copyright (C) 2000, Intel Corporation, all rights reserved.
13  Third party copyrights are property of their respective owners.
14
15  Redistribution and use in source and binary forms, with or without modification,
16  are permitted provided that the following conditions are met:
17
18    * Redistribution's of source code must retain the above copyright notice,
19      this list of conditions and the following disclaimer.
20
21    * Redistribution's in binary form must reproduce the above copyright notice,
22      this list of conditions and the following disclaimer in the documentation
23      and/or other materials provided with the distribution.
24
25    * The name of Intel Corporation may not be used to endorse or promote products
26      derived from this software without specific prior written permission.
27
28  This software is provided by the copyright holders and contributors "as is" and
29  any express or implied warranties, including, but not limited to, the implied
30  warranties of merchantability and fitness for a particular purpose are disclaimed.
31  In no event shall the Intel Corporation or contributors be liable for any direct,
32  indirect, incidental, special, exemplary, or consequential damages
33  (including, but not limited to, procurement of substitute goods or services;
34  loss of use, data, or profits; or business interruption) however caused
35  and on any theory of liability, whether in contract, strict liability,
36  or tort (including negligence or otherwise) arising in any way out of
37  the use of this software, even if advised of the possibility of such damage.
38
39 M*/
40
41 #include "precomp.hpp"
42
43 static const float ord_nan = FLT_MAX*0.5f;
44 static const int min_block_size = 1 << 16;
45 static const int block_size_delta = 1 << 10;
46
47 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
48 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
49
50 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
51 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
52
53 ///
54
55 void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
56     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
57     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
58     bool _shared, bool _add_labels, bool _update_data )
59 {
60     CvMat* sample_indices = 0;
61     CvMat* var_type0 = 0;
62     CvMat* tmp_map = 0;
63     int** int_ptr = 0;
64     CvPair16u32s* pair16u32s_ptr = 0;
65     CvDTreeTrainData* data = 0;
66     float *_fdst = 0;
67     int *_idst = 0;
68     unsigned short* udst = 0;
69     int* idst = 0;
70
71     CV_FUNCNAME( "CvERTreeTrainData::set_data" );
72
73     __BEGIN__;
74
75     int sample_all = 0, r_type, cv_n;
76     int total_c_count = 0;
77     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
78     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
79     int vi, i, size;
80     char err[100];
81     const int *sidx = 0, *vidx = 0;
82
83     uint64 effective_buf_size = 0;
84     int effective_buf_height = 0, effective_buf_width = 0;
85
86     if ( _params.use_surrogates )
87         CV_ERROR(CV_StsBadArg, "CvERTrees do not support surrogate splits");
88
89     if( _update_data && data_root )
90     {
91         CV_ERROR(CV_StsBadArg, "CvERTrees do not support data update");
92     }
93
94     clear();
95
96     var_all = 0;
97     rng = &cv::theRNG();
98
99     CV_CALL( set_params( _params ));
100
101     // check parameter types and sizes
102     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
103
104     train_data = _train_data;
105     responses = _responses;
106     missing_mask = _missing_mask;
107
108     if( _tflag == CV_ROW_SAMPLE )
109     {
110         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
111         dv_step = 1;
112         if( _missing_mask )
113             ms_step = _missing_mask->step, mv_step = 1;
114     }
115     else
116     {
117         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
118         ds_step = 1;
119         if( _missing_mask )
120             mv_step = _missing_mask->step, ms_step = 1;
121     }
122     tflag = _tflag;
123
124     sample_count = sample_all;
125     var_count = var_all;
126
127     if( _sample_idx )
128     {
129         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
130         sidx = sample_indices->data.i;
131         sample_count = sample_indices->rows + sample_indices->cols - 1;
132     }
133
134     if( _var_idx )
135     {
136         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
137         vidx = var_idx->data.i;
138         var_count = var_idx->rows + var_idx->cols - 1;
139     }
140
141     if( !CV_IS_MAT(_responses) ||
142         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
143          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
144         (_responses->rows != 1 && _responses->cols != 1) ||
145         _responses->rows + _responses->cols - 1 != sample_all )
146         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
147                   "floating-point vector containing as many elements as "
148                   "the total number of samples in the training data matrix" );
149
150     is_buf_16u = false;
151     if ( sample_count < 65536 )
152         is_buf_16u = true;
153
154     r_type = CV_VAR_CATEGORICAL;
155     if( _var_type )
156         CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
157
158     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
159
160     cat_var_count = 0;
161     ord_var_count = -1;
162
163     is_classifier = r_type == CV_VAR_CATEGORICAL;
164
165     // step 0. calc the number of categorical vars
166     for( vi = 0; vi < var_count; vi++ )
167     {
168         char vt = var_type0 ? var_type0->data.ptr[vi] : CV_VAR_ORDERED;
169         var_type->data.i[vi] = vt == CV_VAR_CATEGORICAL ? cat_var_count++ : ord_var_count--;
170     }
171
172     ord_var_count = ~ord_var_count;
173     cv_n = params.cv_folds;
174     // set the two last elements of var_type array to be able
175     // to locate responses and cross-validation labels using
176     // the corresponding get_* functions.
177     var_type->data.i[var_count] = cat_var_count;
178     var_type->data.i[var_count+1] = cat_var_count+1;
179
180     // in case of single ordered predictor we need dummy cv_labels
181     // for safe split_node_data() operation
182     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
183
184     work_var_count = cat_var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
185
186     shared = _shared;
187     buf_count = shared ? 2 : 1;
188
189     buf_size = -1; // the member buf_size is obsolete
190
191     effective_buf_size = (uint64)(work_var_count + 1)*(uint64)sample_count * buf_count; // this is the total size of "CvMat buf" to be allocated
192     effective_buf_width = sample_count;
193     effective_buf_height = work_var_count+1;
194
195     if (effective_buf_width >= effective_buf_height)
196         effective_buf_height *= buf_count;
197     else
198         effective_buf_width *= buf_count;
199
200     if ((uint64)effective_buf_width * (uint64)effective_buf_height != effective_buf_size)
201     {
202         CV_Error(CV_StsBadArg, "The memory buffer cannot be allocated since its size exceeds integer fields limit");
203     }
204
205     if ( is_buf_16u )
206     {
207         CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_16UC1 ));
208         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
209     }
210     else
211     {
212         CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_32SC1 ));
213         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
214     }
215
216     size = is_classifier ? cat_var_count+1 : cat_var_count;
217     size = !size ? 1 : size;
218     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
219     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
220
221     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
222     size = !size ? 1 : size;
223     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
224
225     // now calculate the maximum size of split,
226     // create memory storage that will keep nodes and splits of the decision tree
227     // allocate root node and the buffer for the whole training data
228     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
229         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
230     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
231     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
232     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
233     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
234
235     nv_size = var_count*sizeof(int);
236     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
237
238     temp_block_size = nv_size;
239
240     if( cv_n )
241     {
242         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
243             CV_ERROR( CV_StsOutOfRange,
244                 "The many folds in cross-validation for such a small dataset" );
245
246         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
247         temp_block_size = MAX(temp_block_size, cv_size);
248     }
249
250     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
251     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
252     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
253     if( cv_size )
254         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
255
256     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
257
258     max_c_count = 1;
259
260     _fdst = 0;
261     _idst = 0;
262     if (ord_var_count)
263         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
264     if (is_buf_16u && (cat_var_count || is_classifier))
265         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
266
267     // transform the training data to convenient representation
268     for( vi = 0; vi <= var_count; vi++ )
269     {
270         int ci;
271         const uchar* mask = 0;
272         int m_step = 0, step;
273         const int* idata = 0;
274         const float* fdata = 0;
275         int num_valid = 0;
276
277         if( vi < var_count ) // analyze i-th input variable
278         {
279             int vi0 = vidx ? vidx[vi] : vi;
280             ci = get_var_type(vi);
281             step = ds_step; m_step = ms_step;
282             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
283                 idata = _train_data->data.i + vi0*dv_step;
284             else
285                 fdata = _train_data->data.fl + vi0*dv_step;
286             if( _missing_mask )
287                 mask = _missing_mask->data.ptr + vi0*mv_step;
288         }
289         else // analyze _responses
290         {
291             ci = cat_var_count;
292             step = CV_IS_MAT_CONT(_responses->type) ?
293                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
294             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
295                 idata = _responses->data.i;
296             else
297                 fdata = _responses->data.fl;
298         }
299
300         if( (vi < var_count && ci>=0) ||
301             (vi == var_count && is_classifier) ) // process categorical variable or response
302         {
303             int c_count, prev_label;
304             int* c_map;
305
306             if (is_buf_16u)
307                 udst = (unsigned short*)(buf->data.s + ci*sample_count);
308             else
309                 idst = buf->data.i + ci*sample_count;
310
311             // copy data
312             for( i = 0; i < sample_count; i++ )
313             {
314                 int val = INT_MAX, si = sidx ? sidx[i] : i;
315                 if( !mask || !mask[(size_t)si*m_step] )
316                 {
317                     if( idata )
318                         val = idata[(size_t)si*step];
319                     else
320                     {
321                         float t = fdata[(size_t)si*step];
322                         val = cvRound(t);
323                         if( val != t )
324                         {
325                             sprintf( err, "%d-th value of %d-th (categorical) "
326                                 "variable is not an integer", i, vi );
327                             CV_ERROR( CV_StsBadArg, err );
328                         }
329                     }
330
331                     if( val == INT_MAX )
332                     {
333                         sprintf( err, "%d-th value of %d-th (categorical) "
334                             "variable is too large", i, vi );
335                         CV_ERROR( CV_StsBadArg, err );
336                     }
337                     num_valid++;
338                 }
339                 if (is_buf_16u)
340                 {
341                     _idst[i] = val;
342                     pair16u32s_ptr[i].u = udst + i;
343                     pair16u32s_ptr[i].i = _idst + i;
344                 }
345                 else
346                 {
347                     idst[i] = val;
348                     int_ptr[i] = idst + i;
349                 }
350             }
351
352             c_count = num_valid > 0;
353
354             if (is_buf_16u)
355             {
356                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
357                 // count the categories
358                 for( i = 1; i < num_valid; i++ )
359                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
360                         c_count ++ ;
361             }
362             else
363             {
364                 icvSortIntPtr( int_ptr, sample_count, 0 );
365                 // count the categories
366                 for( i = 1; i < num_valid; i++ )
367                     c_count += *int_ptr[i] != *int_ptr[i-1];
368             }
369
370             if( vi > 0 )
371                 max_c_count = MAX( max_c_count, c_count );
372             cat_count->data.i[ci] = c_count;
373             cat_ofs->data.i[ci] = total_c_count;
374
375             // resize cat_map, if need
376             if( cat_map->cols < total_c_count + c_count )
377             {
378                 tmp_map = cat_map;
379                 CV_CALL( cat_map = cvCreateMat( 1,
380                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
381                 for( i = 0; i < total_c_count; i++ )
382                     cat_map->data.i[i] = tmp_map->data.i[i];
383                 cvReleaseMat( &tmp_map );
384             }
385
386             c_map = cat_map->data.i + total_c_count;
387             total_c_count += c_count;
388
389             c_count = -1;
390             if (is_buf_16u)
391             {
392                 // compact the class indices and build the map
393                 prev_label = ~*pair16u32s_ptr[0].i;
394                 for( i = 0; i < num_valid; i++ )
395                 {
396                     int cur_label = *pair16u32s_ptr[i].i;
397                     if( cur_label != prev_label )
398                         c_map[++c_count] = prev_label = cur_label;
399                     *pair16u32s_ptr[i].u = (unsigned short)c_count;
400                 }
401                 // replace labels for missing values with 65535
402                 for( ; i < sample_count; i++ )
403                     *pair16u32s_ptr[i].u = 65535;
404             }
405             else
406             {
407                 // compact the class indices and build the map
408                 prev_label = ~*int_ptr[0];
409                 for( i = 0; i < num_valid; i++ )
410                 {
411                     int cur_label = *int_ptr[i];
412                     if( cur_label != prev_label )
413                         c_map[++c_count] = prev_label = cur_label;
414                     *int_ptr[i] = c_count;
415                 }
416                 // replace labels for missing values with -1
417                 for( ; i < sample_count; i++ )
418                     *int_ptr[i] = -1;
419             }
420         }
421         else if( ci < 0 ) // process ordered variable
422         {
423             for( i = 0; i < sample_count; i++ )
424             {
425                 float val = ord_nan;
426                 int si = sidx ? sidx[i] : i;
427                 if( !mask || !mask[(size_t)si*m_step] )
428                 {
429                     if( idata )
430                         val = (float)idata[(size_t)si*step];
431                     else
432                         val = fdata[(size_t)si*step];
433
434                     if( fabs(val) >= ord_nan )
435                     {
436                         sprintf( err, "%d-th value of %d-th (ordered) "
437                             "variable (=%g) is too large", i, vi, val );
438                         CV_ERROR( CV_StsBadArg, err );
439                     }
440                     num_valid++;
441                 }
442             }
443         }
444         if( vi < var_count )
445             data_root->set_num_valid(vi, num_valid);
446     }
447
448     // set sample labels
449     if (is_buf_16u)
450         udst = (unsigned short*)(buf->data.s + get_work_var_count()*sample_count);
451     else
452         idst = buf->data.i + get_work_var_count()*sample_count;
453
454     for (i = 0; i < sample_count; i++)
455     {
456         if (udst)
457             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
458         else
459             idst[i] = sidx ? sidx[i] : i;
460     }
461
462     if( cv_n )
463     {
464         unsigned short* usdst = 0;
465         int* idst2 = 0;
466
467         if (is_buf_16u)
468         {
469             usdst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
470             for( i = vi = 0; i < sample_count; i++ )
471             {
472                 usdst[i] = (unsigned short)vi++;
473                 vi &= vi < cv_n ? -1 : 0;
474             }
475
476             for( i = 0; i < sample_count; i++ )
477             {
478                 int a = (*rng)(sample_count);
479                 int b = (*rng)(sample_count);
480                 unsigned short unsh = (unsigned short)vi;
481                 CV_SWAP( usdst[a], usdst[b], unsh );
482             }
483         }
484         else
485         {
486             idst2 = buf->data.i + (get_work_var_count()-1)*sample_count;
487             for( i = vi = 0; i < sample_count; i++ )
488             {
489                 idst2[i] = vi++;
490                 vi &= vi < cv_n ? -1 : 0;
491             }
492
493             for( i = 0; i < sample_count; i++ )
494             {
495                 int a = (*rng)(sample_count);
496                 int b = (*rng)(sample_count);
497                 CV_SWAP( idst2[a], idst2[b], vi );
498             }
499         }
500     }
501
502     if ( cat_map )
503         cat_map->cols = MAX( total_c_count, 1 );
504
505     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
506         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
507     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
508
509     have_priors = is_classifier && params.priors;
510     if( is_classifier )
511     {
512         int m = get_num_classes();
513         double sum = 0;
514         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
515         for( i = 0; i < m; i++ )
516         {
517             double val = have_priors ? params.priors[i] : 1.;
518             if( val <= 0 )
519                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
520             priors->data.db[i] = val;
521             sum += val;
522         }
523
524         // normalize weights
525         if( have_priors )
526             cvScale( priors, priors, 1./sum );
527
528         CV_CALL( priors_mult = cvCloneMat( priors ));
529         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
530     }
531
532     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
533     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
534
535     __END__;
536
537     if( data )
538         delete data;
539
540     if (_fdst)
541         cvFree( &_fdst );
542     if (_idst)
543         cvFree( &_idst );
544     cvFree( &int_ptr );
545     cvReleaseMat( &var_type0 );
546     cvReleaseMat( &sample_indices );
547     cvReleaseMat( &tmp_map );
548 }
549
550 void CvERTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
551                                           const float** ord_values, const int** missing, int* sample_indices_buf )
552 {
553     int vidx = var_idx ? var_idx->data.i[vi] : vi;
554     int node_sample_count = n->sample_count;
555     // may use missing_buf as buffer for sample indices!
556     const int* sample_indices = get_sample_indices(n, sample_indices_buf ? sample_indices_buf : missing_buf);
557
558     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
559     int m_step = missing_mask ? missing_mask->step/CV_ELEM_SIZE(missing_mask->type) : 1;
560     if( tflag == CV_ROW_SAMPLE )
561     {
562         for( int i = 0; i < node_sample_count; i++ )
563         {
564             int idx = sample_indices[i];
565             missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + idx * m_step + vi) : 0;
566             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
567         }
568     }
569     else
570         for( int i = 0; i < node_sample_count; i++ )
571         {
572             int idx = sample_indices[i];
573             missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + vi* m_step + idx) : 0;
574             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
575         }
576     *ord_values = ord_values_buf;
577     *missing = missing_buf;
578 }
579
580
581 const int* CvERTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
582 {
583     return get_cat_var_data( n, var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0), indices_buf );
584 }
585
586
587 const int* CvERTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
588 {
589     if (have_labels)
590         return get_cat_var_data( n, var_count + (is_classifier ? 1 : 0), labels_buf );
591     return 0;
592 }
593
594
595 const int* CvERTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf )
596 {
597     int ci = get_var_type( vi);
598     const int* cat_values = 0;
599     if( !is_buf_16u )
600         cat_values = buf->data.i + n->buf_idx*get_length_subbuf() + ci*sample_count + n->offset;
601     else {
602         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
603             ci*sample_count + n->offset);
604         for( int i = 0; i < n->sample_count; i++ )
605             cat_values_buf[i] = short_values[i];
606         cat_values = cat_values_buf;
607     }
608     return cat_values;
609 }
610
611 void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
612                                     float* values, uchar* missing,
613                                     float* _responses, bool get_class_idx )
614 {
615     CvMat* subsample_idx = 0;
616     CvMat* subsample_co = 0;
617
618     cv::AutoBuffer<uchar> inn_buf(sample_count*(sizeof(float) + sizeof(int)));
619
620     CV_FUNCNAME( "CvERTreeTrainData::get_vectors" );
621
622     __BEGIN__;
623
624     int i, vi, total = sample_count, count = total, cur_ofs = 0;
625     int* sidx = 0;
626     int* co = 0;
627
628     if( _subsample_idx )
629     {
630         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
631         sidx = subsample_idx->data.i;
632         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
633         co = subsample_co->data.i;
634         cvZero( subsample_co );
635         count = subsample_idx->cols + subsample_idx->rows - 1;
636         for( i = 0; i < count; i++ )
637             co[sidx[i]*2]++;
638         for( i = 0; i < total; i++ )
639         {
640             int count_i = co[i*2];
641             if( count_i )
642             {
643                 co[i*2+1] = cur_ofs*var_count;
644                 cur_ofs += count_i;
645             }
646         }
647     }
648
649     if( missing )
650         memset( missing, 1, count*var_count );
651
652     for( vi = 0; vi < var_count; vi++ )
653     {
654         int ci = get_var_type(vi);
655         if( ci >= 0 ) // categorical
656         {
657             float* dst = values + vi;
658             uchar* m = missing ? missing + vi : 0;
659             int* lbls_buf = (int*)(uchar*)inn_buf;
660             const int* src = get_cat_var_data(data_root, vi, lbls_buf);
661
662             for( i = 0; i < count; i++, dst += var_count )
663             {
664                 int idx = sidx ? sidx[i] : i;
665                 int val = src[idx];
666                 *dst = (float)val;
667                 if( m )
668                 {
669                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
670                     m += var_count;
671                 }
672             }
673         }
674         else // ordered
675         {
676             int* mis_buf = (int*)(uchar*)inn_buf;
677             const float *dst = 0;
678             const int* mis = 0;
679             get_ord_var_data(data_root, vi, values + vi, mis_buf, &dst, &mis, 0);
680             for (int si = 0; si < total; si++)
681                 *(missing + vi + si) = mis[si] == 0 ? 0 : 1;
682         }
683     }
684
685     // copy responses
686     if( _responses )
687     {
688         if( is_classifier )
689         {
690             int* lbls_buf = (int*)(uchar*)inn_buf;
691             const int* src = get_class_labels(data_root, lbls_buf);
692             for( i = 0; i < count; i++ )
693             {
694                 int idx = sidx ? sidx[i] : i;
695                 int val = get_class_idx ? src[idx] :
696                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
697                 _responses[i] = (float)val;
698             }
699         }
700         else
701         {
702             float* _values_buf = (float*)(uchar*)inn_buf;
703             int* sample_idx_buf = (int*)(_values_buf + sample_count);
704             const float* _values = get_ord_responses(data_root, _values_buf, sample_idx_buf);
705             for( i = 0; i < count; i++ )
706             {
707                 int idx = sidx ? sidx[i] : i;
708                 _responses[i] = _values[idx];
709             }
710         }
711     }
712
713     __END__;
714
715     cvReleaseMat( &subsample_idx );
716     cvReleaseMat( &subsample_co );
717 }
718
719 CvDTreeNode* CvERTreeTrainData::subsample_data( const CvMat* _subsample_idx )
720 {
721     CvDTreeNode* root = 0;
722
723     CV_FUNCNAME( "CvERTreeTrainData::subsample_data" );
724
725     __BEGIN__;
726
727     if( !data_root )
728         CV_ERROR( CV_StsError, "No training data has been set" );
729
730     if( !_subsample_idx )
731     {
732         // make a copy of the root node
733         CvDTreeNode temp;
734         int i;
735         root = new_node( 0, 1, 0, 0 );
736         temp = *root;
737         *root = *data_root;
738         root->num_valid = temp.num_valid;
739         if( root->num_valid )
740         {
741             for( i = 0; i < var_count; i++ )
742                 root->num_valid[i] = data_root->num_valid[i];
743         }
744         root->cv_Tn = temp.cv_Tn;
745         root->cv_node_risk = temp.cv_node_risk;
746         root->cv_node_error = temp.cv_node_error;
747     }
748     else
749         CV_ERROR( CV_StsError, "_subsample_idx must be null for extra-trees" );
750     __END__;
751
752     return root;
753 }
754
755 double CvForestERTree::calc_node_dir( CvDTreeNode* node )
756 {
757     char* dir = (char*)data->direction->data.ptr;
758     int i, n = node->sample_count, vi = node->split->var_idx;
759     double L, R;
760
761     assert( !node->split->inversed );
762
763     if( data->get_var_type(vi) >= 0 ) // split on categorical var
764     {
765         cv::AutoBuffer<uchar> inn_buf(n*sizeof(int)*(!data->have_priors ? 1 : 2));
766         int* labels_buf = (int*)(uchar*)inn_buf;
767         const int* labels = data->get_cat_var_data( node, vi, labels_buf );
768         const int* subset = node->split->subset;
769         if( !data->have_priors )
770         {
771             int sum = 0, sum_abs = 0;
772
773             for( i = 0; i < n; i++ )
774             {
775                 int idx = labels[i];
776                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
777                     CV_DTREE_CAT_DIR(idx,subset) : 0;
778                 sum += d; sum_abs += d & 1;
779                 dir[i] = (char)d;
780             }
781
782             R = (sum_abs + sum) >> 1;
783             L = (sum_abs - sum) >> 1;
784         }
785         else
786         {
787             const double* priors = data->priors_mult->data.db;
788             double sum = 0, sum_abs = 0;
789             int *responses_buf = labels_buf + n;
790             const int* responses = data->get_class_labels(node, responses_buf);
791
792             for( i = 0; i < n; i++ )
793             {
794                 int idx = labels[i];
795                 double w = priors[responses[i]];
796                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
797                 sum += d*w; sum_abs += (d & 1)*w;
798                 dir[i] = (char)d;
799             }
800
801             R = (sum_abs + sum) * 0.5;
802             L = (sum_abs - sum) * 0.5;
803         }
804     }
805     else // split on ordered var
806     {
807         float split_val = node->split->ord.c;
808         cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(!data->have_priors ? 1 : 2) + sizeof(float)));
809         float* val_buf = (float*)(uchar*)inn_buf;
810         int* missing_buf = (int*)(val_buf + n);
811         const float* val = 0;
812         const int* missing = 0;
813         data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing, 0 );
814
815         if( !data->have_priors )
816         {
817             L = R = 0;
818             for( i = 0; i < n; i++ )
819             {
820                 if ( missing[i] )
821                     dir[i] = (char)0;
822                 else
823                 {
824                     if ( val[i] < split_val)
825                     {
826                         dir[i] = (char)-1;
827                         L++;
828                     }
829                     else
830                     {
831                         dir[i] = (char)1;
832                         R++;
833                     }
834                 }
835             }
836         }
837         else
838         {
839             const double* priors = data->priors_mult->data.db;
840             int* responses_buf = missing_buf + n;
841             const int* responses = data->get_class_labels(node, responses_buf);
842             L = R = 0;
843             for( i = 0; i < n; i++ )
844             {
845                 if ( missing[i] )
846                     dir[i] = (char)0;
847                 else
848                 {
849                     double w = priors[responses[i]];
850                     if ( val[i] < split_val)
851                     {
852                         dir[i] = (char)-1;
853                          L += w;
854                     }
855                     else
856                     {
857                         dir[i] = (char)1;
858                         R += w;
859                     }
860                 }
861             }
862         }
863     }
864
865     node->maxlr = MAX( L, R );
866     return node->split->quality/(L + R);
867 }
868
869 CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
870                                                     uchar* _ext_buf )
871 {
872     const float epsilon = FLT_EPSILON*2;
873     const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
874
875     int n = node->sample_count;
876     int m = data->get_num_classes();
877
878     cv::AutoBuffer<uchar> inn_buf;
879     if( !_ext_buf )
880         inn_buf.allocate(n*(2*sizeof(int) + sizeof(float)));
881     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
882     float* values_buf = (float*)ext_buf;
883     int* missing_buf = (int*)(values_buf + n);
884     const float* values = 0;
885     const int* missing = 0;
886     data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
887     int* responses_buf = missing_buf + n;
888     const int* responses = data->get_class_labels( node, responses_buf );
889
890     double lbest_val = 0, rbest_val = 0, best_val = init_quality, split_val = 0;
891     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
892     bool is_find_split = false;
893     float pmin, pmax;
894     int smpi = 0;
895     while ( missing[smpi] && (smpi < n) )
896         smpi++;
897     assert(smpi < n);
898
899     pmin = values[smpi];
900     pmax = pmin;
901     for (; smpi < n; smpi++)
902     {
903         float ptemp = values[smpi];
904         int ms = missing[smpi];
905         if (ms) continue;
906         if ( ptemp < pmin)
907             pmin = ptemp;
908         if ( ptemp > pmax)
909             pmax = ptemp;
910     }
911     float fdiff = pmax-pmin;
912     if (fdiff > epsilon)
913     {
914         is_find_split = true;
915         cv::RNG* rng = data->rng;
916         split_val = pmin + rng->uniform(0.f, 1.f) * fdiff ;
917         if (split_val - pmin <= FLT_EPSILON)
918             split_val = pmin + split_delta;
919         if (pmax - split_val <= FLT_EPSILON)
920             split_val = pmax - split_delta;
921
922         // calculate Gini index
923         if ( !priors )
924         {
925             cv::AutoBuffer<int> lrc(m*2);
926             int *lc = lrc, *rc = lc + m;
927             int L = 0, R = 0;
928
929             // init arrays of class instance counters on both sides of the split
930             for(int i = 0; i < m; i++ )
931             {
932                 lc[i] = 0;
933                 rc[i] = 0;
934             }
935             for( int si = 0; si < n; si++ )
936             {
937                 int r = responses[si];
938                 float val = values[si];
939                 int ms = missing[si];
940                 if (ms) continue;
941                 if ( val < split_val )
942                 {
943                     lc[r]++;
944                     L++;
945                 }
946                 else
947                 {
948                     rc[r]++;
949                     R++;
950                 }
951             }
952             for (int i = 0; i < m; i++)
953             {
954                 lbest_val += lc[i]*lc[i];
955                 rbest_val += rc[i]*rc[i];
956             }
957             best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
958         }
959         else
960         {
961             cv::AutoBuffer<double> lrc(m*2);
962             double *lc = lrc, *rc = lc + m;
963             double L = 0, R = 0;
964
965             // init arrays of class instance counters on both sides of the split
966             for(int i = 0; i < m; i++ )
967             {
968                 lc[i] = 0;
969                 rc[i] = 0;
970             }
971             for( int si = 0; si < n; si++ )
972             {
973                 int r = responses[si];
974                 float val = values[si];
975                 int ms = missing[si];
976                 double p = priors[r];
977                 if (ms) continue;
978                 if ( val < split_val )
979                 {
980                     lc[r] += p;
981                     L += p;
982                 }
983                 else
984                 {
985                     rc[r] += p;
986                     R += p;
987                 }
988             }
989             for (int i = 0; i < m; i++)
990             {
991                 lbest_val += lc[i]*lc[i];
992                 rbest_val += rc[i]*rc[i];
993             }
994             best_val = (lbest_val*R + rbest_val*L) / (L*R);
995         }
996
997     }
998
999     CvDTreeSplit* split = 0;
1000     if( is_find_split )
1001     {
1002         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1003         split->var_idx = vi;
1004         split->ord.c = (float)split_val;
1005         split->ord.split_point = -1;
1006         split->inversed = 0;
1007         split->quality = (float)best_val;
1008     }
1009     return split;
1010 }
1011
1012 CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1013                                                     uchar* _ext_buf )
1014 {
1015     int ci = data->get_var_type(vi);
1016     int n = node->sample_count;
1017     int cm = data->get_num_classes();
1018     int vm = data->cat_count->data.i[ci];
1019     double best_val = init_quality;
1020     CvDTreeSplit *split = 0;
1021
1022     if ( vm > 1 )
1023     {
1024         cv::AutoBuffer<int> inn_buf;
1025         if( !_ext_buf )
1026             inn_buf.allocate(2*n);
1027         int* ext_buf = _ext_buf ? (int*)_ext_buf : (int*)inn_buf;
1028
1029         const int* labels = data->get_cat_var_data( node, vi, ext_buf );
1030         const int* responses = data->get_class_labels( node, ext_buf + n );
1031
1032         const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1033
1034         // create random class mask
1035         cv::AutoBuffer<int> valid_cidx(vm);
1036         for (int i = 0; i < vm; i++)
1037         {
1038             valid_cidx[i] = -1;
1039         }
1040         for (int si = 0; si < n; si++)
1041         {
1042             int c = labels[si];
1043             if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1044                 continue;
1045             valid_cidx[c]++;
1046         }
1047
1048         int valid_ccount = 0;
1049         for (int i = 0; i < vm; i++)
1050             if (valid_cidx[i] >= 0)
1051             {
1052                 valid_cidx[i] = valid_ccount;
1053                 valid_ccount++;
1054             }
1055         if (valid_ccount > 1)
1056         {
1057             CvRNG* rng = forest->get_rng();
1058             int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1059
1060             CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1061             CvMat submask;
1062             memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1063             cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1064             cvSet( &submask, cvScalar(1) );
1065             for (int i = 0; i < valid_ccount; i++)
1066             {
1067                 uchar temp;
1068                 int i1 =  cvRandInt( rng ) % valid_ccount;
1069                 int i2 = cvRandInt( rng ) % valid_ccount;
1070                 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1071             }
1072
1073             split = _split ? _split : data->new_split_cat( 0, -1.0f );
1074             split->var_idx = vi;
1075             memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1076
1077             // calculate Gini index
1078             double lbest_val = 0, rbest_val = 0;
1079             if( !priors )
1080             {
1081                 cv::AutoBuffer<int> lrc(cm*2);
1082                 int *lc = lrc, *rc = lc + cm;
1083                 int L = 0, R = 0;
1084                 // init arrays of class instance counters on both sides of the split
1085                 for(int i = 0; i < cm; i++ )
1086                 {
1087                     lc[i] = 0;
1088                     rc[i] = 0;
1089                 }
1090                 for( int si = 0; si < n; si++ )
1091                 {
1092                     int r = responses[si];
1093                     int var_class_idx = labels[si];
1094                     if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1095                         continue;
1096                     int mask_class_idx = valid_cidx[var_class_idx];
1097                     if (var_class_mask->data.ptr[mask_class_idx])
1098                     {
1099                         lc[r]++;
1100                         L++;
1101                         split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1102                     }
1103                     else
1104                     {
1105                         rc[r]++;
1106                         R++;
1107                     }
1108                 }
1109                 for (int i = 0; i < cm; i++)
1110                 {
1111                     lbest_val += lc[i]*lc[i];
1112                     rbest_val += rc[i]*rc[i];
1113                 }
1114                 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
1115             }
1116             else
1117             {
1118                 cv::AutoBuffer<int> lrc(cm*2);
1119                 int *lc = lrc, *rc = lc + cm;
1120                 double L = 0, R = 0;
1121                 // init arrays of class instance counters on both sides of the split
1122                 for(int i = 0; i < cm; i++ )
1123                 {
1124                     lc[i] = 0;
1125                     rc[i] = 0;
1126                 }
1127                 for( int si = 0; si < n; si++ )
1128                 {
1129                     int r = responses[si];
1130                     int var_class_idx = labels[si];
1131                     if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1132                         continue;
1133                     double p = priors[si];
1134                     int mask_class_idx = valid_cidx[var_class_idx];
1135
1136                     if (var_class_mask->data.ptr[mask_class_idx])
1137                     {
1138                         lc[r]+=(int)p;
1139                         L+=p;
1140                         split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1141                     }
1142                     else
1143                     {
1144                         rc[r]+=(int)p;
1145                         R+=p;
1146                     }
1147                 }
1148                 for (int i = 0; i < cm; i++)
1149                 {
1150                     lbest_val += lc[i]*lc[i];
1151                     rbest_val += rc[i]*rc[i];
1152                 }
1153                 best_val = (lbest_val*R + rbest_val*L) / (L*R);
1154             }
1155             split->quality = (float)best_val;
1156
1157             cvReleaseMat(&var_class_mask);
1158         }
1159     }
1160
1161     return split;
1162 }
1163
1164 CvDTreeSplit* CvForestERTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1165                                                   uchar* _ext_buf )
1166 {
1167     const float epsilon = FLT_EPSILON*2;
1168     const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
1169     int n = node->sample_count;
1170     cv::AutoBuffer<uchar> inn_buf;
1171     if( !_ext_buf )
1172         inn_buf.allocate(n*(2*sizeof(int) + 2*sizeof(float)));
1173     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
1174     float* values_buf = (float*)ext_buf;
1175     int* missing_buf = (int*)(values_buf + n);
1176     const float* values = 0;
1177     const int* missing = 0;
1178     data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
1179     float* responses_buf =  (float*)(missing_buf + n);
1180     int* sample_indices_buf =  (int*)(responses_buf + n);
1181     const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
1182
1183     double best_val = init_quality, split_val = 0, lsum = 0, rsum = 0;
1184     int L = 0, R = 0;
1185
1186     bool is_find_split = false;
1187     float pmin, pmax;
1188     int smpi = 0;
1189     while ( missing[smpi] && (smpi < n) )
1190         smpi++;
1191
1192     assert(smpi < n);
1193
1194     pmin = values[smpi];
1195     pmax = pmin;
1196     for (; smpi < n; smpi++)
1197     {
1198         float ptemp = values[smpi];
1199         int m = missing[smpi];
1200         if (m) continue;
1201         if ( ptemp < pmin)
1202             pmin = ptemp;
1203         if ( ptemp > pmax)
1204             pmax = ptemp;
1205     }
1206     float fdiff = pmax-pmin;
1207     if (fdiff > epsilon)
1208     {
1209         is_find_split = true;
1210         cv::RNG* rng = data->rng;
1211         split_val = pmin + rng->uniform(0.f, 1.f) * fdiff ;
1212         if (split_val - pmin <= FLT_EPSILON)
1213             split_val = pmin + split_delta;
1214         if (pmax - split_val <= FLT_EPSILON)
1215             split_val = pmax - split_delta;
1216
1217         for (int si = 0; si < n; si++)
1218         {
1219             float r = responses[si];
1220             float val = values[si];
1221             int m = missing[si];
1222             if (m) continue;
1223             if (val < split_val)
1224             {
1225                 lsum += r;
1226                 L++;
1227             }
1228             else
1229             {
1230                 rsum += r;
1231                 R++;
1232             }
1233         }
1234         best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1235     }
1236
1237     CvDTreeSplit* split = 0;
1238     if( is_find_split )
1239     {
1240         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1241         split->var_idx = vi;
1242         split->ord.c = (float)split_val;
1243         split->ord.split_point = -1;
1244         split->inversed = 0;
1245         split->quality = (float)best_val;
1246     }
1247     return split;
1248 }
1249
1250 CvDTreeSplit* CvForestERTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1251                                                   uchar* _ext_buf )
1252 {
1253     int ci = data->get_var_type(vi);
1254     int n = node->sample_count;
1255     int vm = data->cat_count->data.i[ci];
1256     double best_val = init_quality;
1257     CvDTreeSplit *split = 0;
1258     float lsum = 0, rsum = 0;
1259
1260     if ( vm > 1 )
1261     {
1262         int base_size =  vm*sizeof(int);
1263         cv::AutoBuffer<uchar> inn_buf(base_size);
1264         if( !_ext_buf )
1265             inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
1266         uchar* base_buf = (uchar*)inn_buf;
1267         uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
1268         int* labels_buf = (int*)ext_buf;
1269         const int* labels = data->get_cat_var_data( node, vi, labels_buf );
1270         float* responses_buf =  (float*)(labels_buf + n);
1271         int* sample_indices_buf = (int*)(responses_buf + n);
1272         const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
1273
1274         // create random class mask
1275         int *valid_cidx = (int*)base_buf;
1276         for (int i = 0; i < vm; i++)
1277         {
1278             valid_cidx[i] = -1;
1279         }
1280         for (int si = 0; si < n; si++)
1281         {
1282             int c = labels[si];
1283             if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1284                         continue;
1285             valid_cidx[c]++;
1286         }
1287
1288         int valid_ccount = 0;
1289         for (int i = 0; i < vm; i++)
1290             if (valid_cidx[i] >= 0)
1291             {
1292                 valid_cidx[i] = valid_ccount;
1293                 valid_ccount++;
1294             }
1295         if (valid_ccount > 1)
1296         {
1297             CvRNG* rng = forest->get_rng();
1298             int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1299
1300             CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1301             CvMat submask;
1302             memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1303             cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1304             cvSet( &submask, cvScalar(1) );
1305             for (int i = 0; i < valid_ccount; i++)
1306             {
1307                 uchar temp;
1308                 int i1 = cvRandInt( rng ) % valid_ccount;
1309                 int i2 = cvRandInt( rng ) % valid_ccount;
1310                 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1311             }
1312
1313             split = _split ? _split : data->new_split_cat( 0, -1.0f);
1314             split->var_idx = vi;
1315             memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1316
1317             int L = 0, R = 0;
1318             for( int si = 0; si < n; si++ )
1319             {
1320                 float r = responses[si];
1321                 int var_class_idx = labels[si];
1322                 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1323                         continue;
1324                 int mask_class_idx = valid_cidx[var_class_idx];
1325                 if (var_class_mask->data.ptr[mask_class_idx])
1326                 {
1327                     lsum += r;
1328                     L++;
1329                     split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1330                 }
1331                 else
1332                 {
1333                     rsum += r;
1334                     R++;
1335                 }
1336             }
1337             best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1338
1339             split->quality = (float)best_val;
1340
1341             cvReleaseMat(&var_class_mask);
1342         }
1343     }
1344
1345     return split;
1346 }
1347
1348 void CvForestERTree::split_node_data( CvDTreeNode* node )
1349 {
1350     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
1351     char* dir = (char*)data->direction->data.ptr;
1352     CvDTreeNode *left = 0, *right = 0;
1353     int new_buf_idx = data->get_child_buf_idx( node );
1354     CvMat* buf = data->buf;
1355     size_t length_buf_row = data->get_length_subbuf();
1356     cv::AutoBuffer<int> temp_buf(n);
1357
1358     complete_node_dir(node);
1359
1360     for( i = nl = nr = 0; i < n; i++ )
1361     {
1362         int d = dir[i];
1363         nr += d;
1364         nl += d^1;
1365     }
1366
1367     bool split_input_data;
1368     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
1369     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
1370
1371     split_input_data = node->depth + 1 < data->params.max_depth &&
1372         (node->left->sample_count > data->params.min_sample_count ||
1373         node->right->sample_count > data->params.min_sample_count);
1374
1375     cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)+sizeof(float)));
1376     // split ordered vars
1377     for( vi = 0; vi < data->var_count; vi++ )
1378     {
1379         int ci = data->get_var_type(vi);
1380         if (ci >= 0) continue;
1381
1382         int n1 = node->get_num_valid(vi), nr1 = 0;
1383         float* values_buf = (float*)(uchar*)inn_buf;
1384         int* missing_buf = (int*)(values_buf + n);
1385         const float* values = 0;
1386         const int* missing = 0;
1387         data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
1388
1389         for( i = 0; i < n; i++ )
1390             nr1 += ((!missing[i]) & dir[i]);
1391         left->set_num_valid(vi, n1 - nr1);
1392         right->set_num_valid(vi, nr1);
1393     }
1394     // split categorical vars, responses and cv_labels using new_idx relocation table
1395     for( vi = 0; vi < data->get_work_var_count() + data->ord_var_count; vi++ )
1396     {
1397         int ci = data->get_var_type(vi);
1398         if (ci < 0) continue;
1399
1400         int n1 = node->get_num_valid(vi), nr1 = 0;
1401         const int* src_lbls = data->get_cat_var_data(node, vi, (int*)(uchar*)inn_buf);
1402
1403         for(i = 0; i < n; i++)
1404             temp_buf[i] = src_lbls[i];
1405
1406         if (data->is_buf_16u)
1407         {
1408             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row +
1409                 ci*scount + left->offset);
1410             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row +
1411                 ci*scount + right->offset);
1412
1413             for( i = 0; i < n; i++ )
1414             {
1415                 int d = dir[i];
1416                 int idx = temp_buf[i];
1417                 if (d)
1418                 {
1419                     *rdst = (unsigned short)idx;
1420                     rdst++;
1421                     nr1 += (idx != 65535);
1422                 }
1423                 else
1424                 {
1425                     *ldst = (unsigned short)idx;
1426                     ldst++;
1427                 }
1428             }
1429
1430             if( vi < data->var_count )
1431             {
1432                 left->set_num_valid(vi, n1 - nr1);
1433                 right->set_num_valid(vi, nr1);
1434             }
1435         }
1436         else
1437         {
1438             int *ldst = buf->data.i + left->buf_idx*length_buf_row +
1439                 ci*scount + left->offset;
1440             int *rdst = buf->data.i + right->buf_idx*length_buf_row +
1441                 ci*scount + right->offset;
1442
1443             for( i = 0; i < n; i++ )
1444             {
1445                 int d = dir[i];
1446                 int idx = temp_buf[i];
1447                 if (d)
1448                 {
1449                     *rdst = idx;
1450                     rdst++;
1451                     nr1 += (idx >= 0);
1452                 }
1453                 else
1454                 {
1455                     *ldst = idx;
1456                     ldst++;
1457                 }
1458
1459             }
1460
1461             if( vi < data->var_count )
1462             {
1463                 left->set_num_valid(vi, n1 - nr1);
1464                 right->set_num_valid(vi, nr1);
1465             }
1466         }
1467     }
1468
1469     // split sample indices
1470     int *sample_idx_src_buf = (int*)(uchar*)inn_buf;
1471     const int* sample_idx_src = 0;
1472     if (split_input_data)
1473     {
1474         sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
1475
1476         for(i = 0; i < n; i++)
1477             temp_buf[i] = sample_idx_src[i];
1478
1479         int pos = data->get_work_var_count();
1480
1481         if (data->is_buf_16u)
1482         {
1483             unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
1484                 pos*scount + left->offset);
1485             unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row +
1486                 pos*scount + right->offset);
1487
1488             for (i = 0; i < n; i++)
1489             {
1490                 int d = dir[i];
1491                 unsigned short idx = (unsigned short)temp_buf[i];
1492                 if (d)
1493                 {
1494                     *rdst = idx;
1495                     rdst++;
1496                 }
1497                 else
1498                 {
1499                     *ldst = idx;
1500                     ldst++;
1501                 }
1502             }
1503         }
1504         else
1505         {
1506             int* ldst = buf->data.i + left->buf_idx*length_buf_row +
1507                 pos*scount + left->offset;
1508             int* rdst = buf->data.i + right->buf_idx*length_buf_row +
1509                 pos*scount + right->offset;
1510             for (i = 0; i < n; i++)
1511             {
1512                 int d = dir[i];
1513                 int idx = temp_buf[i];
1514                 if (d)
1515                 {
1516                     *rdst = idx;
1517                     rdst++;
1518                 }
1519                 else
1520                 {
1521                     *ldst = idx;
1522                     ldst++;
1523                 }
1524             }
1525         }
1526     }
1527
1528     // deallocate the parent node data that is not needed anymore
1529     data->free_node_data(node);
1530 }
1531
1532 CvERTrees::CvERTrees()
1533 {
1534 }
1535
1536 CvERTrees::~CvERTrees()
1537 {
1538 }
1539
1540 std::string CvERTrees::getName() const
1541 {
1542     return CV_TYPE_NAME_ML_ERTREES;
1543 }
1544
1545 bool CvERTrees::train( const CvMat* _train_data, int _tflag,
1546                         const CvMat* _responses, const CvMat* _var_idx,
1547                         const CvMat* _sample_idx, const CvMat* _var_type,
1548                         const CvMat* _missing_mask, CvRTParams params )
1549 {
1550     bool result = false;
1551
1552     CV_FUNCNAME("CvERTrees::train");
1553     __BEGIN__
1554     int var_count = 0;
1555
1556     clear();
1557
1558     CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
1559         params.regression_accuracy, params.use_surrogates, params.max_categories,
1560         params.cv_folds, params.use_1se_rule, false, params.priors );
1561
1562     data = new CvERTreeTrainData();
1563     CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
1564         _sample_idx, _var_type, _missing_mask, tree_params, true));
1565
1566     var_count = data->var_count;
1567     if( params.nactive_vars > var_count )
1568         params.nactive_vars = var_count;
1569     else if( params.nactive_vars == 0 )
1570         params.nactive_vars = (int)sqrt((double)var_count);
1571     else if( params.nactive_vars < 0 )
1572         CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
1573
1574     // Create mask of active variables at the tree nodes
1575     CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
1576     if( params.calc_var_importance )
1577     {
1578         CV_CALL(var_importance  = cvCreateMat( 1, var_count, CV_32FC1 ));
1579         cvZero(var_importance);
1580     }
1581     { // initialize active variables mask
1582         CvMat submask1, submask2;
1583         CV_Assert( (active_var_mask->cols >= 1) && (params.nactive_vars > 0) && (params.nactive_vars <= active_var_mask->cols) );
1584         cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
1585         cvSet( &submask1, cvScalar(1) );
1586         if( params.nactive_vars < active_var_mask->cols )
1587         {
1588             cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
1589             cvZero( &submask2 );
1590         }
1591     }
1592
1593     CV_CALL(result = grow_forest( params.term_crit ));
1594
1595     result = true;
1596
1597     __END__
1598     return result;
1599
1600 }
1601
1602 bool CvERTrees::train( CvMLData* _data, CvRTParams params)
1603 {
1604    bool result = false;
1605
1606     CV_FUNCNAME( "CvERTrees::train" );
1607
1608     __BEGIN__;
1609
1610     CV_CALL( result = CvRTrees::train( _data, params) );
1611
1612     __END__;
1613
1614     return result;
1615 }
1616
1617 bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
1618 {
1619     bool result = false;
1620
1621     CvMat* sample_idx_for_tree      = 0;
1622
1623     CV_FUNCNAME("CvERTrees::grow_forest");
1624     __BEGIN__;
1625
1626     const int max_ntrees = term_crit.max_iter;
1627     const double max_oob_err = term_crit.epsilon;
1628
1629     const int dims = data->var_count;
1630     float maximal_response = 0;
1631
1632     CvMat* oob_sample_votes    = 0;
1633     CvMat* oob_responses       = 0;
1634
1635     float* oob_samples_perm_ptr= 0;
1636
1637     float* samples_ptr     = 0;
1638     uchar* missing_ptr     = 0;
1639     float* true_resp_ptr   = 0;
1640     bool is_oob_or_vimportance = ((max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER)) || var_importance;
1641
1642     // oob_predictions_sum[i] = sum of predicted values for the i-th sample
1643     // oob_num_of_predictions[i] = number of summands
1644     //                            (number of predictions for the i-th sample)
1645     // initialize these variable to avoid warning C4701
1646     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
1647     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
1648
1649     nsamples = data->sample_count;
1650     nclasses = data->get_num_classes();
1651
1652     if ( is_oob_or_vimportance )
1653     {
1654         if( data->is_classifier )
1655         {
1656             CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
1657             cvZero(oob_sample_votes);
1658         }
1659         else
1660         {
1661             // oob_responses[0,i] = oob_predictions_sum[i]
1662             //    = sum of predicted values for the i-th sample
1663             // oob_responses[1,i] = oob_num_of_predictions[i]
1664             //    = number of summands (number of predictions for the i-th sample)
1665             CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
1666             cvZero(oob_responses);
1667             cvGetRow( oob_responses, &oob_predictions_sum, 0 );
1668             cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
1669         }
1670
1671         CV_CALL(oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1672         CV_CALL(samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1673         CV_CALL(missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
1674         CV_CALL(true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples ));
1675
1676         CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
1677         {
1678             double minval, maxval;
1679             CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
1680             cvMinMaxLoc( &responses, &minval, &maxval );
1681             maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
1682         }
1683     }
1684
1685     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
1686     memset( trees, 0, sizeof(trees[0])*max_ntrees );
1687
1688     CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));
1689
1690     for (int i = 0; i < nsamples; i++)
1691         sample_idx_for_tree->data.i[i] = i;
1692     ntrees = 0;
1693     while( ntrees < max_ntrees )
1694     {
1695         int i, oob_samples_count = 0;
1696         double ncorrect_responses = 0; // used for estimation of variable importance
1697         CvForestTree* tree = 0;
1698
1699         trees[ntrees] = new CvForestERTree();
1700         tree = (CvForestERTree*)trees[ntrees];
1701         CV_CALL(tree->train( data, 0, this ));
1702
1703         if ( is_oob_or_vimportance )
1704         {
1705             CvMat sample, missing;
1706             // form array of OOB samples indices and get these samples
1707             sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
1708             missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
1709
1710             oob_error = 0;
1711             for( i = 0; i < nsamples; i++,
1712                 sample.data.fl += dims, missing.data.ptr += dims )
1713             {
1714                 CvDTreeNode* predicted_node = 0;
1715
1716                 // predict oob samples
1717                 if( !predicted_node )
1718                     CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
1719
1720                 if( !data->is_classifier ) //regression
1721                 {
1722                     double avg_resp, resp = predicted_node->value;
1723                     oob_predictions_sum.data.fl[i] += (float)resp;
1724                     oob_num_of_predictions.data.fl[i] += 1;
1725
1726                     // compute oob error
1727                     avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
1728                     avg_resp -= true_resp_ptr[i];
1729                     oob_error += avg_resp*avg_resp;
1730                     resp = (resp - true_resp_ptr[i])/maximal_response;
1731                     ncorrect_responses += exp( -resp*resp );
1732                 }
1733                 else //classification
1734                 {
1735                     double prdct_resp;
1736                     CvPoint max_loc;
1737                     CvMat votes;
1738
1739                     cvGetRow(oob_sample_votes, &votes, i);
1740                     votes.data.i[predicted_node->class_idx]++;
1741
1742                     // compute oob error
1743                     cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
1744
1745                     prdct_resp = data->cat_map->data.i[max_loc.x];
1746                     oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
1747
1748                     ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
1749                 }
1750                 oob_samples_count++;
1751             }
1752             if( oob_samples_count > 0 )
1753                 oob_error /= (double)oob_samples_count;
1754
1755             // estimate variable importance
1756             if( var_importance && oob_samples_count > 0 )
1757             {
1758                 int m;
1759
1760                 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
1761                 for( m = 0; m < dims; m++ )
1762                 {
1763                     double ncorrect_responses_permuted = 0;
1764                     // randomly permute values of the m-th variable in the oob samples
1765                     float* mth_var_ptr = oob_samples_perm_ptr + m;
1766
1767                     for( i = 0; i < nsamples; i++ )
1768                     {
1769                         int i1, i2;
1770                         float temp;
1771
1772                         i1 = (*rng)(nsamples);
1773                         i2 = (*rng)(nsamples);
1774                         CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
1775
1776                         // turn values of (m-1)-th variable, that were permuted
1777                         // at the previous iteration, untouched
1778                         if( m > 1 )
1779                             oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
1780                     }
1781
1782                     // predict "permuted" cases and calculate the number of votes for the
1783                     // correct class in the variable-m-permuted oob data
1784                     sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
1785                     missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
1786                     for( i = 0; i < nsamples; i++,
1787                         sample.data.fl += dims, missing.data.ptr += dims )
1788                     {
1789                         double predct_resp, true_resp;
1790
1791                         predct_resp = tree->predict(&sample, &missing, true)->value;
1792                         true_resp   = true_resp_ptr[i];
1793                         if( data->is_classifier )
1794                             ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
1795                         else
1796                         {
1797                             true_resp = (true_resp - predct_resp)/maximal_response;
1798                             ncorrect_responses_permuted += exp( -true_resp*true_resp );
1799                         }
1800                     }
1801                     var_importance->data.fl[m] += (float)(ncorrect_responses
1802                         - ncorrect_responses_permuted);
1803                 }
1804             }
1805         }
1806         ntrees++;
1807         if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
1808             break;
1809     }
1810     if( var_importance )
1811     {
1812         for ( int vi = 0; vi < var_importance->cols; vi++ )
1813                 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
1814                     var_importance->data.fl[vi] : 0;
1815         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
1816     }
1817
1818     result = true;
1819
1820     cvFree( &oob_samples_perm_ptr );
1821     cvFree( &samples_ptr );
1822     cvFree( &missing_ptr );
1823     cvFree( &true_resp_ptr );
1824
1825     cvReleaseMat( &sample_idx_for_tree );
1826
1827     cvReleaseMat( &oob_sample_votes );
1828     cvReleaseMat( &oob_responses );
1829
1830     __END__;
1831
1832     return result;
1833 }
1834
1835 using namespace cv;
1836
1837 bool CvERTrees::train( const Mat& _train_data, int _tflag,
1838                       const Mat& _responses, const Mat& _var_idx,
1839                       const Mat& _sample_idx, const Mat& _var_type,
1840                       const Mat& _missing_mask, CvRTParams params )
1841 {
1842     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
1843     sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
1844     return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
1845                  sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
1846                  mmask.data.ptr ? &mmask : 0, params);
1847 }
1848
1849 // End of file.