c2a238dafae7784a5f802ab3e5f4ea55f5549bf3
[profile/ivi/opencv.git] / samples / python2 / digits_adjust.py
1 '''\r
2 Digit recognition adjustment. \r
3 Grid search is used to find the best parameters for SVN and KNearest classifiers.\r
4 SVM adjustment follows the guidelines given in \r
5 http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf\r
6 \r
7 Threading or cloud computing (with http://www.picloud.com/)) may be used \r
8 to speedup the computation.\r
9 \r
10 Usage:\r
11   digits_adjust.py [--model {svm|knearest}] [--cloud] [--env <PiCloud environment>]\r
12   \r
13   --model {svm|knearest}   - select the classifier (SVM is the default)\r
14   --cloud                  - use PiCloud computing platform\r
15   --env                    - cloud environment name\r
16 \r
17 '''\r
18 # TODO cloud env setup tutorial\r
19 \r
20 import numpy as np\r
21 import cv2\r
22 from multiprocessing.pool import ThreadPool\r
23 \r
24 from digits import *\r
25 \r
26 try: \r
27     import cloud\r
28     have_cloud = True\r
29 except ImportError:\r
30     have_cloud = False\r
31     \r
32 \r
33 \r
34 def cross_validate(model_class, params, samples, labels, kfold = 3, pool = None):\r
35     n = len(samples)\r
36     folds = np.array_split(np.arange(n), kfold)\r
37     def f(i):\r
38         model = model_class(**params)\r
39         test_idx = folds[i]\r
40         train_idx = list(folds)\r
41         train_idx.pop(i)\r
42         train_idx = np.hstack(train_idx)\r
43         train_samples, train_labels = samples[train_idx], labels[train_idx]\r
44         test_samples, test_labels = samples[test_idx], labels[test_idx]\r
45         model.train(train_samples, train_labels)\r
46         resp = model.predict(test_samples)\r
47         score = (resp != test_labels).mean()\r
48         print ".",\r
49         return score\r
50     if pool is None:\r
51         scores = map(f, xrange(kfold))\r
52     else:\r
53         scores = pool.map(f, xrange(kfold))\r
54     return np.mean(scores)\r
55 \r
56 \r
57 class App(object):\r
58     def __init__(self, usecloud=False, cloud_env=''):\r
59         if usecloud and not have_cloud:\r
60             print 'warning: cloud module is not installed, running locally'\r
61             usecloud = False\r
62         self.usecloud = usecloud\r
63         self.cloud_env = cloud_env\r
64 \r
65         if self.usecloud:\r
66             print 'uploading dataset to cloud...'\r
67             cloud.files.put(DIGITS_FN)\r
68             self.preprocess_job = cloud.call(self.preprocess, _env=self.cloud_env)\r
69         else:\r
70             self._samples, self._labels = self.preprocess()\r
71 \r
72     def preprocess(self):\r
73         if self.usecloud:\r
74             cloud.files.get(DIGITS_FN)\r
75         digits, labels = load_digits(DIGITS_FN)\r
76         shuffle = np.random.permutation(len(digits))\r
77         digits, labels = digits[shuffle], labels[shuffle]\r
78         digits2 = map(deskew, digits)\r
79         samples = preprocess_hog(digits2)\r
80         return samples, labels\r
81 \r
82     def get_dataset(self):\r
83         if self.usecloud:\r
84             return cloud.result(self.preprocess_job)\r
85         else:\r
86             return self._samples, self._labels\r
87 \r
88     def run_jobs(self, f, jobs):\r
89         if self.usecloud:\r
90             jids = cloud.map(f, jobs, _env=self.cloud_env, _profile=True, _depends_on=self.preprocess_job)\r
91             ires = cloud.iresult(jids)\r
92         else:\r
93             pool = ThreadPool(processes=cv2.getNumberOfCPUs())\r
94             ires = pool.imap_unordered(f, jobs)\r
95         return ires\r
96             \r
97     def adjust_SVM(self):\r
98         Cs = np.logspace(0, 10, 15, base=2)\r
99         gammas = np.logspace(-7, 4, 15, base=2)\r
100         scores = np.zeros((len(Cs), len(gammas)))\r
101         scores[:] = np.nan\r
102 \r
103         print 'adjusting SVM (may take a long time) ...'\r
104         def f(job):\r
105             i, j = job\r
106             samples, labels = self.get_dataset()\r
107             params = dict(C = Cs[i], gamma=gammas[j])\r
108             score = cross_validate(SVM, params, samples, labels)\r
109             return i, j, score\r
110         \r
111         ires = self.run_jobs(f, np.ndindex(*scores.shape))\r
112         for count, (i, j, score) in enumerate(ires):\r
113             scores[i, j] = score\r
114             print '%d / %d (best error: %.2f %%, last: %.2f %%)' % (count+1, scores.size, np.nanmin(scores)*100, score*100)\r
115         print scores\r
116 \r
117         print 'writing score table to "svm_scores.npz"'\r
118         np.savez('svm_scores.npz', scores=scores, Cs=Cs, gammas=gammas)\r
119 \r
120         i, j = np.unravel_index(scores.argmin(), scores.shape)\r
121         best_params = dict(C = Cs[i], gamma=gammas[j])\r
122         print 'best params:', best_params\r
123         print 'best error: %.2f %%' % (scores.min()*100)\r
124         return best_params\r
125 \r
126     def adjust_KNearest(self):\r
127         print 'adjusting KNearest ...'\r
128         def f(k):\r
129             samples, labels = self.get_dataset()\r
130             err = cross_validate(KNearest, dict(k=k), samples, labels)\r
131             return k, err\r
132         best_err, best_k = np.inf, -1\r
133         for k, err in self.run_jobs(f, xrange(1, 9)):\r
134             if err < best_err:\r
135                 best_err, best_k = err, k\r
136             print 'k = %d, error: %.2f %%' % (k, err*100)\r
137         best_params = dict(k=best_k)\r
138         print 'best params:', best_params, 'err: %.2f' % (best_err*100)\r
139         return best_params\r
140 \r
141 \r
142 if __name__ == '__main__':\r
143     import getopt\r
144     import sys\r
145     \r
146     print __doc__\r
147 \r
148     args, _ = getopt.getopt(sys.argv[1:], '', ['model=', 'cloud', 'env='])\r
149     args = dict(args)\r
150     args.setdefault('--model', 'svm')\r
151     args.setdefault('--env', '')\r
152     if args['--model'] not in ['svm', 'knearest']:\r
153         print 'unknown model "%s"' % args['--model']\r
154         sys.exit(1)\r
155 \r
156     t = clock()\r
157     app = App(usecloud='--cloud' in args, cloud_env = args['--env'])\r
158     if args['--model'] == 'knearest':\r
159         app.adjust_KNearest()\r
160     else:\r
161         app.adjust_SVM()\r
162     print 'work time: %f s' % (clock() - t)\r