Skip to content

Commit

Permalink
allow multiple measurement models in UKF
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Dec 11, 2024
1 parent c262bd5 commit 21162d3
Show file tree
Hide file tree
Showing 10 changed files with 425 additions and 170 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -46,6 +47,7 @@ MonteCarloMeasurements = "1"
PDMats = "0.10, 0.11"
Parameters = "0.12"
Plots = "1"
PlutoUI = "0.7.60"
Polyester = "0.6, 0.7"
Printf = "1.7"
Random = "1.7"
Expand Down
6 changes: 3 additions & 3 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
julia = "1.6"
ComponentArrays = "0.15"
DifferentiationInterface = "0.6"
Lux = "1.3"
SparseConnectivityTracer = "0.6"
SparseMatrixColorings = "0.4"
DifferentiationInterface = "0.6"
ComponentArrays = "0.15"
julia = "1.6"
11 changes: 11 additions & 0 deletions docs/src/neural_network.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ plot(
DisplayAs.PNG(Plots.current()) # hide
```

## Smoothing
```@example ADAPTIVE_NN
@time xTe,RTe = smooth(sole, ekf)
@time xTu,RTu = smooth(solu, ukf)
plot(
plot(0:Ts:4000, reduce(hcat, xTe)'[:, nx+1:end], title="EKF parameters", c=1, alpha=0.2),
plot(0:Ts:4000, reduce(hcat, xTu)'[:, nx+1:end], title="UKF parameters", c=1, alpha=0.2),
legend = false,
)
```

## Benchmarking
The neural network used in this example has
```@example ADAPTIVE_NN
Expand Down
1 change: 1 addition & 0 deletions src/LowLevelParticleFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ abstract type AbstractFilter end

include("PFtypes.jl")
include("solutions.jl")
include("measurement_model.jl")
include("kalman.jl")
include("ukf.jl")
include("filtering.jl")
Expand Down
8 changes: 4 additions & 4 deletions src/ekf.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
abstract type AbstractExtendedKalmanFilter{IPD,IPM} <: AbstractKalmanFilter end
@with_kw struct ExtendedKalmanFilter{IPD, IPM, KF <: KalmanFilter, F, G, A, C} <: AbstractExtendedKalmanFilter{IPD,IPM}
struct ExtendedKalmanFilter{IPD, IPM, KF <: KalmanFilter, F, G, A, C} <: AbstractExtendedKalmanFilter{IPD,IPM}
kf::KF
dynamics::F
measurement::G
Expand Down Expand Up @@ -117,15 +117,15 @@ function predict!(kf::AbstractExtendedKalmanFilter{IPD}, u, p = parameters(kf),
kf.t += 1
end

function correct!(kf::AbstractExtendedKalmanFilter{<:Any, IPM}, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(kf.R2, kf.x, u, p, t)) where IPM
function correct!(kf::AbstractExtendedKalmanFilter{<:Any, IPM}, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(kf.R2, kf.x, u, p, t), measurement = kf.measurement) where IPM
@unpack x,R = kf
C = kf.Cjac(x, u, p, t)
if IPM
e = zeros(length(y))
kf.measurement(e, x, u, p, t)
measurement(e, x, u, p, t)
e .= y .- e
else
e = y .- kf.measurement(x, u, p, t)
e = y .- measurement(x, u, p, t)
end
S = symmetrize(C*R*C') + R2
Sᵪ = cholesky(Symmetric(S); check=false)
Expand Down
10 changes: 5 additions & 5 deletions src/kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function convert_x0_type(μ)
end
end

@with_kw mutable struct KalmanFilter{AT,BT,CT,DT,R1T,R2T,D0T,XT,RT,TS,P,αT} <: AbstractKalmanFilter
mutable struct KalmanFilter{AT,BT,CT,DT,R1T,R2T,D0T,XT,RT,TS,P,αT} <: AbstractKalmanFilter
A::AT
B::BT
C::CT
Expand All @@ -32,10 +32,10 @@ end
d0::D0T
x::XT
R::RT
t::Int = 0
Ts::TS = 1
p::P = NullParameters()
α::αT = 1.0
t::Int
Ts::TS
p::P
α::αT
end


Expand Down
145 changes: 145 additions & 0 deletions src/measurement_model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
abstract type AbstractMeasurementModel end

struct ComponsiteMeasurementModel{M} <: AbstractMeasurementModel
models::M
end

struct UKFMeasurementModel{IPM,AUGM,MT,RT,IT,MET,CT,CCT,CAT} <: AbstractMeasurementModel
measurement::MT
R2::RT
ny::Int
ne::Int
innovation::IT
mean::MET
cov::CT
cross_cov::CCT
cache::CAT
end

UKFMeasurementModel{IPM,AUGM}(
measurement,
R2,
ny,
ne,
innovation,
mean,
cov,
cross_cov,
cache = nothing,
) where {IPM,AUGM} = UKFMeasurementModel{
IPM,
AUGM,
typeof(measurement),
typeof(R2),
typeof(innovation),
typeof(mean),
typeof(cov),
typeof(cross_cov),
cache,
}(
measurement,
R2,
ny,
ne,
innovation,
mean,
cov,
cross_cov,
cache,
)


function add_cache(model::UKFMeasurementModel{IPM,AUGM}, cache) where {IPM,AUGM}
UKFMeasurementModel{eltype(model.cache),IPM,AUGM}(
model.measurement,
model.R2,
model.ny,
model.ne,
model.innovation,
model.mean,
model.cov,
model.cross_cov,
cache,
)
end


function UKFMeasurementModel{T,IPM,AUGM}(
measurement,
R2;
nx,
ny,
ne = nothing,
innovation = -,
mean = safe_mean,
cov = safe_cov,
cross_cov = cross_cov,
static = nothing,
) where {T,IPM,AUGM}

ne = if ne === nothing
if !AUGM
0
elseif R2 isa AbstractArray
size(R2, 1)
else
error(
"The number of measurement noise variables, ne, can not be inferred from R2 when R2 is not an array, please provide the keyword argument `ne`.",
)
end
else
if AUGM && R2 isa AbstractArray && size(R2, 1) != ne
error(
"R2 must be square with size equal to the measurement vector length for non-augmented measurement",
)
end
end
if AUGM
L = nx + ne
else
L = nx
end
static2 = something(static, L < 50 && !IPM)
correct_sigma_point_cahce = SigmaPointCache{T}(nx, ne, ny, L, static2)
UKFMeasurementModel{
IPM,
AUGM,
typeof(measurement),
typeof(R2),
typeof(innovation),
typeof(mean),
typeof(cov),
typeof(cross_cov),
typeof(correct_sigma_point_cahce),
}(
measurement,
R2,
ny,
ne,
innovation,
mean,
cov,
cross_cov,
correct_sigma_point_cahce,
)
end



struct SigmaPointCache{X0, X1}
x0::X0
x1::X1
end

function SigmaPointCache{T}(nx, nw, ny, L, static) where T
if static
x0 = [@SVector zeros(T, nx + nw) for _ = 1:2L+1]
x1 = [@SVector zeros(T, ny) for _ = 1:2L+1]
else
x0 = [zeros(T, nx + nw) for _ = 1:2L+1]
x1 = [zeros(T, ny) for _ = 1:2L+1]
end
SigmaPointCache(x0, x1)
end

Base.eltype(spc::SigmaPointCache) = eltype(spc.x0)
Loading

0 comments on commit 21162d3

Please sign in to comment.