From a718f2e6eaf88a543d4bc2441b8be582f3ba6af8 Mon Sep 17 00:00:00 2001 From: berak Date: Mon, 13 Jan 2020 12:26:28 +0100 Subject: [PATCH] ml/python: fix digits samples(3.4) --- samples/python/digits.py | 20 +++++++++++++------- samples/python/digits_video.py | 14 +++++++++----- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/samples/python/digits.py b/samples/python/digits.py index f58e9dd..e5d8ceb 100755 --- a/samples/python/digits.py +++ b/samples/python/digits.py @@ -70,13 +70,8 @@ def deskew(img): img = cv.warpAffine(img, M, (SZ, SZ), flags=cv.WARP_INVERSE_MAP | cv.INTER_LINEAR) return img -class StatModel(object): - def load(self, fn): - self.model.load(fn) # Known bug: https://github.com/opencv/opencv/issues/4969 - def save(self, fn): - self.model.save(fn) -class KNearest(StatModel): +class KNearest(object): def __init__(self, k = 3): self.k = k self.model = cv.ml.KNearest_create() @@ -88,7 +83,13 @@ class KNearest(StatModel): _retval, results, _neigh_resp, _dists = self.model.findNearest(samples, self.k) return results.ravel() -class SVM(StatModel): + def load(self, fn): + self.model = cv.ml.KNearest_load(fn) + + def save(self, fn): + self.model.save(fn) + +class SVM(object): def __init__(self, C = 1, gamma = 0.5): self.model = cv.ml.SVM_create() self.model.setGamma(gamma) @@ -102,6 +103,11 @@ class SVM(StatModel): def predict(self, samples): return self.model.predict(samples)[1].ravel() + def load(self, fn): + self.model = cv.ml.SVM_load(fn) + + def save(self, fn): + self.model.save(fn) def evaluate_model(model, digits, samples, labels): resp = model.predict(samples) diff --git a/samples/python/digits_video.py b/samples/python/digits_video.py index dc035e4..7b07831 100755 --- a/samples/python/digits_video.py +++ b/samples/python/digits_video.py @@ -1,4 +1,12 @@ #!/usr/bin/env python +''' +Digit recognition from video. + +Run digits.py before, to train and save the SVM. + +Usage: + digits_video.py [{camera_id|video_file}] +''' # Python 2/3 compatibility from __future__ import print_function @@ -28,11 +36,7 @@ def main(): print('"%s" not found, run digits.py first' % classifier_fn) return - if True: - model = cv.ml.SVM_load(classifier_fn) - else: - model = cv.ml.SVM_create() - model.load_(classifier_fn) #Known bug: https://github.com/opencv/opencv/issues/4969 + model = cv.ml.SVM_load(classifier_fn) while True: _ret, frame = cap.read() -- 2.7.4