From 7120355e06013a894fa33bdd0ac7a607b7f2aff8 Mon Sep 17 00:00:00 2001 From: Maria Dimashova Date: Mon, 16 Apr 2012 14:54:56 +0000 Subject: [PATCH] updated points_classifier sample to use bayes classifier after distributions estimation by EM --- samples/cpp/points_classifier.cpp | 41 ++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/samples/cpp/points_classifier.cpp b/samples/cpp/points_classifier.cpp index f8902d4..e00141c 100644 --- a/samples/cpp/points_classifier.cpp +++ b/samples/cpp/points_classifier.cpp @@ -442,16 +442,30 @@ void find_decision_boundary_EM() Mat trainSamples, trainClasses; prepare_train_data( trainSamples, trainClasses ); - cv::EM em; - cv::EM::Params params; - params.nclusters = classColors.size(); - params.covMatType = cv::EM::COV_MAT_GENERIC; - params.startStep = cv::EM::START_AUTO_STEP; - params.termCrit = cv::TermCriteria(cv::TermCriteria::COUNT + cv::TermCriteria::COUNT, 10, 0.1); + vector em_models(classColors.size()); - // learn classifier - em.train( trainSamples, Mat(), params, &trainClasses ); + CV_Assert((int)trainClasses.total() == trainSamples.rows); + CV_Assert((int)trainClasses.type() == CV_32SC1); + + for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++) + { + const int componentCount = 3; + em_models[modelIndex] = EM(componentCount, cv::EM::COV_MAT_DIAGONAL); + + Mat modelSamples; + for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++) + { + if(trainClasses.at(sampleIndex) == (int)modelIndex) + modelSamples.push_back(trainSamples.row(sampleIndex)); + } + + // learn models + if(!modelSamples.empty()) + em_models[modelIndex].train(modelSamples); + } + // classify coordinate plane points using the bayes classifier, i.e. + // y(x) = arg max_i=1_modelsCount likelihoods_i(x) Mat testSample(1, 2, CV_32FC1 ); for( int y = 0; y < img.rows; y += testStep ) { @@ -460,7 +474,16 @@ void find_decision_boundary_EM() testSample.at(0) = (float)x; testSample.at(1) = (float)y; - int response = (int)em.predict( testSample ); + Mat logLikelihoods(1, em_models.size(), CV_64FC1, Scalar(-DBL_MAX)); + for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++) + { + if(em_models[modelIndex].isTrained()) + em_models[modelIndex].predict( testSample, noArray(), &logLikelihoods.at(modelIndex) ); + } + Point maxLoc; + minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc); + + int response = maxLoc.x; circle( imgDst, Point(x,y), 2, classColors[response], 1 ); } } -- 2.7.4