Refactored SVMSGD class
authorMarina Noskova <marina.noskova@itseez.com>
Wed, 20 Jan 2016 09:59:44 +0000 (12:59 +0300)
committerMarina Noskova <marina.noskova@itseez.com>
Wed, 10 Feb 2016 13:56:14 +0000 (16:56 +0300)
include/opencv2/opencv.hpp
modules/ml/include/opencv2/ml.hpp
modules/ml/include/opencv2/ml/svmsgd.hpp [new file with mode: 0644]
modules/ml/src/precomp.hpp
modules/ml/src/svmsgd.cpp
modules/ml/test/test_mltests2.cpp
modules/ml/test/test_precomp.hpp
modules/ml/test/test_save_load.cpp
modules/ml/test/test_svmsgd.cpp [new file with mode: 0644]
modules/ts/src/ts_gtest.cpp
samples/cpp/train_svmsgd.cpp [new file with mode: 0644]

index 49b6a66..e411621 100644 (file)
@@ -75,6 +75,7 @@
 #endif
 #ifdef HAVE_OPENCV_ML
 #include "opencv2/ml.hpp"
+#include "opencv2/ml/svmsgd.hpp"
 #endif
 
 #endif
index d5debdb..791f580 100644 (file)
@@ -1513,126 +1513,6 @@ CV_EXPORTS void randMVNormal( InputArray mean, InputArray cov, int nsamples, Out
 CV_EXPORTS void createConcentricSpheresTestSet( int nsamples, int nfeatures, int nclasses,
                                                 OutputArray samples, OutputArray responses);
 
-/****************************************************************************************\
-*                        Stochastic Gradient Descent SVM Classifier                      *
-\****************************************************************************************/
-
-/*!
-@brief Stochastic Gradient Descent SVM classifier
-
-SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large.
-The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which
-is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example).
-
-First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined.
-
-Then the SVM model can be trained using the train features and the correspondent labels.
-
-After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically.
-
-@code
-// Initialize object
-SVMSGD SvmSgd;
-
-// Train the Stochastic Gradient Descent SVM
-SvmSgd.train(trainFeatures, labels);
-
-// Predict label for the new feature vector (1xM)
-predictedLabel = SvmSgd.predict(newFeatureVector);
-@endcode
-
-*/
-class CV_EXPORTS_W SVMSGD {
-
-    public:
-        /** @brief SGDSVM constructor.
-
-        @param lambda regularization
-        @param learnRate learning rate
-        @param nIterations number of training iterations
-
-        */
-        SVMSGD(float lambda = 0.000001, float learnRate = 2, uint nIterations = 100000);
-
-        /** @brief SGDSVM constructor.
-
-        @param updateFrequency online update frequency
-        @param learnRateDecay learn rate decay over time: learnRate = learnRate * learnDecay
-        @param lambda regularization
-        @param learnRate learning rate
-        @param nIterations number of training iterations
-
-        */
-        SVMSGD(uint updateFrequency, float learnRateDecay = 1, float lambda = 0.000001, float learnRate = 2, uint nIterations = 100000);
-        virtual ~SVMSGD();
-        virtual SVMSGD* clone() const;
-
-        /** @brief Train the SGDSVM classifier.
-
-        The function trains the SGDSVM classifier using the train features and the correspondent labels (-1 or 1).
-
-        @param trainFeatures features used for training. Each row is a new sample.
-        @param labels mat (size Nx1 with N = number of features) with the label of each training feature.
-
-        */
-        virtual void train(cv::Mat trainFeatures, cv::Mat labels);
-
-        /** @brief Predict the label of a new feature vector.
-
-        The function predicts and returns the label of a new feature vector, using the previously trained SVM model.
-
-        @param newFeature new feature vector used for prediction
-
-        */
-        virtual float predict(cv::Mat newFeature);
-
-        /** @brief Returns the weights of the trained model.
-
-        */
-        virtual std::vector<float> getWeights(){ return _weights; };
-
-        /** @brief Sets the weights of the trained model.
-
-        @param weights weights used to predict the label of a new feature vector.
-
-        */
-        virtual void setWeights(std::vector<float> weights){ _weights = weights; };
-
-    private:
-        void updateWeights();
-        void generateRandomIndex();
-        float calcInnerProduct(float *rowDataPointer);
-        void updateWeights(float innerProduct, float *rowDataPointer, int label);
-
-        // Vector with SVM weights
-        std::vector<float> _weights;
-
-        // Random index generation
-        long long int _randomNumber;
-        unsigned int _randomIndex;
-
-        // Number of features and samples
-        unsigned int _nFeatures;
-        unsigned int _nTrainSamples;
-
-        // Parameters for learning
-        float _lambda;  //regularization
-        float _learnRate;  //learning rate
-        unsigned int _nIterations; //number of training iterations
-
-        // Vars to control the features slider matrix
-        bool _onlineUpdate;
-        bool _initPredict;
-        uint _slidingWindowSize;
-        uint _predictSlidingWindowSize;
-        float* _labelSlider;
-        float _learnRateDecay;
-
-        // Mat with features slider and correspondent counter
-        unsigned int _sliderCounter;
-        cv::Mat _featuresSlider;
-
-};
 
 //! @} ml
 
diff --git a/modules/ml/include/opencv2/ml/svmsgd.hpp b/modules/ml/include/opencv2/ml/svmsgd.hpp
new file mode 100644 (file)
index 0000000..f61a905
--- /dev/null
@@ -0,0 +1,134 @@
+/*M///////////////////////////////////////////////////////////////////////////////////////
+//
+//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
+//
+//  By downloading, copying, installing or using the software you agree to this license.
+//  If you do not agree to this license, do not download, install,
+//  copy or use the software.
+//
+//
+//                           License Agreement
+//                For Open Source Computer Vision Library
+//
+// Copyright (C) 2000, Intel Corporation, all rights reserved.
+// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
+// Copyright (C) 2014, Itseez Inc, all rights reserved.
+// Third party copyrights are property of their respective owners.
+//
+// Redistribution and use in source and binary forms, with or without modification,
+// are permitted provided that the following conditions are met:
+//
+//   * Redistribution's of source code must retain the above copyright notice,
+//     this list of conditions and the following disclaimer.
+//
+//   * Redistribution's in binary form must reproduce the above copyright notice,
+//     this list of conditions and the following disclaimer in the documentation
+//     and/or other materials provided with the distribution.
+//
+//   * The name of the copyright holders may not be used to endorse or promote products
+//     derived from this software without specific prior written permission.
+//
+// This software is provided by the copyright holders and contributors "as is" and
+// any express or implied warranties, including, but not limited to, the implied
+// warranties of merchantability and fitness for a particular purpose are disclaimed.
+// In no event shall the Intel Corporation or contributors be liable for any direct,
+// indirect, incidental, special, exemplary, or consequential damages
+// (including, but not limited to, procurement of substitute goods or services;
+// loss of use, data, or profits; or business interruption) however caused
+// and on any theory of liability, whether in contract, strict liability,
+// or tort (including negligence or otherwise) arising in any way out of
+// the use of this software, even if advised of the possibility of such damage.
+//
+//M*/
+
+#ifndef __OPENCV_ML_SVMSGD_HPP__
+#define __OPENCV_ML_SVMSGD_HPP__
+
+#ifdef __cplusplus
+
+#include "opencv2/ml.hpp"
+
+namespace cv
+{
+namespace ml
+{
+
+
+/****************************************************************************************\
+*                        Stochastic Gradient Descent SVM Classifier                      *
+\****************************************************************************************/
+
+/*!
+@brief Stochastic Gradient Descent SVM classifier
+
+SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large.
+The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which
+is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example).
+
+First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined.
+
+Then the SVM model can be trained using the train features and the correspondent labels.
+
+After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically.
+
+@code
+// Initialize object
+SVMSGD SvmSgd;
+
+// Train the Stochastic Gradient Descent SVM
+SvmSgd.train(trainFeatures, labels);
+
+// Predict label for the new feature vector (1xM)
+predictedLabel = SvmSgd.predict(newFeatureVector);
+@endcode
+
+*/
+
+class CV_EXPORTS_W SVMSGD : public cv::ml::StatModel
+{
+public:
+
+    enum SvmsgdType
+    {
+        ILLEGAL_VALUE,
+        SGD,                                     //Stochastic Gradient Descent
+        ASGD                                     //Average Stochastic Gradient Descent
+    };
+
+    /**
+     * @return the weights of the trained model.
+    */
+    CV_WRAP virtual Mat getWeights() = 0;
+
+    CV_WRAP virtual float getShift() = 0;
+
+    CV_WRAP static Ptr<SVMSGD> create();    
+
+    CV_WRAP virtual void setOptimalParameters(int type = ASGD) = 0;
+
+    CV_WRAP virtual int getType() const = 0;
+
+    CV_WRAP virtual void setType(int type) = 0;
+
+    CV_WRAP virtual float getLambda() const = 0;
+
+    CV_WRAP virtual void setLambda(float lambda) = 0;
+
+    CV_WRAP virtual float getGamma0() const = 0;
+
+    CV_WRAP virtual void setGamma0(float gamma0) = 0;
+
+    CV_WRAP virtual float getC() const = 0;
+
+    CV_WRAP virtual void setC(float c) = 0;
+
+    CV_WRAP virtual cv::TermCriteria getTermCriteria() const = 0;
+
+    CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0;
+};
+
+} //ml
+} //cv
+
+#endif  // __clpusplus
+#endif  // __OPENCV_ML_SVMSGD_HPP
index 8482198..9318e4a 100644 (file)
@@ -45,7 +45,7 @@
 #include "opencv2/ml.hpp"
 #include "opencv2/core/core_c.h"
 #include "opencv2/core/utility.hpp"
-
+#include "opencv2/ml/svmsgd.hpp"
 #include "opencv2/core/private.hpp"
 
 #include <assert.h>
index 3114e43..91377cf 100644 (file)
 //M*/
 
 #include "precomp.hpp"
+#include "limits"
 
 /****************************************************************************************\
 *                        Stochastic Gradient Descent SVM Classifier                      *
 \****************************************************************************************/
 
-namespace cv {
-namespace ml {
+namespace cv
+{
+namespace ml
+{
 
-SVMSGD::SVMSGD(float lambda, float learnRate, uint nIterations){
+class SVMSGDImpl : public SVMSGD
+{
 
-    // Initialize with random seed
-    _randomNumber = 1;
+public:
+    SVMSGDImpl();
 
-    // Initialize constants
-    _slidingWindowSize = 0;
-    _nFeatures = 0;
-    _predictSlidingWindowSize = 1;
+    virtual ~SVMSGDImpl() {}
 
-    // Initialize sliderCounter at index 0
-    _sliderCounter = 0;
+    virtual bool train(const Ptr<TrainData>& data, int);
 
-    // Parameters for learning
-    _lambda = lambda;  // regularization
-    _learnRate = learnRate;  // learning rate (ideally should be large at beginning and decay each iteration)
-    _nIterations = nIterations;  // number of training iterations
+    virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
 
-    // True only in the first predict iteration
-    _initPredict = true;
+    virtual bool isClassifier() const { return params.svmsgdType == SGD || params.svmsgdType == ASGD; }
 
-    // Online update flag
-    _onlineUpdate = false;
-}
+    virtual bool isTrained() const;
 
-SVMSGD::SVMSGD(uint updateFrequency, float learnRateDecay, float lambda, float learnRate, uint nIterations){
+    virtual void clear();
 
-    // Initialize with random seed
-    _randomNumber = 1;
+    virtual void write(FileStorage& fs) const;
 
-    // Initialize constants
-    _slidingWindowSize = 0;
-    _nFeatures = 0;
-    _predictSlidingWindowSize = updateFrequency;
+    virtual void read(const FileNode& fn);
 
-    // Initialize sliderCounter at index 0
-    _sliderCounter = 0;
+    virtual Mat getWeights(){ return weights_; }
 
-    // Parameters for learning
-    _lambda = lambda;  // regularization
-    _learnRate = learnRate;  // learning rate (ideally should be large at beginning and decay each iteration)
-    _nIterations = nIterations;  // number of training iterations
+    virtual float getShift(){ return shift_; }
 
-    // True only in the first predict iteration
-    _initPredict = true;
+    virtual int getVarCount() const { return weights_.cols; }
 
-    // Online update flag
-    _onlineUpdate = true;
+    virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
 
-    // Learn rate decay: _learnRate = _learnRate * _learnDecay
-    _learnRateDecay = learnRateDecay;
-}
+    virtual void setOptimalParameters(int type = ASGD);
 
-SVMSGD::~SVMSGD(){
+    virtual int getType() const;
 
-}
+    virtual void setType(int type);
+
+    CV_IMPL_PROPERTY(float, Lambda, params.lambda)
+    CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
+    CV_IMPL_PROPERTY(float, C, params.c)
+    CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
 
-SVMSGD* SVMSGD::clone() const{
-    return new SVMSGD(*this);
+    private:
+        void updateWeights(InputArray sample, bool is_first_class, float gamma);
+    float calcShift(InputArray trainSamples, InputArray trainResponses) const;
+    std::pair<bool,bool> areClassesEmpty(Mat responses);
+    void writeParams( FileStorage& fs ) const;
+    void readParams( const FileNode& fn );
+    static inline bool isFirstClass(float val) { return val > 0; }
+
+
+    // Vector with SVM weights
+    Mat weights_;
+    float shift_;
+
+    // Random index generation
+    RNG rng_;
+
+    // Parameters for learning
+    struct SVMSGDParams
+    {
+        float lambda;                             //regularization
+        float gamma0;                             //learning rate
+        float c;
+        TermCriteria termCrit;
+        SvmsgdType svmsgdType;
+    };
+
+    SVMSGDParams params;
+};
+
+Ptr<SVMSGD> SVMSGD::create()
+{    
+    return makePtr<SVMSGDImpl>();
 }
 
-void SVMSGD::train(cv::Mat trainFeatures, cv::Mat labels){
 
-    // Initialize _nFeatures
-    _slidingWindowSize = trainFeatures.rows;
-    _nFeatures = trainFeatures.cols;
+bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
+{
+    clear();
+
+    Mat trainSamples = data->getTrainSamples();
+
+    // Initialize varCount
+    int trainSamplesCount_ = trainSamples.rows;
+    int varCount = trainSamples.cols;
 
-    float innerProduct;
     // Initialize weights vector with zeros
-    if (_weights.size()==0){
-        _weights.reserve(_nFeatures);
-        for (uint feat = 0; feat < _nFeatures; ++feat){
-            _weights.push_back(0.0);
-        }
+    weights_ = Mat::zeros(1, varCount, CV_32F);
+
+    Mat trainResponses = data->getTrainResponses();        // (trainSamplesCount x 1) matrix
+
+    std::pair<bool,bool> are_empty = areClassesEmpty(trainResponses);
+
+    if ( are_empty.first && are_empty.second )
+    {
+        weights_.release();
+        return false;
+    }
+    if ( are_empty.first || are_empty.second )
+    {
+        shift_ = are_empty.first ? -1 : 1;
+        return true;
+    }
+
+
+    Mat currentSample;
+    float gamma = 0;
+    Mat lastWeights = Mat::zeros(1, varCount, CV_32F);     //weights vector for calculating terminal criterion
+    Mat averageWeights;                                    //average weights vector for ASGD model
+    double err = DBL_MAX;
+    if (params.svmsgdType == ASGD)
+    {
+        averageWeights = Mat::zeros(1, varCount, CV_32F);
     }
 
     // Stochastic gradient descent SVM
-    for (uint iter = 0; iter < _nIterations; ++iter){
-        generateRandomIndex();
-        innerProduct = calcInnerProduct(trainFeatures.ptr<float>(_randomIndex));
-        int label = (labels.at<int>(_randomIndex,0) > 0) ? 1 : -1; // ensure that labels are -1 or 1
-        updateWeights(innerProduct, trainFeatures.ptr<float>(_randomIndex), label );
+    for (int iter = 0; (iter < params.termCrit.maxCount)&&(err > params.termCrit.epsilon); iter++)
+    {
+        //generate sample number
+        int randomNumber = rng_.uniform(0, trainSamplesCount_);
+
+        currentSample = trainSamples.row(randomNumber);
+
+        //update gamma
+        gamma = params.gamma0 * std::pow((1 + params.lambda * params.gamma0 * (float)iter), (-params.c));
+
+        bool is_first_class = isFirstClass(trainResponses.at<float>(randomNumber));
+        updateWeights( currentSample, is_first_class, gamma );
+
+        //average weights (only for ASGD model)
+        if (params.svmsgdType == ASGD)
+        {
+            averageWeights = ((float)iter/ (1 + (float)iter)) * averageWeights  + weights_ / (1 + (float) iter);
+        }
+
+        err = norm(weights_ - lastWeights);
+        weights_.copyTo(lastWeights);
+    }
+
+    if (params.svmsgdType == ASGD)
+    {
+        weights_ = averageWeights;
     }
+
+    shift_ = calcShift(trainSamples, trainResponses);
+
+    return true;
 }
 
-float SVMSGD::predict(cv::Mat newFeature){
-    float innerProduct;
+std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
+{
+    std::pair<bool,bool> are_classes_empty(true, true);
+    int limit_index = responses.rows;
+
+    for(int index = 0; index < limit_index; index++)
+    {
+        if (isFirstClass(responses.at<float>(index,0)))
+            are_classes_empty.first = false;
+        else
+            are_classes_empty.second = false;
 
-    if (_initPredict){
-        _nFeatures = newFeature.cols;
-        _slidingWindowSize = _predictSlidingWindowSize;
-        _featuresSlider = cv::Mat::zeros(_slidingWindowSize, _nFeatures, CV_32F);
-        _initPredict = false;
-        _labelSlider = new float[_predictSlidingWindowSize]();
-        _learnRate = _learnRate * _learnRateDecay;
+        if (!are_classes_empty.first && ! are_classes_empty.second)
+            break;
     }
 
-    innerProduct = calcInnerProduct(newFeature.ptr<float>(0));
+    return are_classes_empty;
+}
 
-    // Resultant label (-1 or 1)
-    int label = (innerProduct>=0) ? 1 : -1;
+float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
+{
+    float distance_to_classes[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
 
-    if (_onlineUpdate){
-        // Update the featuresSlider with newFeature and _labelSlider with label
-        newFeature.row(0).copyTo(_featuresSlider.row(_sliderCounter));
-        _labelSlider[_sliderCounter] = float(label);
+    Mat trainSamples = _samples.getMat();
+    int trainSamplesCount = trainSamples.rows;
 
-        // Update weights with a random index
-        if (_sliderCounter == _slidingWindowSize-1){
-            generateRandomIndex();
-            updateWeights(innerProduct, _featuresSlider.ptr<float>(_randomIndex), int(_labelSlider[_randomIndex]) );
-        }
+    Mat trainResponses = _responses.getMat();
+
+    for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
+    {
+        Mat currentSample = trainSamples.row(samplesIndex);
+        float scalar_product = currentSample.dot(weights_);
 
-        // _sliderCounter++ if < _slidingWindowSize
-        _sliderCounter = (_sliderCounter == _slidingWindowSize-1) ? 0 : (_sliderCounter+1);
+        bool is_first_class = isFirstClass(trainResponses.at<float>(samplesIndex));
+        int index = is_first_class ? 0:1;
+        float sign_to_mul = is_first_class ? 1 : -1;
+        float cur_distance = scalar_product * sign_to_mul ;
+
+        if (cur_distance < distance_to_classes[index])
+        {
+            distance_to_classes[index] = cur_distance;
+        }
     }
 
-    return float(label);
+    //todo: areClassesEmpty(); make const;
+    return -(distance_to_classes[0] - distance_to_classes[1]) / 2.f;
 }
 
-void SVMSGD::generateRandomIndex(){
-    // Choose random sample, using Mikolov's fast almost-uniform random number
-    _randomNumber = _randomNumber * (unsigned long long) 25214903917 + 11;
-    _randomIndex = uint(_randomNumber % (unsigned long long) _slidingWindowSize);
-}
+float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
+{
+    float result = 0;
+    cv::Mat samples = _samples.getMat();
+    int nSamples = samples.rows;
+    cv::Mat results;
 
-float SVMSGD::calcInnerProduct(float *rowDataPointer){
-    float innerProduct = 0;
-    for (uint feat = 0; feat < _nFeatures; ++feat){
-        innerProduct += _weights[feat] * rowDataPointer[feat];
+    CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32F );
+
+    if( _results.needed() )
+    {
+        _results.create( nSamples, 1, samples.type() );
+        results = _results.getMat();
+    }
+    else
+    {
+        CV_Assert( nSamples == 1 );
+        results = Mat(1, 1, CV_32F, &result);
     }
-    return innerProduct;
+
+    Mat currentSample;
+    float criterion = 0;
+
+    for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
+    {
+        currentSample = samples.row(sampleIndex);
+        criterion = currentSample.dot(weights_) + shift_;
+        results.at<float>(sampleIndex) = (criterion >= 0) ? 1 : -1;
+    }
+
+    return result;
 }
 
-void SVMSGD::updateWeights(float innerProduct, float *rowDataPointer, int label){
-    if (label * innerProduct > 1) {
+void SVMSGDImpl::updateWeights(InputArray _sample, bool is_first_class, float gamma)
+{
+    Mat sample = _sample.getMat();
+
+    int responce = is_first_class ? 1 : -1; // ensure that trainResponses are -1 or 1
+
+    if ( sample.dot(weights_) * responce > 1)
+    {
         // Not a support vector, only apply weight decay
-        for (uint feat = 0; feat < _nFeatures; feat++) {
-            _weights[feat] -= _learnRate * _lambda * _weights[feat];
-        }
-    } else {
+        weights_ *= (1.f - gamma * params.lambda);
+    }
+    else
+    {
         // It's a support vector, add it to the weights
-        for (uint feat = 0; feat < _nFeatures; feat++) {
-            _weights[feat] -= _learnRate * (_lambda * _weights[feat] - label * rowDataPointer[feat]);
-        }
+        weights_ -= (gamma * params.lambda) * weights_ - gamma * responce * sample;
+        //std::cout << "sample " << sample << std::endl;
+        //std::cout << "weights_ " << weights_ << std::endl;
+    }
+}
+
+bool SVMSGDImpl::isTrained() const
+{
+    return !weights_.empty();
+}
+
+void SVMSGDImpl::write(FileStorage& fs) const
+{
+    if( !isTrained() )
+        CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );
+
+    writeParams( fs );
+
+    fs << "shift" << shift_;
+    fs << "weights" << weights_;
+}
+
+void SVMSGDImpl::writeParams( FileStorage& fs ) const
+{
+    String SvmsgdTypeStr;
+
+    switch (params.svmsgdType)
+    {
+    case SGD:
+        SvmsgdTypeStr = "SGD";
+        break;
+    case ASGD:
+        SvmsgdTypeStr = "ASGD";
+        break;
+    case ILLEGAL_VALUE:
+        SvmsgdTypeStr = format("Uknown_%d", params.svmsgdType);
+    default:
+        std::cout << "params.svmsgdType isn't initialized" << std::endl;
+    }
+
+
+    fs << "svmsgdType" << SvmsgdTypeStr;
+
+    fs << "lambda" << params.lambda;
+    fs << "gamma0" << params.gamma0;
+    fs << "c" << params.c;
+
+    fs << "term_criteria" << "{:";
+    if( params.termCrit.type & TermCriteria::EPS )
+        fs << "epsilon" << params.termCrit.epsilon;
+    if( params.termCrit.type & TermCriteria::COUNT )
+        fs << "iterations" << params.termCrit.maxCount;
+    fs << "}";
+}
+
+
+
+void SVMSGDImpl::read(const FileNode& fn)
+{
+    clear();
+
+    readParams(fn);
+
+    shift_ = (float) fn["shift"];
+    fn["weights"] >> weights_;
+}
+
+void SVMSGDImpl::readParams( const FileNode& fn )
+{
+    String svmsgdTypeStr = (String)fn["svmsgdType"];
+    SvmsgdType svmsgdType =
+            svmsgdTypeStr == "SGD" ? SGD :
+                                     svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_VALUE;
+
+    if( svmsgdType == ILLEGAL_VALUE )
+        CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
+
+    params.svmsgdType = svmsgdType;
+
+    CV_Assert ( fn["lambda"].isReal() );
+    params.lambda = (float)fn["lambda"];
+
+    CV_Assert ( fn["gamma0"].isReal() );
+    params.gamma0 = (float)fn["gamma0"];
+
+    CV_Assert ( fn["c"].isReal() );
+    params.c = (float)fn["c"];
+
+    FileNode tcnode = fn["term_criteria"];
+    if( !tcnode.empty() )
+    {
+        params.termCrit.epsilon = (double)tcnode["epsilon"];
+        params.termCrit.maxCount = (int)tcnode["iterations"];
+        params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
+                (params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
     }
+    else
+        params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON );
+
 }
 
+void SVMSGDImpl::clear()
+{
+    weights_.release();
+    shift_ = 0;
 }
+
+
+SVMSGDImpl::SVMSGDImpl()
+{
+    clear();
+    rng_(0);
+
+    params.svmsgdType = ILLEGAL_VALUE;
+
+    // Parameters for learning
+    params.lambda = 0;                              // regularization
+    params.gamma0 = 0;                        // learning rate (ideally should be large at beginning and decay each iteration)
+    params.c = 0;
+
+    TermCriteria _termCrit(TermCriteria::COUNT + TermCriteria::EPS, 0, 0);
+    params.termCrit = _termCrit;
+}
+
+void SVMSGDImpl::setOptimalParameters(int type)
+{
+    switch (type)
+    {
+    case SGD:
+        params.svmsgdType = SGD;
+        params.lambda = 0.00001;
+        params.gamma0 = 0.05;
+        params.c = 1;
+        params.termCrit.maxCount = 50000;
+        params.termCrit.epsilon = 0.00000001;
+        break;
+
+    case ASGD:
+        params.svmsgdType = ASGD;
+        params.lambda = 0.00001;
+        params.gamma0 = 0.5;
+        params.c = 0.75;
+        params.termCrit.maxCount = 100000;
+        params.termCrit.epsilon = 0.000001;
+        break;
+
+    default:
+        CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
+    }
+}
+
+void SVMSGDImpl::setType(int type)
+{
+    switch (type)
+    {
+    case SGD:
+        params.svmsgdType = SGD;
+        break;
+    case ASGD:
+        params.svmsgdType = ASGD;
+        break;
+    default:
+        params.svmsgdType = ILLEGAL_VALUE;
+    }
+}
+
+int SVMSGDImpl::getType() const
+{
+    return params.svmsgdType;
 }
+}   //ml
+}   //cv
index 919fae6..6603a35 100644 (file)
@@ -193,6 +193,16 @@ int str_to_boost_type( String& str )
 // 8. rtrees
 // 9. ertrees
 
+int str_to_svmsgd_type( String& str )
+{
+    if ( !str.compare("SGD") )
+        return SVMSGD::SGD;
+    if ( !str.compare("ASGD") )
+        return SVMSGD::ASGD;
+    CV_Error( CV_StsBadArg, "incorrect boost type string" );
+    return -1;
+}
+
 // ---------------------------------- MLBaseTest ---------------------------------------------------
 
 CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
@@ -248,7 +258,9 @@ void CV_MLBaseTest::run( int )
 {
     string filename = ts->get_data_path();
     filename += get_validation_filename();
+
     validationFS.open( filename, FileStorage::READ );
+
     read_params( *validationFS );
 
     int code = cvtest::TS::OK;
@@ -436,6 +448,21 @@ int CV_MLBaseTest::train( int testCaseIdx )
         model = m;
     }
 
+    else if( modelName == CV_SVMSGD )
+    {
+        String svmsgdTypeStr;
+        modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
+        Ptr<SVMSGD> m = SVMSGD::create();      
+        int type = str_to_svmsgd_type( svmsgdTypeStr );
+        m->setType(type);
+        //m->setType(str_to_svmsgd_type( svmsgdTypeStr ));
+        m->setLambda(modelParamsNode["lambda"]);
+        m->setGamma0(modelParamsNode["gamma0"]);
+        m->setC(modelParamsNode["c"]);
+        m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001));
+        model = m;
+    }
+
     if( !model.empty() )
         is_trained = model->train(data, 0);
 
@@ -457,7 +484,7 @@ float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
     else if( modelName == CV_ANN )
         err = ann_calc_error( model, data, cls_map, type, resp );
     else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
-             modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST )
+             modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST || modelName == CV_SVMSGD )
         err = model->calcError( data, true, _resp );
     if( !_resp.empty() && resp )
         _resp.convertTo(*resp, CV_32F);
@@ -485,6 +512,8 @@ void CV_MLBaseTest::load( const char* filename )
         model = Algorithm::load<Boost>( filename );
     else if( modelName == CV_RTREES )
         model = Algorithm::load<RTrees>( filename );
+    else if( modelName == CV_SVMSGD )
+        model = Algorithm::load<SVMSGD>( filename );
     else
         CV_Error( CV_StsNotImplemented, "invalid stat model name");
 }
index 329b9bd..18cee96 100644 (file)
@@ -13,6 +13,7 @@
 #include <map>
 #include "opencv2/ts.hpp"
 #include "opencv2/ml.hpp"
+#include "opencv2/ml/svmsgd.hpp"
 #include "opencv2/core/core_c.h"
 
 #define CV_NBAYES   "nbayes"
@@ -24,6 +25,7 @@
 #define CV_BOOST    "boost"
 #define CV_RTREES   "rtrees"
 #define CV_ERTREES  "ertrees"
+#define CV_SVMSGD   "svmsgd"
 
 enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 };
 
@@ -38,6 +40,7 @@ using cv::ml::ANN_MLP;
 using cv::ml::DTrees;
 using cv::ml::Boost;
 using cv::ml::RTrees;
+using cv::ml::SVMSGD;
 
 class CV_MLBaseTest : public cvtest::BaseTest
 {
index 2d6f144..354c6e0 100644 (file)
@@ -150,12 +150,20 @@ int CV_SLMLTest::validate_test_results( int testCaseIdx )
 
 TEST(ML_NaiveBayes, save_load) { CV_SLMLTest test( CV_NBAYES ); test.safe_run(); }
 TEST(ML_KNearest, save_load) { CV_SLMLTest test( CV_KNEAREST ); test.safe_run(); }
-TEST(ML_SVM, save_load) { CV_SLMLTest test( CV_SVM ); test.safe_run(); }
+TEST(ML_SVM, save_load)
+{
+    CV_SLMLTest test( CV_SVM );
+    test.safe_run();
+}
 TEST(ML_ANN, save_load) { CV_SLMLTest test( CV_ANN ); test.safe_run(); }
 TEST(ML_DTree, save_load) { CV_SLMLTest test( CV_DTREE ); test.safe_run(); }
 TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
 TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
 TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
+TEST(MV_SVMSGD, save_load){
+    CV_SLMLTest test( CV_SVMSGD );
+    test.safe_run();
+}
 
 class CV_LegacyTest : public cvtest::BaseTest
 {
@@ -201,6 +209,8 @@ protected:
             model = Algorithm::load<SVM>(filename);
         else if (modelName == CV_RTREES)
             model = Algorithm::load<RTrees>(filename);
+        else if (modelName == CV_SVMSGD)
+            model = Algorithm::load<SVMSGD>(filename);
         if (!model)
         {
             code = cvtest::TS::FAIL_INVALID_TEST_DATA;
@@ -260,6 +270,11 @@ TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushro
 TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
 TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
 TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
+TEST(ML_SVMSGD, legacy_load)
+{
+    CV_LegacyTest test(CV_SVMSGD, "_waveform.xml");
+    test.safe_run();
+}
 
 /*TEST(ML_SVM, throw_exception_when_save_untrained_model)
 {
diff --git a/modules/ml/test/test_svmsgd.cpp b/modules/ml/test/test_svmsgd.cpp
new file mode 100644 (file)
index 0000000..9f4aafc
--- /dev/null
@@ -0,0 +1,182 @@
+/*M///////////////////////////////////////////////////////////////////////////////////////
+//
+//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
+//
+//  By downloading, copying, installing or using the software you agree to this license.
+//  If you do not agree to this license, do not download, install,
+//  copy or use the software.
+//
+//
+//                        Intel License Agreement
+//                For Open Source Computer Vision Library
+//
+// Copyright (C) 2000, Intel Corporation, all rights reserved.
+// Third party copyrights are property of their respective owners.
+//
+// Redistribution and use in source and binary forms, with or without modification,
+// are permitted provided that the following conditions are met:
+//
+//   * Redistribution's of source code must retain the above copyright notice,
+//     this list of conditions and the following disclaimer.
+//
+//   * Redistribution's in binary form must reproduce the above copyright notice,
+//     this list of conditions and the following disclaimer in the documentation
+//     and/or other materials provided with the distribution.
+//
+//   * The name of Intel Corporation may not be used to endorse or promote products
+//     derived from this software without specific prior written permission.
+//
+// This software is provided by the copyright holders and contributors "as is" and
+// any express or implied warranties, including, but not limited to, the implied
+// warranties of merchantability and fitness for a particular purpose are disclaimed.
+// In no event shall the Intel Corporation or contributors be liable for any direct,
+// indirect, incidental, special, exemplary, or consequential damages
+// (including, but not limited to, procurement of substitute goods or services;
+// loss of use, data, or profits; or business interruption) however caused
+// and on any theory of liability, whether in contract, strict liability,
+// or tort (including negligence or otherwise) arising in any way out of
+// the use of this software, even if advised of the possibility of such damage.
+//
+//M*/
+
+#include "test_precomp.hpp"
+#include "opencv2/highgui.hpp"
+
+using namespace cv;
+using namespace cv::ml;
+using cv::ml::SVMSGD;
+using cv::ml::TrainData;
+
+
+
+class CV_SVMSGDTrainTest : public cvtest::BaseTest
+{
+public:
+    CV_SVMSGDTrainTest(Mat _weights, float _shift);
+private:
+    virtual void run( int start_from );
+    float decisionFunction(Mat sample, Mat weights, float shift);
+
+    cv::Ptr<TrainData> data;
+    cv::Mat testSamples;
+    cv::Mat testResponses;
+    static const int TEST_VALUE_LIMIT = 50;
+};
+
+CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift)
+{
+    int datasize = 100000;
+    int varCount = weights.cols;
+    cv::Mat samples = cv::Mat::zeros( datasize, varCount, CV_32FC1 );
+    cv::Mat responses = cv::Mat::zeros( datasize, 1, CV_32FC1 );
+    cv::RNG rng(0);
+
+    float lowerLimit = -TEST_VALUE_LIMIT;
+    float upperLimit = TEST_VALUE_LIMIT;
+
+
+    rng.fill(samples, RNG::UNIFORM, lowerLimit, upperLimit);
+    for (int sampleIndex = 0; sampleIndex < datasize; sampleIndex++)
+    {
+        responses.at<float>( sampleIndex ) = decisionFunction(samples.row(sampleIndex), weights, shift) > 0 ? 1 : -1;
+    }
+
+    data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
+
+    int testSamplesCount = 100000;
+
+    testSamples.create(testSamplesCount, varCount, CV_32FC1);
+    rng.fill(testSamples, RNG::UNIFORM, lowerLimit, upperLimit);
+    testResponses.create(testSamplesCount, 1, CV_32FC1);
+
+    for (int i = 0 ; i < testSamplesCount; i++)
+    {
+        testResponses.at<float>(i) = decisionFunction(testSamples.row(i), weights, shift) > 0 ? 1 : -1;
+    }
+}
+
+void CV_SVMSGDTrainTest::run( int /*start_from*/ )
+{
+    cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
+
+    svmsgd->setOptimalParameters(SVMSGD::ASGD);
+
+    svmsgd->train( data );
+
+    Mat responses;
+
+    svmsgd->predict(testSamples, responses);
+
+    int errCount = 0;
+    int testSamplesCount = testSamples.rows;
+
+    for (int i = 0; i < testSamplesCount; i++)
+    {
+        if (responses.at<float>(i) * testResponses.at<float>(i) < 0 )
+            errCount++;
+    }
+
+    float err = (float)errCount / testSamplesCount;
+    std::cout << "err " << err << std::endl;
+
+    if ( err > 0.01 )
+    {
+        ts->set_failed_test_info( cvtest::TS::FAIL_BAD_ACCURACY );
+    }
+}
+
+float CV_SVMSGDTrainTest::decisionFunction(Mat sample, Mat weights, float shift)
+{
+    return sample.dot(weights) + shift;
+}
+
+TEST(ML_SVMSGD, train0)
+{
+    int varCount = 2;
+
+    Mat weights;
+    weights.create(1, varCount, CV_32FC1);
+    weights.at<float>(0) = 1;
+    weights.at<float>(1) = 0;
+
+    float shift = 5;
+
+    CV_SVMSGDTrainTest test(weights, shift);
+    test.safe_run();
+}
+
+TEST(ML_SVMSGD, train1)
+{
+    int varCount = 5;
+
+    Mat weights;
+    weights.create(1, varCount, CV_32FC1);
+
+    float lowerLimit = -1;
+    float upperLimit = 1;
+    cv::RNG rng(0);
+    rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
+
+    float shift = rng.uniform(-5.f, 5.f);
+
+    CV_SVMSGDTrainTest test(weights, shift);
+    test.safe_run();
+}
+
+TEST(ML_SVMSGD, train2)
+{
+    int varCount = 100;
+
+    Mat weights;
+    weights.create(1, varCount, CV_32FC1);
+
+    float lowerLimit = -1;
+    float upperLimit = 1;
+    cv::RNG rng(0);
+    rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
+
+    float shift = rng.uniform(-1000.f, 1000.f);
+
+    CV_SVMSGDTrainTest test(weights, shift);
+    test.safe_run();
+}
index 29a3996..5604eb7 100644 (file)
@@ -5659,7 +5659,7 @@ class TestCaseNameIs {
 
   // Returns true iff the name of test_case matches name_.
   bool operator()(const TestCase* test_case) const {
-    return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0;
+     return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0;
   }
 
  private:
diff --git a/samples/cpp/train_svmsgd.cpp b/samples/cpp/train_svmsgd.cpp
new file mode 100644 (file)
index 0000000..cee8217
--- /dev/null
@@ -0,0 +1,226 @@
+#include <opencv2/opencv.hpp>
+#include "opencv2/video/tracking.hpp"
+#include "opencv2/imgproc/imgproc.hpp"
+#include "opencv2/highgui/highgui.hpp"
+
+using namespace cv;
+using namespace cv::ml;
+
+#define WIDTH 841
+#define HEIGHT 594
+
+struct Data
+{
+    Mat img;
+    Mat samples;
+    Mat responses;
+    RNG rng;
+    //Point points[2];
+
+    Data()
+    {
+        img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
+        imshow("Train svmsgd", img);
+    }
+};
+
+bool doTrain(const Mat samples,const Mat responses, Mat &weights, float &shift);
+bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2]);
+bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
+void fillSegments(std::vector<std::pair<Point,Point> > &segments);
+void redraw(Data data, const Point points[2]);
+void addPointsRetrainAndRedraw(Data &data, int x, int y);
+
+
+bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
+{
+    cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
+    svmsgd->setOptimalParameters(SVMSGD::ASGD);
+    svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 50000, 0.0000001));
+    svmsgd->setLambda(0.01);
+    svmsgd->setGamma0(1);
+   // svmsgd->setC(5);
+
+    cv::Ptr<TrainData> train_data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
+    svmsgd->train( train_data );
+
+    if (svmsgd->isTrained())
+    {
+        weights = svmsgd->getWeights();
+        shift = svmsgd->getShift();
+
+        std::cout << weights << std::endl;
+        std::cout << shift << std::endl;
+
+        return true;
+    }
+    return false;
+}
+
+
+bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
+{
+    int x = 0;
+    int y = 0;
+    //с (0,0) всё плохо
+    if (segment.first.x == segment.second.x && weights.at<float>(1) != 0)
+    {
+        x = segment.first.x;
+        y = -(weights.at<float>(0) * x + shift) / weights.at<float>(1);
+        if (y >= 0 && y <= HEIGHT)
+        {
+            crossPoint.x = x;
+            crossPoint.y = y;
+            return true;
+        }
+    }
+    else if (segment.first.y == segment.second.y && weights.at<float>(0) != 0)
+    {
+        y = segment.first.y;
+        x = - (weights.at<float>(1) * y + shift) / weights.at<float>(0);
+        if (x >= 0 && x <= WIDTH)
+        {
+            crossPoint.x = x;
+            crossPoint.y = y;
+            return true;
+        }
+    }
+    return false;
+}
+
+bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2])
+{
+    if (weights.empty())
+    {
+        return false;
+    }
+
+    int foundPointsCount = 0;
+    std::vector<std::pair<Point,Point> > segments;
+    fillSegments(segments);
+
+    for (int i = 0; i < 4; i++)
+    {
+        if (findCrossPoint(weights, shift, segments[i], points[foundPointsCount]))
+            foundPointsCount++;
+        if (foundPointsCount > 2)
+            break;
+    }
+    return true;
+}
+
+void fillSegments(std::vector<std::pair<Point,Point> > &segments)
+{
+    std::pair<Point,Point> curSegment;
+
+    curSegment.first = Point(0,0);
+    curSegment.second = Point(0,HEIGHT);
+    segments.push_back(curSegment);
+
+    curSegment.first = Point(0,0);
+    curSegment.second = Point(WIDTH,0);
+    segments.push_back(curSegment);
+
+    curSegment.first = Point(WIDTH,0);
+    curSegment.second = Point(WIDTH,HEIGHT);
+    segments.push_back(curSegment);
+
+    curSegment.first = Point(0,HEIGHT);
+    curSegment.second = Point(WIDTH,HEIGHT);
+    segments.push_back(curSegment);
+}
+
+void redraw(Data data, const Point points[2])
+{
+    data.img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
+    Point center;
+    int radius = 3;
+    Scalar color;
+    for (int i = 0; i < data.samples.rows; i++)
+    {
+        center.x = data.samples.at<float>(i,0);
+        center.y = data.samples.at<float>(i,1);
+        color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
+        circle(data.img, center, radius, color, 5);
+    }
+    line(data.img, points[0],points[1],cv::Scalar(1,255,1));
+
+    imshow("Train svmsgd", data.img);
+}
+
+void addPointsRetrainAndRedraw(Data &data, int x, int y)
+{
+
+    Mat currentSample(1, 2, CV_32F);
+    //start
+/*
+    Mat _weights;
+    _weights.create(1, 2, CV_32FC1);
+    _weights.at<float>(0) = 1;
+    _weights.at<float>(1) = -1;
+
+    int _x, _y;
+
+    for (int i=0;i<199;i++)
+    {
+    _x = data.rng.uniform(0,800);
+    _y = data.rng.uniform(0,500);*/
+    currentSample.at<float>(0,0) = x;
+    currentSample.at<float>(0,1) = y;
+    //if (currentSample.dot(_weights) > 0)
+        //data.responses.push_back(1);
+   // else data.responses.push_back(-1);
+
+    //finish
+    data.samples.push_back(currentSample);
+
+
+
+    Mat weights(1, 2, CV_32F);
+    float shift = 0;
+
+    if (doTrain(data.samples, data.responses, weights, shift))
+    {
+        Point points[2];
+        shift = 0;
+
+        findPointsForLine(weights, shift, points);
+
+        redraw(data, points);
+    }
+}
+
+
+static void onMouse( int event, int x, int y, int, void* pData)
+{
+    Data &data = *(Data*)pData;
+
+    switch( event )
+    {
+    case CV_EVENT_LBUTTONUP:
+        data.responses.push_back(1);
+        addPointsRetrainAndRedraw(data, x, y);
+
+        break;
+
+    case CV_EVENT_RBUTTONDOWN:
+        data.responses.push_back(-1);
+        addPointsRetrainAndRedraw(data, x, y);
+        break;
+    }
+
+}
+
+int main()
+{
+
+    Data data;
+
+    setMouseCallback( "Train svmsgd", onMouse, &data );
+    waitKey();
+
+
+
+
+    return 0;
+}