Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiffEqBase Rule #320

Closed
wants to merge 26 commits into from
Closed

DiffEqBase Rule #320

wants to merge 26 commits into from

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Oct 29, 2024

The purpose of this PR is to integrate with the rrule for solve_up found here.

This is a reasonably complicated rule to interact with, because

  1. DiffEqBase makes use of quite a few interesting types (e.g. FunctionWrappers) which require their own rules,
  2. I'm still trying to figure out the interface for solve_up, and
  3. This case pushes our @from_rrule macro due to the presence of a Vararg. Currently I'm using position arguments, but I think I should be using a Vararg.

Anyway, this is very much a WIP.

edit: also, I think I'm getting bitten by load / precompile ordering issues in package extensions, because I've had to temporarily shove the tangent type for FunctionWrappers in Mooncake.jl itself, rather than making it an extension. I'm really hoping that I can get away from this somehow, as I really do not want Mooncake.jl to depend on FunctionWrappers.jl.

@willtebbutt willtebbutt marked this pull request as draft October 29, 2024 15:00
Copy link

codecov bot commented Oct 29, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅


🚨 Try these New Features:

Copy link
Contributor

github-actions bot commented Oct 29, 2024

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬──────────┬─────────────┬─────────┐
│                      Label │ Mooncake │   Zygote │ ReverseDiff │  Enzyme │
│                     String │   String │   String │      String │  String │
├────────────────────────────┼──────────┼──────────┼─────────────┼─────────┤
│                   sum_1000 │     70.9 │      1.0 │         5.5 │ missing │
│                  _sum_1000 │     6.65 │ 345000.0 │        34.3 │ missing │
│               sum_sin_1000 │     2.32 │     1.72 │        10.7 │ missing │
│              _sum_sin_1000 │     2.79 │    231.0 │        13.7 │ missing │
│                   kron_sum │     78.4 │     9.34 │       302.0 │ missing │
│              kron_view_sum │     53.0 │     8.36 │       205.0 │ missing │
│      naive_map_sin_cos_exp │     2.51 │  missing │         7.5 │ missing │
│            map_sin_cos_exp │     2.74 │     1.42 │        5.99 │ missing │
│      broadcast_sin_cos_exp │      2.6 │     3.28 │        1.46 │ missing │
│                 simple_mlp │     6.03 │     2.67 │        10.4 │ missing │
│                     gp_lml │     8.67 │     3.69 │     missing │ missing │
│ turing_broadcast_benchmark │     3.15 │  missing │        28.5 │ missing │
│         large_single_block │     4.16 │   4180.0 │        31.2 │ missing │
└────────────────────────────┴──────────┴──────────┴─────────────┴─────────┘


using DiffEqBase, Mooncake, OrdinaryDiffEqTsit5, Random, SciMLSensitivity, Test

function lotka_volterra!(du, u, p, t)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Start of DiffEqBase rule tests


using DiffEqBase, Mooncake

Mooncake.@from_rrule(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'll need the solution adjoints as well.

Should this all live with DiffEqBase?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this. What do you mean by the solution adjoints here?

I would be happy to move this over to DiffEqBase if that's what you would prefer -- also happy to have it here though 🤷

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'd prefer it live there so that I have some forced interaction with it and start to learn it / use it a bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay -- I'll look into making a PR.

@willtebbutt willtebbutt marked this pull request as ready for review November 12, 2024 09:13
@ChrisRackauckas
Copy link

I am surprised this works without the special indexing and construction rules on ODESolution that ChainRules requires. Can you describe a bit why those just work?

@willtebbutt
Copy link
Member Author

willtebbutt commented Nov 12, 2024

It's hard to say why Mooncake.jl doesn't need a rule without understanding exactly why does Zygote need a rule, but I can break down some of what Mooncake has to do here.

An example solution has type

julia> typeof(sol)
ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}

for which the Mooncake.tangent_type is

julia> Mooncake.tangent_type(typeof(sol))
Tangent{@NamedTuple{u::Vector{Vector{Float64}}, u_analytic::NoTangent, errors::NoTangent, t::Vector{Float64}, k::Vector{Vector{Vector{Float64}}}, discretes::NoTangent, prob::MutableTangent{@NamedTuple{f::NoTangent, u0::Vector{Float64}, tspan::Tuple{Float64, Float64}, p::Vector{Float64}, kwargs::NoTangent, problem_type::NoTangent}}, alg::NoTangent, interp::Tangent{@NamedTuple{f::NoTangent, timeseries::Vector{Vector{Float64}}, ts::Vector{Float64}, ks::Vector{Vector{Vector{Float64}}}, alg_choice::NoTangent, dense::NoTangent, cache::Tangent{@NamedTuple{u::Vector{Float64}, uprev::Vector{Float64}, k1::Vector{Float64}, k2::Vector{Float64}, k3::Vector{Float64}, k4::Vector{Float64}, k5::Vector{Float64}, k6::Vector{Float64}, k7::Vector{Float64}, utilde::Vector{Float64}, tmp::Vector{Float64}, atmp::Vector{Float64}, stage_limiter!::NoTangent, step_limiter!::NoTangent, thread::NoTangent}}, differential_vars::NoTangent, sensitivitymode::NoTangent}}, dense::NoTangent, tslocation::NoTangent, stats::MutableTangent{@NamedTuple{nf::NoTangent, nf2::NoTangent, nw::NoTangent, nsolve::NoTangent, njacs::NoTangent, nnonliniter::NoTangent, nnonlinconvfail::NoTangent, nfpiter::NoTangent, nfpconvfail::NoTangent, ncondition::NoTangent, naccept::NoTangent, nreject::NoTangent, maxeig::Float64}}, alg_choice::NoTangent, retcode::NoTangent, resid::NoTangent, original::NoTangent, saved_subsystem::NoTangent}}

This is completely automatic, except for where I had to tell Mooncake that the tangent of a FunctionWrapper is always NoTangent (which I'm reasonably confident will cause issues at some point, because you can shove closures in FunctionWrappers which can themselves contain differentiable data and so have non-NoTangent tangents. This isn't a problem here though, because the FunctionWrappers are just internal implementation details in this case).

The primal code for getindex optimises down to something really quite simple:

julia> Base.code_ircode_by_type(Tuple{typeof(getindex), typeof(sol), typeof(:), Int}; interp=Mooncake.get_interpreter())
1-element Vector{Any}:
    1 ──       nothing::Nothing2 ──       nothing::Nothing3 ──       nothing::Nothing404 4 ── %4  = Core.getfield(_4, 1)::Int64                  │    
    │    %5  = SciMLBase.getfield(_2, :u)::Vector{Vector{Float64}}_getindex
    │    %6  = $(Expr(:boundscheck))::Bool                  ││╻    getindex
    └───       goto #8 if not %6                            │││  
    5 ── %8  = Base.sub_int(%4, 1)::Int64                   │││  
    │    %9  = Base.bitcast(Base.UInt, %8)::UInt64          │││  
    │    %10 = Base.getfield(%5, :size)::Tuple{Int64}       │││╻    length
    │    %11 = $(Expr(:boundscheck, true))::Bool            ││││╻    getindex
    │    %12 = Base.getfield(%10, 1, %11)::Int64            │││││
    │    %13 = Base.bitcast(Base.UInt, %12)::UInt64         │││  
    │    %14 = Base.ult_int(%9, %13)::Bool                  │││  
    └───       goto #7 if not %14                           │││  
    6 ──       goto #8                                      │    
    7 ── %17 = Core.tuple(%4)::Tuple{Int64}                 │││  
    │          invoke Base.throw_boundserror(%5::Vector{Vector{Float64}}, %17::Tuple{Int64})::Union{}
    └───       unreachable                                  │││  
    8 ┄─ %20 = Base.getfield(%5, :ref)::MemoryRef{Vector{Float64}}%21 = Base.memoryrefnew(%20, %4, false)::MemoryRef{Vector{Float64}}%22 = Base.memoryrefget(%21, :not_atomic, false)::Vector{Float64}
    └───       goto #9                                      │    
    9 ──       goto #10                                     │    
    10return %22=> Vector{Float64}

so it's not a problem for Mooncake. I've not taken a look at the code for the constructor for ODESolution yet, but generally speaking Mooncake is fine with constructors provided that you can define the tangent type of the resulting object (Mooncake directly targets the :new or :splatnew instructions, rather than trying to intercept a constructor or anything).

I don't know how helpful this is, sorry!

@willtebbutt
Copy link
Member Author

(actually I should probably fix the function wrapper stuff before merging anything FunctionWrapper related, because it's quite unsafe as things stand)

@ChrisRackauckas
Copy link

That's very helpful, thanks. I'm surprised it's able to figure out how to differentiate the special getindex if it's using that Tangent type instead of the ODESolution, but I guess it must use the indexing in the forward pass and know what to dispatch into when going in reverse. Neat.

@willtebbutt
Copy link
Member Author

willtebbutt commented Nov 12, 2024

Yeah exactly. I actually don't think we have any rules in Mooncake specifically for getindex, everything is done at a lower level. That is, we have rules for the various builtins which your specialised method of getindex is implemented in terms of (getfield, arrayref / memoryrefget (1.10 vs 1.11) etc), so we just differentiate the code generated to do that.

@willtebbutt
Copy link
Member Author

willtebbutt commented Nov 19, 2024

I'm closing this in favour of SciML/SciMLSensitivity.jl#1151 .

I plan to keep this branch alive until everything is merged into the various bits of SciML though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants