Skip to content

Commit

Permalink
Merge branch 'main' into ux-rule-suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri authored Oct 21, 2024
2 parents 277ffe4 + 635a856 commit 9bf9e01
Show file tree
Hide file tree
Showing 4 changed files with 2,883 additions and 2,890 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.10'
- '1.11'
os:
- ubuntu-latest
arch:
Expand Down
5,760 changes: 2,876 additions & 2,884 deletions examples/problem_specific/Autoregressive Models.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions test/ext/ProjectionExt/inference_with_projection_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,9 @@ end
model = mymodel(C = C), data = (y = y,), meta = mymeta(), constraints = myconstraints(), initialization = myinitialization(), free_energy = true, iterations = 40
)

@test mean(result.posteriors[:a][end]) a atol = 2e-2
@test foo(mean(result.posteriors[:a][end]), mean(result.posteriors[:b][end])) foo(a, b) atol = 2e-2
@test mean(result.posteriors[][end]) foo(a, b) atol = 2e-2
@test mean(result.posteriors[:a][end]) a atol = 3e-2
@test foo(mean(result.posteriors[:a][end]), mean(result.posteriors[:b][end])) foo(a, b) atol = 3e-2
@test mean(result.posteriors[][end]) foo(a, b) atol = 3e-2
@test first(result.free_energy) > last(result.free_energy)
@test count(<(0), diff(result.free_energy)) > 0.95

Expand Down
5 changes: 3 additions & 2 deletions test/models/nonlinear/cvi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,16 @@
## -------------------------------------------- ##
mz = res.posteriors[:z]
fe = res.free_energy

@test length(res.posteriors[:z]) === T

@test all(mean.(mz) .- 6 .* std.(mz) .< hidden .< (mean.(mz) .+ 6 .* std.(mz)))
@test (sum((mean.(mz) .- 4 .* std.(mz)) .< hidden .< (mean.(mz) .+ 4 .* std.(mz))) / T) > 0.95
@test (sum((mean.(mz) .- 3 .* std.(mz)) .< hidden .< (mean.(mz) .+ 3 .* std.(mz))) / T) > 0.90

# Free energy for the CVI may fluctuate
@test all(d -> d < 2.5, diff(fe)) # Check that the fluctuations are not big
@test abs(last(fe) - 308) < 1.0 # Check the final result with relatively low precision
@test all(d -> d < 3.0, diff(fe)) # Check that the fluctuations are not big
@test abs(last(fe) - 317) < 1.0 # Check the final result with relatively low precision
@test (first(fe) - last(fe)) > 0

## Create output plots
Expand Down

0 comments on commit 9bf9e01

Please sign in to comment.