From 7e35f76d06185d7ea53e1ff4d904b826c97c7407 Mon Sep 17 00:00:00 2001 From: StevenPuttemans Date: Fri, 6 Mar 2015 11:52:26 +0100 Subject: [PATCH] allowing people to manually define how sharp a cascade classifier model should be trained --- apps/traincascade/cascadeclassifier.cpp | 13 ++++++++++--- apps/traincascade/cascadeclassifier.h | 3 ++- apps/traincascade/traincascade.cpp | 9 ++++++++- doc/user_guide/ug_traincascade.markdown | 6 ++++++ 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/apps/traincascade/cascadeclassifier.cpp b/apps/traincascade/cascadeclassifier.cpp index c9b524f..a64566e 100644 --- a/apps/traincascade/cascadeclassifier.cpp +++ b/apps/traincascade/cascadeclassifier.cpp @@ -135,7 +135,8 @@ bool CvCascadeClassifier::train( const string _cascadeDirName, const CvCascadeParams& _cascadeParams, const CvFeatureParams& _featureParams, const CvCascadeBoostParams& _stageParams, - bool baseFormatSave ) + bool baseFormatSave, + double acceptanceRatioBreakValue ) { // Start recording clock ticks for training time output const clock_t begin_time = clock(); @@ -185,6 +186,7 @@ bool CvCascadeClassifier::train( const string _cascadeDirName, cout << "numStages: " << numStages << endl; cout << "precalcValBufSize[Mb] : " << _precalcValBufSize << endl; cout << "precalcIdxBufSize[Mb] : " << _precalcIdxBufSize << endl; + cout << "acceptanceRatioBreakValue : " << acceptanceRatioBreakValue << endl; cascadeParams.printAttrs(); stageParams->printAttrs(); featureParams->printAttrs(); @@ -207,13 +209,18 @@ bool CvCascadeClassifier::train( const string _cascadeDirName, if ( !updateTrainingSet( tempLeafFARate ) ) { cout << "Train dataset for temp stage can not be filled. " - "Branch training terminated." << endl; + "Branch training terminated." << endl; break; } if( tempLeafFARate <= requiredLeafFARate ) { cout << "Required leaf false alarm rate achieved. " - "Branch training terminated." << endl; + "Branch training terminated." << endl; + break; + } + if( (tempLeafFARate <= acceptanceRatioBreakValue) && (acceptanceRatioBreakValue < 0) ){ + cout << "The required acceptanceRatio for the model has been reached to avoid overfitting of trainingdata. " + "Branch training terminated." << endl; break; } diff --git a/apps/traincascade/cascadeclassifier.h b/apps/traincascade/cascadeclassifier.h index 6d6cb5b..d8e0448 100644 --- a/apps/traincascade/cascadeclassifier.h +++ b/apps/traincascade/cascadeclassifier.h @@ -94,7 +94,8 @@ public: const CvCascadeParams& _cascadeParams, const CvFeatureParams& _featureParams, const CvCascadeBoostParams& _stageParams, - bool baseFormatSave = false ); + bool baseFormatSave = false, + double acceptanceRatioBreakValue = -1.0 ); private: int predict( int sampleIdx ); void save( const std::string cascadeDirName, bool baseFormat = false ); diff --git a/apps/traincascade/traincascade.cpp b/apps/traincascade/traincascade.cpp index d1c3e4e..ba8afb2 100644 --- a/apps/traincascade/traincascade.cpp +++ b/apps/traincascade/traincascade.cpp @@ -15,6 +15,7 @@ int main( int argc, char* argv[] ) int precalcValBufSize = 256, precalcIdxBufSize = 256; bool baseFormatSave = false; + double acceptanceRatioBreakValue = -1.0; CvCascadeParams cascadeParams; CvCascadeBoostParams stageParams; @@ -36,6 +37,7 @@ int main( int argc, char* argv[] ) cout << " [-precalcIdxBufSize ]" << endl; cout << " [-baseFormatSave]" << endl; cout << " [-numThreads ]" << endl; + cout << " [-acceptanceRatioBreakValue = " << acceptanceRatioBreakValue << ">]" << endl; cascadeParams.printDefaults(); stageParams.printDefaults(); for( int fi = 0; fi < fc; fi++ ) @@ -86,6 +88,10 @@ int main( int argc, char* argv[] ) { numThreads = atoi(argv[++i]); } + else if( !strcmp( argv[i], "-acceptanceRatioBreakValue" ) ) + { + acceptanceRatioBreakValue = atof(argv[++i]); + } else if ( cascadeParams.scanAttr( argv[i], argv[i+1] ) ) { i++; } else if ( stageParams.scanAttr( argv[i], argv[i+1] ) ) { i++; } else if ( !set ) @@ -112,6 +118,7 @@ int main( int argc, char* argv[] ) cascadeParams, *featureParams[cascadeParams.featureType], stageParams, - baseFormatSave ); + baseFormatSave, + acceptanceRatioBreakValue ); return 0; } diff --git a/doc/user_guide/ug_traincascade.markdown b/doc/user_guide/ug_traincascade.markdown index 059d25e..ccceb50 100644 --- a/doc/user_guide/ug_traincascade.markdown +++ b/doc/user_guide/ug_traincascade.markdown @@ -256,6 +256,12 @@ Command line arguments of opencv_traincascade application grouped by purposes: Maximum number of threads to use during training. Notice that the actual number of used threads may be lower, depending on your machine and compilation options. + - -acceptanceRatioBreakValue \ + + This argument is used to determine how precise your model should keep learning and when to stop. + A good guideline is to train not further than 10e-5, to ensure the model does not overtrain on your training data. + By default this value is set to -1 to disable this feature. + -# Cascade parameters: - -stageType \ -- 2.7.4