added new sample on ML models
authorMaria Dimashova <no@email>
Sat, 30 Apr 2011 16:44:34 +0000 (16:44 +0000)
committerMaria Dimashova <no@email>
Sat, 30 Apr 2011 16:44:34 +0000 (16:44 +0000)
samples/cpp/points_classifier.cpp [new file with mode: 0644]

diff --git a/samples/cpp/points_classifier.cpp b/samples/cpp/points_classifier.cpp
new file mode 100644 (file)
index 0000000..c665608
--- /dev/null
@@ -0,0 +1,411 @@
+#include "opencv2/core/core.hpp"
+#include "opencv2/ml/ml.hpp"
+#include "opencv2/highgui/highgui.hpp"
+
+#include <stdio.h>
+
+using namespace std;
+using namespace cv;
+
+const Scalar WHITE_COLOR = CV_RGB(255,255,255);
+const string winName = "points";
+const int testStep = 5;
+
+
+Mat img, img_dst;
+RNG rng;
+
+vector<Point>  trainedPoints;
+vector<int>    trainedPointsMarkers;
+vector<Scalar> classColors;
+
+#define KNN 0
+#define SVM 0
+#define DT  1
+#define RF  0
+#define ANN 0
+#define GMM 0
+
+void on_mouse( int event, int x, int y, int /*flags*/, void* )
+{
+    if( img.empty() )
+        return;
+
+    int updateFlag = 0;
+
+    if( event == CV_EVENT_LBUTTONUP )
+    {
+        if( classColors.empty() )
+            return;
+
+        trainedPoints.push_back( Point(x,y) );
+        trainedPointsMarkers.push_back( classColors.size()-1 );
+        updateFlag = true;
+    }
+    else if( event == CV_EVENT_RBUTTONUP )
+    {
+        classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
+        updateFlag = true;
+    }
+
+    //draw
+    if( updateFlag )
+    {
+        img = Scalar::all(0);
+
+        // put the text
+        stringstream text;
+        text << "current class " << classColors.size()-1;
+        putText( img, text.str(), Point(10,25), CV_FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
+
+        text.str("");
+        text << "total classes " << classColors.size();
+        putText( img, text.str(), Point(10,50), CV_FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
+
+        text.str("");
+        text << "total points " << trainedPoints.size();
+        putText(img, text.str(), cvPoint(10,75), CV_FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
+
+        // draw points
+        for( size_t i = 0; i < trainedPoints.size(); i++ )
+            circle( img, trainedPoints[i], 5, classColors[trainedPointsMarkers[i]], -1 );
+
+        imshow( winName, img );
+   }
+}
+
+void prepare_train_data( Mat& samples, Mat& classes )
+{
+    Mat( trainedPoints ).copyTo( samples );
+    Mat( trainedPointsMarkers ).copyTo( classes );
+
+    // reshape trainData and change its type
+    samples = samples.reshape( 1, samples.rows );
+    samples.convertTo( samples, CV_32FC1 );
+}
+
+#if KNN
+void find_decision_boundary_KNN( int K )
+{
+    img.copyTo( img_dst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    // learn classifier
+    CvKNearest knnClassifier( trainSamples, trainClasses, Mat(), false, K );
+
+    Mat testSample( 1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            int response = (int)knnClassifier.find_nearest( testSample, K );
+            circle( img_dst, Point(x,y), 1, classColors[response] );
+        }
+    }
+}
+#endif
+
+#if SVM
+void find_decision_boundary_SVM( CvSVMParams params )
+{
+    img.copyTo( img_dst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    // learn classifier
+    CvSVM svmClassifier( trainSamples, trainClasses, Mat(), Mat(), params );
+
+    Mat testSample( 1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            int response = (int)svmClassifier.predict( testSample );
+            circle( img_dst, Point(x,y), 2, classColors[response], 1 );
+        }
+    }
+
+
+    for( int i = 0; i < svmClassifier.get_support_vector_count(); i++ )
+    {
+        const float* supportVector = svmClassifier.get_support_vector(i);
+        circle( img_dst, Point(supportVector[0],supportVector[1]), 5, CV_RGB(255,255,255), -1 );
+    }
+
+}
+#endif
+
+#if DT
+void find_decision_boundary_DT()
+{
+    img.copyTo( img_dst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    // learn classifier
+    CvDTree  dtree;
+
+    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
+    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
+
+    CvDTreeParams params;
+    params.max_depth = 8;
+    params.min_sample_count = 2;
+    params.use_surrogates = false;
+    params.cv_folds = 0; // the number of cross-validation folds
+    params.use_1se_rule = false;
+    params.truncate_pruned_tree = false;
+
+    dtree.train( trainSamples, CV_ROW_SAMPLE, trainClasses,
+                 Mat(), Mat(), var_types, Mat(), params );
+
+    Mat testSample(1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            int response = (int)dtree.predict( testSample )->value;
+            circle( img_dst, Point(x,y), 2, classColors[response], 1 );
+        }
+    }
+}
+#endif
+
+#if RF
+void find_decision_boundary_RF()
+{
+    img.copyTo( img_dst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    // learn classifier
+    CvRTrees  rtrees;
+
+    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
+    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
+
+    CvRTParams  params( 4, // max_depth,
+                        2, // min_sample_count,
+                        0.f, // regression_accuracy,
+                        false, // use_surrogates,
+                        16, // max_categories,
+                        0, // priors,
+                        false, // calc_var_importance,
+                        1, // nactive_vars,
+                        5, // max_num_of_trees_in_the_forest,
+                        0, // forest_accuracy,
+                        CV_TERMCRIT_ITER // termcrit_type
+                       );
+
+    rtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
+
+    Mat testSample(1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            int response = (int)rtrees.predict( testSample );
+            circle( img_dst, Point(x,y), 2, classColors[response], 1 );
+        }
+    }
+}
+
+#endif
+
+#if ANN
+void find_decision_boundary_ANN( const Mat&  layer_sizes )
+{
+    img.copyTo( img_dst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    // prerare trainClasses
+    trainClasses.create( trainedPoints.size(), classColors.size(), CV_32FC1 );
+    for( int i = 0; i <  trainClasses.rows; i++ )
+    {
+        for( int k = 0; k < trainClasses.cols; k++ )
+        {
+            if( k == trainedPointsMarkers[i] )
+                trainClasses.at<float>(i,k) = 1;
+            else
+                trainClasses.at<float>(i,k) = 0;
+        }
+    }
+
+    Mat weights( 1, trainedPoints.size(), CV_32FC1, Scalar::all(1) );
+
+    // learn classifier
+    CvANN_MLP  ann( layer_sizes, CvANN_MLP::SIGMOID_SYM, 1, 1 );
+    ann.train( trainSamples, trainClasses, weights );
+
+    Mat testSample( 1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            Mat outputs( 1, classColors.size(), CV_32FC1, testSample.data );
+            ann.predict( testSample, outputs );
+            Point maxLoc;
+            minMaxLoc( outputs, 0, 0, 0, &maxLoc );
+            circle( img_dst, Point(x,y), 2, classColors[maxLoc.x], 1 );
+        }
+    }
+}
+#endif
+
+#if GMM
+void find_decision_boundary_GMM()
+{
+    img.copyTo( img_dst );
+
+    Mat trainSamples, trainClasses;
+    prepare_train_data( trainSamples, trainClasses );
+
+    CvEM em;
+    CvEMParams params;
+    params.covs      = NULL;
+    params.means     = NULL;
+    params.weights   = NULL;
+    params.probs     = NULL;
+    params.nclusters = classColors.size();
+    params.cov_mat_type       = CvEM::COV_MAT_GENERIC;
+    params.start_step         = CvEM::START_AUTO_STEP;
+    params.term_crit.max_iter = 10;
+    params.term_crit.epsilon  = 0.1;
+    params.term_crit.type     = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
+
+
+    // learn classifier
+    em.train( trainSamples, Mat(), params, &trainClasses );
+
+    Mat testSample(1, 2, CV_32FC1 );
+    for( int y = 0; y < img.rows; y += testStep )
+    {
+        for( int x = 0; x < img.cols; x += testStep )
+        {
+            testSample.at<float>(0) = (float)x;
+            testSample.at<float>(1) = (float)y;
+
+            int response = (int)em.predict( testSample );
+            circle( img_dst, Point(x,y), 2, classColors[response], 1 );
+        }
+    }
+}
+#endif
+
+int main()
+{
+    cv::namedWindow( "points", 1 );
+    img.create( 480, 640, CV_8UC3 );
+    img_dst.create( 480, 640, CV_8UC3 );
+
+    imshow( "points", img );
+    cvSetMouseCallback( "points", on_mouse );
+
+    for(;;)
+    {
+        uchar key = waitKey();
+
+        if( key == 27 ) break;
+
+        if( key == 'i' ) // init
+        {
+            img = Scalar::all(0);
+
+            classColors.clear();
+            trainedPoints.clear();
+            trainedPointsMarkers.clear();
+
+            imshow( winName, img );
+        }
+
+        if( key == 'r' ) // run
+        {
+#if KNN
+            int K = 3;
+            find_decision_boundary_KNN( K );
+            namedWindow( "kNN", WINDOW_AUTOSIZE );
+            imshow( "kNN", img_dst );
+
+            K = 15;
+            find_decision_boundary_KNN( K );
+            namedWindow( "kNN2", WINDOW_AUTOSIZE );
+            imshow( "kNN2", img_dst );
+#endif
+
+#if SVM
+            //(1)-(2)separable and not sets
+            CvSVMParams params;
+            params.svm_type = CvSVM::C_SVC;
+            params.kernel_type = CvSVM::POLY; //CvSVM::LINEAR;
+            params.degree = 0.5;
+            params.gamma = 1;
+            params.coef0 = 1;
+            params.C = 1;
+            params.nu = 0.5;
+            params.p = 0;
+            params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01);
+
+            find_decision_boundary_SVM( params );
+            namedWindow( "classificationSVM1", WINDOW_AUTOSIZE );
+            imshow( "classificationSVM1", img_dst );
+
+            params.C = 10;
+            find_decision_boundary_SVM( params );
+            cvNamedWindow( "classificationSVM2", WINDOW_AUTOSIZE );
+            imshow( "classificationSVM2", img_dst );
+#endif
+
+#if DT
+            find_decision_boundary_DT();
+            namedWindow( "DT", 1 );
+            imshow( "DT", img_dst );
+#endif
+
+#if RF
+            find_decision_boundary_RF();
+            namedWindow( "RF", 1 );
+            imshow( "RF", img_dst);
+#endif
+
+#if ANN
+            Mat layer_sizes1( 1, 3, CV_32SC1 );
+            layer_sizes1.at<int>(0) = 2;
+            layer_sizes1.at<int>(1) = 5;
+            layer_sizes1.at<int>(2) = classColors.size();
+            find_decision_boundary_ANN( layer_sizes1 );
+            namedWindow( "ANN", WINDOW_AUTOSIZE );
+            imshow( "ANN", img_dst );
+#endif
+
+#if GMM
+            find_decision_boundary_GMM();
+            namedWindow( "GMM", WINDOW_AUTOSIZE );
+            imshow( "GMM", img_dst );
+#endif
+        }
+    }
+
+    return 1;
+}