Merge remote-tracking branch 'origin/2.4' into merge-2.4
[profile/ivi/opencv.git] / samples / cpp / points_classifier.cpp
1 #include "opencv2/core/core.hpp"
2 #include "opencv2/ml/ml.hpp"
3 #include "opencv2/highgui/highgui.hpp"
4
5 #include <stdio.h>
6
7 using namespace std;
8 using namespace cv;
9
10 const Scalar WHITE_COLOR = Scalar(255,255,255);
11 const string winName = "points";
12 const int testStep = 5;
13
14 Mat img, imgDst;
15 RNG rng;
16
17 vector<Point>  trainedPoints;
18 vector<int>    trainedPointsMarkers;
19 vector<Scalar> classColors;
20
21 #define _NBC_ 0 // normal Bayessian classifier
22 #define _KNN_ 0 // k nearest neighbors classifier
23 #define _SVM_ 0 // support vectors machine
24 #define _DT_  1 // decision tree
25 #define _BT_  0 // ADA Boost
26 #define _GBT_ 0 // gradient boosted trees
27 #define _RF_  0 // random forest
28 #define _ERT_ 0 // extremely randomized trees
29 #define _ANN_ 0 // artificial neural networks
30 #define _EM_  0 // expectation-maximization
31
32 static void on_mouse( int event, int x, int y, int /*flags*/, void* )
33 {
34     if( img.empty() )
35         return;
36
37     int updateFlag = 0;
38
39     if( event == EVENT_LBUTTONUP )
40     {
41         if( classColors.empty() )
42             return;
43
44         trainedPoints.push_back( Point(x,y) );
45         trainedPointsMarkers.push_back( (int)(classColors.size()-1) );
46         updateFlag = true;
47     }
48     else if( event == EVENT_RBUTTONUP )
49     {
50 #if _BT_
51         if( classColors.size() < 2 )
52         {
53 #endif
54             classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
55             updateFlag = true;
56 #if _BT_
57         }
58         else
59             cout << "New class can not be added, because CvBoost can only be used for 2-class classification" << endl;
60 #endif
61
62     }
63
64     //draw
65     if( updateFlag )
66     {
67         img = Scalar::all(0);
68
69         // put the text
70         stringstream text;
71         text << "current class " << classColors.size()-1;
72         putText( img, text.str(), Point(10,25), FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
73
74         text.str("");
75         text << "total classes " << classColors.size();
76         putText( img, text.str(), Point(10,50), FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
77
78         text.str("");
79         text << "total points " << trainedPoints.size();
80         putText(img, text.str(), Point(10,75), FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
81
82         // draw points
83         for( size_t i = 0; i < trainedPoints.size(); i++ )
84             circle( img, trainedPoints[i], 5, classColors[trainedPointsMarkers[i]], -1 );
85
86         imshow( winName, img );
87    }
88 }
89
90 static void prepare_train_data( Mat& samples, Mat& classes )
91 {
92     Mat( trainedPoints ).copyTo( samples );
93     Mat( trainedPointsMarkers ).copyTo( classes );
94
95     // reshape trainData and change its type
96     samples = samples.reshape( 1, samples.rows );
97     samples.convertTo( samples, CV_32FC1 );
98 }
99
100 #if _NBC_
101 static void find_decision_boundary_NBC()
102 {
103     img.copyTo( imgDst );
104
105     Mat trainSamples, trainClasses;
106     prepare_train_data( trainSamples, trainClasses );
107
108     // learn classifier
109     CvNormalBayesClassifier normalBayesClassifier( trainSamples, trainClasses );
110
111     Mat testSample( 1, 2, CV_32FC1 );
112     for( int y = 0; y < img.rows; y += testStep )
113     {
114         for( int x = 0; x < img.cols; x += testStep )
115         {
116             testSample.at<float>(0) = (float)x;
117             testSample.at<float>(1) = (float)y;
118
119             int response = (int)normalBayesClassifier.predict( testSample );
120             circle( imgDst, Point(x,y), 1, classColors[response] );
121         }
122     }
123 }
124 #endif
125
126
127 #if _KNN_
128 static void find_decision_boundary_KNN( int K )
129 {
130     img.copyTo( imgDst );
131
132     Mat trainSamples, trainClasses;
133     prepare_train_data( trainSamples, trainClasses );
134
135     // learn classifier
136     CvKNearest knnClassifier( trainSamples, trainClasses, Mat(), false, K );
137
138     Mat testSample( 1, 2, CV_32FC1 );
139     for( int y = 0; y < img.rows; y += testStep )
140     {
141         for( int x = 0; x < img.cols; x += testStep )
142         {
143             testSample.at<float>(0) = (float)x;
144             testSample.at<float>(1) = (float)y;
145
146             int response = (int)knnClassifier.find_nearest( testSample, K );
147             circle( imgDst, Point(x,y), 1, classColors[response] );
148         }
149     }
150 }
151 #endif
152
153 #if _SVM_
154 static void find_decision_boundary_SVM( CvSVMParams params )
155 {
156     img.copyTo( imgDst );
157
158     Mat trainSamples, trainClasses;
159     prepare_train_data( trainSamples, trainClasses );
160
161     // learn classifier
162     CvSVM svmClassifier( trainSamples, trainClasses, Mat(), Mat(), params );
163
164     Mat testSample( 1, 2, CV_32FC1 );
165     for( int y = 0; y < img.rows; y += testStep )
166     {
167         for( int x = 0; x < img.cols; x += testStep )
168         {
169             testSample.at<float>(0) = (float)x;
170             testSample.at<float>(1) = (float)y;
171
172             int response = (int)svmClassifier.predict( testSample );
173             circle( imgDst, Point(x,y), 2, classColors[response], 1 );
174         }
175     }
176
177
178     for( int i = 0; i < svmClassifier.get_support_vector_count(); i++ )
179     {
180         const float* supportVector = svmClassifier.get_support_vector(i);
181         circle( imgDst, Point(supportVector[0],supportVector[1]), 5, Scalar(255,255,255), -1 );
182     }
183
184 }
185 #endif
186
187 #if _DT_
188 static void find_decision_boundary_DT()
189 {
190     img.copyTo( imgDst );
191
192     Mat trainSamples, trainClasses;
193     prepare_train_data( trainSamples, trainClasses );
194
195     // learn classifier
196     CvDTree  dtree;
197
198     Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
199     var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
200
201     CvDTreeParams params;
202     params.max_depth = 8;
203     params.min_sample_count = 2;
204     params.use_surrogates = false;
205     params.cv_folds = 0; // the number of cross-validation folds
206     params.use_1se_rule = false;
207     params.truncate_pruned_tree = false;
208
209     dtree.train( trainSamples, CV_ROW_SAMPLE, trainClasses,
210                  Mat(), Mat(), var_types, Mat(), params );
211
212     Mat testSample(1, 2, CV_32FC1 );
213     for( int y = 0; y < img.rows; y += testStep )
214     {
215         for( int x = 0; x < img.cols; x += testStep )
216         {
217             testSample.at<float>(0) = (float)x;
218             testSample.at<float>(1) = (float)y;
219
220             int response = (int)dtree.predict( testSample )->value;
221             circle( imgDst, Point(x,y), 2, classColors[response], 1 );
222         }
223     }
224 }
225 #endif
226
227 #if _BT_
228 void find_decision_boundary_BT()
229 {
230     img.copyTo( imgDst );
231
232     Mat trainSamples, trainClasses;
233     prepare_train_data( trainSamples, trainClasses );
234
235     // learn classifier
236     CvBoost  boost;
237
238     Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
239     var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
240
241     CvBoostParams  params( CvBoost::DISCRETE, // boost_type
242                            100, // weak_count
243                            0.95, // weight_trim_rate
244                            2, // max_depth
245                            false, //use_surrogates
246                            0 // priors
247                          );
248
249     boost.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
250
251     Mat testSample(1, 2, CV_32FC1 );
252     for( int y = 0; y < img.rows; y += testStep )
253     {
254         for( int x = 0; x < img.cols; x += testStep )
255         {
256             testSample.at<float>(0) = (float)x;
257             testSample.at<float>(1) = (float)y;
258
259             int response = (int)boost.predict( testSample );
260             circle( imgDst, Point(x,y), 2, classColors[response], 1 );
261         }
262     }
263 }
264
265 #endif
266
267 #if _GBT_
268 void find_decision_boundary_GBT()
269 {
270     img.copyTo( imgDst );
271
272     Mat trainSamples, trainClasses;
273     prepare_train_data( trainSamples, trainClasses );
274
275     // learn classifier
276     CvGBTrees gbtrees;
277
278     Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
279     var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
280
281     CvGBTreesParams  params( CvGBTrees::DEVIANCE_LOSS, // loss_function_type
282                              100, // weak_count
283                              0.1f, // shrinkage
284                              1.0f, // subsample_portion
285                              2, // max_depth
286                              false // use_surrogates )
287                            );
288
289     gbtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
290
291     Mat testSample(1, 2, CV_32FC1 );
292     for( int y = 0; y < img.rows; y += testStep )
293     {
294         for( int x = 0; x < img.cols; x += testStep )
295         {
296             testSample.at<float>(0) = (float)x;
297             testSample.at<float>(1) = (float)y;
298
299             int response = (int)gbtrees.predict( testSample );
300             circle( imgDst, Point(x,y), 2, classColors[response], 1 );
301         }
302     }
303 }
304
305 #endif
306
307 #if _RF_
308 void find_decision_boundary_RF()
309 {
310     img.copyTo( imgDst );
311
312     Mat trainSamples, trainClasses;
313     prepare_train_data( trainSamples, trainClasses );
314
315     // learn classifier
316     CvRTrees  rtrees;
317     CvRTParams  params( 4, // max_depth,
318                         2, // min_sample_count,
319                         0.f, // regression_accuracy,
320                         false, // use_surrogates,
321                         16, // max_categories,
322                         0, // priors,
323                         false, // calc_var_importance,
324                         1, // nactive_vars,
325                         5, // max_num_of_trees_in_the_forest,
326                         0, // forest_accuracy,
327                         CV_TERMCRIT_ITER // termcrit_type
328                        );
329
330     rtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), Mat(), Mat(), params );
331
332     Mat testSample(1, 2, CV_32FC1 );
333     for( int y = 0; y < img.rows; y += testStep )
334     {
335         for( int x = 0; x < img.cols; x += testStep )
336         {
337             testSample.at<float>(0) = (float)x;
338             testSample.at<float>(1) = (float)y;
339
340             int response = (int)rtrees.predict( testSample );
341             circle( imgDst, Point(x,y), 2, classColors[response], 1 );
342         }
343     }
344 }
345
346 #endif
347
348 #if _ERT_
349 void find_decision_boundary_ERT()
350 {
351     img.copyTo( imgDst );
352
353     Mat trainSamples, trainClasses;
354     prepare_train_data( trainSamples, trainClasses );
355
356     // learn classifier
357     CvERTrees ertrees;
358
359     Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
360     var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
361
362     CvRTParams  params( 4, // max_depth,
363                         2, // min_sample_count,
364                         0.f, // regression_accuracy,
365                         false, // use_surrogates,
366                         16, // max_categories,
367                         0, // priors,
368                         false, // calc_var_importance,
369                         1, // nactive_vars,
370                         5, // max_num_of_trees_in_the_forest,
371                         0, // forest_accuracy,
372                         CV_TERMCRIT_ITER // termcrit_type
373                        );
374
375     ertrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
376
377     Mat testSample(1, 2, CV_32FC1 );
378     for( int y = 0; y < img.rows; y += testStep )
379     {
380         for( int x = 0; x < img.cols; x += testStep )
381         {
382             testSample.at<float>(0) = (float)x;
383             testSample.at<float>(1) = (float)y;
384
385             int response = (int)ertrees.predict( testSample );
386             circle( imgDst, Point(x,y), 2, classColors[response], 1 );
387         }
388     }
389 }
390 #endif
391
392 #if _ANN_
393 void find_decision_boundary_ANN( const Mat&  layer_sizes )
394 {
395     img.copyTo( imgDst );
396
397     Mat trainSamples, trainClasses;
398     prepare_train_data( trainSamples, trainClasses );
399
400     // prerare trainClasses
401     trainClasses.create( trainedPoints.size(), classColors.size(), CV_32FC1 );
402     for( int i = 0; i <  trainClasses.rows; i++ )
403     {
404         for( int k = 0; k < trainClasses.cols; k++ )
405         {
406             if( k == trainedPointsMarkers[i] )
407                 trainClasses.at<float>(i,k) = 1;
408             else
409                 trainClasses.at<float>(i,k) = 0;
410         }
411     }
412
413     Mat weights( 1, trainedPoints.size(), CV_32FC1, Scalar::all(1) );
414
415     // learn classifier
416     CvANN_MLP  ann( layer_sizes, CvANN_MLP::SIGMOID_SYM, 1, 1 );
417     ann.train( trainSamples, trainClasses, weights );
418
419     Mat testSample( 1, 2, CV_32FC1 );
420     for( int y = 0; y < img.rows; y += testStep )
421     {
422         for( int x = 0; x < img.cols; x += testStep )
423         {
424             testSample.at<float>(0) = (float)x;
425             testSample.at<float>(1) = (float)y;
426
427             Mat outputs( 1, classColors.size(), CV_32FC1, testSample.data );
428             ann.predict( testSample, outputs );
429             Point maxLoc;
430             minMaxLoc( outputs, 0, 0, 0, &maxLoc );
431             circle( imgDst, Point(x,y), 2, classColors[maxLoc.x], 1 );
432         }
433     }
434 }
435 #endif
436
437 #if _EM_
438 void find_decision_boundary_EM()
439 {
440     img.copyTo( imgDst );
441
442     Mat trainSamples, trainClasses;
443     prepare_train_data( trainSamples, trainClasses );
444
445     vector<cv::EM> em_models(classColors.size());
446
447     CV_Assert((int)trainClasses.total() == trainSamples.rows);
448     CV_Assert((int)trainClasses.type() == CV_32SC1);
449
450     for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
451     {
452         const int componentCount = 3;
453         em_models[modelIndex] = EM(componentCount, cv::EM::COV_MAT_DIAGONAL);
454
455         Mat modelSamples;
456         for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
457         {
458             if(trainClasses.at<int>(sampleIndex) == (int)modelIndex)
459                 modelSamples.push_back(trainSamples.row(sampleIndex));
460         }
461
462         // learn models
463         if(!modelSamples.empty())
464             em_models[modelIndex].train(modelSamples);
465     }
466
467     // classify coordinate plane points using the bayes classifier, i.e.
468     // y(x) = arg max_i=1_modelsCount likelihoods_i(x)
469     Mat testSample(1, 2, CV_32FC1 );
470     for( int y = 0; y < img.rows; y += testStep )
471     {
472         for( int x = 0; x < img.cols; x += testStep )
473         {
474             testSample.at<float>(0) = (float)x;
475             testSample.at<float>(1) = (float)y;
476
477             Mat logLikelihoods(1, em_models.size(), CV_64FC1, Scalar(-DBL_MAX));
478             for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
479             {
480                 if(em_models[modelIndex].isTrained())
481                     logLikelihoods.at<double>(modelIndex) = em_models[modelIndex].predict(testSample)[0];
482             }
483             Point maxLoc;
484             minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
485
486             int response = maxLoc.x;
487             circle( imgDst, Point(x,y), 2, classColors[response], 1 );
488         }
489     }
490 }
491 #endif
492
493 int main()
494 {
495     cout << "Use:" << endl
496          << "  right mouse button - to add new class;" << endl
497          << "  left mouse button - to add new point;" << endl
498          << "  key 'r' - to run the ML model;" << endl
499          << "  key 'i' - to init (clear) the data." << endl << endl;
500
501     cv::namedWindow( "points", 1 );
502     img.create( 480, 640, CV_8UC3 );
503     imgDst.create( 480, 640, CV_8UC3 );
504
505     imshow( "points", img );
506     setMouseCallback( "points", on_mouse );
507
508     for(;;)
509     {
510         uchar key = (uchar)waitKey();
511
512         if( key == 27 ) break;
513
514         if( key == 'i' ) // init
515         {
516             img = Scalar::all(0);
517
518             classColors.clear();
519             trainedPoints.clear();
520             trainedPointsMarkers.clear();
521
522             imshow( winName, img );
523         }
524
525         if( key == 'r' ) // run
526         {
527 #if _NBC_
528             find_decision_boundary_NBC();
529             namedWindow( "NormalBayesClassifier", WINDOW_AUTOSIZE );
530             imshow( "NormalBayesClassifier", imgDst );
531 #endif
532 #if _KNN_
533             int K = 3;
534             find_decision_boundary_KNN( K );
535             namedWindow( "kNN", WINDOW_AUTOSIZE );
536             imshow( "kNN", imgDst );
537
538             K = 15;
539             find_decision_boundary_KNN( K );
540             namedWindow( "kNN2", WINDOW_AUTOSIZE );
541             imshow( "kNN2", imgDst );
542 #endif
543
544 #if _SVM_
545             //(1)-(2)separable and not sets
546             CvSVMParams params;
547             params.svm_type = CvSVM::C_SVC;
548             params.kernel_type = CvSVM::POLY; //CvSVM::LINEAR;
549             params.degree = 0.5;
550             params.gamma = 1;
551             params.coef0 = 1;
552             params.C = 1;
553             params.nu = 0.5;
554             params.p = 0;
555             params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01);
556
557             find_decision_boundary_SVM( params );
558             namedWindow( "classificationSVM1", WINDOW_AUTOSIZE );
559             imshow( "classificationSVM1", imgDst );
560
561             params.C = 10;
562             find_decision_boundary_SVM( params );
563             namedWindow( "classificationSVM2", WINDOW_AUTOSIZE );
564             imshow( "classificationSVM2", imgDst );
565 #endif
566
567 #if _DT_
568             find_decision_boundary_DT();
569             namedWindow( "DT", WINDOW_AUTOSIZE );
570             imshow( "DT", imgDst );
571 #endif
572
573 #if _BT_
574             find_decision_boundary_BT();
575             namedWindow( "BT", WINDOW_AUTOSIZE );
576             imshow( "BT", imgDst);
577 #endif
578
579 #if _GBT_
580             find_decision_boundary_GBT();
581             namedWindow( "GBT", WINDOW_AUTOSIZE );
582             imshow( "GBT", imgDst);
583 #endif
584
585 #if _RF_
586             find_decision_boundary_RF();
587             namedWindow( "RF", WINDOW_AUTOSIZE );
588             imshow( "RF", imgDst);
589 #endif
590
591 #if _ERT_
592             find_decision_boundary_ERT();
593             namedWindow( "ERT", WINDOW_AUTOSIZE );
594             imshow( "ERT", imgDst);
595 #endif
596
597 #if _ANN_
598             Mat layer_sizes1( 1, 3, CV_32SC1 );
599             layer_sizes1.at<int>(0) = 2;
600             layer_sizes1.at<int>(1) = 5;
601             layer_sizes1.at<int>(2) = classColors.size();
602             find_decision_boundary_ANN( layer_sizes1 );
603             namedWindow( "ANN", WINDOW_AUTOSIZE );
604             imshow( "ANN", imgDst );
605 #endif
606
607 #if _EM_
608             find_decision_boundary_EM();
609             namedWindow( "EM", WINDOW_AUTOSIZE );
610             imshow( "EM", imgDst );
611 #endif
612         }
613     }
614
615     return 1;
616 }