Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing use of rrule #1103

Open
willtebbutt opened this issue Nov 9, 2024 · 2 comments
Open

Testing use of rrule #1103

willtebbutt opened this issue Nov 9, 2024 · 2 comments
Labels

Comments

@willtebbutt
Copy link

willtebbutt commented Nov 9, 2024

Question❓

This package defines an extension for ChainRulesCore, in which an rrule for solve_up is defined:

function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEProblem,

I am attempting to make use of this rule in Mooncake in compintell/Mooncake.jl#320 to ensure that I can use Mooncake to differentiate functions which solve differential equations inside themselves. I am attempting to test that I've implemented my wrapper around this rule correctly -- see https://github.com/compintell/Mooncake.jl/pull/320/files#r1835360913 . However, I have no idea whether I've got a good collection of tests -- you'll see that I really just test that the function build_and_solve (which solves Lotka-Volterra equations) can be differentiated using Mooncake for a variety of sensealgs.

My question: how / where is the rrule for solve_up tested inside this package, and is there an existing set of tests I can either copy over, or take inspiration from, in order to check that my wrapper for the rrule works in all cases?

@willtebbutt willtebbutt changed the title Testing Use of rrule Testing use of rrule Nov 9, 2024
@ChrisRackauckas
Copy link
Member

So the difficulty here is that the vjp used in the adjoint rule of the differential equation is almost always (at this point) a different AD engine from the AD engine the user uses. Right now it is all Zygote that users tend to gravitate towards, and almost all ODE definitions are mutating. So the user uses Zygote, it hits that chain rule, and it goes to SciMLSensitivity.jl which defines the full adjoint and has AD deps, which then normally would slap Enzyme in there. If they haven't added SciMLSensitivity.jl, then they get an error https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L1593-L1597 .

So because of that, the real test suite is SciMLSensitivity.jl: this package just shuttles over to there, and it's here because that way we can give an informative error message.

Our current tests on "alternative AD frontends", i.e. non-Zygote, is https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/alternative_ad_frontend.jl . That's not exactly comprehensive, but the vast majority of work in the package is not on the AD front end but in the adjoint definitions, and thus that has tended to work. It's a little bit of a mess that it's not with DiffEqBase, but that's because most of the work is on adjoints so the front end parts in DiffEqBase don't tend to change much. This situation will hopefully improve with JuliaLang/julia#55516, but I digress.

From what I see in your set of the front end tests, one of the big things you're missing is flexing the solution interface a bit more. sum(sum(sol.u[end])) is one, but then sum(sol), or sum(sol[end]), and a higher dimensional ODE with sum(sol[:,:,end]) kind of things.

@willtebbutt
Copy link
Author

Thanks for your help with this @ChrisRackauckas .

I've added some more tests to the Mooncake PR now, which I think overs all of the bits you've mentioned above. I've also replied to your comment on the PR with a couple of additional questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants