lot's of changes; nonfree & photo modules added; SIFT & SURF -> nonfree module; Inpai...
[profile/ivi/opencv.git] / samples / c / mushroom.cpp
1 #include "opencv2/core/core_c.h"
2 #include "opencv2/ml/ml.hpp"
3 #include <stdio.h>
4
5 void help()
6 {
7         printf("\nThis program demonstrated the use of OpenCV's decision tree function for learning and predicting data\n"
8             "Usage :\n"
9             "./mushroom <path to agaricus-lepiota.data>\n"
10             "\n"
11             "The sample demonstrates how to build a decision tree for classifying mushrooms.\n"
12             "It uses the sample base agaricus-lepiota.data from UCI Repository, here is the link:\n"
13             "\n"
14             "Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"
15             "UCI Repository of machine learning databases\n"
16             "[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"
17             "Irvine, CA: University of California, Department of Information and Computer Science.\n"
18             "\n"
19             "// loads the mushroom database, which is a text file, containing\n"
20             "// one training sample per row, all the input variables and the output variable are categorical,\n"
21             "// the values are encoded by characters.\n\n");
22 }
23
24 int mushroom_read_database( const char* filename, CvMat** data, CvMat** missing, CvMat** responses )
25 {
26     const int M = 1024;
27     FILE* f = fopen( filename, "rt" );
28     CvMemStorage* storage;
29     CvSeq* seq;
30     char buf[M+2], *ptr;
31     float* el_ptr;
32     CvSeqReader reader;
33     int i, j, var_count = 0;
34
35     if( !f )
36         return 0;
37
38     // read the first line and determine the number of variables
39     if( !fgets( buf, M, f ))
40     {
41         fclose(f);
42         return 0;
43     }
44
45     for( ptr = buf; *ptr != '\0'; ptr++ )
46         var_count += *ptr == ',';
47     assert( ptr - buf == (var_count+1)*2 );
48
49     // create temporary memory storage to store the whole database
50     el_ptr = new float[var_count+1];
51     storage = cvCreateMemStorage();
52     seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );
53
54     for(;;)
55     {
56         for( i = 0; i <= var_count; i++ )
57         {
58             int c = buf[i*2];
59             el_ptr[i] = c == '?' ? -1.f : (float)c;
60         }
61         if( i != var_count+1 )
62             break;
63         cvSeqPush( seq, el_ptr );
64         if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )
65             break;
66     }
67     fclose(f);
68
69     // allocate the output matrices and copy the base there
70     *data = cvCreateMat( seq->total, var_count, CV_32F );
71     *missing = cvCreateMat( seq->total, var_count, CV_8U );
72     *responses = cvCreateMat( seq->total, 1, CV_32F );
73
74     cvStartReadSeq( seq, &reader );
75
76     for( i = 0; i < seq->total; i++ )
77     {
78         const float* sdata = (float*)reader.ptr + 1;
79         float* ddata = data[0]->data.fl + var_count*i;
80         float* dr = responses[0]->data.fl + i;
81         uchar* dm = missing[0]->data.ptr + var_count*i;
82
83         for( j = 0; j < var_count; j++ )
84         {
85             ddata[j] = sdata[j];
86             dm[j] = sdata[j] < 0;
87         }
88         *dr = sdata[-1];
89         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
90     }
91
92     cvReleaseMemStorage( &storage );
93     delete el_ptr;
94     return 1;
95 }
96
97
98 CvDTree* mushroom_create_dtree( const CvMat* data, const CvMat* missing,
99                                 const CvMat* responses, float p_weight )
100 {
101     CvDTree* dtree;
102     CvMat* var_type;
103     int i, hr1 = 0, hr2 = 0, p_total = 0;
104     float priors[] = { 1, p_weight };
105
106     var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
107     cvSet( var_type, cvScalarAll(CV_VAR_CATEGORICAL) ); // all the variables are categorical
108
109     dtree = new CvDTree;
110     
111     dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
112                   CvDTreeParams( 8, // max depth
113                                  10, // min sample count
114                                  0, // regression accuracy: N/A here
115                                  true, // compute surrogate split, as we have missing data
116                                  15, // max number of categories (use sub-optimal algorithm for larger numbers)
117                                  10, // the number of cross-validation folds
118                                  true, // use 1SE rule => smaller tree
119                                  true, // throw away the pruned tree branches
120                                  priors // the array of priors, the bigger p_weight, the more attention
121                                         // to the poisonous mushrooms
122                                         // (a mushroom will be judjed to be poisonous with bigger chance)
123                                  ));
124
125     // compute hit-rate on the training database, demonstrates predict usage.
126     for( i = 0; i < data->rows; i++ )
127     {
128         CvMat sample, mask;
129         cvGetRow( data, &sample, i );
130         cvGetRow( missing, &mask, i );
131         double r = dtree->predict( &sample, &mask )->value;
132         int d = fabs(r - responses->data.fl[i]) >= FLT_EPSILON;
133         if( d )
134         {
135             if( r != 'p' )
136                 hr1++;
137             else
138                 hr2++;
139         }
140         p_total += responses->data.fl[i] == 'p';
141     }
142
143     printf( "Results on the training database:\n"
144             "\tPoisonous mushrooms mis-predicted: %d (%g%%)\n"
145             "\tFalse-alarms: %d (%g%%)\n", hr1, (double)hr1*100/p_total,
146             hr2, (double)hr2*100/(data->rows - p_total) );
147
148     cvReleaseMat( &var_type );
149
150     return dtree;
151 }
152
153
154 static const char* var_desc[] =
155 {
156     "cap shape (bell=b,conical=c,convex=x,flat=f)",
157     "cap surface (fibrous=f,grooves=g,scaly=y,smooth=s)",
158     "cap color (brown=n,buff=b,cinnamon=c,gray=g,green=r,\n\tpink=p,purple=u,red=e,white=w,yellow=y)",
159     "bruises? (bruises=t,no=f)",
160     "odor (almond=a,anise=l,creosote=c,fishy=y,foul=f,\n\tmusty=m,none=n,pungent=p,spicy=s)",
161     "gill attachment (attached=a,descending=d,free=f,notched=n)",
162     "gill spacing (close=c,crowded=w,distant=d)",
163     "gill size (broad=b,narrow=n)",
164     "gill color (black=k,brown=n,buff=b,chocolate=h,gray=g,\n\tgreen=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y)",
165     "stalk shape (enlarging=e,tapering=t)",
166     "stalk root (bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r)",
167     "stalk surface above ring (ibrous=f,scaly=y,silky=k,smooth=s)",
168     "stalk surface below ring (ibrous=f,scaly=y,silky=k,smooth=s)",
169     "stalk color above ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
170     "stalk color below ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
171     "veil type (partial=p,universal=u)",
172     "veil color (brown=n,orange=o,white=w,yellow=y)",
173     "ring number (none=n,one=o,two=t)",
174     "ring type (cobwebby=c,evanescent=e,flaring=f,large=l,\n\tnone=n,pendant=p,sheathing=s,zone=z)",
175     "spore print color (black=k,brown=n,buff=b,chocolate=h,green=r,\n\torange=o,purple=u,white=w,yellow=y)",
176     "population (abundant=a,clustered=c,numerous=n,\n\tscattered=s,several=v,solitary=y)",
177     "habitat (grasses=g,leaves=l,meadows=m,paths=p\n\turban=u,waste=w,woods=d)",
178     0
179 };
180
181
182 void print_variable_importance( CvDTree* dtree, const char** var_desc )
183 {
184     const CvMat* var_importance = dtree->get_var_importance();
185     int i;
186     char input[1000];
187
188     if( !var_importance )
189     {
190         printf( "Error: Variable importance can not be retrieved\n" );
191         return;
192     }
193
194     printf( "Print variable importance information? (y/n) " );
195     scanf( "%1s", input );
196     if( input[0] != 'y' && input[0] != 'Y' )
197         return;
198
199     for( i = 0; i < var_importance->cols*var_importance->rows; i++ )
200     {
201         double val = var_importance->data.db[i];
202         if( var_desc )
203         {
204             char buf[100];
205             int len = strchr( var_desc[i], '(' ) - var_desc[i] - 1;
206             strncpy( buf, var_desc[i], len );
207             buf[len] = '\0';
208             printf( "%s", buf );
209         }
210         else
211             printf( "var #%d", i );
212         printf( ": %g%%\n", val*100. );
213     }
214 }
215
216 void interactive_classification( CvDTree* dtree, const char** var_desc )
217 {
218     char input[1000];
219     const CvDTreeNode* root;
220     CvDTreeTrainData* data;
221
222     if( !dtree )
223         return;
224
225     root = dtree->get_root();
226     data = dtree->get_data();
227
228     for(;;)
229     {
230         const CvDTreeNode* node;
231         
232         printf( "Start/Proceed with interactive mushroom classification (y/n): " );
233         scanf( "%1s", input );
234         if( input[0] != 'y' && input[0] != 'Y' )
235             break;
236         printf( "Enter 1-letter answers, '?' for missing/unknown value...\n" ); 
237
238         // custom version of predict
239         node = root;
240         for(;;)
241         {
242             CvDTreeSplit* split = node->split;
243             int dir = 0;
244             
245             if( !node->left || node->Tn <= dtree->get_pruned_tree_idx() || !node->split )
246                 break;
247
248             for( ; split != 0; )
249             {
250                 int vi = split->var_idx, j;
251                 int count = data->cat_count->data.i[vi];
252                 const int* map = data->cat_map->data.i + data->cat_ofs->data.i[vi];
253
254                 printf( "%s: ", var_desc[vi] );
255                 scanf( "%1s", input );
256
257                 if( input[0] == '?' )
258                 {
259                     split = split->next;
260                     continue;
261                 }
262
263                 // convert the input character to the normalized value of the variable
264                 for( j = 0; j < count; j++ )
265                     if( map[j] == input[0] )
266                         break;
267                 if( j < count )
268                 {
269                     dir = (split->subset[j>>5] & (1 << (j&31))) ? -1 : 1;
270                     if( split->inversed )
271                         dir = -dir;
272                     break;
273                 }
274                 else
275                     printf( "Error: unrecognized value\n" );
276             }
277             
278             if( !dir )
279             {
280                 printf( "Impossible to classify the sample\n");
281                 node = 0;
282                 break;
283             }
284             node = dir < 0 ? node->left : node->right;
285         }
286
287         if( node )
288             printf( "Prediction result: the mushroom is %s\n",
289                     node->class_idx == 0 ? "EDIBLE" : "POISONOUS" );
290         printf( "\n-----------------------------\n" );
291     }
292 }
293
294
295 int main( int argc, char** argv )
296 {
297     CvMat *data = 0, *missing = 0, *responses = 0;
298     CvDTree* dtree;
299     const char* base_path = argc >= 2 ? argv[1] : "agaricus-lepiota.data";
300
301     help();
302
303     if( !mushroom_read_database( base_path, &data, &missing, &responses ) )
304     {
305         printf( "\nUnable to load the training database\n\n");
306         help();
307         return -1;
308     }
309
310     dtree = mushroom_create_dtree( data, missing, responses,
311         10 // poisonous mushrooms will have 10x higher weight in the decision tree
312         );
313     cvReleaseMat( &data );
314     cvReleaseMat( &missing );
315     cvReleaseMat( &responses );
316
317     print_variable_importance( dtree, var_desc );
318     interactive_classification( dtree, var_desc );
319     delete dtree;
320
321     return 0;
322 }