ml: update checks
authorAlexander Alekhin <alexander.a.alekhin@gmail.com>
Tue, 13 Apr 2021 11:09:14 +0000 (11:09 +0000)
committerAlexander Alekhin <alexander.a.alekhin@gmail.com>
Tue, 13 Apr 2021 11:09:14 +0000 (11:09 +0000)
modules/ml/src/rtrees.cpp
modules/ml/src/tree.cpp

index 46af37ce11900a0f8a1e25e88b1b8127dcec8ea7..56be5c0e226c9bbd08b271bb0631892eed4a8be1 100644 (file)
@@ -479,7 +479,7 @@ public:
     float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
     {
         CV_TRACE_FUNCTION();
-        CV_Assert( samples.cols() == getVarCount() && samples.type() == CV_32F );
+        CV_CheckEQ(samples.cols(), getVarCount(), "");
         return impl.predict(samples, results, flags);
     }
 
index 1f82ff508198a650c51c852d767608745bf9ecd4..5dae889013d92583e901fbbf111c76c947738a87 100644 (file)
@@ -1701,6 +1701,9 @@ void DTreesImpl::readParams( const FileNode& fn )
     /*int cat_var_count = (int)fn["cat_var_count"];
     int ord_var_count = (int)fn["ord_var_count"];*/
 
+    if (varAll <= 0)
+        CV_Error(Error::StsParseError, "The field \"var_all\" of DTree classifier is missing or non-positive");
+
     FileNode tparams_node = fn["training_params"];
 
     TreeParams params0 = TreeParams();