Added throwing exception when saving untrained SVM model
authorDaniil Osokin <daniil.osokin@itseez.com>
Mon, 13 Jan 2014 07:41:54 +0000 (11:41 +0400)
committerDaniil Osokin <daniil.osokin@itseez.com>
Mon, 13 Jan 2014 09:50:30 +0000 (13:50 +0400)
modules/ml/src/svm.cpp
modules/ml/test/test_save_load.cpp

index 674365b..f158805 100644 (file)
@@ -2298,14 +2298,24 @@ void CvSVM::write_params( CvFileStorage* fs ) const
 }
 
 
+static bool isSvmModelApplicable(int sv_total, int var_all, int var_count, int class_count)
+{
+    return (sv_total > 0 && var_count > 0 && var_count <= var_all && class_count >= 0);
+}
+
+
 void CvSVM::write( CvFileStorage* fs, const char* name ) const
 {
     CV_FUNCNAME( "CvSVM::write" );
 
     __BEGIN__;
 
-    int i, var_count = get_var_count(), df_count, class_count;
+    int i, var_count = get_var_count(), df_count;
+    int class_count = class_labels ? class_labels->cols :
+                      params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
     const CvSVMDecisionFunc* df = decision_func;
+    if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
+        CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
 
     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
 
@@ -2314,9 +2324,6 @@ void CvSVM::write( CvFileStorage* fs, const char* name ) const
     cvWriteInt( fs, "var_all", var_all );
     cvWriteInt( fs, "var_count", var_count );
 
-    class_count = class_labels ? class_labels->cols :
-                  params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
-
     if( class_count )
     {
         cvWriteInt( fs, "class_count", class_count );
@@ -2454,7 +2461,6 @@ void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
     __END__;
 }
 
-
 void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
 {
     const double not_found_dbl = DBL_MAX;
@@ -2483,7 +2489,7 @@ void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
     var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
     class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
 
-    if( sv_total <= 0 || var_all <= 0 || var_count <= 0 || var_count > var_all || class_count < 0 )
+    if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
         CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
 
     CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
index 9fd31b9..7300185 100644 (file)
@@ -155,6 +155,14 @@ TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
 TEST(ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
 
 
+TEST(ML_SVM, throw_exception_when_save_untrained_model)
+{
+    SVM svm;
+    string filename = tempfile("svm.xml");
+    ASSERT_THROW(svm.save(filename.c_str()), Exception);
+    remove(filename.c_str());
+}
+
 TEST(DISABLED_ML_SVM, linear_save_load)
 {
     CvSVM svm1, svm2, svm3;