extended python interface for KalmanFilter
[profile/ivi/opencv.git] / samples / python2 / kalman.py
1 #!/usr/bin/python
2 """
3    Tracking of rotating point.
4    Rotation speed is constant.
5    Both state and measurements vectors are 1D (a point angle),
6    Measurement is the real point angle + gaussian noise.
7    The real and the estimated points are connected with yellow line segment,
8    the real and the measured points are connected with red line segment.
9    (if Kalman filter works correctly,
10     the yellow segment should be shorter than the red one).
11    Pressing any key (except ESC) will reset the tracking with a different speed.
12    Pressing ESC will stop the program.
13 """
14 import urllib2
15 import cv2
16 from math import cos, sin, sqrt
17 import sys
18 import numpy as np
19
20 if __name__ == "__main__":
21
22     img_height = 500
23     img_width = 500
24     img = np.array((img_height, img_width, 3), np.uint8)
25     kalman = cv2.KalmanFilter(2, 1, 0)
26     state = np.zeros((2, 1))  # (phi, delta_phi)
27     process_noise = np.zeros((2, 1))
28     measurement = np.zeros((1, 1))
29
30     code = -1L
31
32     cv2.namedWindow("Kalman")
33
34     while True:
35         state = 0.1 * np.random.randn(2, 1)
36
37         transition_matrix = np.array([[1., 1.], [0., 1.]])
38         kalman.setTransitionMatrix(transition_matrix)
39         measurement_matrix = 1. * np.ones((1, 2))
40         kalman.setMeasurementMatrix(measurement_matrix)
41
42         process_noise_cov = 1e-5
43         kalman.setProcessNoiseCov(process_noise_cov * np.eye(2))
44
45         measurement_noise_cov = 1e-1
46         kalman.setMeasurementNoiseCov(measurement_noise_cov * np.ones((1, 1)))
47
48         kalman.setErrorCovPost(1. * np.ones((2, 2)))
49
50         kalman.setStatePost(0.1 * np.random.randn(2, 1))
51
52         while True:
53             def calc_point(angle):
54                 return (np.around(img_width/2 + img_width/3*cos(angle), 0).astype(int),
55                          np.around(img_height/2 - img_width/3*sin(angle), 1).astype(int))
56
57             state_angle = state[0, 0]
58             state_pt = calc_point(state_angle)
59
60             prediction = kalman.predict()
61             predict_angle = prediction[0, 0]
62             predict_pt = calc_point(predict_angle)
63
64
65             measurement = measurement_noise_cov * np.random.randn(1, 1) 
66
67             # generate measurement
68             measurement = np.dot(measurement_matrix, state) + measurement
69
70             measurement_angle = measurement[0, 0]
71             measurement_pt = calc_point(measurement_angle)
72
73             # plot points
74             def draw_cross(center, color, d):
75                 cv2.line(img, (center[0] - d, center[1] - d),
76                               (center[0] + d, center[1] + d), color, 1, cv2.LINE_AA, 0)
77                 cv2.line(img, (center[0] + d, center[1] - d),
78                               (center[0] - d, center[1] + d), color, 1, cv2.LINE_AA, 0)
79
80             img = np.zeros((img_height, img_width, 3), np.uint8)
81             draw_cross(np.int32(state_pt), (255, 255, 255), 3)
82             draw_cross(np.int32(measurement_pt), (0, 0, 255), 3)
83             draw_cross(np.int32(predict_pt), (0, 255, 0), 3)
84
85             cv2.line(img, state_pt, measurement_pt, (0, 0, 255), 3, cv2.LINE_AA, 0)
86             cv2.line(img, state_pt, predict_pt, (0, 255, 255), 3, cv2.LINE_AA, 0)
87
88             kalman.correct(measurement)
89
90             process_noise = process_noise_cov * np.random.randn(2, 1)
91             
92             state = np.dot(transition_matrix, state) + process_noise
93
94             cv2.imshow("Kalman", img)
95
96             code = cv2.waitKey(100) % 0x100
97             if code != -1:
98                 break
99
100         if code in [27, ord('q'), ord('Q')]:
101             break
102
103     cv2.destroyWindow("Kalman")