Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build examples with doc #19

Merged
merged 5 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/literate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Retrieve name of example and output directory
if length(ARGS) != 2
error("please specify the name of the example and the output directory")
end
const EXAMPLE = ARGS[1]
const OUTDIR = ARGS[2]

# Activate environment
# Note that each example's Project.toml must include Literate as a dependency
using Pkg: Pkg
const EXAMPLEPATH = joinpath(@__DIR__, "..", "examples", EXAMPLE)
Pkg.activate(EXAMPLEPATH)
Pkg.instantiate()
using Literate: Literate

# Convert to markdown and notebook
const SCRIPTJL = joinpath(EXAMPLEPATH, "script.jl")
Literate.markdown(SCRIPTJL, OUTDIR; name=EXAMPLE, execute=true)
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mkpath(EXAMPLES_OUT)
# Workaround for https://github.com/JuliaLang/Pkg.jl/issues/2219
examples = filter!(isdir, readdir(joinpath(@__DIR__, "..", "examples"); join=true))
above = joinpath(@__DIR__, "..")
let script = "using Pkg; Pkg.activate(ARGS[1]); Pkg.instantiate(); Pkg.develop(path=\"$(above)\");"
let script = "using Pkg; Pkg.activate(ARGS[1]); Pkg.develop(path=\"$(above)\"); Pkg.instantiate()"
for example in examples
if !success(`$(Base.julia_cmd()) -e $script $example`)
error(
Expand Down Expand Up @@ -50,7 +50,7 @@ DocMeta.setdocmeta!(SSMProblems, :DocTestSetup, :(using SSMProblems); recursive=

makedocs(;
sitename="SSMProblems",
format=Documenter.HTML(),
format=Documenter.HTML(; size_threshold=1000 * 2^11), # 1Mb per page
modules=[SSMProblems],
pages=[
"Home" => "index.md",
Expand Down
7 changes: 7 additions & 0 deletions examples/smc/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
14 changes: 6 additions & 8 deletions examples/smc.jl → examples/smc/script.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# # Partilce Filter with adaptive resampling
using Random
using SSMProblems
using Distributions
Expand Down Expand Up @@ -26,15 +27,13 @@ function sweep!(
logweights = zeros(N)

for (timestep, observation) in enumerate(observations)
# Resample step
weights = get_weights(logweights)
if ess(weights) <= threshold * N
idx = resampling(rng, weights)
particles = particles[idx]
fill!(logweights, 0)
end

# Mutation step
for i in eachindex(particles)
latent_state = transition!!(rng, model, particles[i].state, timestep)
particles[i] = SSMProblems.Utils.Particle(particles[i], latent_state)
Expand All @@ -44,7 +43,6 @@ function sweep!(
end
end

# Return unweighted set
idx = resampling(rng, get_weights(logweights))
return particles[idx]
end
Expand Down Expand Up @@ -73,9 +71,9 @@ Base.@kwdef struct LinearSSM <: AbstractStateSpaceModel
end

# Simulation
T = 250
T = 150
seed = 1
N = 1_000
N = 500
rng = MersenneTwister(seed)

model = LinearSSM(0.2, 0.7)
Expand Down Expand Up @@ -111,6 +109,6 @@ end
samples = sample(rng, LinearSSM(), N, observations)
traces = reverse(hcat(map(SSMProblems.Utils.linearize, samples)...))

scatter(traces; color=:black, opacity=0.3, label=false)
plot!(x; label="True state")
plot!(mean(traces; dims=2); label="Posterior mean")
scatter(traces[:, 1:10]; color=:black, opacity=0.7, label=false)
plot!(x; label="True state", linewidth=2)
plot!(mean(traces; dims=2); label="Posterior mean", color=:orange, linewidth=2)
Loading