fixed compilation of some samples; fixed ANN_MLP::predict
authorVadim Pisarevsky <vadim.pisarevsky@gmail.com>
Wed, 30 Jul 2014 18:53:46 +0000 (22:53 +0400)
committerVadim Pisarevsky <vadim.pisarevsky@gmail.com>
Wed, 30 Jul 2014 18:53:46 +0000 (22:53 +0400)
modules/ml/src/ann_mlp.cpp
samples/cpp/bagofwords_classification.cpp
samples/cpp/em.cpp
samples/cpp/points_classifier.cpp

index 19f5572..750b3e6 100644 (file)
@@ -262,9 +262,9 @@ public:
                 int cols = layer_sizes[j];
 
                 layer_out = Mat(dn, cols, CV_64F, data);
-                Mat w = weights[i].rowRange(0, layer_in.cols);
+                Mat w = weights[j].rowRange(0, layer_in.cols);
                 gemm(layer_in, w, 1, noArray(), 0, layer_out);
-                calc_activ_func( layer_out, weights[i] );
+                calc_activ_func( layer_out, weights[j] );
 
                 layer_in = layer_out;
             }
@@ -682,6 +682,8 @@ public:
             train_backprop( inputs, outputs, sw, termcrit ) :
             train_rprop( inputs, outputs, sw, termcrit );
 
+        trained = true;
+
         return iter;
     }
 
index ef4f3c7..320acf3 100644 (file)
@@ -23,6 +23,7 @@
 #define DEBUG_DESC_PROGRESS
 
 using namespace cv;
+using namespace cv::ml;
 using namespace std;
 
 const string paramsFile = "params.xml";
@@ -677,7 +678,7 @@ void VocData::writeClassifierResultsFile( const string& out_dir, const string& o
         result_file.close();
     } else {
         string err_msg = "could not open classifier results file '" + output_file + "' for writing. Before running for the first time, a 'results' subdirectory should be created within the VOC dataset base directory. e.g. if the VOC data is stored in /VOC/VOC2010 then the path /VOC/results must be created.";
-        CV_Error(CV_StsError,err_msg.c_str());
+        CV_Error(Error::StsError,err_msg.c_str());
     }
 }
 
@@ -701,9 +702,9 @@ void VocData::writeClassifierResultsFile( const string& out_dir, const string& o
 string VocData::getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition, const int number)
 {
     if ((competition < 1) && (competition != -1))
-        CV_Error(CV_StsBadArg,"competition argument should be a positive non-zero number or -1 to accept the default");
+        CV_Error(Error::StsBadArg,"competition argument should be a positive non-zero number or -1 to accept the default");
     if ((number < 1) && (number != -1))
-        CV_Error(CV_StsBadArg,"number argument should be a positive non-zero number or -1 to accept the default");
+        CV_Error(Error::StsBadArg,"number argument should be a positive non-zero number or -1 to accept the default");
 
     string dset, task_type;
 
@@ -815,7 +816,7 @@ void VocData::calcClassifierPrecRecall(const string& input_file, vector<float>&
             scoregt_file.close();
         } else {
             string err_msg = "could not open scoregt file '" + scoregt_file_str + "' for writing.";
-            CV_Error(CV_StsError,err_msg.c_str());
+            CV_Error(Error::StsError,err_msg.c_str());
         }
     }
 
@@ -974,7 +975,7 @@ void VocData::calcClassifierConfMatRow(const string& obj_class, const vector<Obd
         if (target_idx_it == output_headers.end())
         {
             string err_msg = "could not find the target object class '" + obj_class + "' in list of valid classes.";
-            CV_Error(CV_StsError,err_msg.c_str());
+            CV_Error(Error::StsError,err_msg.c_str());
         }
         /* convert iterator to index */
         target_idx = (int)std::distance(output_headers.begin(),target_idx_it);
@@ -1037,7 +1038,7 @@ void VocData::calcClassifierConfMatRow(const string& obj_class, const vector<Obd
                 if (class_idx_it == output_headers.end())
                 {
                     string err_msg = "could not find object class '" + img_objects[obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
-                    CV_Error(CV_StsError,err_msg.c_str());
+                    CV_Error(Error::StsError,err_msg.c_str());
                 }
                 /* convert iterator to index */
                 int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
@@ -1189,7 +1190,7 @@ void VocData::calcDetectorConfMatRow(const string& obj_class, const ObdDatasetTy
             if (class_idx_it == output_headers.end())
             {
                 string err_msg = "could not find object class '" + img_objects[max_gt_obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
-                CV_Error(CV_StsError,err_msg.c_str());
+                CV_Error(Error::StsError,err_msg.c_str());
             }
             /* convert iterator to index */
             int class_idx = (int)std::distance(output_headers.begin(),class_idx_it);
@@ -1282,7 +1283,7 @@ void VocData::savePrecRecallToGnuplot(const string& output_file, const vector<fl
         plot_file.close();
     } else {
         string err_msg = "could not open plot file '" + output_file_std + "' for writing.";
-        CV_Error(CV_StsError,err_msg.c_str());
+        CV_Error(Error::StsError,err_msg.c_str());
     }
 }
 
@@ -1446,7 +1447,7 @@ void VocData::readClassifierGroundTruth(const string& filename, vector<string>&
     if (!gtfile.is_open())
     {
         string err_msg = "could not open VOC ground truth textfile '" + filename + "'.";
-        CV_Error(CV_StsError,err_msg.c_str());
+        CV_Error(Error::StsError,err_msg.c_str());
     }
 
     string line;
@@ -1462,7 +1463,7 @@ void VocData::readClassifierGroundTruth(const string& filename, vector<string>&
             image_codes.push_back(image);
             object_present.push_back(obj_present == 1);
         } else {
-            if (!gtfile.eof()) CV_Error(CV_StsParseError,"error parsing VOC ground truth textfile.");
+            if (!gtfile.eof()) CV_Error(Error::StsParseError,"error parsing VOC ground truth textfile.");
         }
     }
     gtfile.close();
@@ -1488,13 +1489,13 @@ void VocData::readClassifierResultsFile(const string& input_file, vector<string>
                 image_codes.push_back(image);
                 scores.push_back(score);
             } else {
-                if(!result_file.eof()) CV_Error(CV_StsParseError,"error parsing VOC classifier results file.");
+                if(!result_file.eof()) CV_Error(Error::StsParseError,"error parsing VOC classifier results file.");
             }
         }
         result_file.close();
     } else {
         string err_msg = "could not open classifier results file '" + input_file + "' for reading.";
-        CV_Error(CV_StsError,err_msg.c_str());
+        CV_Error(Error::StsError,err_msg.c_str());
     }
 }
 
@@ -1545,13 +1546,13 @@ void VocData::readDetectorResultsFile(const string& input_file, vector<string>&
                     bounding_boxes[image_idx].push_back(bounding_box);
                 }
             } else {
-                if(!result_file.eof()) CV_Error(CV_StsParseError,"error parsing VOC detector results file.");
+                if(!result_file.eof()) CV_Error(Error::StsParseError,"error parsing VOC detector results file.");
             }
         }
         result_file.close();
     } else {
         string err_msg = "could not open detector results file '" + input_file + "' for reading.";
-        CV_Error(CV_StsError,err_msg.c_str());
+        CV_Error(Error::StsError,err_msg.c_str());
     }
 }
 
@@ -1595,23 +1596,23 @@ void VocData::extractVocObjects(const string filename, vector<ObdObject>& object
 
             //object class -------------
 
-            if (extractXMLBlock(object_contents, "name", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <name> tag in object definition of '" + filename + "'");
+            if (extractXMLBlock(object_contents, "name", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <name> tag in object definition of '" + filename + "'");
             object.object_class.swap(tag_contents);
 
             //object bounding box -------------
 
             int xmax, xmin, ymax, ymin;
 
-            if (extractXMLBlock(object_contents, "xmax", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <xmax> tag in object definition of '" + filename + "'");
+            if (extractXMLBlock(object_contents, "xmax", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <xmax> tag in object definition of '" + filename + "'");
             xmax = stringToInteger(tag_contents);
 
-            if (extractXMLBlock(object_contents, "xmin", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <xmin> tag in object definition of '" + filename + "'");
+            if (extractXMLBlock(object_contents, "xmin", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <xmin> tag in object definition of '" + filename + "'");
             xmin = stringToInteger(tag_contents);
 
-            if (extractXMLBlock(object_contents, "ymax", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <ymax> tag in object definition of '" + filename + "'");
+            if (extractXMLBlock(object_contents, "ymax", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <ymax> tag in object definition of '" + filename + "'");
             ymax = stringToInteger(tag_contents);
 
-            if (extractXMLBlock(object_contents, "ymin", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <ymin> tag in object definition of '" + filename + "'");
+            if (extractXMLBlock(object_contents, "ymin", 0, tag_contents) == -1) CV_Error(Error::StsError,"missing <ymin> tag in object definition of '" + filename + "'");
             ymin = stringToInteger(tag_contents);
 
             object.boundingBox.x = xmin-1;      //convert to 0-based indexing
@@ -1714,11 +1715,11 @@ void VocData::extractDataFromResultsFilename(const string& input_file, string& c
     size_t fnameend = input_file_std.rfind(".txt");
 
     if ((fnamestart == input_file_std.npos) || (fnameend == input_file_std.npos))
-        CV_Error(CV_StsError,"Could not extract filename of results file.");
+        CV_Error(Error::StsError,"Could not extract filename of results file.");
 
     ++fnamestart;
     if (fnamestart >= fnameend)
-        CV_Error(CV_StsError,"Could not extract filename of results file.");
+        CV_Error(Error::StsError,"Could not extract filename of results file.");
 
     //extract dataset and class names, triggering exception if the filename format is not correct
     string filename = input_file_std.substr(fnamestart, fnameend-fnamestart);
@@ -1729,11 +1730,11 @@ void VocData::extractDataFromResultsFilename(const string& input_file, string& c
     size_t classend = filename.find("_",classstart+1);
     if (classend == filename.npos) classend = filename.size();
     if ((datasetstart == filename.npos) || (classstart == filename.npos))
-        CV_Error(CV_StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
+        CV_Error(Error::StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
     ++datasetstart;
     ++classstart;
     if (((datasetstart-classstart) < 1) || ((classend-datasetstart) < 1))
-        CV_Error(CV_StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
+        CV_Error(Error::StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
 
     dataset_name = filename.substr(datasetstart,classstart-datasetstart-1);
     class_name = filename.substr(classstart,classend-classstart);
@@ -1781,7 +1782,7 @@ bool VocData::getClassifierGroundTruthImage(const string& obj_class, const strin
         return m_classifier_gt_all_present[std::distance(m_classifier_gt_all_ids.begin(),it)] != 0;
     } else {
         string err_msg = "could not find classifier ground truth for image '" + id + "' and class '" + obj_class + "'";
-        CV_Error(CV_StsError,err_msg.c_str());
+        CV_Error(Error::StsError,err_msg.c_str());
     }
 
     return true;
@@ -1814,7 +1815,7 @@ void VocData::getSortOrder(const vector<float>& values, vector<size_t>& order, b
 void VocData::readFileToString(const string filename, string& file_contents)
 {
     std::ifstream ifs(filename.c_str());
-    if (!ifs.is_open()) CV_Error(CV_StsError,"could not open text file");
+    if (!ifs.is_open()) CV_Error(Error::StsError,"could not open text file");
 
     stringstream oss;
     oss << ifs.rdbuf();
@@ -1829,7 +1830,7 @@ int VocData::stringToInteger(const string input_str)
     stringstream ss(input_str);
     if ((ss >> result).fail())
     {
-        CV_Error(CV_StsBadArg,"could not perform string to integer conversion");
+        CV_Error(Error::StsBadArg,"could not perform string to integer conversion");
     }
     return result;
 }
@@ -1841,7 +1842,7 @@ string VocData::integerToString(const int input_int)
     stringstream ss;
     if ((ss << input_int).fail())
     {
-        CV_Error(CV_StsBadArg,"could not perform integer to string conversion");
+        CV_Error(Error::StsBadArg,"could not perform integer to string conversion");
     }
     result = ss.str();
     return result;
@@ -2325,7 +2326,7 @@ static void removeBowImageDescriptorsByCount( vector<ObdImage>& images, vector<M
     CV_Assert( bowImageDescriptors.size() == objectPresent.size() );
 }
 
-static void setSVMParams( CvSVMParams& svmParams, CvMat& class_wts_cv, const Mat& responses, bool balanceClasses )
+static void setSVMParams( const SVM::Params& svmParams, Mat& class_wts_cv, const Mat& responses, bool balanceClasses )
 {
     int pos_ex = countNonZero(responses == 1);
     int neg_ex = countNonZero(responses == -1);
index a078588..be792a9 100644 (file)
@@ -2,6 +2,7 @@
 #include "opencv2/ml.hpp"
 
 using namespace cv;
+using namespace cv::ml;
 
 int main( int /*argc*/, char** /*argv*/ )
 {
@@ -34,8 +35,9 @@ int main( int /*argc*/, char** /*argv*/ )
     samples = samples.reshape(1, 0);
 
     // cluster the data
-    EM em_model(N, EM::COV_MAT_SPHERICAL, TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, 300, 0.1));
-    em_model.train( samples, noArray(), labels, noArray() );
+    Ptr<EM> em_model = EM::train( samples, noArray(), labels, noArray(),
+            EM::Params(N, EM::COV_MAT_SPHERICAL,
+                       TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, 300, 0.1)));
 
     // classify every image pixel
     for( i = 0; i < img.rows; i++ )
@@ -44,7 +46,7 @@ int main( int /*argc*/, char** /*argv*/ )
         {
             sample.at<float>(0) = (float)j;
             sample.at<float>(1) = (float)i;
-            int response = cvRound(em_model.predict( sample )[1]);
+            int response = cvRound(em_model->predict2( sample, noArray() )[1]);
             Scalar c = colors[response];
 
             circle( img, Point(j, i), 1, c*0.75, FILLED );
index 26858da..0a742f3 100644 (file)
@@ -12,6 +12,7 @@
 
 using namespace std;
 using namespace cv;
+using namespace cv::ml;
 
 const Scalar WHITE_COLOR = Scalar(255,255,255);
 const string winName = "points";
@@ -22,18 +23,20 @@ RNG rng;
 
 vector<Point>  trainedPoints;
 vector<int>    trainedPointsMarkers;
-vector<Scalar> classColors;
-
-#define _NBC_ 0 // normal Bayessian classifier
-#define _KNN_ 0 // k nearest neighbors classifier
-#define _SVM_ 0 // support vectors machine
+const int MAX_CLASSES = 2;
+vector<Vec3b>  classColors(MAX_CLASSES);
+int currentClass = 0;
+vector<int> classCounters(MAX_CLASSES);
+
+#define _NBC_ 1 // normal Bayessian classifier
+#define _KNN_ 1 // k nearest neighbors classifier
+#define _SVM_ 1 // support vectors machine
 #define _DT_  1 // decision tree
-#define _BT_  0 // ADA Boost
+#define _BT_  1 // ADA Boost
 #define _GBT_ 0 // gradient boosted trees
-#define _RF_  0 // random forest
-#define _ERT_ 0 // extremely randomized trees
-#define _ANN_ 0 // artificial neural networks
-#define _EM_  0 // expectation-maximization
+#define _RF_  1 // random forest
+#define _ANN_ 1 // artificial neural networks
+#define _EM_  1 // expectation-maximization
 
 static void on_mouse( int event, int x, int y, int /*flags*/, void* )
 {
@@ -44,76 +47,43 @@ static void on_mouse( int event, int x, int y, int /*flags*/, void* )
 
     if( event == EVENT_LBUTTONUP )
     {
-        if( classColors.empty() )
-            return;
-
         trainedPoints.push_back( Point(x,y) );
-        trainedPointsMarkers.push_back( (int)(classColors.size()-1) );
+        trainedPointsMarkers.push_back( currentClass );
+        classCounters[currentClass]++;
         updateFlag = true;
     }
-    else if( event == EVENT_RBUTTONUP )
-    {
-#if _BT_
-        if( classColors.size() < 2 )
-        {
-#endif
-            classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
-            updateFlag = true;
-#if _BT_
-        }
-        else
-            cout << "New class can not be added, because CvBoost can only be used for 2-class classification" << endl;
-#endif
-
-    }
 
     //draw
     if( updateFlag )
     {
         img = Scalar::all(0);
 
-        // put the text
-        stringstream text;
-        text << "current class " << classColors.size()-1;
-        putText( img, text.str(), Point(10,25), FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
-
-        text.str("");
-        text << "total classes " << classColors.size();
-        putText( img, text.str(), Point(10,50), FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
-
-        text.str("");
-        text << "total points " << trainedPoints.size();
-        putText(img, text.str(), Point(10,75), FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
-
         // draw points
         for( size_t i = 0; i < trainedPoints.size(); i++ )
-            circle( img, trainedPoints[i], 5, classColors[trainedPointsMarkers[i]], -1 );
+        {
+            Vec3b c = classColors[trainedPointsMarkers[i]];
+            circle( img, trainedPoints[i], 5, Scalar(c), -1 );
+        }
 
         imshow( winName, img );
    }
 }
 
-static void prepare_train_data( Mat& samples, Mat& classes )
+static Mat prepare_train_samples(const vector<Point>& pts)
 {
-    Mat( trainedPoints ).copyTo( samples );
-    Mat( trainedPointsMarkers ).copyTo( classes );
-
-    // reshape trainData and change its type
-    samples = samples.reshape( 1, samples.rows );
-    samples.convertTo( samples, CV_32FC1 );
+    Mat samples;
+    Mat(pts).reshape(1, (int)pts.size()).convertTo(samples, CV_32F);
+    return samples;
 }
 
-#if _NBC_
-static void find_decision_boundary_NBC()
+static Ptr<TrainData> prepare_train_data()
 {
-    img.copyTo( imgDst );
-
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
-
-    // learn classifier
-    CvNormalBayesClassifier normalBayesClassifier( trainSamples, trainClasses );
+    Mat samples = prepare_train_samples(trainedPoints);
+    return TrainData::create(samples, ROW_SAMPLE, Mat(trainedPointsMarkers));
+}
 
+static void predict_and_paint(const Ptr<StatModel>& model, Mat& dst)
+{
     Mat testSample( 1, 2, CV_32FC1 );
     for( int y = 0; y < img.rows; y += testStep )
     {
@@ -122,328 +92,146 @@ static void find_decision_boundary_NBC()
             testSample.at<float>(0) = (float)x;
             testSample.at<float>(1) = (float)y;
 
-            int response = (int)normalBayesClassifier.predict( testSample );
-            circle( imgDst, Point(x,y), 1, classColors[response] );
+            int response = (int)model->predict( testSample );
+            dst.at<Vec3b>(y, x) = classColors[response];
         }
     }
 }
-#endif
-
 
-#if _KNN_
-static void find_decision_boundary_KNN( int K )
+#if _NBC_
+static void find_decision_boundary_NBC()
 {
-    img.copyTo( imgDst );
-
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
-
     // learn classifier
-#if defined HAVE_OPENCV_OCL && _OCL_KNN_
-    cv::ocl::KNearestNeighbour knnClassifier;
-    Mat temp, result;
-    knnClassifier.train(trainSamples, trainClasses, temp, false, K);
-    cv::ocl::oclMat testSample_ocl, reslut_ocl;
-#else
-    CvKNearest knnClassifier( trainSamples, trainClasses, Mat(), false, K );
-#endif
+    Ptr<NormalBayesClassifier> normalBayesClassifier = NormalBayesClassifier::create();
+    normalBayesClassifier->train(prepare_train_data());
 
-    Mat testSample( 1, 2, CV_32FC1 );
-    for( int y = 0; y < img.rows; y += testStep )
-    {
-        for( int x = 0; x < img.cols; x += testStep )
-        {
-            testSample.at<float>(0) = (float)x;
-            testSample.at<float>(1) = (float)y;
-#if defined HAVE_OPENCV_OCL && _OCL_KNN_
-            testSample_ocl.upload(testSample);
+    predict_and_paint(normalBayesClassifier, imgDst);
+}
+#endif
 
-            knnClassifier.find_nearest(testSample_ocl, K, reslut_ocl);
 
-            reslut_ocl.download(result);
-            int response = saturate_cast<int>(result.at<float>(0));
-            circle(imgDst, Point(x, y), 1, classColors[response]);
-#else
+#if _KNN_
+static void find_decision_boundary_KNN( int K )
+{
+    Ptr<KNearest> knn = KNearest::create(true);
+    knn->setDefaultK(K);
+    knn->train(prepare_train_data());
 
-            int response = (int)knnClassifier.find_nearest( testSample, K );
-            circle( imgDst, Point(x,y), 1, classColors[response] );
-#endif
-        }
-    }
+    predict_and_paint(knn, imgDst);
 }
 #endif
 
 #if _SVM_
-static void find_decision_boundary_SVM( CvSVMParams params )
+static void find_decision_boundary_SVM( SVM::Params params )
 {
-    img.copyTo( imgDst );
-
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
+    Ptr<SVM> svm = SVM::create(params);
+    svm->train(prepare_train_data());
 
-    // learn classifier
-#if defined HAVE_OPENCV_OCL && _OCL_SVM_
-    cv::ocl::CvSVM_OCL svmClassifier(trainSamples, trainClasses, Mat(), Mat(), params);
-#else
-    CvSVM svmClassifier( trainSamples, trainClasses, Mat(), Mat(), params );
-#endif
+    predict_and_paint(svm, imgDst);
 
-    Mat testSample( 1, 2, CV_32FC1 );
-    for( int y = 0; y < img.rows; y += testStep )
+    Mat sv = svm->getSupportVectors();
+    for( int i = 0; i < sv.rows; i++ )
     {
-        for( int x = 0; x < img.cols; x += testStep )
-        {
-            testSample.at<float>(0) = (float)x;
-            testSample.at<float>(1) = (float)y;
-
-            int response = (int)svmClassifier.predict( testSample );
-            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
-        }
+        const float* supportVector = sv.ptr<float>(i);
+        circle( imgDst, Point(saturate_cast<int>(supportVector[0]),saturate_cast<int>(supportVector[1])), 5, Scalar(255,255,255), -1 );
     }
-
-
-    for( int i = 0; i < svmClassifier.get_support_vector_count(); i++ )
-    {
-        const float* supportVector = svmClassifier.get_support_vector(i);
-        circle( imgDst, Point(saturate_cast<int>(supportVector[0]),saturate_cast<int>(supportVector[1])), 5, CV_RGB(255,255,255), -1 );
-    }
-
 }
 #endif
 
 #if _DT_
 static void find_decision_boundary_DT()
 {
-    img.copyTo( imgDst );
-
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
-
-    // learn classifier
-    CvDTree  dtree;
-
-    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
-    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
-
-    CvDTreeParams params;
-    params.max_depth = 8;
-    params.min_sample_count = 2;
-    params.use_surrogates = false;
-    params.cv_folds = 0; // the number of cross-validation folds
-    params.use_1se_rule = false;
-    params.truncate_pruned_tree = false;
-
-    dtree.train( trainSamples, CV_ROW_SAMPLE, trainClasses,
-                 Mat(), Mat(), var_types, Mat(), params );
-
-    Mat testSample(1, 2, CV_32FC1 );
-    for( int y = 0; y < img.rows; y += testStep )
-    {
-        for( int x = 0; x < img.cols; x += testStep )
-        {
-            testSample.at<float>(0) = (float)x;
-            testSample.at<float>(1) = (float)y;
-
-            int response = (int)dtree.predict( testSample )->value;
-            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
-        }
-    }
+    DTrees::Params params;
+    params.maxDepth = 8;
+    params.minSampleCount = 2;
+    params.useSurrogates = false;
+    params.CVFolds = 0; // the number of cross-validation folds
+    params.use1SERule = false;
+    params.truncatePrunedTree = false;
+
+    Ptr<DTrees> dtree = DTrees::create(params);
+    dtree->train(prepare_train_data());
+
+    predict_and_paint(dtree, imgDst);
 }
 #endif
 
 #if _BT_
-void find_decision_boundary_BT()
+static void find_decision_boundary_BT()
 {
-    img.copyTo( imgDst );
-
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
-
-    // learn classifier
-    CvBoost  boost;
-
-    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
-    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
-
-    CvBoostParams  params( CvBoost::DISCRETE, // boost_type
-                           100, // weak_count
-                           0.95, // weight_trim_rate
-                           2, // max_depth
-                           false, //use_surrogates
-                           0 // priors
-                         );
-
-    boost.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
-
-    Mat testSample(1, 2, CV_32FC1 );
-    for( int y = 0; y < img.rows; y += testStep )
-    {
-        for( int x = 0; x < img.cols; x += testStep )
-        {
-            testSample.at<float>(0) = (float)x;
-            testSample.at<float>(1) = (float)y;
-
-            int response = (int)boost.predict( testSample );
-            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
-        }
-    }
+    Boost::Params params( Boost::DISCRETE, // boost_type
+                          100, // weak_count
+                          0.95, // weight_trim_rate
+                          2, // max_depth
+                          false, //use_surrogates
+                          Mat() // priors
+                          );
+
+    Ptr<Boost> boost = Boost::create(params);
+    boost->train(prepare_train_data());
+    predict_and_paint(boost, imgDst);
 }
 
 #endif
 
 #if _GBT_
-void find_decision_boundary_GBT()
+static void find_decision_boundary_GBT()
 {
-    img.copyTo( imgDst );
-
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
-
-    // learn classifier
-    CvGBTrees gbtrees;
-
-    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
-    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
-
-    CvGBTreesParams  params( CvGBTrees::DEVIANCE_LOSS, // loss_function_type
-                             100, // weak_count
-                             0.1f, // shrinkage
-                             1.0f, // subsample_portion
-                             2, // max_depth
-                             false // use_surrogates )
-                           );
-
-    gbtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
-
-    Mat testSample(1, 2, CV_32FC1 );
-    for( int y = 0; y < img.rows; y += testStep )
-    {
-        for( int x = 0; x < img.cols; x += testStep )
-        {
-            testSample.at<float>(0) = (float)x;
-            testSample.at<float>(1) = (float)y;
+    GBTrees::Params params( GBTrees::DEVIANCE_LOSS, // loss_function_type
+                         100, // weak_count
+                         0.1f, // shrinkage
+                         1.0f, // subsample_portion
+                         2, // max_depth
+                         false // use_surrogates )
+                         );
 
-            int response = (int)gbtrees.predict( testSample );
-            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
-        }
-    }
+    Ptr<GBTrees> gbtrees = GBTrees::create(params);
+    gbtrees->train(prepare_train_data());
+    predict_and_paint(gbtrees, imgDst);
 }
-
 #endif
 
 #if _RF_
-void find_decision_boundary_RF()
+static void find_decision_boundary_RF()
 {
-    img.copyTo( imgDst );
-
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
-
-    // learn classifier
-    CvRTrees  rtrees;
-    CvRTParams  params( 4, // max_depth,
+    RTrees::Params  params( 4, // max_depth,
                         2, // min_sample_count,
                         0.f, // regression_accuracy,
                         false, // use_surrogates,
                         16, // max_categories,
-                        0, // priors,
+                        Mat(), // priors,
                         false, // calc_var_importance,
                         1, // nactive_vars,
-                        5, // max_num_of_trees_in_the_forest,
-                        0, // forest_accuracy,
-                        CV_TERMCRIT_ITER // termcrit_type
+                        TermCriteria(TermCriteria::MAX_ITER, 5, 0) // max_num_of_trees_in_the_forest,
                        );
 
-    rtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), Mat(), Mat(), params );
-
-    Mat testSample(1, 2, CV_32FC1 );
-    for( int y = 0; y < img.rows; y += testStep )
-    {
-        for( int x = 0; x < img.cols; x += testStep )
-        {
-            testSample.at<float>(0) = (float)x;
-            testSample.at<float>(1) = (float)y;
-
-            int response = (int)rtrees.predict( testSample );
-            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
-        }
-    }
+    Ptr<RTrees> rtrees = RTrees::create(params);
+    rtrees->train(prepare_train_data());
+    predict_and_paint(rtrees, imgDst);
 }
 
 #endif
 
-#if _ERT_
-void find_decision_boundary_ERT()
-{
-    img.copyTo( imgDst );
-
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
-
-    // learn classifier
-    CvERTrees ertrees;
-
-    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
-    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
-
-    CvRTParams  params( 4, // max_depth,
-                        2, // min_sample_count,
-                        0.f, // regression_accuracy,
-                        false, // use_surrogates,
-                        16, // max_categories,
-                        0, // priors,
-                        false, // calc_var_importance,
-                        1, // nactive_vars,
-                        5, // max_num_of_trees_in_the_forest,
-                        0, // forest_accuracy,
-                        CV_TERMCRIT_ITER // termcrit_type
-                       );
-
-    ertrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
-
-    Mat testSample(1, 2, CV_32FC1 );
-    for( int y = 0; y < img.rows; y += testStep )
-    {
-        for( int x = 0; x < img.cols; x += testStep )
-        {
-            testSample.at<float>(0) = (float)x;
-            testSample.at<float>(1) = (float)y;
-
-            int response = (int)ertrees.predict( testSample );
-            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
-        }
-    }
-}
-#endif
-
 #if _ANN_
-void find_decision_boundary_ANN( const Mat&  layer_sizes )
+static void find_decision_boundary_ANN( const Mat&  layer_sizes )
 {
-    img.copyTo( imgDst );
+    ANN_MLP::Params params(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON),
+                           ANN_MLP::Params::BACKPROP, 0.001);
+    Ptr<ANN_MLP> ann = ANN_MLP::create(layer_sizes, params, ANN_MLP::SIGMOID_SYM, 1, 1 );
 
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
-
-    // prerare trainClasses
-    trainClasses.create( trainedPoints.size(), classColors.size(), CV_32FC1 );
-    for( int i = 0; i <  trainClasses.rows; i++ )
+    Mat trainClasses = Mat::zeros( trainedPoints.size(), classColors.size(), CV_32FC1 );
+    for( int i = 0; i < trainClasses.rows; i++ )
     {
-        for( int k = 0; k < trainClasses.cols; k++ )
-        {
-            if( k == trainedPointsMarkers[i] )
-                trainClasses.at<float>(i,k) = 1;
-            else
-                trainClasses.at<float>(i,k) = 0;
-        }
+        trainClasses.at<float>(i, trainedPointsMarkers[i]) = 1.f;
     }
 
-    Mat weights( 1, trainedPoints.size(), CV_32FC1, Scalar::all(1) );
+    Mat samples = prepare_train_samples(trainedPoints);
+    Ptr<TrainData> tdata = TrainData::create(samples, ROW_SAMPLE, trainClasses);
 
-    // learn classifier
-    CvANN_MLP  ann( layer_sizes, CvANN_MLP::SIGMOID_SYM, 1, 1 );
-    ann.train( trainSamples, trainClasses, weights );
+    ann->train(tdata);
 
     Mat testSample( 1, 2, CV_32FC1 );
+    Mat outputs;
     for( int y = 0; y < img.rows; y += testStep )
     {
         for( int x = 0; x < img.cols; x += testStep )
@@ -451,49 +239,50 @@ void find_decision_boundary_ANN( const Mat&  layer_sizes )
             testSample.at<float>(0) = (float)x;
             testSample.at<float>(1) = (float)y;
 
-            Mat outputs( 1, classColors.size(), CV_32FC1, testSample.data );
-            ann.predict( testSample, outputs );
+            ann->predict( testSample, outputs );
             Point maxLoc;
             minMaxLoc( outputs, 0, 0, 0, &maxLoc );
-            circle( imgDst, Point(x,y), 2, classColors[maxLoc.x], 1 );
+            imgDst.at<Vec3b>(y, x) = classColors[maxLoc.x];
         }
     }
 }
 #endif
 
 #if _EM_
-void find_decision_boundary_EM()
+static void find_decision_boundary_EM()
 {
     img.copyTo( imgDst );
 
-    Mat trainSamples, trainClasses;
-    prepare_train_data( trainSamples, trainClasses );
-
-    vector<cv::EM> em_models(classColors.size());
+    Mat samples = prepare_train_samples(trainedPoints);
 
-    CV_Assert((int)trainClasses.total() == trainSamples.rows);
-    CV_Assert((int)trainClasses.type() == CV_32SC1);
+    int i, j, nmodels = (int)classColors.size();
+    vector<Ptr<EM> > em_models(nmodels);
+    Mat modelSamples;
 
-    for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
+    for( i = 0; i < nmodels; i++ )
     {
         const int componentCount = 3;
-        em_models[modelIndex] = EM(componentCount, cv::EM::COV_MAT_DIAGONAL);
 
-        Mat modelSamples;
-        for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
+        modelSamples.release();
+        for( j = 0; j < samples.rows; j++ )
         {
-            if(trainClasses.at<int>(sampleIndex) == (int)modelIndex)
-                modelSamples.push_back(trainSamples.row(sampleIndex));
+            if( trainedPointsMarkers[j] == i )
+                modelSamples.push_back(samples.row(j));
         }
 
         // learn models
-        if(!modelSamples.empty())
-            em_models[modelIndex].train(modelSamples);
+        if( !modelSamples.empty() )
+        {
+            em_models[i] = EM::train(modelSamples, noArray(), noArray(), noArray(),
+                                   EM::Params(componentCount, EM::COV_MAT_DIAGONAL));
+        }
     }
 
     // classify coordinate plane points using the bayes classifier, i.e.
     // y(x) = arg max_i=1_modelsCount likelihoods_i(x)
     Mat testSample(1, 2, CV_32FC1 );
+    Mat logLikelihoods(1, nmodels, CV_64FC1, Scalar(-DBL_MAX));
+
     for( int y = 0; y < img.rows; y += testStep )
     {
         for( int x = 0; x < img.cols; x += testStep )
@@ -501,17 +290,14 @@ void find_decision_boundary_EM()
             testSample.at<float>(0) = (float)x;
             testSample.at<float>(1) = (float)y;
 
-            Mat logLikelihoods(1, em_models.size(), CV_64FC1, Scalar(-DBL_MAX));
-            for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
+            for( i = 0; i < nmodels; i++ )
             {
-                if(em_models[modelIndex].isTrained())
-                    logLikelihoods.at<double>(modelIndex) = em_models[modelIndex].predict(testSample)[0];
+                if( !em_models[i].empty() )
+                    logLikelihoods.at<double>(i) = em_models[i]->predict2(testSample, noArray())[0];
             }
             Point maxLoc;
             minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
-
-            int response = maxLoc.x;
-            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
+            imgDst.at<Vec3b>(y, x) = classColors[maxLoc.x];
         }
     }
 }
@@ -520,7 +306,7 @@ void find_decision_boundary_EM()
 int main()
 {
     cout << "Use:" << endl
-         << "  right mouse button - to add new class;" << endl
+         << "  key '0' .. '1' - switch to class #n" << endl
          << "  left mouse button - to add new point;" << endl
          << "  key 'r' - to run the ML model;" << endl
          << "  key 'i' - to init (clear) the data." << endl << endl;
@@ -532,6 +318,9 @@ int main()
     imshow( "points", img );
     setMouseCallback( "points", on_mouse );
 
+    classColors[0] = Vec3b(0, 255, 0);
+    classColors[1] = Vec3b(0, 0, 255);
+
     for(;;)
     {
         uchar key = (uchar)waitKey();
@@ -542,15 +331,28 @@ int main()
         {
             img = Scalar::all(0);
 
-            classColors.clear();
             trainedPoints.clear();
             trainedPointsMarkers.clear();
+            classCounters.assign(MAX_CLASSES, 0);
 
             imshow( winName, img );
         }
 
+        if( key == '0' || key == '1' )
+        {
+            currentClass = key - '0';
+        }
+
         if( key == 'r' ) // run
         {
+            double minVal = 0;
+            minMaxLoc(classCounters, &minVal, 0, 0, 0);
+            if( minVal == 0 )
+            {
+                printf("each class should have at least 1 point\n");
+                continue;
+            }
+            img.copyTo( imgDst );
 #if _NBC_
             find_decision_boundary_NBC();
             namedWindow( "NormalBayesClassifier", WINDOW_AUTOSIZE );
@@ -570,16 +372,16 @@ int main()
 
 #if _SVM_
             //(1)-(2)separable and not sets
-            CvSVMParams params;
-            params.svm_type = CvSVM::C_SVC;
-            params.kernel_type = CvSVM::POLY; //CvSVM::LINEAR;
+            SVM::Params params;
+            params.svmType = SVM::C_SVC;
+            params.kernelType = SVM::POLY; //CvSVM::LINEAR;
             params.degree = 0.5;
             params.gamma = 1;
             params.coef0 = 1;
             params.C = 1;
             params.nu = 0.5;
             params.p = 0;
-            params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01);
+            params.termCrit = TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 1000, 0.01);
 
             find_decision_boundary_SVM( params );
             namedWindow( "classificationSVM1", WINDOW_AUTOSIZE );
@@ -615,12 +417,6 @@ int main()
             imshow( "RF", imgDst);
 #endif
 
-#if _ERT_
-            find_decision_boundary_ERT();
-            namedWindow( "ERT", WINDOW_AUTOSIZE );
-            imshow( "ERT", imgDst);
-#endif
-
 #if _ANN_
             Mat layer_sizes1( 1, 3, CV_32SC1 );
             layer_sizes1.at<int>(0) = 2;