Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface Changes for Use in Filtering #56

Merged
merged 33 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
440252c
added basic particle methods and filters
charlesknipp Aug 9, 2024
9fd4453
added qualifiers
charlesknipp Aug 12, 2024
3fd90c4
added parameter priors
charlesknipp Aug 12, 2024
884b9e3
Merge branch 'main' into ck/particle-methods
charlesknipp Aug 30, 2024
1def6a1
Merge branch 'main' into ck/particle-methods
charlesknipp Sep 24, 2024
a5a2e05
added adaptive resampling to bootstrap filter (WIP)
charlesknipp Sep 25, 2024
57da3ff
Julia fomatter changes
charlesknipp Sep 25, 2024
dc713b0
Merge branch 'ck/particle-methods' of https://github.com/TuringLang/S…
charlesknipp Sep 25, 2024
b846fa4
changed eltype for <: StateSpaceModel
charlesknipp Sep 26, 2024
4263ae7
updated naming conventions
charlesknipp Sep 26, 2024
5a2aeb4
formatter
charlesknipp Sep 26, 2024
8db658b
fixed adaptive resampling
charlesknipp Sep 27, 2024
15dfa9f
added particle ancestry
charlesknipp Oct 1, 2024
7e3c93d
formatter issues
charlesknipp Oct 1, 2024
f905a41
fixed metropolis and added rejection resampler
charlesknipp Oct 1, 2024
8ac1455
Keep track of free indices using stack
THargreaves Oct 2, 2024
f11a63e
updated particle types and organized directory
charlesknipp Oct 2, 2024
1fa3c93
weakened SSM type parameter assertions
charlesknipp Oct 4, 2024
8cb4338
improved particle state containment and resampling
charlesknipp Oct 4, 2024
73dd433
added hacky sparse ancestry to example
charlesknipp Oct 5, 2024
f71ab32
fixed RNG in rejection resampling
charlesknipp Oct 6, 2024
25cebf4
improved callbacks and resamplers
charlesknipp Oct 6, 2024
c729879
formatting
charlesknipp Oct 6, 2024
d13c80c
added conditional SMC
charlesknipp Oct 8, 2024
856cebb
improved linear model type structure
charlesknipp Oct 8, 2024
d7daf93
formatter
charlesknipp Oct 8, 2024
b29ba60
replaced extra with kwargs
charlesknipp Oct 11, 2024
ece40fa
formatter
charlesknipp Oct 11, 2024
75fdf2c
migrated filtering code
charlesknipp Oct 11, 2024
2cc4016
Add unittests for new interface
THargreaves Oct 14, 2024
c76278f
Update documentation to match kwargs
THargreaves Oct 18, 2024
04f9808
Rename extras/kwargs docs file
THargreaves Oct 18, 2024
5a8bba2
remove redundant forward simulations
charlesknipp Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/particle-mcmc/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
73 changes: 73 additions & 0 deletions examples/particle-mcmc/script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using AdvancedMH
using CairoMakie

include("simple-filters.jl")

true_params, simulation_model = let T = Float32
θ = randexp(T, 3)
dyn = LinearGaussianLatentDynamics(T[1 1; 0 1], diagm(θ[1:2]))
obs = LinearGaussianObservationProcess(T[0.5 0.5], diagm(θ[3:end]))
θ, StateSpaceModel(dyn, obs)
end

# simulate data
rng = MersenneTwister(1234)
_, _, data = sample(rng, simulation_model, 150)

# consider a default Gamma prior with Float32s
prior_dist = product_distribution(Gamma(1.0f0), Gamma(1.0f0), Gamma(1.0f0))

# test the adaptive resampling procedure
sample(rng, simulation_model, data, BF(512, 0.1); debug=true);


#=
Not crazy about this structure, especially since the RNG is referenced on
the global scope. I think we can make a PMCMC sampler type which includes
the filter algorithm within the sampler definition.

Another issue is that we lose information on the states. Granted, this is
also by design since that would cost a considerable amount of memory, but
is useful nonetheless. This also needs to interface with bundle_samples()
different than ususal, since we have the parameter space and the filtered
states.
=#
function density(θ::Vector{T}) where {T<:Real}
if insupport(prior_dist, θ)
dyn = LinearGaussianLatentDynamics(T[1 1; 0 1], diagm(θ[1:2]))
obs = LinearGaussianObservationProcess(T[0.5 0.5], diagm(θ[3:end]))

# _, ll = sample(rng, StateSpaceModel(dyn, obs), data, BF(512))
_, ll = sample(rng, StateSpaceModel(dyn, obs), data, KF())
return ll + logpdf(prior_dist, θ)
else
return -Inf
end
end

# plug it into the DensityModel interface for now
pmmh = RWMH(MvNormal(zeros(Float32, 3), (0.01f0) * I))
model = DensityModel(density)

# works with AdvancedMH out of the box
chains = sample(model, pmmh, 10_000)
burn_in = 1_000

# plot the posteriors
hist_plots = begin
param_post = hcat(getproperty.(chains[burn_in:end], :params)...)
fig = Figure(; size=(1200, 400))

for i in 1:3
# plot the posteriors with burn-in
hist(fig[1, i], param_post[i, :]; color=:gray, strokewidth=1, normalization=:pdf)

# plot the true values
vlines!(fig[1, i], true_params[i]; color=:red, linestyle=:dash, linewidth=2)
end

fig
end

# this is useful for SMC algorithms like SMC² or density tempered SMC
acc_ratio = mean(getproperty.(chains, :accepted))
Loading
Loading