From c31164bf1e8f26bd87c06af9f335d3beca022078 Mon Sep 17 00:00:00 2001 From: Danny <33044223+danielenricocahall@users.noreply.github.com> Date: Sat, 5 Sep 2020 14:52:10 -0400 Subject: [PATCH] Merge pull request #18126 from danielenricocahall:add-oob-error-sample-weighting Account for sample weights in calculating OOB Error * account for sample weights in oob error calculation * redefine oob error functions * fix ABI compatibility --- modules/ml/include/opencv2/ml.hpp | 9 ++++++++ modules/ml/src/rtrees.cpp | 26 ++++++++++++++++++++-- modules/ml/test/test_rtrees.cpp | 46 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 2 deletions(-) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index adbd846..396a792 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -1294,6 +1294,15 @@ public: */ CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const; + /** Returns the OOB error value, computed at the training stage when calcOOBError is set to true. + * If this flag was set to false, 0 is returned. The OOB error is also scaled by sample weighting. + */ +#if CV_VERSION_MAJOR == 3 + CV_WRAP double getOOBError() const; +#else + /*CV_WRAP*/ virtual double getOOBError() const = 0; +#endif + /** Creates the empty model. Use StatModel::train to train the model, StatModel::train to create and train the model, Algorithm::load to load the pre-trained model. diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index 27ec096..1deee6f 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -216,13 +216,14 @@ public: sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) ); double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags); + double sample_weight = w->sample_weights[w->sidx[j]]; if( !_isClassifier ) { oobres[j] += val; oobcount[j]++; double true_val = w->ord_responses[w->sidx[j]]; double a = oobres[j]/oobcount[j] - true_val; - oobError += a*a; + oobError += sample_weight * a*a; val = (val - true_val)/max_response; ncorrect_responses += std::exp( -val*val ); } @@ -237,7 +238,7 @@ public: if( votes[best_class] < votes[k] ) best_class = k; int diff = best_class != w->cat_responses[w->sidx[j]]; - oobError += diff; + oobError += sample_weight * diff; ncorrect_responses += diff == 0; } } @@ -421,6 +422,10 @@ public: } } + double getOOBError() const { + return oobError; + } + RTreeParams rparams; double oobError; vector varImportance; @@ -505,6 +510,12 @@ public: const vector& getNodes() const CV_OVERRIDE { return impl.getNodes(); } const vector& getSplits() const CV_OVERRIDE { return impl.getSplits(); } const vector& getSubsets() const CV_OVERRIDE { return impl.getSubsets(); } +#if CV_VERSION_MAJOR == 3 + double getOOBError_() const { return impl.getOOBError(); } +#else + double getOOBError() const CV_OVERRIDE { return impl.getOOBError(); } +#endif + DTreesImplForRTrees impl; }; @@ -532,6 +543,17 @@ void RTrees::getVotes(InputArray input, OutputArray output, int flags) const return this_->getVotes_(input, output, flags); } +#if CV_VERSION_MAJOR == 3 +double RTrees::getOOBError() const +{ + CV_TRACE_FUNCTION(); + const RTreesImpl* this_ = dynamic_cast(this); + if(!this_) + CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl"); + return this_->getOOBError_(); +} +#endif + }} // End of file. diff --git a/modules/ml/test/test_rtrees.cpp b/modules/ml/test/test_rtrees.cpp index ebf0c46..1ec9b8d 100644 --- a/modules/ml/test/test_rtrees.cpp +++ b/modules/ml/test/test_rtrees.cpp @@ -51,4 +51,50 @@ TEST(ML_RTrees, getVotes) EXPECT_EQ(result.at(0, predicted_class), rt->predict(test)); } +TEST(ML_RTrees, 11142_sample_weights_regression) +{ + int n = 3; + // RTrees for regression + Ptr rt = cv::ml::RTrees::create(); + //simple regression problem of x -> 2x + Mat data = (Mat_(n,1) << 1, 2, 3); + Mat values = (Mat_(n,1) << 2, 4, 6); + Mat weights = (Mat_(n, 1) << 10, 10, 10); + + Ptr trainData = TrainData::create(data, ml::ROW_SAMPLE, values); + rt->train(trainData); + double error_without_weights = round(rt->getOOBError()); + rt->clear(); + Ptr trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, values, Mat(), Mat(), weights ); + rt->train(trainDataWithWeights); + double error_with_weights = round(rt->getOOBError()); + // error with weights should be larger than error without weights + EXPECT_GE(error_with_weights, error_without_weights); +} + +TEST(ML_RTrees, 11142_sample_weights_classification) +{ + int n = 12; + // RTrees for classification + Ptr rt = cv::ml::RTrees::create(); + + Mat data(n, 4, CV_32F); + randu(data, 0, 10); + Mat labels = (Mat_(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2); + Mat weights = (Mat_(n, 1) << 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10); + + rt->train(data, ml::ROW_SAMPLE, labels); + rt->clear(); + double error_without_weights = round(rt->getOOBError()); + Ptr trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, labels, Mat(), Mat(), weights ); + rt->train(data, ml::ROW_SAMPLE, labels); + double error_with_weights = round(rt->getOOBError()); + std::cout << error_without_weights << std::endl; + std::cout << error_with_weights << std::endl; + // error with weights should be larger than error without weights + EXPECT_GE(error_with_weights, error_without_weights); +} + + + }} // namespace -- 2.7.4