Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface Changes for Use in Filtering #56

Merged
merged 33 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
440252c
added basic particle methods and filters
charlesknipp Aug 9, 2024
9fd4453
added qualifiers
charlesknipp Aug 12, 2024
3fd90c4
added parameter priors
charlesknipp Aug 12, 2024
884b9e3
Merge branch 'main' into ck/particle-methods
charlesknipp Aug 30, 2024
1def6a1
Merge branch 'main' into ck/particle-methods
charlesknipp Sep 24, 2024
a5a2e05
added adaptive resampling to bootstrap filter (WIP)
charlesknipp Sep 25, 2024
57da3ff
Julia fomatter changes
charlesknipp Sep 25, 2024
dc713b0
Merge branch 'ck/particle-methods' of https://github.com/TuringLang/S…
charlesknipp Sep 25, 2024
b846fa4
changed eltype for <: StateSpaceModel
charlesknipp Sep 26, 2024
4263ae7
updated naming conventions
charlesknipp Sep 26, 2024
5a2aeb4
formatter
charlesknipp Sep 26, 2024
8db658b
fixed adaptive resampling
charlesknipp Sep 27, 2024
15dfa9f
added particle ancestry
charlesknipp Oct 1, 2024
7e3c93d
formatter issues
charlesknipp Oct 1, 2024
f905a41
fixed metropolis and added rejection resampler
charlesknipp Oct 1, 2024
8ac1455
Keep track of free indices using stack
THargreaves Oct 2, 2024
f11a63e
updated particle types and organized directory
charlesknipp Oct 2, 2024
1fa3c93
weakened SSM type parameter assertions
charlesknipp Oct 4, 2024
8cb4338
improved particle state containment and resampling
charlesknipp Oct 4, 2024
73dd433
added hacky sparse ancestry to example
charlesknipp Oct 5, 2024
f71ab32
fixed RNG in rejection resampling
charlesknipp Oct 6, 2024
25cebf4
improved callbacks and resamplers
charlesknipp Oct 6, 2024
c729879
formatting
charlesknipp Oct 6, 2024
d13c80c
added conditional SMC
charlesknipp Oct 8, 2024
856cebb
improved linear model type structure
charlesknipp Oct 8, 2024
d7daf93
formatter
charlesknipp Oct 8, 2024
b29ba60
replaced extra with kwargs
charlesknipp Oct 11, 2024
ece40fa
formatter
charlesknipp Oct 11, 2024
75fdf2c
migrated filtering code
charlesknipp Oct 11, 2024
2cc4016
Add unittests for new interface
THargreaves Oct 14, 2024
c76278f
Update documentation to match kwargs
THargreaves Oct 18, 2024
04f9808
Rename extras/kwargs docs file
THargreaves Oct 18, 2024
5a8bba2
remove redundant forward simulations
charlesknipp Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/particle-mcmc/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
180 changes: 180 additions & 0 deletions examples/particle-mcmc/particles.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
using DataStructures: Stack

## PARTICLES ###############################################################################

abstract type AbstractParticleContainer{T} end

"""
store!(particles, new_states, [idx])

update the state component of the particle container, with optional parent indices supplied
for use in ancestry storage.
"""
function store! end

"""
reset_weights!(particles)

in-place method to reset the log weights of the particle cloud to zero; typically called
following a resampling step.
"""
function reset_weights! end

mutable struct ParticleContainer{T,WT<:Real} <: AbstractParticleContainer{T}
vals::Vector{T}
log_weights::Vector{WT}
end

Base.collect(pc::ParticleContainer) = pc.vals
Base.length(pc::ParticleContainer) = length(pc.vals)
Base.keys(pc::ParticleContainer) = LinearIndices(pc.vals)

# not sure if this is kosher, since it doesn't follow the convention of Base.getindex
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Int) = pc.vals[i]
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Vector{Int}) = pc.vals[i]

function store!(pc::ParticleContainer, new_states, idx...; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need the parent indices, that's up to the storage implementation to decide how they store ancestry paths:

Maybe something like that is enough ?

""" Store new generation in the container
"""
store!(pc::AbstractParticleContainer, new_generation)

""" Get last generation
"""
load(pc::AbstractParticleContainer)

""" Get (log)-weights from storage
"""
weights(pc::AbstractParticleContainer)
logweights(pc::AbstractParticleContainer)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need the parent indices, that's up to the storage implementation to decide how they store ancestry paths

I might be misunderstanding you but I think we do, e.g. for smoothing.

If we just pass the particle container multiple vectors of states, it has no idea what the genealogy is so you can't perform naive smoothing on it by back-tracing ancestry.

Maybe something like that is enough ?

I'm a bit unsure on tis too. It feels like the filter is depending a bit too much on the implementation of the storage , which should ideally independent.

My instinct is for the filter to maintain a minimal collection of variables for it to run. I think this generally would just be the current state (represented here as a combination of x values and log weights). It would update these independently of the storage.

Then at each time step, it passes what it currently has to the storage object which can do what it wants with it. The key idea is that the filter should still work even in the extreme case that the storage throws everything away.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I might be a bit confused here.

@charlesknipp, what was the intention of ParticleContainer? Is it a type of particle storage (one that just only remembers the current state) or just a representation of the current state?

If the latter, I don't think it and AncestoryContainer should be subtypes of the same type as they do different things.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I currently use ParticleContainer as a means of storage to preserve the weighted nature of the sample at step t. Although I wonder if we could move ancestry storage to a callback, which would be very elegant if possible

Copy link
Collaborator

@THargreaves THargreaves Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, I'm following now.

I didn't make this clear, but from my point of view store is the callback.

But rather than defining it as a simple function, it is tied to a storage container.

So the particle filter has a state::ParticleState(xs, log_ws) (currently called ParticleContainer but just to make the difference really clear) which it updates either in-place or by replacing with a new ParticleState and then this is passed to store! after each step to do what it pleases.

And store! can dispatch on SparseAncestoryStorage <: AbstractParticleStorage <: AbstractStorage or something like that, which is the Lawrence Murray algorithm you implemented.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, right now store! is just a means of updating an AbstractStorage (or AbstractParticleContainer in my code).

I really like the idea you present with separating particle storage and particle state. Although, that would imply the need to store the ancestry indices in the particle state (which would be necessary for sparse ancestry storage). I'm not 100% sure of the details yet, but I think I can make this look pretty elegant.

setindex!(pc.vals, new_states, eachindex(pc))
return pc
end

function reset_weights!(pc::ParticleContainer{T,WT}) where {T,WT<:Real}
fill!(pc.log_weights, zero(WT))
return pc.log_weights
end

## JACOB-MURRAY PARTICLE STORAGE ###########################################################

Base.append!(s::Stack, a::AbstractVector) = map(x -> push!(s, x), a)

mutable struct ParticleTree{T}
states::Vector{T}
parents::Vector{Int64}
leaves::Vector{Int64}
offspring::Vector{Int64}
free_indices::Stack{Int64}

function ParticleTree(states::Vector{T}, M::Integer) where {T}
nodes = Vector{T}(undef, M)
initial_free_indices = Stack{Int64}()
append!(initial_free_indices, M:-1:(length(states) + 1))
@inbounds nodes[1:length(states)] = states
return new{T}(
nodes, zeros(Int64, M), 1:length(states), zeros(Int64, M), initial_free_indices
)
end
end

Base.length(tree::ParticleTree) = length(tree.states)
Base.keys(tree::ParticleTree) = LinearIndices(tree.states)

function prune!(tree::ParticleTree, offspring::Vector{Int64})
## insert new offspring counts
setindex!(tree.offspring, offspring, tree.leaves)

## update each branch
@inbounds for i in eachindex(offspring)
j = tree.leaves[i]
while (j > 0) && (tree.offspring[j] == 0)
push!(tree.free_indices, j)
j = tree.parents[j]
if j > 0
tree.offspring[j] -= 1
end
end
end
end

function insert!(
tree::ParticleTree{T}, states::Vector{T}, a::AbstractVector{Int64}
charlesknipp marked this conversation as resolved.
Show resolved Hide resolved
) where {T}
## parents of new generation
parents = getindex(tree.leaves, a)

## ensure there are enough dead branches
if (length(tree.free_indices) < length(a))
@debug "expanding tree"
expand!(tree)
end

## find places for new states
@inbounds for i in eachindex(states)
tree.leaves[i] = pop!(tree.free_indices)
end

## insert new generation and update parent child relationships
setindex!(tree.states, states, tree.leaves)
setindex!(tree.parents, parents, tree.leaves)
return tree
end

function expand!(tree::ParticleTree)
M = length(tree)
resize!(tree.states, 2 * M)

# new allocations must be zero valued, this is not a perfect solution
tree.parents = [tree.parents; zero(tree.parents)]
tree.offspring = [tree.offspring; zero(tree.offspring)]
append!(tree.free_indices, (2 * M):-1:(M + 1))
return tree
end

function get_offspring(a::AbstractVector{Int64})
offspring = zero(a)
for i in a
offspring[i] += 1
end
return offspring
end

## FILTERING WITH ANCESTRY #################################################################

mutable struct AncestryContainer{T,WT<:Real} <: AbstractParticleContainer{T}
tree::ParticleTree{T}
log_weights::Vector{WT}

function AncestryContainer(
initial_states::Vector{T}, log_weights::Vector{WT}, C::Int64=1
) where {T,WT<:Real}
N = length(log_weights)
M = floor(C * N * log(N))
tree = ParticleTree(initial_states, Int64(M))
return new{T,WT}(tree, log_weights)
end
end

function Base.collect(ac::AncestryContainer)
return getindex(ac.tree.states, ac.tree.leaves)
end

function Base.getindex(ac::AncestryContainer, a::AbstractVector{Int64})
return getindex(ac.tree.states, getindex(ac.tree.leaves, a))
end

function reset_weights!(ac::AncestryContainer{T,WT}) where {T,WT<:Real}
fill!(ac.log_weights, zero(WT))
return ac.log_weights
end

function store!(ac::AncestryContainer, new_states, idx)
prune!(ac.tree, get_offspring(idx))
insert!(ac.tree, new_states, idx)
return ac
end

# start at each leaf and retrace it's steps to the root node
function get_ancestry(tree::ParticleTree{T}) where {T}
paths = Vector{Vector{T}}(undef, length(tree.leaves))
@inbounds for (k, i) in enumerate(tree.leaves)
j = tree.parents[i]
xi = tree.states[i]

xs = [xi]
while j > 0
push!(xs, tree.states[j])
j = tree.parents[j]
end
paths[k] = reverse(xs)
end
return paths
end
84 changes: 84 additions & 0 deletions examples/particle-mcmc/resamplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using Random
using Distributions

function multinomial_resampling(
rng::AbstractRNG, weights::AbstractVector{WT}, n::Int64=length(weights); kwargs...
) where {WT<:Real}
return rand(rng, Distributions.Categorical(weights), n)
end

function systematic_resampling(
rng::AbstractRNG, weights::AbstractVector{WT}, n::Int64=length(weights); kwargs...
) where {WT<:Real}
# pre-calculations
@inbounds v = n * weights[1]
u = oftype(v, rand(rng))

# initialize sampling algorithm
a = Vector{Int64}(undef, n)
idx = 1

@inbounds for i in 1:n
while v < u
idx += 1
v += n * weights[idx]
end
a[i] = idx
u += one(u)
end

return a
end

# TODO: this should be done in the log domain and also parallelized
function metropolis_resampling(
rng::AbstractRNG,
weights::AbstractVector{WT},
n::Int64=length(weights);
ε::Float64=0.01,
kwargs...,
) where {WT<:Real}
# pre-calculations
β = mean(weights)
bins = Int64(cld(log(ε), log(1 - β)))

# initialize the algorithm
a = Vector{Int64}(undef, n)

@inbounds for i in 1:n
k = i
for _ in 1:bins
j = rand(rng, 1:n)
v = weights[j] / weights[k]
if rand(rng) ≤ v
k = j
end
end
a[i] = k
end

return a
end

# TODO: this should be done in the log domain and also parallelized
function rejection_resampling(
rng::AbstractRNG, weights::AbstractVector{WT}, n::Int64=length(weights); kwargs...
) where {WT<:Real}
# pre-calculations
max_weight = maximum(weights)

# initialize the algorithm
a = Vector{Int64}(undef, n)

@inbounds for i in 1:n
j = i
u = rand(rng)
while u > weights[j] / max_weight
j = rand(1:n)
u = rand(rng)
end
a[i] = j
end

return a
end
97 changes: 97 additions & 0 deletions examples/particle-mcmc/script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using AdvancedMH
using CairoMakie
using StatsBase: weights, mean

include("particles.jl")
include("resamplers.jl")
include("simple-filters.jl")

## FILTERING DEMONSTRATION #################################################################

# use a local level trend model
function simulation_model(σx²::T, σy²::T) where {T<:Real}
init = Gaussian(zeros(T, 2), PDMat(diagm(ones(T, 2))))
dyn = LinearGaussianLatentDynamics(T[1 1; 0 1], T[0; 0], [σx² 0; 0 0], init)
obs = LinearGaussianObservationProcess(T[1 0], [σy²;;])
return StateSpaceModel(dyn, obs)
end

true_params = randexp(Float32, 2);
true_model = simulation_model(true_params...);

# simulate data
rng = MersenneTwister(1234);
_, _, data = sample(rng, true_model, 150);

# test the adaptive resampling procedure
states, llbf = sample(rng, true_model, data, BF(2048, 0.5); store_ancestry=true);

# plot the smoothed states to validate the algorithm
smoothed_trend = begin
fig = Figure(; size=(1200, 400))
ax1 = Axis(fig[1, 1])
ax2 = Axis(fig[1, 2])

# this is gross but it works fro visualization purposes
all_paths = map(x -> hcat(x...), get_ancestry(states.tree))
mean_paths = mean(all_paths, weights(softmax(states.log_weights)))
n_paths = length(all_paths)

# plot smoothed states in black and observed data in red
lines!(ax1, mean_paths[1, :]; color=:black)
lines!(ax1, vcat(0, data...); color=:red, linestyle=:dash)

# plot ancestry tree in graded black and data in red
lines!.(ax2, getindex.(all_paths, 1, :), color=(:black, maximum([2 / n_paths, 1e-2])))
lines!(ax2, vcat(0, data...); color=:red, linestyle=:dash)

fig
end

## PARTICLE MCMC ###########################################################################

# consider a default Gamma prior with Float32s
prior_dist = product_distribution(Gamma(1.0f0), Gamma(1.0f0));

# basic RWMH ala AdvancedMH
function density(θ::Vector{T}) where {T<:Real}
if insupport(prior_dist, θ)
# _, ll = sample(rng, simulation_model(θ...), data, BF(512))
_, ll = sample(rng, simulation_model(θ...), data, KF())
return ll + logpdf(prior_dist, θ)
else
return -Inf
end
end

pmmh = RWMH(MvNormal(zeros(Float32, 2), (0.01f0) * I));
model = DensityModel(density);

# works with AdvancedMH out of the box
chains = sample(model, pmmh, 50_000);
burn_in = 1_000;

# plot the posteriors
hist_plots = begin
param_post = hcat(getproperty.(chains[burn_in:end], :params)...)
fig = Figure(; size=(1200, 400))

for i in 1:2
# plot the posteriors with burn-in
hist(
fig[1, i],
param_post[i, :];
color=(:black, 0.4),
strokewidth=1,
normalization=:pdf,
)

# plot the true values
vlines!(fig[1, i], true_params[i]; color=:red, linestyle=:dash, linewidth=3)
end

fig
end

# this is useful for SMC algorithms like SMC² or density tempered SMC
acc_ratio = mean(getproperty.(chains, :accepted))
Loading
Loading