clear();
vector<int> _layer_sizes;
- fn["layer_sizes"] >> _layer_sizes;
+ readVectorOrMat(fn["layer_sizes"], _layer_sizes);
create( _layer_sizes );
int i, l_count = layer_count();
bparams.priors = params0.priors;
FileNode tparams_node = fn["training_params"];
- String bts = (String)tparams_node["boosting_type"];
+ // check for old layout
+ String bts = (String)(fn["boosting_type"].empty() ?
+ tparams_node["boosting_type"] : fn["boosting_type"]);
bparams.boostType = (bts == "DiscreteAdaboost" ? Boost::DISCRETE :
bts == "RealAdaboost" ? Boost::REAL :
bts == "LogitBoost" ? Boost::LOGIT :
bts == "GentleAdaboost" ? Boost::GENTLE : -1);
_isClassifier = bparams.boostType == Boost::DISCRETE;
- bparams.weightTrimRate = (double)tparams_node["weight_trimming_rate"];
+ // check for old layout
+ bparams.weightTrimRate = (double)(fn["weight_trimming_rate"].empty() ?
+ tparams_node["weight_trimming_rate"] : fn["weight_trimming_rate"]);
}
void read( const FileNode& fn )
CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
const int* cmap = &catMap.at<int>(ofs[0]);
- bool fastMap = (m == cmap[m] - cmap[0]);
+ bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
if( fastMap )
{
{
FileStorage fs(filename, FileStorage::WRITE);
fs << getDefaultModelName() << "{";
+ fs << "format" << (int)3;
write(fs);
fs << "}";
}
vector<int> subsets;
vector<int> classLabels;
vector<float> missingSubst;
+ vector<int> varMapping;
bool _isClassifier;
Ptr<WorkData> w;
};
+ template <typename T>
+ static inline void readVectorOrMat(const FileNode & node, std::vector<T> & v)
+ {
+ if (node.type() == FileNode::MAP)
+ {
+ Mat m;
+ node >> m;
+ m.copyTo(v);
+ }
+ else if (node.type() == FileNode::SEQ)
+ {
+ node >> v;
+ }
+ }
+
}}
#endif /* __OPENCV_ML_PRECOMP_HPP__ */
oobError = (double)fn["oob_error"];
int ntrees = (int)fn["ntrees"];
- fn["var_importance"] >> varImportance;
+ readVectorOrMat(fn["var_importance"], varImportance);
readParams(fn);
{
Params _params;
- String svm_type_str = (String)fn["svmType"];
+ // check for old naming
+ String svm_type_str = (String)(fn["svm_type"].empty() ? fn["svmType"] : fn["svm_type"]);
int svmType =
svm_type_str == "C_SVC" ? C_SVC :
svm_type_str == "NU_SVC" ? NU_SVC :
fs << "}";
if( !varIdx.empty() )
+ {
+ fs << "global_var_idx" << 1;
fs << "var_idx" << varIdx;
+ }
fs << "var_type" << varType;
if( !tparams_node.empty() ) // training parameters are not necessary
{
params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
- params0.maxCategories = (int)tparams_node["max_categories"];
+ params0.maxCategories = (int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]);
params0.regressionAccuracy = (float)tparams_node["regression_accuracy"];
-
params0.maxDepth = (int)tparams_node["max_depth"];
params0.minSampleCount = (int)tparams_node["min_sample_count"];
params0.CVFolds = (int)tparams_node["cross_validation_folds"];
tparams_node["priors"] >> params0.priors;
}
- fn["var_idx"] >> varIdx;
+ readVectorOrMat(fn["var_idx"], varIdx);
fn["var_type"] >> varType;
- fn["cat_ofs"] >> catOfs;
- fn["cat_map"] >> catMap;
- fn["missing_subst"] >> missingSubst;
- fn["class_labels"] >> classLabels;
+ int format = 0;
+ fn["format"] >> format;
+ bool isLegacy = format < 3;
+
+ int varAll = (int)fn["var_all"];
+ if (isLegacy && (int)varType.size() <= varAll)
+ {
+ std::vector<uchar> extendedTypes(varAll + 1, 0);
+
+ int i = 0, n;
+ if (!varIdx.empty())
+ {
+ n = (int)varIdx.size();
+ for (; i < n; ++i)
+ {
+ int var = varIdx[i];
+ extendedTypes[var] = varType[i];
+ }
+ }
+ else
+ {
+ n = (int)varType.size();
+ for (; i < n; ++i)
+ {
+ extendedTypes[i] = varType[i];
+ }
+ }
+ extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
+ extendedTypes.swap(varType);
+ }
+
+ readVectorOrMat(fn["cat_map"], catMap);
+
+ if (isLegacy)
+ {
+ // generating "catOfs" from "cat_count"
+ catOfs.clear();
+ classLabels.clear();
+ std::vector<int> counts;
+ readVectorOrMat(fn["cat_count"], counts);
+ unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
+ for (; i < size; ++i)
+ {
+ Vec2i newOffsets(0, 0);
+ if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
+ {
+ newOffsets[0] = curShift;
+ curShift += counts[j];
+ newOffsets[1] = curShift;
+ ++j;
+ }
+ catOfs.push_back(newOffsets);
+ }
+ // other elements in "catMap" are "classLabels"
+ if (curShift < catMap.size())
+ {
+ classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
+ catMap.erase(catMap.begin() + curShift, catMap.end());
+ }
+ }
+ else
+ {
+ fn["cat_ofs"] >> catOfs;
+ fn["missing_subst"] >> missingSubst;
+ fn["class_labels"] >> classLabels;
+ }
+
+ // init var mapping for node reading (var indexes or varIdx indexes)
+ bool globalVarIdx = false;
+ fn["global_var_idx"] >> globalVarIdx;
+ if (globalVarIdx || varIdx.empty())
+ setRangeVector(varMapping, (int)varType.size());
+ else
+ varMapping = varIdx;
initCompVarIdx();
setDParams(params0);
int vi = (int)fn["var"];
CV_Assert( 0 <= vi && vi <= (int)varType.size() );
+ vi = varMapping[vi]; // convert to varIdx if needed
split.varIdx = vi;
if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
+class CV_LegacyTest : public cvtest::BaseTest
+{
+public:
+ CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string())
+ : cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes)
+ {
+ }
+ virtual ~CV_LegacyTest() {}
+protected:
+ void run(int)
+ {
+ unsigned int idx = 0;
+ for (;;)
+ {
+ if (idx >= suffixes.size())
+ break;
+ int found = (int)suffixes.find(';', idx);
+ string piece = suffixes.substr(idx, found - idx);
+ if (piece.empty())
+ break;
+ oneTest(piece);
+ idx += (unsigned int)piece.size() + 1;
+ }
+ }
+ void oneTest(const string & suffix)
+ {
+ using namespace cv::ml;
+
+ int code = cvtest::TS::OK;
+ string filename = ts->get_data_path() + "legacy/" + modelName + suffix;
+ bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES;
+ Ptr<StatModel> model;
+ if (modelName == CV_BOOST)
+ model = StatModel::load<Boost>(filename);
+ else if (modelName == CV_ANN)
+ model = StatModel::load<ANN_MLP>(filename);
+ else if (modelName == CV_DTREE)
+ model = StatModel::load<DTrees>(filename);
+ else if (modelName == CV_NBAYES)
+ model = StatModel::load<NormalBayesClassifier>(filename);
+ else if (modelName == CV_SVM)
+ model = StatModel::load<SVM>(filename);
+ else if (modelName == CV_RTREES)
+ model = StatModel::load<RTrees>(filename);
+ if (!model)
+ {
+ code = cvtest::TS::FAIL_INVALID_TEST_DATA;
+ }
+ else
+ {
+ Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
+ ts->get_rng().fill(input, RNG::UNIFORM, 0, 40);
+
+ if (isTree)
+ randomFillCategories(filename, input);
+
+ Mat output;
+ model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0));
+ // just check if no internal assertions or errors thrown
+ }
+ ts->set_failed_test_info(code);
+ }
+ void randomFillCategories(const string & filename, Mat & input)
+ {
+ Mat catMap;
+ Mat catCount;
+ std::vector<uchar> varTypes;
+
+ FileStorage fs(filename, FileStorage::READ);
+ FileNode root = fs.getFirstTopLevelNode();
+ root["cat_map"] >> catMap;
+ root["cat_count"] >> catCount;
+ root["var_type"] >> varTypes;
+
+ int offset = 0;
+ int countOffset = 0;
+ uint var = 0, varCount = (uint)varTypes.size();
+ for (; var < varCount; ++var)
+ {
+ if (varTypes[var] == ml::VAR_CATEGORICAL)
+ {
+ int size = catCount.at<int>(0, countOffset);
+ for (int row = 0; row < input.rows; ++row)
+ {
+ int randomChosenIndex = offset + ((uint)ts->get_rng()) % size;
+ int value = catMap.at<int>(0, randomChosenIndex);
+ input.at<float>(row, var) = (float)value;
+ }
+ offset += size;
+ ++countOffset;
+ }
+ }
+ }
+ string modelName;
+ string suffixes;
+};
+
+TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); }
+TEST(ML_Boost, legacy_load) { CV_LegacyTest test(CV_BOOST, "_adult.xml;_1.xml;_2.xml;_3.xml"); test.safe_run(); }
+TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushroom.xml"); test.safe_run(); }
+TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
+TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
+TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
/*TEST(ML_SVM, throw_exception_when_save_untrained_model)
{