From b9b19185bcf603e27853260dc1d9c763203e4b2c Mon Sep 17 00:00:00 2001 From: Alexander Alekhin Date: Mon, 12 Apr 2021 19:05:52 +0000 Subject: [PATCH] ml: fix legacy import in DTreesImpl --- modules/ml/src/boost.cpp | 2 +- modules/ml/src/tree.cpp | 43 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/modules/ml/src/boost.cpp b/modules/ml/src/boost.cpp index 58f572b..be9c9a7 100644 --- a/modules/ml/src/boost.cpp +++ b/modules/ml/src/boost.cpp @@ -490,7 +490,7 @@ public: float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE { - CV_Assert( samples.cols() == getVarCount() && samples.type() == CV_32F ); + CV_CheckEQ(samples.cols(), getVarCount(), ""); return impl.predict(samples, results, flags); } diff --git a/modules/ml/src/tree.cpp b/modules/ml/src/tree.cpp index 87181b1..1f82ff5 100644 --- a/modules/ml/src/tree.cpp +++ b/modules/ml/src/tree.cpp @@ -43,6 +43,8 @@ #include "precomp.hpp" #include +#include + namespace cv { namespace ml { @@ -1694,9 +1696,9 @@ void DTreesImpl::write( FileStorage& fs ) const void DTreesImpl::readParams( const FileNode& fn ) { _isClassifier = (int)fn["is_classifier"] != 0; - /*int var_all = (int)fn["var_all"]; - int var_count = (int)fn["var_count"]; - int cat_var_count = (int)fn["cat_var_count"]; + int varAll = (int)fn["var_all"]; + int varCount = (int)fn["var_count"]; + /*int cat_var_count = (int)fn["cat_var_count"]; int ord_var_count = (int)fn["ord_var_count"];*/ FileNode tparams_node = fn["training_params"]; @@ -1723,11 +1725,38 @@ void DTreesImpl::readParams( const FileNode& fn ) readVectorOrMat(fn["var_idx"], varIdx); fn["var_type"] >> varType; - int format = 0; - fn["format"] >> format; - bool isLegacy = format < 3; + bool isLegacy = false; + if (fn["format"].empty()) // Export bug until OpenCV 3.2: https://github.com/opencv/opencv/pull/6314 + { + if (!fn["cat_ofs"].empty()) + isLegacy = false; // 2.4 doesn't store "cat_ofs" + else if (!fn["missing_subst"].empty()) + isLegacy = false; // 2.4 doesn't store "missing_subst" + else if (!fn["class_labels"].empty()) + isLegacy = false; // 2.4 doesn't store "class_labels" + else if ((int)varType.size() != varAll) + isLegacy = true; // 3.0+: https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/tree.cpp#L1576 + else if (/*(int)varType.size() == varAll &&*/ varCount == varAll) + isLegacy = true; + else + { + // 3.0+: + // - https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/tree.cpp#L1552-L1553 + // - https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/precomp.hpp#L296 + isLegacy = !(varCount + 1 == varAll); + } + CV_LOG_INFO(NULL, "ML/DTrees: possible missing 'format' field due to bug of OpenCV export implementation. " + "Details: https://github.com/opencv/opencv/issues/5412. Consider re-exporting of saved ML model. " + "isLegacy = " << isLegacy); + } + else + { + int format = 0; + fn["format"] >> format; + CV_CheckGT(format, 0, ""); + isLegacy = format < 3; + } - int varAll = (int)fn["var_all"]; if (isLegacy && (int)varType.size() <= varAll) { std::vector extendedTypes(varAll + 1, 0); -- 2.7.4