diff --git a/examples/particle-mcmc/script.jl b/examples/particle-mcmc/script.jl index bbb90e1..6856a84 100644 --- a/examples/particle-mcmc/script.jl +++ b/examples/particle-mcmc/script.jl @@ -16,38 +16,47 @@ function simulation_model(σx²::T, σy²::T) where {T<:Real} return StateSpaceModel(dyn, obs) end -true_params = randexp(Float32, 2); +# generate model and data +rng = MersenneTwister(1234) +true_params = randexp(rng, Float32, 2); true_model = simulation_model(true_params...); - -# simulate data -rng = MersenneTwister(1234); _, _, data = sample(rng, true_model, 150); -# test the adaptive resampling procedure -bootstrap_filter = BF(256; threshold=0.5, resampler=Multinomial()); -states, llbf = sample(rng, true_model, bootstrap_filter, data); +# run a bootstrap filter, resampling at every iteration with a Rejection resampler +filter = BF(1024; threshold=1.0, resampler=Rejection()); +sparse_ancestry = begin + M = floor(filter.N * log(filter.N)) + states = initialise(rng, true_model, filter, nothing) + sparse_ancestry = ParticleTree(states.filtered, Int64(M)) + + for t in eachindex(data) + proposed_states = predict(rng, true_model, filter, t, states, nothing) + states, log_marginal = update( + true_model, filter, t, proposed_states, data[t], nothing + ) + + prune!(sparse_ancestry, get_offspring(states.ancestors)) + insert!(sparse_ancestry, states.filtered, states.ancestors) + end + + sparse_ancestry +end; -# plot the smoothed states to validate the algorithm (currently broken) smoothed_trend = try - fig = Figure(; size=(1200, 400)) + fig = Figure(; size=(600, 400)) ax1 = Axis(fig[1, 1]) - ax2 = Axis(fig[1, 2]) # this is gross but it works fro visualization purposes all_paths = map(x -> hcat(x...), get_ancestry(sparse_ancestry)) - mean_paths = mean(all_paths, weights(softmax(states.log_weights))) n_paths = length(all_paths) - # plot smoothed states in black and observed data in red - lines!(ax1, mean_paths[1, :]; color=:black) - lines!(ax1, vcat(0, data...); color=:red, linestyle=:dash) - # plot ancestry tree in graded black and data in red - lines!.(ax2, getindex.(all_paths, 1, :), color=(:black, maximum([2 / n_paths, 1e-2]))) - lines!(ax2, vcat(0, data...); color=:red, linestyle=:dash) + lines!.(ax1, getindex.(all_paths, 1, :), color=(:black, maximum([2 / n_paths, 1e-2]))) + lines!(ax1, vcat(0, data...); color=:red, linestyle=:dash) fig catch + # keep this here until the callbacks are in a stable enough @error "Sparse ancestry storage callbacks not yet implemented, this will error" end