From: Danny <33044223+danielenricocahall@users.noreply.github.com> Date: Fri, 9 Apr 2021 16:56:14 +0000 (-0400) Subject: Merge pull request #19884 from danielenricocahall:fix-prediction-features-bug X-Git-Tag: submit/tizen/20220120.021815~1^2~1^2~70 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a9a6801c6dcbe2baa989453e2c95c80388888879;p=platform%2Fupstream%2Fopencv.git Merge pull request #19884 from danielenricocahall:fix-prediction-features-bug Fix bug with predictions in RTrees/Boost * address bug where predict functions with invalid feature count in rtrees/boost models * compact matrix rep in tests * check 1..n-1 and n+1 in feature size validation test --- diff --git a/modules/ml/src/boost.cpp b/modules/ml/src/boost.cpp index 4b94410eeb..58f572b90d 100644 --- a/modules/ml/src/boost.cpp +++ b/modules/ml/src/boost.cpp @@ -490,6 +490,7 @@ public: float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE { + CV_Assert( samples.cols() == getVarCount() && samples.type() == CV_32F ); return impl.predict(samples, results, flags); } diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index 1deee6f6c8..46af37ce11 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -479,6 +479,7 @@ public: float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE { CV_TRACE_FUNCTION(); + CV_Assert( samples.cols() == getVarCount() && samples.type() == CV_32F ); return impl.predict(samples, results, flags); } diff --git a/modules/ml/test/test_rtrees.cpp b/modules/ml/test/test_rtrees.cpp index 1ec9b8d042..5a4fb34e74 100644 --- a/modules/ml/test/test_rtrees.cpp +++ b/modules/ml/test/test_rtrees.cpp @@ -95,6 +95,25 @@ TEST(ML_RTrees, 11142_sample_weights_classification) EXPECT_GE(error_with_weights, error_without_weights); } +TEST(ML_RTrees, bug_12974_throw_exception_when_predict_different_feature_count) +{ + int numFeatures = 5; + // create a 5 feature dataset and train the model + cv::Ptr model = RTrees::create(); + Mat samples(10, numFeatures, CV_32F); + randu(samples, 0, 10); + Mat labels = (Mat_(10,1) << 0,0,0,0,0,1,1,1,1,1); + cv::Ptr trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, labels); + model->train(trainData); + // try to predict on data which have fewer features - this should throw an exception + for(int i = 1; i < numFeatures - 1; ++i) { + Mat test(1, i, CV_32FC1); + ASSERT_THROW(model->predict(test), Exception); + } + // try to predict on data which have more features - this should also throw an exception + Mat test(1, numFeatures + 1, CV_32FC1); + ASSERT_THROW(model->predict(test), Exception); +} }} // namespace