1 #include "opencv2/highgui.hpp"
2 #include "opencv2/legacy.hpp"
6 int main( int /*argc*/, char** /*argv*/ )
9 const int N1 = (int)sqrt((double)N);
10 const Scalar colors[] =
12 Scalar(0,0,255), Scalar(0,255,0),
13 Scalar(0,255,255),Scalar(255,255,0)
18 Mat samples( nsamples, 2, CV_32FC1 );
20 Mat img = Mat::zeros( Size( 500, 500 ), CV_8UC3 );
21 Mat sample( 1, 2, CV_32FC1 );
25 samples = samples.reshape(2, 0);
26 for( i = 0; i < N; i++ )
28 // form the training samples
29 Mat samples_part = samples.rowRange(i*nsamples/N, (i+1)*nsamples/N );
31 Scalar mean(((i%N1)+1)*img.rows/(N1+1),
32 ((i/N1)+1)*img.rows/(N1+1));
34 randn( samples_part, mean, sigma );
36 samples = samples.reshape(1, 0);
38 // initialize model parameters
41 params.weights = NULL;
44 params.cov_mat_type = CvEM::COV_MAT_SPHERICAL;
45 params.start_step = CvEM::START_AUTO_STEP;
46 params.term_crit.max_iter = 300;
47 params.term_crit.epsilon = 0.1;
48 params.term_crit.type = TermCriteria::COUNT|TermCriteria::EPS;
51 em_model.train( samples, Mat(), params, &labels );
54 // the piece of code shows how to repeatedly optimize the model
55 // with less-constrained parameters
56 //(COV_MAT_DIAGONAL instead of COV_MAT_SPHERICAL)
57 // when the output of the first stage is used as input for the second one.
59 params.cov_mat_type = CvEM::COV_MAT_DIAGONAL;
60 params.start_step = CvEM::START_E_STEP;
61 params.means = em_model.get_means();
62 params.covs = (const CvMat**)em_model.get_covs();
63 params.weights = em_model.get_weights();
65 em_model2.train( samples, Mat(), params, &labels );
66 // to use em_model2, replace em_model.predict()
67 // with em_model2.predict() below
69 // classify every image pixel
70 for( i = 0; i < img.rows; i++ )
72 for( j = 0; j < img.cols; j++ )
74 sample.at<float>(0) = (float)j;
75 sample.at<float>(1) = (float)i;
76 int response = cvRound(em_model.predict( sample ));
77 Scalar c = colors[response];
79 circle( img, Point(j, i), 1, c*0.75, FILLED );
83 //draw the clustered samples
84 for( i = 0; i < nsamples; i++ )
86 Point pt(cvRound(samples.at<float>(i, 0)), cvRound(samples.at<float>(i, 1)));
87 circle( img, pt, 1, colors[labels.at<int>(i)], FILLED );
90 imshow( "EM-clustering result", img );