Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiffEqBase Rule #320

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6b3e1b9
Initial work
willtebbutt Oct 29, 2024
4897268
Merge in main
willtebbutt Nov 1, 2024
80f412e
Move array rules around
willtebbutt Nov 1, 2024
a62ec96
Merge branch 'main' into wct/sciml-base-rule
willtebbutt Nov 4, 2024
c790f5a
Merge in main
willtebbutt Nov 7, 2024
8d334b6
Fix up rrule with kwargs
willtebbutt Nov 7, 2024
6cb61bc
Fix up merge
willtebbutt Nov 7, 2024
52a2f9d
Improve error message very slightly
willtebbutt Nov 7, 2024
baa6e83
Improve test suite
willtebbutt Nov 7, 2024
3132e69
Merge branch 'main' into wct/sciml-base-rule
willtebbutt Nov 8, 2024
4353348
Merge branch 'main' into wct/sciml-base-rule
willtebbutt Nov 8, 2024
6bc80ae
Merge branch 'wct/sciml-base-rule' of https://github.com/compintell/M…
willtebbutt Nov 8, 2024
0ab46f1
Reformat
willtebbutt Nov 8, 2024
e80f338
Basics
willtebbutt Nov 8, 2024
4b00150
Run new functionality in CI
willtebbutt Nov 8, 2024
ece9e2f
Merge in main
willtebbutt Nov 11, 2024
bce71fc
Included tests involving matrix-valued ODE
willtebbutt Nov 12, 2024
5396567
Remove redundant rule
willtebbutt Nov 12, 2024
45f5e50
Remove deprecated indexing
willtebbutt Nov 12, 2024
fb4f1dc
Remove FunctionWrapper type
willtebbutt Nov 12, 2024
8dd2ac3
Merge branch 'main' into wct/sciml-base-rule
willtebbutt Nov 17, 2024
f501fc3
CRC interop
willtebbutt Nov 18, 2024
cbe23fe
Tidy up loading etc
willtebbutt Nov 18, 2024
3d00548
Merge in main
willtebbutt Nov 18, 2024
daf2a79
Resolve merge conflict
willtebbutt Nov 19, 2024
c2323ca
Remove extraneous using
willtebbutt Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ jobs:
matrix:
test_group: [
{test_type: 'ext', label: 'differentiation_interface'},
{test_type: 'ext', label: 'diffeqbase'},
{test_type: 'ext', label: 'dynamic_ppl'},
{test_type: 'ext', label: 'luxlib'},
{test_type: 'ext', label: 'nnlib'},
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[weakdeps]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Expand All @@ -32,6 +33,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
[extensions]
MooncakeAllocCheckExt = "AllocCheck"
MooncakeCUDAExt = "CUDA"
MooncakeDiffEqBaseExt = "DiffEqBase"
MooncakeDynamicPPLExt = "DynamicPPL"
MooncakeJETExt = "JET"
MooncakeLuxLibExt = "LuxLib"
Expand All @@ -47,6 +49,7 @@ BenchmarkTools = "1"
CUDA = "5"
ChainRules = "1.71.0"
ChainRulesCore = "1"
DiffEqBase = "6"
DiffRules = "1"
DiffTests = "0.1"
DynamicPPL = "0.29, 0.30"
Expand Down
20 changes: 20 additions & 0 deletions ext/MooncakeDiffEqBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module MooncakeDiffEqBaseExt

using DiffEqBase, Mooncake

Mooncake.@from_rrule(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'll need the solution adjoints as well.

Should this all live with DiffEqBase?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this. What do you mean by the solution adjoints here?

I would be happy to move this over to DiffEqBase if that's what you would prefer -- also happy to have it here though 🤷

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'd prefer it live there so that I have some forced interaction with it and start to learn it / use it a bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay -- I'll look into making a PR.

Mooncake.MinimalCtx,
Tuple{
typeof(DiffEqBase.solve_up),
DiffEqBase.AbstractDEProblem,
Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm},
Any,
Any,
Any,
},
true,
)

Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}

end
9 changes: 9 additions & 0 deletions test/ext/diffeqbase/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
77 changes: 77 additions & 0 deletions test/ext/diffeqbase/diffeqbase.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", ".."))

using DiffEqBase, Mooncake, OrdinaryDiffEqTsit5, SciMLSensitivity, StableRNGs, Test
using DiffEqBase: solve
using Mooncake.TestUtils: test_rule

function lotka_volterra!(du, u, p, t)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Start of DiffEqBase rule tests

x, y = u
α, β, δ, γ = p
du[1] = dx = α * x - β * x * y
du[2] = dy = -δ * y + γ * x * y
end

u0 = [1.0, 1.0]
tspan = (0.0, 1.0)
p = [1.5, 1.0, 4.0, 1.0]

# Have to use remake with a const global in order to get good performance.
const prob = ODEProblem(lotka_volterra!, u0, (0.0, 1.0), p)

function build_and_solve(u0, tspan, p, sensealg)
_prob = remake(prob; u0, p)
sol = solve(_prob, Tsit5(); abstol=1e-14, reltol=1e-14, sensealg, saveat=0.01)
return sum(sol) + sum(sum(sol.u[end]))
end

function matrix_ode!(du, u, p, t)
du .= reshape(p, 4, 4) * u
return nothing
end

u0_mat = rand(4, 8)
p_mat = rand(16)

# Have to use remake with a const global in order to get good performance.
const matrix_prob = ODEProblem(matrix_ode!, u0_mat, (0.0, 1.0), p_mat)

function build_and_solve_mat(u0, tspan, p, sensealg)
_prob = remake(matrix_prob; u0, p)
sol = solve(_prob, Tsit5(); abstol=1e-14, reltol=1e-14, sensealg, saveat=0.01)
return sol[:, :, end]
end

@testset "diffeqbase" begin
vjps = [false, true, EnzymeVJP(), ZygoteVJP(), ReverseDiffVJP(), ReverseDiffVJP(true)]
reduced_vjps = [false, EnzymeVJP(), ReverseDiffVJP(), ReverseDiffVJP(true)]

# These cases are excluded because Zygote also does not successfully work on them.
excluded_cases = Any[
BacksolveAdjoint(; autojacvec=false),
BacksolveAdjoint(; autojacvec=true),
QuadratureAdjoint(; autojacvec=EnzymeVJP()),
]

@testset "$sensealg" for sensealg in vcat(
[ForwardDiffSensitivity()],
[BacksolveAdjoint(; autojacvec=vjp) for vjp in vjps],
[GaussAdjoint(; autojacvec=vjp) for vjp in reduced_vjps],
[InterpolatingAdjoint(; autojacvec=vjp) for vjp in vjps],
[QuadratureAdjoint(; autojacvec=vjp) for vjp in reduced_vjps],
)
@info sensealg

test_rule(
StableRNG(123), build_and_solve, u0, tspan, p, sensealg;
is_primitive=false, debug_mode=false,
)

sensealg in excluded_cases && continue
test_rule(
StableRNG(123), build_and_solve_mat, u0_mat, tspan, p_mat, sensealg;
is_primitive=false, debug_mode=false,
)
end
end
Loading