1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
13 // Copyright( C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 //(including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort(including negligence or otherwise) arising in any way out of
38 // the use of this software, even ifadvised of the possibility of such damage.
42 #include "precomp.hpp"
46 CvEMParams::CvEMParams() : nclusters(10), cov_mat_type(CvEM::COV_MAT_DIAGONAL),
47 start_step(CvEM::START_AUTO_STEP), probs(0), weights(0), means(0), covs(0)
49 term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
52 CvEMParams::CvEMParams( int _nclusters, int _cov_mat_type, int _start_step,
53 CvTermCriteria _term_crit, const CvMat* _probs,
54 const CvMat* _weights, const CvMat* _means, const CvMat** _covs ) :
55 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
56 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
59 CvEM::CvEM() : logLikelihood(DBL_MAX)
63 CvEM::CvEM( const CvMat* samples, const CvMat* sample_idx,
64 CvEMParams params, CvMat* labels ) : logLikelihood(DBL_MAX)
66 train(samples, sample_idx, params, labels);
79 void CvEM::read( CvFileStorage* fs, CvFileNode* node )
81 FileNode fn(fs, node);
86 void CvEM::write( CvFileStorage* _fs, const char* name ) const
88 FileStorage fs(_fs, false);
96 double CvEM::calcLikelihood( const Mat &input_sample ) const
98 return emObj.predict(input_sample)[0];
102 CvEM::predict( const CvMat* _sample, CvMat* _probs ) const
104 Mat prbs0 = cvarrToMat(_probs), prbs = prbs0, sample = cvarrToMat(_sample);
105 int cls = static_cast<int>(emObj.predict(sample, _probs ? _OutputArray(prbs) :
106 (OutputArray)cv::noArray())[1]);
109 if( prbs.data != prbs0.data )
111 CV_Assert( prbs.size == prbs0.size );
112 prbs.convertTo(prbs0, prbs0.type());
118 void CvEM::set_mat_hdrs()
120 if(emObj.isTrained())
122 meansHdr = emObj.get<Mat>("means");
123 int K = emObj.get<int>("nclusters");
126 const std::vector<Mat>& covs = emObj.get<std::vector<Mat> >("covs");
127 for(size_t i = 0; i < covsHdrs.size(); i++)
129 covsHdrs[i] = covs[i];
130 covsPtrs[i] = &covsHdrs[i];
132 weightsHdr = emObj.get<Mat>("weights");
138 void init_params(const CvEMParams& src,
139 Mat& prbs, Mat& weights,
140 Mat& means, std::vector<Mat>& covsHdrs)
142 prbs = cv::cvarrToMat(src.probs);
143 weights = cv::cvarrToMat(src.weights);
144 means = cv::cvarrToMat(src.means);
148 covsHdrs.resize(src.nclusters);
149 for(size_t i = 0; i < covsHdrs.size(); i++)
150 covsHdrs[i] = cv::cvarrToMat(src.covs[i]);
154 bool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx,
155 CvEMParams _params, CvMat* _labels )
157 CV_Assert(_sample_idx == 0);
158 Mat samples = cvarrToMat(_samples), labels0, labels;
160 labels0 = labels = cvarrToMat(_labels);
162 bool isOk = train(samples, Mat(), _params, _labels ? &labels : 0);
163 CV_Assert( labels0.data == labels.data );
168 int CvEM::get_nclusters() const
170 return emObj.get<int>("nclusters");
173 const CvMat* CvEM::get_means() const
175 return emObj.isTrained() ? &meansHdr : 0;
178 const CvMat** CvEM::get_covs() const
180 return emObj.isTrained() ? (const CvMat**)&covsPtrs[0] : 0;
183 const CvMat* CvEM::get_weights() const
185 return emObj.isTrained() ? &weightsHdr : 0;
188 const CvMat* CvEM::get_probs() const
190 return emObj.isTrained() ? &probsHdr : 0;
195 CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params )
197 train(samples, sample_idx, params, 0);
200 bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
201 CvEMParams _params, Mat* _labels )
203 CV_Assert(_sample_idx.empty());
204 Mat prbs, weights, means, logLikelihoods;
205 std::vector<Mat> covshdrs;
206 init_params(_params, prbs, weights, means, covshdrs);
208 emObj = EM(_params.nclusters, _params.cov_mat_type, _params.term_crit);
210 if( _params.start_step == EM::START_AUTO_STEP )
211 isOk = emObj.train(_samples,
212 logLikelihoods, _labels ? _OutputArray(*_labels) :
213 (OutputArray)cv::noArray(), probs);
214 else if( _params.start_step == EM::START_E_STEP )
215 isOk = emObj.trainE(_samples, means, covshdrs, weights,
216 logLikelihoods, _labels ? _OutputArray(*_labels) :
217 (OutputArray)cv::noArray(), probs);
218 else if( _params.start_step == EM::START_M_STEP )
219 isOk = emObj.trainM(_samples, prbs,
220 logLikelihoods, _labels ? _OutputArray(*_labels) :
221 (OutputArray)cv::noArray(), probs);
223 CV_Error(CV_StsBadArg, "Bad start type of EM algorithm");
227 logLikelihood = sum(logLikelihoods).val[0];
235 CvEM::predict( const Mat& _sample, Mat* _probs ) const
237 return static_cast<float>(emObj.predict(_sample, _probs ?
238 _OutputArray(*_probs) :
239 (OutputArray)cv::noArray())[1]);
242 int CvEM::getNClusters() const
244 return emObj.get<int>("nclusters");
247 Mat CvEM::getMeans() const
249 return emObj.get<Mat>("means");
252 void CvEM::getCovs(std::vector<Mat>& _covs) const
254 _covs = emObj.get<std::vector<Mat> >("covs");
257 Mat CvEM::getWeights() const
259 return emObj.get<Mat>("weights");
262 Mat CvEM::getProbs() const