1 #include "opencv2/ml/ml.hpp"
2 #include "opencv2/core/core_c.h"
9 "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees:\n"
13 "CvERTrees ertrees;\n"
14 "CvGBTrees gbtrees;\n"
15 "Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n"
16 "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
17 "-c specifies that the response is categorical (it's ordered by default) and\n"
18 "<csv filename> is the name of training data file in comma-separated value format\n\n");
22 int count_classes(CvMLData& data)
24 cv::Mat r(data.get_responses());
25 std::map<int, int> rmap;
26 int i, n = (int)r.total();
27 for( i = 0; i < n; i++ )
29 float val = r.at<float>(i);
30 int ival = cvRound(val);
38 void print_result(float train_err, float test_err, const CvMat* _var_imp)
40 printf( "train error %f\n", train_err );
41 printf( "test error %f\n\n", test_err );
45 cv::Mat var_imp(_var_imp), sorted_idx;
46 cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
48 printf( "variable importance:\n" );
49 int i, n = (int)var_imp.total();
50 int type = var_imp.type();
51 CV_Assert(type == CV_32F || type == CV_64F);
53 for( i = 0; i < n; i++)
55 int k = sorted_idx.at<int>(i);
56 printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k));
62 int main(int argc, char** argv)
69 const char* filename = 0;
71 bool categorical_response = false;
73 for(int i = 1; i < argc; i++)
75 if(strcmp(argv[i], "-r") == 0)
76 sscanf(argv[++i], "%d", &response_idx);
77 else if(strcmp(argv[i], "-c") == 0)
78 categorical_response = true;
79 else if(argv[i][0] != '-' )
83 printf("Error. Invalid option %s\n", argv[i]);
89 printf("\nReading in %s...\n\n",filename);
99 CvTrainTestSplit spl( 0.5f );
101 if ( data.read_csv( filename ) == 0)
103 data.set_response_idx( response_idx );
104 if(categorical_response)
105 data.change_var_type( response_idx, CV_VAR_CATEGORICAL );
106 data.set_train_test_split( &spl );
108 printf("======DTREE=====\n");
109 dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
110 print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
112 if( categorical_response && count_classes(data) == 2 )
114 printf("======BOOST=====\n");
115 boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
116 print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
119 printf("======RTREES=====\n");
120 rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
121 print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );
123 printf("======ERTREES=====\n");
124 ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
125 print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
127 printf("======GBTREES=====\n");
128 gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true));
129 print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
132 printf("File can not be read");