incorporated several critical fixes in EM implementation from Albert G (ticket #264)
authorVadim Pisarevsky <no@email>
Sat, 20 Nov 2010 12:34:16 +0000 (12:34 +0000)
committerVadim Pisarevsky <no@email>
Sat, 20 Nov 2010 12:34:16 +0000 (12:34 +0000)
modules/ml/src/em.cpp

index 755cd3a..e540968 100644 (file)
@@ -789,8 +789,9 @@ double CvEM::run_em( const CvVectors& train_data )
     int nsamples = train_data.count, dims = train_data.dims, nclusters = params.nclusters;
     double min_variation = FLT_EPSILON;
     double min_det_value = MAX( DBL_MIN, pow( min_variation, dims ));
-    double likelihood_bias = -CV_LOG2PI * (double)nsamples * (double)dims / 2., _log_likelihood = -DBL_MAX;
+    double _log_likelihood = -DBL_MAX;
     int start_step = params.start_step;
+    double sum_max_val;
 
     int i, j, k, n;
     int is_general = 0, is_diagonal = 0, is_spherical = 0;
@@ -912,6 +913,7 @@ double CvEM::run_em( const CvVectors& train_data )
             // e-step: compute probs_ik from means_k, covs_k and weights_k.
             CV_CALL(cvLog( weights, log_weights ));
 
+            sum_max_val = 0.;
             // S_ik = -0.5[log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] + log(weights_k)
             for( k = 0; k < nclusters; k++ )
             {
@@ -934,14 +936,16 @@ double CvEM::run_em( const CvVectors& train_data )
                         cvGEMM( centered_sample, u, 1, 0, 0, centered_sample, CV_GEMM_B_T );
                     for( j = 0; j < dims; j++ )
                         p += csample[j]*csample[j]*w_data[is_spherical ? 0 : j];
-                    pp[k] = -0.5*p + log_weights->data.db[k];
+                    //pp[k] = -0.5*p + log_weights->data.db[k];
+                    pp[k] = -0.5*(p+CV_LOG2PI * (double)dims) + log_weights->data.db[k];
 
                     // S_ik <- S_ik - max_j S_ij
                     if( k == nclusters - 1 )
                     {
-                        double max_val = 0;
-                        for( j = 0; j < nclusters; j++ )
+                        double max_val = pp[0];
+                        for( j = 1; j < nclusters; j++ )
                             max_val = MAX( max_val, pp[j] );
+                        sum_max_val += max_val;
                         for( j = 0; j < nclusters; j++ )
                             pp[j] -= max_val;
                     }
@@ -953,7 +957,7 @@ double CvEM::run_em( const CvVectors& train_data )
 
             // alpha_ik = exp( S_ik ) / sum_j exp( S_ij ),
             // log_likelihood = sum_i log (sum_j exp(S_ij))
-            for( i = 0, _log_likelihood = likelihood_bias; i < nsamples; i++ )
+            for( i = 0, _log_likelihood = 0; i < nsamples; i++ )
             {
                 double* pp = (double*)(probs->data.ptr + probs->step*i), sum = 0;
                 for( j = 0; j < nclusters; j++ )
@@ -966,9 +970,11 @@ double CvEM::run_em( const CvVectors& train_data )
                 }
                 _log_likelihood -= log( sum );
             }
+            _log_likelihood+=sum_max_val;
 
             // check termination criteria
-            if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon )
+            //if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon )
+            if( fabs( (_log_likelihood - prev_log_likelihood)  ) < params.term_crit.epsilon )
                 break;
             prev_log_likelihood = _log_likelihood;
         }