Skip to content

Commit

Permalink
separate functions for correct and predict
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Dec 6, 2024
1 parent 25441ce commit c262bd5
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions src/ukf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end

abstract type AbstractUnscentedKalmanFilter <: AbstractKalmanFilter end

@with_kw mutable struct UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM,DT,MT,R1T,R2T,D0T,XD,XD0,XM,Y,XT,RT,P,RJ,MET,CT,CCT,IT} <: AbstractUnscentedKalmanFilter
@with_kw mutable struct UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM,DT,MT,R1T,R2T,D0T,XD,XD0,XM,Y,XT,RT,P,RJ,SMT,SCT,MMT,MCT,IT} <: AbstractUnscentedKalmanFilter
dynamics::DT
measurement::MT
R1::R1T
Expand All @@ -54,9 +54,10 @@ abstract type AbstractUnscentedKalmanFilter <: AbstractKalmanFilter end
nu::Int
p::P
reject::RJ = nothing
mean::MET = safe_mean
cov::CT = safe_cov
cross_cov::CCT = cross_cov
state_mean::SMT = safe_mean
state_cov::SCT = safe_cov
measurement_mean::MMT = safe_mean
measurement_cov::MCT = cross_cov
innovation::IT = -
end

Expand Down Expand Up @@ -110,13 +111,14 @@ For problems with challenging dynamics, a mechanism for rejection of sigma point
# Custom mean innovation functions
By default, standard arithmetic mean and `e(y, yh) = y - yh` are used as mean and innovation functions.
By passing the keyword arguments `mean`, `cov`, `cross_cov` and `innovation`, you may override those for use in situations where the state lives on a manifold. These functions must take the following signatures
- `mean(::AbstractVector{<:AbstractVector})`
- `cov(xs::AbstractVector{<:AbstractVector}, m = mean(xs))` where the first argument represent state sigma points and the second argument, which must be optional, represents the mean of those points.
- `cross_cov(xs::AbstractVector{<:AbstractVector}, x::AbstractVector, ys::AbstractVector{<:AbstractVector}, y::AbstractVector)` where the arguments represents (state sigma points, mean state, output sigma points, mean output)
By passing the keyword arguments `state_mean`, `state_cov`, `measurement_mean`, `measurement_cov` and `innovation`, you may override those for use in situations where the state lives on a manifold. These functions must take the following signatures
- `state_mean(xs::AbstractVector{<:AbstractVector})` computes the mean of the vector of vectors of state sigma points.
- `state_cov(xs::AbstractVector{<:AbstractVector}, m = mean(xs))` where the first argument represent state sigma points and the second argument, which must be optional, represents the mean of those points. The function should return the covariance matrix of the state sigma points.
- `measurement_mean(ys::AbstractVector{<:AbstractVector})` computes the mean of the vector of vectors of output sigma points.
- `measurement_cov(xs::AbstractVector{<:AbstractVector}, x::AbstractVector, ys::AbstractVector{<:AbstractVector}, y::AbstractVector)` where the arguments represents (state sigma points, mean state, output sigma points, mean output). The function should return the **cross-covariance** matrix between the state and output sigma points.
- `innovation(y::AbstractVector, yh::AbstractVector)` where the arguments represent (measured output, predicted output)
"""
function UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM}(dynamics,measurement,R1,R2,d0=SimpleMvNormal(R1); Ts = 1.0, p = NullParameters(), nu::Int, ny::Int, reject=nothing, mean=safe_mean, cov=safe_cov, cross_cov=cross_cov, innovation=-) where {IPD,IPM,AUGD,AUGM}
function UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM}(dynamics,measurement,R1,R2,d0=SimpleMvNormal(R1); Ts = 1.0, p = NullParameters(), nu::Int, ny::Int, reject=nothing, state_mean=safe_mean, state_cov=safe_cov, measurement_mean=safe_mean, measurement_cov=cross_cov, innovation=-) where {IPD,IPM,AUGD,AUGM}
nx = length(d0)
nw = size(R1, 1) # nw may be smaller than nx for augmented dynamics
ne = size(R2, 1)
Expand Down Expand Up @@ -166,8 +168,8 @@ function UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM}(dynamics,measurement,R1,R2,d0=
x0 = convert_x0_type(d0.μ)
UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM, typeof(dynamics), typeof(measurement), typeof(R1), typeof(R2), typeof(d0),
typeof(xsd), typeof(xsd0), typeof(xsm), typeof(ys),
typeof(x0), typeof(R), typeof(p), typeof(reject), typeof(mean), typeof(cov), typeof(cross_cov), typeof(innovation)}(
dynamics,measurement,R1,R2, d0, xsd,xsd0,xsm,ys, x0, R, 0, Ts, ny, nu, p, reject, mean, cov, cross_cov, innovation)
typeof(x0), typeof(R), typeof(p), typeof(reject), typeof(state_mean), typeof(state_cov), typeof(measurement_mean), typeof(measurement_cov), typeof(innovation)}(
dynamics,measurement,R1,R2, d0, xsd,xsd0,xsm,ys, x0, R, 0, Ts, ny, nu, p, reject, state_mean, state_cov, measurement_mean, measurement_cov, innovation)
end

function UnscentedKalmanFilter(dynamics,measurement,args...; kwargs...)
Expand All @@ -188,7 +190,7 @@ dynamics(kf::AbstractUnscentedKalmanFilter) = kf.dynamics
@inline has_ip(fun) = hasmethod(fun, Tuple{AbstractArray,AbstractArray,AbstractArray,AbstractArray,Real})

function predict!(ukf::UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM}, u, p = parameters(ukf), t::Real = index(ukf)*ukf.Ts;
R1 = get_mat(ukf.R1, ukf.x, u, p, t), reject = ukf.reject, mean = ukf.mean, cov = ukf.cov) where {IPD,IPM,AUGD,AUGM}
R1 = get_mat(ukf.R1, ukf.x, u, p, t), reject = ukf.reject, mean = ukf.state_mean, cov = ukf.state_cov) where {IPD,IPM,AUGD,AUGM}
@unpack dynamics,measurement,x,xsd,xsd0,R = ukf
# xtyped = eltype(xsd)(x)
nx = length(x)
Expand Down Expand Up @@ -297,7 +299,7 @@ function safe_cov(xs::Vector{<:SVector}, m = safe_mean(xs))
end

function correct!(ukf::UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM}, u, y, p=parameters(ukf), t::Real = index(ukf)*ukf.Ts;
R2 = get_mat(ukf.R2, ukf.x, u, p, t), mean = ukf.mean, cov = ukf.cov, cross_cov = ukf.cross_cov, innovation = ukf.innovation) where {IPD,IPM,AUGD,AUGM}
R2 = get_mat(ukf.R2, ukf.x, u, p, t), mean = ukf.measurement_mean, measurement_cov = ukf.measurement_cov, innovation = ukf.innovation) where {IPD,IPM,AUGD,AUGM}
(; measurement,x,xsm,ys,R,R1) = ukf
nx = length(x)
L = length(xsm[1])
Expand All @@ -310,7 +312,7 @@ function correct!(ukf::UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM}, u, y, p=paramet
sigmapoints_c!(ukf)
propagate_sigmapoints_c!(ukf, u, p, t)
ym = mean(ys)
C = cross_cov(xsm, x, ys, ym)
C = measurement_cov(xsm, x, ys, ym)
e = innovation(y, ym)
S = compute_S(ukf)
Sᵪ = cholesky(Symmetric(S); check=false)
Expand All @@ -323,6 +325,11 @@ function correct!(ukf::UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM}, u, y, p=paramet
(; ll, e, S, Sᵪ, K)
end

"""
cross_cov(xsm, x, ys, y)
Default `measurement_cov` function for `UnscentedKalmanFilter`. Computes the cross-covariance between the state and output sigma points.
"""
function cross_cov(xsm, x, ys, y)
T = eltype(x)
nx = length(x)
Expand Down

0 comments on commit c262bd5

Please sign in to comment.