diff --git a/src/message.jl b/src/message.jl index cb6a48ac7..64c85e7c6 100644 --- a/src/message.jl +++ b/src/message.jl @@ -323,18 +323,24 @@ function materialize!(mapping::MessageMapping, dependencies) # Message is initial if it is not clamped and all of the inputs are either clamped or initial is_message_initial = !is_message_clamped && (__check_all(is_clamped_or_initial, messages) && __check_all(is_clamped_or_initial, marginals)) - result, addons = rule( - message_mapping_fform(mapping), - mapping.vtag, - mapping.vconstraint, - mapping.msgs_names, - messages, - mapping.marginals_names, - marginals, - mapping.meta, - mapping.addons, - mapping.factornode - ) + result, addons = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages))) + missing, mapping.addons + elseif !isnothing(marginals) && any(ismissing, TupleTools.flatten(getdata.(marginals))) + missing, mapping.addons + else + rule( + message_mapping_fform(mapping), + mapping.vtag, + mapping.vconstraint, + mapping.msgs_names, + messages, + mapping.marginals_names, + marginals, + mapping.meta, + mapping.addons, + mapping.factornode + ) + end # Inject extra addons after the rule has been executed addons = message_mapping_addons(mapping, getdata(messages), getdata(marginals), result, addons) diff --git a/src/variables/data.jl b/src/variables/data.jl index dac060e9e..3ab823c4f 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -5,6 +5,8 @@ import Base: show mutable struct DataVariable{D, S} <: AbstractVariable name :: Symbol collection_type :: AbstractVariableCollectionType + prediction :: MarginalObservable + input_messages :: Vector{MessageObservable{AbstractMessage}} messageout :: S nconnected :: Int end @@ -70,7 +72,7 @@ datavar(name::Symbol, ::Type{D}, dims::Tuple) where {D} datavar(name::Symbol, ::Type{D}, dims::Vararg{Int}) where {D} = datavar(DataVariableCreationOptions(D), name, D, dims) datavar(options::DataVariableCreationOptions{S}, name::Symbol, ::Type{D}, collection_type::AbstractVariableCollectionType = VariableIndividual()) where {S, D} = - DataVariable{D, S}(name, collection_type, options.subject, 0) + DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0) function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, length::Int) where {D} return map(i -> datavar(similar(options), name, D, VariableVector(i)), 1:length) @@ -165,5 +167,21 @@ setanonymous!(::DataVariable, ::Bool) = nothing function setmessagein!(datavar::DataVariable, ::Int, messagein) datavar.nconnected += 1 + push!(datavar.input_messages, messagein) return nothing end + +marginal_prod_fn(datavar::DataVariable) = + marginal_prod_fn(FoldLeftProdStrategy(), ProdAnalytical(), UnspecifiedFormConstraint(), FormConstraintCheckLast()) + +_getprediction(datavar::DataVariable) = datavar.prediction +_setprediction!(datavar::DataVariable, observable) = connect!(_getprediction(datavar), observable) +_makeprediction(datavar::DataVariable) = collectLatest(AbstractMessage, Marginal, datavar.input_messages, marginal_prod_fn(datavar)) + +# options here must implement at least `Rocket.getscheduler` +function activate!(datavar::DataVariable, options) + + _setprediction!(datavar, _makeprediction(datavar)) + + return nothing +end \ No newline at end of file diff --git a/src/variables/random.jl b/src/variables/random.jl index 900addc00..919da8571 100644 --- a/src/variables/random.jl +++ b/src/variables/random.jl @@ -291,4 +291,4 @@ function initialize_output_messages!(chain::EqualityChain, randomvar::RandomVari randomvar.output_initialised = true return nothing -end +end \ No newline at end of file diff --git a/src/variables/variable.jl b/src/variables/variable.jl index 0dc48ea36..5d2be4800 100644 --- a/src/variables/variable.jl +++ b/src/variables/variable.jl @@ -1,7 +1,7 @@ export AbstractVariable, degree export is_clamped, is_marginalisation, is_moment_matching export FoldLeftProdStrategy, FoldRightProdStrategy, CustomProdStrategy -export getmarginal, getmarginals, setmarginal!, setmarginals!, name, as_variable +export getprediction, getpredictions, getmarginal, getmarginals, setmarginal!, setmarginals!, name, as_variable export setmessage!, setmessages! using Rocket @@ -80,6 +80,9 @@ add_pipeline_stage!(variable::AbstractVariable, stage) = error("Its not possible # Helper functions # Getters +getprediction(variable::AbstractVariable) = _getprediction(variable) +getpredictions(variables::AbstractArray{<:AbstractVariable}) = collectLatest(map(v -> getprediction(v), variables)) + getmarginal(variable::AbstractVariable) = getmarginal(variable, SkipInitial()) getmarginal(variable::AbstractVariable, skip_strategy::MarginalSkipStrategy) = apply_skip_filter(_getmarginal(variable), skip_strategy)