9 #define pCvDTreeNode CvDTreeNode*
11 #define CV_CMP_FLOAT(a,b) ((a) < (b))
12 static CV_IMPLEMENT_QSORT_EX( icvSortFloat, float, CV_CMP_FLOAT, float)
13 #define CV_CMP_INT(a,b) ((a) < (b))
14 static CV_IMPLEMENT_QSORT_EX( icvSortInt, int, CV_CMP_INT, int)
16 //===========================================================================
17 static string ToString(int i)
26 //===========================================================================
27 //----------------------------- CvGBTreesParams -----------------------------
28 //===========================================================================
30 CvGBTreesParams::CvGBTreesParams()
31 : CvDTreeParams( 3, 10, 0, false, 10, 0, false, false, 0 )
34 loss_function_type = CvGBTrees::SQUARED_LOSS;
35 subsample_portion = 0.8f;
39 //===========================================================================
41 CvGBTreesParams::CvGBTreesParams( int _loss_function_type, int _weak_count,
42 float _shrinkage, float _subsample_portion,
43 int _max_depth, bool _use_surrogates )
44 : CvDTreeParams( 3, 10, 0, false, 10, 0, false, false, 0 )
46 loss_function_type = _loss_function_type;
47 weak_count = _weak_count;
48 shrinkage = _shrinkage;
49 subsample_portion = _subsample_portion;
50 max_depth = _max_depth;
51 use_surrogates = _use_surrogates;
54 //===========================================================================
55 //------------------------------- CvGBTrees ---------------------------------
56 //===========================================================================
58 CvGBTrees::CvGBTrees()
62 default_model_name = "my_boost_tree";
63 orig_response = sum_response = sum_response_tmp = 0;
64 subsample_train = subsample_test = 0;
65 missing = sample_idx = 0;
73 //===========================================================================
75 int CvGBTrees::get_len(const CvMat* mat) const
77 return (mat->cols > mat->rows) ? mat->cols : mat->rows;
80 //===========================================================================
82 void CvGBTrees::clear()
87 CvSlice slice = CV_WHOLE_SEQ;
90 //data->shared = false;
91 for (int i=0; i<class_count; ++i)
93 int weak_count = cvSliceLength( slice, weak[i] );
94 if ((weak[i]) && (weak_count))
96 cvStartReadSeq( weak[i], &reader );
97 cvSetSeqReaderPos( &reader, slice.start_index );
98 for (int j=0; j<weak_count; ++j)
100 CV_READ_SEQ_ELEM( tree, reader );
107 for (int i=0; i<class_count; ++i)
108 if (weak[i]) cvReleaseMemStorage( &(weak[i]->storage) );
113 data->shared = false;
119 cvReleaseMat( &orig_response );
120 cvReleaseMat( &sum_response );
121 cvReleaseMat( &sum_response_tmp );
122 cvReleaseMat( &subsample_train );
123 cvReleaseMat( &subsample_test );
124 cvReleaseMat( &sample_idx );
125 cvReleaseMat( &missing );
126 cvReleaseMat( &class_labels );
129 //===========================================================================
131 CvGBTrees::~CvGBTrees()
136 //===========================================================================
138 CvGBTrees::CvGBTrees( const CvMat* _train_data, int _tflag,
139 const CvMat* _responses, const CvMat* _var_idx,
140 const CvMat* _sample_idx, const CvMat* _var_type,
141 const CvMat* _missing_mask, CvGBTreesParams _params )
145 default_model_name = "my_boost_tree";
146 orig_response = sum_response = sum_response_tmp = 0;
147 subsample_train = subsample_test = 0;
148 missing = sample_idx = 0;
153 train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
154 _var_type, _missing_mask, _params );
157 //===========================================================================
159 bool CvGBTrees::problem_type() const
161 switch (params.loss_function_type)
163 case DEVIANCE_LOSS: return false;
164 default: return true;
168 //===========================================================================
171 CvGBTrees::train( CvMLData* _data, CvGBTreesParams _params, bool update )
174 result = train ( _data->get_values(), CV_ROW_SAMPLE,
175 _data->get_responses(), _data->get_var_idx(),
176 _data->get_train_sample_idx(), _data->get_var_types(),
177 _data->get_missing(), _params, update);
178 //update is not supported
182 //===========================================================================
186 CvGBTrees::train( const CvMat* _train_data, int _tflag,
187 const CvMat* _responses, const CvMat* _var_idx,
188 const CvMat* _sample_idx, const CvMat* _var_type,
189 const CvMat* _missing_mask,
190 CvGBTreesParams _params, bool /*_update*/ ) //update is not supported
192 CvMemStorage* storage = 0;
195 bool is_regression = problem_type();
200 m - count of variables
202 int n = _train_data->rows;
203 int m = _train_data->cols;
204 if (_tflag != CV_ROW_SAMPLE)
210 CvMat* new_responses = cvCreateMat( n, 1, CV_32F);
211 cvZero(new_responses);
213 data = new CvDTreeTrainData( _train_data, _tflag, new_responses, _var_idx,
214 _sample_idx, _var_type, _missing_mask, _params, true, true );
217 missing = cvCreateMat(_missing_mask->rows, _missing_mask->cols,
218 _missing_mask->type);
219 cvCopy( _missing_mask, missing);
222 orig_response = cvCreateMat( 1, n, CV_32F );
223 int step = (_responses->cols > _responses->rows) ? 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
224 switch (CV_MAT_TYPE(_responses->type))
228 for (int i=0; i<n; ++i)
229 orig_response->data.fl[i] = _responses->data.fl[i*step];
233 for (int i=0; i<n; ++i)
234 orig_response->data.fl[i] = (float) _responses->data.i[i*step];
237 CV_Error(CV_StsUnmatchedFormats, "Response should be a 32fC1 or 32sC1 vector.");
243 unsigned char * mask = new unsigned char[n];
245 // compute the count of different output classes
246 for (int i=0; i<n; ++i)
250 for (int j=i; j<n; ++j)
251 if (int(orig_response->data.fl[j]) == int(orig_response->data.fl[i]))
256 class_labels = cvCreateMat(1, class_count, CV_32S);
257 class_labels->data.i[0] = int(orig_response->data.fl[0]);
259 for (int i=1; i<n; ++i)
262 while ((k<j) && (int(orig_response->data.fl[i]) - class_labels->data.i[k]))
266 class_labels->data.i[k] = int(orig_response->data.fl[i]);
272 // inside gbt learning proccess only regression decision trees are built
273 data->is_classifier = false;
275 // preproccessing sample indices
278 int sample_idx_len = get_len(_sample_idx);
280 switch (CV_MAT_TYPE(_sample_idx->type))
284 sample_idx = cvCreateMat( 1, sample_idx_len, CV_32S );
285 for (int i=0; i<sample_idx_len; ++i)
286 sample_idx->data.i[i] = _sample_idx->data.i[i];
287 icvSortInt(sample_idx->data.i, sample_idx_len, 0);
292 int active_samples_count = 0;
293 for (int i=0; i<sample_idx_len; ++i)
294 active_samples_count += int( _sample_idx->data.ptr[i] );
295 sample_idx = cvCreateMat( 1, active_samples_count, CV_32S );
296 active_samples_count = 0;
297 for (int i=0; i<sample_idx_len; ++i)
298 if (int( _sample_idx->data.ptr[i] ))
299 sample_idx->data.i[active_samples_count++] = i;
302 default: CV_Error(CV_StsUnmatchedFormats, "_sample_idx should be a 32sC1, 8sC1 or 8uC1 vector.");
307 sample_idx = cvCreateMat( 1, n, CV_32S );
308 for (int i=0; i<n; ++i)
309 sample_idx->data.i[i] = i;
312 sum_response = cvCreateMat(class_count, n, CV_32F);
313 sum_response_tmp = cvCreateMat(class_count, n, CV_32F);
314 cvZero(sum_response);
318 in the case of a regression problem the initial guess (the zero term
319 in the sum) is set to the mean of all the training responses, that is
320 the best constant model
322 if (is_regression) base_value = find_optimal_value(sample_idx);
324 in the case of a classification problem the initial guess (the zero term
325 in the sum) is set to zero for all the trees sequences
327 else base_value = 0.0f;
329 current predicition on all training samples is set to be
330 equal to the base_value
332 cvSet( sum_response, cvScalar(base_value) );
334 weak = new pCvSeq[class_count];
335 for (int i=0; i<class_count; ++i)
337 storage = cvCreateMemStorage();
338 weak[i] = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvDTree*), storage );
342 // subsample params and data
345 int samples_count = get_len(sample_idx);
347 params.subsample_portion = params.subsample_portion <= FLT_EPSILON ||
348 1 - params.subsample_portion <= FLT_EPSILON
349 ? 1 : params.subsample_portion;
350 int train_sample_count = cvFloor(params.subsample_portion * samples_count);
351 if (train_sample_count == 0)
352 train_sample_count = samples_count;
353 int test_sample_count = samples_count - train_sample_count;
354 int* idx_data = new int[samples_count];
355 subsample_train = cvCreateMatHeader( 1, train_sample_count, CV_32SC1 );
356 *subsample_train = cvMat( 1, train_sample_count, CV_32SC1, idx_data );
357 if (test_sample_count)
359 subsample_test = cvCreateMatHeader( 1, test_sample_count, CV_32SC1 );
360 *subsample_test = cvMat( 1, test_sample_count, CV_32SC1,
361 idx_data + train_sample_count );
364 // training procedure
366 for ( int i=0; i < params.weak_count; ++i )
369 for ( int k=0; k < class_count; ++k )
372 CvDTree* tree = new CvDTree;
373 tree->train( data, subsample_train );
374 change_values(tree, k);
380 int* sample_data = sample_idx->data.i;
381 int* subsample_data = subsample_test->data.i;
382 int s_step = (sample_idx->cols > sample_idx->rows) ? 1
383 : sample_idx->step/CV_ELEM_SIZE(sample_idx->type);
384 for (int j=0; j<get_len(subsample_test); ++j)
386 int idx = *(sample_data + subsample_data[j]*s_step);
388 if (_tflag == CV_ROW_SAMPLE)
389 cvGetRow( data->train_data, &x, idx);
391 cvGetCol( data->train_data, &x, idx);
395 if (_tflag == CV_ROW_SAMPLE)
396 cvGetRow( missing, &x_miss, idx);
398 cvGetCol( missing, &x_miss, idx);
400 res = (float)tree->predict(&x, &x_miss)->value;
404 res = (float)tree->predict(&x)->value;
406 sum_response_tmp->data.fl[idx + k*n] =
407 sum_response->data.fl[idx + k*n] +
408 params.shrinkage * res;
412 cvSeqPush( weak[k], &tree );
414 } // k=0..class_count
416 tmp = sum_response_tmp;
417 sum_response_tmp = sum_response;
420 } // i=0..params.weak_count
423 cvReleaseMat(&new_responses);
424 data->free_train_data();
428 } // CvGBTrees::train(...)
430 //===========================================================================
432 inline float Sign(float x)
434 if (x<0.0f) return -1.0f;
435 else if (x>0.0f) return 1.0f;
439 //===========================================================================
441 void CvGBTrees::find_gradient(const int k)
443 int* sample_data = sample_idx->data.i;
444 int* subsample_data = subsample_train->data.i;
445 float* grad_data = data->responses->data.fl;
446 float* resp_data = orig_response->data.fl;
447 float* current_data = sum_response->data.fl;
449 switch (params.loss_function_type)
450 // loss_function_type in
451 // {SQUARED_LOSS, ABSOLUTE_LOSS, HUBER_LOSS, DEVIANCE_LOSS}
455 for (int i=0; i<get_len(subsample_train); ++i)
457 int s_step = (sample_idx->cols > sample_idx->rows) ? 1
458 : sample_idx->step/CV_ELEM_SIZE(sample_idx->type);
459 int idx = *(sample_data + subsample_data[i]*s_step);
460 grad_data[idx] = resp_data[idx] - current_data[idx];
466 for (int i=0; i<get_len(subsample_train); ++i)
468 int s_step = (sample_idx->cols > sample_idx->rows) ? 1
469 : sample_idx->step/CV_ELEM_SIZE(sample_idx->type);
470 int idx = *(sample_data + subsample_data[i]*s_step);
471 grad_data[idx] = Sign(resp_data[idx] - current_data[idx]);
478 int n = get_len(subsample_train);
479 int s_step = (sample_idx->cols > sample_idx->rows) ? 1
480 : sample_idx->step/CV_ELEM_SIZE(sample_idx->type);
482 float* residuals = new float[n];
483 for (int i=0; i<n; ++i)
485 int idx = *(sample_data + subsample_data[i]*s_step);
486 residuals[i] = fabs(resp_data[idx] - current_data[idx]);
488 icvSortFloat(residuals, n, 0.0f);
490 delta = residuals[int(ceil(n*alpha))];
492 for (int i=0; i<n; ++i)
494 int idx = *(sample_data + subsample_data[i]*s_step);
495 float r = resp_data[idx] - current_data[idx];
496 grad_data[idx] = (fabs(r) > delta) ? delta*Sign(r) : r;
504 for (int i=0; i<get_len(subsample_train); ++i)
508 int s_step = (sample_idx->cols > sample_idx->rows) ? 1
509 : sample_idx->step/CV_ELEM_SIZE(sample_idx->type);
510 int idx = *(sample_data + subsample_data[i]*s_step);
512 for (int j=0; j<class_count; ++j)
515 res = current_data[idx + j*sum_response->cols];
517 if (j == k) exp_fk = res;
520 int orig_label = int(resp_data[idx]);
522 grad_data[idx] = (float)(!(k-class_labels->data.i[orig_label]+1)) -
523 (float)(exp_fk / exp_sfi);
525 int ensemble_label = 0;
526 while (class_labels->data.i[ensemble_label] - orig_label)
529 grad_data[idx] = (float)(!(k-ensemble_label)) -
530 (float)(exp_fk / exp_sfi);
537 } // CvGBTrees::find_gradient(...)
539 //===========================================================================
541 void CvGBTrees::change_values(CvDTree* tree, const int _k)
543 CvDTreeNode** predictions = new pCvDTreeNode[get_len(subsample_train)];
545 int* sample_data = sample_idx->data.i;
546 int* subsample_data = subsample_train->data.i;
547 int s_step = (sample_idx->cols > sample_idx->rows) ? 1
548 : sample_idx->step/CV_ELEM_SIZE(sample_idx->type);
553 for (int i=0; i<get_len(subsample_train); ++i)
555 int idx = *(sample_data + subsample_data[i]*s_step);
556 if (data->tflag == CV_ROW_SAMPLE)
557 cvGetRow( data->train_data, &x, idx);
559 cvGetCol( data->train_data, &x, idx);
563 if (data->tflag == CV_ROW_SAMPLE)
564 cvGetRow( missing, &miss_x, idx);
566 cvGetCol( missing, &miss_x, idx);
568 predictions[i] = tree->predict(&x, &miss_x);
571 predictions[i] = tree->predict(&x);
575 CvDTreeNode** leaves;
576 int leaves_count = 0;
577 leaves = GetLeaves( tree, leaves_count);
579 for (int i=0; i<leaves_count; ++i)
581 int samples_in_leaf = 0;
582 for (int j=0; j<get_len(subsample_train); ++j)
584 if (leaves[i] == predictions[j]) samples_in_leaf++;
587 if (!samples_in_leaf) // It should not be done anyways! but...
589 leaves[i]->value = 0.0;
593 CvMat* leaf_idx = cvCreateMat(1, samples_in_leaf, CV_32S);
594 int* leaf_idx_data = leaf_idx->data.i;
596 for (int j=0; j<get_len(subsample_train); ++j)
598 int idx = *(sample_data + subsample_data[j]*s_step);
599 if (leaves[i] == predictions[j])
600 *leaf_idx_data++ = idx;
603 float value = find_optimal_value(leaf_idx);
604 leaves[i]->value = value;
606 leaf_idx_data = leaf_idx->data.i;
608 int len = sum_response_tmp->cols;
609 for (int j=0; j<get_len(leaf_idx); ++j)
611 int idx = leaf_idx_data[j];
612 sum_response_tmp->data.fl[idx + _k*len] =
613 sum_response->data.fl[idx + _k*len] +
614 params.shrinkage * value;
617 cvReleaseMat(&leaf_idx);
620 // releasing the memory
621 for (int i=0; i<get_len(subsample_train); ++i)
625 delete[] predictions;
627 for (int i=0; i<leaves_count; ++i)
635 //===========================================================================
637 void CvGBTrees::change_values(CvDTree* tree, const int _k)
640 CvDTreeNode** leaves;
641 int leaves_count = 0;
642 int offset = _k*sum_response_tmp->cols;
646 leaves = GetLeaves( tree, leaves_count);
648 for (int i=0; i<leaves_count; ++i)
650 int n = leaves[i]->sample_count;
651 int* leaf_idx_data = new int[n];
652 data->get_sample_indices(leaves[i], leaf_idx_data);
653 //CvMat* leaf_idx = new CvMat();
654 //cvInitMatHeader(leaf_idx, n, 1, CV_32S, leaf_idx_data);
656 leaf_idx.data.i = leaf_idx_data;
658 float value = find_optimal_value(&leaf_idx);
659 leaves[i]->value = value;
660 float val = params.shrinkage * value;
663 for (int j=0; j<n; ++j)
665 int idx = leaf_idx_data[j] + offset;
666 sum_response_tmp->data.fl[idx] = sum_response->data.fl[idx] + val;
669 //cvReleaseMat(&leaf_idx);
672 delete[] leaf_idx_data;
675 // releasing the memory
676 for (int i=0; i<leaves_count; ++i)
682 } //change_values(...);
684 //===========================================================================
686 float CvGBTrees::find_optimal_value( const CvMat* _Idx )
689 double gamma = (double)0.0;
691 int* idx = _Idx->data.i;
692 float* resp_data = orig_response->data.fl;
693 float* cur_data = sum_response->data.fl;
694 int n = get_len(_Idx);
696 switch (params.loss_function_type)
697 // SQUARED_LOSS=0, ABSOLUTE_LOSS=1, HUBER_LOSS=3, DEVIANCE_LOSS=4
701 for (int i=0; i<n; ++i)
702 gamma += resp_data[idx[i]] - cur_data[idx[i]];
708 float* residuals = new float[n];
709 for (int i=0; i<n; ++i, ++idx)
710 residuals[i] = (resp_data[*idx] - cur_data[*idx]);
711 icvSortFloat(residuals, n, 0.0f);
713 gamma = residuals[n/2];
714 else gamma = (residuals[n/2-1] + residuals[n/2]) / 2.0f;
720 float* residuals = new float[n];
721 for (int i=0; i<n; ++i, ++idx)
722 residuals[i] = (resp_data[*idx] - cur_data[*idx]);
723 icvSortFloat(residuals, n, 0.0f);
726 float r_median = (n == n_half<<1) ?
727 (residuals[n_half-1] + residuals[n_half]) / 2.0f :
730 for (int i=0; i<n; ++i)
732 float dif = residuals[i] - r_median;
733 gamma += (delta < fabs(dif)) ? Sign(dif)*delta : dif;
743 float* grad_data = data->responses->data.fl;
747 for (int i=0; i<n; ++i)
749 tmp = grad_data[idx[i]];
751 tmp2 += fabs(tmp)*(1-fabs(tmp));
758 gamma = ((double)(class_count-1)) / (double)class_count * (tmp1/tmp2);
766 } // CvGBTrees::find_optimal_value
768 //===========================================================================
771 void CvGBTrees::leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node )
773 if (node->left != NULL) leaves_get(leaves, count, node->left);
774 if (node->right != NULL) leaves_get(leaves, count, node->right);
775 if ((node->left == NULL) && (node->right == NULL))
776 leaves[count++] = node;
779 //---------------------------------------------------------------------------
781 CvDTreeNode** CvGBTrees::GetLeaves( const CvDTree* dtree, int& len )
784 CvDTreeNode** leaves = new pCvDTreeNode[(size_t)1 << params.max_depth];
785 leaves_get(leaves, len, const_cast<pCvDTreeNode>(dtree->get_root()));
789 //===========================================================================
791 void CvGBTrees::do_subsample()
794 int n = get_len(sample_idx);
795 int* idx = subsample_train->data.i;
797 for (int i = 0; i < n; i++ )
801 for (int i = 0; i < n; i++)
806 CV_SWAP( idx[a], idx[b], t );
810 int n = get_len(sample_idx);
811 if (subsample_train == 0)
812 subsample_train = cvCreateMat(1, n, CV_32S);
813 int* subsample_data = subsample_train->data.i;
814 for (int i=0; i<n; ++i)
815 subsample_data[i] = i;
820 //===========================================================================
822 float CvGBTrees::predict_serial( const CvMat* _sample, const CvMat* _missing,
823 CvMat* weak_responses, CvSlice slice, int k) const
827 if (!weak) return 0.0f;
830 int weak_count = cvSliceLength( slice, weak[class_count-1] );
835 if (CV_MAT_TYPE(weak_responses->type) != CV_32F)
837 if ((k >= 0) && (k<class_count) && (weak_responses->rows != 1))
839 if ((k == -1) && (weak_responses->rows != class_count))
841 if (weak_responses->cols != weak_count)
845 float* sum = new float[class_count];
846 memset(sum, 0, class_count*sizeof(float));
848 for (int i=0; i<class_count; ++i)
850 if ((weak[i]) && (weak_count))
852 cvStartReadSeq( weak[i], &reader );
853 cvSetSeqReaderPos( &reader, slice.start_index );
854 for (int j=0; j<weak_count; ++j)
856 CV_READ_SEQ_ELEM( tree, reader );
857 float p = (float)(tree->predict(_sample, _missing)->value);
858 sum[i] += params.shrinkage * p;
860 weak_responses->data.fl[i*weak_count+j] = p;
865 for (int i=0; i<class_count; ++i)
866 sum[i] += base_value;
868 if (class_count == 1)
875 if ((k>=0) && (k<class_count))
884 for (int i=1; i<class_count; ++i)
894 int orig_class_label = -1;
895 for (int i=0; i<get_len(class_labels); ++i)
896 if (class_labels->data.i[i] == class_label+1)
897 orig_class_label = i;
899 int orig_class_label = class_labels->data.i[class_label];
901 return float(orig_class_label);
905 class Tree_predictor : public cv::ParallelLoopBody
912 const CvMat* missing;
913 const float shrinkage;
915 static cv::Mutex SumMutex;
919 Tree_predictor() : weak(0), sum(0), k(0), sample(0), missing(0), shrinkage(1.0f) {}
920 Tree_predictor(pCvSeq* _weak, const int _k, const float _shrinkage,
921 const CvMat* _sample, const CvMat* _missing, float* _sum ) :
922 weak(_weak), sum(_sum), k(_k), sample(_sample),
923 missing(_missing), shrinkage(_shrinkage)
926 Tree_predictor( const Tree_predictor& p, cv::Split ) :
927 weak(p.weak), sum(p.sum), k(p.k), sample(p.sample),
928 missing(p.missing), shrinkage(p.shrinkage)
931 Tree_predictor& operator=( const Tree_predictor& )
934 virtual void operator()(const cv::Range& range) const
937 int begin = range.start;
940 int weak_count = end - begin;
943 for (int i=0; i<k; ++i)
945 float tmp_sum = 0.0f;
946 if ((weak[i]) && (weak_count))
948 cvStartReadSeq( weak[i], &reader );
949 cvSetSeqReaderPos( &reader, begin );
950 for (int j=0; j<weak_count; ++j)
952 CV_READ_SEQ_ELEM( tree, reader );
953 tmp_sum += shrinkage*(float)(tree->predict(sample, missing)->value);
958 cv::AutoLock lock(SumMutex);
962 } // Tree_predictor::operator()
964 virtual ~Tree_predictor() {}
966 }; // class Tree_predictor
968 cv::Mutex Tree_predictor::SumMutex;
971 float CvGBTrees::predict( const CvMat* _sample, const CvMat* _missing,
972 CvMat* /*weak_responses*/, CvSlice slice, int k) const
975 if (!weak) return 0.0f;
976 float* sum = new float[class_count];
977 for (int i=0; i<class_count; ++i)
979 int begin = slice.start_index;
980 int end = begin + cvSliceLength( slice, weak[0] );
982 pCvSeq* weak_seq = weak;
983 Tree_predictor predictor = Tree_predictor(weak_seq, class_count,
984 params.shrinkage, _sample, _missing, sum);
986 cv::parallel_for_(cv::Range(begin, end), predictor);
988 for (int i=0; i<class_count; ++i)
989 sum[i] = sum[i] /** params.shrinkage*/ + base_value;
991 if (class_count == 1)
998 if ((k>=0) && (k<class_count))
1006 int class_label = 0;
1007 for (int i=1; i<class_count; ++i)
1015 int orig_class_label = class_labels->data.i[class_label];
1017 return float(orig_class_label);
1021 //===========================================================================
1023 void CvGBTrees::write_params( CvFileStorage* fs ) const
1025 const char* loss_function_type_str =
1026 params.loss_function_type == SQUARED_LOSS ? "SquaredLoss" :
1027 params.loss_function_type == ABSOLUTE_LOSS ? "AbsoluteLoss" :
1028 params.loss_function_type == HUBER_LOSS ? "HuberLoss" :
1029 params.loss_function_type == DEVIANCE_LOSS ? "DevianceLoss" : 0;
1032 if( loss_function_type_str )
1033 cvWriteString( fs, "loss_function", loss_function_type_str );
1035 cvWriteInt( fs, "loss_function", params.loss_function_type );
1037 cvWriteInt( fs, "ensemble_length", params.weak_count );
1038 cvWriteReal( fs, "shrinkage", params.shrinkage );
1039 cvWriteReal( fs, "subsample_portion", params.subsample_portion );
1040 //cvWriteInt( fs, "max_tree_depth", params.max_depth );
1041 //cvWriteString( fs, "use_surrogate_splits", params.use_surrogates ? "true" : "false");
1042 if (class_labels) cvWrite( fs, "class_labels", class_labels);
1044 data->is_classifier = !problem_type();
1045 data->write_params( fs );
1046 data->is_classifier = 0;
1050 //===========================================================================
1052 void CvGBTrees::read_params( CvFileStorage* fs, CvFileNode* fnode )
1054 CV_FUNCNAME( "CvGBTrees::read_params" );
1060 if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
1063 data = new CvDTreeTrainData();
1064 CV_CALL( data->read_params(fs, fnode));
1065 data->shared = true;
1067 params.max_depth = data->params.max_depth;
1068 params.min_sample_count = data->params.min_sample_count;
1069 params.max_categories = data->params.max_categories;
1070 params.priors = data->params.priors;
1071 params.regression_accuracy = data->params.regression_accuracy;
1072 params.use_surrogates = data->params.use_surrogates;
1074 temp = cvGetFileNodeByName( fs, fnode, "loss_function" );
1078 if( temp && CV_NODE_IS_STRING(temp->tag) )
1080 const char* loss_function_type_str = cvReadString( temp, "" );
1081 params.loss_function_type = strcmp( loss_function_type_str, "SquaredLoss" ) == 0 ? SQUARED_LOSS :
1082 strcmp( loss_function_type_str, "AbsoluteLoss" ) == 0 ? ABSOLUTE_LOSS :
1083 strcmp( loss_function_type_str, "HuberLoss" ) == 0 ? HUBER_LOSS :
1084 strcmp( loss_function_type_str, "DevianceLoss" ) == 0 ? DEVIANCE_LOSS : -1;
1087 params.loss_function_type = cvReadInt( temp, -1 );
1090 if( params.loss_function_type < SQUARED_LOSS || params.loss_function_type > DEVIANCE_LOSS || params.loss_function_type == 2)
1091 CV_ERROR( CV_StsBadArg, "Unknown loss function" );
1093 params.weak_count = cvReadIntByName( fs, fnode, "ensemble_length" );
1094 params.shrinkage = (float)cvReadRealByName( fs, fnode, "shrinkage", 0.1 );
1095 params.subsample_portion = (float)cvReadRealByName( fs, fnode, "subsample_portion", 1.0 );
1097 if (data->is_classifier)
1099 class_labels = (CvMat*)cvReadByName( fs, fnode, "class_labels" );
1100 if( class_labels && !CV_IS_MAT(class_labels))
1101 CV_ERROR( CV_StsParseError, "class_labels must stored as a matrix");
1103 data->is_classifier = 0;
1111 void CvGBTrees::write( CvFileStorage* fs, const char* name ) const
1113 CV_FUNCNAME( "CvGBTrees::write" );
1121 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_GBT );
1124 CV_ERROR( CV_StsBadArg, "The model has not been trained yet" );
1127 cvWriteReal( fs, "base_value", base_value);
1128 cvWriteInt( fs, "class_count", class_count);
1130 for ( int j=0; j < class_count; ++j )
1134 cvStartWriteStruct( fs, s.c_str(), CV_NODE_SEQ );
1136 cvStartReadSeq( weak[j], &reader );
1138 for( i = 0; i < weak[j]->total; i++ )
1141 CV_READ_SEQ_ELEM( tree, reader );
1142 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
1144 cvEndWriteStruct( fs );
1147 cvEndWriteStruct( fs );
1150 cvEndWriteStruct( fs );
1156 //===========================================================================
1159 void CvGBTrees::read( CvFileStorage* fs, CvFileNode* node )
1162 CV_FUNCNAME( "CvGBTrees::read" );
1167 CvFileNode* trees_fnode;
1168 CvMemStorage* storage;
1173 read_params( fs, node );
1178 base_value = (float)cvReadRealByName( fs, node, "base_value", 0.0 );
1179 class_count = cvReadIntByName( fs, node, "class_count", 1 );
1181 weak = new pCvSeq[class_count];
1184 for (int j=0; j<class_count; ++j)
1189 trees_fnode = cvGetFileNodeByName( fs, node, s.c_str() );
1190 if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
1191 CV_ERROR( CV_StsParseError, "<trees_x> tag is missing" );
1193 cvStartReadSeq( trees_fnode->data.seq, &reader );
1194 ntrees = trees_fnode->data.seq->total;
1196 if( ntrees != params.weak_count )
1197 CV_ERROR( CV_StsUnmatchedSizes,
1198 "The number of trees stored does not match <ntrees> tag value" );
1200 CV_CALL( storage = cvCreateMemStorage() );
1201 weak[j] = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvDTree*), storage );
1203 for( i = 0; i < ntrees; i++ )
1205 CvDTree* tree = new CvDTree();
1206 CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, data ));
1207 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1208 cvSeqPush( weak[j], &tree );
1215 //===========================================================================
1217 class Sample_predictor : public cv::ParallelLoopBody
1220 const CvGBTrees* gbt;
1222 const CvMat* samples;
1223 const CvMat* missing;
1228 Sample_predictor() : gbt(0), predictions(0), samples(0), missing(0),
1229 idx(0), slice(CV_WHOLE_SEQ)
1232 Sample_predictor(const CvGBTrees* _gbt, float* _predictions,
1233 const CvMat* _samples, const CvMat* _missing,
1234 const CvMat* _idx, CvSlice _slice=CV_WHOLE_SEQ) :
1235 gbt(_gbt), predictions(_predictions), samples(_samples),
1236 missing(_missing), idx(_idx), slice(_slice)
1240 Sample_predictor( const Sample_predictor& p, cv::Split ) :
1241 gbt(p.gbt), predictions(p.predictions),
1242 samples(p.samples), missing(p.missing), idx(p.idx),
1247 virtual void operator()(const cv::Range& range) const
1249 int begin = range.start;
1250 int end = range.end;
1255 for (int i=begin; i<end; ++i)
1257 int j = idx ? idx->data.i[i] : i;
1258 cvGetRow(samples, &x, j);
1261 predictions[i] = gbt->predict_serial(&x,0,0,slice);
1265 cvGetRow(missing, &miss, j);
1266 predictions[i] = gbt->predict_serial(&x,&miss,0,slice);
1269 } // Sample_predictor::operator()
1271 virtual ~Sample_predictor() {}
1273 }; // class Sample_predictor
1277 // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1279 CvGBTrees::calc_error( CvMLData* _data, int type, std::vector<float> *resp )
1283 const CvMat* _sample_idx = (type == CV_TRAIN_ERROR) ?
1284 _data->get_train_sample_idx() :
1285 _data->get_test_sample_idx();
1286 const CvMat* response = _data->get_responses();
1288 int n = _sample_idx ? get_len(_sample_idx) : 0;
1289 n = (type == CV_TRAIN_ERROR && n == 0) ? _data->get_values()->rows : n;
1294 float* pred_resp = 0;
1295 bool needsFreeing = false;
1300 pred_resp = &((*resp)[0]);
1304 pred_resp = new float[n];
1305 needsFreeing = true;
1308 Sample_predictor predictor = Sample_predictor(this, pred_resp, _data->get_values(),
1309 _data->get_missing(), _sample_idx);
1311 cv::parallel_for_(cv::Range(0,n), predictor);
1313 int* sidx = _sample_idx ? _sample_idx->data.i : 0;
1314 int r_step = CV_IS_MAT_CONT(response->type) ?
1315 1 : response->step / CV_ELEM_SIZE(response->type);
1318 if ( !problem_type() )
1320 for( int i = 0; i < n; i++ )
1322 int si = sidx ? sidx[i] : i;
1323 int d = fabs((double)pred_resp[i] - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
1326 err = err / (float)n * 100.0f;
1330 for( int i = 0; i < n; i++ )
1332 int si = sidx ? sidx[i] : i;
1333 float d = pred_resp[i] - response->data.fl[si*r_step];
1336 err = err / (float)n;
1346 CvGBTrees::CvGBTrees( const cv::Mat& trainData, int tflag,
1347 const cv::Mat& responses, const cv::Mat& varIdx,
1348 const cv::Mat& sampleIdx, const cv::Mat& varType,
1349 const cv::Mat& missingDataMask,
1350 CvGBTreesParams _params )
1354 default_model_name = "my_boost_tree";
1355 orig_response = sum_response = sum_response_tmp = 0;
1356 subsample_train = subsample_test = 0;
1357 missing = sample_idx = 0;
1364 train(trainData, tflag, responses, varIdx, sampleIdx, varType, missingDataMask, _params, false);
1367 bool CvGBTrees::train( const cv::Mat& trainData, int tflag,
1368 const cv::Mat& responses, const cv::Mat& varIdx,
1369 const cv::Mat& sampleIdx, const cv::Mat& varType,
1370 const cv::Mat& missingDataMask,
1371 CvGBTreesParams _params,
1374 CvMat _trainData = trainData, _responses = responses;
1375 CvMat _varIdx = varIdx, _sampleIdx = sampleIdx, _varType = varType;
1376 CvMat _missingDataMask = missingDataMask;
1378 return train( &_trainData, tflag, &_responses, varIdx.empty() ? 0 : &_varIdx,
1379 sampleIdx.empty() ? 0 : &_sampleIdx, varType.empty() ? 0 : &_varType,
1380 missingDataMask.empty() ? 0 : &_missingDataMask, _params, update);
1383 float CvGBTrees::predict( const cv::Mat& sample, const cv::Mat& _missing,
1384 const cv::Range& slice, int k ) const
1386 CvMat _sample = sample, miss = _missing;
1387 return predict(&_sample, _missing.empty() ? 0 : &miss, 0,
1388 slice==cv::Range::all() ? CV_WHOLE_SEQ : cvSlice(slice.start, slice.end), k);