Vadim, Maria, Alex, Andrey and I fixed the EM algorithm
authorIlya Lysenkov <no@email>
Tue, 10 Apr 2012 12:45:07 +0000 (12:45 +0000)
committerIlya Lysenkov <no@email>
Tue, 10 Apr 2012 12:45:07 +0000 (12:45 +0000)
modules/ml/src/em.cpp

index f2f3520..d5b2e29 100644 (file)
@@ -86,7 +86,8 @@ bool EM::train(InputArray samples,
                OutputArray probs,
                OutputArray logLikelihoods)
 {
-    setTrainData(START_AUTO_STEP, samples.getMat(), 0, 0, 0, 0);
+    Mat samplesMat = samples.getMat();
+    setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
     return doTrain(START_AUTO_STEP, labels, probs, logLikelihoods);
 }
 
@@ -98,12 +99,13 @@ bool EM::trainE(InputArray samples,
                 OutputArray probs,
                 OutputArray logLikelihoods)
 {
+    Mat samplesMat = samples.getMat();
     vector<Mat> covs0;
     _covs0.getMatVector(covs0);
     
     Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
 
-    setTrainData(START_E_STEP, samples.getMat(), 0, !_means0.empty() ? &means0 : 0,
+    setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
                  !_covs0.empty() ? &covs0 : 0, _weights0.empty() ? &weights0 : 0);
     return doTrain(START_E_STEP, labels, probs, logLikelihoods);
 }
@@ -114,9 +116,10 @@ bool EM::trainM(InputArray samples,
                 OutputArray probs,
                 OutputArray logLikelihoods)
 {
+    Mat samplesMat = samples.getMat();
     Mat probs0 = _probs0.getMat();
     
-    setTrainData(START_M_STEP, samples.getMat(), !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
+    setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
     return doTrain(START_M_STEP, labels, probs, logLikelihoods);
 }
 
@@ -337,7 +340,11 @@ void EM::clusterTrainSamples()
 
     CV_Assert(meansFlt.type() == CV_32FC1);
     if(trainSamples.type() != CV_64FC1)
-        trainSamplesFlt.convertTo(trainSamples, CV_64FC1);
+    {
+        Mat trainSamplesBuffer;
+        trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
+        trainSamples = trainSamplesBuffer;
+    }
     meansFlt.convertTo(means, CV_64FC1);
 
     // Compute weights and covs