added python cv2 port of letter_recog sample
authorAlexander Mordvintsev <no@email>
Sun, 5 Jun 2011 06:44:19 +0000 (06:44 +0000)
committerAlexander Mordvintsev <no@email>
Sun, 5 Jun 2011 06:44:19 +0000 (06:44 +0000)
samples/python2/letter_recog.py [new file with mode: 0644]

diff --git a/samples/python2/letter_recog.py b/samples/python2/letter_recog.py
new file mode 100644 (file)
index 0000000..efd4297
--- /dev/null
@@ -0,0 +1,136 @@
+import numpy as np\r
+import cv2\r
+\r
+def load_base(fn):\r
+    a = np.loadtxt(fn, np.float32, delimiter=',', converters={ 0 : lambda ch : ord(ch)-ord('A') })\r
+    samples, responses = a[:,1:], a[:,0]\r
+    return samples, responses\r
+\r
+# TODO move these to cv2\r
+CV_ROW_SAMPLE = 1\r
+CV_VAR_NUMERICAL   = 0\r
+CV_VAR_ORDERED     = 0\r
+CV_VAR_CATEGORICAL = 1\r
+\r
+\r
+class LetterStatModel(object):\r
+    train_ratio = 0.5\r
+    def load(self, fn):\r
+        self.model.load(fn)\r
+    def save(self, fn):\r
+        self.model.save(fn)\r
+\r
+class RTrees(LetterStatModel):\r
+    def __init__(self):\r
+        self.model = cv2.RTrees()\r
+\r
+    def train(self, samples, responses):\r
+        sample_n, var_n = samples.shape\r
+        var_types = np.array([CV_VAR_NUMERICAL] * var_n + [CV_VAR_CATEGORICAL], np.uint8)\r
+        #CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));\r
+        params = dict(max_depth=10 )\r
+        self.model.train(samples, CV_ROW_SAMPLE, responses, varType = var_types, params = params)\r
+\r
+    def predict(self, samples):\r
+        return np.float32( [self.model.predict(s) for s in samples] )\r
+        \r
+\r
+class KNearest(LetterStatModel):\r
+    def __init__(self):\r
+        self.model = cv2.KNearest()\r
+\r
+    def train(self, samples, responses):\r
+        self.model.train(samples, responses)\r
+\r
+    def predict(self, samples):\r
+        retval, results, neigh_resp, dists = self.model.find_nearest(samples, k = 10)\r
+        return results.ravel()\r
+\r
+\r
+class Boost(LetterStatModel):\r
+    def __init__(self):\r
+        self.model = cv2.Boost()\r
+        self.class_n = 26\r
+    \r
+    def train(self, samples, responses):\r
+        sample_n, var_n = samples.shape\r
+        new_samples = self.unroll_samples(samples)\r
+        new_responses = self.unroll_responses(responses)\r
+        var_types = np.array([CV_VAR_NUMERICAL] * var_n + [CV_VAR_CATEGORICAL, CV_VAR_CATEGORICAL], np.uint8)\r
+        #CvBoostParams(CvBoost::REAL, 100, 0.95, 5, false, 0 )\r
+        params = dict(max_depth=5) #, use_surrogates=False)\r
+        self.model.train(new_samples, CV_ROW_SAMPLE, new_responses, varType = var_types, params=params)\r
+\r
+    def predict(self, samples):\r
+        new_samples = self.unroll_samples(samples)\r
+        pred = np.array( [self.model.predict(s, returnSum = True) for s in new_samples] )\r
+        pred = pred.reshape(-1, self.class_n).argmax(1)\r
+        return pred\r
+\r
+    def unroll_samples(self, samples):\r
+        sample_n, var_n = samples.shape\r
+        new_samples = np.zeros((sample_n * self.class_n, var_n+1), np.float32)\r
+        new_samples[:,:-1] = np.repeat(samples, self.class_n, axis=0)\r
+        new_samples[:,-1] = np.tile(np.arange(self.class_n), sample_n)\r
+        return new_samples\r
+    \r
+    def unroll_responses(self, responses):\r
+        sample_n = len(responses)\r
+        new_responses = np.zeros(sample_n*self.class_n, np.int32)\r
+        resp_idx = np.int32( responses + np.arange(sample_n)*self.class_n )\r
+        new_responses[resp_idx] = 1\r
+        return new_responses\r
+\r
+\r
+class SVM(LetterStatModel):\r
+    train_ratio = 0.1\r
+    def __init__(self):\r
+        self.model = cv2.SVM()\r
+\r
+    def train(self, samples, responses):\r
+        params = dict( kernel_type = cv2.SVM_LINEAR, \r
+                       svm_type = cv2.SVM_C_SVC,\r
+                       C = 1 )\r
+        self.model.train(samples, responses, params = params)\r
+\r
+    def predict(self, samples):\r
+        return np.float32( [self.model.predict(s) for s in samples] )\r
+\r
+\r
+if __name__ == '__main__':\r
+    import argparse\r
+\r
+    models = [RTrees, KNearest, Boost, SVM] # MLP, NBayes\r
+    models = dict( [(cls.__name__.lower(), cls) for cls in models] )\r
+    \r
+    parser = argparse.ArgumentParser()\r
+    parser.add_argument('-model', default='rtrees', choices=models.keys())\r
+    parser.add_argument('-data', nargs=1, default='letter-recognition.data')\r
+    parser.add_argument('-load', nargs=1)\r
+    parser.add_argument('-save', nargs=1)\r
+    args = parser.parse_args()\r
+\r
+    print 'loading data %s ...' % args.data\r
+    samples, responses = load_base(args.data)\r
+    Model = models[args.model]\r
+    model = Model()\r
+\r
+    train_n = int(len(samples)*model.train_ratio)\r
+    if args.load is None:\r
+        print 'training %s ...' % Model.__name__\r
+        model.train(samples[:train_n], responses[:train_n])\r
+    else:\r
+        fn = args.load[0]\r
+        print 'loading model from %s ...' % fn\r
+        model.load(fn)\r
+\r
+    print 'testing...'\r
+    train_rate = np.mean(model.predict(samples[:train_n]) == responses[:train_n])\r
+    test_rate  = np.mean(model.predict(samples[train_n:]) == responses[train_n:])\r
+\r
+    print 'train rate: %f  test rate: %f' % (train_rate*100, test_rate*100)\r
+\r
+    if args.save is not None:\r
+        fn = args.save[0]\r
+        print 'saving model to %s ...' % fn\r
+        model.save(fn)\r