Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Oct 9, 2024
1 parent 6723a94 commit 1452069
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions examples/particle-mcmc/lorenz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,43 @@ include("particles.jl")
include("resamplers.jl")
include("simple-filters.jl")


Base.@kwdef struct Parameters{T<:Real}
β::T = 8/3
ρ::T = 28.
σ::T = 10.
ν::T = 1. # Obs noise variance
dt::T = 0.025 # Time step
β::T = 8 / 3
ρ::T = 28.0
σ::T = 10.0
ν::T = 1.0 # Obs noise variance
dt::T = 0.025 # Time step
end

function lorenz!(du, u, p::Parameters, t)
@unpack β, ρ, σ = p
du[1] = σ * (u[2] - u[1])
du[2] = u[1] *- u[3]) - u[2]
du[3] = u[1] * u[2] - β * u[3]
function lorenz!(du, u, p::Parameters, t)
@unpack β, ρ, σ = p
du[1] = σ * (u[2] - u[1])
du[2] = u[1] *- u[3]) - u[2]
return du[3] = u[1] * u[2] - β * u[3]
end


struct LatentNoiseProcess{T} <: LatentDynamics{Vector{T}}
σ::AbstractPDMat{T}
dt::T
integrator
σ::AbstractPDMat{T}
dt::T
integrator
end

struct ObservationNoiseProcess{T} <: ObservationProcess{Vector{T}}
σ::AbstractPDMat{T}
σ::AbstractPDMat{T}
end

function SSMProblems.distribution(dyn::LatentNoiseProcess, step::Integer, prev_state, extra)
reinit!(dyn.integrator, prev_state)
step!(dyn.integrator, dyn.dt, true)
return MvNormal(dyn.integrator.u, dyn.σ)
reinit!(dyn.integrator, prev_state)
step!(dyn.integrator, dyn.dt, true)
return MvNormal(dyn.integrator.u, dyn.σ)
end

function SSMProblems.distribution(dyn::LatentNoiseProcess, extra)
return MvNormal([1; 0; 0], dyn.σ)
return MvNormal([1; 0; 0], dyn.σ)
end

function SSMProblems.distribution(obs::ObservationNoiseProcess, step::Integer, state, extra)
return MvNormal(state, obs.σ * I)
return MvNormal(state, obs.σ * I)
end

# Simulate some data
Expand Down Expand Up @@ -78,7 +76,7 @@ filter = BF(Np; threshold=1.0, resampler=Systematic());
sparse_ancestry = AncestorCallback(eltype(model.dyn), filter.N, 1.0);
tree, llbf = sample(rng, model, filter, y; callback=sparse_ancestry);
lineage = get_ancestry(sparse_ancestry.tree)

# Fancy 3D plot
# fig = Figure()
# lines(fig[1, 1], hcat(x0, x...))
Expand All @@ -88,8 +86,8 @@ lineage = get_ancestry(sparse_ancestry.tree)

fig = Figure()
for i in eachindex(first(x))
lines(fig[i, 1], hcat(x0, x...)[i, :])
for path in lineage
lines!(fig[i, 1], hcat(path...)[i, :], color=:black, alpha=0.1)
end
end
lines(fig[i, 1], hcat(x0, x...)[i, :])
for path in lineage
lines!(fig[i, 1], hcat(path...)[i, :]; color=:black, alpha=0.1)
end
end

0 comments on commit 1452069

Please sign in to comment.