ete*_*ty1 2 numpy smoothing missing-data kalman-filter pykalman
我正在使用 pykalman 模块中的 KalmanFilter 并想知道它如何处理丢失的观察结果。根据文档:
在现实世界的系统中,传感器偶尔出现故障是很常见的。卡尔曼滤波器、卡尔曼平滑器和 EM 算法都可以处理这种情况。要使用它,只需在缺失的时间步长处对测量应用 NumPy 掩码:
from numpy import ma X = ma.array([1,2,3]) X 1 = ma.masked # 在时间步长 1 隐藏测量 kf.em(X).smooth(X)
我们可以平滑输入时间序列。由于这是一个“附加”功能,我认为它不是自动完成的;那么在变量中有 NaN 时的默认方法是什么?
这里解释了可能发生的理论方法;这也是pykalman所做的吗(在我看来这真的很棒):
让我们来看看源代码:
在filter_update函数中,pykalman 检查当前观察是否被屏蔽。
def filter_update(...)
# Make a masked observation if necessary
if observation is None:
n_dim_obs = observation_covariance.shape[0]
observation = np.ma.array(np.zeros(n_dim_obs))
observation.mask = True
else:
observation = np.ma.asarray(observation)
Run Code Online (Sandbox Code Playgroud)
它不会影响预测步骤。但是修正步骤有两个选择。它发生在_filter_correct函数中。
def _filter_correct(...)
if not np.any(np.ma.getmask(observation)):
# the normal Kalman Filter math
else:
n_dim_state = predicted_state_covariance.shape[0]
n_dim_obs = observation_matrix.shape[0]
kalman_gain = np.zeros((n_dim_state, n_dim_obs))
# !!!! the corrected state takes the result of the prediction !!!!
corrected_state_mean = predicted_state_mean
corrected_state_covariance = predicted_state_covariance
Run Code Online (Sandbox Code Playgroud)
正如你所看到的,这正是理论方法。
这是一个简短的示例和工作数据。
假设您有一个 GPS 接收器,并且您想在步行时跟踪自己。接收器具有良好的精度。为简化起见,假设您只往前走。
没有什么有趣的事情发生。由于良好的 GPS 信号,过滤器可以很好地估计您的位置。如果一段时间没有信号会怎样?
滤波器只能根据现有状态和系统动力学知识进行预测(见矩阵 Q)。随着每个预测步骤,不确定性增加。估计位置周围的 1-Sigma 范围变大。一旦再次出现新的观察结果,状态就会得到纠正。
这是代码和数据:
from pykalman import KalmanFilter
import numpy as np
import matplotlib.pyplot as plt
from numpy import ma
# enable or disable missing observations
use_mask = 1
# reading data (quick and dirty)
Time=[]
X=[]
for line in open('data/dataset_01.csv'):
f1, f2 = line.split(';')
Time.append(float(f1))
X.append(float(f2))
if (use_mask):
X = ma.asarray(X)
X[300:500] = ma.masked
# Filter Configuration
# time step
dt = Time[2] - Time[1]
# transition_matrix
F = [[1, dt, 0.5*dt*dt],
[0, 1, dt],
[0, 0, 1]]
# observation_matrix
H = [1, 0, 0]
# transition_covariance
Q = [[ 1, 0, 0],
[ 0, 1e-4, 0],
[ 0, 0, 1e-6]]
# observation_covariance
R = [0.04] # max error = 0.6m
# initial_state_mean
X0 = [0,
0,
0]
# initial_state_covariance
P0 = [[ 10, 0, 0],
[ 0, 1, 0],
[ 0, 0, 1]]
n_timesteps = len(Time)
n_dim_state = 3
filtered_state_means = np.zeros((n_timesteps, n_dim_state))
filtered_state_covariances = np.zeros((n_timesteps, n_dim_state, n_dim_state))
# Kalman-Filter initialization
kf = KalmanFilter(transition_matrices = F,
observation_matrices = H,
transition_covariance = Q,
observation_covariance = R,
initial_state_mean = X0,
initial_state_covariance = P0)
# iterative estimation for each new measurement
for t in range(n_timesteps):
if t == 0:
filtered_state_means[t] = X0
filtered_state_covariances[t] = P0
else:
filtered_state_means[t], filtered_state_covariances[t] = (
kf.filter_update(
filtered_state_means[t-1],
filtered_state_covariances[t-1],
observation = X[t])
)
position_sigma = np.sqrt(filtered_state_covariances[:, 0, 0]);
# plot of the resulted trajectory
plt.plot(Time, filtered_state_means[:, 0], "g-", label="Filtered position", markersize=1)
plt.plot(Time, filtered_state_means[:, 0] + position_sigma, "r--", label="+ sigma", markersize=1)
plt.plot(Time, filtered_state_means[:, 0] - position_sigma, "r--", label="- sigma", markersize=1)
plt.grid()
plt.legend(loc="upper left")
plt.xlabel("Time (s)")
plt.ylabel("Position (m)")
plt.show()
Run Code Online (Sandbox Code Playgroud)
更新
如果您屏蔽更长的时间 (300:700),它看起来会更有趣。
如您所见,位置返回。它的发生是因为转换矩阵 F 绑定了位置、速度和加速度的预测。
如果你看一下速度状态,它解释了下降的位置。
在 300 秒的时间点,加速度冻结。速度以恒定斜率下降并越过 0 值。在此时间点之后,位置必须返回。
要绘制速度,请使用以下代码:
velocity_sigma = np.sqrt(filtered_state_covariances[:, 1, 1]);
# plot of the estimated velocity
plt.plot(Time, filtered_state_means[:, 1], "g-", label="Filtered velocity", markersize=1)
plt.plot(Time, filtered_state_means[:, 1] + velocity_sigma, "r--", label="+ sigma", markersize=1)
plt.plot(Time, filtered_state_means[:, 1] - velocity_sigma, "r--", label="- sigma", markersize=1)
plt.grid()
plt.legend(loc="upper left")
plt.xlabel("Time (s)")
plt.ylabel("Velocity (m/s)")
plt.show()
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
3276 次 |
最近记录: |