1 #include "opencv2/core/core.hpp"
2 #include "opencv2/ml/ml.hpp"
10 using namespace cv::ml;
14 printf("\nThe sample demonstrates how to train Random Trees classifier\n"
15 "(or Boosting classifier, or MLP, or Knearest, or Nbayes, or Support Vector Machines - see main()) using the provided dataset.\n"
17 "We use the sample database letter-recognition.data\n"
18 "from UCI Repository, here is the link:\n"
20 "Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"
21 "UCI Repository of machine learning databases\n"
22 "[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"
23 "Irvine, CA: University of California, Department of Information and Computer Science.\n"
25 "The dataset consists of 20000 feature vectors along with the\n"
26 "responses - capital latin letters A..Z.\n"
27 "The first 16000 (10000 for boosting)) samples are used for training\n"
28 "and the remaining 4000 (10000 for boosting) - to test the classifier.\n"
29 "======================================================\n");
30 printf("\nThis is letter recognition sample.\n"
31 "The usage: letter_recog [-data <path to letter-recognition.data>] \\\n"
32 " [-save <output XML file for the classifier>] \\\n"
33 " [-load <XML file with the pre-trained classifier>] \\\n"
34 " [-boost|-mlp|-knearest|-nbayes|-svm] # to use boost/mlp/knearest/SVM classifier instead of default Random Trees\n" );
37 // This function reads data and responses from the file <filename>
39 read_num_class_data( const string& filename, int var_count,
40 Mat* _data, Mat* _responses )
45 Mat el_ptr(1, var_count, CV_32F);
47 vector<int> responses;
50 _responses->release();
52 FILE* f = fopen( filename.c_str(), "rt" );
55 cout << "Could not read the database " << filename << endl;
62 if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )
64 responses.push_back((int)buf[0]);
66 for( i = 0; i < var_count; i++ )
69 sscanf( ptr, "%f%n", &el_ptr.at<float>(i), &n );
74 _data->push_back(el_ptr);
77 Mat(responses).copyTo(*_responses);
79 cout << "The database " << filename << " is loaded.\n";
85 static Ptr<T> load_classifier(const string& filename_to_load)
87 // load classifier from the specified file
88 Ptr<T> model = StatModel::load<T>( filename_to_load );
90 cout << "Could not read the classifier " << filename_to_load << endl;
92 cout << "The classifier " << filename_to_load << " is loaded.\n";
98 prepare_train_data(const Mat& data, const Mat& responses, int ntrain_samples)
100 Mat sample_idx = Mat::zeros( 1, data.rows, CV_8U );
101 Mat train_samples = sample_idx.colRange(0, ntrain_samples);
102 train_samples.setTo(Scalar::all(1));
104 int nvars = data.cols;
105 Mat var_type( nvars + 1, 1, CV_8U );
106 var_type.setTo(Scalar::all(VAR_ORDERED));
107 var_type.at<uchar>(nvars) = VAR_CATEGORICAL;
109 return TrainData::create(data, ROW_SAMPLE, responses,
110 noArray(), sample_idx, noArray(), var_type);
113 inline TermCriteria TC(int iters, double eps)
115 return TermCriteria(TermCriteria::MAX_ITER + (eps > 0 ? TermCriteria::EPS : 0), iters, eps);
118 static void test_and_save_classifier(const Ptr<StatModel>& model,
119 const Mat& data, const Mat& responses,
120 int ntrain_samples, int rdelta,
121 const string& filename_to_save)
123 int i, nsamples_all = data.rows;
124 double train_hr = 0, test_hr = 0;
126 // compute prediction error on train and test data
127 for( i = 0; i < nsamples_all; i++ )
129 Mat sample = data.row(i);
131 float r = model->predict( sample );
132 r = std::abs(r + rdelta - responses.at<int>(i)) <= FLT_EPSILON ? 1.f : 0.f;
134 if( i < ntrain_samples )
140 test_hr /= nsamples_all - ntrain_samples;
141 train_hr = ntrain_samples > 0 ? train_hr/ntrain_samples : 1.;
143 printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
144 train_hr*100., test_hr*100. );
146 if( !filename_to_save.empty() )
148 model->save( filename_to_save );
154 build_rtrees_classifier( const string& data_filename,
155 const string& filename_to_save,
156 const string& filename_to_load )
160 bool ok = read_num_class_data( data_filename, 16, &data, &responses );
166 int nsamples_all = data.rows;
167 int ntrain_samples = (int)(nsamples_all*0.8);
169 // Create or load Random Trees classifier
170 if( !filename_to_load.empty() )
172 model = load_classifier<RTrees>(filename_to_load);
179 // create classifier by using <data> and <responses>
180 cout << "Training the classifier ...\n";
181 Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
182 model = StatModel::train<RTrees>(tdata, RTrees::Params(10,10,0,false,15,Mat(),true,4,TC(100,0.01f)));
186 test_and_save_classifier(model, data, responses, ntrain_samples, 0, filename_to_save);
187 cout << "Number of trees: " << model->getRoots().size() << endl;
189 // Print variable importance
190 Mat var_importance = model->getVarImportance();
191 if( !var_importance.empty() )
193 double rt_imp_sum = sum( var_importance )[0];
194 printf("var#\timportance (in %%):\n");
195 int i, n = (int)var_importance.total();
196 for( i = 0; i < n; i++ )
197 printf( "%-2d\t%-4.1f\n", i, 100.f*var_importance.at<float>(i)/rt_imp_sum);
205 build_boost_classifier( const string& data_filename,
206 const string& filename_to_save,
207 const string& filename_to_load )
209 const int class_count = 26;
214 bool ok = read_num_class_data( data_filename, 16, &data, &responses );
221 int nsamples_all = data.rows;
222 int ntrain_samples = (int)(nsamples_all*0.5);
223 int var_count = data.cols;
225 // Create or load Boosted Tree classifier
226 if( !filename_to_load.empty() )
228 model = load_classifier<Boost>(filename_to_load);
235 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
237 // As currently boosted tree classifier in MLL can only be trained
238 // for 2-class problems, we transform the training database by
239 // "unrolling" each training sample as many times as the number of
240 // classes (26) that we have.
242 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
244 Mat new_data( ntrain_samples*class_count, var_count + 1, CV_32F );
245 Mat new_responses( ntrain_samples*class_count, 1, CV_32S );
247 // 1. unroll the database type mask
248 printf( "Unrolling the database...\n");
249 for( i = 0; i < ntrain_samples; i++ )
251 const float* data_row = data.ptr<float>(i);
252 for( j = 0; j < class_count; j++ )
254 float* new_data_row = (float*)new_data.ptr<float>(i*class_count+j);
255 memcpy(new_data_row, data_row, var_count*sizeof(data_row[0]));
256 new_data_row[var_count] = (float)j;
257 new_responses.at<int>(i*class_count + j) = responses.at<int>(i) == j+'A';
261 Mat var_type( 1, var_count + 2, CV_8U );
262 var_type.setTo(Scalar::all(VAR_ORDERED));
263 var_type.at<uchar>(var_count) = var_type.at<uchar>(var_count+1) = VAR_CATEGORICAL;
265 Ptr<TrainData> tdata = TrainData::create(new_data, ROW_SAMPLE, new_responses,
266 noArray(), noArray(), noArray(), var_type);
267 vector<double> priors(2);
271 cout << "Training the classifier (may take a few minutes)...\n";
272 model = StatModel::train<Boost>(tdata, Boost::Params(Boost::GENTLE, 100, 0.95, 5, false, Mat(priors) ));
276 Mat temp_sample( 1, var_count + 1, CV_32F );
277 float* tptr = temp_sample.ptr<float>();
279 // compute prediction error on train and test data
280 double train_hr = 0, test_hr = 0;
281 for( i = 0; i < nsamples_all; i++ )
284 double max_sum = -DBL_MAX;
285 const float* ptr = data.ptr<float>(i);
286 for( k = 0; k < var_count; k++ )
289 for( j = 0; j < class_count; j++ )
291 tptr[var_count] = (float)j;
292 float s = model->predict( temp_sample, noArray(), StatModel::RAW_OUTPUT );
296 best_class = j + 'A';
300 double r = std::abs(best_class - responses.at<int>(i)) < FLT_EPSILON ? 1 : 0;
301 if( i < ntrain_samples )
307 test_hr /= nsamples_all-ntrain_samples;
308 train_hr = ntrain_samples > 0 ? train_hr/ntrain_samples : 1.;
309 printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
310 train_hr*100., test_hr*100. );
312 cout << "Number of trees: " << model->getRoots().size() << endl;
314 // Save classifier to file if needed
315 if( !filename_to_save.empty() )
316 model->save( filename_to_save );
323 build_mlp_classifier( const string& data_filename,
324 const string& filename_to_save,
325 const string& filename_to_load )
327 const int class_count = 26;
331 bool ok = read_num_class_data( data_filename, 16, &data, &responses );
337 int nsamples_all = data.rows;
338 int ntrain_samples = (int)(nsamples_all*0.8);
340 // Create or load MLP classifier
341 if( !filename_to_load.empty() )
343 model = load_classifier<ANN_MLP>(filename_to_load);
350 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
352 // MLP does not support categorical variables by explicitly.
353 // So, instead of the output class label, we will use
354 // a binary vector of <class_count> components for training and,
355 // therefore, MLP will give us a vector of "probabilities" at the
358 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
360 Mat train_data = data.rowRange(0, ntrain_samples);
361 Mat train_responses = Mat::zeros( ntrain_samples, class_count, CV_32F );
363 // 1. unroll the responses
364 cout << "Unrolling the responses...\n";
365 for( int i = 0; i < ntrain_samples; i++ )
367 int cls_label = responses.at<int>(i) - 'A';
368 train_responses.at<float>(i, cls_label) = 1.f;
371 // 2. train classifier
372 int layer_sz[] = { data.cols, 100, 100, class_count };
373 int nlayers = (int)(sizeof(layer_sz)/sizeof(layer_sz[0]));
374 Mat layer_sizes( 1, nlayers, CV_32S, layer_sz );
377 int method = ANN_MLP::Params::BACKPROP;
378 double method_param = 0.001;
381 int method = ANN_MLP::Params::RPROP;
382 double method_param = 0.1;
386 Ptr<TrainData> tdata = TrainData::create(train_data, ROW_SAMPLE, train_responses);
388 cout << "Training the classifier (may take a few minutes)...\n";
389 model = StatModel::train<ANN_MLP>(tdata, ANN_MLP::Params(layer_sizes, ANN_MLP::SIGMOID_SYM, 0, 0, TC(max_iter,0), method, method_param));
393 test_and_save_classifier(model, data, responses, ntrain_samples, 'A', filename_to_save);
398 build_knearest_classifier( const string& data_filename, int K )
402 bool ok = read_num_class_data( data_filename, 16, &data, &responses );
408 int nsamples_all = data.rows;
409 int ntrain_samples = (int)(nsamples_all*0.8);
411 // create classifier by using <data> and <responses>
412 cout << "Training the classifier ...\n";
413 Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
414 model = StatModel::train<KNearest>(tdata, KNearest::Params(K, true));
417 test_and_save_classifier(model, data, responses, ntrain_samples, 0, string());
422 build_nbayes_classifier( const string& data_filename )
426 bool ok = read_num_class_data( data_filename, 16, &data, &responses );
430 Ptr<NormalBayesClassifier> model;
432 int nsamples_all = data.rows;
433 int ntrain_samples = (int)(nsamples_all*0.8);
435 // create classifier by using <data> and <responses>
436 cout << "Training the classifier ...\n";
437 Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
438 model = StatModel::train<NormalBayesClassifier>(tdata, NormalBayesClassifier::Params());
441 test_and_save_classifier(model, data, responses, ntrain_samples, 0, string());
446 build_svm_classifier( const string& data_filename,
447 const string& filename_to_save,
448 const string& filename_to_load )
452 bool ok = read_num_class_data( data_filename, 16, &data, &responses );
458 int nsamples_all = data.rows;
459 int ntrain_samples = (int)(nsamples_all*0.8);
461 // Create or load Random Trees classifier
462 if( !filename_to_load.empty() )
464 model = load_classifier<SVM>(filename_to_load);
471 // create classifier by using <data> and <responses>
472 cout << "Training the classifier ...\n";
473 Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
476 params.svmType = SVM::C_SVC;
477 params.kernelType = SVM::LINEAR;
480 model = StatModel::train<SVM>(tdata, params);
484 test_and_save_classifier(model, data, responses, ntrain_samples, 0, filename_to_save);
488 int main( int argc, char *argv[] )
490 string filename_to_save = "";
491 string filename_to_load = "";
492 string data_filename = "./letter-recognition.data";
496 for( i = 1; i < argc; i++ )
498 if( strcmp(argv[i],"-data") == 0 ) // flag "-data letter_recognition.xml"
501 data_filename = argv[i];
503 else if( strcmp(argv[i],"-save") == 0 ) // flag "-save filename.xml"
506 filename_to_save = argv[i];
508 else if( strcmp(argv[i],"-load") == 0) // flag "-load filename.xml"
511 filename_to_load = argv[i];
513 else if( strcmp(argv[i],"-boost") == 0)
517 else if( strcmp(argv[i],"-mlp") == 0 )
521 else if( strcmp(argv[i], "-knearest") == 0 || strcmp(argv[i], "-knn") == 0 )
525 else if( strcmp(argv[i], "-nbayes") == 0)
529 else if( strcmp(argv[i], "-svm") == 0)
539 build_rtrees_classifier( data_filename, filename_to_save, filename_to_load ) :
541 build_boost_classifier( data_filename, filename_to_save, filename_to_load ) :
543 build_mlp_classifier( data_filename, filename_to_save, filename_to_load ) :
545 build_knearest_classifier( data_filename, 10 ) :
547 build_nbayes_classifier( data_filename) :
549 build_svm_classifier( data_filename, filename_to_save, filename_to_load ):