improved tree_engine.cpp sample (added train file data specification; print sorted...
authorVadim Pisarevsky <no@email>
Tue, 5 Apr 2011 15:13:10 +0000 (15:13 +0000)
committerVadim Pisarevsky <no@email>
Tue, 5 Apr 2011 15:13:10 +0000 (15:13 +0000)
samples/c/tree_engine.cpp

index e9a6446..3822840 100644 (file)
@@ -1,5 +1,7 @@
 #include "opencv2/ml/ml.hpp"
+#include "opencv2/core/core_c.h"
 #include <stdio.h>
+#include <map>
 
 void help()
 {
@@ -10,41 +12,81 @@ void help()
                "CvRTrees rtrees;\n"
                "CvERTrees ertrees;\n"
                "CvGBTrees gbtrees;\n"
-               "Date is hard coded to come from filename = \"../../../opencv/samples/c/waveform.data\";\n"
-               "Or can come from filename = \"../../../opencv/samples/c/waveform.data\";\n"
-               "Call:\n"
-               "./tree_engine\n\n");
+               "Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n"
+        "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
+        "-c specifies that the response is categorical (it's ordered by default) and\n"
+        "<csv filename> is the name of training data file in comma-separated value format\n\n");
 }
-void print_result(float train_err, float test_err, const CvMat* var_imp)
+
+
+int count_classes(CvMLData& data)
+{
+    cv::Mat r(data.get_responses());
+    std::map<int, int> rmap;
+    int i, n = (int)r.total();
+    for( i = 0; i < n; i++ )
+    {
+        float val = r.at<float>(i);
+        int ival = cvRound(val);
+        if( ival != val )
+            return -1;
+        rmap[ival] = 1; 
+    }
+    return rmap.size();
+}
+
+void print_result(float train_err, float test_err, const CvMat* _var_imp)
 {
     printf( "train error    %f\n", train_err );
     printf( "test error    %f\n\n", test_err );
        
-    if (var_imp)
+    if (_var_imp)
     {
-        bool is_flt = false;
-        if ( CV_MAT_TYPE( var_imp->type ) == CV_32FC1)
-            is_flt = true;
-        printf( "variable impotance\n" );
-        for( int i = 0; i < var_imp->cols; i++)
+        cv::Mat var_imp(_var_imp), sorted_idx;
+        cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
+        
+        printf( "variable importance:\n" );
+        int i, n = (int)var_imp.total();
+        int type = var_imp.type();
+        CV_Assert(type == CV_32F || type == CV_64F);
+        
+        for( i = 0; i < n; i++)
         {
-            printf( "%d     %f\n", i, is_flt ? var_imp->data.fl[i] : var_imp->data.db[i] );
+            int k = sorted_idx.at<int>(i);
+            printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k));
         }
     }
     printf("\n");
 }
 
-int main()
+int main(int argc, char** argv)
 {
-    const int train_sample_count = 300;
-
-#define LEPIOTA //Turn on discrete data set
-#ifdef LEPIOTA   //Of course, you might have to set the path here to what's on your machine ...
-    const char* filename = "../../opencv/samples/c/agaricus-lepiota.data";
-#else
-    const char* filename = "../../opencv/samples/c/waveform.data";
-#endif
-    printf("\n Reading in %s. If it is not found, you may have to change this hard-coded path in tree_engine.cpp\n\n",filename);
+    if(argc < 2)
+    {
+        help();
+        return 0;
+    }
+    const char* filename = 0;
+    int response_idx = 0;
+    bool categorical_response = false;
+    
+    for(int i = 1; i < argc; i++)
+    {
+        if(strcmp(argv[i], "-r") == 0)
+            sscanf(argv[++i], "%d", &response_idx);
+        else if(strcmp(argv[i], "-c") == 0)
+            categorical_response = true;
+        else if(argv[i][0] != '-' )
+            filename = argv[i];
+        else
+        {
+            printf("Error. Invalid option %s\n", argv[i]);
+            help();
+            return -1;
+        }
+    }
+        
+    printf("\nReading in %s...\n\n",filename);
     CvDTree dtree;
     CvBoost boost;
     CvRTrees rtrees;
@@ -53,29 +95,26 @@ int main()
 
     CvMLData data;
 
-    CvTrainTestSplit spl( train_sample_count );
+    
+    CvTrainTestSplit spl( 0.5f );
     
     if ( data.read_csv( filename ) == 0)
     {
-
-#ifdef LEPIOTA
-        data.set_response_idx( 0 );     
-#else
-        data.set_response_idx( 21 );     
-        data.change_var_type( 21, CV_VAR_CATEGORICAL );
-#endif
-
+        data.set_response_idx( response_idx );
+        if(categorical_response)
+            data.change_var_type( response_idx, CV_VAR_CATEGORICAL );
         data.set_train_test_split( &spl );
         
         printf("======DTREE=====\n");
         dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
         print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
 
-#ifdef LEPIOTA
+        if( categorical_response && count_classes(data) == 2 )
+        {
         printf("======BOOST=====\n");
         boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
         print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
-#endif
+        }
 
         printf("======RTREES=====\n");
         rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));