Skip to content

Commit

Permalink
improved particle state containment and resampling
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesknipp committed Oct 4, 2024
1 parent 1fa3c93 commit 8cb4338
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 130 deletions.
92 changes: 24 additions & 68 deletions examples/particle-mcmc/particles.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,21 @@
using DataStructures: Stack
using StatsBase

## PARTICLES ###############################################################################

abstract type AbstractParticleContainer{T} end

"""
store!(particles, new_states, [idx])
update the state component of the particle container, with optional parent indices supplied
for use in ancestry storage.
"""
function store! end

"""
reset_weights!(particles)
in-place method to reset the log weights of the particle cloud to zero; typically called
following a resampling step.
"""
function reset_weights! end

mutable struct ParticleContainer{T,WT<:Real} <: AbstractParticleContainer{T}
vals::Vector{T}
mutable struct ParticleContainer{T,WT<:Real}
filtered::Vector{T}
proposed::Vector{T}
ancestors::Vector{Int64}
log_weights::Vector{WT}

function ParticleContainer(
initial_states::Vector{T}, log_weights::Vector{WT}
) where {T,WT<:Real}
return new{T,WT}(
initial_states, similar(initial_states), eachindex(log_weights), log_weights
)
end
end

Base.collect(pc::ParticleContainer) = pc.vals
Expand All @@ -33,17 +26,16 @@ Base.keys(pc::ParticleContainer) = LinearIndices(pc.vals)
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Int) = pc.vals[i]
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Vector{Int}) = pc.vals[i]

function store!(pc::ParticleContainer, new_states, idx...; kwargs...)
setindex!(pc.vals, new_states, eachindex(pc))
return pc
end

function reset_weights!(pc::ParticleContainer{T,WT}) where {T,WT<:Real}
fill!(pc.log_weights, zero(WT))
return pc.log_weights
end

## JACOB-MURRAY PARTICLE STORAGE ###########################################################
function StatsBase.weights(pc::ParticleContainer)
return softmax(pc.log_weights)
end

## SPARSE PARTICLE STORAGE #################################################################

Base.append!(s::Stack, a::AbstractVector) = map(x -> push!(s, x), a)

Expand All @@ -69,10 +61,10 @@ Base.length(tree::ParticleTree) = length(tree.states)
Base.keys(tree::ParticleTree) = LinearIndices(tree.states)

function prune!(tree::ParticleTree, offspring::Vector{Int64})
## insert new offspring counts
# insert new offspring counts
setindex!(tree.offspring, offspring, tree.leaves)

## update each branch
# update each branch
@inbounds for i in eachindex(offspring)
j = tree.leaves[i]
while (j > 0) && (tree.offspring[j] == 0)
Expand All @@ -88,21 +80,21 @@ end
function insert!(
tree::ParticleTree{T}, states::Vector{T}, a::AbstractVector{Int64}
) where {T}
## parents of new generation
# parents of new generation
parents = getindex(tree.leaves, a)

## ensure there are enough dead branches
# ensure there are enough dead branches
if (length(tree.free_indices) < length(a))
@debug "expanding tree"
expand!(tree)
end

## find places for new states
# find places for new states
@inbounds for i in eachindex(states)
tree.leaves[i] = pop!(tree.free_indices)
end

## insert new generation and update parent child relationships
# insert new generation and update parent child relationships
setindex!(tree.states, states, tree.leaves)
setindex!(tree.parents, parents, tree.leaves)
return tree
Expand All @@ -127,42 +119,6 @@ function get_offspring(a::AbstractVector{Int64})
return offspring
end

## FILTERING WITH ANCESTRY #################################################################

mutable struct AncestryContainer{T,WT<:Real} <: AbstractParticleContainer{T}
tree::ParticleTree{T}
log_weights::Vector{WT}

function AncestryContainer(
initial_states::Vector{T}, log_weights::Vector{WT}, C::Int64=1
) where {T,WT<:Real}
N = length(log_weights)
M = floor(C * N * log(N))
tree = ParticleTree(initial_states, Int64(M))
return new{T,WT}(tree, log_weights)
end
end

function Base.collect(ac::AncestryContainer)
return getindex(ac.tree.states, ac.tree.leaves)
end

function Base.getindex(ac::AncestryContainer, a::AbstractVector{Int64})
return getindex(ac.tree.states, getindex(ac.tree.leaves, a))
end

function reset_weights!(ac::AncestryContainer{T,WT}) where {T,WT<:Real}
fill!(ac.log_weights, zero(WT))
return ac.log_weights
end

function store!(ac::AncestryContainer, new_states, idx)
prune!(ac.tree, get_offspring(idx))
insert!(ac.tree, new_states, idx)
return ac
end

# start at each leaf and retrace it's steps to the root node
function get_ancestry(tree::ParticleTree{T}) where {T}
paths = Vector{Vector{T}}(undef, length(tree.leaves))
@inbounds for (k, i) in enumerate(tree.leaves)
Expand Down
53 changes: 42 additions & 11 deletions examples/particle-mcmc/resamplers.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
using Random
using Distributions

function multinomial_resampling(
rng::AbstractRNG, weights::AbstractVector{WT}, n::Int64=length(weights); kwargs...
abstract type AbstractResampler end

## DOUBLE PRECISION STABLE ALGORITHMS ######################################################

struct Multinomial <: AbstractResampler end

function resample(
rng::AbstractRNG, ::Multinomial, weights::AbstractVector{WT}, n::Int64=length(weights)
) where {WT<:Real}
return rand(rng, Distributions.Categorical(weights), n)
end

function systematic_resampling(
rng::AbstractRNG, weights::AbstractVector{WT}, n::Int64=length(weights); kwargs...
struct Systematic <: AbstractResampler end

function resample(
rng::AbstractRNG, ::Systematic, weights::AbstractVector{WT}, n::Int64=length(weights)
) where {WT<:Real}
# pre-calculations
@inbounds v = n * weights[1]
Expand All @@ -30,24 +38,45 @@ function systematic_resampling(
return a
end

function resample(
rng::AbstractRNG,
alg::Systematic,
weights::AbstractVector{Float32},
n::Int64=length(weights),
)
try
return resample(rng, alg, weights, n)
catch e
throw(e("Systematic resampling is not numerically stable for single precision"))
end
end

## SINGLE PRECISION STABLE ALGORITHMS ######################################################

struct Metropolis{T<:Real} <: AbstractResampler
ε::T
function Metropolis::T=0.01) where {T<:Real}
return new{T}(ε)
end
end

# TODO: this should be done in the log domain and also parallelized
function metropolis_resampling(
function resample(
rng::AbstractRNG,
resampler::Metropolis,
weights::AbstractVector{WT},
n::Int64=length(weights);
ε::Float64=0.01,
kwargs...,
) where {WT<:Real}
# pre-calculations
β = mean(weights)
bins = Int64(cld(log(ε), log(1 - β)))
B = Int64(cld(log(resampler.ε), log(1 - β)))

# initialize the algorithm
a = Vector{Int64}(undef, n)

@inbounds for i in 1:n
k = i
for _ in 1:bins
for _ in 1:B
j = rand(rng, 1:n)
v = weights[j] / weights[k]
if rand(rng) v
Expand All @@ -60,9 +89,11 @@ function metropolis_resampling(
return a
end

struct Rejection <: AbstractResampler end

# TODO: this should be done in the log domain and also parallelized
function rejection_resampling(
rng::AbstractRNG, weights::AbstractVector{WT}, n::Int64=length(weights); kwargs...
function resample(
rng::AbstractRNG, ::Rejection, weights::AbstractVector{WT}, n::Int64=length(weights)
) where {WT<:Real}
# pre-calculations
max_weight = maximum(weights)
Expand Down
15 changes: 9 additions & 6 deletions examples/particle-mcmc/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ rng = MersenneTwister(1234);
_, _, data = sample(rng, true_model, 150);

# test the adaptive resampling procedure
states, llbf = sample(rng, true_model, data, BF(2048, 0.5); store_ancestry=true);
bootstrap_filter = BF(256; threshold=0.5, resampler=Multinomial());
states, llbf = sample(rng, true_model, bootstrap_filter, data);

# plot the smoothed states to validate the algorithm
smoothed_trend = begin
# plot the smoothed states to validate the algorithm (currently broken)
smoothed_trend = try
fig = Figure(; size=(1200, 400))
ax1 = Axis(fig[1, 1])
ax2 = Axis(fig[1, 2])

# this is gross but it works fro visualization purposes
all_paths = map(x -> hcat(x...), get_ancestry(states.tree))
all_paths = map(x -> hcat(x...), get_ancestry(sparse_ancestry))
mean_paths = mean(all_paths, weights(softmax(states.log_weights)))
n_paths = length(all_paths)

Expand All @@ -46,6 +47,8 @@ smoothed_trend = begin
lines!(ax2, vcat(0, data...); color=:red, linestyle=:dash)

fig
catch
@error "Sparse ancestry storage callbacks not yet implemented, this will error"
end

## PARTICLE MCMC ###########################################################################
Expand All @@ -56,8 +59,8 @@ prior_dist = product_distribution(Gamma(1.0f0), Gamma(1.0f0));
# basic RWMH ala AdvancedMH
function density::Vector{T}) where {T<:Real}
if insupport(prior_dist, θ)
# _, ll = sample(rng, simulation_model(θ...), data, BF(512))
_, ll = sample(rng, simulation_model...), data, KF())
# _, ll = sample(rng, simulation_model(θ...), BF(512), data)
_, ll = sample(rng, simulation_model...), KF(), data)
return ll + logpdf(prior_dist, θ)
else
return -Inf
Expand Down
Loading

0 comments on commit 8cb4338

Please sign in to comment.