-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DiffEqBase Rule #320
Closed
Closed
DiffEqBase Rule #320
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
6b3e1b9
Initial work
willtebbutt 4897268
Merge in main
willtebbutt 80f412e
Move array rules around
willtebbutt a62ec96
Merge branch 'main' into wct/sciml-base-rule
willtebbutt c790f5a
Merge in main
willtebbutt 8d334b6
Fix up rrule with kwargs
willtebbutt 6cb61bc
Fix up merge
willtebbutt 52a2f9d
Improve error message very slightly
willtebbutt baa6e83
Improve test suite
willtebbutt 3132e69
Merge branch 'main' into wct/sciml-base-rule
willtebbutt 4353348
Merge branch 'main' into wct/sciml-base-rule
willtebbutt 6bc80ae
Merge branch 'wct/sciml-base-rule' of https://github.com/compintell/M…
willtebbutt 0ab46f1
Reformat
willtebbutt e80f338
Basics
willtebbutt 4b00150
Run new functionality in CI
willtebbutt ece9e2f
Merge in main
willtebbutt bce71fc
Included tests involving matrix-valued ODE
willtebbutt 5396567
Remove redundant rule
willtebbutt 45f5e50
Remove deprecated indexing
willtebbutt fb4f1dc
Remove FunctionWrapper type
willtebbutt 8dd2ac3
Merge branch 'main' into wct/sciml-base-rule
willtebbutt f501fc3
CRC interop
willtebbutt cbe23fe
Tidy up loading etc
willtebbutt 3d00548
Merge in main
willtebbutt daf2a79
Resolve merge conflict
willtebbutt c2323ca
Remove extraneous using
willtebbutt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
module MooncakeDiffEqBaseExt | ||
|
||
using DiffEqBase, Mooncake | ||
|
||
Mooncake.@from_rrule( | ||
Mooncake.MinimalCtx, | ||
Tuple{ | ||
typeof(DiffEqBase.solve_up), | ||
DiffEqBase.AbstractDEProblem, | ||
Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, | ||
Any, | ||
Any, | ||
Any, | ||
}, | ||
true, | ||
) | ||
|
||
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any} | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[deps] | ||
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" | ||
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" | ||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" | ||
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
using Pkg | ||
Pkg.activate(@__DIR__) | ||
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) | ||
|
||
using DiffEqBase, Mooncake, OrdinaryDiffEqTsit5, SciMLSensitivity, StableRNGs, Test | ||
using DiffEqBase: solve | ||
using Mooncake.TestUtils: test_rule | ||
|
||
function lotka_volterra!(du, u, p, t) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Start of DiffEqBase rule tests |
||
x, y = u | ||
α, β, δ, γ = p | ||
du[1] = dx = α * x - β * x * y | ||
du[2] = dy = -δ * y + γ * x * y | ||
end | ||
|
||
u0 = [1.0, 1.0] | ||
tspan = (0.0, 1.0) | ||
p = [1.5, 1.0, 4.0, 1.0] | ||
|
||
# Have to use remake with a const global in order to get good performance. | ||
const prob = ODEProblem(lotka_volterra!, u0, (0.0, 1.0), p) | ||
|
||
function build_and_solve(u0, tspan, p, sensealg) | ||
_prob = remake(prob; u0, p) | ||
sol = solve(_prob, Tsit5(); abstol=1e-14, reltol=1e-14, sensealg, saveat=0.01) | ||
return sum(sol) + sum(sum(sol.u[end])) | ||
end | ||
|
||
function matrix_ode!(du, u, p, t) | ||
du .= reshape(p, 4, 4) * u | ||
return nothing | ||
end | ||
|
||
u0_mat = rand(4, 8) | ||
p_mat = rand(16) | ||
|
||
# Have to use remake with a const global in order to get good performance. | ||
const matrix_prob = ODEProblem(matrix_ode!, u0_mat, (0.0, 1.0), p_mat) | ||
|
||
function build_and_solve_mat(u0, tspan, p, sensealg) | ||
_prob = remake(matrix_prob; u0, p) | ||
sol = solve(_prob, Tsit5(); abstol=1e-14, reltol=1e-14, sensealg, saveat=0.01) | ||
return sol[:, :, end] | ||
end | ||
|
||
@testset "diffeqbase" begin | ||
vjps = [false, true, EnzymeVJP(), ZygoteVJP(), ReverseDiffVJP(), ReverseDiffVJP(true)] | ||
reduced_vjps = [false, EnzymeVJP(), ReverseDiffVJP(), ReverseDiffVJP(true)] | ||
|
||
# These cases are excluded because Zygote also does not successfully work on them. | ||
excluded_cases = Any[ | ||
BacksolveAdjoint(; autojacvec=false), | ||
BacksolveAdjoint(; autojacvec=true), | ||
QuadratureAdjoint(; autojacvec=EnzymeVJP()), | ||
] | ||
|
||
@testset "$sensealg" for sensealg in vcat( | ||
[ForwardDiffSensitivity()], | ||
[BacksolveAdjoint(; autojacvec=vjp) for vjp in vjps], | ||
[GaussAdjoint(; autojacvec=vjp) for vjp in reduced_vjps], | ||
[InterpolatingAdjoint(; autojacvec=vjp) for vjp in vjps], | ||
[QuadratureAdjoint(; autojacvec=vjp) for vjp in reduced_vjps], | ||
) | ||
@info sensealg | ||
|
||
test_rule( | ||
StableRNG(123), build_and_solve, u0, tspan, p, sensealg; | ||
is_primitive=false, debug_mode=false, | ||
) | ||
|
||
sensealg in excluded_cases && continue | ||
test_rule( | ||
StableRNG(123), build_and_solve_mat, u0_mat, tspan, p_mat, sensealg; | ||
is_primitive=false, debug_mode=false, | ||
) | ||
end | ||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you'll need the solution adjoints as well.
Should this all live with DiffEqBase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this. What do you mean by the solution adjoints here?
I would be happy to move this over to DiffEqBase if that's what you would prefer -- also happy to have it here though 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'd prefer it live there so that I have some forced interaction with it and start to learn it / use it a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay -- I'll look into making a PR.