enabled gst
[profile/ivi/opencv.git] / modules / ml / src / rtrees.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "precomp.hpp"
44
45 namespace cv {
46 namespace ml {
47
48 //////////////////////////////////////////////////////////////////////////////////////////
49 //                                  Random trees                                        //
50 //////////////////////////////////////////////////////////////////////////////////////////
51 RTrees::Params::Params()
52     : DTrees::Params(5, 10, 0.f, false, 10, 0, false, false, Mat())
53 {
54     calcVarImportance = false;
55     nactiveVars = 0;
56     termCrit = TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 50, 0.1);
57 }
58
59 RTrees::Params::Params( int _maxDepth, int _minSampleCount,
60                         double _regressionAccuracy, bool _useSurrogates,
61                         int _maxCategories, const Mat& _priors,
62                         bool _calcVarImportance, int _nactiveVars,
63                         TermCriteria _termCrit )
64     : DTrees::Params(_maxDepth, _minSampleCount, _regressionAccuracy, _useSurrogates,
65                      _maxCategories, 0, false, false, _priors)
66 {
67     calcVarImportance = _calcVarImportance;
68     nactiveVars = _nactiveVars;
69     termCrit = _termCrit;
70 }
71
72
73 class DTreesImplForRTrees : public DTreesImpl
74 {
75 public:
76     DTreesImplForRTrees() {}
77     virtual ~DTreesImplForRTrees() {}
78
79     void setRParams(const RTrees::Params& p)
80     {
81         rparams = p;
82     }
83
84     RTrees::Params getRParams() const
85     {
86         return rparams;
87     }
88
89     void clear()
90     {
91         DTreesImpl::clear();
92         oobError = 0.;
93         rng = RNG((uint64)-1);
94     }
95
96     const vector<int>& getActiveVars()
97     {
98         int i, nvars = (int)allVars.size(), m = (int)activeVars.size();
99         for( i = 0; i < nvars; i++ )
100         {
101             int i1 = rng.uniform(0, nvars);
102             int i2 = rng.uniform(0, nvars);
103             std::swap(allVars[i1], allVars[i2]);
104         }
105         for( i = 0; i < m; i++ )
106             activeVars[i] = allVars[i];
107         return activeVars;
108     }
109
110     void startTraining( const Ptr<TrainData>& trainData, int flags )
111     {
112         DTreesImpl::startTraining(trainData, flags);
113         int nvars = w->data->getNVars();
114         int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
115         m = std::min(std::max(m, 1), nvars);
116         allVars.resize(nvars);
117         activeVars.resize(m);
118         for( i = 0; i < nvars; i++ )
119             allVars[i] = varIdx[i];
120     }
121
122     void endTraining()
123     {
124         DTreesImpl::endTraining();
125         vector<int> a, b;
126         std::swap(allVars, a);
127         std::swap(activeVars, b);
128     }
129
130     bool train( const Ptr<TrainData>& trainData, int flags )
131     {
132         Params dp(rparams.maxDepth, rparams.minSampleCount, rparams.regressionAccuracy,
133                   rparams.useSurrogates, rparams.maxCategories, rparams.CVFolds,
134                   rparams.use1SERule, rparams.truncatePrunedTree, rparams.priors);
135         setDParams(dp);
136         startTraining(trainData, flags);
137         int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
138             rparams.termCrit.maxCount : 10000;
139         int i, j, k, vi, vi_, n = (int)w->sidx.size();
140         int nclasses = (int)classLabels.size();
141         double eps = (rparams.termCrit.type & TermCriteria::EPS) != 0 &&
142             rparams.termCrit.epsilon > 0 ? rparams.termCrit.epsilon : 0.;
143         vector<int> sidx(n);
144         vector<uchar> oobmask(n);
145         vector<int> oobidx;
146         vector<int> oobperm;
147         vector<double> oobres(n, 0.);
148         vector<int> oobcount(n, 0);
149         vector<int> oobvotes(n*nclasses, 0);
150         int nvars = w->data->getNVars();
151         int nallvars = w->data->getNAllVars();
152         const int* vidx = !varIdx.empty() ? &varIdx[0] : 0;
153         vector<float> samplebuf(nallvars);
154         Mat samples = w->data->getSamples();
155         float* psamples = samples.ptr<float>();
156         size_t sstep0 = samples.step1(), sstep1 = 1;
157         Mat sample0, sample(nallvars, 1, CV_32F, &samplebuf[0]);
158         int predictFlags = _isClassifier ? (PREDICT_MAX_VOTE + RAW_OUTPUT) : PREDICT_SUM;
159
160         bool calcOOBError = eps > 0 || rparams.calcVarImportance;
161         double max_response = 0.;
162
163         if( w->data->getLayout() == COL_SAMPLE )
164             std::swap(sstep0, sstep1);
165
166         if( !_isClassifier )
167         {
168             for( i = 0; i < n; i++ )
169             {
170                 double val = std::abs(w->ord_responses[w->sidx[i]]);
171                 max_response = std::max(max_response, val);
172             }
173         }
174
175         if( rparams.calcVarImportance )
176             varImportance.resize(nallvars, 0.f);
177
178         for( treeidx = 0; treeidx < ntrees; treeidx++ )
179         {
180             for( i = 0; i < n; i++ )
181                 oobmask[i] = (uchar)1;
182
183             for( i = 0; i < n; i++ )
184             {
185                 j = rng.uniform(0, n);
186                 sidx[i] = w->sidx[j];
187                 oobmask[j] = (uchar)0;
188             }
189             int root = addTree( sidx );
190             if( root < 0 )
191                 return false;
192
193             if( calcOOBError )
194             {
195                 oobidx.clear();
196                 for( i = 0; i < n; i++ )
197                 {
198                     if( !oobmask[i] )
199                         oobidx.push_back(i);
200                 }
201                 int n_oob = (int)oobidx.size();
202                 // if there is no out-of-bag samples, we can not compute OOB error
203                 // nor update the variable importance vector; so we proceed to the next tree
204                 if( n_oob == 0 )
205                     continue;
206                 double ncorrect_responses = 0.;
207
208                 oobError = 0.;
209                 for( i = 0; i < n_oob; i++ )
210                 {
211                     j = oobidx[i];
212                     sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
213
214                     double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
215                     if( !_isClassifier )
216                     {
217                         oobres[j] += val;
218                         oobcount[j]++;
219                         double true_val = w->ord_responses[w->sidx[j]];
220                         double a = oobres[j]/oobcount[j] - true_val;
221                         oobError += a*a;
222                         val = (val - true_val)/max_response;
223                         ncorrect_responses += std::exp( -val*val );
224                     }
225                     else
226                     {
227                         int ival = cvRound(val);
228                         int* votes = &oobvotes[j*nclasses];
229                         votes[ival]++;
230                         int best_class = 0;
231                         for( k = 1; k < nclasses; k++ )
232                             if( votes[best_class] < votes[k] )
233                                 best_class = k;
234                         int diff = best_class != w->cat_responses[w->sidx[j]];
235                         oobError += diff;
236                         ncorrect_responses += diff == 0;
237                     }
238                 }
239
240                 oobError /= n_oob;
241                 if( rparams.calcVarImportance && n_oob > 1 )
242                 {
243                     oobperm.resize(n_oob);
244                     for( i = 0; i < n_oob; i++ )
245                         oobperm[i] = oobidx[i];
246
247                     for( vi_ = 0; vi_ < nvars; vi_++ )
248                     {
249                         vi = vidx ? vidx[vi_] : vi_;
250                         double ncorrect_responses_permuted = 0;
251                         for( i = 0; i < n_oob; i++ )
252                         {
253                             int i1 = rng.uniform(0, n_oob);
254                             int i2 = rng.uniform(0, n_oob);
255                             std::swap(i1, i2);
256                         }
257
258                         for( i = 0; i < n_oob; i++ )
259                         {
260                             j = oobidx[i];
261                             int vj = oobperm[i];
262                             sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
263                             for( k = 0; k < nallvars; k++ )
264                                 sample.at<float>(k) = sample0.at<float>(k);
265                             sample.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
266
267                             double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
268                             if( !_isClassifier )
269                             {
270                                 val = (val - w->ord_responses[w->sidx[j]])/max_response;
271                                 ncorrect_responses_permuted += exp( -val*val );
272                             }
273                             else
274                                 ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]];
275                         }
276                         varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted);
277                     }
278                 }
279             }
280             if( calcOOBError && oobError < eps )
281                 break;
282         }
283
284         if( rparams.calcVarImportance )
285         {
286             for( vi_ = 0; vi_ < nallvars; vi_++ )
287                 varImportance[vi_] = std::max(varImportance[vi_], 0.f);
288             normalize(varImportance, varImportance, 1., 0, NORM_L1);
289         }
290         endTraining();
291         return true;
292     }
293
294     void writeTrainingParams( FileStorage& fs ) const
295     {
296         DTreesImpl::writeTrainingParams(fs);
297         fs << "nactive_vars" << rparams.nactiveVars;
298     }
299
300     void write( FileStorage& fs ) const
301     {
302         if( roots.empty() )
303             CV_Error( CV_StsBadArg, "RTrees have not been trained" );
304
305         writeParams(fs);
306
307         fs << "oob_error" << oobError;
308         if( !varImportance.empty() )
309             fs << "var_importance" << varImportance;
310
311         int k, ntrees = (int)roots.size();
312
313         fs << "ntrees" << ntrees
314            << "trees" << "[";
315
316         for( k = 0; k < ntrees; k++ )
317         {
318             fs << "{";
319             writeTree(fs, roots[k]);
320             fs << "}";
321         }
322
323         fs << "]";
324     }
325
326     void readParams( const FileNode& fn )
327     {
328         DTreesImpl::readParams(fn);
329         rparams.maxDepth = params0.maxDepth;
330         rparams.minSampleCount = params0.minSampleCount;
331         rparams.regressionAccuracy = params0.regressionAccuracy;
332         rparams.useSurrogates = params0.useSurrogates;
333         rparams.maxCategories = params0.maxCategories;
334         rparams.priors = params0.priors;
335
336         FileNode tparams_node = fn["training_params"];
337         rparams.nactiveVars = (int)tparams_node["nactive_vars"];
338     }
339
340     void read( const FileNode& fn )
341     {
342         clear();
343
344         //int nclasses = (int)fn["nclasses"];
345         //int nsamples = (int)fn["nsamples"];
346         oobError = (double)fn["oob_error"];
347         int ntrees = (int)fn["ntrees"];
348
349         fn["var_importance"] >> varImportance;
350
351         readParams(fn);
352
353         FileNode trees_node = fn["trees"];
354         FileNodeIterator it = trees_node.begin();
355         CV_Assert( ntrees == (int)trees_node.size() );
356
357         for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
358         {
359             FileNode nfn = (*it)["nodes"];
360             readTree(nfn);
361         }
362     }
363
364     RTrees::Params rparams;
365     double oobError;
366     vector<float> varImportance;
367     vector<int> allVars, activeVars;
368     RNG rng;
369 };
370
371
372 class RTreesImpl : public RTrees
373 {
374 public:
375     RTreesImpl() {}
376     virtual ~RTreesImpl() {}
377
378     String getDefaultModelName() const { return "opencv_ml_rtrees"; }
379
380     bool train( const Ptr<TrainData>& trainData, int flags )
381     {
382         return impl.train(trainData, flags);
383     }
384
385     float predict( InputArray samples, OutputArray results, int flags ) const
386     {
387         return impl.predict(samples, results, flags);
388     }
389
390     void write( FileStorage& fs ) const
391     {
392         impl.write(fs);
393     }
394
395     void read( const FileNode& fn )
396     {
397         impl.read(fn);
398     }
399
400     void setRParams(const Params& p) { impl.setRParams(p); }
401     Params getRParams() const { return impl.getRParams(); }
402
403     Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
404     int getVarCount() const { return impl.getVarCount(); }
405
406     bool isTrained() const { return impl.isTrained(); }
407     bool isClassifier() const { return impl.isClassifier(); }
408
409     const vector<int>& getRoots() const { return impl.getRoots(); }
410     const vector<Node>& getNodes() const { return impl.getNodes(); }
411     const vector<Split>& getSplits() const { return impl.getSplits(); }
412     const vector<int>& getSubsets() const { return impl.getSubsets(); }
413
414     DTreesImplForRTrees impl;
415 };
416
417
418 Ptr<RTrees> RTrees::create(const Params& params)
419 {
420     Ptr<RTreesImpl> p = makePtr<RTreesImpl>();
421     p->setRParams(params);
422     return p;
423 }
424
425 }}
426
427 // End of file.