From d4c6cf2b82cd422ec730a02140fc3fcaf4b0f084 Mon Sep 17 00:00:00 2001 From: Francesc Verdugo Date: Fri, 15 Nov 2024 08:20:09 +0100 Subject: [PATCH] Quick draft for backward_euler --- PartitionedSolvers/src/PartitionedSolvers.jl | 1 + PartitionedSolvers/src/ode_solvers.jl | 66 ++++++++++++++++++++ PartitionedSolvers/test/ode_solvers_tests.jl | 43 +++++++++++++ PartitionedSolvers/test/runtests.jl | 1 + 4 files changed, 111 insertions(+) create mode 100644 PartitionedSolvers/src/ode_solvers.jl create mode 100644 PartitionedSolvers/test/ode_solvers_tests.jl diff --git a/PartitionedSolvers/src/PartitionedSolvers.jl b/PartitionedSolvers/src/PartitionedSolvers.jl index 20c3ebe6..e3fd8a25 100644 --- a/PartitionedSolvers/src/PartitionedSolvers.jl +++ b/PartitionedSolvers/src/PartitionedSolvers.jl @@ -14,5 +14,6 @@ include("wrappers.jl") include("smoothers.jl") include("amg.jl") include("nonlinear_solvers.jl") +include("ode_solvers.jl") end # module diff --git a/PartitionedSolvers/src/ode_solvers.jl b/PartitionedSolvers/src/ode_solvers.jl new file mode 100644 index 00000000..57166f16 --- /dev/null +++ b/PartitionedSolvers/src/ode_solvers.jl @@ -0,0 +1,66 @@ + +function nonlinear_stage_problem(f,ode0,t,x0,coeffs) + workspace = nothing + nonlinear_problem(x0,residual(ode0),jacobian(ode0),workspace) do p + x = solution(p) + r = residual(p) + j = jacobian(p) + ode = update(ode0,coefficients=coeffs,residual=r,jacobian=j,solution=(t,f(x)...)) + r = residual(ode) + j = jacobian(ode) + p = update(p,residual=r,jacobian=j) + end +end + +function backward_euler_update(workspace,ode) + (;s,x,dt,ex) = workspace + t, = solution(ode) + x = solution(s) + coeffs = coefficients(ode) + p = nonlinear_stage_problem(ex,ode,t,x,coeffs) + p = update(p,solution=x) + s = update(s,problem=p) + workspace = (;s,x,dt,ex) +end + +function backward_euler_step(workspace,ode,phase=:start) + (;s,x,dt,ex) = workspace + t,u,v = solution(ode) + if phase === :start + t = first(interval(ode)) + phase = :advance + end + s = solve(s) + x = solution(s) + t += dt + u,v = ex(x) + u .= x + ode = update(ode,solution=(t,u,v)) + tend = last(interval(ode)) + if t >= tend + phase = :stop + end + workspace = (;s,x,dt,ex) + workspace = backward_euler_update(workspace,ode) + workspace,ode,phase +end + +function backward_euler(ode; + dt = (interval(ode)[end]-interval(ode)[1])/10, + solver = default_solver, + ) + @assert uses_mutable_types(ode) + coeffs = (1.0,1/dt) + t,u,v = solution(ode) + t = interval(ode)[1] + x = copy(u) + function ex(x) + v .= (x .- u) ./ dt + (x,v) + end + p = nonlinear_stage_problem(ex,ode,t,x,coeffs) + s = solver(p) + workspace = (;s,x,dt,ex) + ode_solver(backward_euler_update,backward_euler_step,ode,workspace) +end + diff --git a/PartitionedSolvers/test/ode_solvers_tests.jl b/PartitionedSolvers/test/ode_solvers_tests.jl new file mode 100644 index 00000000..e2c1a5e7 --- /dev/null +++ b/PartitionedSolvers/test/ode_solvers_tests.jl @@ -0,0 +1,43 @@ +module ODESolversTests + +import PartitionedSolvers as PS +using Test +import PartitionedArrays as PA + +function mock_ode(u) + r = zeros(1) + j = PA.sparse_matrix([1],[1],[0.0],1,1) + v = 0*u + x = (0,u,v) + ts = (0,10) + coeffs = (1.,1.) + workspace = nothing + PS.ode_problem(x,r,j,ts,coeffs,workspace) do ode + (t,u2,v2) = PS.solution(ode) + du,dv = PS.coefficients(ode) + r = PS.residual(ode) + j = PS.jacobian(ode) + if r !== nothing + r .= 2 .* u2.^2 .+ v2 .- 4*t .+ 1 + ode = PS.update(ode,residual = r) + end + if j !== nothing + j .= 4 .* u2 .* du .+ dv + ode = PS.update(ode,jacobian = j) + end + ode + end |> PS.update +end + +u = [2.0] +p = mock_ode(u) +s = PS.backward_euler(p) +for s in PS.history(s) + t,u,v = PS.solution(s) + @show PS.solution(s) + @test v[1] != 0 +end + + + +end # module diff --git a/PartitionedSolvers/test/runtests.jl b/PartitionedSolvers/test/runtests.jl index 3da585dd..2618f750 100644 --- a/PartitionedSolvers/test/runtests.jl +++ b/PartitionedSolvers/test/runtests.jl @@ -8,6 +8,7 @@ using Test @testset "interfaces" begin include("interfaces_tests.jl") end @testset "wrappers" begin include("wrappers_tests.jl") end @testset "nonlinear_solvers" begin include("nonlinear_solvers_tests.jl") end + @testset "ode_solvers" begin include("ode_solvers_tests.jl") end @testset "smoothers" begin include("smoothers_tests.jl") end @testset "amg" begin include("amg_tests.jl") end end