ml: refactor non-virtual methods
authorberak <px1704@web.de>
Tue, 24 Apr 2018 10:11:59 +0000 (12:11 +0200)
committerberak <px1704@web.de>
Tue, 24 Apr 2018 11:23:27 +0000 (13:23 +0200)
modules/ml/include/opencv2/ml.hpp
modules/ml/src/data.cpp
modules/ml/src/rtrees.cpp
modules/ml/src/svm.cpp

index 2b694c8..357aac1 100644 (file)
@@ -198,7 +198,7 @@ public:
     CV_WRAP virtual Mat getTestSampleWeights() const = 0;
     CV_WRAP virtual Mat getVarIdx() const = 0;
     CV_WRAP virtual Mat getVarType() const = 0;
-    CV_WRAP Mat getVarSymbolFlags() const;
+    CV_WRAP virtual Mat getVarSymbolFlags() const = 0;
     CV_WRAP virtual int getResponseType() const = 0;
     CV_WRAP virtual Mat getTrainSampleIdx() const = 0;
     CV_WRAP virtual Mat getTestSampleIdx() const = 0;
@@ -234,10 +234,10 @@ public:
     CV_WRAP virtual void shuffleTrainTest() = 0;
 
     /** @brief Returns matrix of test samples */
-    CV_WRAP Mat getTestSamples() const;
+    CV_WRAP virtual Mat getTestSamples() const = 0;
 
     /** @brief Returns vector of symbolic names captured in loadFromCSV() */
-    CV_WRAP void getNames(std::vector<String>& names) const;
+    CV_WRAP virtual void getNames(std::vector<String>& names) const = 0;
 
     CV_WRAP static Mat getSubVector(const Mat& vec, const Mat& idx);
 
@@ -727,7 +727,7 @@ public:
     regression (SVM::EPS_SVR or SVM::NU_SVR). If it is SVM::ONE_CLASS, no optimization is made and
     the usual %SVM with parameters specified in params is executed.
     */
-    CV_WRAP bool trainAuto(InputArray samples,
+    CV_WRAP virtual bool trainAuto(InputArray samples,
             int layout,
             InputArray responses,
             int kFold = 10,
@@ -737,7 +737,7 @@ public:
             Ptr<ParamGrid> nuGrid     = SVM::getDefaultGridPtr(SVM::NU),
             Ptr<ParamGrid> coeffGrid  = SVM::getDefaultGridPtr(SVM::COEF),
             Ptr<ParamGrid> degreeGrid = SVM::getDefaultGridPtr(SVM::DEGREE),
-            bool balanced=false);
+            bool balanced=false) = 0;
 
     /** @brief Retrieves all the support vectors
 
@@ -752,7 +752,7 @@ public:
     support vector, used for prediction, was derived from. They are returned in a floating-point
     matrix, where the support vectors are stored as matrix rows.
      */
-    CV_WRAP Mat getUncompressedSupportVectors() const;
+    CV_WRAP virtual Mat getUncompressedSupportVectors() const = 0;
 
     /** @brief Retrieves the decision function
 
@@ -1273,7 +1273,7 @@ public:
         @param results Array where the result of the calculation will be written.
         @param flags Flags for defining the type of RTrees.
     */
-    CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
+    CV_WRAP virtual void getVotes(InputArray samples, OutputArray results, int flags) const = 0;
 
     /** Creates the empty model.
     Use StatModel::train to train the model, StatModel::train to create and train the model,
index cbd1c3f..1067c31 100644 (file)
@@ -50,13 +50,6 @@ static const int VAR_MISSED = VAR_ORDERED;
 
 TrainData::~TrainData() {}
 
-Mat TrainData::getTestSamples() const
-{
-    Mat idx = getTestSampleIdx();
-    Mat samples = getSamples();
-    return idx.empty() ? Mat() : getSubVector(samples, idx);
-}
-
 Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
 {
     if( idx.empty() )
@@ -119,6 +112,7 @@ Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
     return subvec;
 }
 
+
 class TrainDataImpl CV_FINAL : public TrainData
 {
 public:
@@ -155,6 +149,12 @@ public:
         return layout == ROW_SAMPLE ? samples.cols : samples.rows;
     }
 
+    Mat getTestSamples() const CV_OVERRIDE
+    {
+        Mat idx = getTestSampleIdx();
+        return idx.empty() ? Mat() : getSubVector(samples, idx);
+    }
+
     Mat getSamples() const CV_OVERRIDE { return samples; }
     Mat getResponses() const CV_OVERRIDE { return responses; }
     Mat getMissing() const CV_OVERRIDE { return missing; }
@@ -987,6 +987,27 @@ public:
         }
     }
 
+    void getNames(std::vector<String>& names) const CV_OVERRIDE
+    {
+        size_t n = nameMap.size();
+        TrainDataImpl::MapType::const_iterator it = nameMap.begin(),
+                                               it_end = nameMap.end();
+        names.resize(n+1);
+        names[0] = "?";
+        for( ; it != it_end; ++it )
+        {
+            String s = it->first;
+            int label = it->second;
+            CV_Assert( label > 0 && label <= (int)n );
+            names[label] = s;
+        }
+    }
+
+    Mat getVarSymbolFlags() const CV_OVERRIDE
+    {
+        return varSymbolFlags;
+    }
+
     FILE* file;
     int layout;
     Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst;
@@ -996,30 +1017,6 @@ public:
     MapType nameMap;
 };
 
-void TrainData::getNames(std::vector<String>& names) const
-{
-    const TrainDataImpl* impl = dynamic_cast<const TrainDataImpl*>(this);
-    CV_Assert(impl != 0);
-    size_t n = impl->nameMap.size();
-    TrainDataImpl::MapType::const_iterator it = impl->nameMap.begin(),
-                                           it_end = impl->nameMap.end();
-    names.resize(n+1);
-    names[0] = "?";
-    for( ; it != it_end; ++it )
-    {
-        String s = it->first;
-        int label = it->second;
-        CV_Assert( label > 0 && label <= (int)n );
-        names[label] = s;
-    }
-}
-
-Mat TrainData::getVarSymbolFlags() const
-{
-    const TrainDataImpl* impl = dynamic_cast<const TrainDataImpl*>(this);
-    CV_Assert(impl != 0);
-    return impl->varSymbolFlags;
-}
 
 Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
                                       int headerLines,
index 0751e37..cc5253e 100644 (file)
@@ -453,6 +453,7 @@ public:
     inline void setRegressionAccuracy(float val) CV_OVERRIDE { impl.params.setRegressionAccuracy(val); }
     inline cv::Mat getPriors() const CV_OVERRIDE { return impl.params.getPriors(); }
     inline void setPriors(const cv::Mat& val) CV_OVERRIDE { impl.params.setPriors(val); }
+    inline void getVotes(InputArray input, OutputArray output, int flags) const CV_OVERRIDE {return impl.getVotes(input,output,flags);}
 
     RTreesImpl() {}
     virtual ~RTreesImpl() CV_OVERRIDE {}
@@ -485,12 +486,6 @@ public:
         impl.read(fn);
     }
 
-    void getVotes_( InputArray samples, OutputArray results, int flags ) const
-    {
-        CV_TRACE_FUNCTION();
-        impl.getVotes(samples, results, flags);
-    }
-
     Mat getVarImportance() const CV_OVERRIDE { return Mat_<float>(impl.varImportance, true); }
     int getVarCount() const CV_OVERRIDE { return impl.getVarCount(); }
 
@@ -519,15 +514,6 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
     return Algorithm::load<RTrees>(filepath, nodeName);
 }
 
-void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
-{
-    CV_TRACE_FUNCTION();
-    const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
-    if(!this_)
-        CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
-    return this_->getVotes_(input, output, flags);
-}
-
 }}
 
 // End of file.
index d6518ef..b1a073b 100644 (file)
@@ -1250,7 +1250,7 @@ public:
         uncompressed_sv.release();
     }
 
-    Mat getUncompressedSupportVectors_() const
+    Mat getUncompressedSupportVectors() const CV_OVERRIDE
     {
         return uncompressed_sv;
     }
@@ -1982,10 +1982,10 @@ public:
         bool returnDFVal;
     };
 
-    bool trainAuto_(InputArray samples, int layout,
+    bool trainAuto(InputArray samples, int layout,
             InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
             Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
-            Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
+            Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced) CV_OVERRIDE
     {
         Ptr<TrainData> data = TrainData::create(samples, layout, responses);
         return this->trainAuto(
@@ -2353,26 +2353,6 @@ Ptr<SVM> SVM::load(const String& filepath)
     return svm;
 }
 
-Mat SVM::getUncompressedSupportVectors() const
-{
-    const SVMImpl* this_ = dynamic_cast<const SVMImpl*>(this);
-    if(!this_)
-        CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
-    return this_->getUncompressedSupportVectors_();
-}
-
-bool SVM::trainAuto(InputArray samples, int layout,
-            InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
-            Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
-            Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
-{
-  SVMImpl* this_ = dynamic_cast<SVMImpl*>(this);
-  if (!this_) {
-    CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
-  }
-  return this_->trainAuto_(samples, layout, responses,
-    kfold, Cgrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
-}
 
 }
 }