first implementation KNearest wrapper on KDTree
authorDmitriy Anisimov <avdmitry@gmail.com>
Sat, 23 Aug 2014 14:41:32 +0000 (18:41 +0400)
committerDmitriy Anisimov <avdmitry@gmail.com>
Sat, 23 Aug 2014 14:41:32 +0000 (18:41 +0400)
modules/ml/include/opencv2/ml.hpp
modules/ml/src/knearest.cpp
modules/ml/test/test_emknearestkmeans.cpp

index 696facb..ebd11c7 100644 (file)
@@ -230,10 +230,11 @@ public:
     class CV_EXPORTS_W_MAP Params
     {
     public:
-        Params(int defaultK=10, bool isclassifier=true);
+        Params(int defaultK=10, bool isclassifier_=true, int Emax_=INT_MAX);
 
         CV_PROP_RW int defaultK;
         CV_PROP_RW bool isclassifier;
+        CV_PROP_RW int Emax; // for implementation with KDTree
     };
     virtual void setParams(const Params& p) = 0;
     virtual Params getParams() const = 0;
@@ -241,7 +242,10 @@ public:
                                OutputArray results,
                                OutputArray neighborResponses=noArray(),
                                OutputArray dist=noArray() ) const = 0;
-    static Ptr<KNearest> create(const Params& params=Params());
+
+    enum { DEFAULT=1, KDTREE=2 };
+
+    static Ptr<KNearest> create(const Params& params=Params(), int type=DEFAULT);
 };
 
 /****************************************************************************************\
index 3ead322..6d2bebf 100644 (file)
 namespace cv {
 namespace ml {
 
-KNearest::Params::Params(int k, bool isclassifier_)
+KNearest::Params::Params(int k, bool isclassifier_, int Emax_)
 {
     defaultK = k;
     isclassifier = isclassifier_;
+    Emax = Emax_;
 }
 
 
@@ -352,8 +353,156 @@ public:
     Params params;
 };
 
-Ptr<KNearest> KNearest::create(const Params& p)
+
+class KNearestKDTreeImpl : public KNearest
+{
+public:
+    KNearestKDTreeImpl(const Params& p)
+    {
+        params = p;
+    }
+
+    virtual ~KNearestKDTreeImpl() {}
+
+    Params getParams() const { return params; }
+    void setParams(const Params& p) { params = p; }
+
+    bool isClassifier() const { return params.isclassifier; }
+    bool isTrained() const { return !samples.empty(); }
+
+    String getDefaultModelName() const { return "opencv_ml_knn_kd"; }
+
+    void clear()
+    {
+        samples.release();
+        responses.release();
+    }
+
+    int getVarCount() const { return samples.cols; }
+
+    bool train( const Ptr<TrainData>& data, int flags )
+    {
+        Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
+        Mat new_responses;
+        data->getTrainResponses().convertTo(new_responses, CV_32F);
+        bool update = (flags & UPDATE_MODEL) != 0 && !samples.empty();
+
+        CV_Assert( new_samples.type() == CV_32F );
+
+        if( !update )
+        {
+            clear();
+        }
+        else
+        {
+            CV_Assert( new_samples.cols == samples.cols &&
+                       new_responses.cols == responses.cols );
+        }
+
+        samples.push_back(new_samples);
+        responses.push_back(new_responses);
+
+        tr.build(samples);
+
+        return true;
+    }
+
+    float findNearest( InputArray _samples, int k,
+                       OutputArray _results,
+                       OutputArray _neighborResponses,
+                       OutputArray _dists ) const
+    {
+        float result = 0.f;
+        CV_Assert( 0 < k );
+
+        Mat test_samples = _samples.getMat();
+        CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
+        int testcount = test_samples.rows;
+
+        if( testcount == 0 )
+        {
+            _results.release();
+            _neighborResponses.release();
+            _dists.release();
+            return 0.f;
+        }
+
+        Mat res, nr, d;
+        if( _results.needed() )
+        {
+            _results.create(testcount, 1, CV_32F);
+            res = _results.getMat();
+        }
+        if( _neighborResponses.needed() )
+        {
+            _neighborResponses.create(testcount, k, CV_32F);
+            nr = _neighborResponses.getMat();
+        }
+        if( _dists.needed() )
+        {
+            _dists.create(testcount, k, CV_32F);
+            d = _dists.getMat();
+        }
+
+        for (int i=0; i<test_samples.rows; ++i)
+        {
+            Mat _res, _nr, _d;
+            if (res.rows>i)
+            {
+                _res = res.row(i);
+            }
+            if (nr.rows>i)
+            {
+                _nr = nr.row(i);
+            }
+            if (d.rows>i)
+            {
+                _d = d.row(i);
+            }
+            tr.findNearest(test_samples.row(i), k, params.Emax, _res, _nr, _d, noArray());
+        }
+
+        return result; // currently always 0
+    }
+
+    float predict(InputArray inputs, OutputArray outputs, int) const
+    {
+        return findNearest( inputs, params.defaultK, outputs, noArray(), noArray() );
+    }
+
+    void write( FileStorage& fs ) const
+    {
+        fs << "is_classifier" << (int)params.isclassifier;
+        fs << "default_k" << params.defaultK;
+
+        fs << "samples" << samples;
+        fs << "responses" << responses;
+    }
+
+    void read( const FileNode& fn )
+    {
+        clear();
+        params.isclassifier = (int)fn["is_classifier"] != 0;
+        params.defaultK = (int)fn["default_k"];
+
+        fn["samples"] >> samples;
+        fn["responses"] >> responses;
+    }
+
+    KDTree tr;
+
+    Mat samples;
+    Mat responses;
+    Params params;
+};
+
+Ptr<KNearest> KNearest::create(const Params& p, int type)
 {
+    if (KDTREE==type)
+    {
+        return makePtr<KNearestKDTreeImpl>(p);
+    }
+
     return makePtr<KNearestImpl>(p);
 }
 
index 98b88c7..b404634 100644 (file)
@@ -312,9 +312,11 @@ void CV_KNearestTest::run( int /*start_from*/ )
     generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
 
     int code = cvtest::TS::OK;
-    Ptr<KNearest> knearest = KNearest::create(true);
-    knearest->train(trainData, cv::ml::ROW_SAMPLE, trainLabels);
-    knearest->findNearest( testData, 4, bestLabels);
+
+    // KNearest default implementation
+    Ptr<KNearest> knearest = KNearest::create();
+    knearest->train(trainData, ml::ROW_SAMPLE, trainLabels);
+    knearest->findNearest(testData, 4, bestLabels);
     float err;
     if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
     {
@@ -326,6 +328,17 @@ void CV_KNearestTest::run( int /*start_from*/ )
         ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
         code = cvtest::TS::FAIL_BAD_ACCURACY;
     }
+
+    // KNearest KDTree implementation
+    Ptr<KNearest> knearestKdt = KNearest::create(ml::KNearest::Params(), ml::KNearest::KDTREE);
+    knearestKdt->train(trainData, ml::ROW_SAMPLE, trainLabels);
+    knearestKdt->findNearest(testData, 4, bestLabels);
+    if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
+    {
+        ts->printf( cvtest::TS::LOG, "Bad output labels.\n" );
+        code = cvtest::TS::FAIL_INVALID_OUTPUT;
+    }
+
     ts->set_failed_test_info( code );
 }