Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix{RandomVariable} * RandomVariable errors #163

Closed
gvdr opened this issue Oct 15, 2023 · 1 comment
Closed

Matrix{RandomVariable} * RandomVariable errors #163

gvdr opened this issue Oct 15, 2023 · 1 comment
Assignees
Labels
good first issue Good for newcomers

Comments

@gvdr
Copy link

gvdr commented Oct 15, 2023

Hello. I'm trying to extend one of the examples but hitting a wall.

I'm building upon the Gaussian Linear Dynamical System example.

In particular I'm trying to generalise it so that we don't assume to know the transition matrix $A$ a priori (eventually, I'd like to get to a scenario where we learn most of those matrices). In the example $A$ is passed as an input, and then converted to a constantvar to be used in the model. I would like it to be given a prior in the model, and then inferred together with the rest of the parameters.

I tried a couple of things, but I can't make it work. A minimum reproducible example is consist in redefining model and inference as follows:

@model function rotate_ssm(n, x0, B, Q, P)
    
    # We create constvar references for better efficiency
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
    
    # `x` is a sequence of hidden states
    x = randomvar(n)
    # THIS IS WHERE I DIVERGE FROM EXAMPLE:
    A = randomvar(2,2)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)
    
    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior
    
    α ~ Uniform(0.0,3.0)
    β ~ Uniform(0.0,3.0)

    A .~ Gamma(α,β)
    
    for i in 1:n
        x[i] ~ MvNormalMeanCovariance(A*x_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end

end

And I change the inference function accordingly

result = inference(
    model = rotate_ssm(length(y), x0, B, Q, P), 
    data = (y = y,),
    free_energy = true
)

I tried to specify $A$ in a number of ways, but never getting any luck.

When I try to run it, I get the following error:

ERROR: MethodError: no method matching make_node(::typeof(*), ::FactorNodeCreationOptions{FullFactorisation, Nothing, Nothing}, ::RandomVariable, ::Matrix{RandomVariable}, ::RandomVariable)
@albertpod albertpod self-assigned this Oct 16, 2023
@albertpod albertpod added the good first issue Good for newcomers label Oct 16, 2023
@albertpod
Copy link
Member

albertpod commented Oct 16, 2023

Hi @gvdr! Thanks for trying out RxInfer.jl. Long story short, broadcasting isn't supported in RxInfer.jl.
The reason is mainly the intricacies of graph construction and inference. I see you are trying to build a hierarchical prior by introducing a Uniform prior on top of the parameters of Gamma. This is currently not supported out of the box as well.

One way to circumvent the current problem is somewhat similar to what was suggested in this discussion:
#156
Check docs for Delta node as well:
DISCLAIMER: (1) the code below looks ugly, but this is the only way to enforce Gamma prior for each element of your transition matrix. (2) CVI will be slow, hyperparameter-dependent and inaccurate to mean field constraint in this case; I don't think that inference is accurate, so sampling-based toolboxes (such as Turing or numpyro) could make the inference faster.

# assuming your matrix A has 4 elements. 
# ugly, I know
function f(x, a, b, c, d)
    A = [a b; c d]
    A*x
end

@model function rotate_ssm(n, x0, B, Q, P)
    
    # We create constvar references for better efficiency
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
    
    # `x` is a sequence of hidden states
    x = randomvar(n)
    x̂ = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)
    
    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior
    
    a = randomvar(length(Q))
    
    for i in 1:length(Q)
        a[i] ~ Gamma=1.0, β=1.0)
    end

    for i in 1:n
        # here you'd want to do  x̂[i] ~ f(x_prev, a), but we can't do that yet, but it's coming
        # so we do ugly way
        x̂[i] ~ f(x_prev, a[1], a[2], a[3], a[4])
        x[i] ~ MvNormalMeanCovariance(x̂[i], cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end

end

delta_meta = @meta begin 
    f() -> CVI(StableRNG(42), 100, 200, Optimisers.Descent(0.01))
end


x0 = MvNormalMeanCovariance(zeros(2), 1.0 * diageye(2))
result = inference(model = rotate_ssm(length(y), x0, B, Q, P), options = (limit_stack_depth = 500, ), constraints=MeanField(), 
initmarginals = (x = MvNormalMeanCovariance(zeros(2), 1e4diageye(2)), x̂ = MvNormalMeanCovariance(zeros(2), 1e4diageye(2))),
initmessages=(a = GammaShapeRate(1e-2, 1e2),), 
meta=delta_meta, data = (y = y,), free_energy=true, showprogress=true, iterations=5, returnvars=KeepLast())

Initial marginals are needed due to mean-field constraint; the init message is necessary for the CVI approximation method.

I will convert this issue into a discussion. I will create two issues associated with your question: (1) throwing an error on broadcasting, (2) dealing with broadcasting (this will take some time)

There's of always room for creating your node and the associated rules with that.

To get a better idea what RxInfer.jl can or cannot support, the understanding of Forney-style factor graphs helps significantly.

@ReactiveBayes ReactiveBayes locked and limited conversation to collaborators Oct 16, 2023
@albertpod albertpod converted this issue into discussion #164 Oct 16, 2023

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants