Skip to content

Commit

Permalink
Simplify problem constructors and minor cleanups.
Browse files Browse the repository at this point in the history
  • Loading branch information
michakraus committed Dec 5, 2024
1 parent c0d17f8 commit 7c2ddef
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 71 deletions.
15 changes: 2 additions & 13 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,7 @@ function HODE(lsys::HamiltonianSystem; kwargs...)
HODE(eqs.v, eqs.f, eqs.H; kwargs...)
end

function HODEProblem(lsys::HamiltonianSystem, tspan::Tuple, tstep::Real, ics::NamedTuple; kwargs...)
function HODEProblem(lsys::HamiltonianSystem, tspan::Tuple, tstep::Real, ics...; kwargs...)
eqs = functions(lsys)
HODEProblem(eqs.v, eqs.f, eqs.H, tspan, tstep, ics; kwargs...)
end

function HODEProblem(lsys::HamiltonianSystem, tspan::Tuple, tstep::Real, q₀::StateVariable, p₀::StateVariable; kwargs...)
ics = (q = q₀, p = p₀)
HODEProblem(lsys, tspan, tstep, ics; kwargs...)
end

function HODEProblem(lsys::HamiltonianSystem, tspan::Tuple, tstep::Real, q₀::AbstractArray, p₀::AbstractArray; kwargs...)
_q₀ = StateVariable(q₀)
_p₀ = StateVariable(p₀)
HODEProblem(lsys, tspan, tstep, _q₀, _p₀; kwargs...)
HODEProblem(eqs.v, eqs.f, eqs.H, tspan, tstep, ics...; kwargs...)
end
16 changes: 2 additions & 14 deletions src/lagrangian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,7 @@ function LODE(lsys::LagrangianSystem; kwargs...)
LODE(eqs.ϑ, eqs.f, eqs.g, eqs.ω, eqs.L; kwargs...)
end

function LODEProblem(lsys::LagrangianSystem, tspan::Tuple, tstep::Real, ics::NamedTuple; kwargs...)
function LODEProblem(lsys::LagrangianSystem, tspan::Tuple, tstep::Real, ics...; kwargs...)
eqs = functions(lsys)
LODEProblem(eqs.ϑ, eqs.f, eqs.g, eqs.ω, eqs.L, tspan, tstep, ics; kwargs...)
end

function LODEProblem(lsys::LagrangianSystem, tspan::Tuple, tstep::Real, q₀::StateVariable, p₀::StateVariable, λ₀::AlgebraicVariable; kwargs...)
ics = (q = q₀, p = p₀, λ = λ₀)
LODEProblem(lsys, tspan, tstep, ics; kwargs...)
end

function LODEProblem(lsys::LagrangianSystem, tspan::Tuple, tstep::Real, q₀::AbstractArray, p₀::AbstractArray, λ₀::AbstractArray = zero(q₀); kwargs...)
_q₀ = StateVariable(q₀)
_p₀ = StateVariable(p₀)
_λ₀ = AlgebraicVariable(λ₀)
LODEProblem(lsys, tspan, tstep, _q₀, _p₀, _λ₀; kwargs...)
LODEProblem(eqs.ϑ, eqs.f, eqs.g, eqs.ω, eqs.L, tspan, tstep, ics...; kwargs...)
end
51 changes: 9 additions & 42 deletions src/lagrangian_degenerate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct DegenerateLagrangianSystem

code = (
L = substitute_parameters(build_function(equs_subs.L, t, X, V, params...; nanmath = false), params),
H = substitute_parameters(build_function(equs_subs.H, t, X, V, params...; nanmath = false), params),
H = substitute_parameters(build_function(equs_subs.H, t, X, params...; nanmath = false), params),
EL = substitute_parameters(build_function(equs_subs.EL, t, X, V, params...; nanmath = false)[2], params),
∇H = substitute_parameters(build_function(equs_subs.∇H, t, X, V, params...; nanmath = false)[2], params),
= substitute_parameters(build_function(equs_subs.ẋ, t, X, params...; nanmath = false)[2], params),
Expand Down Expand Up @@ -119,62 +119,29 @@ function ODE(lsys::DegenerateLagrangianSystem; kwargs...)
ODE(eqs.ẋ; invariants = (h = eqs.H,), kwargs...)
end

function ODEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, ics::NamedTuple; kwargs...)
function ODEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, ics...; kwargs...)
eqs = functions(lsys)
ODEProblem(eqs.ẋ, tspan, tstep, ics; invariants = (h = eqs.H,), kwargs...)
end

function ODEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, q₀::StateVariable; kwargs...)
ODEProblem(lsys, tspan, tstep, (q = q₀, ); kwargs...)
end

function ODEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, q₀::AbstractArray; kwargs...)
ODEProblem(lsys, tspan, tstep, StateVariable(q₀); kwargs...)
ODEProblem(eqs.ẋ, tspan, tstep, ics...; invariants = (h = eqs.H,), kwargs...)
end


function LODE(lsys::DegenerateLagrangianSystem; v̄ = functions(lsys).v, f̄ = functions(lsys).f, kwargs...)
eqs = functions(lsys)
LODE(eqs.ϑ, eqs.f, eqs.g, eqs.ω, eqs.L; v̄ = v̄, f̄ = f̄, invariants = (h = eqs.H,), kwargs...)
LODE(eqs.ϑ, eqs.f, eqs.g, eqs.ω, eqs.L; v̄ = v̄, f̄ = f̄, invariants = (h = (t,q,v,params) -> eqs.H(t,q,params),), kwargs...)
end

function LODEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, ics::NamedTuple; v̄ = functions(lsys).v, f̄ = functions(lsys).f, kwargs...)
function LODEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, ics...; v̄ = functions(lsys).v, f̄ = functions(lsys).f, kwargs...)
eqs = functions(lsys)
LODEProblem(eqs.ϑ, eqs.f, eqs.g, eqs.ω, eqs.L, tspan, tstep, ics; v̄ = v̄, f̄ = f̄, invariants = (h = eqs.H,), kwargs...)
end

function LODEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, q₀::StateVariable, p₀::StateVariable, λ₀::AlgebraicVariable; kwargs...)
ics = (q = q₀, p = p₀, λ = λ₀)
LODEProblem(lsys, tspan, tstep, ics; kwargs...)
end

function LODEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, q₀::AbstractArray, p₀::AbstractArray, λ₀::AbstractArray = zero(q₀); kwargs...)
_q₀ = StateVariable(q₀)
_p₀ = StateVariable(p₀)
_λ₀ = AlgebraicVariable(λ₀)
LODEProblem(lsys, tspan, tstep, _q₀, _p₀, _λ₀; kwargs...)
LODEProblem(eqs.ϑ, eqs.f, eqs.g, eqs.ω, eqs.L, tspan, tstep, ics...; v̄ = v̄, f̄ = f̄, invariants = (h = (t,q,v,params) -> eqs.H(t,q,params),), kwargs...)
end


function LDAE(lsys::DegenerateLagrangianSystem; v̄ = functions(lsys).v, f̄ = functions(lsys).f, kwargs...)
eqs = functions(lsys)
LDAE(eqs.ϑ, eqs.f, eqs.u, eqs.g, eqs.ϕ, eqs.ū, eqs.ḡ, eqs.ψ, eqs.ω, eqs.L; v̄ = v̄, f̄ = f̄, invariants = (h = eqs.H,), kwargs...)
LDAE(eqs.ϑ, eqs.f, eqs.u, eqs.g, eqs.ϕ, eqs.ū, eqs.ḡ, eqs.ψ, eqs.ω, eqs.L; v̄ = v̄, f̄ = f̄, invariants = (h = (t,q,v,params) -> eqs.H(t,q,params),), kwargs...)
end

function LDAEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, ics::NamedTuple; v̄ = functions(lsys).v, f̄ = functions(lsys).f, kwargs...)
function LDAEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, ics...; v̄ = functions(lsys).v, f̄ = functions(lsys).f, kwargs...)
eqs = functions(lsys)
LDAEProblem(eqs.ϑ, eqs.f, eqs.u, eqs.g, eqs.ϕ, eqs.ū, eqs.ḡ, eqs.ψ, eqs.ω, eqs.L, tspan, tstep, ics; v̄ = v̄, f̄ = f̄, invariants = (h = eqs.H,), kwargs...)
end

function LDAEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, q₀::StateVariable, p₀::StateVariable, λ₀::AlgebraicVariable, μ₀::AlgebraicVariable; kwargs...)
ics = (q = q₀, p = p₀, λ = λ₀, μ = μ₀)
LDAEProblem(lsys, tspan, tstep, ics; kwargs...)
end

function LDAEProblem(lsys::DegenerateLagrangianSystem, tspan::Tuple, tstep::Real, q₀::AbstractArray, p₀::AbstractArray, λ₀::AbstractArray = zero(q₀), μ₀::AbstractArray = zero(λ₀); kwargs...)
_q₀ = StateVariable(q₀)
_p₀ = StateVariable(p₀)
_λ₀ = AlgebraicVariable(λ₀)
_μ₀ = AlgebraicVariable(μ₀)
LDAEProblem(lsys, tspan, tstep, _q₀, _p₀, _λ₀, _μ₀; kwargs...)
LDAEProblem(eqs.ϑ, eqs.f, eqs.u, eqs.g, eqs.ϕ, eqs.ū, eqs.ḡ, eqs.ψ, eqs.ω, eqs.L, tspan, tstep, ics...; v̄ = v̄, f̄ = f̄, invariants = (h = (t,q,v,params) -> eqs.H(t,q,params),), kwargs...)
end
7 changes: 5 additions & 2 deletions test/lagrangian_lotka_volterra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ end

t₀, q₀, v₀ = (0.0, [2.0, 1.0], [0.5, 2.0])
p₀ = zero(v₀)
λ₀ = zero(v₀)

tspan = (0.0, 1.0)
tstep = 0.1
Expand All @@ -55,6 +56,8 @@ params_alt = (
b₂ = 1.0,
)

(p₀, t₀, q₀, v₀, params)


# Symbolic variables and parameters

Expand Down Expand Up @@ -117,7 +120,7 @@ f̃(ṗ₂, t₀, q₀, v₀, params)
@test_nowarn ODEProblem(deg_lag_sys, tspan, tstep, q₀; parameters = params)

@test_nowarn LODE(deg_lag_sys)
@test_nowarn LODEProblem(deg_lag_sys, tspan, tstep, q₀, v₀; parameters = params)
@test_nowarn LODEProblem(deg_lag_sys, tspan, tstep, q₀, p₀; parameters = params)

@test_nowarn LDAE(deg_lag_sys)
@test_nowarn LDAEProblem(deg_lag_sys, tspan, tstep, q₀, v₀; parameters = params)
@test_nowarn LDAEProblem(deg_lag_sys, tspan, tstep, q₀, p₀, λ₀; parameters = params)

0 comments on commit 7c2ddef

Please sign in to comment.