updated points_classifier sample to use bayes classifier after distributions estimati...
authorMaria Dimashova <no@email>
Mon, 16 Apr 2012 14:54:56 +0000 (14:54 +0000)
committerMaria Dimashova <no@email>
Mon, 16 Apr 2012 14:54:56 +0000 (14:54 +0000)
samples/cpp/points_classifier.cpp

index f8902d4..e00141c 100644 (file)
@@ -442,16 +442,30 @@ void find_decision_boundary_EM()
     Mat trainSamples, trainClasses;
     prepare_train_data( trainSamples, trainClasses );
 
-    cv::EM em;
-    cv::EM::Params params;
-    params.nclusters = classColors.size();
-    params.covMatType = cv::EM::COV_MAT_GENERIC;
-    params.startStep = cv::EM::START_AUTO_STEP;
-    params.termCrit = cv::TermCriteria(cv::TermCriteria::COUNT + cv::TermCriteria::COUNT, 10, 0.1);
+    vector<cv::EM> em_models(classColors.size());
 
-    // learn classifier
-    em.train( trainSamples, Mat(), params, &trainClasses );
+    CV_Assert((int)trainClasses.total() == trainSamples.rows);
+    CV_Assert((int)trainClasses.type() == CV_32SC1);
+
+    for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
+    {
+        const int componentCount = 3;
+        em_models[modelIndex] = EM(componentCount, cv::EM::COV_MAT_DIAGONAL);
+
+        Mat modelSamples;
+        for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
+        {
+            if(trainClasses.at<int>(sampleIndex) == (int)modelIndex)
+                modelSamples.push_back(trainSamples.row(sampleIndex));
+        }
+
+        // learn models
+        if(!modelSamples.empty())
+            em_models[modelIndex].train(modelSamples);
+    }
 
+    // classify coordinate plane points using the bayes classifier, i.e.
+    // y(x) = arg max_i=1_modelsCount likelihoods_i(x)
     Mat testSample(1, 2, CV_32FC1 );
     for( int y = 0; y < img.rows; y += testStep )
     {
@@ -460,7 +474,16 @@ void find_decision_boundary_EM()
             testSample.at<float>(0) = (float)x;
             testSample.at<float>(1) = (float)y;
 
-            int response = (int)em.predict( testSample );
+            Mat logLikelihoods(1, em_models.size(), CV_64FC1, Scalar(-DBL_MAX));
+            for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
+            {
+                if(em_models[modelIndex].isTrained())
+                    em_models[modelIndex].predict( testSample, noArray(), &logLikelihoods.at<double>(modelIndex) );
+            }
+            Point maxLoc;
+            minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
+
+            int response = maxLoc.x;
             circle( imgDst, Point(x,y), 2, classColors[response], 1 );
         }
     }