added the optional balanced cross-validation in SVN::train_auto (by arman, ticket...
authorVadim Pisarevsky <no@email>
Mon, 29 Nov 2010 22:37:02 +0000 (22:37 +0000)
committerVadim Pisarevsky <no@email>
Mon, 29 Nov 2010 22:37:02 +0000 (22:37 +0000)
modules/ml/include/opencv2/ml/ml.hpp
modules/ml/src/svm.cpp

index bf3dd8f..7fbd2d8 100644 (file)
@@ -540,7 +540,8 @@ public:
         CvParamGrid pGrid      = get_default_grid(CvSVM::P),
         CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
         CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
-        CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE) );
+        CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
+        bool balanced=false );
 
     virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
 
@@ -561,7 +562,8 @@ public:
                             CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
                             CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
                             CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
-                            CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE) );
+                            CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
+                            bool balanced=false);
     CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;    
 #endif
     
index 7da0af7..bc0cff9 100644 (file)
@@ -1593,10 +1593,27 @@ bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
     return ok;
 }
 
+struct indexedratio 
+{
+    double val;
+    int ind;
+    int count_smallest, count_biggest;
+    void eval() { val = (double) count_smallest/(count_smallest+count_biggest); }
+};
+
+static int CV_CDECL
+icvCmpIndexedratio( const void* a, const void* b )
+{
+    return ((const indexedratio*)a)->val < ((const indexedratio*)b)->val ? -1
+    : ((const indexedratio*)a)->val > ((const indexedratio*)b)->val ? 1
+    : 0;
+}
+
 bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
     CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
-    CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
+    CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid,
+    bool balanced)
 {
     bool ok = false;
     CvMat* responses = 0;
@@ -1757,6 +1774,105 @@ bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
         else
             CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
     }
+        
+    if (!is_regression && class_labels->cols==2 && balanced)
+    {
+        // count class samples
+        int num_0=0,num_1=0;
+        for (i=0; i<sample_count; ++i)
+        {
+            if (responses->data.i[i]==class_labels->data.i[0])
+                ++num_0;
+            else
+                ++num_1;
+        }
+        
+        int label_smallest_class;
+        int label_biggest_class;
+        if (num_0 < num_1)
+        {
+            label_biggest_class = class_labels->data.i[1];
+            label_smallest_class = class_labels->data.i[0]; 
+        }
+        else
+        {
+            label_biggest_class = class_labels->data.i[0];
+            label_smallest_class = class_labels->data.i[1];
+            int y;
+            CV_SWAP(num_0,num_1,y);
+        }
+        const double class_ratio = (double) num_0/sample_count;
+        // calculate class ratio of each fold
+        indexedratio *ratios=0;
+        ratios = (indexedratio*) cvAlloc(k_fold*sizeof(*ratios));
+        for (int k=0, i_begin=0; k<k_fold; ++k, i_begin+=testset_size)
+        {
+            int count0=0;
+            int count1=0;
+            int i_end = i_begin + (k<k_fold-1 ? testset_size : last_testset_size);
+            for (int i=i_begin; i<i_end; ++i)
+            {
+                if (responses->data.i[i]==label_smallest_class)
+                    ++count0;
+                else
+                    ++count1;
+            }
+            ratios[k].ind = k;
+            ratios[k].count_smallest = count0;
+            ratios[k].count_biggest = count1;
+            ratios[k].eval();
+        }
+        // initial distance
+        qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio);
+        double old_dist = 0.0;
+        for (int k=0; k<k_fold; ++k)
+            old_dist += abs(ratios[k].val-class_ratio);
+        double new_dist = 1.0;
+        // iterate to make the folds more balanced
+        while (new_dist > 0.0)
+        {
+            if (ratios[0].count_biggest==0 || ratios[k_fold-1].count_smallest==0)
+                break; // we are not able to swap samples anymore
+            // what if we swap the samples, calculate the new distance
+            ratios[0].count_smallest++;
+            ratios[0].count_biggest--;
+            ratios[0].eval();
+            ratios[k_fold-1].count_smallest--;
+            ratios[k_fold-1].count_biggest++;
+            ratios[k_fold-1].eval();
+            qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio);
+            new_dist = 0.0;
+            for (int k=0; k<k_fold; ++k)
+                new_dist += abs(ratios[k].val-class_ratio);
+            if (new_dist < old_dist)
+            {
+                // swapping really improves, so swap the samples
+                // index of the biggest_class sample from the minimum ratio fold
+                int i1 = ratios[0].ind * testset_size;
+                for ( ; i1<sample_count; ++i1)
+                {
+                    if (responses->data.i[i1]==label_biggest_class)
+                        break;
+                }
+                // index of the smallest_class sample from the maximum ratio fold
+                int i2 = ratios[k_fold-1].ind * testset_size;
+                for ( ; i2<sample_count; ++i2)
+                {
+                    if (responses->data.i[i2]==label_smallest_class)
+                        break;
+                }
+                // swap
+                const float* temp;
+                int y;
+                CV_SWAP( samples[i1], samples[i2], temp );
+                CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
+                old_dist = new_dist;
+            }
+            else
+                break; // does not improve, so break the loop
+        }
+        cvFree(&ratios);
+    }
 
     int* cls_lbls = class_labels ? class_labels->data.i : 0;
     C = C_grid.min_val;
@@ -2011,12 +2127,12 @@ bool CvSVM::train( const Mat& _train_data, const Mat& _responses,
 bool CvSVM::train_auto( const Mat& _train_data, const Mat& _responses,
                        const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params, int k_fold,
                        CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
-                       CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
+                       CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid, bool balanced )
 {
     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
     return train_auto(&tdata, &responses, vidx.data.ptr ? &vidx : 0,
                       sidx.data.ptr ? &sidx : 0, _params, k_fold, C_grid, gamma_grid, p_grid,
-                      nu_grid, coef_grid, degree_grid);
+                      nu_grid, coef_grid, degree_grid, balanced);
 }
 
 float CvSVM::predict( const Mat& _sample, bool returnDFVal ) const