update help according single standart for it
[profile/ivi/opencv.git] / samples / c / tree_engine.cpp
1 #include "opencv2/ml/ml.hpp"
2 #include "opencv2/core/core_c.h"
3 #include <stdio.h>
4 #include <map>
5
6 void help()
7 {
8         printf(
9                 "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees:\n"
10                 "CvDTree dtree;\n"
11                 "CvBoost boost;\n"
12                 "CvRTrees rtrees;\n"
13                 "CvERTrees ertrees;\n"
14                 "CvGBTrees gbtrees;\n"
15                 "Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n"
16         "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
17         "-c specifies that the response is categorical (it's ordered by default) and\n"
18         "<csv filename> is the name of training data file in comma-separated value format\n\n");
19 }
20
21
22 int count_classes(CvMLData& data)
23 {
24     cv::Mat r(data.get_responses());
25     std::map<int, int> rmap;
26     int i, n = (int)r.total();
27     for( i = 0; i < n; i++ )
28     {
29         float val = r.at<float>(i);
30         int ival = cvRound(val);
31         if( ival != val )
32             return -1;
33         rmap[ival] = 1; 
34     }
35     return rmap.size();
36 }
37
38 void print_result(float train_err, float test_err, const CvMat* _var_imp)
39 {
40     printf( "train error    %f\n", train_err );
41     printf( "test error    %f\n\n", test_err );
42        
43     if (_var_imp)
44     {
45         cv::Mat var_imp(_var_imp), sorted_idx;
46         cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
47         
48         printf( "variable importance:\n" );
49         int i, n = (int)var_imp.total();
50         int type = var_imp.type();
51         CV_Assert(type == CV_32F || type == CV_64F);
52         
53         for( i = 0; i < n; i++)
54         {
55             int k = sorted_idx.at<int>(i);
56             printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k));
57         }
58     }
59     printf("\n");
60 }
61
62 int main(int argc, char** argv)
63 {
64     if(argc < 2)
65     {
66         help();
67         return 0;
68     }
69     const char* filename = 0;
70     int response_idx = 0;
71     bool categorical_response = false;
72     
73     for(int i = 1; i < argc; i++)
74     {
75         if(strcmp(argv[i], "-r") == 0)
76             sscanf(argv[++i], "%d", &response_idx);
77         else if(strcmp(argv[i], "-c") == 0)
78             categorical_response = true;
79         else if(argv[i][0] != '-' )
80             filename = argv[i];
81         else
82         {
83             printf("Error. Invalid option %s\n", argv[i]);
84             help();
85             return -1;
86         }
87     }
88         
89     printf("\nReading in %s...\n\n",filename);
90     CvDTree dtree;
91     CvBoost boost;
92     CvRTrees rtrees;
93     CvERTrees ertrees;
94         CvGBTrees gbtrees;
95
96     CvMLData data;
97
98     
99     CvTrainTestSplit spl( 0.5f );
100     
101     if ( data.read_csv( filename ) == 0)
102     {
103         data.set_response_idx( response_idx );
104         if(categorical_response)
105             data.change_var_type( response_idx, CV_VAR_CATEGORICAL );
106         data.set_train_test_split( &spl );
107         
108         printf("======DTREE=====\n");
109         dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
110         print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
111
112         if( categorical_response && count_classes(data) == 2 )
113         {
114         printf("======BOOST=====\n");
115         boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
116         print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
117         }
118
119         printf("======RTREES=====\n");
120         rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
121         print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );
122
123         printf("======ERTREES=====\n");
124         ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
125         print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
126
127         printf("======GBTREES=====\n");
128         gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true));
129         print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
130     }
131     else
132         printf("File can not be read");
133
134     return 0;
135 }