Skip to content

Commit

Permalink
added hacky sparse ancestry to example
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesknipp committed Oct 5, 2024
1 parent 8cb4338 commit 73dd433
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions examples/particle-mcmc/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,47 @@ function simulation_model(σx²::T, σy²::T) where {T<:Real}
return StateSpaceModel(dyn, obs)
end

true_params = randexp(Float32, 2);
# generate model and data
rng = MersenneTwister(1234)
true_params = randexp(rng, Float32, 2);
true_model = simulation_model(true_params...);

# simulate data
rng = MersenneTwister(1234);
_, _, data = sample(rng, true_model, 150);

# test the adaptive resampling procedure
bootstrap_filter = BF(256; threshold=0.5, resampler=Multinomial());
states, llbf = sample(rng, true_model, bootstrap_filter, data);
# run a bootstrap filter, resampling at every iteration with a Rejection resampler
filter = BF(1024; threshold=1.0, resampler=Rejection());
sparse_ancestry = begin
M = floor(filter.N * log(filter.N))
states = initialise(rng, true_model, filter, nothing)
sparse_ancestry = ParticleTree(states.filtered, Int64(M))

for t in eachindex(data)
proposed_states = predict(rng, true_model, filter, t, states, nothing)
states, log_marginal = update(
true_model, filter, t, proposed_states, data[t], nothing
)

prune!(sparse_ancestry, get_offspring(states.ancestors))
insert!(sparse_ancestry, states.filtered, states.ancestors)
end

sparse_ancestry
end;

# plot the smoothed states to validate the algorithm (currently broken)
smoothed_trend = try
fig = Figure(; size=(1200, 400))
fig = Figure(; size=(600, 400))
ax1 = Axis(fig[1, 1])
ax2 = Axis(fig[1, 2])

# this is gross but it works fro visualization purposes
all_paths = map(x -> hcat(x...), get_ancestry(sparse_ancestry))
mean_paths = mean(all_paths, weights(softmax(states.log_weights)))
n_paths = length(all_paths)

# plot smoothed states in black and observed data in red
lines!(ax1, mean_paths[1, :]; color=:black)
lines!(ax1, vcat(0, data...); color=:red, linestyle=:dash)

# plot ancestry tree in graded black and data in red
lines!.(ax2, getindex.(all_paths, 1, :), color=(:black, maximum([2 / n_paths, 1e-2])))
lines!(ax2, vcat(0, data...); color=:red, linestyle=:dash)
lines!.(ax1, getindex.(all_paths, 1, :), color=(:black, maximum([2 / n_paths, 1e-2])))
lines!(ax1, vcat(0, data...); color=:red, linestyle=:dash)

fig
catch
# keep this here until the callbacks are in a stable enough
@error "Sparse ancestry storage callbacks not yet implemented, this will error"
end

Expand Down

0 comments on commit 73dd433

Please sign in to comment.