Merge pull request #19884 from danielenricocahall:fix-prediction-features-bug
authorDanny <33044223+danielenricocahall@users.noreply.github.com>
Fri, 9 Apr 2021 16:56:14 +0000 (12:56 -0400)
committerGitHub <noreply@github.com>
Fri, 9 Apr 2021 16:56:14 +0000 (16:56 +0000)
Fix bug with predictions in RTrees/Boost

* address bug where predict functions with invalid feature count in rtrees/boost models

* compact matrix rep in tests

* check 1..n-1 and n+1 in feature size validation test

modules/ml/src/boost.cpp
modules/ml/src/rtrees.cpp
modules/ml/test/test_rtrees.cpp

index 4b94410eeb5a212f2aa15cdfbbd770241f80e17d..58f572b90d0ee6a6caa5a1056c769d12e115162f 100644 (file)
@@ -490,6 +490,7 @@ public:
 
     float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
     {
+        CV_Assert( samples.cols() == getVarCount() && samples.type() == CV_32F );
         return impl.predict(samples, results, flags);
     }
 
index 1deee6f6c8e2a251354e476bf68885b4cfd9ff87..46af37ce11900a0f8a1e25e88b1b8127dcec8ea7 100644 (file)
@@ -479,6 +479,7 @@ public:
     float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
     {
         CV_TRACE_FUNCTION();
+        CV_Assert( samples.cols() == getVarCount() && samples.type() == CV_32F );
         return impl.predict(samples, results, flags);
     }
 
index 1ec9b8d042fd25f2b0c7f59c82ff46a49263f1f7..5a4fb34e744b6b28cc3db6cc9db5ef373befa491 100644 (file)
@@ -95,6 +95,25 @@ TEST(ML_RTrees, 11142_sample_weights_classification)
     EXPECT_GE(error_with_weights, error_without_weights);
 }
 
+TEST(ML_RTrees, bug_12974_throw_exception_when_predict_different_feature_count)
+{
+    int numFeatures = 5;
+    // create a 5 feature dataset and train the model
+    cv::Ptr<RTrees> model = RTrees::create();
+    Mat samples(10, numFeatures, CV_32F);
+    randu(samples, 0, 10);
+    Mat labels = (Mat_<int>(10,1) << 0,0,0,0,0,1,1,1,1,1);
+    cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, labels);
+    model->train(trainData);
+    // try to predict on data which have fewer features - this should throw an exception
+    for(int i = 1; i < numFeatures - 1; ++i) {
+        Mat test(1, i, CV_32FC1);
+        ASSERT_THROW(model->predict(test), Exception);
+    }
+    // try to predict on data which have more features - this should also throw an exception
+    Mat test(1, numFeatures + 1, CV_32FC1);
+    ASSERT_THROW(model->predict(test), Exception);
+}
 
 
 }} // namespace