-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DiffEqBase Rule #320
DiffEqBase Rule #320
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅ 🚨 Try these New Features:
|
Performance Ratio:
|
…ooncake.jl into wct/sciml-base-rule
|
||
using DiffEqBase, Mooncake, OrdinaryDiffEqTsit5, Random, SciMLSensitivity, Test | ||
|
||
function lotka_volterra!(du, u, p, t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Start of DiffEqBase rule tests
|
||
using DiffEqBase, Mooncake | ||
|
||
Mooncake.@from_rrule( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you'll need the solution adjoints as well.
Should this all live with DiffEqBase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this. What do you mean by the solution adjoints here?
I would be happy to move this over to DiffEqBase if that's what you would prefer -- also happy to have it here though 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'd prefer it live there so that I have some forced interaction with it and start to learn it / use it a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay -- I'll look into making a PR.
I am surprised this works without the special indexing and construction rules on ODESolution that ChainRules requires. Can you describe a bit why those just work? |
It's hard to say why Mooncake.jl doesn't need a rule without understanding exactly why does Zygote need a rule, but I can break down some of what Mooncake has to do here. An example solution has type julia> typeof(sol)
ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing} for which the julia> Mooncake.tangent_type(typeof(sol))
Tangent{@NamedTuple{u::Vector{Vector{Float64}}, u_analytic::NoTangent, errors::NoTangent, t::Vector{Float64}, k::Vector{Vector{Vector{Float64}}}, discretes::NoTangent, prob::MutableTangent{@NamedTuple{f::NoTangent, u0::Vector{Float64}, tspan::Tuple{Float64, Float64}, p::Vector{Float64}, kwargs::NoTangent, problem_type::NoTangent}}, alg::NoTangent, interp::Tangent{@NamedTuple{f::NoTangent, timeseries::Vector{Vector{Float64}}, ts::Vector{Float64}, ks::Vector{Vector{Vector{Float64}}}, alg_choice::NoTangent, dense::NoTangent, cache::Tangent{@NamedTuple{u::Vector{Float64}, uprev::Vector{Float64}, k1::Vector{Float64}, k2::Vector{Float64}, k3::Vector{Float64}, k4::Vector{Float64}, k5::Vector{Float64}, k6::Vector{Float64}, k7::Vector{Float64}, utilde::Vector{Float64}, tmp::Vector{Float64}, atmp::Vector{Float64}, stage_limiter!::NoTangent, step_limiter!::NoTangent, thread::NoTangent}}, differential_vars::NoTangent, sensitivitymode::NoTangent}}, dense::NoTangent, tslocation::NoTangent, stats::MutableTangent{@NamedTuple{nf::NoTangent, nf2::NoTangent, nw::NoTangent, nsolve::NoTangent, njacs::NoTangent, nnonliniter::NoTangent, nnonlinconvfail::NoTangent, nfpiter::NoTangent, nfpconvfail::NoTangent, ncondition::NoTangent, naccept::NoTangent, nreject::NoTangent, maxeig::Float64}}, alg_choice::NoTangent, retcode::NoTangent, resid::NoTangent, original::NoTangent, saved_subsystem::NoTangent}} This is completely automatic, except for where I had to tell Mooncake that the tangent of a The primal code for julia> Base.code_ircode_by_type(Tuple{typeof(getindex), typeof(sol), typeof(:), Int}; interp=Mooncake.get_interpreter())
1-element Vector{Any}:
1 ── nothing::Nothing │
2 ── nothing::Nothing │
3 ── nothing::Nothing │
404 4 ── %4 = Core.getfield(_4, 1)::Int64 │
│ %5 = SciMLBase.getfield(_2, :u)::Vector{Vector{Float64}}_getindex
│ %6 = $(Expr(:boundscheck))::Bool ││╻ getindex
└─── goto #8 if not %6 │││
5 ── %8 = Base.sub_int(%4, 1)::Int64 │││
│ %9 = Base.bitcast(Base.UInt, %8)::UInt64 │││
│ %10 = Base.getfield(%5, :size)::Tuple{Int64} │││╻ length
│ %11 = $(Expr(:boundscheck, true))::Bool ││││╻ getindex
│ %12 = Base.getfield(%10, 1, %11)::Int64 │││││
│ %13 = Base.bitcast(Base.UInt, %12)::UInt64 │││
│ %14 = Base.ult_int(%9, %13)::Bool │││
└─── goto #7 if not %14 │││
6 ── goto #8 │
7 ── %17 = Core.tuple(%4)::Tuple{Int64} │││
│ invoke Base.throw_boundserror(%5::Vector{Vector{Float64}}, %17::Tuple{Int64})::Union{}
└─── unreachable │││
8 ┄─ %20 = Base.getfield(%5, :ref)::MemoryRef{Vector{Float64}}
│ %21 = Base.memoryrefnew(%20, %4, false)::MemoryRef{Vector{Float64}}
│ %22 = Base.memoryrefget(%21, :not_atomic, false)::Vector{Float64}
└─── goto #9 │
9 ── goto #10 │
10 ─ return %22 │
=> Vector{Float64} so it's not a problem for Mooncake. I've not taken a look at the code for the constructor for I don't know how helpful this is, sorry! |
(actually I should probably fix the function wrapper stuff before merging anything FunctionWrapper related, because it's quite unsafe as things stand) |
That's very helpful, thanks. I'm surprised it's able to figure out how to differentiate the special getindex if it's using that Tangent type instead of the ODESolution, but I guess it must use the indexing in the forward pass and know what to dispatch into when going in reverse. Neat. |
Yeah exactly. I actually don't think we have any rules in Mooncake specifically for |
I'm closing this in favour of SciML/SciMLSensitivity.jl#1151 . I plan to keep this branch alive until everything is merged into the various bits of SciML though. |
The purpose of this PR is to integrate with the
rrule
forsolve_up
found here.This is a reasonably complicated rule to interact with, because
FunctionWrappers
) which require their own rules,solve_up
, and@from_rrule
macro due to the presence of aVararg
. Currently I'm using position arguments, but I think I should be using aVararg
.Anyway, this is very much a WIP.
edit: also, I think I'm getting bitten by load / precompile ordering issues in package extensions, because I've had to temporarily shove the tangent type for
FunctionWrapper
s in Mooncake.jl itself, rather than making it an extension. I'm really hoping that I can get away from this somehow, as I really do not want Mooncake.jl to depend on FunctionWrappers.jl.