From c0f68ec400015b96d1bf9524118186527c411e4a Mon Sep 17 00:00:00 2001 From: "marina.kolpakova" Date: Wed, 12 Dec 2012 14:20:42 +0400 Subject: [PATCH] add xml serialization --- apps/sft/include/sft/octave.hpp | 4 ++ apps/sft/octave.cpp | 85 +++++++++++++++++++++++++++++++++++++++++ apps/sft/sft.cpp | 34 +++++++++++++++-- apps/traincascade/boost.cpp | 3 ++ 4 files changed, 123 insertions(+), 3 deletions(-) diff --git a/apps/sft/include/sft/octave.hpp b/apps/sft/include/sft/octave.hpp index 29532cf..f5fd788 100644 --- a/apps/sft/include/sft/octave.hpp +++ b/apps/sft/include/sft/octave.hpp @@ -144,6 +144,8 @@ public: virtual float predict( const Mat& _sample, Mat& _votes, bool raw_mode, bool return_sum ) const; virtual void setRejectThresholds(cv::Mat& thresholds); + virtual void write( cv::FileStorage &fs, const Mat& thresholds = Mat()) const; + int logScale; protected: @@ -155,6 +157,8 @@ protected: float predict( const Mat& _sample, const cv::Range range) const; private: + void traverse(const CvBoostTree* tree, cv::FileStorage& fs, const float* th = 0) const; + cv::Rect boundingBox; int npositives; diff --git a/apps/sft/octave.cpp b/apps/sft/octave.cpp index 5c44157..6b601ab 100644 --- a/apps/sft/octave.cpp +++ b/apps/sft/octave.cpp @@ -47,6 +47,8 @@ #include #include +#include + // ============ Octave ============ // sft::Octave::Octave(cv::Rect bb, int np, int nn, int ls, int shr) : logScale(ls), boundingBox(bb), npositives(np), nnegatives(nn), shrinkage(shr) @@ -293,6 +295,89 @@ void sft::Octave::generateNegatives(const Dataset& dataset) dprintf("Processing negatives finished:\n\trequested %d negatives, viewed %d samples.\n", nnegatives, total); } +template int sgn(T val) { + return (T(0) < val) - (val < T(0)); +} + +void sft::Octave::traverse(const CvBoostTree* tree, cv::FileStorage& fs, const float* th) const +{ + std::queue nodes; + nodes.push( tree->get_root()); + const CvDTreeNode* tempNode; + int leafValIdx = 0; + int internalNodeIdx = 1; + float* leafs = new float[(int)pow(2.f, get_params().max_depth)]; + + fs << "{"; + fs << "internalNodes" << "["; + while (!nodes.empty()) + { + tempNode = nodes.front(); + CV_Assert( tempNode->left ); + if ( !tempNode->left->left && !tempNode->left->right) + { + leafs[-leafValIdx] = (float)tempNode->left->value; + fs << leafValIdx-- ; + } + else + { + nodes.push( tempNode->left ); + fs << internalNodeIdx++; + } + CV_Assert( tempNode->right ); + if ( !tempNode->right->left && !tempNode->right->right) + { + leafs[-leafValIdx] = (float)tempNode->right->value; + fs << leafValIdx--; + } + else + { + nodes.push( tempNode->right ); + fs << internalNodeIdx++; + } + int fidx = tempNode->split->var_idx; + fs << fidx; + + fs << tempNode->split->ord.c; + + nodes.pop(); + } + fs << "]"; + + fs << "leafValues" << "["; + for (int ni = 0; ni < -leafValIdx; ni++) + fs << ( (!th) ? leafs[ni] : (sgn(leafs[ni]) * *th)); + fs << "]"; + + fs << "}"; +} + +void sft::Octave::write( cv::FileStorage &fso, const Mat& thresholds) const +{ + fso << "{" + << "scale" << logScale + << "weaks" << weak->total + << "trees" << "["; + // should be replased with the H.L. one + CvSeqReader reader; + cvStartReadSeq( weak, &reader); + + for(int i = 0; i < weak->total; i++ ) + { + CvBoostTree* tree; + CV_READ_SEQ_ELEM( tree, reader ); + + if (!thresholds.empty()) + traverse(tree, fso, thresholds.ptr(0)+ i); + else + traverse(tree, fso); + } + // + + fso << "]" + << "}"; +} + bool sft::Octave::train(const Dataset& dataset, const FeaturePool& pool, int weaks, int treeDepth) { CV_Assert(treeDepth == 2); diff --git a/apps/sft/sft.cpp b/apps/sft/sft.cpp index 869d50c..f3be928 100644 --- a/apps/sft/sft.cpp +++ b/apps/sft/sft.cpp @@ -94,16 +94,41 @@ int main(int argc, char** argv) // 2. check and open output file cv::FileStorage fso(cfg.outXmlPath, cv::FileStorage::WRITE); - if(!fs.isOpened()) + if(!fso.isOpened()) { std::cout << "Training stopped. Output classifier Xml file " << cfg.outXmlPath << " can't be opened." << std::endl << std::flush; return 1; } + cv::FileStorage fsr(cfg.outXmlPath + ".raw.xml" , cv::FileStorage::WRITE); + if(!fsr.isOpened()) + { + std::cout << "Training stopped. Output classifier Xml file " <featureEvaluator->getCls( i ) == 1.0F ) eval[numPos++] = predict( i, true ); + icvSortFlt( &eval[0], numPos, 0 ); + int thresholdIdx = (int)((1.0F - minHitRate) * numPos); + threshold = eval[ thresholdIdx ]; numPosTrue = numPos - thresholdIdx; for( int i = thresholdIdx - 1; i >= 0; i--) -- 2.7.4