fixed test on em
authorMaria Dimashova <no@email>
Mon, 16 Apr 2012 12:15:16 +0000 (12:15 +0000)
committerMaria Dimashova <no@email>
Mon, 16 Apr 2012 12:15:16 +0000 (12:15 +0000)
modules/ml/test/test_emknearestkmeans.cpp

index c6f6023..911b2d9 100644 (file)
@@ -129,7 +129,7 @@ int maxIdx( const vector<int>& count )
 }
 
 static
-bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap )
+bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap, bool checkClusterUniq=true )
 {
     size_t total = 0, nclusters = sizes.size();
     for(size_t i = 0; i < sizes.size(); i++)
@@ -158,21 +158,26 @@ bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& lab
         startIndex += sizes[clusterIndex];
 
         int cls = maxIdx( count );
-        CV_Assert( !buzy[cls] );
+        if(checkClusterUniq)
+            CV_Assert( !buzy[cls] );
 
         labelsMap[clusterIndex] = cls;
 
         buzy[cls] = true;
     }
-    for(size_t i = 0; i < buzy.size(); i++)
-        if(!buzy[i])
-            return false;
+
+    if(checkClusterUniq)
+    {
+        for(size_t i = 0; i < buzy.size(); i++)
+            if(!buzy[i])
+                return false;
+    }
 
     return true;
 }
 
 static
-bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true )
+bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true, bool checkClusterUniq=true )
 {
     err = 0;
     CV_Assert( !labels.empty() && !origLabels.empty() );
@@ -186,7 +191,7 @@ bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes
     bool isFlt = labels.type() == CV_32FC1;
     if( !labelsEquivalent )
     {
-        if( !getLabelsMap( labels, sizes, labelsMap ) )
+        if( !getLabelsMap( labels, sizes, labelsMap, checkClusterUniq ) )
             return false;
 
         for( int i = 0; i < labels.rows; i++ )
@@ -376,7 +381,7 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
         em.trainM( trainData, *params.probs, labels );
 
     // check train error
-    if( !calcErr( labels, trainLabels, sizes, err , false ) )
+    if( !calcErr( labels, trainLabels, sizes, err , false, false ) )
     {
         ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
         code = cvtest::TS::FAIL_INVALID_OUTPUT;
@@ -396,7 +401,7 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
         Mat probs;
         labels.at<int>(i,0) = (int)em.predict( sample, probs, &likelihood );
     }
-    if( !calcErr( labels, testLabels, sizes, err, false ) )
+    if( !calcErr( labels, testLabels, sizes, err, false, false ) )
     {
         ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
         code = cvtest::TS::FAIL_INVALID_OUTPUT;