fixed traincascade for ordered features
authorMaria Dimashova <no@email>
Thu, 22 Dec 2011 11:19:27 +0000 (11:19 +0000)
committerMaria Dimashova <no@email>
Thu, 22 Dec 2011 11:19:27 +0000 (11:19 +0000)
modules/ml/src/boost.cpp
modules/ml/src/tree.cpp
modules/traincascade/boost.cpp
modules/traincascade/boost.h

index b29feab..d69cb80 100644 (file)
@@ -1066,7 +1066,7 @@ CvBoost::train( const CvMat* _train_data, int _tflag,
         if( !tree->train( data, subsample_mask, this ) )
         {
             delete tree;
-            continue;
+            break;
         }
         //cvCheckArr( get_weak_response());
         cvSeqPush( weak, &tree );
index 371c717..956b262 100644 (file)
@@ -718,7 +718,7 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
         int* co, cur_ofs = 0;
         int vi, i;
-        int work_var_count = get_work_var_count();
+        int workVarCount = get_work_var_count();
         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
 
         root = new_node( 0, count, 1, 0 );
@@ -740,7 +740,7 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
         }
 
         cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
-        for( vi = 0; vi < work_var_count; vi++ )
+        for( vi = 0; vi < workVarCount; vi++ )
         {
             int ci = get_var_type(vi);
 
@@ -841,14 +841,14 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
         if (is_buf_16u)
         {
             unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + 
-                get_work_var_count()*sample_count + root->offset);            
+                workVarCount*sample_count + root->offset);
             for (i = 0; i < count; i++)
                 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
         }
         else
         {
             int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols + 
-                get_work_var_count()*sample_count + root->offset;            
+                workVarCount*sample_count + root->offset;
             for (i = 0; i < count; i++)
                 sample_idx_dst[i] = sample_idx_src[sidx[i]];
         }
@@ -1622,13 +1622,19 @@ bool CvDTree::do_train( const CvMat* _subsample_idx )
 
     CV_CALL( try_split_node(root));
 
-    if( data->params.cv_folds > 0 )
-        CV_CALL( prune_cv() );
+    if( root->split )
+    {
+        CV_Assert( root->left );
+        CV_Assert( root->right );
+
+        if( data->params.cv_folds > 0 )
+            CV_CALL( prune_cv() );
 
-    if( !data->shared )
-        data->free_train_data();
+        if( !data->shared )
+            data->free_train_data();
 
-    result = true;
+        result = true;
+    }
 
     __END__;
 
index 545ecb9..e449fac 100644 (file)
@@ -27,6 +27,115 @@ static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, con
 static const int MinBlockSize = 1 << 16;
 static const int BlockSizeDelta = 1 << 10;
 
+// TODO remove this code duplication with ml/precomp.hpp
+
+static int CV_CDECL icvCmpIntegers( const void* a, const void* b )
+{
+    return *(const int*)a - *(const int*)b;
+}
+
+static CvMat* cvPreprocessIndexArray( const CvMat* idx_arr, int data_arr_size, bool check_for_duplicates=false )
+{
+    CvMat* idx = 0;
+
+    CV_FUNCNAME( "cvPreprocessIndexArray" );
+
+    __BEGIN__;
+
+    int i, idx_total, idx_selected = 0, step, type, prev = INT_MIN, is_sorted = 1;
+    uchar* srcb = 0;
+    int* srci = 0;
+    int* dsti;
+
+    if( !CV_IS_MAT(idx_arr) )
+        CV_ERROR( CV_StsBadArg, "Invalid index array" );
+
+    if( idx_arr->rows != 1 && idx_arr->cols != 1 )
+        CV_ERROR( CV_StsBadSize, "the index array must be 1-dimensional" );
+
+    idx_total = idx_arr->rows + idx_arr->cols - 1;
+    srcb = idx_arr->data.ptr;
+    srci = idx_arr->data.i;
+
+    type = CV_MAT_TYPE(idx_arr->type);
+    step = CV_IS_MAT_CONT(idx_arr->type) ? 1 : idx_arr->step/CV_ELEM_SIZE(type);
+
+    switch( type )
+    {
+    case CV_8UC1:
+    case CV_8SC1:
+        // idx_arr is array of 1's and 0's -
+        // i.e. it is a mask of the selected components
+        if( idx_total != data_arr_size )
+            CV_ERROR( CV_StsUnmatchedSizes,
+            "Component mask should contain as many elements as the total number of input variables" );
+
+        for( i = 0; i < idx_total; i++ )
+            idx_selected += srcb[i*step] != 0;
+
+        if( idx_selected == 0 )
+            CV_ERROR( CV_StsOutOfRange, "No components/input_variables is selected!" );
+
+        break;
+    case CV_32SC1:
+        // idx_arr is array of integer indices of selected components
+        if( idx_total > data_arr_size )
+            CV_ERROR( CV_StsOutOfRange,
+            "index array may not contain more elements than the total number of input variables" );
+        idx_selected = idx_total;
+        // check if sorted already
+        for( i = 0; i < idx_total; i++ )
+        {
+            int val = srci[i*step];
+            if( val >= prev )
+            {
+                is_sorted = 0;
+                break;
+            }
+            prev = val;
+        }
+        break;
+    default:
+        CV_ERROR( CV_StsUnsupportedFormat, "Unsupported index array data type "
+                                           "(it should be 8uC1, 8sC1 or 32sC1)" );
+    }
+
+    CV_CALL( idx = cvCreateMat( 1, idx_selected, CV_32SC1 ));
+    dsti = idx->data.i;
+
+    if( type < CV_32SC1 )
+    {
+        for( i = 0; i < idx_total; i++ )
+            if( srcb[i*step] )
+                *dsti++ = i;
+    }
+    else
+    {
+        for( i = 0; i < idx_total; i++ )
+            dsti[i] = srci[i*step];
+
+        if( !is_sorted )
+            qsort( dsti, idx_total, sizeof(dsti[0]), icvCmpIntegers );
+
+        if( dsti[0] < 0 || dsti[idx_total-1] >= data_arr_size )
+            CV_ERROR( CV_StsOutOfRange, "the index array elements are out of range" );
+
+        if( check_for_duplicates )
+        {
+            for( i = 1; i < idx_total; i++ )
+                if( dsti[i] <= dsti[i-1] )
+                    CV_ERROR( CV_StsBadArg, "There are duplicated index array elements" );
+        }
+    }
+
+    __END__;
+
+    if( cvGetErrStatus() < 0 )
+        cvReleaseMat( &idx );
+
+    return idx;
+}
+
 //----------------------------- CascadeBoostParams -------------------------------------------------
 
 CvCascadeBoostParams::CvCascadeBoostParams() : minHitRate( 0.995F), maxFalseAlarm( 0.5F )
@@ -153,6 +262,171 @@ bool CvCascadeBoostParams::scanAttr( const String prmName, const String val)
     return res;        
 }
 
+CvDTreeNode* CvCascadeBoostTrainData::subsample_data( const CvMat* _subsample_idx )
+{
+    CvDTreeNode* root = 0;
+    CvMat* isubsample_idx = 0;
+    CvMat* subsample_co = 0;
+
+    bool isMakeRootCopy = true;
+
+    if( !data_root )
+        CV_Error( CV_StsError, "No training data has been set" );
+
+    if( _subsample_idx )
+    {
+        CV_Assert( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ) );
+
+        if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count )
+        {
+            const int* sidx = isubsample_idx->data.i;
+            for( int i = 0; i < sample_count; i++ )
+            {
+                if( sidx[i] != i )
+                {
+                    isMakeRootCopy = false;
+                    break;
+                }
+            }
+        }
+        else
+            isMakeRootCopy = false;
+    }
+
+    if( isMakeRootCopy )
+    {
+        // make a copy of the root node
+        CvDTreeNode temp;
+        int i;
+        root = new_node( 0, 1, 0, 0 );
+        temp = *root;
+        *root = *data_root;
+        root->num_valid = temp.num_valid;
+        if( root->num_valid )
+        {
+            for( i = 0; i < var_count; i++ )
+                root->num_valid[i] = data_root->num_valid[i];
+        }
+        root->cv_Tn = temp.cv_Tn;
+        root->cv_node_risk = temp.cv_node_risk;
+        root->cv_node_error = temp.cv_node_error;
+    }
+    else
+    {
+        int* sidx = isubsample_idx->data.i;
+        // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
+        int* co, cur_ofs = 0;
+        int workVarCount = get_work_var_count();
+        int count = isubsample_idx->rows + isubsample_idx->cols - 1;
+
+        root = new_node( 0, count, 1, 0 );
+
+        CV_Assert( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
+        cvZero( subsample_co );
+        co = subsample_co->data.i;
+        for( int i = 0; i < count; i++ )
+            co[sidx[i]*2]++;
+        for( int i = 0; i < sample_count; i++ )
+        {
+            if( co[i*2] )
+            {
+                co[i*2+1] = cur_ofs;
+                cur_ofs += co[i*2];
+            }
+            else
+                co[i*2+1] = -1;
+        }
+
+        cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
+        // subsample ordered variables
+        for( int vi = 0; vi < numPrecalcIdx; vi++ )
+        {
+            int ci = get_var_type(vi);
+            CV_Assert( ci < 0 );
+
+            int *src_idx_buf = (int*)(uchar*)inn_buf;
+            float *src_val_buf = (float*)(src_idx_buf + sample_count);
+            int* sample_indices_buf = (int*)(src_val_buf + sample_count);
+            const int* src_idx = 0;
+            const float* src_val = 0;
+            get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf );
+
+            int j = 0, idx, count_i;
+            int num_valid = data_root->get_num_valid(vi);
+            CV_Assert( num_valid == sample_count );
+
+            if (is_buf_16u)
+            {
+                unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
+                    vi*sample_count + data_root->offset);
+                for( int i = 0; i < num_valid; i++ )
+                {
+                    idx = src_idx[i];
+                    count_i = co[idx*2];
+                    if( count_i )
+                        for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
+                            udst_idx[j] = (unsigned short)cur_ofs;
+                }
+            }
+            else
+            {
+                int* idst_idx = buf->data.i + root->buf_idx*buf->cols +
+                    vi*sample_count + root->offset;
+                for( int i = 0; i < num_valid; i++ )
+                {
+                    idx = src_idx[i];
+                    count_i = co[idx*2];
+                    if( count_i )
+                        for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
+                            idst_idx[j] = cur_ofs;
+                }
+            }
+        }
+
+        // subsample cv_lables
+        const int* src_lbls = get_cv_labels(data_root, (int*)(uchar*)inn_buf);
+        if (is_buf_16u)
+        {
+            unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
+                (workVarCount-1)*sample_count + root->offset);
+            for( int i = 0; i < count; i++ )
+                udst[i] = (unsigned short)src_lbls[sidx[i]];
+        }
+        else
+        {
+            int* idst = buf->data.i + root->buf_idx*buf->cols +
+                (workVarCount-1)*sample_count + root->offset;
+            for( int i = 0; i < count; i++ )
+                idst[i] = src_lbls[sidx[i]];
+        }
+
+        // subsample sample_indices
+        const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf);
+        if (is_buf_16u)
+        {
+            unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
+                workVarCount*sample_count + root->offset);
+            for( int i = 0; i < count; i++ )
+                sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
+        }
+        else
+        {
+            int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols +
+                workVarCount*sample_count + root->offset;
+            for( int i = 0; i < count; i++ )
+                sample_idx_dst[i] = sample_idx_src[sidx[i]];
+        }
+
+        for( int vi = 0; vi < var_count; vi++ )
+            root->set_num_valid(vi, count);
+    }
+
+    cvReleaseMat( &isubsample_idx );
+    cvReleaseMat( &subsample_co );
+
+    return root;
+}
+
 //---------------------------- CascadeBoostTrainData -----------------------------
 
 CvCascadeBoostTrainData::CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
@@ -270,8 +544,8 @@ void CvCascadeBoostTrainData::setData( const CvFeatureEvaluator* _featureEvaluat
     }
     var_type->data.i[var_count] = cat_var_count;
     var_type->data.i[var_count+1] = cat_var_count+1;
-    work_var_count = ( cat_var_count ? 0 : numPrecalcIdx ) + 1;
-    buf_size = (work_var_count + 1) * sample_count;
+    work_var_count = ( cat_var_count ? 0 : numPrecalcIdx ) + 1/*cv_lables*/;
+    buf_size = (work_var_count + 1) * sample_count/*sample_indices*/;
     buf_count = 2;
     
     if ( is_buf_16u )
@@ -814,10 +1088,7 @@ void CvCascadeBoostTree::split_node_data( CvDTreeNode* node )
                     ldst++;
                 }
             }
-            assert( n1 == n);
-
-            left->set_num_valid(vi, (int)(ldst - ldst0));
-            right->set_num_valid(vi, (int)(rdst - rdst0));
+            CV_Assert( n1 == n );
         }   
         else
         {
@@ -844,10 +1115,7 @@ void CvCascadeBoostTree::split_node_data( CvDTreeNode* node )
                     ldst++;
                 }
             }
-
-            left->set_num_valid(vi, (int)(ldst - ldst0));
-            right->set_num_valid(vi, (int)(rdst - rdst0));
-            CV_Assert( n1 == n);
+            CV_Assert( n1 == n );
         }  
     }
 
@@ -860,11 +1128,11 @@ void CvCascadeBoostTree::split_node_data( CvDTreeNode* node )
 
     if (data->is_buf_16u)
     {
-        unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + 
+        unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols +
             (workVarCount-1)*scount + left->offset);
-        unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + 
-            (workVarCount-1)*scount + right->offset);            
-        
+        unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols +
+            (workVarCount-1)*scount + right->offset);
+
         for( int i = 0; i < n; i++ )
         {
             int idx = tempBuf[i];
@@ -883,11 +1151,11 @@ void CvCascadeBoostTree::split_node_data( CvDTreeNode* node )
     }
     else
     {
-        int *ldst = buf->data.i + left->buf_idx*buf->cols + 
+        int *ldst = buf->data.i + left->buf_idx*buf->cols +
             (workVarCount-1)*scount + left->offset;
-        int *rdst = buf->data.i + right->buf_idx*buf->cols + 
+        int *rdst = buf->data.i + right->buf_idx*buf->cols +
             (workVarCount-1)*scount + right->offset;
-        
+
         for( int i = 0; i < n; i++ )
         {
             int idx = tempBuf[i];
@@ -902,13 +1170,8 @@ void CvCascadeBoostTree::split_node_data( CvDTreeNode* node )
                 ldst++;
             }
         }
-    }        
-    for( int vi = 0; vi < data->var_count; vi++ )
-    {
-        left->set_num_valid(vi, (int)(nl));
-        right->set_num_valid(vi, (int)(nr));
     }
-
+    
     // split sample indices
     int *sampleIdx_src_buf = tempBuf + n;
     const int* sampleIdx_src = data->get_sample_indices(node, sampleIdx_src_buf);
@@ -959,6 +1222,12 @@ void CvCascadeBoostTree::split_node_data( CvDTreeNode* node )
         }
     }
 
+    for( int vi = 0; vi < data->var_count; vi++ )
+    {
+        left->set_num_valid(vi, (int)(nl));
+        right->set_num_valid(vi, (int)(nr));
+    }
+
     // deallocate the parent node data that is not needed anymore
     data->free_node_data(node); 
 }
@@ -1008,10 +1277,8 @@ bool CvCascadeBoost::train( const CvFeatureEvaluator* _featureEvaluator,
         CvCascadeBoostTree* tree = new CvCascadeBoostTree;
         if( !tree->train( data, subsample_mask, this ) )
         {
-            // TODO: may be should finish the loop (!!!)
-            assert(0);
             delete tree;
-            continue;
+            break;
         }
         cvSeqPush( weak, &tree );
         update_weights( tree );
index d4c3689..03dce69 100644 (file)
@@ -32,6 +32,8 @@ struct CvCascadeBoostTrainData : CvDTreeTrainData
                           const CvDTreeParams& _params=CvDTreeParams() );
     void precalculate();
 
+    virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
+
     virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf );
     virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf);
     virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf );
@@ -67,7 +69,7 @@ public:
                         const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
     virtual float predict( int sampleIdx, bool returnSum = false ) const;
 
-    float getThreshold() const { return threshold; }
+    float getThreshold() const { return threshold; }
     void write( FileStorage &fs, const Mat& featureMap ) const;
     bool read( const FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
                const CvCascadeBoostParams& _params );