Merge pull request #18126 from danielenricocahall:add-oob-error-sample-weighting
authorDanny <33044223+danielenricocahall@users.noreply.github.com>
Sat, 5 Sep 2020 18:52:10 +0000 (14:52 -0400)
committerGitHub <noreply@github.com>
Sat, 5 Sep 2020 18:52:10 +0000 (18:52 +0000)
Account for sample weights in calculating OOB Error

* account for sample weights in oob error calculation

* redefine oob error functions

* fix ABI compatibility

modules/ml/include/opencv2/ml.hpp
modules/ml/src/rtrees.cpp
modules/ml/test/test_rtrees.cpp

index adbd84682edcdbd96965f8ff28e9b2b754c0d1d6..396a7921195df93fdae795d5751f953216f36de0 100644 (file)
@@ -1294,6 +1294,15 @@ public:
     */
     CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
 
+    /** Returns the OOB error value, computed at the training stage when calcOOBError is set to true.
+     * If this flag was set to false, 0 is returned. The OOB error is also scaled by sample weighting.
+     */
+#if CV_VERSION_MAJOR == 3
+    CV_WRAP double getOOBError() const;
+#else
+    /*CV_WRAP*/ virtual double getOOBError() const = 0;
+#endif
+
     /** Creates the empty model.
     Use StatModel::train to train the model, StatModel::train to create and train the model,
     Algorithm::load to load the pre-trained model.
index 27ec096bc005cf948d6e4388c20dcea2b64a74f6..1deee6f6c8e2a251354e476bf68885b4cfd9ff87 100644 (file)
@@ -216,13 +216,14 @@ public:
                     sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
 
                     double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
+                    double sample_weight = w->sample_weights[w->sidx[j]];
                     if( !_isClassifier )
                     {
                         oobres[j] += val;
                         oobcount[j]++;
                         double true_val = w->ord_responses[w->sidx[j]];
                         double a = oobres[j]/oobcount[j] - true_val;
-                        oobError += a*a;
+                        oobError += sample_weight * a*a;
                         val = (val - true_val)/max_response;
                         ncorrect_responses += std::exp( -val*val );
                     }
@@ -237,7 +238,7 @@ public:
                             if( votes[best_class] < votes[k] )
                                 best_class = k;
                         int diff = best_class != w->cat_responses[w->sidx[j]];
-                        oobError += diff;
+                        oobError += sample_weight * diff;
                         ncorrect_responses += diff == 0;
                     }
                 }
@@ -421,6 +422,10 @@ public:
         }
     }
 
+    double getOOBError() const {
+        return oobError;
+    }
+
     RTreeParams rparams;
     double oobError;
     vector<float> varImportance;
@@ -505,6 +510,12 @@ public:
     const vector<Node>& getNodes() const CV_OVERRIDE { return impl.getNodes(); }
     const vector<Split>& getSplits() const CV_OVERRIDE { return impl.getSplits(); }
     const vector<int>& getSubsets() const CV_OVERRIDE { return impl.getSubsets(); }
+#if CV_VERSION_MAJOR == 3
+    double getOOBError_() const { return impl.getOOBError(); }
+#else
+    double getOOBError() const CV_OVERRIDE { return impl.getOOBError(); }
+#endif
+
 
     DTreesImplForRTrees impl;
 };
@@ -532,6 +543,17 @@ void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
     return this_->getVotes_(input, output, flags);
 }
 
+#if CV_VERSION_MAJOR == 3
+double RTrees::getOOBError() const
+{
+    CV_TRACE_FUNCTION();
+    const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
+    if(!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
+    return this_->getOOBError_();
+}
+#endif
+
 }}
 
 // End of file.
index ebf0c465570ee626df4a900623b989f06c3d233f..1ec9b8d042fd25f2b0c7f59c82ff46a49263f1f7 100644 (file)
@@ -51,4 +51,50 @@ TEST(ML_RTrees, getVotes)
     EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
 }
 
+TEST(ML_RTrees, 11142_sample_weights_regression)
+{
+    int n = 3;
+    // RTrees for regression
+    Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
+    //simple regression problem of x -> 2x
+    Mat data = (Mat_<float>(n,1) << 1, 2, 3);
+    Mat values = (Mat_<float>(n,1) << 2, 4, 6);
+    Mat weights = (Mat_<float>(n, 1) << 10, 10, 10);
+
+    Ptr<TrainData> trainData = TrainData::create(data, ml::ROW_SAMPLE, values);
+    rt->train(trainData);
+    double error_without_weights = round(rt->getOOBError());
+    rt->clear();
+    Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, values, Mat(), Mat(), weights );
+    rt->train(trainDataWithWeights);
+    double error_with_weights = round(rt->getOOBError());
+    // error with weights should be larger than error without weights
+    EXPECT_GE(error_with_weights, error_without_weights);
+}
+
+TEST(ML_RTrees, 11142_sample_weights_classification)
+{
+    int n = 12;
+    // RTrees for classification
+    Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
+
+    Mat data(n, 4, CV_32F);
+    randu(data, 0, 10);
+    Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
+    Mat weights = (Mat_<float>(n, 1) << 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10);
+
+    rt->train(data, ml::ROW_SAMPLE, labels);
+    rt->clear();
+    double error_without_weights = round(rt->getOOBError());
+    Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, labels, Mat(), Mat(), weights );
+    rt->train(data, ml::ROW_SAMPLE, labels);
+    double error_with_weights = round(rt->getOOBError());
+    std::cout << error_without_weights << std::endl;
+    std::cout << error_with_weights << std::endl;
+    // error with weights should be larger than error without weights
+    EXPECT_GE(error_with_weights, error_without_weights);
+}
+
+
+
 }} // namespace