added smoke test on EM, fixed EM reading #1570 (thanks to mr.pppoe),
authorMaria Dimashova <no@email>
Thu, 29 Mar 2012 08:55:43 +0000 (08:55 +0000)
committerMaria Dimashova <no@email>
Thu, 29 Mar 2012 08:55:43 +0000 (08:55 +0000)
modules/ml/src/em.cpp
modules/ml/test/test_emknearestkmeans.cpp

index 5f98b29..9ec262d 100644 (file)
@@ -141,8 +141,6 @@ void CvEM::read( CvFileStorage* fs, CvFileNode* node )
     CvFileNode* em_node = 0;
     CvFileNode* tmp_node = 0;
     CvSeq* seq = 0;
-    CvMat **tmp_covs = 0;
-    CvMat **tmp_cov_rotate_mats = 0;
 
     read_params( fs, node );
 
@@ -156,13 +154,10 @@ void CvEM::read( CvFileStorage* fs, CvFileNode* node )
     CV_CALL( inv_eigen_values = (CvMat*)cvReadByName( fs, em_node, "inv_eigen_values" ));
 
     // Size of all the following data
-    data_size = params.nclusters*2*sizeof(CvMat*);
-
-    CV_CALL( tmp_covs = (CvMat**)cvAlloc( data_size ));
-    memset( tmp_covs, 0, data_size );
-
-    tmp_cov_rotate_mats = tmp_covs + params.nclusters;
+    data_size = params.nclusters*sizeof(CvMat*);
 
+    CV_CALL( covs = (CvMat**)cvAlloc( data_size ));
+    memset( covs, 0, data_size );
     CV_CALL( tmp_node = cvGetFileNodeByName( fs, em_node, "covs" ));
     seq = tmp_node->data.seq;
     if( !CV_NODE_IS_SEQ(tmp_node->tag) || seq->total != params.nclusters)
@@ -170,24 +165,23 @@ void CvEM::read( CvFileStorage* fs, CvFileNode* node )
     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
     for( int i = 0; i < params.nclusters; i++ )
     {
-        CV_CALL( tmp_covs[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
+        CV_CALL( covs[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
     }
 
+    CV_CALL( cov_rotate_mats = (CvMat**)cvAlloc( data_size ));
+    memset( cov_rotate_mats, 0, data_size );
     CV_CALL( tmp_node = cvGetFileNodeByName( fs, em_node, "cov_rotate_mats" ));
     seq = tmp_node->data.seq;
     if( !CV_NODE_IS_SEQ(tmp_node->tag) || seq->total != params.nclusters)
-        CV_ERROR( CV_StsParseError, "Missing or invalid sequence of rotated cov. matrices" );
+        CV_ERROR( CV_StsParseError, "Missing or invalid sequence of covariance matrices" );
     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
     for( int i = 0; i < params.nclusters; i++ )
     {
-        CV_CALL( tmp_cov_rotate_mats[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
+        CV_CALL( cov_rotate_mats[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
     }
 
-    covs = tmp_covs;
-    cov_rotate_mats = tmp_cov_rotate_mats;
-
     ok = true;
     __END__;
 
@@ -862,10 +856,10 @@ void CvEM::kmeans( const CvVectors& train_data, int nclusters, CvMat* labels,
 {
     int i, nsamples = train_data.count, dims = train_data.dims;
     cv::Ptr<CvMat> temp_mat = cvCreateMat(nsamples, dims, CV_32F);
-    
+
     for( i = 0; i < nsamples; i++ )
         memcpy( temp_mat->data.ptr + temp_mat->step*i, train_data.data.fl[i], dims*sizeof(float));
-    
+
     cvKMeans2(temp_mat, nclusters, labels, termcrit, 10);
 }
 
@@ -1240,20 +1234,20 @@ CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params )
 {
     means = weights = probs = inv_eigen_values = log_weight_div_det = 0;
     covs = cov_rotate_mats = 0;
-    
+
     // just invoke the train() method
     train(samples, sample_idx, params);
-}    
+}
 
 bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
                  CvEMParams _params, Mat* _labels )
 {
     CvMat samples = _samples, sidx = _sample_idx, labels, *plabels = 0;
-    
+
     if( _labels )
     {
         int nsamples = sidx.data.ptr ? sidx.rows : samples.rows;
-        
+
         if( !(_labels->data && _labels->type() == CV_32SC1 &&
               (_labels->cols == 1 || _labels->rows == 1) &&
               _labels->cols + _labels->rows - 1 == nsamples) )
@@ -1267,7 +1261,7 @@ float
 CvEM::predict( const Mat& _sample, Mat* _probs ) const
 {
     CvMat sample = _sample, probs, *pprobs = 0;
-    
+
     if( _probs )
     {
         int nclusters = params.nclusters;
index 93bbadf..0576e73 100644 (file)
@@ -332,6 +332,82 @@ void CV_EMTest::run( int /*start_from*/ )
     ts->set_failed_test_info( code );
 }
 
+class CV_EMTest_Smoke : public cvtest::BaseTest {
+public:
+    CV_EMTest_Smoke() {}
+protected:
+    virtual void run( int /*start_from*/ )
+    {
+        int code = cvtest::TS::OK;
+        CvEM em;
+
+        Mat samples = Mat(3,2,CV_32F);
+        samples.at<float>(0,0) = 1;
+        samples.at<float>(1,0) = 2;
+        samples.at<float>(2,0) = 3;
+
+        CvEMParams params;
+        params.nclusters = 2;
+
+        Mat labels;
+
+        em.train(samples, Mat(), params, &labels);
+
+        Mat firstResult(samples.rows, 1, CV_32FC1);
+        for( int i = 0; i < samples.rows; i++)
+            firstResult.at<float>(i) = em.predict( samples.row(i) );
+
+        // Write out
+        string filename = tempfile() + ".xml";
+        {
+            FileStorage fs = FileStorage(filename, FileStorage::WRITE);
+
+            try
+            {
+                em.write(fs.fs, "EM");
+            }
+            catch(...)
+            {
+                ts->printf( cvtest::TS::LOG, "Crash in write method.\n" );
+                ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION );
+            }
+        }
+
+        em.clear();
+
+        // Read in
+        {
+            FileStorage fs = FileStorage(filename, FileStorage::READ);
+            FileNode fileNode = fs["EM"];
+
+            try
+            {
+                em.read(const_cast<CvFileStorage*>(fileNode.fs), const_cast<CvFileNode*>(fileNode.node));
+            }
+            catch(...)
+            {
+                ts->printf( cvtest::TS::LOG, "Crash in read method.\n" );
+                ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION );
+            }
+        }
+
+        remove( filename.c_str() );
+
+        int errCaseCount = 0;
+        for( int i = 0; i < samples.rows; i++)
+            errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<float>(i)) < FLT_EPSILON ? 0 : 1;
+
+        if( errCaseCount > 0 )
+        {
+            ts->printf( cvtest::TS::LOG, "Different prediction results before writeing and after reading (errCaseCount=%d).\n", errCaseCount );
+            code = cvtest::TS::FAIL_BAD_ACCURACY;
+        }
+
+        ts->set_failed_test_info( code );
+    }
+};
+
 TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
 TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
 TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
+TEST(ML_EM, smoke) { CV_EMTest_Smoke test; test.safe_run(); }