removed contrib, legacy and softcsscade modules; removed latentsvm and datamatrix...
[platform/upstream/opencv.git] / samples / cpp / tree_engine.cpp
1 #include "opencv2/ml/ml.hpp"
2 #include "opencv2/core/core_c.h"
3 #include "opencv2/core/utility.hpp"
4 #include <stdio.h>
5 #include <map>
6
7 static void help()
8 {
9     printf(
10         "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees:\n"
11         "CvDTree dtree;\n"
12         "CvBoost boost;\n"
13         "CvRTrees rtrees;\n"
14         "CvERTrees ertrees;\n"
15         "CvGBTrees gbtrees;\n"
16         "Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n"
17         "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
18         "-c specifies that the response is categorical (it's ordered by default) and\n"
19         "<csv filename> is the name of training data file in comma-separated value format\n\n");
20 }
21
22
23 static int count_classes(CvMLData& data)
24 {
25     cv::Mat r = cv::cvarrToMat(data.get_responses());
26     std::map<int, int> rmap;
27     int i, n = (int)r.total();
28     for( i = 0; i < n; i++ )
29     {
30         float val = r.at<float>(i);
31         int ival = cvRound(val);
32         if( ival != val )
33             return -1;
34         rmap[ival] = 1;
35     }
36     return (int)rmap.size();
37 }
38
39 static void print_result(float train_err, float test_err, const CvMat* _var_imp)
40 {
41     printf( "train error    %f\n", train_err );
42     printf( "test error    %f\n\n", test_err );
43
44     if (_var_imp)
45     {
46         cv::Mat var_imp = cv::cvarrToMat(_var_imp), sorted_idx;
47         cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
48
49         printf( "variable importance:\n" );
50         int i, n = (int)var_imp.total();
51         int type = var_imp.type();
52         CV_Assert(type == CV_32F || type == CV_64F);
53
54         for( i = 0; i < n; i++)
55         {
56             int k = sorted_idx.at<int>(i);
57             printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k));
58         }
59     }
60     printf("\n");
61 }
62
63 int main(int argc, char** argv)
64 {
65     if(argc < 2)
66     {
67         help();
68         return 0;
69     }
70     const char* filename = 0;
71     int response_idx = 0;
72     bool categorical_response = false;
73
74     for(int i = 1; i < argc; i++)
75     {
76         if(strcmp(argv[i], "-r") == 0)
77             sscanf(argv[++i], "%d", &response_idx);
78         else if(strcmp(argv[i], "-c") == 0)
79             categorical_response = true;
80         else if(argv[i][0] != '-' )
81             filename = argv[i];
82         else
83         {
84             printf("Error. Invalid option %s\n", argv[i]);
85             help();
86             return -1;
87         }
88     }
89
90     printf("\nReading in %s...\n\n",filename);
91     CvDTree dtree;
92     CvBoost boost;
93     CvRTrees rtrees;
94     CvERTrees ertrees;
95     CvGBTrees gbtrees;
96
97     CvMLData data;
98
99
100     CvTrainTestSplit spl( 0.5f );
101
102     if ( data.read_csv( filename ) == 0)
103     {
104         data.set_response_idx( response_idx );
105         if(categorical_response)
106             data.change_var_type( response_idx, CV_VAR_CATEGORICAL );
107         data.set_train_test_split( &spl );
108
109         printf("======DTREE=====\n");
110         dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
111         print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
112
113         if( categorical_response && count_classes(data) == 2 )
114         {
115         printf("======BOOST=====\n");
116         boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
117         print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
118         }
119
120         printf("======RTREES=====\n");
121         rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
122         print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );
123
124         printf("======ERTREES=====\n");
125         ertrees.train( &data, CvRTParams( 18, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
126         print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
127
128         printf("======GBTREES=====\n");
129         if (categorical_response)
130             gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.1f, 0.8f, 5, false));
131         else
132             gbtrees.train( &data, CvGBTreesParams(CvGBTrees::SQUARED_LOSS, 100, 0.1f, 0.8f, 5, false));
133         print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
134     }
135     else
136         printf("File can not be read");
137
138     return 0;
139 }