diff --git a/src/initialization.jl b/src/initialization.jl index 993d10686..c707b0ff6 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -3,7 +3,7 @@ A collection of all the data required for `OverrideInit`. """ -struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} +struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap, IProbDu0Map} """ The `AbstractNonlinearProblem` to solve for initialization. """ @@ -29,12 +29,18 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} initialized will be returned as-is. """ initializeprobpmap::IProbPmap + """ + A function which takes the solution of `initializeprob` and returns the + `du0` vector of the original problem. + """ + initializeprob_du0map::IProbDu0Map function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K, - initprobpmap::L) where {I, J, K, L} + initprobpmap::L, initprob_du0map::M = nothing) where {I, J, K, L, M} @assert initprob isa Union{SCCNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem} - return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap) + return new{I, J, K, L, M}( + initprob, update_initprob!, initprobmap, initprobpmap, initprob_du0map) end end @@ -171,9 +177,12 @@ Keyword arguments: provided to the `OverrideInit` constructor takes priority over this keyword argument. If the former is `nothing`, this keyword argument will be used. If it is also not provided, an error will be thrown. +- `return_du0`: Whether to use `initializeprob_du0map` (if present) and return + `du0, u0, p, success`. """ function get_initial_values(prob, valp, f, alg::OverrideInit, - iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...) + iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, + reltol = nothing, return_du0 = false, kwargs...) u0 = state_values(valp) p = parameter_values(valp) @@ -214,5 +223,10 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, p = initdata.initializeprobpmap(valp, nlsol) end + if return_du0 + du0 = initdata.initializeprob_du0map === nothing ? nothing : initdata.initializeprob_du0map(nlsol) + return du0, u0, p, SciMLBase.successful_retcode(nlsol) + end + return u0, p, SciMLBase.successful_retcode(nlsol) end diff --git a/test/initialization.jl b/test/initialization.jl index e6211e59c..a2de41c8a 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -229,4 +229,66 @@ end @test p ≈ 0.0 @test success end + + @testset "DAEProblem" begin + function daerhs(du, u, p, t) + return [u[1] * t + p, u[1]^2 - u[2]^2] + end + # unknowns are u[2], p, D(u[1]), D(u[2]). Parameters are u[1], t + initprob = NonlinearProblem([1.0, 1.0, 1.0, 1.0], [1.0, 0.0]) do x, _p + u2, p, du1, du2 = x + u1, t = _p + return [u1^3 - u2^3, p^2 - 2p + 1, du1 - u1 * t - p, 2u1 * du1 - 2u2 * du2] + end + + update_initializeprob! = function (iprob, integ) + iprob.p[1] = integ.u[1] + iprob.p[2] = integ.t + end + initprobmap = function (nlsol) + return [parameter_values(nlsol)[1], nlsol.u[1]] + end + initprobpmap = function (_, nlsol) + return nlsol.u[2] + end + initprob_du0map = function (nlsol) + return nlsol.u[3:4] + end + initialization_data = SciMLBase.OverrideInitData( + initprob, update_initializeprob!, initprobmap, initprobpmap, initprob_du0map) + fn = DAEFunction(daerhs; initialization_data) + prob = DAEProblem(fn, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0) + integ = init(prob, DImplicitEuler(); initializealg = NoInit()) + + initialization_data2 = SciMLBase.OverrideInitData( + initprob, update_initializeprob!, initprobmap, initprobpmap) + fn2 = DAEFunction(daerhs; initialization_data = initialization_data2) + prob2 = DAEProblem(fn2, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0) + integ2 = init(prob2, DImplicitEuler(); initializealg = NoInit()) + + nlsolve_alg = FastShortcutNonlinearPolyalg() + @testset "Doesn't return `du0` by default" begin + @test length(SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), + Val(false); nlsolve_alg, abstol, reltol)) == 3 + end + @testset "`du0 === nothing` if missing `du0map`" begin + du0, u0, p, success = SciMLBase.get_initial_values( + prob2, integ2, fn2, SciMLBase.OverrideInit(), Val(false); + nlsolve_alg, abstol, reltol, return_du0 = true) + @test du0 === nothing + @test u0 ≈ [2.0, 2.0] + @test p ≈ 1.0 + @test success + end + @testset "With `return_du0 = true`" begin + du0, u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), Val(false); + nlsolve_alg, abstol, reltol, return_du0 = true) + @test du0 ≈ [1.0, 1.0] + @test u0 ≈ [2.0, 2.0] + @test p ≈ 1.0 + @test success + end + end end