#!/usr/bin/env python
# (C) 2000 Huaiyu Zhu <hzhu@users.sourceforge.net>. Licence: GPL
# $Id: test_kalman.py,v 1.4 2001/08/26 12:40:44 hzhu Exp $
"""
test_kalman.py - testing class LinearSystem with kalman filtering
System description:
state: x+ = A*x + u + w
measurement: z = H*x + v
Much of the code is for recording history and displaying.
"""
from MatPy.Matrix import eye,zeros,norm,Matrix_cr,Matrix_r
from MatPy.Stats.distribs import randn
from MatPy.DynSys.kalman import LinearSystem
from MatPy.efuncs import triu
from MatPy.Graphics.trajplot import TrajPlot
class RandomWalk(LinearSystem):
""" RandomWalk: Wiener processes with given kernels """
def init_params(self, n):
""" sets system parameters:
A = [I,I*dt ; I 0] state (p, v) transition
dt : time interval
n : dimension of positions
"""
zz = zeros((n,n))
ee = eye(n)
dt = 0.04
self.A = Matrix_cr([[ee, ee*dt], [zz, ee]])
self.Q2 = q = Matrix_cr([[randn((n,n))/n, zz], [zz, randn((n,n))/n]])
self.Q = q * q.T
self.H = Matrix_r([ee, zz])
self.R2 = r = randn((n,n)) * .4
self.R = r * r.T
self.RI = self.R.I
self.n = n*2
self.m = n
return self
def init_history(self):
""" initialize history of system """
self.g = g = TrajPlot(2, names=["System", "Estimate"],
wait_time=0.1)
g.title("Kalman filter: Trajectories of state and estimate")
g.axis_equal = 1
self.e = [];
print " "*10, " state ", " "*20, "estimate", " "*10, " error "
def add_history(self, x_sys, x_est):
""" add on step to history, print, then plot """
# get past history
e = self.e
e.append(norm(x_sys - x_est))
# record new data, and plot the first two axes
xx = Matrix((x_sys[0], x_est[0]))
yy = Matrix((x_sys[1], x_est[1]))
self.g.plotadd(xx, yy)
# print state, estimate or their distance
print x_sys.T, x_est.T, "%.4g" % e[i]
def show_history(self, waittime):
""" show entire history """
# plot errors
g = Gplot()
g.title("Kalman filter: Distances between states and estimates")
g.plot([Matrix(self.e)])
wait(waittime)
#------------------------------------------------------------------
""" Simulate linear system with Kalman filtering """
from MatPy.gplot import Gplot,wait
from MatPy.Matrix import Matrix
from MatPy.efuncs import abs
linsys = RandomWalk().init_params(n=2)
linsys.init_state()
linsys.init_history()
x = randn(linsys.x.shape)
for i in range(20):
x, u = linsys.evolve(x)
z = linsys.measure(x)
tx = linsys.filter(z, u)
linsys.add_history(x, tx)
if __name__ == "__main__": waittime = None
else: waittime = 2
linsys.show_history(waittime)
|