Fix epsilon comparison when splitting
authorEvan Heidtmann <evan.heidtmann@gmail.com>
Tue, 22 Mar 2016 00:33:36 +0000 (17:33 -0700)
committerEvan Heidtmann <evan.heidtmann@gmail.com>
Mon, 28 Mar 2016 21:16:32 +0000 (14:16 -0700)
modules/ml/src/tree.cpp

index 143e1fb..f803d25 100644 (file)
@@ -638,7 +638,6 @@ void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
 
 DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
 {
-    const double epsilon = FLT_EPSILON*2;
     int n = (int)_sidx.size();
     int m = (int)classLabels.size();
 
@@ -688,7 +687,8 @@ DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _si
         rsum2 -= 2*rv*wval - w2;
         lcw[idx] = lv + wval; rcw[idx] = rv - wval;
 
-        if( values[curr] + epsilon < values[next] )
+        float value_between = (values[next] + values[curr]) * 0.5f;
+        if( value_between > values[curr] && value_between < values[next] )
         {
             double val = (lsum2*R + rsum2*L)/(L*R);
             if( best_val < val )
@@ -985,7 +985,6 @@ DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _si
 
 DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
 {
-    const float epsilon = FLT_EPSILON*2;
     const double* weights = &w->sample_weights[0];
     int n = (int)_sidx.size();
 
@@ -1021,7 +1020,8 @@ DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx
         L += wval; R -= wval;
         lsum += t; rsum -= t;
 
-        if( values[curr] + epsilon < values[next] )
+        float value_between = (values[next] + values[curr]) * 0.5f;
+        if( value_between > values[curr] && value_between < values[next] )
         {
             double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
             if( best_val < val )