Skip to content

Commit

Permalink
Merge pull request #1106 from jClugstor/ODESensitivityKwargs
Browse files Browse the repository at this point in the history
Switch positional `alg` argument to kwarg `sensealg` for `ODEForwardSensitivityProblem`
  • Loading branch information
ChrisRackauckas authored Sep 7, 2024
2 parents a24d6c6 + 477780d commit 6c95130
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem,

# callback = nothing ensures only the callback in kwargs is used
_prob = ODEForwardSensitivityProblem(
_f, u0, prob.tspan, p, sensealg, callback = nothing)
_f, u0, prob.tspan, p; sensealg = sensealg, callback = nothing)
sol = solve(_prob, alg, args...; kwargs...)
_, du = extract_local_sensitivities(sol, sensealg, Val(true))
ts = current_time(sol)
Expand Down Expand Up @@ -739,7 +739,7 @@ function DiffEqBase._concrete_solve_forward(prob::SciMLBase.AbstractODEProblem,
args...; save_idxs = nothing,
kwargs...)
_prob = ODEForwardSensitivityProblem(
prob.f, u0, prob.tspan, p, sensealg, callback = nothing)
prob.f, u0, prob.tspan, p; sensealg = sensealg, callback = nothing)
sol = solve(_prob, args...; kwargs...)

if originator isa SciMLBase.EnzymeOriginator
Expand Down
41 changes: 35 additions & 6 deletions src/forward_sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ function ODEForwardSensitivityProblem(f::F, args...; kwargs...) where {F}
ODEForwardSensitivityProblem(ODEFunction(f), args...; kwargs...)
end

function ODEForwardSensitivityProblem(prob::ODEProblem, alg; kwargs...)
ODEForwardSensitivityProblem(
prob.f, state_values(prob), prob.tspan, parameter_values(prob), alg; kwargs...)
function ODEForwardSensitivityProblem(prob::ODEProblem; sensealg = ForwardSensitivity(), kwargs...)
_ODEForwardSensitivityProblem(
prob.f, state_values(prob), prob.tspan, parameter_values(prob), sensealg; kwargs...)
end

const FORWARD_SENSITIVITY_PARAMETER_COMPATIBILITY_MESSAGE = """
Expand Down Expand Up @@ -351,9 +351,28 @@ at time `sol.t[i]`. Note that all the functionality available to ODE solutions
is available in this case, including interpolations and plot recipes (the recipes
will plot the expanded system).
"""
function ODEForwardSensitivityProblem(f::F, u0, tspan, p = nothing;
sensealg = ForwardSensitivity(),
kwargs...) where {F <: DiffEqBase.AbstractODEFunction}

_ODEForwardSensitivityProblem(f,u0,tspan,p,sensealg; kwargs...)
end

# deprecated
function ODEForwardSensitivityProblem(f::F, u0,
tspan, p = nothing,
alg::ForwardSensitivity = ForwardSensitivity();
tspan, p,
alg::ForwardSensitivity;
nus = nothing, # determine if Nilss is used
w0 = nothing,
v0 = nothing,
kwargs...) where {F <: DiffEqBase.AbstractODEFunction}
Base.depwarn("The form of this function with `alg` as a positional argument is deprecated. Please use the `sensealg` keyword argument instead.", :ODEForwardSensitivityProblem)
_ODEForwardSensitivityProblem(f,u0,tspan,p,alg; nus,w0,v0,kwargs...)
end

function _ODEForwardSensitivityProblem(f::F, u0,
tspan, p,
alg::ForwardSensitivity;
nus = nothing, # determine if Nilss is used
w0 = nothing,
v0 = nothing,
Expand Down Expand Up @@ -465,11 +484,21 @@ has_continuous_callback(cb::DiscreteCallback) = false
has_continuous_callback(cb::ContinuousCallback) = true
has_continuous_callback(cb::CallbackSet) = !isempty(cb.continuous_callbacks)

# deprecated
function ODEForwardSensitivityProblem(f::DiffEqBase.AbstractODEFunction, u0,
tspan, p, alg::ForwardDiffSensitivity;
du0 = zeros(eltype(u0), length(u0), length(p)), # perturbations of initial condition
dp = I(length(p)), # perturbations of parameters
kwargs...)
Base.depwarn("The form of this function with `alg` as a positional argument is deprecated. Please use the `sensealg` keyword argument instead.", :ODEForwardSensitivity)
_ODEForwardSensitivityProblem(f, u0, tspan, p, alg, du0, dp, kwargs...)
end

function _ODEForwardSensitivityProblem(f::DiffEqBase.AbstractODEFunction, u0,
tspan, p, alg::ForwardDiffSensitivity;
du0 = zeros(eltype(u0), length(u0), length(p)), # perturbations of initial condition
dp = I(length(p)), # perturbations of parameters
kwargs...)
num_sen_par = size(du0, 2)
if num_sen_par != size(dp, 2)
error("Same number of perturbations of initial conditions and parameters required")
Expand Down Expand Up @@ -648,7 +677,7 @@ function SciMLBase.remake(
u0[1:(prob.f.numindvar)]
_tspan = tspan === nothing ? prob.tspan : tspan
ODEForwardSensitivityProblem(_f, _u0,
_tspan, _p, prob.problem_type.sensealg;
_tspan, _p; sensealg = prob.problem_type.sensealg,
prob.kwargs..., kwargs...)
end
SciMLBase.ODEFunction(f::ODEForwardSensitivityFunction; kwargs...) = f
6 changes: 3 additions & 3 deletions src/nilss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ function forward_sense(prob::NILSSProblem, nilss::NILSS, alg)
t1 = forward_prob.tspan[1]
t2 = forward_prob.tspan[1] + T_seg
_prob = ODEForwardSensitivityProblem(
S.f, u0, (t1, t2), p, sensealg; nus = nus, w0 = w0,
S.f, u0, (t1, t2), p; sensealg = sensealg, nus = nus, w0 = w0,
v0 = vstar0)

for iseg in 1:nseg
Expand All @@ -315,8 +315,8 @@ function forward_sense(prob::NILSSProblem, nilss::NILSS, alg)
renormalize!(R, b, w_perp, vstar_perp, y, vstar, w, iseg, numparams, nus)
t1 = forward_prob.tspan[1] + iseg * T_seg
t2 = forward_prob.tspan[1] + (iseg + 1) * T_seg
_prob = ODEForwardSensitivityProblem(S.f, y[:, 1, iseg + 1], (t1, t2), p,
sensealg; nus = nus,
_prob = ODEForwardSensitivityProblem(S.f, y[:, 1, iseg + 1], (t1, t2), p;
sensealg = sensealg, nus = nus,
w0 = vec(w[:, 1, iseg + 1, :]),
v0 = vec(vstar[:, :, 1, iseg + 1]))
end
Expand Down
23 changes: 14 additions & 9 deletions test/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@ p = [1.5, 1.0, 3.0]
prob = ODEForwardSensitivityProblem(f, [1.0; 1.0], (0.0, 10.0), p)
probInpl = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p)
probnoad = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p,
ForwardSensitivity(autodiff = false))
sensealg = ForwardSensitivity(autodiff = false))
probnoadjacvec = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p,
ForwardSensitivity(autodiff = false,
sensealg = ForwardSensitivity(autodiff = false,
autojacvec = true))
probnoad2 = ODEForwardSensitivityProblem(f, [1.0; 1.0], (0.0, 10.0), p,
ForwardSensitivity(autodiff = false))
sensealg = ForwardSensitivity(autodiff = false))
probvecmat = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p,
ForwardSensitivity(autojacvec = false,
sensealg = ForwardSensitivity(autojacvec = false,
autojacmat = true))

# tests that the deprecated version still works
dep_prob_const = ODEForwardSensitivityProblem(fb, [1.0; 1.0], (0.0, 10.0), p, ForwardSensitivity())

sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14)
@test_broken solve(probInpl, KenCarp4(), abstol = 1e-14, reltol = 1e-14).retcode == :Success
solInpl = solve(probInpl, KenCarp4(autodiff = false), abstol = 1e-14, reltol = 1e-14)
Expand All @@ -44,6 +48,7 @@ solnoadjacvec = solve(probnoadjacvec, KenCarp4(autodiff = false), abstol = 1e-14
reltol = 1e-14)
solnoad2 = solve(probnoad, KenCarp4(autodiff = false), abstol = 1e-14, reltol = 1e-14)
solvecmat = solve(probvecmat, Tsit5(), abstol = 1e-14, reltol = 1e-14)
solve_dep_prob_const = solve(probvecmat, Tsit5(), abstol = 1e-14, reltol = 1e-14)

x = sol[1:(sol.prob.f.numindvar), :]

Expand Down Expand Up @@ -125,7 +130,7 @@ _, dp_ts = extract_local_sensitivities(sol, sol.t)
### ForwardDiff version

prob = ODEForwardSensitivityProblem(f.f, [1.0; 1.0], (0.0, 10.0), p,
ForwardDiffSensitivity())
sensealg = ForwardDiffSensitivity())
sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, saveat = 0.01)

xall, dpall = extract_local_sensitivities(sol)
Expand Down Expand Up @@ -184,17 +189,17 @@ f_MM = ODEFunction(rober_MM, mass_matrix = M)
f_no_MM = ODEFunction(rober_no_MM)

prob_MM_ForwardSensitivity = ODEForwardSensitivityProblem(f_MM, u0, tspan, p,
ForwardSensitivity())
sensealg = ForwardSensitivity())
sol_MM_ForwardSensitivity = solve(prob_MM_ForwardSensitivity, Rodas4(autodiff = false),
reltol = 1e-14, abstol = 1e-14)

prob_MM_ForwardDiffSensitivity = ODEForwardSensitivityProblem(f_MM, u0, tspan, p,
ForwardDiffSensitivity())
sensealg = ForwardDiffSensitivity())
sol_MM_ForwardDiffSensitivity = solve(prob_MM_ForwardDiffSensitivity,
Rodas4(autodiff = false), reltol = 1e-14,
abstol = 1e-14)

prob_no_MM = ODEForwardSensitivityProblem(f_no_MM, u0, tspan, p, ForwardSensitivity())
prob_no_MM = ODEForwardSensitivityProblem(f_no_MM, u0, tspan, p, sensealg = ForwardSensitivity())
sol_no_MM = solve(prob_no_MM, Rodas4(autodiff = false), reltol = 1e-14, abstol = 1e-14)

sen_MM_ForwardSensitivity = extract_local_sensitivities(sol_MM_ForwardSensitivity, 10.0,
Expand Down Expand Up @@ -253,7 +258,7 @@ prob = ODEForwardSensitivityProblem(f,
[1.0; 1.0],
(0.0, 10.0),
p,
absolutely_no_ad_sensealg)
sensealg = absolutely_no_ad_sensealg)
@test SciMLSensitivity.has_original_jac(prob.f)
@assert jac_call_count == 0
solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14)
Expand Down

0 comments on commit 6c95130

Please sign in to comment.