Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify inference for predictions functionality #51

Merged
merged 35 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b310e3a
Modify inference for predictions functionality
albertpod Jan 24, 2023
bc6ad92
Merge branch 'main' into dev-predict
albertpod Jan 24, 2023
abc5edb
Make format
albertpod Jan 24, 2023
b67474c
WIP: update inference function
albertpod Jan 30, 2023
5fe1ac6
Make format
albertpod Jan 30, 2023
1c9eb9b
WIP: Change inference
albertpod Jan 30, 2023
11077a0
Update inference
albertpod Feb 1, 2023
a9cb543
Merge branch 'main' into dev-predict
albertpod Feb 2, 2023
39c91ce
Add tests
albertpod Feb 6, 2023
4b223ed
Merge branch 'main' into dev-predict
albertpod Feb 7, 2023
b7aed19
Merge branch 'main' of https://github.com/biaslab/RxInfer.jl into dev…
albertpod Feb 13, 2023
c77e146
Merge branch 'main' into dev-predict
albertpod Feb 22, 2023
cef2171
Merge branch 'main' into dev-predict
albertpod Mar 6, 2023
528164d
Merge main into dev-predict
albertpod Jun 19, 2023
976a38e
Merge branch 'main' into dev-predict
albertpod Jul 23, 2023
31f35f5
Merge branch 'main' into dev-predict
albertpod Sep 9, 2023
eb9c691
Merge branch 'main' into dev-predict
albertpod Sep 12, 2023
ccf4489
fix: fix datavar tests
bvdmitri Sep 13, 2023
7f53ab7
improve check data is missing
bvdmitri Sep 13, 2023
3589fa0
more tests
bvdmitri Sep 13, 2023
b4a29fb
Update inference function
albertpod Sep 13, 2023
9f59b0d
Make format
albertpod Sep 13, 2023
447efcc
Make format
albertpod Sep 13, 2023
bb993ac
Update src/inference.jl
albertpod Sep 18, 2023
829b921
Update src/inference.jl
albertpod Sep 18, 2023
5eb6c46
Update src/inference.jl
albertpod Sep 18, 2023
d98d3c3
Update src/inference.jl
albertpod Sep 18, 2023
b0dd137
Add prediction test for coin model
albertpod Sep 18, 2023
2d2ac7f
Update tests
albertpod Sep 18, 2023
3eb5cda
Make format
albertpod Sep 18, 2023
1a54d0f
fix tests
bvdmitri Sep 18, 2023
9dd6738
fix inference tests
bvdmitri Sep 18, 2023
bdafd5d
fix examples
bvdmitri Sep 18, 2023
47ccf2b
update: Bump version to 2.12.0
bvdmitri Sep 19, 2023
c309034
Merge branch 'main' into dev-predict
bvdmitri Sep 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ function inference(;
end
end
else # In this case, the prediction functionality should only be performed if the data allows missings and actually contains missing values.
foreach((variable, value) -> isdata(value) && __inference_check_dataismissing(data[variable]) && !allows_missings(value) ? error("datavar $(variable) has missings inside but does not allow it. Add `where {allow_missing = true }`") : nothing, keys(vardict), values(vardict))
predictvars = Dict(
variable => KeepLast() for (variable, value) in pairs(vardict) if
(isdata(value) && haskey(data, variable) && allows_missings(value) && __inference_check_dataismissing(data[variable]) && !isanonymous(value))
Expand Down Expand Up @@ -684,6 +685,10 @@ function inference(;

is_free_energy, S, T = unwrap_free_energy_option(free_energy)

if !isempty(actors_pr) && is_free_energy
error("Cannot compute Bethe Free Energy for models with prediction variables. Please set `free_energy = false`.")
albertpod marked this conversation as resolved.
Show resolved Hide resolved
end

if is_free_energy
fe_actor = ScoreActor(S, _iterations, 1)
fe_objective = BetheFreeEnergy(BetheFreeEnergyDefaultMarginalSkipStrategy, AsapScheduler(), free_energy_diagnostics)
Expand Down
53 changes: 47 additions & 6 deletions test/test_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -722,11 +722,6 @@ end

@testset "Predictions functionality" begin

# Given the current implementation of RxInfer ecosystem, this rule needs to be predifined
@rule NormalMeanPrecision(:μ, Marginalisation) (m_out::PointMass, q_τ::PointMass) = begin
return missing
end

# test #1 (array with missing + predictvars)
data = (y = [1.0, -500.0, missing, 100.0],)

Expand Down Expand Up @@ -857,7 +852,7 @@ end
@test haskey(result.predictions, :y)
@test typeof(result.predictions[:y]) <: NormalDistributionsFamily

# test vmp model
# test #7 vmp model
data = (y = [1.0, -10.0, 5.0],)
@model function vmp_model(n)
x = randomvar(n + 1)
Expand Down Expand Up @@ -894,6 +889,52 @@ end

@test first(result.posteriors[:γ]) != last(result.posteriors[:γ])
@test first(result.predictions[:o]) != last(result.predictions[:o])


# test #8 non gaussian likelihood (single datavar missing)
dataset = [1.0, 0.0, 1.0, missing,]
@model function coin_model(n)

y = datavar(Float64, n) where {allow_missing = true}

θ ~ Beta(4.0, 8.0)
for i in 1:n
y[i] ~ Bernoulli(θ)
end

end

result = inference(
model = coin_model(length(dataset)),
data = (y = dataset, )
)

@test typeof(last(result.predictions[:y])) <: Bernoulli
albertpod marked this conversation as resolved.
Show resolved Hide resolved

# test #9 allow_missing error handling
dataset = [1.0, 0.0, 1.0, missing,]
@model function coin_model(n)

y = datavar(Float64, n)

θ ~ Beta(4.0, 8.0)
for i in 1:n
y[i] ~ Bernoulli(θ)
end

end

@test_throws ErrorException inference(
model = coin_model(length(dataset)),
data = (y = dataset, )
)

#test #10 free_energy error handling
@test_throws ErrorException inference(
model = coin_model(length(dataset)),
data = (y = dataset, ),
free_energy = true
)
end

end
Loading