2 Digit recognition adjustment.
\r
3 Grid search is used to find the best parameters for SVN and KNearest classifiers.
\r
4 SVM adjustment follows the guidelines given in
\r
5 http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
\r
7 Threading or cloud computing (with http://www.picloud.com/)) may be used
\r
8 to speedup the computation.
\r
11 digits_adjust.py [--model {svm|knearest}] [--cloud] [--env <PiCloud environment>]
\r
13 --model {svm|knearest} - select the classifier (SVM is the default)
\r
14 --cloud - use PiCloud computing platform
\r
15 --env - cloud environment name
\r
18 # TODO cloud env setup tutorial
\r
22 from multiprocessing.pool import ThreadPool
\r
24 from digits import *
\r
34 def cross_validate(model_class, params, samples, labels, kfold = 3, pool = None):
\r
36 folds = np.array_split(np.arange(n), kfold)
\r
38 model = model_class(**params)
\r
40 train_idx = list(folds)
\r
42 train_idx = np.hstack(train_idx)
\r
43 train_samples, train_labels = samples[train_idx], labels[train_idx]
\r
44 test_samples, test_labels = samples[test_idx], labels[test_idx]
\r
45 model.train(train_samples, train_labels)
\r
46 resp = model.predict(test_samples)
\r
47 score = (resp != test_labels).mean()
\r
51 scores = map(f, xrange(kfold))
\r
53 scores = pool.map(f, xrange(kfold))
\r
54 return np.mean(scores)
\r
58 def __init__(self, usecloud=False, cloud_env=''):
\r
59 if usecloud and not have_cloud:
\r
60 print 'warning: cloud module is not installed, running locally'
\r
62 self.usecloud = usecloud
\r
63 self.cloud_env = cloud_env
\r
66 print 'uploading dataset to cloud...'
\r
67 cloud.files.put(DIGITS_FN)
\r
68 self.preprocess_job = cloud.call(self.preprocess, _env=self.cloud_env)
\r
70 self._samples, self._labels = self.preprocess()
\r
72 def preprocess(self):
\r
74 cloud.files.get(DIGITS_FN)
\r
75 digits, labels = load_digits(DIGITS_FN)
\r
76 shuffle = np.random.permutation(len(digits))
\r
77 digits, labels = digits[shuffle], labels[shuffle]
\r
78 digits2 = map(deskew, digits)
\r
79 samples = preprocess_hog(digits2)
\r
80 return samples, labels
\r
82 def get_dataset(self):
\r
84 return cloud.result(self.preprocess_job)
\r
86 return self._samples, self._labels
\r
88 def run_jobs(self, f, jobs):
\r
90 jids = cloud.map(f, jobs, _env=self.cloud_env, _profile=True, _depends_on=self.preprocess_job)
\r
91 ires = cloud.iresult(jids)
\r
93 pool = ThreadPool(processes=cv2.getNumberOfCPUs())
\r
94 ires = pool.imap_unordered(f, jobs)
\r
97 def adjust_SVM(self):
\r
98 Cs = np.logspace(0, 10, 15, base=2)
\r
99 gammas = np.logspace(-7, 4, 15, base=2)
\r
100 scores = np.zeros((len(Cs), len(gammas)))
\r
103 print 'adjusting SVM (may take a long time) ...'
\r
106 samples, labels = self.get_dataset()
\r
107 params = dict(C = Cs[i], gamma=gammas[j])
\r
108 score = cross_validate(SVM, params, samples, labels)
\r
111 ires = self.run_jobs(f, np.ndindex(*scores.shape))
\r
112 for count, (i, j, score) in enumerate(ires):
\r
113 scores[i, j] = score
\r
114 print '%d / %d (best error: %.2f %%, last: %.2f %%)' % (count+1, scores.size, np.nanmin(scores)*100, score*100)
\r
117 print 'writing score table to "svm_scores.npz"'
\r
118 np.savez('svm_scores.npz', scores=scores, Cs=Cs, gammas=gammas)
\r
120 i, j = np.unravel_index(scores.argmin(), scores.shape)
\r
121 best_params = dict(C = Cs[i], gamma=gammas[j])
\r
122 print 'best params:', best_params
\r
123 print 'best error: %.2f %%' % (scores.min()*100)
\r
126 def adjust_KNearest(self):
\r
127 print 'adjusting KNearest ...'
\r
129 samples, labels = self.get_dataset()
\r
130 err = cross_validate(KNearest, dict(k=k), samples, labels)
\r
132 best_err, best_k = np.inf, -1
\r
133 for k, err in self.run_jobs(f, xrange(1, 9)):
\r
135 best_err, best_k = err, k
\r
136 print 'k = %d, error: %.2f %%' % (k, err*100)
\r
137 best_params = dict(k=best_k)
\r
138 print 'best params:', best_params, 'err: %.2f' % (best_err*100)
\r
142 if __name__ == '__main__':
\r
148 args, _ = getopt.getopt(sys.argv[1:], '', ['model=', 'cloud', 'env='])
\r
150 args.setdefault('--model', 'svm')
\r
151 args.setdefault('--env', '')
\r
152 if args['--model'] not in ['svm', 'knearest']:
\r
153 print 'unknown model "%s"' % args['--model']
\r
157 app = App(usecloud='--cloud' in args, cloud_env = args['--env'])
\r
158 if args['--model'] == 'knearest':
\r
159 app.adjust_KNearest()
\r
162 print 'work time: %f s' % (clock() - t)
\r