1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
42 #include "test_precomp.hpp"
47 // auxiliary functions
49 void nbayes_check_data( CvMLData* _data )
51 if( _data->get_missing() )
52 CV_Error( CV_StsBadArg, "missing values are not supported" );
53 const CvMat* var_types = _data->get_var_types();
54 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
55 if( ( fabs( cvNorm( var_types, 0, CV_L1 ) -
56 (var_types->rows + var_types->cols - 2)*CV_VAR_ORDERED - CV_VAR_CATEGORICAL ) > FLT_EPSILON ) ||
58 CV_Error( CV_StsBadArg, "incorrect types of predictors or responses" );
60 bool nbayes_train( CvNormalBayesClassifier* nbayes, CvMLData* _data )
62 nbayes_check_data( _data );
63 const CvMat* values = _data->get_values();
64 const CvMat* responses = _data->get_responses();
65 const CvMat* train_sidx = _data->get_train_sample_idx();
66 const CvMat* var_idx = _data->get_var_idx();
67 return nbayes->train( values, responses, var_idx, train_sidx );
69 float nbayes_calc_error( CvNormalBayesClassifier* nbayes, CvMLData* _data, int type, vector<float> *resp )
72 nbayes_check_data( _data );
73 const CvMat* values = _data->get_values();
74 const CvMat* response = _data->get_responses();
75 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
76 int* sidx = sample_idx ? sample_idx->data.i : 0;
77 int r_step = CV_IS_MAT_CONT(response->type) ?
78 1 : response->step / CV_ELEM_SIZE(response->type);
79 int sample_count = sample_idx ? sample_idx->cols : 0;
80 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
82 if( resp && (sample_count > 0) )
84 resp->resize( sample_count );
85 pred_resp = &((*resp)[0]);
88 for( int i = 0; i < sample_count; i++ )
91 int si = sidx ? sidx[i] : i;
92 cvGetRow( values, &sample, si );
93 float r = (float)nbayes->predict( &sample, 0 );
96 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
99 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
104 void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors )
106 const CvMat* values = _data->get_values();
107 const CvMat* var_idx = _data->get_var_idx();
108 if( var_idx->cols + var_idx->rows != values->cols )
109 CV_Error( CV_StsBadArg, "var_idx is not supported" );
110 if( _data->get_missing() )
111 CV_Error( CV_StsBadArg, "missing values are not supported" );
112 int resp_idx = _data->get_response_idx();
114 cvGetCols( values, _predictors, 1, values->cols );
115 else if( resp_idx == values->cols - 1 )
116 cvGetCols( values, _predictors, 0, values->cols - 1 );
118 CV_Error( CV_StsBadArg, "responses must be in the first or last column; other cases are not supported" );
120 bool knearest_train( CvKNearest* knearest, CvMLData* _data )
122 const CvMat* responses = _data->get_responses();
123 const CvMat* train_sidx = _data->get_train_sample_idx();
124 bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;
126 knearest_check_data_and_get_predictors( _data, &predictors );
127 return knearest->train( &predictors, responses, train_sidx, is_regression );
129 float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int type, vector<float> *resp )
132 const CvMat* response = _data->get_responses();
133 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
134 int* sidx = sample_idx ? sample_idx->data.i : 0;
135 int r_step = CV_IS_MAT_CONT(response->type) ?
136 1 : response->step / CV_ELEM_SIZE(response->type);
137 bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;
139 knearest_check_data_and_get_predictors( _data, &predictors );
140 int sample_count = sample_idx ? sample_idx->cols : 0;
141 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;
142 float* pred_resp = 0;
143 if( resp && (sample_count > 0) )
145 resp->resize( sample_count );
146 pred_resp = &((*resp)[0]);
148 if ( !is_regression )
150 for( int i = 0; i < sample_count; i++ )
153 int si = sidx ? sidx[i] : i;
154 cvGetRow( &predictors, &sample, si );
155 float r = knearest->find_nearest( &sample, k );
158 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
161 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
165 for( int i = 0; i < sample_count; i++ )
168 int si = sidx ? sidx[i] : i;
169 cvGetRow( &predictors, &sample, si );
170 float r = knearest->find_nearest( &sample, k );
173 float d = r - response->data.fl[si*r_step];
176 err = sample_count ? err / (float)sample_count : -FLT_MAX;
182 int str_to_svm_type(string& str)
184 if( !str.compare("C_SVC") )
186 if( !str.compare("NU_SVC") )
187 return CvSVM::NU_SVC;
188 if( !str.compare("ONE_CLASS") )
189 return CvSVM::ONE_CLASS;
190 if( !str.compare("EPS_SVR") )
191 return CvSVM::EPS_SVR;
192 if( !str.compare("NU_SVR") )
193 return CvSVM::NU_SVR;
194 CV_Error( CV_StsBadArg, "incorrect svm type string" );
197 int str_to_svm_kernel_type( string& str )
199 if( !str.compare("LINEAR") )
200 return CvSVM::LINEAR;
201 if( !str.compare("POLY") )
203 if( !str.compare("RBF") )
205 if( !str.compare("SIGMOID") )
206 return CvSVM::SIGMOID;
207 CV_Error( CV_StsBadArg, "incorrect svm type string" );
210 void svm_check_data( CvMLData* _data )
212 if( _data->get_missing() )
213 CV_Error( CV_StsBadArg, "missing values are not supported" );
214 const CvMat* var_types = _data->get_var_types();
215 for( int i = 0; i < var_types->cols-1; i++ )
216 if (var_types->data.ptr[i] == CV_VAR_CATEGORICAL)
219 sprintf( msg, "incorrect type of %d-predictor", i );
220 CV_Error( CV_StsBadArg, msg );
223 bool svm_train( CvSVM* svm, CvMLData* _data, CvSVMParams _params )
225 svm_check_data(_data);
226 const CvMat* _train_data = _data->get_values();
227 const CvMat* _responses = _data->get_responses();
228 const CvMat* _var_idx = _data->get_var_idx();
229 const CvMat* _sample_idx = _data->get_train_sample_idx();
230 return svm->train( _train_data, _responses, _var_idx, _sample_idx, _params );
232 bool svm_train_auto( CvSVM* svm, CvMLData* _data, CvSVMParams _params,
233 int k_fold, CvParamGrid C_grid, CvParamGrid gamma_grid,
234 CvParamGrid p_grid, CvParamGrid nu_grid, CvParamGrid coef_grid,
235 CvParamGrid degree_grid )
237 svm_check_data(_data);
238 const CvMat* _train_data = _data->get_values();
239 const CvMat* _responses = _data->get_responses();
240 const CvMat* _var_idx = _data->get_var_idx();
241 const CvMat* _sample_idx = _data->get_train_sample_idx();
242 return svm->train_auto( _train_data, _responses, _var_idx,
243 _sample_idx, _params, k_fold, C_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
245 float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector<float> *resp )
247 svm_check_data(_data);
249 const CvMat* values = _data->get_values();
250 const CvMat* response = _data->get_responses();
251 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
252 const CvMat* var_types = _data->get_var_types();
253 int* sidx = sample_idx ? sample_idx->data.i : 0;
254 int r_step = CV_IS_MAT_CONT(response->type) ?
255 1 : response->step / CV_ELEM_SIZE(response->type);
256 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
257 int sample_count = sample_idx ? sample_idx->cols : 0;
258 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
259 float* pred_resp = 0;
260 if( resp && (sample_count > 0) )
262 resp->resize( sample_count );
263 pred_resp = &((*resp)[0]);
267 for( int i = 0; i < sample_count; i++ )
270 int si = sidx ? sidx[i] : i;
271 cvGetRow( values, &sample, si );
272 float r = svm->predict( &sample );
275 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
278 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
282 for( int i = 0; i < sample_count; i++ )
285 int si = sidx ? sidx[i] : i;
286 cvGetRow( values, &sample, si );
287 float r = svm->predict( &sample );
290 float d = r - response->data.fl[si*r_step];
293 err = sample_count ? err / (float)sample_count : -FLT_MAX;
300 int str_to_ann_train_method( string& str )
302 if( !str.compare("BACKPROP") )
303 return CvANN_MLP_TrainParams::BACKPROP;
304 if( !str.compare("RPROP") )
305 return CvANN_MLP_TrainParams::RPROP;
306 CV_Error( CV_StsBadArg, "incorrect ann train method string" );
309 void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs )
311 const CvMat* values = _data->get_values();
312 const CvMat* var_idx = _data->get_var_idx();
313 if( var_idx->cols + var_idx->rows != values->cols )
314 CV_Error( CV_StsBadArg, "var_idx is not supported" );
315 if( _data->get_missing() )
316 CV_Error( CV_StsBadArg, "missing values are not supported" );
317 int resp_idx = _data->get_response_idx();
319 cvGetCols( values, _inputs, 1, values->cols );
320 else if( resp_idx == values->cols - 1 )
321 cvGetCols( values, _inputs, 0, values->cols - 1 );
323 CV_Error( CV_StsBadArg, "outputs must be in the first or last column; other cases are not supported" );
325 void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map<int, int>& cls_map )
327 const CvMat* train_sidx = _data->get_train_sample_idx();
328 int* train_sidx_ptr = train_sidx->data.i;
329 const CvMat* responses = _data->get_responses();
330 float* responses_ptr = responses->data.fl;
331 int r_step = CV_IS_MAT_CONT(responses->type) ?
332 1 : responses->step / CV_ELEM_SIZE(responses->type);
336 for( int si = 0; si < train_sidx->cols; si++ )
338 int sidx = train_sidx_ptr[si];
339 int r = cvRound(responses_ptr[sidx*r_step]);
340 CV_DbgAssert( fabs(responses_ptr[sidx*r_step]-r) < FLT_EPSILON );
341 int cls_map_size = (int)cls_map.size();
343 if ( (int)cls_map.size() > cls_map_size )
344 cls_map[r] = cls_count++;
346 new_responses.create( responses->rows, cls_count, CV_32F );
347 new_responses.setTo( 0 );
348 for( int si = 0; si < train_sidx->cols; si++ )
350 int sidx = train_sidx_ptr[si];
351 int r = cvRound(responses_ptr[sidx*r_step]);
352 int cidx = cls_map[r];
353 new_responses.ptr<float>(sidx)[cidx] = 1;
356 int ann_train( CvANN_MLP* ann, CvMLData* _data, Mat& new_responses, CvANN_MLP_TrainParams _params, int flags = 0 )
358 const CvMat* train_sidx = _data->get_train_sample_idx();
360 ann_check_data_and_get_predictors( _data, &predictors );
361 CvMat _new_responses = CvMat( new_responses );
362 return ann->train( &predictors, &_new_responses, 0, train_sidx, _params, flags );
364 float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map<int, int>& cls_map, int type , vector<float> *resp_labels )
367 const CvMat* responses = _data->get_responses();
368 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
369 int* sidx = sample_idx ? sample_idx->data.i : 0;
370 int r_step = CV_IS_MAT_CONT(responses->type) ?
371 1 : responses->step / CV_ELEM_SIZE(responses->type);
373 ann_check_data_and_get_predictors( _data, &predictors );
374 int sample_count = sample_idx ? sample_idx->cols : 0;
375 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;
376 float* pred_resp = 0;
377 vector<float> innresp;
378 if( sample_count > 0 )
382 resp_labels->resize( sample_count );
383 pred_resp = &((*resp_labels)[0]);
387 innresp.resize( sample_count );
388 pred_resp = &(innresp[0]);
391 int cls_count = (int)cls_map.size();
392 Mat output( 1, cls_count, CV_32FC1 );
393 CvMat _output = CvMat(output);
394 for( int i = 0; i < sample_count; i++ )
397 int si = sidx ? sidx[i] : i;
398 cvGetRow( &predictors, &sample, si );
399 ann->predict( &sample, &_output );
400 CvPoint best_cls = {0,0};
401 cvMinMaxLoc( &_output, 0, 0, 0, &best_cls, 0 );
402 int r = cvRound(responses->data.fl[si*r_step]);
403 CV_DbgAssert( fabs(responses->data.fl[si*r_step]-r) < FLT_EPSILON );
405 int d = best_cls.x == r ? 0 : 1;
407 pred_resp[i] = (float)best_cls.x;
409 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
415 int str_to_boost_type( string& str )
417 if ( !str.compare("DISCRETE") )
418 return CvBoost::DISCRETE;
419 if ( !str.compare("REAL") )
420 return CvBoost::REAL;
421 if ( !str.compare("LOGIT") )
422 return CvBoost::LOGIT;
423 if ( !str.compare("GENTLE") )
424 return CvBoost::GENTLE;
425 CV_Error( CV_StsBadArg, "incorrect boost type string" );
432 // ---------------------------------- MLBaseTest ---------------------------------------------------
434 CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
436 int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
437 CV_BIG_INT(0x0000a17166072c7c),
438 CV_BIG_INT(0x0201b32115cd1f9a),
439 CV_BIG_INT(0x0513cb37abcd1234),
440 CV_BIG_INT(0x0001a2b3c4d5f678)
443 int seedCount = sizeof(seeds)/sizeof(seeds[0]);
446 initSeed = rng.state;
448 rng.state = seeds[rng(seedCount)];
450 modelName = _modelName;
460 if( !modelName.compare(CV_NBAYES) )
461 nbayes = new CvNormalBayesClassifier;
462 else if( !modelName.compare(CV_KNEAREST) )
463 knearest = new CvKNearest;
464 else if( !modelName.compare(CV_SVM) )
466 else if( !modelName.compare(CV_EM) )
468 else if( !modelName.compare(CV_ANN) )
470 else if( !modelName.compare(CV_DTREE) )
472 else if( !modelName.compare(CV_BOOST) )
474 else if( !modelName.compare(CV_RTREES) )
475 rtrees = new CvRTrees;
476 else if( !modelName.compare(CV_ERTREES) )
477 ertrees = new CvERTrees;
480 CV_MLBaseTest::~CV_MLBaseTest()
482 if( validationFS.isOpened() )
483 validationFS.release();
502 theRNG().state = initSeed;
505 int CV_MLBaseTest::read_params( CvFileStorage* _fs )
508 test_case_count = -1;
511 CvFileNode* fn = cvGetRootFileNode( _fs, 0 );
512 fn = (CvFileNode*)cvGetSeqElem( fn->data.seq, 0 );
513 fn = cvGetFileNodeByName( _fs, fn, "run_params" );
514 CvSeq* dataSetNamesSeq = cvGetFileNodeByName( _fs, fn, modelName.c_str() )->data.seq;
515 test_case_count = dataSetNamesSeq ? dataSetNamesSeq->total : -1;
516 if( test_case_count > 0 )
518 dataSetNames.resize( test_case_count );
519 vector<string>::iterator it = dataSetNames.begin();
520 for( int i = 0; i < test_case_count; i++, it++ )
521 *it = ((CvFileNode*)cvGetSeqElem( dataSetNamesSeq, i ))->data.str.ptr;
524 return cvtest::TS::OK;;
527 void CV_MLBaseTest::run( int start_from )
529 string filename = ts->get_data_path();
530 filename += get_validation_filename();
531 validationFS.open( filename, FileStorage::READ );
532 read_params( *validationFS );
534 int code = cvtest::TS::OK;
536 for (int i = 0; i < test_case_count; i++)
538 int temp_code = run_test_case( i );
539 if (temp_code == cvtest::TS::OK)
540 temp_code = validate_test_results( i );
541 if (temp_code != cvtest::TS::OK)
544 if ( test_case_count <= 0)
546 ts->printf( cvtest::TS::LOG, "validation file is not determined or not correct" );
547 code = cvtest::TS::FAIL_INVALID_TEST_DATA;
549 ts->set_failed_test_info( code );
552 int CV_MLBaseTest::prepare_test_case( int test_case_idx )
554 int trainSampleCount, respIdx;
558 string dataPath = ts->get_data_path();
559 if ( dataPath.empty() )
561 ts->printf( cvtest::TS::LOG, "data path is empty" );
562 return cvtest::TS::FAIL_INVALID_TEST_DATA;
565 string dataName = dataSetNames[test_case_idx],
566 filename = dataPath + dataName + ".data";
567 if ( data.read_csv( filename.c_str() ) != 0)
570 sprintf( msg, "file %s can not be read", filename.c_str() );
571 ts->printf( cvtest::TS::LOG, msg );
572 return cvtest::TS::FAIL_INVALID_TEST_DATA;
575 FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
576 CV_DbgAssert( !dataParamsNode.empty() );
578 CV_DbgAssert( !dataParamsNode["LS"].empty() );
579 dataParamsNode["LS"] >> trainSampleCount;
580 CvTrainTestSplit spl( trainSampleCount );
581 data.set_train_test_split( &spl );
583 CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
584 dataParamsNode["resp_idx"] >> respIdx;
585 data.set_response_idx( respIdx );
587 CV_DbgAssert( !dataParamsNode["types"].empty() );
588 dataParamsNode["types"] >> varTypes;
589 data.set_var_types( varTypes.c_str() );
591 return cvtest::TS::OK;
594 string& CV_MLBaseTest::get_validation_filename()
599 int CV_MLBaseTest::train( int testCaseIdx )
601 bool is_trained = false;
602 FileNode modelParamsNode =
603 validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
605 if( !modelName.compare(CV_NBAYES) )
606 is_trained = nbayes_train( nbayes, &data );
607 else if( !modelName.compare(CV_KNEAREST) )
610 //is_trained = knearest->train( &data );
612 else if( !modelName.compare(CV_SVM) )
614 string svm_type_str, kernel_type_str;
615 modelParamsNode["svm_type"] >> svm_type_str;
616 modelParamsNode["kernel_type"] >> kernel_type_str;
618 params.svm_type = str_to_svm_type( svm_type_str );
619 params.kernel_type = str_to_svm_kernel_type( kernel_type_str );
620 modelParamsNode["degree"] >> params.degree;
621 modelParamsNode["gamma"] >> params.gamma;
622 modelParamsNode["coef0"] >> params.coef0;
623 modelParamsNode["C"] >> params.C;
624 modelParamsNode["nu"] >> params.nu;
625 modelParamsNode["p"] >> params.p;
626 is_trained = svm_train( svm, &data, params );
628 else if( !modelName.compare(CV_EM) )
632 else if( !modelName.compare(CV_ANN) )
634 string train_method_str;
635 double param1, param2;
636 modelParamsNode["train_method"] >> train_method_str;
637 modelParamsNode["param1"] >> param1;
638 modelParamsNode["param2"] >> param2;
640 ann_get_new_responses( &data, new_responses, cls_map );
641 int layer_sz[] = { data.get_values()->cols - 1, 100, 100, (int)cls_map.size() };
643 cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
644 ann->create( &layer_sizes );
645 is_trained = ann_train( ann, &data, new_responses, CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01),
646 str_to_ann_train_method(train_method_str), param1, param2) ) >= 0;
648 else if( !modelName.compare(CV_DTREE) )
650 int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;
651 float REG_ACCURACY = 0;
652 bool USE_SURROGATE, IS_PRUNED;
653 modelParamsNode["max_depth"] >> MAX_DEPTH;
654 modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
655 modelParamsNode["use_surrogate"] >> USE_SURROGATE;
656 modelParamsNode["max_categories"] >> MAX_CATEGORIES;
657 modelParamsNode["cv_folds"] >> CV_FOLDS;
658 modelParamsNode["is_pruned"] >> IS_PRUNED;
659 is_trained = dtree->train( &data,
660 CvDTreeParams(MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, USE_SURROGATE,
661 MAX_CATEGORIES, CV_FOLDS, false, IS_PRUNED, 0 )) != 0;
663 else if( !modelName.compare(CV_BOOST) )
665 int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;
666 float WEIGHT_TRIM_RATE;
669 modelParamsNode["type"] >> typeStr;
670 BOOST_TYPE = str_to_boost_type( typeStr );
671 modelParamsNode["weak_count"] >> WEAK_COUNT;
672 modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;
673 modelParamsNode["max_depth"] >> MAX_DEPTH;
674 modelParamsNode["use_surrogate"] >> USE_SURROGATE;
675 is_trained = boost->train( &data,
676 CvBoostParams(BOOST_TYPE, WEAK_COUNT, WEIGHT_TRIM_RATE, MAX_DEPTH, USE_SURROGATE, 0) ) != 0;
678 else if( !modelName.compare(CV_RTREES) )
680 int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
681 float REG_ACCURACY = 0, OOB_EPS = 0.0;
682 bool USE_SURROGATE, IS_PRUNED;
683 modelParamsNode["max_depth"] >> MAX_DEPTH;
684 modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
685 modelParamsNode["use_surrogate"] >> USE_SURROGATE;
686 modelParamsNode["max_categories"] >> MAX_CATEGORIES;
687 modelParamsNode["cv_folds"] >> CV_FOLDS;
688 modelParamsNode["is_pruned"] >> IS_PRUNED;
689 modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
690 modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
691 is_trained = rtrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,
692 USE_SURROGATE, MAX_CATEGORIES, 0, true, // (calc_var_importance == true) <=> RF processes variable importance
693 NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
695 else if( !modelName.compare(CV_ERTREES) )
697 int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
698 float REG_ACCURACY = 0, OOB_EPS = 0.0;
699 bool USE_SURROGATE, IS_PRUNED;
700 modelParamsNode["max_depth"] >> MAX_DEPTH;
701 modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
702 modelParamsNode["use_surrogate"] >> USE_SURROGATE;
703 modelParamsNode["max_categories"] >> MAX_CATEGORIES;
704 modelParamsNode["cv_folds"] >> CV_FOLDS;
705 modelParamsNode["is_pruned"] >> IS_PRUNED;
706 modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
707 modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
708 is_trained = ertrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,
709 USE_SURROGATE, MAX_CATEGORIES, 0, false, // (calc_var_importance == true) <=> RF processes variable importance
710 NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
715 ts->printf( cvtest::TS::LOG, "in test case %d model training was failed", testCaseIdx );
716 return cvtest::TS::FAIL_INVALID_OUTPUT;
718 return cvtest::TS::OK;
721 float CV_MLBaseTest::get_error( int testCaseIdx, int type, vector<float> *resp )
724 if( !modelName.compare(CV_NBAYES) )
725 err = nbayes_calc_error( nbayes, &data, type, resp );
726 else if( !modelName.compare(CV_KNEAREST) )
731 validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]["k"] >> k;
732 err = knearest->calc_error( &data, k, type, resp );*/
734 else if( !modelName.compare(CV_SVM) )
735 err = svm_calc_error( svm, &data, type, resp );
736 else if( !modelName.compare(CV_EM) )
738 else if( !modelName.compare(CV_ANN) )
739 err = ann_calc_error( ann, &data, cls_map, type, resp );
740 else if( !modelName.compare(CV_DTREE) )
741 err = dtree->calc_error( &data, type, resp );
742 else if( !modelName.compare(CV_BOOST) )
743 err = boost->calc_error( &data, type, resp );
744 else if( !modelName.compare(CV_RTREES) )
745 err = rtrees->calc_error( &data, type, resp );
746 else if( !modelName.compare(CV_ERTREES) )
747 err = ertrees->calc_error( &data, type, resp );
751 void CV_MLBaseTest::save( const char* filename )
753 if( !modelName.compare(CV_NBAYES) )
754 nbayes->save( filename );
755 else if( !modelName.compare(CV_KNEAREST) )
756 knearest->save( filename );
757 else if( !modelName.compare(CV_SVM) )
758 svm->save( filename );
759 else if( !modelName.compare(CV_EM) )
760 em->save( filename );
761 else if( !modelName.compare(CV_ANN) )
762 ann->save( filename );
763 else if( !modelName.compare(CV_DTREE) )
764 dtree->save( filename );
765 else if( !modelName.compare(CV_BOOST) )
766 boost->save( filename );
767 else if( !modelName.compare(CV_RTREES) )
768 rtrees->save( filename );
769 else if( !modelName.compare(CV_ERTREES) )
770 ertrees->save( filename );
773 void CV_MLBaseTest::load( const char* filename )
775 if( !modelName.compare(CV_NBAYES) )
776 nbayes->load( filename );
777 else if( !modelName.compare(CV_KNEAREST) )
778 knearest->load( filename );
779 else if( !modelName.compare(CV_SVM) )
780 svm->load( filename );
781 else if( !modelName.compare(CV_EM) )
782 em->load( filename );
783 else if( !modelName.compare(CV_ANN) )
784 ann->load( filename );
785 else if( !modelName.compare(CV_DTREE) )
786 dtree->load( filename );
787 else if( !modelName.compare(CV_BOOST) )
788 boost->load( filename );
789 else if( !modelName.compare(CV_RTREES) )
790 rtrees->load( filename );
791 else if( !modelName.compare(CV_ERTREES) )
792 ertrees->load( filename );