ml: fix legacy import in DTreesImpl
authorAlexander Alekhin <alexander.a.alekhin@gmail.com>
Mon, 12 Apr 2021 19:05:52 +0000 (19:05 +0000)
committerAlexander Alekhin <alexander.a.alekhin@gmail.com>
Mon, 12 Apr 2021 19:21:48 +0000 (19:21 +0000)
modules/ml/src/boost.cpp
modules/ml/src/tree.cpp

index 58f572b..be9c9a7 100644 (file)
@@ -490,7 +490,7 @@ public:
 
     float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
     {
-        CV_Assert( samples.cols() == getVarCount() && samples.type() == CV_32F );
+        CV_CheckEQ(samples.cols(), getVarCount(), "");
         return impl.predict(samples, results, flags);
     }
 
index 87181b1..1f82ff5 100644 (file)
@@ -43,6 +43,8 @@
 #include "precomp.hpp"
 #include <ctype.h>
 
+#include <opencv2/core/utils/logger.hpp>
+
 namespace cv {
 namespace ml {
 
@@ -1694,9 +1696,9 @@ void DTreesImpl::write( FileStorage& fs ) const
 void DTreesImpl::readParams( const FileNode& fn )
 {
     _isClassifier = (int)fn["is_classifier"] != 0;
-    /*int var_all = (int)fn["var_all"];
-    int var_count = (int)fn["var_count"];
-    int cat_var_count = (int)fn["cat_var_count"];
+    int varAll = (int)fn["var_all"];
+    int varCount = (int)fn["var_count"];
+    /*int cat_var_count = (int)fn["cat_var_count"];
     int ord_var_count = (int)fn["ord_var_count"];*/
 
     FileNode tparams_node = fn["training_params"];
@@ -1723,11 +1725,38 @@ void DTreesImpl::readParams( const FileNode& fn )
     readVectorOrMat(fn["var_idx"], varIdx);
     fn["var_type"] >> varType;
 
-    int format = 0;
-    fn["format"] >> format;
-    bool isLegacy = format < 3;
+    bool isLegacy = false;
+    if (fn["format"].empty())  // Export bug until OpenCV 3.2: https://github.com/opencv/opencv/pull/6314
+    {
+        if (!fn["cat_ofs"].empty())
+            isLegacy = false;  // 2.4 doesn't store "cat_ofs"
+        else if (!fn["missing_subst"].empty())
+            isLegacy = false;  // 2.4 doesn't store "missing_subst"
+        else if (!fn["class_labels"].empty())
+            isLegacy = false;  // 2.4 doesn't store "class_labels"
+        else if ((int)varType.size() != varAll)
+            isLegacy = true;  // 3.0+: https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/tree.cpp#L1576
+        else if (/*(int)varType.size() == varAll &&*/ varCount == varAll)
+            isLegacy = true;
+        else
+        {
+            // 3.0+:
+            // - https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/tree.cpp#L1552-L1553
+            // - https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/precomp.hpp#L296
+            isLegacy = !(varCount + 1 == varAll);
+        }
+        CV_LOG_INFO(NULL, "ML/DTrees: possible missing 'format' field due to bug of OpenCV export implementation. "
+                "Details: https://github.com/opencv/opencv/issues/5412. Consider re-exporting of saved ML model. "
+                "isLegacy = " << isLegacy);
+    }
+    else
+    {
+        int format = 0;
+        fn["format"] >> format;
+        CV_CheckGT(format, 0, "");
+        isLegacy = format < 3;
+    }
 
-    int varAll = (int)fn["var_all"];
     if (isLegacy && (int)varType.size() <= varAll)
     {
         std::vector<uchar> extendedTypes(varAll + 1, 0);