1 #include "opencv2/core.hpp"
2 #include "opencv2/imgproc.hpp"
3 #include "opencv2/ml.hpp"
4 #include "opencv2/highgui.hpp"
6 #define _OCL_KNN_ 1 // select whether using ocl::KNN method or not, default is using
7 #define _OCL_SVM_ 1 // select whether using ocl::svm method or not, default is using
8 #include "opencv2/ocl/ocl.hpp"
15 using namespace cv::ml;
17 const Scalar WHITE_COLOR = Scalar(255,255,255);
18 const string winName = "points";
19 const int testStep = 5;
24 vector<Point> trainedPoints;
25 vector<int> trainedPointsMarkers;
26 const int MAX_CLASSES = 2;
27 vector<Vec3b> classColors(MAX_CLASSES);
29 vector<int> classCounters(MAX_CLASSES);
31 #define _NBC_ 1 // normal Bayessian classifier
32 #define _KNN_ 1 // k nearest neighbors classifier
33 #define _SVM_ 1 // support vectors machine
34 #define _DT_ 1 // decision tree
35 #define _BT_ 1 // ADA Boost
36 #define _GBT_ 0 // gradient boosted trees
37 #define _RF_ 1 // random forest
38 #define _ANN_ 1 // artificial neural networks
39 #define _EM_ 1 // expectation-maximization
41 static void on_mouse( int event, int x, int y, int /*flags*/, void* )
48 if( event == EVENT_LBUTTONUP )
50 trainedPoints.push_back( Point(x,y) );
51 trainedPointsMarkers.push_back( currentClass );
52 classCounters[currentClass]++;
62 for( size_t i = 0; i < trainedPoints.size(); i++ )
64 Vec3b c = classColors[trainedPointsMarkers[i]];
65 circle( img, trainedPoints[i], 5, Scalar(c), -1 );
68 imshow( winName, img );
72 static Mat prepare_train_samples(const vector<Point>& pts)
75 Mat(pts).reshape(1, (int)pts.size()).convertTo(samples, CV_32F);
79 static Ptr<TrainData> prepare_train_data()
81 Mat samples = prepare_train_samples(trainedPoints);
82 return TrainData::create(samples, ROW_SAMPLE, Mat(trainedPointsMarkers));
85 static void predict_and_paint(const Ptr<StatModel>& model, Mat& dst)
87 Mat testSample( 1, 2, CV_32FC1 );
88 for( int y = 0; y < img.rows; y += testStep )
90 for( int x = 0; x < img.cols; x += testStep )
92 testSample.at<float>(0) = (float)x;
93 testSample.at<float>(1) = (float)y;
95 int response = (int)model->predict( testSample );
96 dst.at<Vec3b>(y, x) = classColors[response];
102 static void find_decision_boundary_NBC()
105 Ptr<NormalBayesClassifier> normalBayesClassifier = StatModel::train<NormalBayesClassifier>(prepare_train_data(), NormalBayesClassifier::Params());
107 predict_and_paint(normalBayesClassifier, imgDst);
113 static void find_decision_boundary_KNN( int K )
115 Ptr<KNearest> knn = StatModel::train<KNearest>(prepare_train_data(), KNearest::Params(K, true));
116 predict_and_paint(knn, imgDst);
121 static void find_decision_boundary_SVM( SVM::Params params )
123 Ptr<SVM> svm = StatModel::train<SVM>(prepare_train_data(), params);
124 predict_and_paint(svm, imgDst);
126 Mat sv = svm->getSupportVectors();
127 for( int i = 0; i < sv.rows; i++ )
129 const float* supportVector = sv.ptr<float>(i);
130 circle( imgDst, Point(saturate_cast<int>(supportVector[0]),saturate_cast<int>(supportVector[1])), 5, Scalar(255,255,255), -1 );
136 static void find_decision_boundary_DT()
138 DTrees::Params params;
140 params.minSampleCount = 2;
141 params.useSurrogates = false;
142 params.CVFolds = 0; // the number of cross-validation folds
143 params.use1SERule = false;
144 params.truncatePrunedTree = false;
146 Ptr<DTrees> dtree = StatModel::train<DTrees>(prepare_train_data(), params);
148 predict_and_paint(dtree, imgDst);
153 static void find_decision_boundary_BT()
155 Boost::Params params( Boost::DISCRETE, // boost_type
157 0.95, // weight_trim_rate
159 false, //use_surrogates
163 Ptr<Boost> boost = StatModel::train<Boost>(prepare_train_data(), params);
164 predict_and_paint(boost, imgDst);
170 static void find_decision_boundary_GBT()
172 GBTrees::Params params( GBTrees::DEVIANCE_LOSS, // loss_function_type
175 1.0f, // subsample_portion
177 false // use_surrogates )
180 Ptr<GBTrees> gbtrees = StatModel::train<GBTrees>(prepare_train_data(), params);
181 predict_and_paint(gbtrees, imgDst);
186 static void find_decision_boundary_RF()
188 RTrees::Params params( 4, // max_depth,
189 2, // min_sample_count,
190 0.f, // regression_accuracy,
191 false, // use_surrogates,
192 16, // max_categories,
194 false, // calc_var_importance,
196 TermCriteria(TermCriteria::MAX_ITER, 5, 0) // max_num_of_trees_in_the_forest,
199 Ptr<RTrees> rtrees = StatModel::train<RTrees>(prepare_train_data(), params);
200 predict_and_paint(rtrees, imgDst);
206 static void find_decision_boundary_ANN( const Mat& layer_sizes )
208 ANN_MLP::Params params(layer_sizes, ANN_MLP::SIGMOID_SYM, 1, 1, TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON),
209 ANN_MLP::Params::BACKPROP, 0.001);
211 Mat trainClasses = Mat::zeros( (int)trainedPoints.size(), (int)classColors.size(), CV_32FC1 );
212 for( int i = 0; i < trainClasses.rows; i++ )
214 trainClasses.at<float>(i, trainedPointsMarkers[i]) = 1.f;
217 Mat samples = prepare_train_samples(trainedPoints);
218 Ptr<TrainData> tdata = TrainData::create(samples, ROW_SAMPLE, trainClasses);
220 Ptr<ANN_MLP> ann = StatModel::train<ANN_MLP>(tdata, params);
221 predict_and_paint(ann, imgDst);
226 static void find_decision_boundary_EM()
228 img.copyTo( imgDst );
230 Mat samples = prepare_train_samples(trainedPoints);
232 int i, j, nmodels = (int)classColors.size();
233 vector<Ptr<EM> > em_models(nmodels);
236 for( i = 0; i < nmodels; i++ )
238 const int componentCount = 3;
240 modelSamples.release();
241 for( j = 0; j < samples.rows; j++ )
243 if( trainedPointsMarkers[j] == i )
244 modelSamples.push_back(samples.row(j));
248 if( !modelSamples.empty() )
250 em_models[i] = EM::train(modelSamples, noArray(), noArray(), noArray(),
251 EM::Params(componentCount, EM::COV_MAT_DIAGONAL));
255 // classify coordinate plane points using the bayes classifier, i.e.
256 // y(x) = arg max_i=1_modelsCount likelihoods_i(x)
257 Mat testSample(1, 2, CV_32FC1 );
258 Mat logLikelihoods(1, nmodels, CV_64FC1, Scalar(-DBL_MAX));
260 for( int y = 0; y < img.rows; y += testStep )
262 for( int x = 0; x < img.cols; x += testStep )
264 testSample.at<float>(0) = (float)x;
265 testSample.at<float>(1) = (float)y;
267 for( i = 0; i < nmodels; i++ )
269 if( !em_models[i].empty() )
270 logLikelihoods.at<double>(i) = em_models[i]->predict2(testSample, noArray())[0];
273 minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
274 imgDst.at<Vec3b>(y, x) = classColors[maxLoc.x];
282 cout << "Use:" << endl
283 << " key '0' .. '1' - switch to class #n" << endl
284 << " left mouse button - to add new point;" << endl
285 << " key 'r' - to run the ML model;" << endl
286 << " key 'i' - to init (clear) the data." << endl << endl;
288 cv::namedWindow( "points", 1 );
289 img.create( 480, 640, CV_8UC3 );
290 imgDst.create( 480, 640, CV_8UC3 );
292 imshow( "points", img );
293 setMouseCallback( "points", on_mouse );
295 classColors[0] = Vec3b(0, 255, 0);
296 classColors[1] = Vec3b(0, 0, 255);
300 uchar key = (uchar)waitKey();
302 if( key == 27 ) break;
304 if( key == 'i' ) // init
306 img = Scalar::all(0);
308 trainedPoints.clear();
309 trainedPointsMarkers.clear();
310 classCounters.assign(MAX_CLASSES, 0);
312 imshow( winName, img );
315 if( key == '0' || key == '1' )
317 currentClass = key - '0';
320 if( key == 'r' ) // run
323 minMaxLoc(classCounters, &minVal, 0, 0, 0);
326 printf("each class should have at least 1 point\n");
329 img.copyTo( imgDst );
331 find_decision_boundary_NBC();
332 imshow( "NormalBayesClassifier", imgDst );
336 find_decision_boundary_KNN( K );
337 imshow( "kNN", imgDst );
340 find_decision_boundary_KNN( K );
341 imshow( "kNN2", imgDst );
345 //(1)-(2)separable and not sets
347 params.svmType = SVM::C_SVC;
348 params.kernelType = SVM::POLY; //CvSVM::LINEAR;
355 params.termCrit = TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 1000, 0.01);
357 find_decision_boundary_SVM( params );
358 imshow( "classificationSVM1", imgDst );
361 find_decision_boundary_SVM( params );
362 imshow( "classificationSVM2", imgDst );
366 find_decision_boundary_DT();
367 imshow( "DT", imgDst );
371 find_decision_boundary_BT();
372 imshow( "BT", imgDst);
376 find_decision_boundary_GBT();
377 imshow( "GBT", imgDst);
381 find_decision_boundary_RF();
382 imshow( "RF", imgDst);
386 Mat layer_sizes1( 1, 3, CV_32SC1 );
387 layer_sizes1.at<int>(0) = 2;
388 layer_sizes1.at<int>(1) = 5;
389 layer_sizes1.at<int>(2) = (int)classColors.size();
390 find_decision_boundary_ANN( layer_sizes1 );
391 imshow( "ANN", imgDst );
395 find_decision_boundary_EM();
396 imshow( "EM", imgDst );