Support loading old models in ML module
authorMaksim Shabunin <maksim.shabunin@itseez.com>
Tue, 16 Dec 2014 15:15:50 +0000 (18:15 +0300)
committerMaksim Shabunin <maksim.shabunin@itseez.com>
Wed, 31 Dec 2014 09:16:25 +0000 (12:16 +0300)
- added test for loading legacy files
- added version to new written models
- fixed loading of several fields in some models
- added generation of new fields from old data

modules/ml/src/ann_mlp.cpp
modules/ml/src/boost.cpp
modules/ml/src/data.cpp
modules/ml/src/inner_functions.cpp
modules/ml/src/precomp.hpp
modules/ml/src/rtrees.cpp
modules/ml/src/svm.cpp
modules/ml/src/tree.cpp
modules/ml/test/test_save_load.cpp

index 3e7d44e..ef52801 100644 (file)
@@ -1241,7 +1241,7 @@ public:
         clear();
 
         vector<int> _layer_sizes;
-        fn["layer_sizes"] >> _layer_sizes;
+        readVectorOrMat(fn["layer_sizes"], _layer_sizes);
         create( _layer_sizes );
 
         int i, l_count = layer_count();
index 5e0b307..236cd97 100644 (file)
@@ -434,13 +434,17 @@ public:
         bparams.priors = params0.priors;
 
         FileNode tparams_node = fn["training_params"];
-        String bts = (String)tparams_node["boosting_type"];
+        // check for old layout
+        String bts = (String)(fn["boosting_type"].empty() ?
+                         tparams_node["boosting_type"] : fn["boosting_type"]);
         bparams.boostType = (bts == "DiscreteAdaboost" ? Boost::DISCRETE :
                              bts == "RealAdaboost" ? Boost::REAL :
                              bts == "LogitBoost" ? Boost::LOGIT :
                              bts == "GentleAdaboost" ? Boost::GENTLE : -1);
         _isClassifier = bparams.boostType == Boost::DISCRETE;
-        bparams.weightTrimRate = (double)tparams_node["weight_trimming_rate"];
+        // check for old layout
+        bparams.weightTrimRate = (double)(fn["weight_trimming_rate"].empty() ?
+                                    tparams_node["weight_trimming_rate"] : fn["weight_trimming_rate"]);
     }
 
     void read( const FileNode& fn )
index 6b5ceb4..d2ac18f 100644 (file)
@@ -898,7 +898,7 @@ public:
 
         CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
         const int* cmap = &catMap.at<int>(ofs[0]);
-        bool fastMap = (m == cmap[m] - cmap[0]);
+        bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
 
         if( fastMap )
         {
index dbc21ff..561abba 100644 (file)
@@ -115,6 +115,7 @@ void StatModel::save(const String& filename) const
 {
     FileStorage fs(filename, FileStorage::WRITE);
     fs << getDefaultModelName() << "{";
+    fs << "format" << (int)3;
     write(fs);
     fs << "}";
 }
index d308ae9..69ff030 100644 (file)
@@ -263,11 +263,27 @@ namespace ml
         vector<int> subsets;
         vector<int> classLabels;
         vector<float> missingSubst;
+        vector<int> varMapping;
         bool _isClassifier;
 
         Ptr<WorkData> w;
     };
 
+    template <typename T>
+    static inline void readVectorOrMat(const FileNode & node, std::vector<T> & v)
+    {
+        if (node.type() == FileNode::MAP)
+        {
+            Mat m;
+            node >> m;
+            m.copyTo(v);
+        }
+        else if (node.type() == FileNode::SEQ)
+        {
+            node >> v;
+        }
+    }
+
 }}
 
 #endif /* __OPENCV_ML_PRECOMP_HPP__ */
index 7c9cbaf..7441faa 100644 (file)
@@ -346,7 +346,7 @@ public:
         oobError = (double)fn["oob_error"];
         int ntrees = (int)fn["ntrees"];
 
-        fn["var_importance"] >> varImportance;
+        readVectorOrMat(fn["var_importance"], varImportance);
 
         readParams(fn);
 
index c7c32f0..a0df44f 100644 (file)
@@ -2038,7 +2038,8 @@ public:
     {
         Params _params;
 
-        String svm_type_str = (String)fn["svmType"];
+        // check for old naming
+        String svm_type_str = (String)(fn["svm_type"].empty() ? fn["svmType"] : fn["svm_type"]);
         int svmType =
             svm_type_str == "C_SVC" ? C_SVC :
             svm_type_str == "NU_SVC" ? NU_SVC :
index 416abd9..64f6616 100644 (file)
@@ -1597,7 +1597,10 @@ void DTreesImpl::writeParams(FileStorage& fs) const
     fs << "}";
 
     if( !varIdx.empty() )
+    {
+        fs << "global_var_idx" << 1;
         fs << "var_idx" << varIdx;
+    }
 
     fs << "var_type" << varType;
 
@@ -1726,9 +1729,8 @@ void DTreesImpl::readParams( const FileNode& fn )
     if( !tparams_node.empty() ) // training parameters are not necessary
     {
         params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
-        params0.maxCategories = (int)tparams_node["max_categories"];
+        params0.maxCategories = (int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]);
         params0.regressionAccuracy = (float)tparams_node["regression_accuracy"];
-
         params0.maxDepth = (int)tparams_node["max_depth"];
         params0.minSampleCount = (int)tparams_node["min_sample_count"];
         params0.CVFolds = (int)tparams_node["cross_validation_folds"];
@@ -1741,13 +1743,83 @@ void DTreesImpl::readParams( const FileNode& fn )
         tparams_node["priors"] >> params0.priors;
     }
 
-    fn["var_idx"] >> varIdx;
+    readVectorOrMat(fn["var_idx"], varIdx);
     fn["var_type"] >> varType;
 
-    fn["cat_ofs"] >> catOfs;
-    fn["cat_map"] >> catMap;
-    fn["missing_subst"] >> missingSubst;
-    fn["class_labels"] >> classLabels;
+    int format = 0;
+    fn["format"] >> format;
+    bool isLegacy = format < 3;
+
+    int varAll = (int)fn["var_all"];
+    if (isLegacy && (int)varType.size() <= varAll)
+    {
+        std::vector<uchar> extendedTypes(varAll + 1, 0);
+
+        int i = 0, n;
+        if (!varIdx.empty())
+        {
+            n = (int)varIdx.size();
+            for (; i < n; ++i)
+            {
+                int var = varIdx[i];
+                extendedTypes[var] = varType[i];
+            }
+        }
+        else
+        {
+            n = (int)varType.size();
+            for (; i < n; ++i)
+            {
+                extendedTypes[i] = varType[i];
+            }
+        }
+        extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
+        extendedTypes.swap(varType);
+    }
+
+    readVectorOrMat(fn["cat_map"], catMap);
+
+    if (isLegacy)
+    {
+        // generating "catOfs" from "cat_count"
+        catOfs.clear();
+        classLabels.clear();
+        std::vector<int> counts;
+        readVectorOrMat(fn["cat_count"], counts);
+        unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
+        for (; i < size; ++i)
+        {
+            Vec2i newOffsets(0, 0);
+            if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
+            {
+                newOffsets[0] = curShift;
+                curShift += counts[j];
+                newOffsets[1] = curShift;
+                ++j;
+            }
+            catOfs.push_back(newOffsets);
+        }
+        // other elements in "catMap" are "classLabels"
+        if (curShift < catMap.size())
+        {
+            classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
+            catMap.erase(catMap.begin() + curShift, catMap.end());
+        }
+    }
+    else
+    {
+        fn["cat_ofs"] >> catOfs;
+        fn["missing_subst"] >> missingSubst;
+        fn["class_labels"] >> classLabels;
+    }
+
+    // init var mapping for node reading (var indexes or varIdx indexes)
+    bool globalVarIdx = false;
+    fn["global_var_idx"] >> globalVarIdx;
+    if (globalVarIdx || varIdx.empty())
+        setRangeVector(varMapping, (int)varType.size());
+    else
+        varMapping = varIdx;
 
     initCompVarIdx();
     setDParams(params0);
@@ -1759,6 +1831,7 @@ int DTreesImpl::readSplit( const FileNode& fn )
 
     int vi = (int)fn["var"];
     CV_Assert( 0 <= vi && vi <= (int)varType.size() );
+    vi = varMapping[vi]; // convert to varIdx if needed
     split.varIdx = vi;
 
     if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
index bef2fd0..74e8eef 100644 (file)
@@ -158,6 +158,109 @@ TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
 TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
 TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
 
+class CV_LegacyTest : public cvtest::BaseTest
+{
+public:
+    CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string())
+        : cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes)
+    {
+    }
+    virtual ~CV_LegacyTest() {}
+protected:
+    void run(int)
+    {
+        unsigned int idx = 0;
+        for (;;)
+        {
+            if (idx >= suffixes.size())
+                break;
+            int found = (int)suffixes.find(';', idx);
+            string piece = suffixes.substr(idx, found - idx);
+            if (piece.empty())
+                break;
+            oneTest(piece);
+            idx += (unsigned int)piece.size() + 1;
+        }
+    }
+    void oneTest(const string & suffix)
+    {
+        using namespace cv::ml;
+
+        int code = cvtest::TS::OK;
+        string filename = ts->get_data_path() + "legacy/" + modelName + suffix;
+        bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES;
+        Ptr<StatModel> model;
+        if (modelName == CV_BOOST)
+            model = StatModel::load<Boost>(filename);
+        else if (modelName == CV_ANN)
+            model = StatModel::load<ANN_MLP>(filename);
+        else if (modelName == CV_DTREE)
+            model = StatModel::load<DTrees>(filename);
+        else if (modelName == CV_NBAYES)
+            model = StatModel::load<NormalBayesClassifier>(filename);
+        else if (modelName == CV_SVM)
+            model = StatModel::load<SVM>(filename);
+        else if (modelName == CV_RTREES)
+            model = StatModel::load<RTrees>(filename);
+        if (!model)
+        {
+            code = cvtest::TS::FAIL_INVALID_TEST_DATA;
+        }
+        else
+        {
+            Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
+            ts->get_rng().fill(input, RNG::UNIFORM, 0, 40);
+
+            if (isTree)
+                randomFillCategories(filename, input);
+
+            Mat output;
+            model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0));
+            // just check if no internal assertions or errors thrown
+        }
+        ts->set_failed_test_info(code);
+    }
+    void randomFillCategories(const string & filename, Mat & input)
+    {
+        Mat catMap;
+        Mat catCount;
+        std::vector<uchar> varTypes;
+
+        FileStorage fs(filename, FileStorage::READ);
+        FileNode root = fs.getFirstTopLevelNode();
+        root["cat_map"] >> catMap;
+        root["cat_count"] >> catCount;
+        root["var_type"] >> varTypes;
+
+        int offset = 0;
+        int countOffset = 0;
+        uint var = 0, varCount = (uint)varTypes.size();
+        for (; var < varCount; ++var)
+        {
+            if (varTypes[var] == ml::VAR_CATEGORICAL)
+            {
+                int size = catCount.at<int>(0, countOffset);
+                for (int row = 0; row < input.rows; ++row)
+                {
+                    int randomChosenIndex = offset + ((uint)ts->get_rng()) % size;
+                    int value = catMap.at<int>(0, randomChosenIndex);
+                    input.at<float>(row, var) = (float)value;
+                }
+                offset += size;
+                ++countOffset;
+            }
+        }
+    }
+    string modelName;
+    string suffixes;
+};
+
+TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); }
+TEST(ML_Boost, legacy_load) { CV_LegacyTest test(CV_BOOST, "_adult.xml;_1.xml;_2.xml;_3.xml"); test.safe_run(); }
+TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushroom.xml"); test.safe_run(); }
+TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
+TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
+TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
 
 /*TEST(ML_SVM, throw_exception_when_save_untrained_model)
 {