add xml serialization
authormarina.kolpakova <marina.kolpakova@itseez.com>
Wed, 12 Dec 2012 10:20:42 +0000 (14:20 +0400)
committermarina.kolpakova <marina.kolpakova@itseez.com>
Fri, 1 Feb 2013 10:34:39 +0000 (14:34 +0400)
apps/sft/include/sft/octave.hpp
apps/sft/octave.cpp
apps/sft/sft.cpp
apps/traincascade/boost.cpp

index 29532cf..f5fd788 100644 (file)
@@ -144,6 +144,8 @@ public:
     virtual float predict( const Mat& _sample, Mat& _votes, bool raw_mode, bool return_sum ) const;
     virtual void setRejectThresholds(cv::Mat& thresholds);
 
+    virtual void write( cv::FileStorage &fs, const Mat& thresholds = Mat()) const;
+
     int logScale;
 
 protected:
@@ -155,6 +157,8 @@ protected:
 
     float predict( const Mat& _sample, const cv::Range range) const;
 private:
+    void traverse(const CvBoostTree* tree, cv::FileStorage& fs, const float* th = 0) const;
+
     cv::Rect boundingBox;
 
     int npositives;
index 5c44157..6b601ab 100644 (file)
@@ -47,6 +47,8 @@
 #include <opencv2/imgproc/imgproc.hpp>
 #include <opencv2/highgui/highgui.hpp>
 
+#include <queue>
+
 // ============ Octave ============ //
 sft::Octave::Octave(cv::Rect bb, int np, int nn, int ls, int shr)
 : logScale(ls), boundingBox(bb), npositives(np), nnegatives(nn), shrinkage(shr)
@@ -293,6 +295,89 @@ void sft::Octave::generateNegatives(const Dataset& dataset)
     dprintf("Processing negatives finished:\n\trequested %d negatives, viewed %d samples.\n", nnegatives, total);
 }
 
+template <typename T> int sgn(T val) {
+    return (T(0) < val) - (val < T(0));
+}
+
+void sft::Octave::traverse(const CvBoostTree* tree, cv::FileStorage& fs, const float* th) const
+{
+    std::queue<const CvDTreeNode*> nodes;
+    nodes.push( tree->get_root());
+    const CvDTreeNode* tempNode;
+    int leafValIdx = 0;
+    int internalNodeIdx = 1;
+    float* leafs = new float[(int)pow(2.f, get_params().max_depth)];
+
+    fs << "{";
+    fs << "internalNodes" << "[";
+    while (!nodes.empty())
+    {
+        tempNode = nodes.front();
+        CV_Assert( tempNode->left );
+        if ( !tempNode->left->left && !tempNode->left->right)
+        {
+            leafs[-leafValIdx] = (float)tempNode->left->value;
+            fs << leafValIdx-- ;
+        }
+        else
+        {
+            nodes.push( tempNode->left );
+            fs << internalNodeIdx++;
+        }
+        CV_Assert( tempNode->right );
+        if ( !tempNode->right->left && !tempNode->right->right)
+        {
+            leafs[-leafValIdx] = (float)tempNode->right->value;
+            fs << leafValIdx--;
+        }
+        else
+        {
+            nodes.push( tempNode->right );
+            fs << internalNodeIdx++;
+        }
+        int fidx = tempNode->split->var_idx;
+        fs << fidx;
+
+        fs << tempNode->split->ord.c;
+
+        nodes.pop();
+    }
+    fs << "]";
+
+    fs << "leafValues" << "[";
+    for (int ni = 0; ni < -leafValIdx; ni++)
+        fs << ( (!th) ? leafs[ni] : (sgn(leafs[ni]) * *th));
+    fs << "]";
+
+    fs << "}";
+}
+
+void sft::Octave::write( cv::FileStorage &fso, const Mat& thresholds) const
+{
+    fso << "{"
+        << "scale" << logScale
+        << "weaks" << weak->total
+        << "trees" << "[";
+        // should be replased with the H.L. one
+        CvSeqReader reader;
+        cvStartReadSeq( weak, &reader);
+
+        for(int i = 0; i < weak->total; i++ )
+        {
+            CvBoostTree* tree;
+            CV_READ_SEQ_ELEM( tree, reader );
+
+            if (!thresholds.empty())
+                traverse(tree, fso, thresholds.ptr<float>(0)+ i);
+            else
+                traverse(tree, fso);
+        }
+        //
+
+    fso << "]"
+        << "}";
+}
+
 bool sft::Octave::train(const Dataset& dataset, const FeaturePool& pool, int weaks, int treeDepth)
 {
     CV_Assert(treeDepth == 2);
index 869d50c..f3be928 100644 (file)
@@ -94,16 +94,41 @@ int main(int argc, char** argv)
 
     // 2. check and open output file
     cv::FileStorage fso(cfg.outXmlPath, cv::FileStorage::WRITE);
-    if(!fs.isOpened())
+    if(!fso.isOpened())
     {
         std::cout << "Training stopped. Output classifier Xml file " << cfg.outXmlPath << " can't be opened." << std::endl << std::flush;
         return 1;
     }
 
+    cv::FileStorage fsr(cfg.outXmlPath + ".raw.xml" , cv::FileStorage::WRITE);
+    if(!fsr.isOpened())
+    {
+        std::cout << "Training stopped. Output classifier Xml file " <<cfg.outXmlPath + ".raw.xml" << " can't be opened." << std::endl << std::flush;
+        return 1;
+    }
+
     // ovector strong;
     // strong.reserve(cfg.octaves.size());
 
-    // fso << "softcascade" << "{" << "octaves" << "[";
+    fso << cfg.cascadeName
+        << "{"
+        << "stageType"   << "BOOST"
+        << "featureType" << "ICF"
+        << "octavesNum"  << (int)cfg.octaves.size()
+        << "width"       << cfg.modelWinSize.width
+        << "height"      << cfg.modelWinSize.height
+        << "shrinkage"   << cfg.shrinkage
+        << "octaves"     << "[";
+
+    fsr << cfg.cascadeName
+        << "{"
+        << "stageType"   << "BOOST"
+        << "featureType" << "ICF"
+        << "octavesNum"  << (int)cfg.octaves.size()
+        << "width"       << cfg.modelWinSize.width
+        << "height"      << cfg.modelWinSize.height
+        << "shrinkage"   << cfg.shrinkage
+        << "octaves"     << "[";
 
     // 3. Train all octaves
     for (ivector::const_iterator it = cfg.octaves.begin(); it != cfg.octaves.end(); ++it)
@@ -137,6 +162,8 @@ int main(int argc, char** argv)
             cv::Mat thresholds;
             boost.setRejectThresholds(thresholds);
 
+            boost.write(fso, thresholds);
+            boost.write(fsr);
             // std::cout << "thresholds " << thresholds << std::endl;
 
             cv::FileStorage tfs(("thresholds." + cfg.resPath(it)).c_str(), cv::FileStorage::WRITE);
@@ -146,7 +173,8 @@ int main(int argc, char** argv)
         }
     }
 
-    // fso << "]" << "}";
+    fso << "]" << "}";
+    fsr << "]" << "}";
 
 //     // // 6. Set thresolds
 //     // cascade.prune();
index 3e17b5d..ea61cda 100644 (file)
@@ -1580,8 +1580,11 @@ bool CvCascadeBoost::isErrDesired()
     for( int i = 0; i < sCount; i++ )
         if( ((CvCascadeBoostTrainData*)data)->featureEvaluator->getCls( i ) == 1.0F )
             eval[numPos++] = predict( i, true );
+
     icvSortFlt( &eval[0], numPos, 0 );
+
     int thresholdIdx = (int)((1.0F - minHitRate) * numPos);
+
     threshold = eval[ thresholdIdx ];
     numPosTrue = numPos - thresholdIdx;
     for( int i = thresholdIdx - 1; i >= 0; i--)