Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify inference for predictions functionality #51

Merged
merged 35 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b310e3a
Modify inference for predictions functionality
albertpod Jan 24, 2023
bc6ad92
Merge branch 'main' into dev-predict
albertpod Jan 24, 2023
abc5edb
Make format
albertpod Jan 24, 2023
b67474c
WIP: update inference function
albertpod Jan 30, 2023
5fe1ac6
Make format
albertpod Jan 30, 2023
1c9eb9b
WIP: Change inference
albertpod Jan 30, 2023
11077a0
Update inference
albertpod Feb 1, 2023
a9cb543
Merge branch 'main' into dev-predict
albertpod Feb 2, 2023
39c91ce
Add tests
albertpod Feb 6, 2023
4b223ed
Merge branch 'main' into dev-predict
albertpod Feb 7, 2023
b7aed19
Merge branch 'main' of https://github.com/biaslab/RxInfer.jl into dev…
albertpod Feb 13, 2023
c77e146
Merge branch 'main' into dev-predict
albertpod Feb 22, 2023
cef2171
Merge branch 'main' into dev-predict
albertpod Mar 6, 2023
528164d
Merge main into dev-predict
albertpod Jun 19, 2023
976a38e
Merge branch 'main' into dev-predict
albertpod Jul 23, 2023
31f35f5
Merge branch 'main' into dev-predict
albertpod Sep 9, 2023
eb9c691
Merge branch 'main' into dev-predict
albertpod Sep 12, 2023
ccf4489
fix: fix datavar tests
bvdmitri Sep 13, 2023
7f53ab7
improve check data is missing
bvdmitri Sep 13, 2023
3589fa0
more tests
bvdmitri Sep 13, 2023
b4a29fb
Update inference function
albertpod Sep 13, 2023
9f59b0d
Make format
albertpod Sep 13, 2023
447efcc
Make format
albertpod Sep 13, 2023
bb993ac
Update src/inference.jl
albertpod Sep 18, 2023
829b921
Update src/inference.jl
albertpod Sep 18, 2023
5eb6c46
Update src/inference.jl
albertpod Sep 18, 2023
d98d3c3
Update src/inference.jl
albertpod Sep 18, 2023
b0dd137
Add prediction test for coin model
albertpod Sep 18, 2023
2d2ac7f
Update tests
albertpod Sep 18, 2023
3eb5cda
Make format
albertpod Sep 18, 2023
1a54d0f
fix tests
bvdmitri Sep 18, 2023
9dd6738
fix inference tests
bvdmitri Sep 18, 2023
bdafd5d
fix examples
bvdmitri Sep 18, 2023
47ccf2b
update: Bump version to 2.12.0
bvdmitri Sep 19, 2023
c309034
Merge branch 'main' into dev-predict
bvdmitri Sep 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 56 additions & 19 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import ReactiveMP: CountingReal

import ProgressMeter

obtain_prediction(variable::AbstractVariable) = getprediction(variable)
obtain_prediction(variables::AbstractArray{<:AbstractVariable}) = getpredictions(variables)

obtain_marginal(variable::AbstractVariable, strategy = SkipInitial()) = getmarginal(variable, strategy)
obtain_marginal(variables::AbstractArray{<:AbstractVariable}, strategy = SkipInitial()) = getmarginals(variables, strategy)

Expand All @@ -24,19 +27,19 @@ assign_message!(variable::AbstractVariable, message) = setmes
struct KeepEach end
struct KeepLast end

make_actor(::RandomVariable, ::KeepEach) = keep(Marginal)
make_actor(::Array{<:RandomVariable, N}, ::KeepEach) where {N} = keep(Array{Marginal, N})
make_actor(x::AbstractArray{<:RandomVariable}, ::KeepEach) = keep(typeof(similar(x, Marginal)))
make_actor(::AbstractVariable, ::KeepEach) = keep(Marginal)
make_actor(::Array{<:AbstractVariable, N}, ::KeepEach) where {N} = keep(Array{Marginal, N})
make_actor(x::AbstractArray{<:AbstractVariable}, ::KeepEach) = keep(typeof(similar(x, Marginal)))

make_actor(::RandomVariable, ::KeepEach, capacity::Integer) = circularkeep(Marginal, capacity)
make_actor(::Array{<:RandomVariable, N}, ::KeepEach, capacity::Integer) where {N} = circularkeep(Array{Marginal, N}, capacity)
make_actor(x::AbstractArray{<:RandomVariable}, ::KeepEach, capacity::Integer) = circularkeep(typeof(similar(x, Marginal)), capacity)
make_actor(::AbstractVariable, ::KeepEach, capacity::Integer) = circularkeep(Marginal, capacity)
make_actor(::Array{<:AbstractVariable, N}, ::KeepEach, capacity::Integer) where {N} = circularkeep(Array{Marginal, N}, capacity)
make_actor(x::AbstractArray{<:AbstractVariable}, ::KeepEach, capacity::Integer) = circularkeep(typeof(similar(x, Marginal)), capacity)

make_actor(::RandomVariable, ::KeepLast) = storage(Marginal)
make_actor(x::AbstractArray{<:RandomVariable}, ::KeepLast) = buffer(Marginal, size(x))
make_actor(::AbstractVariable, ::KeepLast) = storage(Marginal)
make_actor(x::AbstractArray{<:AbstractVariable}, ::KeepLast) = buffer(Marginal, size(x))

make_actor(::RandomVariable, ::KeepLast, capacity::Integer) = storage(Marginal)
make_actor(x::AbstractArray{<:RandomVariable}, ::KeepLast, capacity::Integer) = buffer(Marginal, size(x))
make_actor(::AbstractVariable, ::KeepLast, capacity::Integer) = storage(Marginal)
make_actor(x::AbstractArray{<:AbstractVariable}, ::KeepLast, capacity::Integer) = buffer(Marginal, size(x))

## Inference ensure update

Expand Down Expand Up @@ -173,15 +176,16 @@ This structure is used as a return value from the [`inference`](@ref) function.

See also: [`inference`](@ref)
"""
struct InferenceResult{P, F, M, R}
struct InferenceResult{P, A, F, M, R}
posteriors :: P
predictions :: A
free_energy :: F
model :: M
returnval :: R
end

Base.iterate(results::InferenceResult) = iterate((getfield(results, :posteriors), getfield(results, :free_energy), getfield(results, :model), getfield(results, :returnval)))
Base.iterate(results::InferenceResult, any) = iterate((getfield(results, :posteriors), getfield(results, :free_energy), getfield(results, :model), getfield(results, :returnval)), any)
Base.iterate(results::InferenceResult) = iterate((getfield(results, :posteriors), getfield(results, :predictions), getfield(results, :free_energy), getfield(results, :model), getfield(results, :returnval)))
Base.iterate(results::InferenceResult, any) = iterate((getfield(results, :posteriors), getfield(results, :predictions), getfield(results, :free_energy), getfield(results, :model), getfield(results, :returnval)), any)

function Base.show(io::IO, result::InferenceResult)
print(io, "Inference results:\n")
Expand All @@ -193,6 +197,13 @@ function Base.show(io::IO, result::InferenceResult)
join(io, keys(getfield(result, :posteriors)), ", ")
print(io, ")\n")

if !isnothing(getfield(result, :predictions))
print(io, rpad(" Predictions", lcolumnlen), " | ")
print(io, "available for (")
join(io, keys(getfield(result, :predictions)), ", ")
print(io, ")\n")
end

if !isnothing(getfield(result, :free_energy))
print(io, rpad(" Free Energy:", lcolumnlen), " | ")
print(IOContext(io, :compact => true, :limit => true, :displaysize => (1, 80)), result.free_energy)
Expand Down Expand Up @@ -431,6 +442,8 @@ function inference(;
model::ModelGenerator,
# NamedTuple or Dict with data, required
data,
# NamedTuple or Dict with predictions
predictions = nothing, # optional
# NamedTuple or Dict with initial marginals, optional, defaults to empty
initmarginals = nothing,
# NamedTuple or Dict with initial messages, optional, defaults to empty
Expand All @@ -443,6 +456,8 @@ function inference(;
options = nothing,
# Return structure info, optional, defaults to return everything at each iteration
returnvars = nothing,
# Return structure info, optional, defaults to return everything at each iteration
albertpod marked this conversation as resolved.
Show resolved Hide resolved
predictvars = nothing,
# Number of iterations, defaults to 1, we do not distinguish between VMP or Loopy belief or EP iterations
iterations = nothing,
# Do we compute FE, optional, defaults to false
Expand All @@ -462,6 +477,7 @@ function inference(;
warn = true
)
__inference_check_dicttype(:data, data)
__inference_check_dicttype(:predictions, predictions)
__inference_check_dicttype(:initmarginals, initmarginals)
__inference_check_dicttype(:initmessages, initmessages)
__inference_check_dicttype(:callbacks, callbacks)
Expand Down Expand Up @@ -508,7 +524,13 @@ function inference(;
returnvars = Dict(variable => returnoption for (variable, value) in pairs(vardict) if (israndom(value) && !isanonymous(value)))
end

# Check if `predictvars` is nothing but `data` has missing values
if predictvars === nothing
predictvars = Dict(variable => KeepLast() for (variable, value) in pairs(vardict) if (isdata(value) && !isempty(findall(ismissing, data[variable])) && !isanonymous(value)))
bvdmitri marked this conversation as resolved.
Show resolved Hide resolved
end

__inference_check_dicttype(:returnvars, returnvars)
__inference_check_dicttype(:predictvars, predictvars)

# Use `__check_has_randomvar` to filter out unknown or non-random variables in the `returnvar` specification
__check_has_randomvar(vardict, variable) = begin
Expand All @@ -522,19 +544,33 @@ function inference(;
return haskey_check && israndom_check
end

__check_has_prediction(vardict, variable) = begin
haskey_check = haskey(vardict, variable)
isdata_check = haskey_check ? isdata(vardict[variable]) : false
if warn && !haskey_check
@warn "`predictvars` object has `$(variable)` specification, but model has no variable named `$(variable)`. The `$(variable)` specification is ignored. Use `warn = false` to suppress this warning."
elseif warn && haskey_check && !isdata_check
@warn "`predictvars` object has `$(variable)` specification, but model has no **data** variable named `$(variable)`. The `$(variable)` specification is ignored. Use `warn = false` to suppress this warning."
end
return haskey_check && isdata_check
end

# Second, for each random variable entry we create an actor
actors = Dict(variable => make_actor(vardict[variable], value) for (variable, value) in pairs(returnvars) if __check_has_randomvar(vardict, variable))
actors_rv = Dict(variable => make_actor(vardict[variable], value) for (variable, value) in pairs(returnvars) if __check_has_randomvar(vardict, variable))
actors_pr = Dict(variable => make_actor(vardict[variable], value) for (variable, value) in pairs(predictvars) if __check_has_prediction(vardict, variable))

# At third, for each random variable entry we create a boolean flag to track their updates
updates = Dict(variable => MarginalHasBeenUpdated(false) for (variable, _) in pairs(actors))
updates = Dict(variable => MarginalHasBeenUpdated(false) for (variable, _) in pairs(merge(actors_rv, actors_pr)))
# updates = Dict(variable => MarginalHasBeenUpdated(false) for (variable, _) in pairs(actors_pr))

_iterations = something(iterations, 1)
_iterations isa Integer || error("`iterations` argument must be of type Integer or `nothing`")
_iterations > 0 || error("`iterations` arguments must be greater than zero")

try
on_marginal_update = inference_get_callback(callbacks, :on_marginal_update)
subscriptions = Dict(variable => subscribe!(obtain_marginal(vardict[variable]) |> ensure_update(fmodel, on_marginal_update, variable, updates[variable]), actor) for (variable, actor) in pairs(actors))
subscriptions_rv = Dict(variable => subscribe!(obtain_marginal(vardict[variable]) |> ensure_update(fmodel, on_marginal_update, variable, updates[variable]), actor) for (variable, actor) in pairs(actors_rv))
subscriptions_pr = Dict(variable => subscribe!(obtain_prediction(vardict[variable]) |> ensure_update(fmodel, on_marginal_update, variable, updates[variable]), actor) for (variable, actor) in pairs(actors_pr))

fe_actor = nothing
fe_subscription = VoidTeardown()
Expand Down Expand Up @@ -607,7 +643,7 @@ function inference(;
inference_invoke_callback(callbacks, :after_iteration, fmodel, iteration)
end

for (_, subscription) in pairs(subscriptions)
for (_, subscription) in pairs(merge(subscriptions_pr, subscriptions_rv))
unsubscribe!(subscription)
end

Expand All @@ -617,12 +653,13 @@ function inference(;

unsubscribe!(fe_subscription)

posterior_values = Dict(variable => __inference_postprocess(postprocess, getvalues(actor)) for (variable, actor) in pairs(actors))
posterior_values = Dict(variable => __inference_postprocess(postprocess, getvalues(actor)) for (variable, actor) in pairs(actors_rv))
predicted_values = Dict(variable => __inference_postprocess(postprocess, getvalues(actor)) for (variable, actor) in pairs(actors_pr))
fe_values = !isnothing(fe_actor) ? score_snapshot_iterations(fe_actor) : nothing

inference_invoke_callback(callbacks, :after_inference, fmodel)

return InferenceResult(posterior_values, fe_values, fmodel, freturval)
return InferenceResult(posterior_values, predicted_values, fe_values, fmodel, freturval)
catch error
__inference_process_error(error)
end
Expand Down
1 change: 1 addition & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ function ReactiveMP.activate!(model::FactorGraphModel)

filter!(c -> isconnected(c), getconstant(model))
foreach(r -> activate!(r, options), getrandom(model))
foreach(d -> activate!(d, options), getdata(model))
foreach(n -> activate!(n, options), getnodes(model))
end

Expand Down