You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In particular I'm trying to generalise it so that we don't assume to know the transition matrix $A$ a priori (eventually, I'd like to get to a scenario where we learn most of those matrices). In the example $A$ is passed as an input, and then converted to a constantvar to be used in the model. I would like it to be given a prior in the model, and then inferred together with the rest of the parameters.
I tried a couple of things, but I can't make it work. A minimum reproducible example is consist in redefining model and inference as follows:
@modelfunctionrotate_ssm(n, x0, B, Q, P)
# We create constvar references for better efficiency
cB =constvar(B)
cQ =constvar(Q)
cP =constvar(P)
# `x` is a sequence of hidden states
x =randomvar(n)
# THIS IS WHERE I DIVERGE FROM EXAMPLE:
A =randomvar(2,2)
# `y` is a sequence of "clamped" observations
y =datavar(Vector{Float64}, n)
x_prior ~MvNormalMeanCovariance(mean(x0), cov(x0))
x_prev = x_prior
α ~Uniform(0.0,3.0)
β ~Uniform(0.0,3.0)
A .~Gamma(α,β)
for i in1:n
x[i] ~MvNormalMeanCovariance(A*x_prev, cQ)
y[i] ~MvNormalMeanCovariance(cB * x[i], cP)
x_prev = x[i]
endend
And I change the inference function accordingly
result =inference(
model =rotate_ssm(length(y), x0, B, Q, P),
data = (y = y,),
free_energy =true
)
I tried to specify $A$ in a number of ways, but never getting any luck.
Hi @gvdr! Thanks for trying out RxInfer.jl. Long story short, broadcasting isn't supported in RxInfer.jl.
The reason is mainly the intricacies of graph construction and inference. I see you are trying to build a hierarchical prior by introducing a Uniform prior on top of the parameters of Gamma. This is currently not supported out of the box as well.
One way to circumvent the current problem is somewhat similar to what was suggested in this discussion: #156
Check docs for Delta node as well:
DISCLAIMER: (1) the code below looks ugly, but this is the only way to enforce Gamma prior for each element of your transition matrix. (2) CVI will be slow, hyperparameter-dependent and inaccurate to mean field constraint in this case; I don't think that inference is accurate, so sampling-based toolboxes (such as Turing or numpyro) could make the inference faster.
# assuming your matrix A has 4 elements. # ugly, I knowfunctionf(x, a, b, c, d)
A = [a b; c d]
A*x
end@modelfunctionrotate_ssm(n, x0, B, Q, P)
# We create constvar references for better efficiency
cB =constvar(B)
cQ =constvar(Q)
cP =constvar(P)
# `x` is a sequence of hidden states
x =randomvar(n)
x̂ =randomvar(n)
# `y` is a sequence of "clamped" observations
y =datavar(Vector{Float64}, n)
x_prior ~MvNormalMeanCovariance(mean(x0), cov(x0))
x_prev = x_prior
a =randomvar(length(Q))
for i in1:length(Q)
a[i] ~Gamma(α=1.0, β=1.0)
endfor i in1:n
# here you'd want to do x̂[i] ~ f(x_prev, a), but we can't do that yet, but it's coming# so we do ugly way
x̂[i] ~f(x_prev, a[1], a[2], a[3], a[4])
x[i] ~MvNormalMeanCovariance(x̂[i], cQ)
y[i] ~MvNormalMeanCovariance(cB * x[i], cP)
x_prev = x[i]
endend
delta_meta =@metabeginf() ->CVI(StableRNG(42), 100, 200, Optimisers.Descent(0.01))
end
x0 =MvNormalMeanCovariance(zeros(2), 1.0*diageye(2))
result =inference(model =rotate_ssm(length(y), x0, B, Q, P), options = (limit_stack_depth =500, ), constraints=MeanField(),
initmarginals = (x =MvNormalMeanCovariance(zeros(2), 1e4diageye(2)), x̂ =MvNormalMeanCovariance(zeros(2), 1e4diageye(2))),
initmessages=(a =GammaShapeRate(1e-2, 1e2),),
meta=delta_meta, data = (y = y,), free_energy=true, showprogress=true, iterations=5, returnvars=KeepLast())
Initial marginals are needed due to mean-field constraint; the init message is necessary for the CVI approximation method.
I will convert this issue into a discussion. I will create two issues associated with your question: (1) throwing an error on broadcasting, (2) dealing with broadcasting (this will take some time)
There's of always room for creating your node and the associated rules with that.
To get a better idea what RxInfer.jl can or cannot support, the understanding of Forney-style factor graphs helps significantly.
Hello. I'm trying to extend one of the examples but hitting a wall.
I'm building upon the Gaussian Linear Dynamical System example.
In particular I'm trying to generalise it so that we don't assume to know the transition matrix$A$ a priori (eventually, I'd like to get to a scenario where we learn most of those matrices). In the example $A$ is passed as an input, and then converted to a
constantvar
to be used in the model. I would like it to be given a prior in the model, and then inferred together with the rest of the parameters.I tried a couple of things, but I can't make it work. A minimum reproducible example is consist in redefining model and inference as follows:
And I change the inference function accordingly
I tried to specify$A$ in a number of ways, but never getting any luck.
When I try to run it, I get the following error:
The text was updated successfully, but these errors were encountered: