-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom node contraction #342
Conversation
Hi @blolt . Thanks for this PR! As far as I can see you're 90% along the way, and are definitely in the right direction. The last puzzle piece I would say is that RxInfer.jl/src/model/graphppl.jl Lines 162 to 165 in 858000c
which for now defaults to the default backend implementation (which checks that if something is made with the @model macro it is Composite and otherwise it is Atomic . Probably if we make a separate dispatch on ReactiveMPGraphPPLBackend{True} and return GraphPPL.Atomic() the rest of your PR works perfectly. I think the usage of Static is appropriate.
|
Was it not the intention to let the user define custom node logic for a given backend? This was the example in the original issue,
which uses a user defined node |
Ah, you are right. I will check to see if the current implementation works this way |
Thanks Wouter! Just linking a recent paper with some rules on message-passing for HGF as I know one of the goals of this issue is to produce documentation with custom rules. Perhaps this will prove useful. Interesting paper in its own right, too. |
@blolt its a very cool feature, it has been a very busy period in BIASlab, but we didn't forget about your contribution! @wouterwln we should review this before next update meeting on September 18th |
Hi @blolt, I checked your PR, I would advise you to test the functionality in a simpler (sub-)model; there is a reason we like to contract the @model function submodel(x, z, y)
p ~ NormalMeanVariance(0, 1)
y ~ NormalMeanVariance(x + z + p, 1)
end
@model function larger_model(y)
x ~ NormalMeanVariance(0, 1)
z ~ NormalMeanVariance(0, 1)
y ~ submodel(x = x, z = z)
end This still won't work, and I think it is a bug on our part. As far as I understand, we can label the node as being In any case, I would ask you to rewrite the tests with a simpler example, and I think the rest of your PR works as intended, as I was able to isolate this issue quite easily and the rest of the behaviour was as expected. Thank you very much! |
Thanks for your work, @blolt! I reviewed the PR and made a few adjustments, particularly to the API. Now, a user only needs to define the node using the @model function gcv(y, x, z, κ, ω)
log_σ := κ * z + ω
σ := exp(log_σ)
y ~ Normal(mean = x, precision = σ)
end
@node typeof(gcv) Stochastic [ y, x, z, κ, ω ] With I also updated the tests, but I noticed that inference isn't running for the multi-layer HGF. @wouterwln, could you take a look? The single-layer HGF works perfectly. |
|
Unfortunately, the 2-layer HGF also breaks with the existing GCV node. The problem is with the initialization. The problem occurs around the 'topmost' GCV node, which is surrounded by |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #342 +/- ##
==========================================
+ Coverage 84.84% 84.94% +0.09%
==========================================
Files 20 20
Lines 1511 1521 +10
==========================================
+ Hits 1282 1292 +10
Misses 229 229 ☔ View full report in Codecov by Sentry. |
@bvdmitri I think the PR is in a shape to merge now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job @blolt !! Thank you for the help and reviewing @wouterwln ! The code can be merged as is, I agree, but I think we should add a documentation section in docs/src/manuals/inference/node_contraction.md
akin to other manuals in the same folder. Let me know if you have time to work on it, otherwise I can also spend some time on it next week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job! Thanks for the contribution!
Overview
Note: This is a re-opening of a previous PR, which I inadvertently closed.
This PR resolves #287 (or it will), allowing the user to specify custom composite node contraction through an argument on the infer function and an implementation of
GraphPPL.NodeType
.Testing
Three unit tests have been created to test this functionality. The first is a simple test to confirm that the RxInfer backend can be parameterized, and the next two test the infer function on an HGF with contracted nodes. The test
"Static Inference With Node Contraction"
is currently failing with the following error:However, all keyword arguments are being used for the
gcv
macro, so it seems likely that the issue lies elsewhere. I have tried different approaches, but none so far have bore fruit (a passing hgf unit test).The test
"Static Inference With Node Contraction 1"
was extended from RxInfer doc examples of an HGF and ReactiveMP's GCV node. It was originally passing with theGCV
node, but I have not been able to get it to pass with the customgcv
model I introduced. Rules are not yet specified for the custom gcv node. The test is currently failing due to incorrect constraints.