fixed em test
authorMaria Dimashova <no@email>
Thu, 29 Mar 2012 07:01:57 +0000 (07:01 +0000)
committerMaria Dimashova <no@email>
Thu, 29 Mar 2012 07:01:57 +0000 (07:01 +0000)
modules/ml/test/test_emknearestkmeans.cpp

index b9ddfca..93bbadf 100644 (file)
@@ -87,8 +87,10 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const vecto
             r =  r * (*cit) + *mit; 
             if( labelType == CV_32FC1 )
                 labels.at<float>(p, 0) = (float)l;
-            else
+            else if( labelType == CV_32SC1 )
                 labels.at<int>(p, 0) = l;
+            else
+                CV_DbgAssert(0);
         }
     }
 }
@@ -201,20 +203,23 @@ void CV_KMeansTest::run( int /*start_from*/ )
     generateData( data, labels, sizes, means, covs, CV_32SC1 );
     
     int code = cvtest::TS::OK;
+    float err;
     Mat bestLabels;
     // 1. flag==KMEANS_PP_CENTERS
     kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_PP_CENTERS, noArray() );
-    if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
+    err = calcErr( bestLabels, labels, sizes, false );
+    if( err > 0.01f )
     {
-        ts->printf( cvtest::TS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
+        ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
         code = cvtest::TS::FAIL_BAD_ACCURACY;
     }
 
     // 2. flag==KMEANS_RANDOM_CENTERS
     kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_RANDOM_CENTERS, noArray() );
-    if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
+    err = calcErr( bestLabels, labels, sizes, false );
+    if( err > 0.01f )
     {
-        ts->printf( cvtest::TS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
+        ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
         code = cvtest::TS::FAIL_BAD_ACCURACY;
     }
 
@@ -224,9 +229,10 @@ void CV_KMeansTest::run( int /*start_from*/ )
     for( int i = 0; i < 0.5f * pointsCount; i++ )
         bestLabels.at<int>( rng.next() % pointsCount, 0 ) = rng.next() % 3;
     kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_USE_INITIAL_LABELS, noArray() );
-    if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
+    err = calcErr( bestLabels, labels, sizes, false );
+    if( err > 0.01f )
     {
-        ts->printf( cvtest::TS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
+        ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
         code = cvtest::TS::FAIL_BAD_ACCURACY;
     }
 
@@ -261,9 +267,10 @@ void CV_KNearestTest::run( int /*start_from*/ )
     KNearest knearest;
     knearest.train( trainData, trainLabels );
     knearest.find_nearest( testData, 4, &bestLabels );
-    if( calcErr( bestLabels, testLabels, sizes, true ) > 0.01f )
+    float err = calcErr( bestLabels, testLabels, sizes, true );
+    if( err > 0.01f )
     {
-        ts->printf( cvtest::TS::LOG, "bad accuracy on test data" );
+        ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
         code = cvtest::TS::FAIL_BAD_ACCURACY;
     }
     ts->set_failed_test_info( code );
@@ -294,15 +301,17 @@ void CV_EMTest::run( int /*start_from*/ )
     generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
 
     int code = cvtest::TS::OK;
+    float err;
     ExpectationMaximization em;
     CvEMParams params;
     params.nclusters = 3;
     em.train( trainData, Mat(), params, &bestLabels );
 
     // check train error
-    if( calcErr( bestLabels, trainLabels, sizes, true ) > 0.002f )
+    err = calcErr( bestLabels, trainLabels, sizes, false );
+    if( err > 0.002f )
     {
-        ts->printf( cvtest::TS::LOG, "bad accuracy on train data" );
+        ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on train data.\n", err );
         code = cvtest::TS::FAIL_BAD_ACCURACY;
     }
 
@@ -313,9 +322,10 @@ void CV_EMTest::run( int /*start_from*/ )
         Mat sample( 1, testData.cols, CV_32FC1, testData.ptr<float>(i));
         bestLabels.at<int>(i,0) = (int)em.predict( sample, 0 );
     }
-    if( calcErr( bestLabels, testLabels, sizes, true ) > 0.005f )
+    err = calcErr( bestLabels, testLabels, sizes, false );
+    if( err > 0.005f )
     {
-        ts->printf( cvtest::TS::LOG, "bad accuracy on test data" );
+        ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
         code = cvtest::TS::FAIL_BAD_ACCURACY;
     }
     
@@ -324,4 +334,4 @@ void CV_EMTest::run( int /*start_from*/ )
 
 TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
 TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
-TEST(ML_EMTest, accuracy) { CV_EMTest test; test.safe_run(); }
+TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }