fixed many warnings from GCC 4.6.1
[profile/ivi/opencv.git] / modules / ml / test / test_mltests2.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "test_precomp.hpp"
43
44 using namespace cv;
45 using namespace std;
46
47 // auxiliary functions
48 // 1. nbayes
49 void nbayes_check_data( CvMLData* _data )
50 {
51     if( _data->get_missing() )
52         CV_Error( CV_StsBadArg, "missing values are not supported" );
53     const CvMat* var_types = _data->get_var_types();
54     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
55     if( ( fabs( cvNorm( var_types, 0, CV_L1 ) - 
56         (var_types->rows + var_types->cols - 2)*CV_VAR_ORDERED - CV_VAR_CATEGORICAL ) > FLT_EPSILON ) ||
57         !is_classifier )
58         CV_Error( CV_StsBadArg, "incorrect types of predictors or responses" );
59 }
60 bool nbayes_train( CvNormalBayesClassifier* nbayes, CvMLData* _data )
61 {
62     nbayes_check_data( _data );
63     const CvMat* values = _data->get_values();
64     const CvMat* responses = _data->get_responses();
65     const CvMat* train_sidx = _data->get_train_sample_idx();
66     const CvMat* var_idx = _data->get_var_idx();
67     return nbayes->train( values, responses, var_idx, train_sidx );
68 }
69 float nbayes_calc_error( CvNormalBayesClassifier* nbayes, CvMLData* _data, int type, vector<float> *resp )
70 {
71     float err = 0;
72     nbayes_check_data( _data );
73     const CvMat* values = _data->get_values();
74     const CvMat* response = _data->get_responses();
75     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
76     int* sidx = sample_idx ? sample_idx->data.i : 0;
77     int r_step = CV_IS_MAT_CONT(response->type) ?
78         1 : response->step / CV_ELEM_SIZE(response->type);
79     int sample_count = sample_idx ? sample_idx->cols : 0;
80     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
81     float* pred_resp = 0;
82     if( resp && (sample_count > 0) )
83     {
84         resp->resize( sample_count );
85         pred_resp = &((*resp)[0]);
86     }
87
88     for( int i = 0; i < sample_count; i++ )
89     {
90         CvMat sample;
91         int si = sidx ? sidx[i] : i;
92         cvGetRow( values, &sample, si ); 
93         float r = (float)nbayes->predict( &sample, 0 );
94         if( pred_resp )
95             pred_resp[i] = r;
96         int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
97         err += d;
98     }
99     err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
100     return err;
101 }
102
103 // 2. knearest
104 void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors )
105 {
106     const CvMat* values = _data->get_values();
107     const CvMat* var_idx = _data->get_var_idx();
108     if( var_idx->cols + var_idx->rows != values->cols )
109         CV_Error( CV_StsBadArg, "var_idx is not supported" );
110     if( _data->get_missing() )
111         CV_Error( CV_StsBadArg, "missing values are not supported" );
112     int resp_idx = _data->get_response_idx();
113     if( resp_idx == 0)
114         cvGetCols( values, _predictors, 1, values->cols );
115     else if( resp_idx == values->cols - 1 )
116         cvGetCols( values, _predictors, 0, values->cols - 1 );
117     else
118         CV_Error( CV_StsBadArg, "responses must be in the first or last column; other cases are not supported" );
119 }
120 bool knearest_train( CvKNearest* knearest, CvMLData* _data )
121 {
122     const CvMat* responses = _data->get_responses();
123     const CvMat* train_sidx = _data->get_train_sample_idx();
124     bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;
125     CvMat predictors;
126     knearest_check_data_and_get_predictors( _data, &predictors );
127     return knearest->train( &predictors, responses, train_sidx, is_regression );
128 }
129 float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int type, vector<float> *resp )
130 {
131     float err = 0;
132     const CvMat* response = _data->get_responses();
133     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
134     int* sidx = sample_idx ? sample_idx->data.i : 0;
135     int r_step = CV_IS_MAT_CONT(response->type) ?
136         1 : response->step / CV_ELEM_SIZE(response->type);
137     bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;
138     CvMat predictors;
139     knearest_check_data_and_get_predictors( _data, &predictors );
140     int sample_count = sample_idx ? sample_idx->cols : 0;
141     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;
142     float* pred_resp = 0;
143     if( resp && (sample_count > 0) )
144     {
145         resp->resize( sample_count );
146         pred_resp = &((*resp)[0]);
147     }
148     if ( !is_regression )
149     {
150         for( int i = 0; i < sample_count; i++ )
151         {
152             CvMat sample;
153             int si = sidx ? sidx[i] : i;
154             cvGetRow( &predictors, &sample, si ); 
155             float r = knearest->find_nearest( &sample, k );
156             if( pred_resp )
157                 pred_resp[i] = r;
158             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
159             err += d;
160         }
161         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
162     }
163     else
164     {
165         for( int i = 0; i < sample_count; i++ )
166         {
167             CvMat sample;
168             int si = sidx ? sidx[i] : i;
169             cvGetRow( &predictors, &sample, si ); 
170             float r = knearest->find_nearest( &sample, k );
171             if( pred_resp )
172                 pred_resp[i] = r;
173             float d = r - response->data.fl[si*r_step];
174             err += d*d;
175         }
176         err = sample_count ? err / (float)sample_count : -FLT_MAX;    
177     }
178     return err;
179 }
180
181 // 3. svm
182 int str_to_svm_type(string& str)
183 {
184     if( !str.compare("C_SVC") )
185         return CvSVM::C_SVC;
186     if( !str.compare("NU_SVC") )
187         return CvSVM::NU_SVC;
188     if( !str.compare("ONE_CLASS") )
189         return CvSVM::ONE_CLASS;
190     if( !str.compare("EPS_SVR") )
191         return CvSVM::EPS_SVR;
192     if( !str.compare("NU_SVR") )
193         return CvSVM::NU_SVR;
194     CV_Error( CV_StsBadArg, "incorrect svm type string" );
195     return -1;
196 }
197 int str_to_svm_kernel_type( string& str )
198 {
199     if( !str.compare("LINEAR") )
200         return CvSVM::LINEAR;
201     if( !str.compare("POLY") )
202         return CvSVM::POLY;
203     if( !str.compare("RBF") )
204         return CvSVM::RBF;
205     if( !str.compare("SIGMOID") )
206         return CvSVM::SIGMOID;
207     CV_Error( CV_StsBadArg, "incorrect svm type string" );
208     return -1;
209 }
210 void svm_check_data( CvMLData* _data )
211 {
212     if( _data->get_missing() )
213         CV_Error( CV_StsBadArg, "missing values are not supported" );
214     const CvMat* var_types = _data->get_var_types();
215     for( int i = 0; i < var_types->cols-1; i++ )
216         if (var_types->data.ptr[i] == CV_VAR_CATEGORICAL)
217         {
218             char msg[50];
219             sprintf( msg, "incorrect type of %d-predictor", i );
220             CV_Error( CV_StsBadArg, msg );
221         }
222 }
223 bool svm_train( CvSVM* svm, CvMLData* _data, CvSVMParams _params )
224 {
225     svm_check_data(_data);
226     const CvMat* _train_data = _data->get_values();
227     const CvMat* _responses = _data->get_responses();
228     const CvMat* _var_idx = _data->get_var_idx();
229     const CvMat* _sample_idx = _data->get_train_sample_idx();
230     return svm->train( _train_data, _responses, _var_idx, _sample_idx, _params );
231 }
232 bool svm_train_auto( CvSVM* svm, CvMLData* _data, CvSVMParams _params,
233                     int k_fold, CvParamGrid C_grid, CvParamGrid gamma_grid,
234                     CvParamGrid p_grid, CvParamGrid nu_grid, CvParamGrid coef_grid,
235                     CvParamGrid degree_grid )
236 {
237     svm_check_data(_data);
238     const CvMat* _train_data = _data->get_values();
239     const CvMat* _responses = _data->get_responses();
240     const CvMat* _var_idx = _data->get_var_idx();
241     const CvMat* _sample_idx = _data->get_train_sample_idx();
242     return svm->train_auto( _train_data, _responses, _var_idx, 
243         _sample_idx, _params, k_fold, C_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
244 }
245 float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector<float> *resp )
246 {
247     svm_check_data(_data);
248     float err = 0;
249     const CvMat* values = _data->get_values();
250     const CvMat* response = _data->get_responses();
251     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
252     const CvMat* var_types = _data->get_var_types();
253     int* sidx = sample_idx ? sample_idx->data.i : 0;
254     int r_step = CV_IS_MAT_CONT(response->type) ?
255         1 : response->step / CV_ELEM_SIZE(response->type);
256     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
257     int sample_count = sample_idx ? sample_idx->cols : 0;
258     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
259     float* pred_resp = 0;
260     if( resp && (sample_count > 0) )
261     {
262         resp->resize( sample_count );
263         pred_resp = &((*resp)[0]);
264     }
265     if ( is_classifier )
266     {
267         for( int i = 0; i < sample_count; i++ )
268         {
269             CvMat sample;
270             int si = sidx ? sidx[i] : i;
271             cvGetRow( values, &sample, si ); 
272             float r = svm->predict( &sample );
273             if( pred_resp )
274                 pred_resp[i] = r;
275             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
276             err += d;
277         }
278         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
279     }
280     else
281     {
282         for( int i = 0; i < sample_count; i++ )
283         {
284             CvMat sample;
285             int si = sidx ? sidx[i] : i;
286             cvGetRow( values, &sample, si );
287             float r = svm->predict( &sample );
288             if( pred_resp )
289                 pred_resp[i] = r;
290             float d = r - response->data.fl[si*r_step];
291             err += d*d;
292         }
293         err = sample_count ? err / (float)sample_count : -FLT_MAX;    
294     }
295     return err;
296 }
297
298 // 4. em
299 // 5. ann
300 int str_to_ann_train_method( string& str )
301 {
302     if( !str.compare("BACKPROP") )
303         return CvANN_MLP_TrainParams::BACKPROP;
304     if( !str.compare("RPROP") )
305         return CvANN_MLP_TrainParams::RPROP;
306     CV_Error( CV_StsBadArg, "incorrect ann train method string" );
307     return -1;
308 }
309 void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs )
310 {
311     const CvMat* values = _data->get_values();
312     const CvMat* var_idx = _data->get_var_idx();
313     if( var_idx->cols + var_idx->rows != values->cols )
314         CV_Error( CV_StsBadArg, "var_idx is not supported" );
315     if( _data->get_missing() )
316         CV_Error( CV_StsBadArg, "missing values are not supported" );
317     int resp_idx = _data->get_response_idx();
318     if( resp_idx == 0)
319         cvGetCols( values, _inputs, 1, values->cols );
320     else if( resp_idx == values->cols - 1 )
321         cvGetCols( values, _inputs, 0, values->cols - 1 );
322     else
323         CV_Error( CV_StsBadArg, "outputs must be in the first or last column; other cases are not supported" );
324 }
325 void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map<int, int>& cls_map )
326 {
327     const CvMat* train_sidx = _data->get_train_sample_idx();
328     int* train_sidx_ptr = train_sidx->data.i;
329     const CvMat* responses = _data->get_responses();
330     float* responses_ptr = responses->data.fl;
331     int r_step = CV_IS_MAT_CONT(responses->type) ?
332         1 : responses->step / CV_ELEM_SIZE(responses->type);
333     int cls_count = 0;
334     // construct cls_map
335     cls_map.clear();
336     for( int si = 0; si < train_sidx->cols; si++ )
337     {
338         int sidx = train_sidx_ptr[si];
339         int r = cvRound(responses_ptr[sidx*r_step]);
340         CV_DbgAssert( fabs(responses_ptr[sidx*r_step]-r) < FLT_EPSILON );
341         int cls_map_size = (int)cls_map.size();
342         cls_map[r];
343         if ( (int)cls_map.size() > cls_map_size )
344             cls_map[r] = cls_count++;
345     }
346     new_responses.create( responses->rows, cls_count, CV_32F );
347     new_responses.setTo( 0 );
348     for( int si = 0; si < train_sidx->cols; si++ )
349     {
350         int sidx = train_sidx_ptr[si];
351         int r = cvRound(responses_ptr[sidx*r_step]);
352         int cidx = cls_map[r];
353         new_responses.ptr<float>(sidx)[cidx] = 1;
354     }
355 }
356 int ann_train( CvANN_MLP* ann, CvMLData* _data, Mat& new_responses, CvANN_MLP_TrainParams _params, int flags = 0 )
357 {
358     const CvMat* train_sidx = _data->get_train_sample_idx();
359     CvMat predictors;
360     ann_check_data_and_get_predictors( _data, &predictors );
361     CvMat _new_responses = CvMat( new_responses );
362     return ann->train( &predictors, &_new_responses, 0, train_sidx, _params, flags );
363 }
364 float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map<int, int>& cls_map, int type , vector<float> *resp_labels )
365 {
366     float err = 0;
367     const CvMat* responses = _data->get_responses();
368     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
369     int* sidx = sample_idx ? sample_idx->data.i : 0;
370     int r_step = CV_IS_MAT_CONT(responses->type) ?
371         1 : responses->step / CV_ELEM_SIZE(responses->type);
372     CvMat predictors;
373     ann_check_data_and_get_predictors( _data, &predictors );
374     int sample_count = sample_idx ? sample_idx->cols : 0;
375     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;
376     float* pred_resp = 0;
377     vector<float> innresp;
378     if( sample_count > 0 )
379     {
380         if( resp_labels )
381         {
382             resp_labels->resize( sample_count );
383             pred_resp = &((*resp_labels)[0]);
384         }
385         else
386         {
387             innresp.resize( sample_count );
388             pred_resp = &(innresp[0]);
389         }
390     }
391     int cls_count = (int)cls_map.size();
392     Mat output( 1, cls_count, CV_32FC1 );
393     CvMat _output = CvMat(output);
394     for( int i = 0; i < sample_count; i++ )
395     {
396         CvMat sample;
397         int si = sidx ? sidx[i] : i;
398         cvGetRow( &predictors, &sample, si ); 
399         ann->predict( &sample, &_output );
400         CvPoint best_cls = {0,0};
401         cvMinMaxLoc( &_output, 0, 0, 0, &best_cls, 0 );
402         int r = cvRound(responses->data.fl[si*r_step]);
403         CV_DbgAssert( fabs(responses->data.fl[si*r_step]-r) < FLT_EPSILON );
404         r = cls_map[r];
405         int d = best_cls.x == r ? 0 : 1;
406         err += d;
407         pred_resp[i] = (float)best_cls.x;
408     }
409     err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
410     return err;
411 }
412
413 // 6. dtree
414 // 7. boost
415 int str_to_boost_type( string& str )
416 {
417     if ( !str.compare("DISCRETE") )
418         return CvBoost::DISCRETE;
419     if ( !str.compare("REAL") )
420         return CvBoost::REAL;    
421     if ( !str.compare("LOGIT") )
422         return CvBoost::LOGIT;
423     if ( !str.compare("GENTLE") )
424         return CvBoost::GENTLE;
425     CV_Error( CV_StsBadArg, "incorrect boost type string" );
426     return -1;
427 }
428
429 // 8. rtrees
430 // 9. ertrees
431
432 // ---------------------------------- MLBaseTest ---------------------------------------------------
433
434 CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
435 {
436     int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
437                       CV_BIG_INT(0x0000a17166072c7c),
438                       CV_BIG_INT(0x0201b32115cd1f9a),
439                       CV_BIG_INT(0x0513cb37abcd1234),
440                       CV_BIG_INT(0x0001a2b3c4d5f678)
441                     };
442
443     int seedCount = sizeof(seeds)/sizeof(seeds[0]);
444     RNG& rng = theRNG();
445
446     initSeed = rng.state;
447
448     rng.state = seeds[rng(seedCount)];
449
450     modelName = _modelName;
451     nbayes = 0;
452     knearest = 0;
453     svm = 0;
454     em = 0;
455     ann = 0;
456     dtree = 0;
457     boost = 0;
458     rtrees = 0;
459     ertrees = 0;
460     if( !modelName.compare(CV_NBAYES) )
461         nbayes = new CvNormalBayesClassifier;
462     else if( !modelName.compare(CV_KNEAREST) )
463         knearest = new CvKNearest;
464     else if( !modelName.compare(CV_SVM) )
465         svm = new CvSVM;
466     else if( !modelName.compare(CV_EM) )
467         em = new CvEM;
468     else if( !modelName.compare(CV_ANN) )
469         ann = new CvANN_MLP;
470     else if( !modelName.compare(CV_DTREE) )
471         dtree = new CvDTree;
472     else if( !modelName.compare(CV_BOOST) )
473         boost = new CvBoost;
474     else if( !modelName.compare(CV_RTREES) )
475         rtrees = new CvRTrees;
476     else if( !modelName.compare(CV_ERTREES) )
477         ertrees = new CvERTrees;
478 }
479
480 CV_MLBaseTest::~CV_MLBaseTest()
481 {
482     if( validationFS.isOpened() )
483         validationFS.release();
484     if( nbayes )
485         delete nbayes;
486     if( knearest ) 
487         delete knearest;
488     if( svm )
489         delete svm;
490     if( em )
491         delete em;
492     if( ann )
493         delete ann;
494     if( dtree )
495         delete dtree;
496     if( boost )
497         delete boost;
498     if( rtrees )
499         delete rtrees;
500     if( ertrees )
501         delete ertrees;
502     theRNG().state = initSeed;
503 }
504
505 int CV_MLBaseTest::read_params( CvFileStorage* _fs )
506 {
507     if( !_fs )
508         test_case_count = -1;
509     else
510     {
511         CvFileNode* fn = cvGetRootFileNode( _fs, 0 );
512         fn = (CvFileNode*)cvGetSeqElem( fn->data.seq, 0 );
513         fn = cvGetFileNodeByName( _fs, fn, "run_params" );
514         CvSeq* dataSetNamesSeq = cvGetFileNodeByName( _fs, fn, modelName.c_str() )->data.seq;
515         test_case_count = dataSetNamesSeq ? dataSetNamesSeq->total : -1;
516         if( test_case_count > 0 )
517         {
518             dataSetNames.resize( test_case_count );
519             vector<string>::iterator it = dataSetNames.begin();
520             for( int i = 0; i < test_case_count; i++, it++ )
521                 *it = ((CvFileNode*)cvGetSeqElem( dataSetNamesSeq, i ))->data.str.ptr;
522         }
523     }
524     return cvtest::TS::OK;;
525 }
526
527 void CV_MLBaseTest::run( int start_from )
528 {
529     string filename = ts->get_data_path();
530     filename += get_validation_filename();
531     validationFS.open( filename, FileStorage::READ );
532     read_params( *validationFS );
533     
534     int code = cvtest::TS::OK;
535     start_from = 0;
536     for (int i = 0; i < test_case_count; i++)
537     {
538         int temp_code = run_test_case( i );
539         if (temp_code == cvtest::TS::OK)
540             temp_code = validate_test_results( i );
541         if (temp_code != cvtest::TS::OK)
542             code = temp_code;
543     }
544     if ( test_case_count <= 0)
545     {
546         ts->printf( cvtest::TS::LOG, "validation file is not determined or not correct" );
547         code = cvtest::TS::FAIL_INVALID_TEST_DATA;
548     }
549     ts->set_failed_test_info( code );
550 }
551
552 int CV_MLBaseTest::prepare_test_case( int test_case_idx )
553 {
554     int trainSampleCount, respIdx;
555     string varTypes;
556     clear();
557
558     string dataPath = ts->get_data_path();
559     if ( dataPath.empty() )
560     {
561         ts->printf( cvtest::TS::LOG, "data path is empty" );
562         return cvtest::TS::FAIL_INVALID_TEST_DATA;
563     }
564
565     string dataName = dataSetNames[test_case_idx],
566         filename = dataPath + dataName + ".data";
567     if ( data.read_csv( filename.c_str() ) != 0)
568     {
569         char msg[100];
570         sprintf( msg, "file %s can not be read", filename.c_str() );
571         ts->printf( cvtest::TS::LOG, msg );
572         return cvtest::TS::FAIL_INVALID_TEST_DATA;
573     }
574
575     FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
576     CV_DbgAssert( !dataParamsNode.empty() );
577
578     CV_DbgAssert( !dataParamsNode["LS"].empty() );
579     dataParamsNode["LS"] >> trainSampleCount;
580     CvTrainTestSplit spl( trainSampleCount );
581     data.set_train_test_split( &spl );
582
583     CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
584     dataParamsNode["resp_idx"] >> respIdx;
585     data.set_response_idx( respIdx );
586
587     CV_DbgAssert( !dataParamsNode["types"].empty() );
588     dataParamsNode["types"] >> varTypes;
589     data.set_var_types( varTypes.c_str() );
590
591     return cvtest::TS::OK;
592 }
593
594 string& CV_MLBaseTest::get_validation_filename()
595 {
596     return validationFN;
597 }
598
599 int CV_MLBaseTest::train( int testCaseIdx )
600 {
601     bool is_trained = false;
602     FileNode modelParamsNode = 
603         validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
604
605     if( !modelName.compare(CV_NBAYES) )
606         is_trained = nbayes_train( nbayes, &data );
607     else if( !modelName.compare(CV_KNEAREST) )
608     {
609         assert( 0 );
610         //is_trained = knearest->train( &data );
611     }
612     else if( !modelName.compare(CV_SVM) )
613     {
614         string svm_type_str, kernel_type_str;
615         modelParamsNode["svm_type"] >> svm_type_str;
616         modelParamsNode["kernel_type"] >> kernel_type_str;
617         CvSVMParams params;
618         params.svm_type = str_to_svm_type( svm_type_str );
619         params.kernel_type = str_to_svm_kernel_type( kernel_type_str );
620         modelParamsNode["degree"] >> params.degree;
621         modelParamsNode["gamma"] >> params.gamma;
622         modelParamsNode["coef0"] >> params.coef0;
623         modelParamsNode["C"] >> params.C;
624         modelParamsNode["nu"] >> params.nu;
625         modelParamsNode["p"] >> params.p;
626         is_trained = svm_train( svm, &data, params );
627     }
628     else if( !modelName.compare(CV_EM) )
629     {
630         assert( 0 );
631     }
632     else if( !modelName.compare(CV_ANN) )
633     {
634         string train_method_str;
635         double param1, param2;
636         modelParamsNode["train_method"] >> train_method_str;
637         modelParamsNode["param1"] >> param1;
638         modelParamsNode["param2"] >> param2;
639         Mat new_responses;
640         ann_get_new_responses( &data, new_responses, cls_map );
641         int layer_sz[] = { data.get_values()->cols - 1, 100, 100, (int)cls_map.size() };
642         CvMat layer_sizes =
643             cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
644         ann->create( &layer_sizes );
645         is_trained = ann_train( ann, &data, new_responses, CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01),
646             str_to_ann_train_method(train_method_str), param1, param2) ) >= 0;
647     }
648     else if( !modelName.compare(CV_DTREE) )
649     {
650         int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;
651         float REG_ACCURACY = 0;
652         bool USE_SURROGATE, IS_PRUNED;
653         modelParamsNode["max_depth"] >> MAX_DEPTH;
654         modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
655         modelParamsNode["use_surrogate"] >> USE_SURROGATE;
656         modelParamsNode["max_categories"] >> MAX_CATEGORIES;
657         modelParamsNode["cv_folds"] >> CV_FOLDS;
658         modelParamsNode["is_pruned"] >> IS_PRUNED;
659         is_trained = dtree->train( &data, 
660             CvDTreeParams(MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, USE_SURROGATE,
661             MAX_CATEGORIES, CV_FOLDS, false, IS_PRUNED, 0 )) != 0;
662     }
663     else if( !modelName.compare(CV_BOOST) )
664     {
665         int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;
666         float WEIGHT_TRIM_RATE;
667         bool USE_SURROGATE;
668         string typeStr;
669         modelParamsNode["type"] >> typeStr;
670         BOOST_TYPE = str_to_boost_type( typeStr );
671         modelParamsNode["weak_count"] >> WEAK_COUNT;
672         modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;
673         modelParamsNode["max_depth"] >> MAX_DEPTH;
674         modelParamsNode["use_surrogate"] >> USE_SURROGATE;
675         is_trained = boost->train( &data,
676             CvBoostParams(BOOST_TYPE, WEAK_COUNT, WEIGHT_TRIM_RATE, MAX_DEPTH, USE_SURROGATE, 0) ) != 0;
677     }
678     else if( !modelName.compare(CV_RTREES) )
679     {
680         int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
681         float REG_ACCURACY = 0, OOB_EPS = 0.0;
682         bool USE_SURROGATE, IS_PRUNED;
683         modelParamsNode["max_depth"] >> MAX_DEPTH;
684         modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
685         modelParamsNode["use_surrogate"] >> USE_SURROGATE;
686         modelParamsNode["max_categories"] >> MAX_CATEGORIES;
687         modelParamsNode["cv_folds"] >> CV_FOLDS;
688         modelParamsNode["is_pruned"] >> IS_PRUNED;
689         modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
690         modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
691         is_trained = rtrees->train( &data, CvRTParams(  MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,
692             USE_SURROGATE, MAX_CATEGORIES, 0, true, // (calc_var_importance == true) <=> RF processes variable importance
693             NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
694     }
695     else if( !modelName.compare(CV_ERTREES) )
696     {
697         int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
698         float REG_ACCURACY = 0, OOB_EPS = 0.0;
699         bool USE_SURROGATE, IS_PRUNED;
700         modelParamsNode["max_depth"] >> MAX_DEPTH;
701         modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
702         modelParamsNode["use_surrogate"] >> USE_SURROGATE;
703         modelParamsNode["max_categories"] >> MAX_CATEGORIES;
704         modelParamsNode["cv_folds"] >> CV_FOLDS;
705         modelParamsNode["is_pruned"] >> IS_PRUNED;
706         modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
707         modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
708         is_trained = ertrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,
709             USE_SURROGATE, MAX_CATEGORIES, 0, false, // (calc_var_importance == true) <=> RF processes variable importance
710             NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
711     }
712
713     if( !is_trained )
714     {
715         ts->printf( cvtest::TS::LOG, "in test case %d model training was failed", testCaseIdx );
716         return cvtest::TS::FAIL_INVALID_OUTPUT;
717     }
718     return cvtest::TS::OK;
719 }
720
721 float CV_MLBaseTest::get_error( int testCaseIdx, int type, vector<float> *resp )
722 {
723     float err = 0;
724     if( !modelName.compare(CV_NBAYES) )
725         err = nbayes_calc_error( nbayes, &data, type, resp );
726     else if( !modelName.compare(CV_KNEAREST) )
727     {
728         assert( 0 );
729         testCaseIdx = 0;
730         /*int k = 2;
731         validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]["k"] >> k;
732         err = knearest->calc_error( &data, k, type, resp );*/
733     }
734     else if( !modelName.compare(CV_SVM) )
735         err = svm_calc_error( svm, &data, type, resp );
736     else if( !modelName.compare(CV_EM) )
737         assert( 0 );
738     else if( !modelName.compare(CV_ANN) )
739         err = ann_calc_error( ann, &data, cls_map, type, resp );
740     else if( !modelName.compare(CV_DTREE) )
741         err = dtree->calc_error( &data, type, resp );
742     else if( !modelName.compare(CV_BOOST) )
743         err = boost->calc_error( &data, type, resp );
744     else if( !modelName.compare(CV_RTREES) )
745         err = rtrees->calc_error( &data, type, resp );
746     else if( !modelName.compare(CV_ERTREES) )
747         err = ertrees->calc_error( &data, type, resp );
748     return err;
749 }
750
751 void CV_MLBaseTest::save( const char* filename )
752 {
753     if( !modelName.compare(CV_NBAYES) )
754         nbayes->save( filename );
755     else if( !modelName.compare(CV_KNEAREST) )
756         knearest->save( filename );
757     else if( !modelName.compare(CV_SVM) )
758         svm->save( filename );
759     else if( !modelName.compare(CV_EM) )
760         em->save( filename );
761     else if( !modelName.compare(CV_ANN) )
762         ann->save( filename );
763     else if( !modelName.compare(CV_DTREE) )
764         dtree->save( filename );
765     else if( !modelName.compare(CV_BOOST) )
766         boost->save( filename );
767     else if( !modelName.compare(CV_RTREES) )
768         rtrees->save( filename );
769     else if( !modelName.compare(CV_ERTREES) )
770         ertrees->save( filename );
771 }
772
773 void CV_MLBaseTest::load( const char* filename )
774 {
775     if( !modelName.compare(CV_NBAYES) )
776         nbayes->load( filename );
777     else if( !modelName.compare(CV_KNEAREST) )
778         knearest->load( filename );
779     else if( !modelName.compare(CV_SVM) )
780         svm->load( filename );
781     else if( !modelName.compare(CV_EM) )
782         em->load( filename );
783     else if( !modelName.compare(CV_ANN) )
784         ann->load( filename );
785     else if( !modelName.compare(CV_DTREE) )
786         dtree->load( filename );
787     else if( !modelName.compare(CV_BOOST) )
788         boost->load( filename );
789     else if( !modelName.compare(CV_RTREES) )
790         rtrees->load( filename );
791     else if( !modelName.compare(CV_ERTREES) )
792         ertrees->load( filename );
793 }
794
795 /* End of file. */