added new ML models to points_classifier sample
authorMaria Dimashova <no@email>
Sat, 30 Apr 2011 18:04:33 +0000 (18:04 +0000)
committerMaria Dimashova <no@email>
Sat, 30 Apr 2011 18:04:33 +0000 (18:04 +0000)
samples/c/tree_engine.cpp
samples/cpp/points_classifier.cpp

index 3822840..4f41884 100644 (file)
@@ -124,9 +124,9 @@ int main(int argc, char** argv)
         ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
         print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
 
-               printf("======GBTREES=====\n");
-               gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true));
-               print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
+        printf("======GBTREES=====\n");
+        gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true));
+        print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
     }
     else
         printf("File can not be read");
index c665608..66aec92 100644 (file)
@@ -12,19 +12,23 @@ const string winName = "points";
 const int testStep = 5;
 
 
-Mat img, img_dst;
+Mat img, imgDst;
 RNG rng;
 
 vector<Point>  trainedPoints;
 vector<int>    trainedPointsMarkers;
 vector<Scalar> classColors;
 
-#define KNN 0
-#define SVM 0
-#define DT  1
-#define RF  0
-#define ANN 0
-#define GMM 0
+#define NBC 0 // normal Bayessian classifier
+#define KNN 0 // k nearest neighbors classifier
+#define SVM 0 // support vectors machine
+#define DT  1 // decision tree
+#define BT  0 // ADA Boost
+#define GBT 1 // gradient boosted trees
+#define RF  0 // random forest
+#define ERT 0 // extremely randomized trees
+#define ANN 0 // artificial neural networks
+#define EM  0 // expectation-maximization
 
 void on_mouse( int event, int x, int y, int /*flags*/, void* )
 {
@@ -44,8 +48,18 @@ void on_mouse( int event, int x, int y, int /*flags*/, void* )
     }
     else if( event == CV_EVENT_RBUTTONUP )
     {
-        classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
-        updateFlag = true;
+#if BT
+        if( classColors.size() < 2 )
+        {
+#endif
+            classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
+            updateFlag = true;
+#if BT
+        }
+        else
+            cout << "New class can not be added, because CvBoost can only be used for 2-class classification" << endl;
+#endif
+
     }
 
     //draw
@@ -84,10 +98,37 @@ void prepare_train_data( Mat& samples, Mat& classes )
     samples.convertTo( samples, CV_32FC1 );
 }
 
+#if NBC
+void find_decision_boundary_NBC()
+{
+    img.copyTo( imgDst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    // learn classifier
+    CvNormalBayesClassifier normalBayesClassifier( trainSamples, trainClasses );
+
+    Mat testSample( 1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            int response = (int)normalBayesClassifier.predict( testSample );
+            circle( imgDst, Point(x,y), 1, classColors[response] );
+        }
+    }
+}
+#endif
+
+
 #if KNN
 void find_decision_boundary_KNN( int K )
 {
-    img.copyTo( img_dst );
+    img.copyTo( imgDst );
 
     Mat trainSamples, trainClasses;
     prepare_train_data( trainSamples, trainClasses );
@@ -104,7 +145,7 @@ void find_decision_boundary_KNN( int K )
             testSample.at<float>(1) = (float)y;
 
             int response = (int)knnClassifier.find_nearest( testSample, K );
-            circle( img_dst, Point(x,y), 1, classColors[response] );
+            circle( imgDst, Point(x,y), 1, classColors[response] );
         }
     }
 }
@@ -113,7 +154,7 @@ void find_decision_boundary_KNN( int K )
 #if SVM
 void find_decision_boundary_SVM( CvSVMParams params )
 {
-    img.copyTo( img_dst );
+    img.copyTo( imgDst );
 
     Mat trainSamples, trainClasses;
     prepare_train_data( trainSamples, trainClasses );
@@ -130,7 +171,7 @@ void find_decision_boundary_SVM( CvSVMParams params )
             testSample.at<float>(1) = (float)y;
 
             int response = (int)svmClassifier.predict( testSample );
-            circle( img_dst, Point(x,y), 2, classColors[response], 1 );
+            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
         }
     }
 
@@ -138,7 +179,7 @@ void find_decision_boundary_SVM( CvSVMParams params )
     for( int i = 0; i < svmClassifier.get_support_vector_count(); i++ )
     {
         const float* supportVector = svmClassifier.get_support_vector(i);
-        circle( img_dst, Point(supportVector[0],supportVector[1]), 5, CV_RGB(255,255,255), -1 );
+        circle( imgDst, Point(supportVector[0],supportVector[1]), 5, CV_RGB(255,255,255), -1 );
     }
 
 }
@@ -147,7 +188,7 @@ void find_decision_boundary_SVM( CvSVMParams params )
 #if DT
 void find_decision_boundary_DT()
 {
-    img.copyTo( img_dst );
+    img.copyTo( imgDst );
 
     Mat trainSamples, trainClasses;
     prepare_train_data( trainSamples, trainClasses );
@@ -178,16 +219,96 @@ void find_decision_boundary_DT()
             testSample.at<float>(1) = (float)y;
 
             int response = (int)dtree.predict( testSample )->value;
-            circle( img_dst, Point(x,y), 2, classColors[response], 1 );
+            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
+        }
+    }
+}
+#endif
+
+#if BT
+void find_decision_boundary_BT()
+{
+    img.copyTo( imgDst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    // learn classifier
+    CvBoost  boost;
+
+    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
+    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
+
+    CvBoostParams  params( CvBoost::DISCRETE, // boost_type
+                           100, // weak_count
+                           0.95, // weight_trim_rate
+                           2, // max_depth
+                           false, //use_surrogates
+                           0 // priors
+                         );
+
+    boost.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
+
+    Mat testSample(1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            int response = (int)boost.predict( testSample );
+            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
+        }
+    }
+}
+
+#endif
+
+#if GBT
+void find_decision_boundary_GBT()
+{
+    img.copyTo( imgDst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    // learn classifier
+    CvGBTrees gbtrees;
+
+    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
+    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
+
+    CvGBTreesParams  params( CvGBTrees::SQUARED_LOSS, // loss_function_type
+                             100, // weak_count
+                             0.05f, // shrinkage
+                             0.6f, // subsample_portion
+                             2, // max_depth
+                             true // use_surrogates )
+                         );
+
+    gbtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
+
+    Mat testSample(1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            int response = (int)gbtrees.predict( testSample );
+            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
         }
     }
 }
+
 #endif
 
 #if RF
 void find_decision_boundary_RF()
 {
-    img.copyTo( img_dst );
+    img.copyTo( imgDst );
 
     Mat trainSamples, trainClasses;
     prepare_train_data( trainSamples, trainClasses );
@@ -222,17 +343,61 @@ void find_decision_boundary_RF()
             testSample.at<float>(1) = (float)y;
 
             int response = (int)rtrees.predict( testSample );
-            circle( img_dst, Point(x,y), 2, classColors[response], 1 );
+            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
         }
     }
 }
 
 #endif
 
+#if ERT
+void find_decision_boundary_ERT()
+{
+    img.copyTo( imgDst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    // learn classifier
+    CvERTrees ertrees;
+
+    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
+    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
+
+    CvRTParams  params( 4, // max_depth,
+                        2, // min_sample_count,
+                        0.f, // regression_accuracy,
+                        false, // use_surrogates,
+                        16, // max_categories,
+                        0, // priors,
+                        false, // calc_var_importance,
+                        1, // nactive_vars,
+                        5, // max_num_of_trees_in_the_forest,
+                        0, // forest_accuracy,
+                        CV_TERMCRIT_ITER // termcrit_type
+                       );
+
+    ertrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
+
+    Mat testSample(1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            int response = (int)ertrees.predict( testSample );
+            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
+        }
+    }
+}
+#endif
+
 #if ANN
 void find_decision_boundary_ANN( const Mat&  layer_sizes )
 {
-    img.copyTo( img_dst );
+    img.copyTo( imgDst );
 
     Mat trainSamples, trainClasses;
     prepare_train_data( trainSamples, trainClasses );
@@ -268,16 +433,16 @@ void find_decision_boundary_ANN( const Mat&  layer_sizes )
             ann.predict( testSample, outputs );
             Point maxLoc;
             minMaxLoc( outputs, 0, 0, 0, &maxLoc );
-            circle( img_dst, Point(x,y), 2, classColors[maxLoc.x], 1 );
+            circle( imgDst, Point(x,y), 2, classColors[maxLoc.x], 1 );
         }
     }
 }
 #endif
 
-#if GMM
-void find_decision_boundary_GMM()
+#if EM
+void find_decision_boundary_EM()
 {
-    img.copyTo( img_dst );
+    img.copyTo( imgDst );
 
     Mat trainSamples, trainClasses;
     prepare_train_data( trainSamples, trainClasses );
@@ -308,7 +473,7 @@ void find_decision_boundary_GMM()
             testSample.at<float>(1) = (float)y;
 
             int response = (int)em.predict( testSample );
-            circle( img_dst, Point(x,y), 2, classColors[response], 1 );
+            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
         }
     }
 }
@@ -316,9 +481,15 @@ void find_decision_boundary_GMM()
 
 int main()
 {
+    cout << "Use:" << endl
+         << "  right mouse button - to add new class;" << endl
+         << "  left mouse button - to add new point;" << endl
+         << "  key 'r' - to run the ML model;" << endl
+         << "  key 'i' - to init (clear) the data." << endl << endl;
+
     cv::namedWindow( "points", 1 );
     img.create( 480, 640, CV_8UC3 );
-    img_dst.create( 480, 640, CV_8UC3 );
+    imgDst.create( 480, 640, CV_8UC3 );
 
     imshow( "points", img );
     cvSetMouseCallback( "points", on_mouse );
@@ -342,16 +513,21 @@ int main()
 
         if( key == 'r' ) // run
         {
+#if NBC
+            find_decision_boundary_NBC();
+            cvNamedWindow( "NormalBayesClassifier", WINDOW_AUTOSIZE );
+            imshow( "NormalBayesClassifier", imgDst );
+#endif
 #if KNN
             int K = 3;
             find_decision_boundary_KNN( K );
             namedWindow( "kNN", WINDOW_AUTOSIZE );
-            imshow( "kNN", img_dst );
+            imshow( "kNN", imgDst );
 
             K = 15;
             find_decision_boundary_KNN( K );
             namedWindow( "kNN2", WINDOW_AUTOSIZE );
-            imshow( "kNN2", img_dst );
+            imshow( "kNN2", imgDst );
 #endif
 
 #if SVM
@@ -369,24 +545,42 @@ int main()
 
             find_decision_boundary_SVM( params );
             namedWindow( "classificationSVM1", WINDOW_AUTOSIZE );
-            imshow( "classificationSVM1", img_dst );
+            imshow( "classificationSVM1", imgDst );
 
             params.C = 10;
             find_decision_boundary_SVM( params );
             cvNamedWindow( "classificationSVM2", WINDOW_AUTOSIZE );
-            imshow( "classificationSVM2", img_dst );
+            imshow( "classificationSVM2", imgDst );
 #endif
 
 #if DT
             find_decision_boundary_DT();
             namedWindow( "DT", 1 );
-            imshow( "DT", img_dst );
+            imshow( "DT", imgDst );
+#endif
+
+#if BT
+            find_decision_boundary_BT();
+            namedWindow( "BT", 1 );
+            imshow( "BT", imgDst);
+#endif
+
+#if GBT
+            find_decision_boundary_GBT();
+            namedWindow( "GBT", 1 );
+            imshow( "GBT", imgDst);
 #endif
 
 #if RF
             find_decision_boundary_RF();
             namedWindow( "RF", 1 );
-            imshow( "RF", img_dst);
+            imshow( "RF", imgDst);
+#endif
+
+#if ERT
+            find_decision_boundary_ERT();
+            namedWindow( "ERT", 1 );
+            imshow( "ERT", imgDst);
 #endif
 
 #if ANN
@@ -396,13 +590,13 @@ int main()
             layer_sizes1.at<int>(2) = classColors.size();
             find_decision_boundary_ANN( layer_sizes1 );
             namedWindow( "ANN", WINDOW_AUTOSIZE );
-            imshow( "ANN", img_dst );
+            imshow( "ANN", imgDst );
 #endif
 
-#if GMM
-            find_decision_boundary_GMM();
-            namedWindow( "GMM", WINDOW_AUTOSIZE );
-            imshow( "GMM", img_dst );
+#if EM
+            find_decision_boundary_EM();
+            namedWindow( "EM", WINDOW_AUTOSIZE );
+            imshow( "EM", imgDst );
 #endif
         }
     }