3aa4d9b1379da92c14160a6b05afa4382a940d35
[profile/ivi/opencv.git] / samples / cpp / points_classifier.cpp
1 #include "opencv2/opencv_modules.hpp"
2 #include "opencv2/core/core.hpp"
3 #include "opencv2/ml/ml.hpp"
4 #include "opencv2/highgui/highgui.hpp"
5 #ifdef HAVE_OPENCV_OCL
6 #define _OCL_KNN_ 1 // select whether using ocl::KNN method or not, default is using
7 #define _OCL_SVM_ 1 // select whether using ocl::svm method or not, default is using
8 #include "opencv2/ocl/ocl.hpp"
9 #endif
10
11 #include <stdio.h>
12
13 using namespace std;
14 using namespace cv;
15 using namespace cv::ml;
16
17 const Scalar WHITE_COLOR = Scalar(255,255,255);
18 const string winName = "points";
19 const int testStep = 5;
20
21 Mat img, imgDst;
22 RNG rng;
23
24 vector<Point>  trainedPoints;
25 vector<int>    trainedPointsMarkers;
26 const int MAX_CLASSES = 2;
27 vector<Vec3b>  classColors(MAX_CLASSES);
28 int currentClass = 0;
29 vector<int> classCounters(MAX_CLASSES);
30
31 #define _NBC_ 1 // normal Bayessian classifier
32 #define _KNN_ 1 // k nearest neighbors classifier
33 #define _SVM_ 1 // support vectors machine
34 #define _DT_  1 // decision tree
35 #define _BT_  1 // ADA Boost
36 #define _GBT_ 0 // gradient boosted trees
37 #define _RF_  1 // random forest
38 #define _ANN_ 1 // artificial neural networks
39 #define _EM_  1 // expectation-maximization
40
41 static void on_mouse( int event, int x, int y, int /*flags*/, void* )
42 {
43     if( img.empty() )
44         return;
45
46     int updateFlag = 0;
47
48     if( event == EVENT_LBUTTONUP )
49     {
50         trainedPoints.push_back( Point(x,y) );
51         trainedPointsMarkers.push_back( currentClass );
52         classCounters[currentClass]++;
53         updateFlag = true;
54     }
55
56     //draw
57     if( updateFlag )
58     {
59         img = Scalar::all(0);
60
61         // draw points
62         for( size_t i = 0; i < trainedPoints.size(); i++ )
63         {
64             Vec3b c = classColors[trainedPointsMarkers[i]];
65             circle( img, trainedPoints[i], 5, Scalar(c), -1 );
66         }
67
68         imshow( winName, img );
69    }
70 }
71
72 static Mat prepare_train_samples(const vector<Point>& pts)
73 {
74     Mat samples;
75     Mat(pts).reshape(1, (int)pts.size()).convertTo(samples, CV_32F);
76     return samples;
77 }
78
79 static Ptr<TrainData> prepare_train_data()
80 {
81     Mat samples = prepare_train_samples(trainedPoints);
82     return TrainData::create(samples, ROW_SAMPLE, Mat(trainedPointsMarkers));
83 }
84
85 static void predict_and_paint(const Ptr<StatModel>& model, Mat& dst)
86 {
87     Mat testSample( 1, 2, CV_32FC1 );
88     for( int y = 0; y < img.rows; y += testStep )
89     {
90         for( int x = 0; x < img.cols; x += testStep )
91         {
92             testSample.at<float>(0) = (float)x;
93             testSample.at<float>(1) = (float)y;
94
95             int response = (int)model->predict( testSample );
96             dst.at<Vec3b>(y, x) = classColors[response];
97         }
98     }
99 }
100
101 #if _NBC_
102 static void find_decision_boundary_NBC()
103 {
104     // learn classifier
105     Ptr<NormalBayesClassifier> normalBayesClassifier = StatModel::train<NormalBayesClassifier>(prepare_train_data(), NormalBayesClassifier::Params());
106
107     predict_and_paint(normalBayesClassifier, imgDst);
108 }
109 #endif
110
111
112 #if _KNN_
113 static void find_decision_boundary_KNN( int K )
114 {
115     Ptr<KNearest> knn = StatModel::train<KNearest>(prepare_train_data(), KNearest::Params(K, true));
116     predict_and_paint(knn, imgDst);
117 }
118 #endif
119
120 #if _SVM_
121 static void find_decision_boundary_SVM( SVM::Params params )
122 {
123     Ptr<SVM> svm = StatModel::train<SVM>(prepare_train_data(), params);
124     predict_and_paint(svm, imgDst);
125
126     Mat sv = svm->getSupportVectors();
127     for( int i = 0; i < sv.rows; i++ )
128     {
129         const float* supportVector = sv.ptr<float>(i);
130         circle( imgDst, Point(saturate_cast<int>(supportVector[0]),saturate_cast<int>(supportVector[1])), 5, Scalar(255,255,255), -1 );
131     }
132 }
133 #endif
134
135 #if _DT_
136 static void find_decision_boundary_DT()
137 {
138     DTrees::Params params;
139     params.maxDepth = 8;
140     params.minSampleCount = 2;
141     params.useSurrogates = false;
142     params.CVFolds = 0; // the number of cross-validation folds
143     params.use1SERule = false;
144     params.truncatePrunedTree = false;
145
146     Ptr<DTrees> dtree = StatModel::train<DTrees>(prepare_train_data(), params);
147
148     predict_and_paint(dtree, imgDst);
149 }
150 #endif
151
152 #if _BT_
153 static void find_decision_boundary_BT()
154 {
155     Boost::Params params( Boost::DISCRETE, // boost_type
156                           100, // weak_count
157                           0.95, // weight_trim_rate
158                           2, // max_depth
159                           false, //use_surrogates
160                           Mat() // priors
161                           );
162
163     Ptr<Boost> boost = StatModel::train<Boost>(prepare_train_data(), params);
164     predict_and_paint(boost, imgDst);
165 }
166
167 #endif
168
169 #if _GBT_
170 static void find_decision_boundary_GBT()
171 {
172     GBTrees::Params params( GBTrees::DEVIANCE_LOSS, // loss_function_type
173                          100, // weak_count
174                          0.1f, // shrinkage
175                          1.0f, // subsample_portion
176                          2, // max_depth
177                          false // use_surrogates )
178                          );
179
180     Ptr<GBTrees> gbtrees = StatModel::train<GBTrees>(prepare_train_data(), params);
181     predict_and_paint(gbtrees, imgDst);
182 }
183 #endif
184
185 #if _RF_
186 static void find_decision_boundary_RF()
187 {
188     RTrees::Params  params( 4, // max_depth,
189                         2, // min_sample_count,
190                         0.f, // regression_accuracy,
191                         false, // use_surrogates,
192                         16, // max_categories,
193                         Mat(), // priors,
194                         false, // calc_var_importance,
195                         1, // nactive_vars,
196                         TermCriteria(TermCriteria::MAX_ITER, 5, 0) // max_num_of_trees_in_the_forest,
197                        );
198
199     Ptr<RTrees> rtrees = StatModel::train<RTrees>(prepare_train_data(), params);
200     predict_and_paint(rtrees, imgDst);
201 }
202
203 #endif
204
205 #if _ANN_
206 static void find_decision_boundary_ANN( const Mat&  layer_sizes )
207 {
208     ANN_MLP::Params params(layer_sizes, ANN_MLP::SIGMOID_SYM, 1, 1, TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON),
209                            ANN_MLP::Params::BACKPROP, 0.001);
210
211     Mat trainClasses = Mat::zeros( trainedPoints.size(), classColors.size(), CV_32FC1 );
212     for( int i = 0; i < trainClasses.rows; i++ )
213     {
214         trainClasses.at<float>(i, trainedPointsMarkers[i]) = 1.f;
215     }
216
217     Mat samples = prepare_train_samples(trainedPoints);
218     Ptr<TrainData> tdata = TrainData::create(samples, ROW_SAMPLE, trainClasses);
219
220     Ptr<ANN_MLP> ann = StatModel::train<ANN_MLP>(tdata, params);
221     predict_and_paint(ann, imgDst);
222 }
223 #endif
224
225 #if _EM_
226 static void find_decision_boundary_EM()
227 {
228     img.copyTo( imgDst );
229
230     Mat samples = prepare_train_samples(trainedPoints);
231
232     int i, j, nmodels = (int)classColors.size();
233     vector<Ptr<EM> > em_models(nmodels);
234     Mat modelSamples;
235
236     for( i = 0; i < nmodels; i++ )
237     {
238         const int componentCount = 3;
239
240         modelSamples.release();
241         for( j = 0; j < samples.rows; j++ )
242         {
243             if( trainedPointsMarkers[j] == i )
244                 modelSamples.push_back(samples.row(j));
245         }
246
247         // learn models
248         if( !modelSamples.empty() )
249         {
250             em_models[i] = EM::train(modelSamples, noArray(), noArray(), noArray(),
251                                    EM::Params(componentCount, EM::COV_MAT_DIAGONAL));
252         }
253     }
254
255     // classify coordinate plane points using the bayes classifier, i.e.
256     // y(x) = arg max_i=1_modelsCount likelihoods_i(x)
257     Mat testSample(1, 2, CV_32FC1 );
258     Mat logLikelihoods(1, nmodels, CV_64FC1, Scalar(-DBL_MAX));
259
260     for( int y = 0; y < img.rows; y += testStep )
261     {
262         for( int x = 0; x < img.cols; x += testStep )
263         {
264             testSample.at<float>(0) = (float)x;
265             testSample.at<float>(1) = (float)y;
266
267             for( i = 0; i < nmodels; i++ )
268             {
269                 if( !em_models[i].empty() )
270                     logLikelihoods.at<double>(i) = em_models[i]->predict2(testSample, noArray())[0];
271             }
272             Point maxLoc;
273             minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
274             imgDst.at<Vec3b>(y, x) = classColors[maxLoc.x];
275         }
276     }
277 }
278 #endif
279
280 int main()
281 {
282     cout << "Use:" << endl
283          << "  key '0' .. '1' - switch to class #n" << endl
284          << "  left mouse button - to add new point;" << endl
285          << "  key 'r' - to run the ML model;" << endl
286          << "  key 'i' - to init (clear) the data." << endl << endl;
287
288     cv::namedWindow( "points", 1 );
289     img.create( 480, 640, CV_8UC3 );
290     imgDst.create( 480, 640, CV_8UC3 );
291
292     imshow( "points", img );
293     setMouseCallback( "points", on_mouse );
294
295     classColors[0] = Vec3b(0, 255, 0);
296     classColors[1] = Vec3b(0, 0, 255);
297
298     for(;;)
299     {
300         uchar key = (uchar)waitKey();
301
302         if( key == 27 ) break;
303
304         if( key == 'i' ) // init
305         {
306             img = Scalar::all(0);
307
308             trainedPoints.clear();
309             trainedPointsMarkers.clear();
310             classCounters.assign(MAX_CLASSES, 0);
311
312             imshow( winName, img );
313         }
314
315         if( key == '0' || key == '1' )
316         {
317             currentClass = key - '0';
318         }
319
320         if( key == 'r' ) // run
321         {
322             double minVal = 0;
323             minMaxLoc(classCounters, &minVal, 0, 0, 0);
324             if( minVal == 0 )
325             {
326                 printf("each class should have at least 1 point\n");
327                 continue;
328             }
329             img.copyTo( imgDst );
330 #if _NBC_
331             find_decision_boundary_NBC();
332             imshow( "NormalBayesClassifier", imgDst );
333 #endif
334 #if _KNN_
335             int K = 3;
336             find_decision_boundary_KNN( K );
337             imshow( "kNN", imgDst );
338
339             K = 15;
340             find_decision_boundary_KNN( K );
341             imshow( "kNN2", imgDst );
342 #endif
343
344 #if _SVM_
345             //(1)-(2)separable and not sets
346             SVM::Params params;
347             params.svmType = SVM::C_SVC;
348             params.kernelType = SVM::POLY; //CvSVM::LINEAR;
349             params.degree = 0.5;
350             params.gamma = 1;
351             params.coef0 = 1;
352             params.C = 1;
353             params.nu = 0.5;
354             params.p = 0;
355             params.termCrit = TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 1000, 0.01);
356
357             find_decision_boundary_SVM( params );
358             imshow( "classificationSVM1", imgDst );
359
360             params.C = 10;
361             find_decision_boundary_SVM( params );
362             imshow( "classificationSVM2", imgDst );
363 #endif
364
365 #if _DT_
366             find_decision_boundary_DT();
367             imshow( "DT", imgDst );
368 #endif
369
370 #if _BT_
371             find_decision_boundary_BT();
372             imshow( "BT", imgDst);
373 #endif
374
375 #if _GBT_
376             find_decision_boundary_GBT();
377             imshow( "GBT", imgDst);
378 #endif
379
380 #if _RF_
381             find_decision_boundary_RF();
382             imshow( "RF", imgDst);
383 #endif
384
385 #if _ANN_
386             Mat layer_sizes1( 1, 3, CV_32SC1 );
387             layer_sizes1.at<int>(0) = 2;
388             layer_sizes1.at<int>(1) = 5;
389             layer_sizes1.at<int>(2) = classColors.size();
390             find_decision_boundary_ANN( layer_sizes1 );
391             imshow( "ANN", imgDst );
392 #endif
393
394 #if _EM_
395             find_decision_boundary_EM();
396             imshow( "EM", imgDst );
397 #endif
398         }
399     }
400
401     return 1;
402 }