1 #ifndef _OPENCV_BOOST_H_
2 #define _OPENCV_BOOST_H_
4 #include "traincascade_features.h"
7 struct CvCascadeBoostParams : CvBoostParams
12 CvCascadeBoostParams();
13 CvCascadeBoostParams( int _boostType, float _minHitRate, float _maxFalseAlarm,
14 double _weightTrimRate, int _maxDepth, int _maxWeakCount );
15 virtual ~CvCascadeBoostParams() {}
16 void write( FileStorage &fs ) const;
17 bool read( const FileNode &node );
18 virtual void printDefaults() const;
19 virtual void printAttrs() const;
20 virtual bool scanAttr( const String prmName, const String val);
23 struct CvCascadeBoostTrainData : CvDTreeTrainData
25 CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
26 const CvDTreeParams& _params );
27 CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
28 int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
29 const CvDTreeParams& _params = CvDTreeParams() );
30 virtual void setData( const CvFeatureEvaluator* _featureEvaluator,
31 int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
32 const CvDTreeParams& _params=CvDTreeParams() );
35 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
37 virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf );
38 virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf);
39 virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf );
41 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf,
42 const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf );
43 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf );
44 virtual float getVarValue( int vi, int si );
45 virtual void free_train_data();
47 const CvFeatureEvaluator* featureEvaluator;
48 Mat valCache; // precalculated feature values (CV_32FC1)
49 CvMat _resp; // for casting
50 int numPrecalcVal, numPrecalcIdx;
53 class CvCascadeBoostTree : public CvBoostTree
56 virtual CvDTreeNode* predict( int sampleIdx ) const;
57 void write( FileStorage &fs, const Mat& featureMap );
58 void read( const FileNode &node, CvBoost* _ensemble, CvDTreeTrainData* _data );
59 void markFeaturesInMap( Mat& featureMap );
61 virtual void split_node_data( CvDTreeNode* n );
64 class CvCascadeBoost : public CvBoost
67 virtual bool train( const CvFeatureEvaluator* _featureEvaluator,
68 int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
69 const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
70 virtual float predict( int sampleIdx, bool returnSum = false ) const;
72 float getThreshold() const { return threshold; }
73 void write( FileStorage &fs, const Mat& featureMap ) const;
74 bool read( const FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
75 const CvCascadeBoostParams& _params );
76 void markUsedFeaturesInMap( Mat& featureMap );
78 virtual bool set_params( const CvBoostParams& _params );
79 virtual void update_weights( CvBoostTree* tree );
80 virtual bool isErrDesired();
83 float minHitRate, maxFalseAlarm;