From 9ddb23e02553ab1c117d581e31455fab039d714f Mon Sep 17 00:00:00 2001 From: Dmitriy Anisimov Date: Sat, 23 Aug 2014 18:41:32 +0400 Subject: [PATCH] first implementation KNearest wrapper on KDTree --- modules/ml/include/opencv2/ml.hpp | 8 +- modules/ml/src/knearest.cpp | 153 +++++++++++++++++++++++++++++- modules/ml/test/test_emknearestkmeans.cpp | 19 +++- 3 files changed, 173 insertions(+), 7 deletions(-) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index 696facb..ebd11c7 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -230,10 +230,11 @@ public: class CV_EXPORTS_W_MAP Params { public: - Params(int defaultK=10, bool isclassifier=true); + Params(int defaultK=10, bool isclassifier_=true, int Emax_=INT_MAX); CV_PROP_RW int defaultK; CV_PROP_RW bool isclassifier; + CV_PROP_RW int Emax; // for implementation with KDTree }; virtual void setParams(const Params& p) = 0; virtual Params getParams() const = 0; @@ -241,7 +242,10 @@ public: OutputArray results, OutputArray neighborResponses=noArray(), OutputArray dist=noArray() ) const = 0; - static Ptr create(const Params& params=Params()); + + enum { DEFAULT=1, KDTREE=2 }; + + static Ptr create(const Params& params=Params(), int type=DEFAULT); }; /****************************************************************************************\ diff --git a/modules/ml/src/knearest.cpp b/modules/ml/src/knearest.cpp index 3ead322..6d2bebf 100644 --- a/modules/ml/src/knearest.cpp +++ b/modules/ml/src/knearest.cpp @@ -49,10 +49,11 @@ namespace cv { namespace ml { -KNearest::Params::Params(int k, bool isclassifier_) +KNearest::Params::Params(int k, bool isclassifier_, int Emax_) { defaultK = k; isclassifier = isclassifier_; + Emax = Emax_; } @@ -352,8 +353,156 @@ public: Params params; }; -Ptr KNearest::create(const Params& p) + +class KNearestKDTreeImpl : public KNearest +{ +public: + KNearestKDTreeImpl(const Params& p) + { + params = p; + } + + virtual ~KNearestKDTreeImpl() {} + + Params getParams() const { return params; } + void setParams(const Params& p) { params = p; } + + bool isClassifier() const { return params.isclassifier; } + bool isTrained() const { return !samples.empty(); } + + String getDefaultModelName() const { return "opencv_ml_knn_kd"; } + + void clear() + { + samples.release(); + responses.release(); + } + + int getVarCount() const { return samples.cols; } + + bool train( const Ptr& data, int flags ) + { + Mat new_samples = data->getTrainSamples(ROW_SAMPLE); + Mat new_responses; + data->getTrainResponses().convertTo(new_responses, CV_32F); + bool update = (flags & UPDATE_MODEL) != 0 && !samples.empty(); + + CV_Assert( new_samples.type() == CV_32F ); + + if( !update ) + { + clear(); + } + else + { + CV_Assert( new_samples.cols == samples.cols && + new_responses.cols == responses.cols ); + } + + samples.push_back(new_samples); + responses.push_back(new_responses); + + tr.build(samples); + + return true; + } + + float findNearest( InputArray _samples, int k, + OutputArray _results, + OutputArray _neighborResponses, + OutputArray _dists ) const + { + float result = 0.f; + CV_Assert( 0 < k ); + + Mat test_samples = _samples.getMat(); + CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols ); + int testcount = test_samples.rows; + + if( testcount == 0 ) + { + _results.release(); + _neighborResponses.release(); + _dists.release(); + return 0.f; + } + + Mat res, nr, d; + if( _results.needed() ) + { + _results.create(testcount, 1, CV_32F); + res = _results.getMat(); + } + if( _neighborResponses.needed() ) + { + _neighborResponses.create(testcount, k, CV_32F); + nr = _neighborResponses.getMat(); + } + if( _dists.needed() ) + { + _dists.create(testcount, k, CV_32F); + d = _dists.getMat(); + } + + for (int i=0; ii) + { + _res = res.row(i); + } + if (nr.rows>i) + { + _nr = nr.row(i); + } + if (d.rows>i) + { + _d = d.row(i); + } + tr.findNearest(test_samples.row(i), k, params.Emax, _res, _nr, _d, noArray()); + } + + return result; // currently always 0 + } + + float predict(InputArray inputs, OutputArray outputs, int) const + { + return findNearest( inputs, params.defaultK, outputs, noArray(), noArray() ); + } + + void write( FileStorage& fs ) const + { + fs << "is_classifier" << (int)params.isclassifier; + fs << "default_k" << params.defaultK; + + fs << "samples" << samples; + fs << "responses" << responses; + } + + void read( const FileNode& fn ) + { + clear(); + params.isclassifier = (int)fn["is_classifier"] != 0; + params.defaultK = (int)fn["default_k"]; + + fn["samples"] >> samples; + fn["responses"] >> responses; + } + + KDTree tr; + + Mat samples; + Mat responses; + Params params; +}; + +Ptr KNearest::create(const Params& p, int type) { + if (KDTREE==type) + { + return makePtr(p); + } + return makePtr(p); } diff --git a/modules/ml/test/test_emknearestkmeans.cpp b/modules/ml/test/test_emknearestkmeans.cpp index 98b88c7..b404634 100644 --- a/modules/ml/test/test_emknearestkmeans.cpp +++ b/modules/ml/test/test_emknearestkmeans.cpp @@ -312,9 +312,11 @@ void CV_KNearestTest::run( int /*start_from*/ ) generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 ); int code = cvtest::TS::OK; - Ptr knearest = KNearest::create(true); - knearest->train(trainData, cv::ml::ROW_SAMPLE, trainLabels); - knearest->findNearest( testData, 4, bestLabels); + + // KNearest default implementation + Ptr knearest = KNearest::create(); + knearest->train(trainData, ml::ROW_SAMPLE, trainLabels); + knearest->findNearest(testData, 4, bestLabels); float err; if( !calcErr( bestLabels, testLabels, sizes, err, true ) ) { @@ -326,6 +328,17 @@ void CV_KNearestTest::run( int /*start_from*/ ) ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err ); code = cvtest::TS::FAIL_BAD_ACCURACY; } + + // KNearest KDTree implementation + Ptr knearestKdt = KNearest::create(ml::KNearest::Params(), ml::KNearest::KDTREE); + knearestKdt->train(trainData, ml::ROW_SAMPLE, trainLabels); + knearestKdt->findNearest(testData, 4, bestLabels); + if( !calcErr( bestLabels, testLabels, sizes, err, true ) ) + { + ts->printf( cvtest::TS::LOG, "Bad output labels.\n" ); + code = cvtest::TS::FAIL_INVALID_OUTPUT; + } + ts->set_failed_test_info( code ); } -- 2.7.4