add the possibility to add traincascade accuracy
authorStevenPuttemans <steven.puttemans@kuleuven.be>
Fri, 3 Apr 2015 07:39:23 +0000 (09:39 +0200)
committerStevenPuttemans <steven.puttemans@kuleuven.be>
Fri, 3 Apr 2015 10:47:09 +0000 (12:47 +0200)
apps/traincascade/cascadeclassifier.cpp
apps/traincascade/cascadeclassifier.h
apps/traincascade/traincascade.cpp
doc/user_guide/ug_traincascade.rst

index 4d02983..aa97708 100644 (file)
@@ -136,7 +136,8 @@ bool CvCascadeClassifier::train( const string _cascadeDirName,
                                 const CvCascadeParams& _cascadeParams,
                                 const CvFeatureParams& _featureParams,
                                 const CvCascadeBoostParams& _stageParams,
-                                bool baseFormatSave )
+                                bool baseFormatSave,
+                                double acceptanceRatioBreakValue)
 {
     // Start recording clock ticks for training time output
     const clock_t begin_time = clock();
@@ -186,6 +187,7 @@ bool CvCascadeClassifier::train( const string _cascadeDirName,
     cout << "numStages: " << numStages << endl;
     cout << "precalcValBufSize[Mb] : " << _precalcValBufSize << endl;
     cout << "precalcIdxBufSize[Mb] : " << _precalcIdxBufSize << endl;
+    cout << "acceptanceRatioBreakValue : " << acceptanceRatioBreakValue << endl;
     cascadeParams.printAttrs();
     stageParams->printAttrs();
     featureParams->printAttrs();
@@ -208,15 +210,20 @@ bool CvCascadeClassifier::train( const string _cascadeDirName,
         if ( !updateTrainingSet( requiredLeafFARate, tempLeafFARate ) )
         {
             cout << "Train dataset for temp stage can not be filled. "
-                "Branch training terminated." << endl;
+                    "Branch training terminated." << endl;
             break;
         }
         if( tempLeafFARate <= requiredLeafFARate )
         {
             cout << "Required leaf false alarm rate achieved. "
-                 "Branch training terminated." << endl;
+                    "Branch training terminated." << endl;
             break;
         }
+        if( (tempLeafFARate <= acceptanceRatioBreakValue) && (acceptanceRatioBreakValue >= 0) ){
+            cout << "The required acceptanceRatio for the model has been reached to avoid overfitting of trainingdata. "
+                    "Branch training terminated." << endl;
+            break;
+}
 
         CvCascadeBoost* tempStage = new CvCascadeBoost;
         bool isStageTrained = tempStage->train( (CvFeatureEvaluator*)featureEvaluator,
index a4dbf90..2e6bf77 100644 (file)
@@ -96,7 +96,8 @@ public:
                 const CvCascadeParams& _cascadeParams,
                 const CvFeatureParams& _featureParams,
                 const CvCascadeBoostParams& _stageParams,
-                bool baseFormatSave = false );
+                bool baseFormatSave = false,
+                double acceptanceRatioBreakValue = -1.0 );
 private:
     int predict( int sampleIdx );
     void save( const std::string cascadeDirName, bool baseFormat = false );
index b860114..c6a7394 100644 (file)
@@ -17,6 +17,7 @@ int main( int argc, char* argv[] )
     int precalcValBufSize = 1024,
         precalcIdxBufSize = 1024;
     bool baseFormatSave = false;
+    double acceptanceRatioBreakValue = -1.0;
 
     CvCascadeParams cascadeParams;
     CvCascadeBoostParams stageParams;
@@ -37,6 +38,7 @@ int main( int argc, char* argv[] )
         cout << "  [-precalcValBufSize <precalculated_vals_buffer_size_in_Mb = " << precalcValBufSize << ">]" << endl;
         cout << "  [-precalcIdxBufSize <precalculated_idxs_buffer_size_in_Mb = " << precalcIdxBufSize << ">]" << endl;
         cout << "  [-baseFormatSave]" << endl;
+        cout << "  [-acceptanceRatioBreakValue <value> = " << acceptanceRatioBreakValue << ">]" << endl;
         cascadeParams.printDefaults();
         stageParams.printDefaults();
         for( int fi = 0; fi < fc; fi++ )
@@ -83,6 +85,10 @@ int main( int argc, char* argv[] )
         {
             baseFormatSave = true;
         }
+        else if( !strcmp( argv[i], "-acceptanceRatioBreakValue" ) )
+        {
+            acceptanceRatioBreakValue = atof(argv[++i]);
+        }
         else if ( cascadeParams.scanAttr( argv[i], argv[i+1] ) ) { i++; }
         else if ( stageParams.scanAttr( argv[i], argv[i+1] ) ) { i++; }
         else if ( !set )
@@ -108,6 +114,7 @@ int main( int argc, char* argv[] )
                       cascadeParams,
                       *featureParams[cascadeParams.featureType],
                       stageParams,
-                      baseFormatSave );
+                      baseFormatSave,
+                      acceptanceRatioBreakValue );
     return 0;
 }
index 9bf4ae9..49cdbd6 100644 (file)
@@ -294,6 +294,10 @@ Command line arguments of ``opencv_traincascade`` application grouped by purpose
 
         This argument is actual in case of Haar-like features. If it is specified, the cascade will be saved in the old format.
 
+    * ``-acceptanceRatioBreakValue``
+
+        This argument is used to determine how precise your model should keep learning and when to stop. A good guideline is to train not further than 10e-5, to ensure the model does not overtrain on your training data. By default this value is set to -1 to disable this feature.
+
 #.
 
     Cascade parameters: