converted some more samples to C++
[profile/ivi/opencv.git] / samples / c / tree_engine.cpp
1 #include "opencv2/ml/ml.hpp"
2 #include <stdio.h>
3 /*
4 The sample demonstrates how to use different decision trees.
5 */
6 void print_result(float train_err, float test_err, const CvMat* var_imp)
7 {
8     printf( "train error    %f\n", train_err );
9     printf( "test error    %f\n\n", test_err );
10        
11     if (var_imp)
12     {
13         bool is_flt = false;
14         if ( CV_MAT_TYPE( var_imp->type ) == CV_32FC1)
15             is_flt = true;
16         printf( "variable impotance\n" );
17         for( int i = 0; i < var_imp->cols; i++)
18         {
19             printf( "%d     %f\n", i, is_flt ? var_imp->data.fl[i] : var_imp->data.db[i] );
20         }
21     }
22     printf("\n");
23 }
24
25 int main()
26 {
27     const int train_sample_count = 300;
28
29 //#define LEPIOTA
30 #ifdef LEPIOTA
31     const char* filename = "../../../OpenCV/samples/c/agaricus-lepiota.data";
32 #else
33     const char* filename = "../../../OpenCV/samples/c/waveform.data";
34 #endif
35
36     CvDTree dtree;
37     CvBoost boost;
38     CvRTrees rtrees;
39     CvERTrees ertrees;
40         CvGBTrees gbtrees;
41
42     CvMLData data;
43
44     CvTrainTestSplit spl( train_sample_count );
45     
46     if ( data.read_csv( filename ) == 0)
47     {
48
49 #ifdef LEPIOTA
50         data.set_response_idx( 0 );     
51 #else
52         data.set_response_idx( 21 );     
53         data.change_var_type( 21, CV_VAR_CATEGORICAL );
54 #endif
55
56         data.set_train_test_split( &spl );
57         
58         printf("======DTREE=====\n");
59         dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
60         print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
61
62 #ifdef LEPIOTA
63         printf("======BOOST=====\n");
64         boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
65         print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data ), 0 );
66 #endif
67
68         printf("======RTREES=====\n");
69         rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
70         print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );
71
72         printf("======ERTREES=====\n");
73         ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
74         print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
75
76                 printf("======GBTREES=====\n");
77                 gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true));
78                 print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 );
79     }
80     else
81         printf("File can not be read");
82
83     return 0;
84 }