b6ff6c8263910bb0579107cec04b96b69f581ad2
[platform/upstream/opencv.git] / modules / legacy / src / em.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright( C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 //(including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort(including negligence or otherwise) arising in any way out of
38 // the use of this software, even ifadvised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "precomp.hpp"
43
44 using namespace cv;
45
46 CvEMParams::CvEMParams() : nclusters(10), cov_mat_type(CvEM::COV_MAT_DIAGONAL),
47     start_step(CvEM::START_AUTO_STEP), probs(0), weights(0), means(0), covs(0)
48 {
49     term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
50 }
51
52 CvEMParams::CvEMParams( int _nclusters, int _cov_mat_type, int _start_step,
53                         CvTermCriteria _term_crit, const CvMat* _probs,
54                         const CvMat* _weights, const CvMat* _means, const CvMat** _covs ) :
55                         nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
56                         probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
57 {}
58
59 CvEM::CvEM() : logLikelihood(DBL_MAX)
60 {
61 }
62
63 CvEM::CvEM( const CvMat* samples, const CvMat* sample_idx,
64             CvEMParams params, CvMat* labels ) : logLikelihood(DBL_MAX)
65 {
66     train(samples, sample_idx, params, labels);
67 }
68
69 CvEM::~CvEM()
70 {
71     clear();
72 }
73
74 void CvEM::clear()
75 {
76     emObj.clear();
77 }
78
79 void CvEM::read( CvFileStorage* fs, CvFileNode* node )
80 {
81     FileNode fn(fs, node);
82     emObj.read(fn);
83     set_mat_hdrs();
84 }
85
86 void CvEM::write( CvFileStorage* _fs, const char* name ) const
87 {
88     FileStorage fs(_fs, false);
89     if(name)
90         fs << name << "{";
91     emObj.write(fs);
92     if(name)
93         fs << "}";
94 }
95
96 double CvEM::calcLikelihood( const Mat &input_sample ) const
97 {
98     return emObj.predict(input_sample)[0];
99 }
100
101 float
102 CvEM::predict( const CvMat* _sample, CvMat* _probs ) const
103 {
104     Mat prbs0 = cvarrToMat(_probs), prbs = prbs0, sample = cvarrToMat(_sample);
105     int cls = static_cast<int>(emObj.predict(sample, _probs ? _OutputArray(prbs) :
106                                              (OutputArray)cv::noArray())[1]);
107     if(_probs)
108     {
109         if( prbs.data != prbs0.data )
110         {
111             CV_Assert( prbs.size == prbs0.size );
112             prbs.convertTo(prbs0, prbs0.type());
113         }
114     }
115     return (float)cls;
116 }
117
118 void CvEM::set_mat_hdrs()
119 {
120     if(emObj.isTrained())
121     {
122         meansHdr = emObj.get<Mat>("means");
123         int K = emObj.get<int>("nclusters");
124         covsHdrs.resize(K);
125         covsPtrs.resize(K);
126         const std::vector<Mat>& covs = emObj.get<std::vector<Mat> >("covs");
127         for(size_t i = 0; i < covsHdrs.size(); i++)
128         {
129             covsHdrs[i] = covs[i];
130             covsPtrs[i] = &covsHdrs[i];
131         }
132         weightsHdr = emObj.get<Mat>("weights");
133         probsHdr = probs;
134     }
135 }
136
137 static
138 void init_params(const CvEMParams& src,
139                  Mat& prbs, Mat& weights,
140                  Mat& means, std::vector<Mat>& covsHdrs)
141 {
142     prbs = cv::cvarrToMat(src.probs);
143     weights = cv::cvarrToMat(src.weights);
144     means = cv::cvarrToMat(src.means);
145
146     if(src.covs)
147     {
148         covsHdrs.resize(src.nclusters);
149         for(size_t i = 0; i < covsHdrs.size(); i++)
150             covsHdrs[i] = cv::cvarrToMat(src.covs[i]);
151     }
152 }
153
154 bool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx,
155                   CvEMParams _params, CvMat* _labels )
156 {
157     CV_Assert(_sample_idx == 0);
158     Mat samples = cvarrToMat(_samples), labels0, labels;
159     if( _labels )
160         labels0 = labels = cvarrToMat(_labels);
161
162     bool isOk = train(samples, Mat(), _params, _labels ? &labels : 0);
163     CV_Assert( labels0.data == labels.data );
164
165     return isOk;
166 }
167
168 int CvEM::get_nclusters() const
169 {
170     return emObj.get<int>("nclusters");
171 }
172
173 const CvMat* CvEM::get_means() const
174 {
175     return emObj.isTrained() ? &meansHdr : 0;
176 }
177
178 const CvMat** CvEM::get_covs() const
179 {
180     return emObj.isTrained() ? (const CvMat**)&covsPtrs[0] : 0;
181 }
182
183 const CvMat* CvEM::get_weights() const
184 {
185     return emObj.isTrained() ? &weightsHdr : 0;
186 }
187
188 const CvMat* CvEM::get_probs() const
189 {
190     return emObj.isTrained() ? &probsHdr : 0;
191 }
192
193 using namespace cv;
194
195 CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params )
196 {
197     train(samples, sample_idx, params, 0);
198 }
199
200 bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
201                  CvEMParams _params, Mat* _labels )
202 {
203     CV_Assert(_sample_idx.empty());
204     Mat prbs, weights, means, logLikelihoods;
205     std::vector<Mat> covshdrs;
206     init_params(_params, prbs, weights, means, covshdrs);
207
208     emObj = EM(_params.nclusters, _params.cov_mat_type, _params.term_crit);
209     bool isOk = false;
210     if( _params.start_step == EM::START_AUTO_STEP )
211         isOk = emObj.train(_samples,
212                            logLikelihoods, _labels ? _OutputArray(*_labels) :
213                            (OutputArray)cv::noArray(), probs);
214     else if( _params.start_step == EM::START_E_STEP )
215         isOk = emObj.trainE(_samples, means, covshdrs, weights,
216                             logLikelihoods, _labels ? _OutputArray(*_labels) :
217                             (OutputArray)cv::noArray(), probs);
218     else if( _params.start_step == EM::START_M_STEP )
219         isOk = emObj.trainM(_samples, prbs,
220                             logLikelihoods, _labels ? _OutputArray(*_labels) :
221                             (OutputArray)cv::noArray(), probs);
222     else
223         CV_Error(CV_StsBadArg, "Bad start type of EM algorithm");
224
225     if(isOk)
226     {
227         logLikelihood = sum(logLikelihoods).val[0];
228         set_mat_hdrs();
229     }
230
231     return isOk;
232 }
233
234 float
235 CvEM::predict( const Mat& _sample, Mat* _probs ) const
236 {
237     return static_cast<float>(emObj.predict(_sample, _probs ?
238                                             _OutputArray(*_probs) :
239                                             (OutputArray)cv::noArray())[1]);
240 }
241
242 int CvEM::getNClusters() const
243 {
244     return emObj.get<int>("nclusters");
245 }
246
247 Mat CvEM::getMeans() const
248 {
249     return emObj.get<Mat>("means");
250 }
251
252 void CvEM::getCovs(std::vector<Mat>& _covs) const
253 {
254     _covs = emObj.get<std::vector<Mat> >("covs");
255 }
256
257 Mat CvEM::getWeights() const
258 {
259     return emObj.get<Mat>("weights");
260 }
261
262 Mat CvEM::getProbs() const
263 {
264     return probs;
265 }
266
267
268 /* End of file. */