1 #include "opencv2/core/core.hpp"
2 #include "opencv2/ml/ml.hpp"
3 #include "opencv2/highgui/highgui.hpp"
10 const Scalar WHITE_COLOR = CV_RGB(255,255,255);
11 const string winName = "points";
12 const int testStep = 5;
17 vector<Point> trainedPoints;
18 vector<int> trainedPointsMarkers;
19 vector<Scalar> classColors;
21 #define _NBC_ 0 // normal Bayessian classifier
22 #define _KNN_ 0 // k nearest neighbors classifier
23 #define _SVM_ 0 // support vectors machine
24 #define _DT_ 1 // decision tree
25 #define _BT_ 0 // ADA Boost
26 #define _GBT_ 0 // gradient boosted trees
27 #define _RF_ 0 // random forest
28 #define _ERT_ 0 // extremely randomized trees
29 #define _ANN_ 0 // artificial neural networks
30 #define _EM_ 0 // expectation-maximization
32 static void on_mouse( int event, int x, int y, int /*flags*/, void* )
39 if( event == EVENT_LBUTTONUP )
41 if( classColors.empty() )
44 trainedPoints.push_back( Point(x,y) );
45 trainedPointsMarkers.push_back( (int)(classColors.size()-1) );
48 else if( event == EVENT_RBUTTONUP )
51 if( classColors.size() < 2 )
54 classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
59 cout << "New class can not be added, because CvBoost can only be used for 2-class classification" << endl;
71 text << "current class " << classColors.size()-1;
72 putText( img, text.str(), Point(10,25), CV_FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
75 text << "total classes " << classColors.size();
76 putText( img, text.str(), Point(10,50), CV_FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
79 text << "total points " << trainedPoints.size();
80 putText(img, text.str(), cvPoint(10,75), CV_FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
83 for( size_t i = 0; i < trainedPoints.size(); i++ )
84 circle( img, trainedPoints[i], 5, classColors[trainedPointsMarkers[i]], -1 );
86 imshow( winName, img );
90 static void prepare_train_data( Mat& samples, Mat& classes )
92 Mat( trainedPoints ).copyTo( samples );
93 Mat( trainedPointsMarkers ).copyTo( classes );
95 // reshape trainData and change its type
96 samples = samples.reshape( 1, samples.rows );
97 samples.convertTo( samples, CV_32FC1 );
101 static void find_decision_boundary_NBC()
103 img.copyTo( imgDst );
105 Mat trainSamples, trainClasses;
106 prepare_train_data( trainSamples, trainClasses );
109 CvNormalBayesClassifier normalBayesClassifier( trainSamples, trainClasses );
111 Mat testSample( 1, 2, CV_32FC1 );
112 for( int y = 0; y < img.rows; y += testStep )
114 for( int x = 0; x < img.cols; x += testStep )
116 testSample.at<float>(0) = (float)x;
117 testSample.at<float>(1) = (float)y;
119 int response = (int)normalBayesClassifier.predict( testSample );
120 circle( imgDst, Point(x,y), 1, classColors[response] );
128 static void find_decision_boundary_KNN( int K )
130 img.copyTo( imgDst );
132 Mat trainSamples, trainClasses;
133 prepare_train_data( trainSamples, trainClasses );
136 CvKNearest knnClassifier( trainSamples, trainClasses, Mat(), false, K );
138 Mat testSample( 1, 2, CV_32FC1 );
139 for( int y = 0; y < img.rows; y += testStep )
141 for( int x = 0; x < img.cols; x += testStep )
143 testSample.at<float>(0) = (float)x;
144 testSample.at<float>(1) = (float)y;
146 int response = (int)knnClassifier.find_nearest( testSample, K );
147 circle( imgDst, Point(x,y), 1, classColors[response] );
154 static void find_decision_boundary_SVM( CvSVMParams params )
156 img.copyTo( imgDst );
158 Mat trainSamples, trainClasses;
159 prepare_train_data( trainSamples, trainClasses );
162 CvSVM svmClassifier( trainSamples, trainClasses, Mat(), Mat(), params );
164 Mat testSample( 1, 2, CV_32FC1 );
165 for( int y = 0; y < img.rows; y += testStep )
167 for( int x = 0; x < img.cols; x += testStep )
169 testSample.at<float>(0) = (float)x;
170 testSample.at<float>(1) = (float)y;
172 int response = (int)svmClassifier.predict( testSample );
173 circle( imgDst, Point(x,y), 2, classColors[response], 1 );
178 for( int i = 0; i < svmClassifier.get_support_vector_count(); i++ )
180 const float* supportVector = svmClassifier.get_support_vector(i);
181 circle( imgDst, Point(supportVector[0],supportVector[1]), 5, CV_RGB(255,255,255), -1 );
188 static void find_decision_boundary_DT()
190 img.copyTo( imgDst );
192 Mat trainSamples, trainClasses;
193 prepare_train_data( trainSamples, trainClasses );
198 Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
199 var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
201 CvDTreeParams params;
202 params.max_depth = 8;
203 params.min_sample_count = 2;
204 params.use_surrogates = false;
205 params.cv_folds = 0; // the number of cross-validation folds
206 params.use_1se_rule = false;
207 params.truncate_pruned_tree = false;
209 dtree.train( trainSamples, CV_ROW_SAMPLE, trainClasses,
210 Mat(), Mat(), var_types, Mat(), params );
212 Mat testSample(1, 2, CV_32FC1 );
213 for( int y = 0; y < img.rows; y += testStep )
215 for( int x = 0; x < img.cols; x += testStep )
217 testSample.at<float>(0) = (float)x;
218 testSample.at<float>(1) = (float)y;
220 int response = (int)dtree.predict( testSample )->value;
221 circle( imgDst, Point(x,y), 2, classColors[response], 1 );
228 void find_decision_boundary_BT()
230 img.copyTo( imgDst );
232 Mat trainSamples, trainClasses;
233 prepare_train_data( trainSamples, trainClasses );
238 Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
239 var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
241 CvBoostParams params( CvBoost::DISCRETE, // boost_type
243 0.95, // weight_trim_rate
245 false, //use_surrogates
249 boost.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
251 Mat testSample(1, 2, CV_32FC1 );
252 for( int y = 0; y < img.rows; y += testStep )
254 for( int x = 0; x < img.cols; x += testStep )
256 testSample.at<float>(0) = (float)x;
257 testSample.at<float>(1) = (float)y;
259 int response = (int)boost.predict( testSample );
260 circle( imgDst, Point(x,y), 2, classColors[response], 1 );
268 void find_decision_boundary_GBT()
270 img.copyTo( imgDst );
272 Mat trainSamples, trainClasses;
273 prepare_train_data( trainSamples, trainClasses );
278 Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
279 var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
281 CvGBTreesParams params( CvGBTrees::DEVIANCE_LOSS, // loss_function_type
284 1.0f, // subsample_portion
286 false // use_surrogates )
289 gbtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
291 Mat testSample(1, 2, CV_32FC1 );
292 for( int y = 0; y < img.rows; y += testStep )
294 for( int x = 0; x < img.cols; x += testStep )
296 testSample.at<float>(0) = (float)x;
297 testSample.at<float>(1) = (float)y;
299 int response = (int)gbtrees.predict( testSample );
300 circle( imgDst, Point(x,y), 2, classColors[response], 1 );
308 void find_decision_boundary_RF()
310 img.copyTo( imgDst );
312 Mat trainSamples, trainClasses;
313 prepare_train_data( trainSamples, trainClasses );
317 CvRTParams params( 4, // max_depth,
318 2, // min_sample_count,
319 0.f, // regression_accuracy,
320 false, // use_surrogates,
321 16, // max_categories,
323 false, // calc_var_importance,
325 5, // max_num_of_trees_in_the_forest,
326 0, // forest_accuracy,
327 CV_TERMCRIT_ITER // termcrit_type
330 rtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), Mat(), Mat(), params );
332 Mat testSample(1, 2, CV_32FC1 );
333 for( int y = 0; y < img.rows; y += testStep )
335 for( int x = 0; x < img.cols; x += testStep )
337 testSample.at<float>(0) = (float)x;
338 testSample.at<float>(1) = (float)y;
340 int response = (int)rtrees.predict( testSample );
341 circle( imgDst, Point(x,y), 2, classColors[response], 1 );
349 void find_decision_boundary_ERT()
351 img.copyTo( imgDst );
353 Mat trainSamples, trainClasses;
354 prepare_train_data( trainSamples, trainClasses );
359 Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
360 var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
362 CvRTParams params( 4, // max_depth,
363 2, // min_sample_count,
364 0.f, // regression_accuracy,
365 false, // use_surrogates,
366 16, // max_categories,
368 false, // calc_var_importance,
370 5, // max_num_of_trees_in_the_forest,
371 0, // forest_accuracy,
372 CV_TERMCRIT_ITER // termcrit_type
375 ertrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
377 Mat testSample(1, 2, CV_32FC1 );
378 for( int y = 0; y < img.rows; y += testStep )
380 for( int x = 0; x < img.cols; x += testStep )
382 testSample.at<float>(0) = (float)x;
383 testSample.at<float>(1) = (float)y;
385 int response = (int)ertrees.predict( testSample );
386 circle( imgDst, Point(x,y), 2, classColors[response], 1 );
393 void find_decision_boundary_ANN( const Mat& layer_sizes )
395 img.copyTo( imgDst );
397 Mat trainSamples, trainClasses;
398 prepare_train_data( trainSamples, trainClasses );
400 // prerare trainClasses
401 trainClasses.create( trainedPoints.size(), classColors.size(), CV_32FC1 );
402 for( int i = 0; i < trainClasses.rows; i++ )
404 for( int k = 0; k < trainClasses.cols; k++ )
406 if( k == trainedPointsMarkers[i] )
407 trainClasses.at<float>(i,k) = 1;
409 trainClasses.at<float>(i,k) = 0;
413 Mat weights( 1, trainedPoints.size(), CV_32FC1, Scalar::all(1) );
416 CvANN_MLP ann( layer_sizes, CvANN_MLP::SIGMOID_SYM, 1, 1 );
417 ann.train( trainSamples, trainClasses, weights );
419 Mat testSample( 1, 2, CV_32FC1 );
420 for( int y = 0; y < img.rows; y += testStep )
422 for( int x = 0; x < img.cols; x += testStep )
424 testSample.at<float>(0) = (float)x;
425 testSample.at<float>(1) = (float)y;
427 Mat outputs( 1, classColors.size(), CV_32FC1, testSample.data );
428 ann.predict( testSample, outputs );
430 minMaxLoc( outputs, 0, 0, 0, &maxLoc );
431 circle( imgDst, Point(x,y), 2, classColors[maxLoc.x], 1 );
438 void find_decision_boundary_EM()
440 img.copyTo( imgDst );
442 Mat trainSamples, trainClasses;
443 prepare_train_data( trainSamples, trainClasses );
445 vector<cv::EM> em_models(classColors.size());
447 CV_Assert((int)trainClasses.total() == trainSamples.rows);
448 CV_Assert((int)trainClasses.type() == CV_32SC1);
450 for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
452 const int componentCount = 3;
453 em_models[modelIndex] = EM(componentCount, cv::EM::COV_MAT_DIAGONAL);
456 for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
458 if(trainClasses.at<int>(sampleIndex) == (int)modelIndex)
459 modelSamples.push_back(trainSamples.row(sampleIndex));
463 if(!modelSamples.empty())
464 em_models[modelIndex].train(modelSamples);
467 // classify coordinate plane points using the bayes classifier, i.e.
468 // y(x) = arg max_i=1_modelsCount likelihoods_i(x)
469 Mat testSample(1, 2, CV_32FC1 );
470 for( int y = 0; y < img.rows; y += testStep )
472 for( int x = 0; x < img.cols; x += testStep )
474 testSample.at<float>(0) = (float)x;
475 testSample.at<float>(1) = (float)y;
477 Mat logLikelihoods(1, em_models.size(), CV_64FC1, Scalar(-DBL_MAX));
478 for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
480 if(em_models[modelIndex].isTrained())
481 logLikelihoods.at<double>(modelIndex) = em_models[modelIndex].predict(testSample)[0];
484 minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
486 int response = maxLoc.x;
487 circle( imgDst, Point(x,y), 2, classColors[response], 1 );
495 cout << "Use:" << endl
496 << " right mouse button - to add new class;" << endl
497 << " left mouse button - to add new point;" << endl
498 << " key 'r' - to run the ML model;" << endl
499 << " key 'i' - to init (clear) the data." << endl << endl;
501 cv::namedWindow( "points", 1 );
502 img.create( 480, 640, CV_8UC3 );
503 imgDst.create( 480, 640, CV_8UC3 );
505 imshow( "points", img );
506 setMouseCallback( "points", on_mouse );
510 uchar key = (uchar)waitKey();
512 if( key == 27 ) break;
514 if( key == 'i' ) // init
516 img = Scalar::all(0);
519 trainedPoints.clear();
520 trainedPointsMarkers.clear();
522 imshow( winName, img );
525 if( key == 'r' ) // run
528 find_decision_boundary_NBC();
529 cvNamedWindow( "NormalBayesClassifier", WINDOW_AUTOSIZE );
530 imshow( "NormalBayesClassifier", imgDst );
534 find_decision_boundary_KNN( K );
535 namedWindow( "kNN", WINDOW_AUTOSIZE );
536 imshow( "kNN", imgDst );
539 find_decision_boundary_KNN( K );
540 namedWindow( "kNN2", WINDOW_AUTOSIZE );
541 imshow( "kNN2", imgDst );
545 //(1)-(2)separable and not sets
547 params.svm_type = CvSVM::C_SVC;
548 params.kernel_type = CvSVM::POLY; //CvSVM::LINEAR;
555 params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01);
557 find_decision_boundary_SVM( params );
558 namedWindow( "classificationSVM1", WINDOW_AUTOSIZE );
559 imshow( "classificationSVM1", imgDst );
562 find_decision_boundary_SVM( params );
563 cvNamedWindow( "classificationSVM2", WINDOW_AUTOSIZE );
564 imshow( "classificationSVM2", imgDst );
568 find_decision_boundary_DT();
569 namedWindow( "DT", WINDOW_AUTOSIZE );
570 imshow( "DT", imgDst );
574 find_decision_boundary_BT();
575 namedWindow( "BT", WINDOW_AUTOSIZE );
576 imshow( "BT", imgDst);
580 find_decision_boundary_GBT();
581 namedWindow( "GBT", WINDOW_AUTOSIZE );
582 imshow( "GBT", imgDst);
586 find_decision_boundary_RF();
587 namedWindow( "RF", WINDOW_AUTOSIZE );
588 imshow( "RF", imgDst);
592 find_decision_boundary_ERT();
593 namedWindow( "ERT", WINDOW_AUTOSIZE );
594 imshow( "ERT", imgDst);
598 Mat layer_sizes1( 1, 3, CV_32SC1 );
599 layer_sizes1.at<int>(0) = 2;
600 layer_sizes1.at<int>(1) = 5;
601 layer_sizes1.at<int>(2) = classColors.size();
602 find_decision_boundary_ANN( layer_sizes1 );
603 namedWindow( "ANN", WINDOW_AUTOSIZE );
604 imshow( "ANN", imgDst );
608 find_decision_boundary_EM();
609 namedWindow( "EM", WINDOW_AUTOSIZE );
610 imshow( "EM", imgDst );