Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom node contraction #342

Merged
merged 9 commits into from
Sep 30, 2024
Merged

Conversation

blolt
Copy link
Contributor

@blolt blolt commented Aug 4, 2024

Overview

Note: This is a re-opening of a previous PR, which I inadvertently closed.

This PR resolves #287 (or it will), allowing the user to specify custom composite node contraction through an argument on the infer function and an implementation of GraphPPL.NodeType.

Testing

Three unit tests have been created to test this functionality. The first is a simple test to confirm that the RxInfer backend can be parameterized, and the next two test the infer function on an HGF with contracted nodes. The test "Static Inference With Node Contraction" is currently failing with the following error:

ERROR: The `gcv` model macro does not support positional arguments. Use keyword arguments `gcv(κ = ..., ω = ..., z = ..., x = ..., y = ...)` instead.

However, all keyword arguments are being used for the gcv macro, so it seems likely that the issue lies elsewhere. I have tried different approaches, but none so far have bore fruit (a passing hgf unit test).

The test "Static Inference With Node Contraction 1" was extended from RxInfer doc examples of an HGF and ReactiveMP's GCV node. It was originally passing with the GCV node, but I have not been able to get it to pass with the custom gcv model I introduced. Rules are not yet specified for the custom gcv node. The test is currently failing due to incorrect constraints.

@wouterwln
Copy link
Member

Hi @blolt . Thanks for this PR! As far as I can see you're 90% along the way, and are definitely in the right direction. The last puzzle piece I would say is that GraphPPL has to realize that with a ReactiveMPGraphPPLBackend{True}, every Composite node turns into an Atomic node. This could be realized in:

function GraphPPL.NodeType(::ReactiveMPGraphPPLBackend, something::F) where {F}
# Fallback to the default behaviour
return GraphPPL.NodeType(GraphPPL.DefaultBackend(), something)
end

which for now defaults to the default backend implementation (which checks that if something is made with the @model macro it is Composite and otherwise it is Atomic. Probably if we make a separate dispatch on ReactiveMPGraphPPLBackend{True} and return GraphPPL.Atomic() the rest of your PR works perfectly. I think the usage of Static is appropriate.

@blolt
Copy link
Contributor Author

blolt commented Aug 26, 2024

Was it not the intention to let the user define custom node logic for a given backend? This was the example in the original issue,

GraphPPL.NodeType(::RxInferBackend{True}, ::typeof(gcv)) = GraphPPL.Atomic()

which uses a user defined node gcv. Then again, I don't see much of a use for the Static parameter here, since the node-type should be enough for dispatch. In any case, just want to clarify that contracting all composite nodes, rather than just particular kinds of composite nodes, is the intended use-case.

@wouterwln
Copy link
Member

Ah, you are right. I will check to see if the current implementation works this way

@blolt
Copy link
Contributor Author

blolt commented Sep 9, 2024

Thanks Wouter! Just linking a recent paper with some rules on message-passing for HGF as I know one of the goals of this issue is to produce documentation with custom rules. Perhaps this will prove useful. Interesting paper in its own right, too.

https://arxiv.org/pdf/2305.10937

@bvdmitri
Copy link
Member

@blolt its a very cool feature, it has been a very busy period in BIASlab, but we didn't forget about your contribution! @wouterwln we should review this before next update meeting on September 18th

@wouterwln wouterwln self-assigned this Sep 17, 2024
@wouterwln
Copy link
Member

Hi @blolt, I checked your PR, I would advise you to test the functionality in a simpler (sub-)model; there is a reason we like to contract the gcv submodel into a node, which is because it is actually really hard to run inference inside of it. I tried recreating some of the results with the following simpler model:

@model function submodel(x, z, y)
    p ~ NormalMeanVariance(0, 1)
    y ~ NormalMeanVariance(x + z + p, 1)
end

@model function larger_model(y)
    x ~ NormalMeanVariance(0, 1)
    z ~ NormalMeanVariance(0, 1)
    y ~ submodel(x = x, z = z)
end

This still won't work, and I think it is a bug on our part. As far as I understand, we can label the node as being Atomic, and GraphPPL will label it Stochastic as well, as is default for submodels. However, ReactiveMP.sdtype() will still be Deterministic, which prompts a DeltaMeta meta object which shouldn't happen. @bvdmitri any idea how to fix this? Is there a reason GraphPPL and ReactiveMP's implementation of Stochastic and Deterministic are mixed?

In any case, I would ask you to rewrite the tests with a simpler example, and I think the rest of your PR works as intended, as I was able to isolate this issue quite easily and the rest of the behaviour was as expected. Thank you very much!

@bvdmitri
Copy link
Member

Thanks for your work, @blolt! I reviewed the PR and made a few adjustments, particularly to the API. Now, a user only needs to define the node using the @node macro for the contraction to work. For example:

@model function gcv(y, x, z, κ, ω)
      log_σ := κ * z + ω
      σ := exp(log_σ)
      y ~ Normal(mean = x, precision = σ)
  end

@node typeof(gcv) Stochastic [ y, x, z, κ, ω ]

With allow_node_contraction = true, this should be enough to automatically identify the corresponding node and it will try to use the rules instead (after this PR is merged we need to redefine the GCV node as a submodel I guess).

I also updated the tests, but I noticed that inference isn't running for the multi-layer HGF. @wouterwln, could you take a look? The single-layer HGF works perfectly.

@bvdmitri
Copy link
Member

  • we need to add the documentation section with examples

@wouterwln
Copy link
Member

Unfortunately, the 2-layer HGF also breaks with the existing GCV node. The problem is with the initialization. The problem occurs around the 'topmost' GCV node, which is surrounded by x_2[i], x_2[i - 1], x_3[i], κ_2 and ω_2. Now, because of the initialization bug we found over the past couple of weeks, we can fire off all existing rules, using up all the initialized marginals, which sends the message passing algorithm into a deadlock. We can fix this by initializing not every x_2, but every other x_2 (so x_2[1:2:n]). This makes the inference run, and also proves that (luckily for this PR) the problem is not with the node contraction, and node contraction works perfectly, but (unfortunately for us) this weird behavior is due to some other bug. I'll push the fix and then everything should work as is.

Copy link

codecov bot commented Sep 18, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.94%. Comparing base (9facd12) to head (f5150d9).
Report is 10 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #342      +/-   ##
==========================================
+ Coverage   84.84%   84.94%   +0.09%     
==========================================
  Files          20       20              
  Lines        1511     1521      +10     
==========================================
+ Hits         1282     1292      +10     
  Misses        229      229              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wouterwln wouterwln marked this pull request as ready for review September 25, 2024 12:33
@wouterwln
Copy link
Member

@bvdmitri I think the PR is in a shape to merge now

Copy link
Member

@bvdmitri bvdmitri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job @blolt !! Thank you for the help and reviewing @wouterwln ! The code can be merged as is, I agree, but I think we should add a documentation section in docs/src/manuals/inference/node_contraction.md akin to other manuals in the same folder. Let me know if you have time to work on it, otherwise I can also spend some time on it next week.

Copy link
Member

@bvdmitri bvdmitri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! Thanks for the contribution!

@wouterwln wouterwln merged commit fc028f2 into ReactiveBayes:main Sep 30, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow contraction of composite nodes
3 participants