refactored train and predict methods of em
authorMaria Dimashova <no@email>
Tue, 17 Apr 2012 06:29:40 +0000 (06:29 +0000)
committerMaria Dimashova <no@email>
Tue, 17 Apr 2012 06:29:40 +0000 (06:29 +0000)
modules/contrib/src/hybridtracker.cpp
modules/legacy/include/opencv2/legacy/legacy.hpp
modules/legacy/src/em.cpp
modules/ml/include/opencv2/ml/ml.hpp
modules/ml/src/em.cpp
modules/ml/test/test_emknearestkmeans.cpp
samples/cpp/points_classifier.cpp

index b499611..b4ed708 100644 (file)
@@ -213,7 +213,7 @@ void CvHybridTracker::updateTrackerWithEM(Mat image) {
     cv::Mat lbls;
     
     EM em_model(1, EM::COV_MAT_SPHERICAL, TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.001));
-    em_model.train(cvarrToMat(samples), lbls);
+    em_model.train(cvarrToMat(samples), noArray(), lbls);
     if(labels)
         lbls.copyTo(cvarrToMat(labels));
 
index 8df5a23..36e4303 100644 (file)
@@ -1826,7 +1826,7 @@ public:
     CV_WRAP cv::Mat getWeights() const;
     CV_WRAP cv::Mat getProbs() const;
 
-    CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? likelihood : DBL_MAX; }
+    CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? logLikelihood : DBL_MAX; }
 #endif
 
     CV_WRAP virtual void clear();
@@ -1847,7 +1847,7 @@ protected:
 
     cv::EM emObj;
     cv::Mat probs;
-    double likelihood;
+    double logLikelihood;
 
     CvMat meansHdr;
     std::vector<CvMat> covsHdrs;
index 543df09..54e8a1b 100644 (file)
@@ -56,12 +56,12 @@ CvEMParams::CvEMParams( int _nclusters, int _cov_mat_type, int _start_step,
                         probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
 {}
 
-CvEM::CvEM() : likelihood(DBL_MAX)
+CvEM::CvEM() : logLikelihood(DBL_MAX)
 {
 }
 
 CvEM::CvEM( const CvMat* samples, const CvMat* sample_idx,
-            CvEMParams params, CvMat* labels ) : likelihood(DBL_MAX)
+            CvEMParams params, CvMat* labels ) : logLikelihood(DBL_MAX)
 {
     train(samples, sample_idx, params, labels);
 }
@@ -96,16 +96,14 @@ void CvEM::write( CvFileStorage* _fs, const char* name ) const
 
 double CvEM::calcLikelihood( const Mat &input_sample ) const
 {
-    double likelihood;
-    emObj.predict(input_sample, noArray(), &likelihood);
-    return likelihood;
+    return emObj.predict(input_sample)[0];
 }
 
 float
 CvEM::predict( const CvMat* _sample, CvMat* _probs ) const
 {
     Mat prbs0 = cvarrToMat(_probs), prbs = prbs0, sample = cvarrToMat(_sample);
-       int cls = emObj.predict(sample, _probs ? _OutputArray(prbs) : cv::noArray());
+    int cls = static_cast<int>(emObj.predict(sample, _probs ? _OutputArray(prbs) : cv::noArray())[1]);
     if(_probs)
     {
         if( prbs.data != prbs0.data )
@@ -203,29 +201,27 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
                  CvEMParams _params, Mat* _labels )
 {
     CV_Assert(_sample_idx.empty());
-    Mat prbs, weights, means, likelihoods;
+    Mat prbs, weights, means, logLikelihoods;
     std::vector<Mat> covsHdrs;
     init_params(_params, prbs, weights, means, covsHdrs);
 
     emObj = EM(_params.nclusters, _params.cov_mat_type, _params.term_crit);
     bool isOk = false;
     if( _params.start_step == EM::START_AUTO_STEP )
-               isOk = emObj.train(_samples, _labels ? _OutputArray(*_labels) : cv::noArray(),
-                    probs, likelihoods);
+        isOk = emObj.train(_samples,
+                           logLikelihoods, _labels ? _OutputArray(*_labels) : cv::noArray(), probs);
     else if( _params.start_step == EM::START_E_STEP )
         isOk = emObj.trainE(_samples, means, covsHdrs, weights,
-                                        _labels ? _OutputArray(*_labels) : cv::noArray(),
-                     probs, likelihoods);
+                            logLikelihoods, _labels ? _OutputArray(*_labels) : cv::noArray(), probs);
     else if( _params.start_step == EM::START_M_STEP )
         isOk = emObj.trainM(_samples, prbs,
-                                       _labels ? _OutputArray(*_labels) : cv::noArray(),
-                    probs, likelihoods);
+                            logLikelihoods, _labels ? _OutputArray(*_labels) : cv::noArray(), probs);
     else
         CV_Error(CV_StsBadArg, "Bad start type of EM algorithm");
     
     if(isOk)
     {
-        likelihoods = sum(likelihoods).val[0];
+        logLikelihood = sum(logLikelihoods).val[0];
         set_mat_hdrs();
     }
 
@@ -235,8 +231,7 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
 float
 CvEM::predict( const Mat& _sample, Mat* _probs ) const
 {
-       int cls = emObj.predict(_sample, _probs ? _OutputArray(*_probs) : cv::noArray());
-    return (float)cls;
+    return static_cast<float>(emObj.predict(_sample, _probs ? _OutputArray(*_probs) : cv::noArray())[1]);
 }
 
 int CvEM::getNClusters() const
index 16aaae5..da83cf4 100644 (file)
@@ -577,27 +577,26 @@ public:
     CV_WRAP virtual void clear();
 
     CV_WRAP virtual bool train(InputArray samples,
+                       OutputArray logLikelihoods=noArray(),
                        OutputArray labels=noArray(),
-                       OutputArray probs=noArray(),
-                       OutputArray logLikelihoods=noArray());
+                       OutputArray probs=noArray());
     
     CV_WRAP virtual bool trainE(InputArray samples,
                         InputArray means0,
                         InputArray covs0=noArray(),
                         InputArray weights0=noArray(),
+                        OutputArray logLikelihoods=noArray(),
                         OutputArray labels=noArray(),
-                        OutputArray probs=noArray(),
-                        OutputArray logLikelihoods=noArray());
+                        OutputArray probs=noArray());
     
     CV_WRAP virtual bool trainM(InputArray samples,
                         InputArray probs0,
+                        OutputArray logLikelihoods=noArray(),
                         OutputArray labels=noArray(),
-                        OutputArray probs=noArray(),
-                        OutputArray logLikelihoods=noArray());
+                        OutputArray probs=noArray());
     
-    CV_WRAP int predict(InputArray sample,
-                OutputArray probs=noArray(),
-                CV_OUT double* logLikelihood=0) const;
+    CV_WRAP Vec2d predict(InputArray sample,
+                OutputArray probs=noArray()) const;
 
     CV_WRAP bool isTrained() const;
 
@@ -613,9 +612,9 @@ protected:
                               const Mat* weights0);
 
     bool doTrain(int startStep,
+                 OutputArray logLikelihoods,
                  OutputArray labels,
-                 OutputArray probs,
-                 OutputArray logLikelihoods);
+                 OutputArray probs);
     virtual void eStep();
     virtual void mStep();
 
@@ -623,7 +622,7 @@ protected:
     void decomposeCovs();
     void computeLogWeightDivDet();
 
-    void computeProbabilities(const Mat& sample, int& label, Mat* probs, double* logLikelihood) const;
+    Vec2d computeProbabilities(const Mat& sample, Mat* probs) const;
 
     // all inner matrices have type CV_64FC1
     CV_PROP_RW int nclusters;
index 545e107..d827d56 100644 (file)
@@ -81,22 +81,22 @@ void EM::clear()
 
     
 bool EM::train(InputArray samples,
+               OutputArray logLikelihoods,
                OutputArray labels,
-               OutputArray probs,
-               OutputArray logLikelihoods)
+               OutputArray probs)
 {
     Mat samplesMat = samples.getMat();
     setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
-    return doTrain(START_AUTO_STEP, labels, probs, logLikelihoods);
+    return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
 }
 
 bool EM::trainE(InputArray samples,
                 InputArray _means0,
                 InputArray _covs0,
                 InputArray _weights0,
+                OutputArray logLikelihoods,
                 OutputArray labels,
-                OutputArray probs,
-                OutputArray logLikelihoods)
+                OutputArray probs)
 {
     Mat samplesMat = samples.getMat();
     vector<Mat> covs0;
@@ -106,24 +106,24 @@ bool EM::trainE(InputArray samples,
 
     setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
                  !_covs0.empty() ? &covs0 : 0, _weights0.empty() ? &weights0 : 0);
-    return doTrain(START_E_STEP, labels, probs, logLikelihoods);
+    return doTrain(START_E_STEP, logLikelihoods, labels, probs);
 }
 
 bool EM::trainM(InputArray samples,
                 InputArray _probs0,
+                OutputArray logLikelihoods,
                 OutputArray labels,
-                OutputArray probs,
-                OutputArray logLikelihoods)
+                OutputArray probs)
 {
     Mat samplesMat = samples.getMat();
     Mat probs0 = _probs0.getMat();
     
     setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
-    return doTrain(START_M_STEP, labels, probs, logLikelihoods);
+    return doTrain(START_M_STEP, logLikelihoods, labels, probs);
 }
 
     
-int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) const
+Vec2d EM::predict(InputArray _sample, OutputArray _probs) const
 {
     Mat sample = _sample.getMat();
     CV_Assert(isTrained());
@@ -136,16 +136,14 @@ int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) c
         sample = tmp;
     }
 
-    int label;
     Mat probs;
     if( _probs.needed() )
     {
         _probs.create(1, nclusters, CV_64FC1);
         probs = _probs.getMat();
     }
-    computeProbabilities(sample, label, !probs.empty() ? &probs : 0, logLikelihood);
 
-    return label;
+    return computeProbabilities(sample, !probs.empty() ? &probs : 0);
 }
 
 bool EM::isTrained() const
@@ -394,7 +392,7 @@ void EM::computeLogWeightDivDet()
     }
 }
 
-bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArray logLikelihoods)
+bool EM::doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
 {
     int dim = trainSamples.cols;
     // Precompute the empty initial train data in the cases of EM::START_E_STEP and START_AUTO_STEP
@@ -472,7 +470,7 @@ bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArr
     return true;
 }
 
-void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double* logLikelihood) const
+Vec2d EM::computeProbabilities(const Mat& sample, Mat* probs) const
 {
     // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
     // q = arg(max_k(L_ik))
@@ -488,7 +486,7 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
     int dim = sample.cols;
 
     Mat L(1, nclusters, CV_64FC1);
-    label = 0;
+    int label = 0;
     for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
     {
         const Mat centeredSample = sample - means.row(clusterIndex);
@@ -511,9 +509,6 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
             label = clusterIndex;
     }
 
-    if(!probs && !logLikelihood)
-        return;
-
     double maxLVal = L.at<double>(label);
     Mat expL_Lmax = L; // exp(L_ij - L_iq)
     for(int i = 0; i < L.cols; i++)
@@ -528,8 +523,11 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
         expL_Lmax.copyTo(*probs);
     }
 
-    if(logLikelihood)
-        *logLikelihood = std::log(expDiffSum)  + maxLVal - 0.5 * dim * CV_LOG2PI;
+    Vec2d res;
+    res[0] = std::log(expDiffSum)  + maxLVal - 0.5 * dim * CV_LOG2PI;
+    res[1] = label;
+
+    return res;
 }
 
 void EM::eStep()
@@ -547,8 +545,9 @@ void EM::eStep()
     for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
     {
         Mat sampleProbs = trainProbs.row(sampleIndex);
-        computeProbabilities(trainSamples.row(sampleIndex), trainLabels.at<int>(sampleIndex),
-                             &sampleProbs, &trainLogLikelihoods.at<double>(sampleIndex));
+        Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs);
+        trainLogLikelihoods.at<double>(sampleIndex) = res[0];
+        trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
     }
 }
 
index 0990ca0..3dedb3a 100644 (file)
@@ -373,11 +373,11 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
 
     cv::EM em(params.nclusters, params.covMatType, params.termCrit);
     if( params.startStep == EM::START_AUTO_STEP )
-        em.train( trainData, labels );
+        em.train( trainData, noArray(), labels );
     else if( params.startStep == EM::START_E_STEP )
-        em.trainE( trainData, *params.means, *params.covs, *params.weights, labels );
+        em.trainE( trainData, *params.means, *params.covs, *params.weights, noArray(), labels );
     else if( params.startStep == EM::START_M_STEP )
-        em.trainM( trainData, *params.probs, labels );
+        em.trainM( trainData, *params.probs, noArray(), labels );
 
     // check train error
     if( !calcErr( labels, trainLabels, sizes, err , false, false ) )
@@ -396,9 +396,8 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
     for( int i = 0; i < testData.rows; i++ )
     {
         Mat sample = testData.row(i);
-        double likelihood = 0;
         Mat probs;
-        labels.at<int>(i,0) = (int)em.predict( sample, probs, &likelihood );
+        labels.at<int>(i) = static_cast<int>(em.predict( sample, probs )[1]);
     }
     if( !calcErr( labels, testLabels, sizes, err, false, false ) )
     {
@@ -523,7 +522,7 @@ protected:
 
         Mat firstResult(samples.rows, 1, CV_32SC1);
         for( int i = 0; i < samples.rows; i++)
-            firstResult.at<int>(i) = em.predict(samples.row(i));
+            firstResult.at<int>(i) = static_cast<int>(em.predict(samples.row(i))[1]);
 
         // Write out
         string filename = tempfile() + ".xml";
@@ -564,7 +563,7 @@ protected:
 
         int errCaseCount = 0;
         for( int i = 0; i < samples.rows; i++)
-            errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
+            errCaseCount = std::abs(em.predict(samples.row(i))[1] - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
 
         if( errCaseCount > 0 )
         {
@@ -637,10 +636,9 @@ protected:
         const double lambda = 1.;
         for(int i = 0; i < samples.rows; i++)
         {
-            double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0;
             Mat sample = samples.row(i);
-            model0.predict(sample, noArray(), &sampleLogLikelihoods0);
-            model1.predict(sample, noArray(), &sampleLogLikelihoods1);
+            double sampleLogLikelihoods0 = model0.predict(sample)[0];
+            double sampleLogLikelihoods1 = model1.predict(sample)[0];
 
             int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;
 
index e00141c..2567ba8 100644 (file)
@@ -478,7 +478,7 @@ void find_decision_boundary_EM()
             for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
             {
                 if(em_models[modelIndex].isTrained())
-                    em_models[modelIndex].predict( testSample, noArray(), &logLikelihoods.at<double>(modelIndex) );
+                    logLikelihoods.at<double>(modelIndex) = em_models[modelIndex].predict(testSample)[0];
             }
             Point maxLoc;
             minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);