added load/save feature for SVM classifier in letter_recog sample
authorVadim Pisarevsky <vadim.pisarevsky@gmail.com>
Tue, 19 Mar 2013 14:41:38 +0000 (18:41 +0400)
committerVadim Pisarevsky <vadim.pisarevsky@gmail.com>
Tue, 19 Mar 2013 14:41:38 +0000 (18:41 +0400)
samples/cpp/letter_recog.cpp

index 144dbe8..74d5971 100644 (file)
@@ -131,7 +131,7 @@ int build_rtrees_classifier( char* data_filename,
             printf( "Could not read the classifier %s\n", filename_to_load );
             return -1;
         }
-        printf( "The classifier %s is loaded.\n", data_filename );
+        printf( "The classifier %s is loaded.\n", filename_to_load );
     }
     else
     {
@@ -262,7 +262,7 @@ int build_boost_classifier( char* data_filename,
             printf( "Could not read the classifier %s\n", filename_to_load );
             return -1;
         }
-        printf( "The classifier %s is loaded.\n", data_filename );
+        printf( "The classifier %s is loaded.\n", filename_to_load );
     }
     else
     {
@@ -403,7 +403,7 @@ int build_mlp_classifier( char* data_filename,
             printf( "Could not read the classifier %s\n", filename_to_load );
             return -1;
         }
-        printf( "The classifier %s is loaded.\n", data_filename );
+        printf( "The classifier %s is loaded.\n", filename_to_load );
     }
     else
     {
@@ -639,10 +639,11 @@ int build_nbayes_classifier( char* data_filename )
 }
 
 static
-int build_svm_classifier( char* data_filename )
+int build_svm_classifier( char* data_filename, const char* filename_to_save, const char* filename_to_load )
 {
     CvMat* data = 0;
     CvMat* responses = 0;
+    CvMat* train_resp = 0;
     CvMat train_data;
     int nsamples_all = 0, ntrain_samples = 0;
     int var_count;
@@ -666,13 +667,29 @@ int build_svm_classifier( char* data_filename )
     ntrain_samples = (int)(nsamples_all*0.1);
     var_count = data->cols;
 
-    // train classifier
-    printf( "Training the classifier (may take a few minutes)...\n");
-    cvGetRows( data, &train_data, 0, ntrain_samples );
-    CvMat* train_resp = cvCreateMat( ntrain_samples, 1, CV_32FC1);
-    for (int i = 0; i < ntrain_samples; i++)
-        train_resp->data.fl[i] = responses->data.fl[i];
-    svm.train(&train_data, train_resp, 0, 0, param);
+    // Create or load Random Trees classifier
+    if( filename_to_load )
+    {
+        // load classifier from the specified file
+        svm.load( filename_to_load );
+        ntrain_samples = 0;
+        if( svm.get_var_count() == 0 )
+        {
+            printf( "Could not read the classifier %s\n", filename_to_load );
+            return -1;
+        }
+        printf( "The classifier %s is loaded.\n", filename_to_load );
+    }
+    else
+    {
+        // train classifier
+        printf( "Training the classifier (may take a few minutes)...\n");
+        cvGetRows( data, &train_data, 0, ntrain_samples );
+        train_resp = cvCreateMat( ntrain_samples, 1, CV_32FC1);
+        for (int i = 0; i < ntrain_samples; i++)
+            train_resp->data.fl[i] = responses->data.fl[i];
+        svm.train(&train_data, train_resp, 0, 0, param);
+    }
 
     // classification
     std::vector<float> _sample(var_count * (nsamples_all - ntrain_samples));
@@ -705,6 +722,9 @@ int build_svm_classifier( char* data_filename )
 
     printf("true_resp = %f%%\n", (float)true_resp / (nsamples_all - ntrain_samples) * 100);
 
+    if( filename_to_save )
+        svm.save( filename_to_save );
+
     cvReleaseMat( &train_resp );
     cvReleaseMat( &result );
     cvReleaseMat( &data );
@@ -775,7 +795,7 @@ int main( int argc, char *argv[] )
         method == 4 ?
         build_nbayes_classifier( data_filename) :
         method == 5 ?
-        build_svm_classifier( data_filename ):
+        build_svm_classifier( data_filename, filename_to_save, filename_to_load ):
         -1) < 0)
     {
         help();