From bbdd0aecbd43d3690d244dbb7ff52ce7a736cdfd Mon Sep 17 00:00:00 2001 From: Vadim Pisarevsky Date: Tue, 5 Apr 2011 15:13:10 +0000 Subject: [PATCH] improved tree_engine.cpp sample (added train file data specification; print sorted variable importance table) --- samples/c/tree_engine.cpp | 105 +++++++++++++++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 33 deletions(-) diff --git a/samples/c/tree_engine.cpp b/samples/c/tree_engine.cpp index e9a6446..3822840 100644 --- a/samples/c/tree_engine.cpp +++ b/samples/c/tree_engine.cpp @@ -1,5 +1,7 @@ #include "opencv2/ml/ml.hpp" +#include "opencv2/core/core_c.h" #include +#include void help() { @@ -10,41 +12,81 @@ void help() "CvRTrees rtrees;\n" "CvERTrees ertrees;\n" "CvGBTrees gbtrees;\n" - "Date is hard coded to come from filename = \"../../../opencv/samples/c/waveform.data\";\n" - "Or can come from filename = \"../../../opencv/samples/c/waveform.data\";\n" - "Call:\n" - "./tree_engine\n\n"); + "Call:\n\t./tree_engine [-r ] [-c] \n" + "where -r specified the 0-based index of the response (0 by default)\n" + "-c specifies that the response is categorical (it's ordered by default) and\n" + " is the name of training data file in comma-separated value format\n\n"); } -void print_result(float train_err, float test_err, const CvMat* var_imp) + + +int count_classes(CvMLData& data) +{ + cv::Mat r(data.get_responses()); + std::map rmap; + int i, n = (int)r.total(); + for( i = 0; i < n; i++ ) + { + float val = r.at(i); + int ival = cvRound(val); + if( ival != val ) + return -1; + rmap[ival] = 1; + } + return rmap.size(); +} + +void print_result(float train_err, float test_err, const CvMat* _var_imp) { printf( "train error %f\n", train_err ); printf( "test error %f\n\n", test_err ); - if (var_imp) + if (_var_imp) { - bool is_flt = false; - if ( CV_MAT_TYPE( var_imp->type ) == CV_32FC1) - is_flt = true; - printf( "variable impotance\n" ); - for( int i = 0; i < var_imp->cols; i++) + cv::Mat var_imp(_var_imp), sorted_idx; + cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING); + + printf( "variable importance:\n" ); + int i, n = (int)var_imp.total(); + int type = var_imp.type(); + CV_Assert(type == CV_32F || type == CV_64F); + + for( i = 0; i < n; i++) { - printf( "%d %f\n", i, is_flt ? var_imp->data.fl[i] : var_imp->data.db[i] ); + int k = sorted_idx.at(i); + printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at(k) : var_imp.at(k)); } } printf("\n"); } -int main() +int main(int argc, char** argv) { - const int train_sample_count = 300; - -#define LEPIOTA //Turn on discrete data set -#ifdef LEPIOTA //Of course, you might have to set the path here to what's on your machine ... - const char* filename = "../../opencv/samples/c/agaricus-lepiota.data"; -#else - const char* filename = "../../opencv/samples/c/waveform.data"; -#endif - printf("\n Reading in %s. If it is not found, you may have to change this hard-coded path in tree_engine.cpp\n\n",filename); + if(argc < 2) + { + help(); + return 0; + } + const char* filename = 0; + int response_idx = 0; + bool categorical_response = false; + + for(int i = 1; i < argc; i++) + { + if(strcmp(argv[i], "-r") == 0) + sscanf(argv[++i], "%d", &response_idx); + else if(strcmp(argv[i], "-c") == 0) + categorical_response = true; + else if(argv[i][0] != '-' ) + filename = argv[i]; + else + { + printf("Error. Invalid option %s\n", argv[i]); + help(); + return -1; + } + } + + printf("\nReading in %s...\n\n",filename); CvDTree dtree; CvBoost boost; CvRTrees rtrees; @@ -53,29 +95,26 @@ int main() CvMLData data; - CvTrainTestSplit spl( train_sample_count ); + + CvTrainTestSplit spl( 0.5f ); if ( data.read_csv( filename ) == 0) { - -#ifdef LEPIOTA - data.set_response_idx( 0 ); -#else - data.set_response_idx( 21 ); - data.change_var_type( 21, CV_VAR_CATEGORICAL ); -#endif - + data.set_response_idx( response_idx ); + if(categorical_response) + data.change_var_type( response_idx, CV_VAR_CATEGORICAL ); data.set_train_test_split( &spl ); printf("======DTREE=====\n"); dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 )); print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() ); -#ifdef LEPIOTA + if( categorical_response && count_classes(data) == 2 ) + { printf("======BOOST=====\n"); boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0)); print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance -#endif + } printf("======RTREES=====\n"); rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER )); -- 2.7.4