changed GBT parameters in sample
authorMaria Dimashova <no@email>
Wed, 4 May 2011 14:49:02 +0000 (14:49 +0000)
committerMaria Dimashova <no@email>
Wed, 4 May 2011 14:49:02 +0000 (14:49 +0000)
samples/cpp/points_classifier.cpp

index acf01ab..0479e29 100644 (file)
@@ -272,6 +272,7 @@ void find_decision_boundary_GBT()
 
     Mat trainSamples, trainClasses;
     prepare_train_data( trainSamples, trainClasses );
+    trainClasses.convertTo( trainClasses, CV_32FC1 );
 
     // learn classifier
     CvGBTrees gbtrees;
@@ -279,13 +280,13 @@ void find_decision_boundary_GBT()
     Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
     var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
 
-    CvGBTreesParams  params( CvGBTrees::SQUARED_LOSS, // loss_function_type
+    CvGBTreesParams  params( CvGBTrees::DEVIANCE_LOSS, // loss_function_type
                              100, // weak_count
-                             0.05f, // shrinkage
-                             0.6f, // subsample_portion
+                             0.1f, // shrinkage
+                             1.0f, // subsample_portion
                              2, // max_depth
                              false // use_surrogates )
-                         );
+                           );
 
     gbtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
 
@@ -315,10 +316,6 @@ void find_decision_boundary_RF()
 
     // learn classifier
     CvRTrees  rtrees;
-
-    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
-    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
-
     CvRTParams  params( 4, // max_depth,
                         2, // min_sample_count,
                         0.f, // regression_accuracy,
@@ -332,7 +329,7 @@ void find_decision_boundary_RF()
                         CV_TERMCRIT_ITER // termcrit_type
                        );
 
-    rtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
+    rtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), Mat(), Mat(), params );
 
     Mat testSample(1, 2, CV_32FC1 );
     for( int y = 0; y < img.rows; y += testStep )