fixed warnings
[profile/ivi/opencv.git] / samples / cpp / em.cpp
1 #include "opencv2/highgui.hpp"
2 #include "opencv2/ml.hpp"
3
4 using namespace cv;
5 using namespace cv::ml;
6
7 int main( int /*argc*/, char** /*argv*/ )
8 {
9     const int N = 4;
10     const int N1 = (int)sqrt((double)N);
11     const Scalar colors[] =
12     {
13         Scalar(0,0,255), Scalar(0,255,0),
14         Scalar(0,255,255),Scalar(255,255,0)
15     };
16
17     int i, j;
18     int nsamples = 100;
19     Mat samples( nsamples, 2, CV_32FC1 );
20     Mat labels;
21     Mat img = Mat::zeros( Size( 500, 500 ), CV_8UC3 );
22     Mat sample( 1, 2, CV_32FC1 );
23
24     samples = samples.reshape(2, 0);
25     for( i = 0; i < N; i++ )
26     {
27         // form the training samples
28         Mat samples_part = samples.rowRange(i*nsamples/N, (i+1)*nsamples/N );
29
30         Scalar mean(((i%N1)+1)*img.rows/(N1+1),
31                     ((i/N1)+1)*img.rows/(N1+1));
32         Scalar sigma(30,30);
33         randn( samples_part, mean, sigma );
34     }
35     samples = samples.reshape(1, 0);
36
37     // cluster the data
38     Ptr<EM> em_model = EM::train( samples, noArray(), labels, noArray(),
39             EM::Params(N, EM::COV_MAT_SPHERICAL,
40                        TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, 300, 0.1)));
41
42     // classify every image pixel
43     for( i = 0; i < img.rows; i++ )
44     {
45         for( j = 0; j < img.cols; j++ )
46         {
47             sample.at<float>(0) = (float)j;
48             sample.at<float>(1) = (float)i;
49             int response = cvRound(em_model->predict2( sample, noArray() )[1]);
50             Scalar c = colors[response];
51
52             circle( img, Point(j, i), 1, c*0.75, FILLED );
53         }
54     }
55
56     //draw the clustered samples
57     for( i = 0; i < nsamples; i++ )
58     {
59         Point pt(cvRound(samples.at<float>(i, 0)), cvRound(samples.at<float>(i, 1)));
60         circle( img, pt, 1, colors[labels.at<int>(i)], FILLED );
61     }
62
63     imshow( "EM-clustering result", img );
64     waitKey(0);
65
66     return 0;
67 }