fixed many warnings from GCC 4.6.1
[profile/ivi/opencv.git] / modules / ml / src / boost.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "precomp.hpp"
42
43 static inline double
44 log_ratio( double val )
45 {
46     const double eps = 1e-5;
47
48     val = MAX( val, eps );
49     val = MIN( val, 1. - eps );
50     return log( val/(1. - val) );
51 }
52
53
54 CvBoostParams::CvBoostParams()
55 {
56     boost_type = CvBoost::REAL;
57     weak_count = 100;
58     weight_trim_rate = 0.95;
59     cv_folds = 0;
60     max_depth = 1;
61 }
62
63
64 CvBoostParams::CvBoostParams( int _boost_type, int _weak_count,
65                                         double _weight_trim_rate, int _max_depth,
66                                         bool _use_surrogates, const float* _priors )
67 {
68     boost_type = _boost_type;
69     weak_count = _weak_count;
70     weight_trim_rate = _weight_trim_rate;
71     split_criteria = CvBoost::DEFAULT;
72     cv_folds = 0;
73     max_depth = _max_depth;
74     use_surrogates = _use_surrogates;
75     priors = _priors;
76 }
77
78
79
80 ///////////////////////////////// CvBoostTree ///////////////////////////////////
81
82 CvBoostTree::CvBoostTree()
83 {
84     ensemble = 0;
85 }
86
87
88 CvBoostTree::~CvBoostTree()
89 {
90     clear();
91 }
92
93
94 void
95 CvBoostTree::clear()
96 {
97     CvDTree::clear();
98     ensemble = 0;
99 }
100
101
102 bool
103 CvBoostTree::train( CvDTreeTrainData* _train_data,
104                     const CvMat* _subsample_idx, CvBoost* _ensemble )
105 {
106     clear();
107     ensemble = _ensemble;
108     data = _train_data;
109     data->shared = true;
110     return do_train( _subsample_idx );
111 }
112
113
114 bool
115 CvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*,
116                     const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
117 {
118     assert(0);
119     return false;
120 }
121
122
123 bool
124 CvBoostTree::train( CvDTreeTrainData*, const CvMat* )
125 {
126     assert(0);
127     return false;
128 }
129
130
131 void
132 CvBoostTree::scale( double scale )
133 {
134     CvDTreeNode* node = root;
135
136     // traverse the tree and scale all the node values
137     for(;;)
138     {
139         CvDTreeNode* parent;
140         for(;;)
141         {
142             node->value *= scale;
143             if( !node->left )
144                 break;
145             node = node->left;
146         }
147
148         for( parent = node->parent; parent && parent->right == node;
149             node = parent, parent = parent->parent )
150             ;
151
152         if( !parent )
153             break;
154
155         node = parent->right;
156     }
157 }
158
159
160 void
161 CvBoostTree::try_split_node( CvDTreeNode* node )
162 {
163     CvDTree::try_split_node( node );
164
165     if( !node->left )
166     {
167         // if the node has not been split,
168         // store the responses for the corresponding training samples
169         double* weak_eval = ensemble->get_weak_response()->data.db;
170         cv::AutoBuffer<int> inn_buf(node->sample_count);
171         const int* labels = data->get_cv_labels( node, (int*)inn_buf );
172         int i, count = node->sample_count;
173         double value = node->value;
174
175         for( i = 0; i < count; i++ )
176             weak_eval[labels[i]] = value;
177     }
178 }
179
180
181 double
182 CvBoostTree::calc_node_dir( CvDTreeNode* node )
183 {
184     char* dir = (char*)data->direction->data.ptr;
185     const double* weights = ensemble->get_subtree_weights()->data.db;
186     int i, n = node->sample_count, vi = node->split->var_idx;
187     double L, R;
188
189     assert( !node->split->inversed );
190
191     if( data->get_var_type(vi) >= 0 ) // split on categorical var
192     {
193         cv::AutoBuffer<int> inn_buf(n);
194         const int* cat_labels = data->get_cat_var_data( node, vi, (int*)inn_buf );
195         const int* subset = node->split->subset;
196         double sum = 0, sum_abs = 0;
197
198         for( i = 0; i < n; i++ )
199         {
200             int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
201             double w = weights[i];
202             int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
203             sum += d*w; sum_abs += (d & 1)*w;
204             dir[i] = (char)d;
205         }
206
207         R = (sum_abs + sum) * 0.5;
208         L = (sum_abs - sum) * 0.5;
209     }
210     else // split on ordered var
211     {
212         cv::AutoBuffer<uchar> inn_buf(2*n*sizeof(int)+n*sizeof(float));
213         float* values_buf = (float*)(uchar*)inn_buf;
214         int* sorted_indices_buf = (int*)(values_buf + n);
215         int* sample_indices_buf = sorted_indices_buf + n;
216         const float* values = 0;
217         const int* sorted_indices = 0;
218         data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
219         int split_point = node->split->ord.split_point;
220         int n1 = node->get_num_valid(vi);
221
222         assert( 0 <= split_point && split_point < n1-1 );
223         L = R = 0;
224
225         for( i = 0; i <= split_point; i++ )
226         {
227             int idx = sorted_indices[i];
228             double w = weights[idx];
229             dir[idx] = (char)-1;
230             L += w;
231         }
232
233         for( ; i < n1; i++ )
234         {
235             int idx = sorted_indices[i];
236             double w = weights[idx];
237             dir[idx] = (char)1;
238             R += w;
239         }
240
241         for( ; i < n; i++ )
242             dir[sorted_indices[i]] = (char)0;
243     }
244
245     node->maxlr = MAX( L, R );
246     return node->split->quality/(L + R);
247 }
248
249
250 CvDTreeSplit*
251 CvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality,
252                                     CvDTreeSplit* _split, uchar* _ext_buf )
253 {
254     const float epsilon = FLT_EPSILON*2;
255
256     const double* weights = ensemble->get_subtree_weights()->data.db;
257     int n = node->sample_count;
258     int n1 = node->get_num_valid(vi);
259
260     cv::AutoBuffer<uchar> inn_buf;
261     if( !_ext_buf )
262         inn_buf.allocate(n*(3*sizeof(int)+sizeof(float)));
263     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
264     float* values_buf = (float*)ext_buf;
265     int* sorted_indices_buf = (int*)(values_buf + n);
266     int* sample_indices_buf = sorted_indices_buf + n;
267     const float* values = 0;
268     const int* sorted_indices = 0;
269     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
270     int* responses_buf = sorted_indices_buf + n;
271     const int* responses = data->get_class_labels( node, responses_buf );
272     const double* rcw0 = weights + n;
273     double lcw[2] = {0,0}, rcw[2];
274     int i, best_i = -1;
275     double best_val = init_quality;
276     int boost_type = ensemble->get_params().boost_type;
277     int split_criteria = ensemble->get_params().split_criteria;
278
279     rcw[0] = rcw0[0]; rcw[1] = rcw0[1];
280     for( i = n1; i < n; i++ )
281     {
282         int idx = sorted_indices[i];
283         double w = weights[idx];
284         rcw[responses[idx]] -= w;
285     }
286
287     if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
288         split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
289
290     if( split_criteria == CvBoost::GINI )
291     {
292         double L = 0, R = rcw[0] + rcw[1];
293         double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
294
295         for( i = 0; i < n1 - 1; i++ )
296         {
297             int idx = sorted_indices[i];
298             double w = weights[idx], w2 = w*w;
299             double lv, rv;
300             idx = responses[idx];
301             L += w; R -= w;
302             lv = lcw[idx]; rv = rcw[idx];
303             lsum2 += 2*lv*w + w2;
304             rsum2 -= 2*rv*w - w2;
305             lcw[idx] = lv + w; rcw[idx] = rv - w;
306
307             if( values[i] + epsilon < values[i+1] )
308             {
309                 double val = (lsum2*R + rsum2*L)/(L*R);
310                 if( best_val < val )
311                 {
312                     best_val = val;
313                     best_i = i;
314                 }
315             }
316         }
317     }
318     else
319     {
320         for( i = 0; i < n1 - 1; i++ )
321         {
322             int idx = sorted_indices[i];
323             double w = weights[idx];
324             idx = responses[idx];
325             lcw[idx] += w;
326             rcw[idx] -= w;
327
328             if( values[i] + epsilon < values[i+1] )
329             {
330                 double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];
331                 val = MAX(val, val2);
332                 if( best_val < val )
333                 {
334                     best_val = val;
335                     best_i = i;
336                 }
337             }
338         }
339     }
340
341     CvDTreeSplit* split = 0;
342     if( best_i >= 0 )
343     {
344         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
345         split->var_idx = vi;
346         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
347         split->ord.split_point = best_i;
348         split->inversed = 0;
349         split->quality = (float)best_val;
350     }
351     return split;
352 }
353
354
355 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
356 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
357
358 CvDTreeSplit*
359 CvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
360 {
361     int ci = data->get_var_type(vi);
362     int n = node->sample_count;
363     int mi = data->cat_count->data.i[ci];
364
365     int base_size = (2*mi+3)*sizeof(double) + mi*sizeof(double*);
366     cv::AutoBuffer<uchar> inn_buf((2*mi+3)*sizeof(double) + mi*sizeof(double*));
367     if( !_ext_buf)
368         inn_buf.allocate( base_size + 2*n*sizeof(int) );
369     uchar* base_buf = (uchar*)inn_buf;
370     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
371
372     int* cat_labels_buf = (int*)ext_buf;
373     const int* cat_labels = data->get_cat_var_data(node, vi, cat_labels_buf);
374     int* responses_buf = cat_labels_buf + n;
375     const int* responses = data->get_class_labels(node, responses_buf);
376     double lcw[2]={0,0}, rcw[2]={0,0};
377
378     double* cjk = (double*)cv::alignPtr(base_buf,sizeof(double))+2;
379     const double* weights = ensemble->get_subtree_weights()->data.db;
380     double** dbl_ptr = (double**)(cjk + 2*mi);
381     int i, j, k, idx;
382     double L = 0, R;
383     double best_val = init_quality;
384     int best_subset = -1, subset_i;
385     int boost_type = ensemble->get_params().boost_type;
386     int split_criteria = ensemble->get_params().split_criteria;
387
388     // init array of counters:
389     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
390     for( j = -1; j < mi; j++ )
391         cjk[j*2] = cjk[j*2+1] = 0;
392
393     for( i = 0; i < n; i++ )
394     {
395         double w = weights[i];
396         j = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
397         k = responses[i];
398         cjk[j*2 + k] += w;
399     }
400
401     for( j = 0; j < mi; j++ )
402     {
403         rcw[0] += cjk[j*2];
404         rcw[1] += cjk[j*2+1];
405         dbl_ptr[j] = cjk + j*2 + 1;
406     }
407
408     R = rcw[0] + rcw[1];
409
410     if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
411         split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
412
413     // sort rows of c_jk by increasing c_j,1
414     // (i.e. by the weight of samples in j-th category that belong to class 1)
415     icvSortDblPtr( dbl_ptr, mi, 0 );
416
417     for( subset_i = 0; subset_i < mi-1; subset_i++ )
418     {
419         idx = (int)(dbl_ptr[subset_i] - cjk)/2;
420         const double* crow = cjk + idx*2;
421         double w0 = crow[0], w1 = crow[1];
422         double weight = w0 + w1;
423
424         if( weight < FLT_EPSILON )
425             continue;
426
427         lcw[0] += w0; rcw[0] -= w0;
428         lcw[1] += w1; rcw[1] -= w1;
429
430         if( split_criteria == CvBoost::GINI )
431         {
432             double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];
433             double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
434
435             L += weight;
436             R -= weight;
437
438             if( L > FLT_EPSILON && R > FLT_EPSILON )
439             {
440                 double val = (lsum2*R + rsum2*L)/(L*R);
441                 if( best_val < val )
442                 {
443                     best_val = val;
444                     best_subset = subset_i;
445                 }
446             }
447         }
448         else
449         {
450             double val = lcw[0] + rcw[1];
451             double val2 = lcw[1] + rcw[0];
452
453             val = MAX(val, val2);
454             if( best_val < val )
455             {
456                 best_val = val;
457                 best_subset = subset_i;
458             }
459         }
460     }
461
462     CvDTreeSplit* split = 0;
463     if( best_subset >= 0 )
464     {
465         split = _split ? _split : data->new_split_cat( 0, -1.0f);
466         split->var_idx = vi;
467         split->quality = (float)best_val;
468         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
469         for( i = 0; i <= best_subset; i++ )
470         {
471             idx = (int)(dbl_ptr[i] - cjk) >> 1;
472             split->subset[idx >> 5] |= 1 << (idx & 31);
473         }
474     }
475     return split;
476 }
477
478
479 CvDTreeSplit*
480 CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
481 {
482     const float epsilon = FLT_EPSILON*2;
483     const double* weights = ensemble->get_subtree_weights()->data.db;
484     int n = node->sample_count;
485     int n1 = node->get_num_valid(vi);
486
487     cv::AutoBuffer<uchar> inn_buf;
488     if( !_ext_buf )
489         inn_buf.allocate(2*n*(sizeof(int)+sizeof(float)));
490     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
491
492     float* values_buf = (float*)ext_buf;
493     int* indices_buf = (int*)(values_buf + n);
494     int* sample_indices_buf = indices_buf + n;
495     const float* values = 0;
496     const int* indices = 0;
497     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices, sample_indices_buf );
498     float* responses_buf = (float*)(indices_buf + n);
499     const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
500
501     int i, best_i = -1;
502     double L = 0, R = weights[n];
503     double best_val = init_quality, lsum = 0, rsum = node->value*R;
504     
505     // compensate for missing values
506     for( i = n1; i < n; i++ )
507     {
508         int idx = indices[i];
509         double w = weights[idx];
510         rsum -= responses[idx]*w;
511         R -= w;
512     }
513
514     // find the optimal split
515     for( i = 0; i < n1 - 1; i++ )
516     {
517         int idx = indices[i];
518         double w = weights[idx];
519         double t = responses[idx]*w;
520         L += w; R -= w;
521         lsum += t; rsum -= t;
522
523         if( values[i] + epsilon < values[i+1] )
524         {
525             double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
526             if( best_val < val )
527             {
528                 best_val = val;
529                 best_i = i;
530             }
531         }
532     }
533
534     CvDTreeSplit* split = 0;
535     if( best_i >= 0 )
536     {
537         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
538         split->var_idx = vi;
539         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
540         split->ord.split_point = best_i;
541         split->inversed = 0;
542         split->quality = (float)best_val;
543     }
544     return split;
545 }
546
547
548 CvDTreeSplit*
549 CvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
550 {
551     const double* weights = ensemble->get_subtree_weights()->data.db;
552     int ci = data->get_var_type(vi);
553     int n = node->sample_count;
554     int mi = data->cat_count->data.i[ci];
555     int base_size = (2*mi+3)*sizeof(double) + mi*sizeof(double*);
556     cv::AutoBuffer<uchar> inn_buf(base_size);
557     if( !_ext_buf )
558         inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
559     uchar* base_buf = (uchar*)inn_buf;
560     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
561
562     int* cat_labels_buf = (int*)ext_buf;
563     const int* cat_labels = data->get_cat_var_data(node, vi, cat_labels_buf);
564     float* responses_buf = (float*)(cat_labels_buf + n);
565     int* sample_indices_buf = (int*)(responses_buf + n);
566     const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
567
568     double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
569     double* counts = sum + mi + 1;
570     double** sum_ptr = (double**)(counts + mi);
571     double L = 0, R = 0, best_val = init_quality, lsum = 0, rsum = 0;
572     int i, best_subset = -1, subset_i;
573
574     for( i = -1; i < mi; i++ )
575         sum[i] = counts[i] = 0;
576
577     // calculate sum response and weight of each category of the input var
578     for( i = 0; i < n; i++ )
579     {
580         int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
581         double w = weights[i];
582         double s = sum[idx] + responses[i]*w;
583         double nc = counts[idx] + w;
584         sum[idx] = s;
585         counts[idx] = nc;
586     }
587
588     // calculate average response in each category
589     for( i = 0; i < mi; i++ )
590     {
591         R += counts[i];
592         rsum += sum[i];
593                 sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
594         sum_ptr[i] = sum + i;
595     }
596
597     icvSortDblPtr( sum_ptr, mi, 0 );
598
599     // revert back to unnormalized sums
600     // (there should be a very little loss in accuracy)
601     for( i = 0; i < mi; i++ )
602         sum[i] *= counts[i];
603
604     for( subset_i = 0; subset_i < mi-1; subset_i++ )
605     {
606         int idx = (int)(sum_ptr[subset_i] - sum);
607         double ni = counts[idx];
608
609         if( ni > FLT_EPSILON )
610         {
611             double s = sum[idx];
612             lsum += s; L += ni;
613             rsum -= s; R -= ni;
614
615             if( L > FLT_EPSILON && R > FLT_EPSILON )
616             {
617                 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
618                 if( best_val < val )
619                 {
620                     best_val = val;
621                     best_subset = subset_i;
622                 }
623             }
624         }
625     }
626
627     CvDTreeSplit* split = 0;
628     if( best_subset >= 0 )
629     {
630         split = _split ? _split : data->new_split_cat( 0, -1.0f);
631         split->var_idx = vi;
632         split->quality = (float)best_val;
633         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
634         for( i = 0; i <= best_subset; i++ )
635         {
636             int idx = (int)(sum_ptr[i] - sum);
637             split->subset[idx >> 5] |= 1 << (idx & 31);
638         }
639     }
640     return split;
641 }
642
643
644 CvDTreeSplit*
645 CvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
646 {
647     const float epsilon = FLT_EPSILON*2;
648     int n = node->sample_count;
649     cv::AutoBuffer<uchar> inn_buf;
650     if( !_ext_buf )
651         inn_buf.allocate(n*(2*sizeof(int)+sizeof(float)));
652     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
653     float* values_buf = (float*)ext_buf;
654     int* indices_buf = (int*)(values_buf + n);
655     int* sample_indices_buf = indices_buf + n;
656     const float* values = 0;
657     const int* indices = 0;
658     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices, sample_indices_buf );
659
660     const double* weights = ensemble->get_subtree_weights()->data.db;
661     const char* dir = (char*)data->direction->data.ptr;
662     int n1 = node->get_num_valid(vi);
663     // LL - number of samples that both the primary and the surrogate splits send to the left
664     // LR - ... primary split sends to the left and the surrogate split sends to the right
665     // RL - ... primary split sends to the right and the surrogate split sends to the left
666     // RR - ... both send to the right
667     int i, best_i = -1, best_inversed = 0;
668     double best_val;
669     double LL = 0, RL = 0, LR, RR;
670     double worst_val = node->maxlr;
671     double sum = 0, sum_abs = 0;
672     best_val = worst_val;
673
674     for( i = 0; i < n1; i++ )
675     {
676         int idx = indices[i];
677         double w = weights[idx];
678         int d = dir[idx];
679         sum += d*w; sum_abs += (d & 1)*w;
680     }
681
682     // sum_abs = R + L; sum = R - L
683     RR = (sum_abs + sum)*0.5;
684     LR = (sum_abs - sum)*0.5;
685
686     // initially all the samples are sent to the right by the surrogate split,
687     // LR of them are sent to the left by primary split, and RR - to the right.
688     // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
689     for( i = 0; i < n1 - 1; i++ )
690     {
691         int idx = indices[i];
692         double w = weights[idx];
693         int d = dir[idx];
694
695         if( d < 0 )
696         {
697             LL += w; LR -= w;
698             if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
699             {
700                 best_val = LL + RR;
701                 best_i = i; best_inversed = 0;
702             }
703         }
704         else if( d > 0 )
705         {
706             RL += w; RR -= w;
707             if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
708             {
709                 best_val = RL + LR;
710                 best_i = i; best_inversed = 1;
711             }
712         }
713     }
714
715     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
716         (values[best_i] + values[best_i+1])*0.5f, best_i,
717         best_inversed, (float)best_val ) : 0;
718 }
719
720
721 CvDTreeSplit*
722 CvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
723 {
724     const char* dir = (char*)data->direction->data.ptr;
725     const double* weights = ensemble->get_subtree_weights()->data.db;
726     int n = node->sample_count;
727     int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
728
729     int base_size = (2*mi+3)*sizeof(double);
730     cv::AutoBuffer<uchar> inn_buf(base_size);
731     if( !_ext_buf )
732         inn_buf.allocate(base_size + n*sizeof(int));
733     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
734     int* cat_labels_buf = (int*)ext_buf;
735     const int* cat_labels = data->get_cat_var_data(node, vi, cat_labels_buf);
736
737     // LL - number of samples that both the primary and the surrogate splits send to the left
738     // LR - ... primary split sends to the left and the surrogate split sends to the right
739     // RL - ... primary split sends to the right and the surrogate split sends to the left
740     // RR - ... both send to the right
741     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
742     double best_val = 0;
743     double* lc = (double*)cv::alignPtr(cat_labels_buf + n, sizeof(double)) + 1;
744     double* rc = lc + mi + 1;
745
746     for( i = -1; i < mi; i++ )
747         lc[i] = rc[i] = 0;
748
749     // 1. for each category calculate the weight of samples
750     // sent to the left (lc) and to the right (rc) by the primary split
751     for( i = 0; i < n; i++ )
752     {
753         int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
754         double w = weights[i];
755         int d = dir[i];
756         double sum = lc[idx] + d*w;
757         double sum_abs = rc[idx] + (d & 1)*w;
758         lc[idx] = sum; rc[idx] = sum_abs;
759     }
760
761     for( i = 0; i < mi; i++ )
762     {
763         double sum = lc[i];
764         double sum_abs = rc[i];
765         lc[i] = (sum_abs - sum) * 0.5;
766         rc[i] = (sum_abs + sum) * 0.5;
767     }
768
769     // 2. now form the split.
770     // in each category send all the samples to the same direction as majority
771     for( i = 0; i < mi; i++ )
772     {
773         double lval = lc[i], rval = rc[i];
774         if( lval > rval )
775         {
776             split->subset[i >> 5] |= 1 << (i & 31);
777             best_val += lval;
778         }
779         else
780             best_val += rval;
781     }
782
783     split->quality = (float)best_val;
784     if( split->quality <= node->maxlr )
785         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
786
787     return split;
788 }
789
790
791 void
792 CvBoostTree::calc_node_value( CvDTreeNode* node )
793 {
794     int i, n = node->sample_count;
795     const double* weights = ensemble->get_weights()->data.db;
796     cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int) + ( data->is_classifier ? sizeof(int) : sizeof(int) + sizeof(float))));
797     int* labels_buf = (int*)(uchar*)inn_buf;
798     const int* labels = data->get_cv_labels(node, labels_buf);
799     double* subtree_weights = ensemble->get_subtree_weights()->data.db;
800     double rcw[2] = {0,0};
801     int boost_type = ensemble->get_params().boost_type;
802
803     if( data->is_classifier )
804     {
805         int* _responses_buf = labels_buf + n;
806         const int* _responses = data->get_class_labels(node, _responses_buf);
807         int m = data->get_num_classes();
808         int* cls_count = data->counts->data.i;
809         for( int k = 0; k < m; k++ )
810             cls_count[k] = 0;
811
812         for( i = 0; i < n; i++ )
813         {
814             int idx = labels[i];
815             double w = weights[idx];
816             int r = _responses[i];
817             rcw[r] += w;
818             cls_count[r]++;
819             subtree_weights[i] = w;
820         }
821
822         node->class_idx = rcw[1] > rcw[0];
823
824         if( boost_type == CvBoost::DISCRETE )
825         {
826             // ignore cat_map for responses, and use {-1,1},
827             // as the whole ensemble response is computes as sign(sum_i(weak_response_i)
828             node->value = node->class_idx*2 - 1;
829         }
830         else
831         {
832             double p = rcw[1]/(rcw[0] + rcw[1]);
833             assert( boost_type == CvBoost::REAL );
834
835             // store log-ratio of the probability
836             node->value = 0.5*log_ratio(p);
837         }
838     }
839     else
840     {
841         // in case of regression tree:
842         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
843         //    n is the number of samples in the node.
844         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
845         double sum = 0, sum2 = 0, iw;
846         float* values_buf = (float*)(labels_buf + n);
847         int* sample_indices_buf = (int*)(values_buf + n);
848         const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
849
850         for( i = 0; i < n; i++ )
851         {
852             int idx = labels[i];
853             double w = weights[idx]/*priors[values[i] > 0]*/;
854             double t = values[i];
855             rcw[0] += w;
856             subtree_weights[i] = w;
857             sum += t*w;
858             sum2 += t*t*w;
859         }
860
861         iw = 1./rcw[0];
862         node->value = sum*iw;
863         node->node_risk = sum2 - (sum*iw)*sum;
864
865         // renormalize the risk, as in try_split_node the unweighted formula
866         // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)
867         node->node_risk *= n*iw*n*iw;
868     }
869
870     // store summary weights
871     subtree_weights[n] = rcw[0];
872     subtree_weights[n+1] = rcw[1];
873 }
874
875
876 void CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )
877 {
878     CvDTree::read( fs, fnode, _data );
879     ensemble = _ensemble;
880 }
881
882
883 void CvBoostTree::read( CvFileStorage*, CvFileNode* )
884 {
885     assert(0);
886 }
887
888 void CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node,
889                         CvDTreeTrainData* _data )
890 {
891     CvDTree::read( _fs, _node, _data );
892 }
893
894
895 /////////////////////////////////// CvBoost /////////////////////////////////////
896
897 CvBoost::CvBoost()
898 {
899     data = 0;
900     weak = 0;
901     default_model_name = "my_boost_tree";
902
903     active_vars = active_vars_abs = orig_response = sum_response = weak_eval =
904         subsample_mask = weights = subtree_weights = 0;
905     have_active_cat_vars = have_subsample = false;
906
907     clear();
908 }
909
910
911 void CvBoost::prune( CvSlice slice )
912 {
913     if( weak )
914     {
915         CvSeqReader reader;
916         int i, count = cvSliceLength( slice, weak );
917
918         cvStartReadSeq( weak, &reader );
919         cvSetSeqReaderPos( &reader, slice.start_index );
920
921         for( i = 0; i < count; i++ )
922         {
923             CvBoostTree* w;
924             CV_READ_SEQ_ELEM( w, reader );
925             delete w;
926         }
927
928         cvSeqRemoveSlice( weak, slice );
929     }
930 }
931
932
933 void CvBoost::clear()
934 {
935     if( weak )
936     {
937         prune( CV_WHOLE_SEQ );
938         cvReleaseMemStorage( &weak->storage );
939     }
940     if( data )
941         delete data;
942     weak = 0;
943     data = 0;
944     cvReleaseMat( &active_vars );
945     cvReleaseMat( &active_vars_abs );
946     cvReleaseMat( &orig_response );
947     cvReleaseMat( &sum_response );
948     cvReleaseMat( &weak_eval );
949     cvReleaseMat( &subsample_mask );
950     cvReleaseMat( &weights );
951     cvReleaseMat( &subtree_weights );
952
953     have_subsample = false;
954 }
955
956
957 CvBoost::~CvBoost()
958 {
959     clear();
960 }
961
962
963 CvBoost::CvBoost( const CvMat* _train_data, int _tflag,
964                   const CvMat* _responses, const CvMat* _var_idx,
965                   const CvMat* _sample_idx, const CvMat* _var_type,
966                   const CvMat* _missing_mask, CvBoostParams _params )
967 {
968     weak = 0;
969     data = 0;
970     default_model_name = "my_boost_tree";
971
972     active_vars = active_vars_abs = orig_response = sum_response = weak_eval =
973         subsample_mask = weights = subtree_weights = 0;
974
975     train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
976            _var_type, _missing_mask, _params );
977 }
978
979
980 bool
981 CvBoost::set_params( const CvBoostParams& _params )
982 {
983     bool ok = false;
984
985     CV_FUNCNAME( "CvBoost::set_params" );
986
987     __BEGIN__;
988
989     params = _params;
990     if( params.boost_type != DISCRETE && params.boost_type != REAL &&
991         params.boost_type != LOGIT && params.boost_type != GENTLE )
992         CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );
993
994     params.weak_count = MAX( params.weak_count, 1 );
995     params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );
996     params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );
997     if( params.weight_trim_rate < FLT_EPSILON )
998         params.weight_trim_rate = 1.f;
999
1000     if( params.boost_type == DISCRETE &&
1001         params.split_criteria != GINI && params.split_criteria != MISCLASS )
1002         params.split_criteria = MISCLASS;
1003     if( params.boost_type == REAL &&
1004         params.split_criteria != GINI && params.split_criteria != MISCLASS )
1005         params.split_criteria = GINI;
1006     if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&
1007         params.split_criteria != SQERR )
1008         params.split_criteria = SQERR;
1009
1010     ok = true;
1011
1012     __END__;
1013
1014     return ok;
1015 }
1016
1017
1018 bool
1019 CvBoost::train( const CvMat* _train_data, int _tflag,
1020               const CvMat* _responses, const CvMat* _var_idx,
1021               const CvMat* _sample_idx, const CvMat* _var_type,
1022               const CvMat* _missing_mask,
1023               CvBoostParams _params, bool _update )
1024 {
1025     bool ok = false;
1026     CvMemStorage* storage = 0;
1027
1028     CV_FUNCNAME( "CvBoost::train" );
1029
1030     __BEGIN__;
1031
1032     int i;
1033     
1034     set_params( _params );
1035
1036     cvReleaseMat( &active_vars );
1037     cvReleaseMat( &active_vars_abs );
1038
1039     if( !_update || !data )
1040     {
1041         clear();
1042         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
1043             _sample_idx, _var_type, _missing_mask, _params, true, true );
1044
1045         if( data->get_num_classes() != 2 )
1046             CV_ERROR( CV_StsNotImplemented,
1047             "Boosted trees can only be used for 2-class classification." );
1048         CV_CALL( storage = cvCreateMemStorage() );
1049         weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
1050         storage = 0;
1051     }
1052     else
1053     {
1054         data->set_data( _train_data, _tflag, _responses, _var_idx,
1055             _sample_idx, _var_type, _missing_mask, _params, true, true, true );
1056     }
1057
1058     if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
1059         data->do_responses_copy();
1060     
1061     update_weights( 0 );
1062
1063     for( i = 0; i < params.weak_count; i++ )
1064     {
1065         CvBoostTree* tree = new CvBoostTree;
1066         if( !tree->train( data, subsample_mask, this ) )
1067         {
1068             delete tree;
1069             break;
1070         }
1071         //cvCheckArr( get_weak_response());
1072         cvSeqPush( weak, &tree );
1073         update_weights( tree );
1074         trim_weights();
1075         if( cvCountNonZero(subsample_mask) == 0 )
1076             break;
1077     }
1078
1079     get_active_vars(); // recompute active_vars* maps and condensed_idx's in the splits.
1080     data->is_classifier = true;
1081     ok = true;
1082
1083     data->free_train_data();
1084
1085     __END__;
1086
1087     return ok;
1088 }
1089
1090 bool CvBoost::train( CvMLData* _data,
1091              CvBoostParams params,
1092              bool update )
1093 {
1094     bool result = false;
1095
1096     CV_FUNCNAME( "CvBoost::train" );
1097
1098     __BEGIN__;
1099
1100     const CvMat* values = _data->get_values();
1101     const CvMat* response = _data->get_responses();
1102     const CvMat* missing = _data->get_missing();
1103     const CvMat* var_types = _data->get_var_types();
1104     const CvMat* train_sidx = _data->get_train_sample_idx();
1105     const CvMat* var_idx = _data->get_var_idx();
1106
1107     CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
1108         train_sidx, var_types, missing, params, update ) );
1109
1110     __END__;
1111
1112     return result;
1113 }
1114
1115 void
1116 CvBoost::update_weights( CvBoostTree* tree )
1117 {
1118     CV_FUNCNAME( "CvBoost::update_weights" );
1119
1120     __BEGIN__;
1121
1122     int i, n = data->sample_count;
1123     double sumw = 0.;
1124     int step = 0;
1125     float* fdata = 0;
1126     int *sample_idx_buf;
1127     const int* sample_idx = 0;
1128     cv::AutoBuffer<uchar> inn_buf;
1129     int _buf_size = (params.boost_type == LOGIT) || (params.boost_type == GENTLE) ? data->sample_count*sizeof(int) : 0;
1130     if( !tree )
1131         _buf_size += n*sizeof(int);
1132     else
1133     {
1134         if( have_subsample )
1135             _buf_size += data->buf->step*(sizeof(float)+sizeof(uchar));
1136     }
1137     inn_buf.allocate(_buf_size);
1138     uchar* cur_buf_pos = (uchar*)inn_buf;
1139
1140     if ( (params.boost_type == LOGIT) || (params.boost_type == GENTLE) )
1141     {
1142         step = CV_IS_MAT_CONT(data->responses_copy->type) ?
1143             1 : data->responses_copy->step / CV_ELEM_SIZE(data->responses_copy->type);
1144         fdata = data->responses_copy->data.fl;
1145         sample_idx_buf = (int*)cur_buf_pos;
1146         cur_buf_pos = (uchar*)(sample_idx_buf + data->sample_count);
1147         sample_idx = data->get_sample_indices( data->data_root, sample_idx_buf );
1148     }
1149     CvMat* dtree_data_buf = data->buf;
1150     if( !tree ) // before training the first tree, initialize weights and other parameters
1151     {
1152         int* class_labels_buf = (int*)cur_buf_pos;
1153         cur_buf_pos = (uchar*)(class_labels_buf + n);
1154         const int* class_labels = data->get_class_labels(data->data_root, class_labels_buf);
1155         // in case of logitboost and gentle adaboost each weak tree is a regression tree,
1156         // so we need to convert class labels to floating-point values
1157
1158         double w0 = 1./n;
1159         double p[2] = { 1, 1 };
1160
1161         cvReleaseMat( &orig_response );
1162         cvReleaseMat( &sum_response );
1163         cvReleaseMat( &weak_eval );
1164         cvReleaseMat( &subsample_mask );
1165         cvReleaseMat( &weights );
1166         cvReleaseMat( &subtree_weights );
1167
1168         CV_CALL( orig_response = cvCreateMat( 1, n, CV_32S ));
1169         CV_CALL( weak_eval = cvCreateMat( 1, n, CV_64F ));
1170         CV_CALL( subsample_mask = cvCreateMat( 1, n, CV_8U ));
1171         CV_CALL( weights = cvCreateMat( 1, n, CV_64F ));
1172         CV_CALL( subtree_weights = cvCreateMat( 1, n + 2, CV_64F ));
1173
1174         if( data->have_priors )
1175         {
1176             // compute weight scale for each class from their prior probabilities
1177             int c1 = 0;
1178             for( i = 0; i < n; i++ )
1179                 c1 += class_labels[i];
1180             p[0] = data->priors->data.db[0]*(c1 < n ? 1./(n - c1) : 0.);
1181             p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.);
1182             p[0] /= p[0] + p[1];
1183             p[1] = 1. - p[0];
1184         }
1185
1186         if (data->is_buf_16u)
1187         {
1188             unsigned short* labels = (unsigned short*)(dtree_data_buf->data.s + data->data_root->buf_idx*dtree_data_buf->cols +
1189                 data->data_root->offset + (data->work_var_count-1)*data->sample_count);
1190             for( i = 0; i < n; i++ )
1191             {
1192                 // save original categorical responses {0,1}, convert them to {-1,1}
1193                 orig_response->data.i[i] = class_labels[i]*2 - 1;
1194                 // make all the samples active at start.
1195                 // later, in trim_weights() deactivate/reactive again some, if need
1196                 subsample_mask->data.ptr[i] = (uchar)1;
1197                 // make all the initial weights the same.
1198                 weights->data.db[i] = w0*p[class_labels[i]];
1199                 // set the labels to find (from within weak tree learning proc)
1200                 // the particular sample weight, and where to store the response.
1201                 labels[i] = (unsigned short)i;
1202             }
1203         }
1204         else
1205         {
1206             int* labels = dtree_data_buf->data.i + data->data_root->buf_idx*dtree_data_buf->cols +
1207                 data->data_root->offset + (data->work_var_count-1)*data->sample_count;
1208
1209             for( i = 0; i < n; i++ )
1210             {
1211                 // save original categorical responses {0,1}, convert them to {-1,1}
1212                 orig_response->data.i[i] = class_labels[i]*2 - 1;
1213                 // make all the samples active at start.
1214                 // later, in trim_weights() deactivate/reactive again some, if need
1215                 subsample_mask->data.ptr[i] = (uchar)1;
1216                 // make all the initial weights the same.
1217                 weights->data.db[i] = w0*p[class_labels[i]];
1218                 // set the labels to find (from within weak tree learning proc)
1219                 // the particular sample weight, and where to store the response.
1220                 labels[i] = i;
1221             }
1222         }
1223
1224         if( params.boost_type == LOGIT )
1225         {
1226             CV_CALL( sum_response = cvCreateMat( 1, n, CV_64F ));
1227
1228             for( i = 0; i < n; i++ )
1229             {
1230                 sum_response->data.db[i] = 0;
1231                 fdata[sample_idx[i]*step] = orig_response->data.i[i] > 0 ? 2.f : -2.f;
1232             }
1233
1234             // in case of logitboost each weak tree is a regression tree.
1235             // the target function values are recalculated for each of the trees
1236             data->is_classifier = false;
1237         }
1238         else if( params.boost_type == GENTLE )
1239         {
1240             for( i = 0; i < n; i++ )
1241                 fdata[sample_idx[i]*step] = (float)orig_response->data.i[i];
1242
1243             data->is_classifier = false;
1244         }
1245     }
1246     else
1247     {
1248         // at this moment, for all the samples that participated in the training of the most
1249         // recent weak classifier we know the responses. For other samples we need to compute them
1250         if( have_subsample )
1251         {
1252             float* values = (float*)cur_buf_pos;
1253             cur_buf_pos = (uchar*)(values + data->buf->step);
1254             uchar* missing = cur_buf_pos;
1255             cur_buf_pos = missing + data->buf->step;
1256             CvMat _sample, _mask;
1257
1258             // invert the subsample mask
1259             cvXorS( subsample_mask, cvScalar(1.), subsample_mask );
1260             data->get_vectors( subsample_mask, values, missing, 0 );
1261            
1262             _sample = cvMat( 1, data->var_count, CV_32F );
1263             _mask = cvMat( 1, data->var_count, CV_8U );
1264
1265             // run tree through all the non-processed samples
1266             for( i = 0; i < n; i++ )
1267                 if( subsample_mask->data.ptr[i] )
1268                 {
1269                     _sample.data.fl = values;
1270                     _mask.data.ptr = missing;
1271                     values += _sample.cols;
1272                     missing += _mask.cols;
1273                     weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;
1274                 }
1275         }
1276
1277         // now update weights and other parameters for each type of boosting
1278         if( params.boost_type == DISCRETE )
1279         {
1280             // Discrete AdaBoost:
1281             //   weak_eval[i] (=f(x_i)) is in {-1,1}
1282             //   err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
1283             //   C = log((1-err)/err)
1284             //   w_i *= exp(C*(f(x_i) != y_i))
1285
1286             double C, err = 0.;
1287             double scale[] = { 1., 0. };
1288
1289             for( i = 0; i < n; i++ )
1290             {
1291                 double w = weights->data.db[i];
1292                 sumw += w;
1293                 err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);
1294             }
1295
1296             if( sumw != 0 )
1297                 err /= sumw;
1298             C = err = -log_ratio( err );
1299             scale[1] = exp(err);
1300
1301             sumw = 0;
1302             for( i = 0; i < n; i++ )
1303             {
1304                 double w = weights->data.db[i]*
1305                     scale[weak_eval->data.db[i] != orig_response->data.i[i]];
1306                 sumw += w;
1307                 weights->data.db[i] = w;
1308             }
1309
1310             tree->scale( C );
1311         }
1312         else if( params.boost_type == REAL )
1313         {
1314             // Real AdaBoost:
1315             //   weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
1316             //   w_i *= exp(-y_i*f(x_i))
1317
1318             for( i = 0; i < n; i++ )
1319                 weak_eval->data.db[i] *= -orig_response->data.i[i];
1320
1321             cvExp( weak_eval, weak_eval );
1322
1323             for( i = 0; i < n; i++ )
1324             {
1325                 double w = weights->data.db[i]*weak_eval->data.db[i];
1326                 sumw += w;
1327                 weights->data.db[i] = w;
1328             }
1329         }
1330         else if( params.boost_type == LOGIT )
1331         {
1332             // LogitBoost:
1333             //   weak_eval[i] = f(x_i) in [-z_max,z_max]
1334             //   sum_response = F(x_i).
1335             //   F(x_i) += 0.5*f(x_i)
1336             //   p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
1337             //   reuse weak_eval: weak_eval[i] <- p(x_i)
1338             //   w_i = p(x_i)*1(1 - p(x_i))
1339             //   z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
1340             //   store z_i to the data->data_root as the new target responses
1341
1342             const double lb_weight_thresh = FLT_EPSILON;
1343             const double lb_z_max = 10.;
1344             /*float* responses_buf = data->get_resp_float_buf();
1345             const float* responses = 0;
1346             data->get_ord_responses(data->data_root, responses_buf, &responses);*/
1347
1348             /*if( weak->total == 7 )
1349                 putchar('*');*/
1350
1351             for( i = 0; i < n; i++ )
1352             {
1353                 double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];
1354                 sum_response->data.db[i] = s;
1355                 weak_eval->data.db[i] = -2*s;
1356             }
1357
1358             cvExp( weak_eval, weak_eval );
1359
1360             for( i = 0; i < n; i++ )
1361             {
1362                 double p = 1./(1. + weak_eval->data.db[i]);
1363                 double w = p*(1 - p), z;
1364                 w = MAX( w, lb_weight_thresh );
1365                 weights->data.db[i] = w;
1366                 sumw += w;
1367                 if( orig_response->data.i[i] > 0 )
1368                 {
1369                     z = 1./p;
1370                     fdata[sample_idx[i]*step] = (float)MIN(z, lb_z_max);
1371                 }
1372                 else
1373                 {
1374                     z = 1./(1-p);
1375                     fdata[sample_idx[i]*step] = (float)-MIN(z, lb_z_max);
1376                 }
1377             }
1378         }
1379         else
1380         {
1381             // Gentle AdaBoost:
1382             //   weak_eval[i] = f(x_i) in [-1,1]
1383             //   w_i *= exp(-y_i*f(x_i))
1384             assert( params.boost_type == GENTLE );
1385
1386             for( i = 0; i < n; i++ )
1387                 weak_eval->data.db[i] *= -orig_response->data.i[i];
1388
1389             cvExp( weak_eval, weak_eval );
1390
1391             for( i = 0; i < n; i++ )
1392             {
1393                 double w = weights->data.db[i] * weak_eval->data.db[i];
1394                 weights->data.db[i] = w;
1395                 sumw += w;
1396             }
1397         }
1398     }
1399
1400     // renormalize weights
1401     if( sumw > FLT_EPSILON )
1402     {
1403         sumw = 1./sumw;
1404         for( i = 0; i < n; ++i )
1405             weights->data.db[i] *= sumw;
1406     }
1407
1408     __END__;
1409 }
1410
1411
1412 static CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int )
1413
1414
1415 void
1416 CvBoost::trim_weights()
1417 {
1418     //CV_FUNCNAME( "CvBoost::trim_weights" );
1419
1420     __BEGIN__;
1421
1422     int i, count = data->sample_count, nz_count = 0;
1423     double sum, threshold;
1424
1425     if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )
1426         EXIT;
1427
1428     // use weak_eval as temporary buffer for sorted weights
1429     cvCopy( weights, weak_eval );
1430
1431     icvSort_64f( weak_eval->data.db, count, 0 );
1432
1433     // as weight trimming occurs immediately after updating the weights,
1434     // where they are renormalized, we assume that the weight sum = 1.
1435     sum = 1. - params.weight_trim_rate;
1436
1437     for( i = 0; i < count; i++ )
1438     {
1439         double w = weak_eval->data.db[i];
1440         if( sum <= 0 )
1441             break;
1442         sum -= w;
1443     }
1444
1445     threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;
1446
1447     for( i = 0; i < count; i++ )
1448     {
1449         double w = weights->data.db[i];
1450         int f = w >= threshold;
1451         subsample_mask->data.ptr[i] = (uchar)f;
1452         nz_count += f;
1453     }
1454
1455     have_subsample = nz_count < count;
1456
1457     __END__;
1458 }
1459
1460
1461 const CvMat* 
1462 CvBoost::get_active_vars( bool absolute_idx )
1463 {
1464     CvMat* mask = 0;
1465     CvMat* inv_map = 0;
1466     CvMat* result = 0;
1467     
1468     CV_FUNCNAME( "CvBoost::get_active_vars" );
1469
1470     __BEGIN__;
1471     
1472     if( !weak )
1473         CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
1474
1475     if( !active_vars || !active_vars_abs )
1476     {
1477         CvSeqReader reader;
1478         int i, j, nactive_vars;
1479         CvBoostTree* wtree;
1480         const CvDTreeNode* node;
1481         
1482         assert(!active_vars && !active_vars_abs);
1483         mask = cvCreateMat( 1, data->var_count, CV_8U );
1484         inv_map = cvCreateMat( 1, data->var_count, CV_32S );
1485         cvZero( mask );
1486         cvSet( inv_map, cvScalar(-1) );
1487
1488         // first pass: compute the mask of used variables
1489         cvStartReadSeq( weak, &reader );
1490         for( i = 0; i < weak->total; i++ )
1491         {
1492             CV_READ_SEQ_ELEM(wtree, reader);
1493
1494             node = wtree->get_root();
1495             assert( node != 0 );
1496             for(;;)
1497             {
1498                 const CvDTreeNode* parent;
1499                 for(;;)
1500                 {
1501                     CvDTreeSplit* split = node->split;
1502                     for( ; split != 0; split = split->next )
1503                         mask->data.ptr[split->var_idx] = 1;
1504                     if( !node->left )
1505                         break;
1506                     node = node->left;
1507                 }
1508
1509                 for( parent = node->parent; parent && parent->right == node;
1510                     node = parent, parent = parent->parent )
1511                     ;
1512
1513                 if( !parent )
1514                     break;
1515
1516                 node = parent->right;
1517             }
1518         }
1519
1520         nactive_vars = cvCountNonZero(mask);
1521         
1522         //if ( nactive_vars > 0 )
1523         {
1524             active_vars = cvCreateMat( 1, nactive_vars, CV_32S );
1525             active_vars_abs = cvCreateMat( 1, nactive_vars, CV_32S );
1526
1527             have_active_cat_vars = false;
1528
1529             for( i = j = 0; i < data->var_count; i++ )
1530             {
1531                 if( mask->data.ptr[i] )
1532                 {
1533                     active_vars->data.i[j] = i;
1534                     active_vars_abs->data.i[j] = data->var_idx ? data->var_idx->data.i[i] : i;
1535                     inv_map->data.i[i] = j;
1536                     if( data->var_type->data.i[i] >= 0 )
1537                         have_active_cat_vars = true;
1538                     j++;
1539                 }
1540             }
1541             
1542
1543             // second pass: now compute the condensed indices
1544             cvStartReadSeq( weak, &reader );
1545             for( i = 0; i < weak->total; i++ )
1546             {
1547                 CV_READ_SEQ_ELEM(wtree, reader);
1548                 node = wtree->get_root();
1549                 for(;;)
1550                 {
1551                     const CvDTreeNode* parent;
1552                     for(;;)
1553                     {
1554                         CvDTreeSplit* split = node->split;
1555                         for( ; split != 0; split = split->next )
1556                         {
1557                             split->condensed_idx = inv_map->data.i[split->var_idx];
1558                             assert( split->condensed_idx >= 0 );
1559                         }
1560
1561                         if( !node->left )
1562                             break;
1563                         node = node->left;
1564                     }
1565
1566                     for( parent = node->parent; parent && parent->right == node;
1567                         node = parent, parent = parent->parent )
1568                         ;
1569
1570                     if( !parent )
1571                         break;
1572
1573                     node = parent->right;
1574                 }
1575             }
1576         }
1577     }
1578
1579     result = absolute_idx ? active_vars_abs : active_vars;
1580
1581     __END__;
1582
1583     cvReleaseMat( &mask );
1584     cvReleaseMat( &inv_map );
1585
1586     return result;
1587 }
1588
1589
1590 float
1591 CvBoost::predict( const CvMat* _sample, const CvMat* _missing,
1592                   CvMat* weak_responses, CvSlice slice,
1593                   bool raw_mode, bool return_sum ) const
1594 {
1595     float value = -FLT_MAX;
1596
1597     CvMat sample, missing;
1598     CvSeqReader reader;
1599     double sum = 0;
1600     int wstep = 0;
1601     const float* sample_data;
1602
1603     if( !weak )
1604         CV_Error( CV_StsError, "The boosted tree ensemble has not been trained yet" );
1605
1606     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
1607         (_sample->cols != 1 && _sample->rows != 1) ||
1608         (_sample->cols + _sample->rows - 1 != data->var_all && !raw_mode) ||
1609         (active_vars && _sample->cols + _sample->rows - 1 != active_vars->cols && raw_mode) )
1610             CV_Error( CV_StsBadArg,
1611         "the input sample must be 1d floating-point vector with the same "
1612         "number of elements as the total number of variables or "
1613         "as the number of variables used for training" );
1614
1615     if( _missing )
1616     {
1617         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
1618             !CV_ARE_SIZES_EQ(_missing, _sample) )
1619             CV_Error( CV_StsBadArg,
1620             "the missing data mask must be 8-bit vector of the same size as input sample" );
1621     }
1622
1623     int i, weak_count = cvSliceLength( slice, weak );
1624     if( weak_count >= weak->total )
1625     {
1626         weak_count = weak->total;
1627         slice.start_index = 0;
1628     }
1629
1630     if( weak_responses )
1631     {
1632         if( !CV_IS_MAT(weak_responses) ||
1633             CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||
1634             (weak_responses->cols != 1 && weak_responses->rows != 1) ||
1635             weak_responses->cols + weak_responses->rows - 1 != weak_count )
1636             CV_Error( CV_StsBadArg,
1637             "The output matrix of weak classifier responses must be valid "
1638             "floating-point vector of the same number of components as the length of input slice" );
1639         wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);
1640     }
1641     
1642     int var_count = active_vars->cols;
1643     const int* vtype = data->var_type->data.i;
1644     const int* cmap = data->cat_map->data.i;
1645     const int* cofs = data->cat_ofs->data.i;
1646
1647     // if need, preprocess the input vector
1648     if( !raw_mode )
1649     {
1650         int step, mstep = 0;
1651         const float* src_sample;
1652         const uchar* src_mask = 0;
1653         float* dst_sample;
1654         uchar* dst_mask;
1655         const int* vidx = active_vars->data.i;
1656         const int* vidx_abs = active_vars_abs->data.i;
1657         bool have_mask = _missing != 0;
1658
1659         cv::AutoBuffer<float> buf(var_count + (var_count+3)/4);
1660         dst_sample = &buf[0];
1661         dst_mask = (uchar*)&buf[var_count];
1662
1663         src_sample = _sample->data.fl;
1664         step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);
1665
1666         if( _missing )
1667         {
1668             src_mask = _missing->data.ptr;
1669             mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;
1670         }
1671
1672         for( i = 0; i < var_count; i++ )
1673         {
1674             int idx = vidx[i], idx_abs = vidx_abs[i];
1675             float val = src_sample[idx_abs*step];
1676             int ci = vtype[idx];
1677             uchar m = src_mask ? src_mask[idx_abs*mstep] : (uchar)0;
1678
1679             if( ci >= 0 )
1680             {
1681                 int a = cofs[ci], b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1],
1682                     c = a;
1683                 int ival = cvRound(val);
1684                 if ( (ival != val) && (!m) )
1685                     CV_Error( CV_StsBadArg,
1686                         "one of input categorical variable is not an integer" );
1687
1688                 while( a < b )
1689                 {
1690                     c = (a + b) >> 1;
1691                     if( ival < cmap[c] )
1692                         b = c;
1693                     else if( ival > cmap[c] )
1694                         a = c+1;
1695                     else
1696                         break;
1697                 }
1698
1699                 if( c < 0 || ival != cmap[c] )
1700                 {
1701                     m = 1;
1702                     have_mask = true;
1703                 }
1704                 else
1705                 {
1706                     val = (float)(c - cofs[ci]);
1707                 }
1708             }
1709
1710             dst_sample[i] = val;
1711             dst_mask[i] = m;
1712         }
1713
1714         sample = cvMat( 1, var_count, CV_32F, dst_sample );
1715         _sample = &sample;
1716
1717         if( have_mask )
1718         {
1719             missing = cvMat( 1, var_count, CV_8UC1, dst_mask );
1720             _missing = &missing;
1721         }
1722     }
1723     else
1724     {
1725         if( !CV_IS_MAT_CONT(_sample->type & (_missing ? _missing->type : -1)) )
1726             CV_Error( CV_StsBadArg, "In raw mode the input vectors must be continuous" );
1727     }
1728
1729     cvStartReadSeq( weak, &reader );
1730     cvSetSeqReaderPos( &reader, slice.start_index );
1731
1732     sample_data = _sample->data.fl;
1733
1734     if( !have_active_cat_vars && !_missing && !weak_responses )
1735     {
1736         for( i = 0; i < weak_count; i++ )
1737         {
1738             CvBoostTree* wtree;
1739             const CvDTreeNode* node;
1740             CV_READ_SEQ_ELEM( wtree, reader );
1741             
1742             node = wtree->get_root();
1743             while( node->left )
1744             {
1745                 CvDTreeSplit* split = node->split;
1746                 int vi = split->condensed_idx;
1747                 float val = sample_data[vi];
1748                 int dir = val <= split->ord.c ? -1 : 1;
1749                 if( split->inversed )
1750                     dir = -dir;
1751                 node = dir < 0 ? node->left : node->right;
1752             }
1753             sum += node->value;
1754         }
1755     }
1756     else
1757     {
1758         const int* avars = active_vars->data.i;
1759         const uchar* m = _missing ? _missing->data.ptr : 0;
1760         
1761         // full-featured version
1762         for( i = 0; i < weak_count; i++ )
1763         {
1764             CvBoostTree* wtree;
1765             const CvDTreeNode* node;
1766             CV_READ_SEQ_ELEM( wtree, reader );
1767             
1768             node = wtree->get_root();
1769             while( node->left )
1770             {
1771                 const CvDTreeSplit* split = node->split;
1772                 int dir = 0;
1773                 for( ; !dir && split != 0; split = split->next )
1774                 {
1775                     int vi = split->condensed_idx;
1776                     int ci = vtype[avars[vi]];
1777                     float val = sample_data[vi];
1778                     if( m && m[vi] )
1779                         continue;
1780                     if( ci < 0 ) // ordered
1781                         dir = val <= split->ord.c ? -1 : 1;
1782                     else // categorical
1783                     {
1784                         int c = cvRound(val);
1785                         dir = CV_DTREE_CAT_DIR(c, split->subset);
1786                     }
1787                     if( split->inversed )
1788                         dir = -dir;
1789                 }
1790
1791                 if( !dir )
1792                 {
1793                     int diff = node->right->sample_count - node->left->sample_count;
1794                     dir = diff < 0 ? -1 : 1;
1795                 }
1796                 node = dir < 0 ? node->left : node->right;
1797             }
1798             if( weak_responses )
1799                 weak_responses->data.fl[i*wstep] = (float)node->value;
1800             sum += node->value;
1801         }
1802     }
1803
1804     if( return_sum )
1805         value = (float)sum;
1806     else
1807     {
1808         int cls_idx = sum >= 0;
1809         if( raw_mode )
1810             value = (float)cls_idx;
1811         else
1812             value = (float)cmap[cofs[vtype[data->var_count]] + cls_idx];
1813     }
1814
1815     return value;
1816 }
1817
1818 float CvBoost::calc_error( CvMLData* _data, int type, std::vector<float> *resp )
1819 {
1820     float err = 0;
1821     const CvMat* values = _data->get_values();
1822     const CvMat* response = _data->get_responses();
1823     const CvMat* missing = _data->get_missing();
1824     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
1825     const CvMat* var_types = _data->get_var_types();
1826     int* sidx = sample_idx ? sample_idx->data.i : 0;
1827     int r_step = CV_IS_MAT_CONT(response->type) ?
1828                 1 : response->step / CV_ELEM_SIZE(response->type);
1829     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
1830     int sample_count = sample_idx ? sample_idx->cols : 0;
1831     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
1832     float* pred_resp = 0;
1833     if( resp && (sample_count > 0) )
1834     {
1835         resp->resize( sample_count );
1836         pred_resp = &((*resp)[0]);
1837     }
1838     if ( is_classifier )
1839     {
1840         for( int i = 0; i < sample_count; i++ )
1841         {
1842             CvMat sample, miss;
1843             int si = sidx ? sidx[i] : i;
1844             cvGetRow( values, &sample, si ); 
1845             if( missing ) 
1846                 cvGetRow( missing, &miss, si );             
1847             float r = (float)predict( &sample, missing ? &miss : 0 );
1848             if( pred_resp )
1849                 pred_resp[i] = r;
1850             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
1851             err += d;
1852         }
1853         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
1854     }
1855     else
1856     {
1857         for( int i = 0; i < sample_count; i++ )
1858         {
1859             CvMat sample, miss;
1860             int si = sidx ? sidx[i] : i;
1861             cvGetRow( values, &sample, si );
1862             if( missing ) 
1863                 cvGetRow( missing, &miss, si );             
1864             float r = (float)predict( &sample, missing ? &miss : 0 );
1865             if( pred_resp )
1866                 pred_resp[i] = r;
1867             float d = r - response->data.fl[si*r_step];
1868             err += d*d;
1869         }
1870         err = sample_count ? err / (float)sample_count : -FLT_MAX;    
1871     }
1872     return err;
1873 }
1874
1875 void CvBoost::write_params( CvFileStorage* fs ) const
1876 {
1877     const char* boost_type_str =
1878         params.boost_type == DISCRETE ? "DiscreteAdaboost" :
1879         params.boost_type == REAL ? "RealAdaboost" :
1880         params.boost_type == LOGIT ? "LogitBoost" :
1881         params.boost_type == GENTLE ? "GentleAdaboost" : 0;
1882
1883     const char* split_crit_str =
1884         params.split_criteria == DEFAULT ? "Default" :
1885         params.split_criteria == GINI ? "Gini" :
1886         params.boost_type == MISCLASS ? "Misclassification" :
1887         params.boost_type == SQERR ? "SquaredErr" : 0;
1888
1889     if( boost_type_str )
1890         cvWriteString( fs, "boosting_type", boost_type_str );
1891     else
1892         cvWriteInt( fs, "boosting_type", params.boost_type );
1893
1894     if( split_crit_str )
1895         cvWriteString( fs, "splitting_criteria", split_crit_str );
1896     else
1897         cvWriteInt( fs, "splitting_criteria", params.split_criteria );
1898
1899     cvWriteInt( fs, "ntrees", params.weak_count );
1900     cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );
1901
1902     data->write_params( fs );
1903 }
1904
1905
1906 void CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )
1907 {
1908     CV_FUNCNAME( "CvBoost::read_params" );
1909
1910     __BEGIN__;
1911
1912     CvFileNode* temp;
1913
1914     if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
1915         return;
1916
1917     data = new CvDTreeTrainData();
1918     CV_CALL( data->read_params(fs, fnode));
1919     data->shared = true;
1920
1921     params.max_depth = data->params.max_depth;
1922     params.min_sample_count = data->params.min_sample_count;
1923     params.max_categories = data->params.max_categories;
1924     params.priors = data->params.priors;
1925     params.regression_accuracy = data->params.regression_accuracy;
1926     params.use_surrogates = data->params.use_surrogates;
1927
1928     temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );
1929     if( !temp )
1930         return;
1931
1932     if( temp && CV_NODE_IS_STRING(temp->tag) )
1933     {
1934         const char* boost_type_str = cvReadString( temp, "" );
1935         params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :
1936                             strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :
1937                             strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :
1938                             strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;
1939     }
1940     else
1941         params.boost_type = cvReadInt( temp, -1 );
1942
1943     if( params.boost_type < DISCRETE || params.boost_type > GENTLE )
1944         CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1945
1946     temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );
1947     if( temp && CV_NODE_IS_STRING(temp->tag) )
1948     {
1949         const char* split_crit_str = cvReadString( temp, "" );
1950         params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :
1951                                 strcmp( split_crit_str, "Gini" ) == 0 ? GINI :
1952                                 strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :
1953                                 strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;
1954     }
1955     else
1956         params.split_criteria = cvReadInt( temp, -1 );
1957
1958     if( params.split_criteria < DEFAULT || params.boost_type > SQERR )
1959         CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1960
1961     params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );
1962     params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );
1963
1964     __END__;
1965 }
1966
1967
1968
1969 void
1970 CvBoost::read( CvFileStorage* fs, CvFileNode* node )
1971 {
1972     CV_FUNCNAME( "CvBoost::read" );
1973
1974     __BEGIN__;
1975
1976     CvSeqReader reader;
1977     CvFileNode* trees_fnode;
1978     CvMemStorage* storage;
1979     int i, ntrees;
1980
1981     clear();
1982     read_params( fs, node );
1983
1984     if( !data )
1985         EXIT;
1986
1987     trees_fnode = cvGetFileNodeByName( fs, node, "trees" );
1988     if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
1989         CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
1990
1991     cvStartReadSeq( trees_fnode->data.seq, &reader );
1992     ntrees = trees_fnode->data.seq->total;
1993
1994     if( ntrees != params.weak_count )
1995         CV_ERROR( CV_StsUnmatchedSizes,
1996         "The number of trees stored does not match <ntrees> tag value" );
1997
1998     CV_CALL( storage = cvCreateMemStorage() );
1999     weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
2000
2001     for( i = 0; i < ntrees; i++ )
2002     {
2003         CvBoostTree* tree = new CvBoostTree();
2004         CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));
2005         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
2006         cvSeqPush( weak, &tree );
2007     }
2008     get_active_vars();
2009
2010     __END__;
2011 }
2012
2013
2014 void
2015 CvBoost::write( CvFileStorage* fs, const char* name ) const
2016 {
2017     CV_FUNCNAME( "CvBoost::write" );
2018
2019     __BEGIN__;
2020
2021     CvSeqReader reader;
2022     int i;
2023
2024     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );
2025
2026     if( !weak )
2027         CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );
2028
2029     write_params( fs );
2030     cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
2031
2032     cvStartReadSeq( weak, &reader );
2033
2034     for( i = 0; i < weak->total; i++ )
2035     {
2036         CvBoostTree* tree;
2037         CV_READ_SEQ_ELEM( tree, reader );
2038         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2039         tree->write( fs );
2040         cvEndWriteStruct( fs );
2041     }
2042
2043     cvEndWriteStruct( fs );
2044     cvEndWriteStruct( fs );
2045
2046     __END__;
2047 }
2048
2049
2050 CvMat*
2051 CvBoost::get_weights()
2052 {
2053     return weights;
2054 }
2055
2056
2057 CvMat*
2058 CvBoost::get_subtree_weights()
2059 {
2060     return subtree_weights;
2061 }
2062
2063
2064 CvMat*
2065 CvBoost::get_weak_response()
2066 {
2067     return weak_eval;
2068 }
2069
2070
2071 const CvBoostParams&
2072 CvBoost::get_params() const
2073 {
2074     return params;
2075 }
2076
2077 CvSeq* CvBoost::get_weak_predictors()
2078 {
2079     return weak;
2080 }
2081
2082 const CvDTreeTrainData* CvBoost::get_data() const
2083 {
2084     return data;
2085 }
2086
2087 using namespace cv;
2088
2089 CvBoost::CvBoost( const Mat& _train_data, int _tflag,
2090                const Mat& _responses, const Mat& _var_idx,
2091                const Mat& _sample_idx, const Mat& _var_type,
2092                const Mat& _missing_mask,
2093                CvBoostParams _params )
2094 {
2095     weak = 0;
2096     data = 0;
2097     default_model_name = "my_boost_tree";
2098     active_vars = active_vars_abs = orig_response = sum_response = weak_eval =
2099         subsample_mask = weights = subtree_weights = 0;
2100     
2101     train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
2102           _var_type, _missing_mask, _params );
2103 }    
2104
2105
2106 bool
2107 CvBoost::train( const Mat& _train_data, int _tflag,
2108                const Mat& _responses, const Mat& _var_idx,
2109                const Mat& _sample_idx, const Mat& _var_type,
2110                const Mat& _missing_mask,
2111                CvBoostParams _params, bool _update )
2112 {
2113     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
2114         sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
2115     return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
2116           sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
2117           mmask.data.ptr ? &mmask : 0, _params, _update);
2118 }
2119
2120 float
2121 CvBoost::predict( const Mat& _sample, const Mat& _missing,
2122                   const Range& slice, bool raw_mode, bool return_sum ) const
2123 {
2124     CvMat sample = _sample, mmask = _missing;
2125     /*if( weak_responses )
2126     {
2127         int weak_count = cvSliceLength( slice, weak );
2128         if( weak_count >= weak->total )
2129         {
2130             weak_count = weak->total;
2131             slice.start_index = 0;
2132         }
2133         
2134         if( !(weak_responses->data && weak_responses->type() == CV_32FC1 &&
2135               (weak_responses->cols == 1 || weak_responses->rows == 1) &&
2136               weak_responses->cols + weak_responses->rows - 1 == weak_count) )
2137             weak_responses->create(weak_count, 1, CV_32FC1);
2138         pwr = &(wr = *weak_responses);
2139     }*/
2140     return predict(&sample, _missing.empty() ? 0 : &mmask, 0,
2141                    slice == Range::all() ? CV_WHOLE_SEQ : cvSlice(slice.start, slice.end),
2142                    raw_mode, return_sum);
2143 }
2144
2145 /* End of file. */