From f8fb795175cbf9cfa0cc7b22f625302542d5b548 Mon Sep 17 00:00:00 2001 From: JordiManyer Date: Sat, 30 Dec 2023 23:21:10 +1100 Subject: [PATCH] Added restart for GMRES --- src/LinearSolvers/Krylov/FGMRESSolvers.jl | 66 +++++++++++++++----- src/LinearSolvers/Krylov/GMRESSolvers.jl | 76 +++++++++++++++++------ src/SolverInterfaces/ConvergenceLogs.jl | 6 ++ src/SolverInterfaces/SolverInterfaces.jl | 2 +- test/LinearSolvers/KrylovSolversTests.jl | 12 +++- 5 files changed, 123 insertions(+), 39 deletions(-) diff --git a/src/LinearSolvers/Krylov/FGMRESSolvers.jl b/src/LinearSolvers/Krylov/FGMRESSolvers.jl index 22cb5da9..412f8c24 100644 --- a/src/LinearSolvers/Krylov/FGMRESSolvers.jl +++ b/src/LinearSolvers/Krylov/FGMRESSolvers.jl @@ -2,18 +2,25 @@ # FGMRES Solver struct FGMRESSolver <: Gridap.Algebra.LinearSolver m :: Int + restart :: Bool + m_add :: Int Pr :: Gridap.Algebra.LinearSolver Pl :: Union{Gridap.Algebra.LinearSolver,Nothing} - outer_log :: ConvergenceLog{Float64} - inner_log :: ConvergenceLog{Float64} + log :: ConvergenceLog{Float64} end -function FGMRESSolver(m,Pr;Pl=nothing,maxiter=100,atol=1e-12,rtol=1.e-6,verbose=false,name="FGMRES") - outer_tols = SolverTolerances{Float64}(maxiter=maxiter,atol=atol,rtol=rtol) - outer_log = ConvergenceLog(name,outer_tols,verbose=verbose) - inner_tols = SolverTolerances{Float64}(maxiter=m,atol=atol,rtol=rtol) - inner_log = ConvergenceLog("$(name)_inner",inner_tols,verbose=verbose,nested=true) - return FGMRESSolver(m,Pr,Pl,outer_log,inner_log) +function FGMRESSolver(m,Pr;Pl=nothing,restart=false,m_add=1,maxiter=100,atol=1e-12,rtol=1.e-6,verbose=false,name="FGMRES") + tols = SolverTolerances{Float64}(maxiter=maxiter,atol=atol,rtol=rtol) + log = ConvergenceLog(name,tols,verbose=verbose) + return FGMRESSolver(m,restart,m_add,Pr,Pl,log) +end + +function restart(s::FGMRESSolver,k::Int) + if s.restart && (k > s.m) + print_message(s.log,"Restarting Krylov basis.") + return true + end + return false end AbstractTrees.children(s::FGMRESSolver) = [s.Pr,s.Pl] @@ -48,10 +55,34 @@ function get_solver_caches(solver::FGMRESSolver,A) return (V,Z,zl,H,g,c,s) end +function krylov_cache_length(ns::FGMRESNumericalSetup) + V, _, _, _, _, _, _ = ns.caches + return length(V) - 1 +end + +function expand_krylov_caches!(ns::FGMRESNumericalSetup) + V, Z, zl, H, g, c, s = ns.caches + + m = krylov_cache_length(ns) + m_add = ns.solver.m_add + m_new = m + m_add + + for _ in 1:m_add + push!(V,allocate_in_domain(ns.A)) + push!(Z,allocate_in_domain(ns.A)) + end + H_new = zeros(eltype(H),m_new+1,m_new); H_new[1:m+1,1:m] .= H + g_new = zeros(eltype(g),m_new+1); g_new[1:m+1] .= g + c_new = zeros(eltype(c),m_new); c_new[1:m] .= c + s_new = zeros(eltype(s),m_new); s_new[1:m] .= s + ns.caches = (V,Z,zl,H_new,g_new,c_new,s_new) + return H_new,g_new,c_new,s_new +end + function Gridap.Algebra.numerical_setup(ss::FGMRESSymbolicSetup, A::AbstractMatrix) solver = ss.solver Pr_ns = numerical_setup(symbolic_setup(solver.Pr,A),A) - Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A),A) + Pl_ns = !isnothing(solver.Pl) ? numerical_setup(symbolic_setup(solver.Pl,A),A) : nothing caches = get_solver_caches(solver,A) return FGMRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches) end @@ -59,7 +90,7 @@ end function Gridap.Algebra.numerical_setup(ss::FGMRESSymbolicSetup, A::AbstractMatrix, x::AbstractVector) solver = ss.solver Pr_ns = numerical_setup(symbolic_setup(solver.Pr,A,x),A,x) - Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A,x),A,x) + Pl_ns = !isnothing(solver.Pl) ? numerical_setup(symbolic_setup(solver.Pl,A,x),A,x) : nothing caches = get_solver_caches(solver,A) return FGMRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches) end @@ -84,8 +115,9 @@ end function Gridap.Algebra.solve!(x::AbstractVector,ns::FGMRESNumericalSetup,b::AbstractVector) solver, A, Pl, Pr, caches = ns.solver, ns.A, ns.Pl_ns, ns.Pr_ns, ns.caches - log, ilog = solver.outer_log, solver.inner_log V, Z, zl, H, g, c, s = caches + m = krylov_cache_length(ns) + log = solver.log fill!(V[1],zero(eltype(V[1]))) fill!(zl,zero(eltype(zl))) @@ -100,8 +132,13 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::FGMRESNumericalSetup,b::Abs V[1] ./= β fill!(H,0.0) fill!(g,0.0); g[1] = β - idone = init!(ilog,β) - while !idone + while !done && !restart(solver,j) + # Expand Krylov basis if needed + if j > m + H, g, c, s = expand_krylov_caches!(ns) + m = krylov_cache_length(ns) + end + # Arnoldi orthogonalization by Modified Gram-Schmidt fill!(V[j+1],zero(eltype(V[j+1]))) fill!(Z[j],zero(eltype(Z[j]))) @@ -127,7 +164,7 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::FGMRESNumericalSetup,b::Abs β = abs(g[j+1]) j += 1 - idone = update!(ilog,β) + done = update!(log,β) end j = j-1 @@ -141,7 +178,6 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::FGMRESNumericalSetup,b::Abs x .+= g[i] .* Z[i] end krylov_residual!(V[1],x,A,b,Pl,zl) - done = update!(log,β) end finalize!(log,β) diff --git a/src/LinearSolvers/Krylov/GMRESSolvers.jl b/src/LinearSolvers/Krylov/GMRESSolvers.jl index 14d603e9..b7afa0bf 100644 --- a/src/LinearSolvers/Krylov/GMRESSolvers.jl +++ b/src/LinearSolvers/Krylov/GMRESSolvers.jl @@ -1,18 +1,25 @@ # GMRES Solver struct GMRESSolver <: Gridap.Algebra.LinearSolver - m :: Int - Pr :: Union{Gridap.Algebra.LinearSolver,Nothing} - Pl :: Union{Gridap.Algebra.LinearSolver,Nothing} - outer_log :: ConvergenceLog{Float64} - inner_log :: ConvergenceLog{Float64} + m :: Int + restart :: Bool + m_add :: Int + Pr :: Union{Gridap.Algebra.LinearSolver,Nothing} + Pl :: Union{Gridap.Algebra.LinearSolver,Nothing} + log :: ConvergenceLog{Float64} end -function GMRESSolver(m;Pr=nothing,Pl=nothing,maxiter=100,atol=1e-12,rtol=1.e-6,verbose=false,name="GMRES") - outer_tols = SolverTolerances{Float64}(maxiter=maxiter,atol=atol,rtol=rtol) - outer_log = ConvergenceLog(name,outer_tols,verbose=verbose) - inner_tols = SolverTolerances{Float64}(maxiter=m,atol=atol,rtol=rtol) - inner_log = ConvergenceLog("$(name)_inner",inner_tols,verbose=verbose,nested=true) - return GMRESSolver(m,Pr,Pl,outer_log,inner_log) +function GMRESSolver(m;Pr=nothing,Pl=nothing,restart=false,m_add=1,maxiter=100,atol=1e-12,rtol=1.e-6,verbose=false,name="GMRES") + tols = SolverTolerances{Float64}(maxiter=maxiter,atol=atol,rtol=rtol) + log = ConvergenceLog(name,tols,verbose=verbose) + return GMRESSolver(m,restart,m_add,Pr,Pl,log) +end + +function restart(s::GMRESSolver,k::Int) + if s.restart && (k > s.m) + print_message(s.log,"Restarting Krylov basis.") + return true + end + return false end AbstractTrees.children(s::GMRESSolver) = [s.Pr,s.Pl] @@ -37,7 +44,7 @@ function get_solver_caches(solver::GMRESSolver,A) m, Pl, Pr = solver.m, solver.Pl, solver.Pr V = [allocate_in_domain(A) for i in 1:m+1] - zr = !isa(Pr,Nothing) ? allocate_in_domain(A) : nothing + zr = !isnothing(Pr) ? allocate_in_domain(A) : nothing zl = allocate_in_domain(A) H = zeros(m+1,m) # Hessenberg matrix @@ -47,10 +54,33 @@ function get_solver_caches(solver::GMRESSolver,A) return (V,zr,zl,H,g,c,s) end +function krylov_cache_length(ns::GMRESNumericalSetup) + V, _, _, _, _, _, _ = ns.caches + return length(V) - 1 +end + +function expand_krylov_caches!(ns::GMRESNumericalSetup) + V, zr, zl, H, g, c, s = ns.caches + + m = krylov_cache_length(ns) + m_add = ns.solver.m_add + m_new = m + m_add + + for _ in 1:m_add + push!(V,allocate_in_domain(ns.A)) + end + H_new = zeros(eltype(H),m_new+1,m_new); H_new[1:m+1,1:m] .= H + g_new = zeros(eltype(g),m_new+1); g_new[1:m+1] .= g + c_new = zeros(eltype(c),m_new); c_new[1:m] .= c + s_new = zeros(eltype(s),m_new); s_new[1:m] .= s + ns.caches = (V,zr,zl,H_new,g_new,c_new,s_new) + return H_new,g_new,c_new,s_new +end + function Gridap.Algebra.numerical_setup(ss::GMRESSymbolicSetup, A::AbstractMatrix) solver = ss.solver - Pr_ns = isa(solver.Pr,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pr,A),A) - Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A),A) + Pr_ns = !isnothing(solver.Pr) ? numerical_setup(symbolic_setup(solver.Pr,A),A) : nothing + Pl_ns = !isnothing(solver.Pl) ? numerical_setup(symbolic_setup(solver.Pl,A),A) : nothing caches = get_solver_caches(solver,A) return GMRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches) end @@ -79,11 +109,12 @@ end function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::AbstractVector) solver, A, Pl, Pr, caches = ns.solver, ns.A, ns.Pl_ns, ns.Pr_ns, ns.caches - log, ilog = solver.outer_log, solver.inner_log V, zr, zl, H, g, c, s = caches + m = krylov_cache_length(ns) + log = solver.log fill!(V[1],zero(eltype(V[1]))) - fill!(zr,zero(eltype(zr))) + !isnothing(zr) && fill!(zr,zero(eltype(zr))) fill!(zl,zero(eltype(zl))) # Initial residual @@ -96,8 +127,13 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::Abst V[1] ./= β fill!(H,0.0) fill!(g,0.0); g[1] = β - idone = init!(ilog,β) - while !idone + while !done && !restart(solver,j) + # Expand Krylov basis if needed + if j > m + H, g, c, s = expand_krylov_caches!(ns) + m = krylov_cache_length(ns) + end + # Arnoldi orthogonalization by Modified Gram-Schmidt fill!(V[j+1],zero(eltype(V[j+1]))) krylov_mul!(V[j+1],A,V[j],Pr,Pl,zr,zl) @@ -122,7 +158,7 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::Abst β = abs(g[j+1]) j += 1 - idone = update!(ilog,β) + done = update!(log,β) end j = j-1 @@ -145,8 +181,8 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::Abst x .+= zr end krylov_residual!(V[1],x,A,b,Pl,zl) - done = update!(log,β) end + finalize!(log,β) return x end diff --git a/src/SolverInterfaces/ConvergenceLogs.jl b/src/SolverInterfaces/ConvergenceLogs.jl index 4e29888c..45790328 100644 --- a/src/SolverInterfaces/ConvergenceLogs.jl +++ b/src/SolverInterfaces/ConvergenceLogs.jl @@ -82,6 +82,12 @@ function finalize!(log::ConvergenceLog{T},r::T) where T return flag end +function print_message(log::ConvergenceLog{T},msg::String) where T + if log.verbose > SOLVER_VERBOSE_LOW + println(get_tabulation(log),msg) + end +end + function Base.show(io::IO,k::MIME"text/plain",log::ConvergenceLog) println(io,"ConvergenceLog[$(log.name)]") println(io," > tols: $(summary(log.tols))") diff --git a/src/SolverInterfaces/SolverInterfaces.jl b/src/SolverInterfaces/SolverInterfaces.jl index 812bf868..24563838 100644 --- a/src/SolverInterfaces/SolverInterfaces.jl +++ b/src/SolverInterfaces/SolverInterfaces.jl @@ -14,7 +14,7 @@ include("SolverInfos.jl") export SolverVerboseLevel, SolverConvergenceFlag export SolverTolerances, get_solver_tolerances, set_solver_tolerances! -export ConvergenceLog, init!, update!, finalize!, reset! +export ConvergenceLog, init!, update!, finalize!, reset!, print_message export SolverInfo diff --git a/test/LinearSolvers/KrylovSolversTests.jl b/test/LinearSolvers/KrylovSolversTests.jl index f0857fe2..ca114ff7 100644 --- a/test/LinearSolvers/KrylovSolversTests.jl +++ b/test/LinearSolvers/KrylovSolversTests.jl @@ -16,7 +16,7 @@ function test_solver(solver,op,Uh,dΩ) A, b = get_matrix(op), get_vector(op); ns = numerical_setup(symbolic_setup(solver,A),A) - x = allocate_in_domain(A) + x = allocate_in_domain(A); fill!(x,0.0) solve!(x,ns,b) u = interpolate(sol,Uh) @@ -69,10 +69,16 @@ function main(distribute,np) test_solver(gmres,op,Uh,dΩ) # GMRES without preconditioner - gmres = LinearSolvers.GMRESSolver(40;rtol=1.e-8,verbose=verbose) + gmres = LinearSolvers.GMRESSolver(10;rtol=1.e-8,verbose=verbose) test_solver(gmres,op,Uh,dΩ) - fgmres = LinearSolvers.FGMRESSolver(40,P;rtol=1.e-8,verbose=verbose) + gmres = LinearSolvers.GMRESSolver(10;restart=true,rtol=1.e-8,verbose=verbose) + test_solver(gmres,op,Uh,dΩ) + + fgmres = LinearSolvers.FGMRESSolver(10,P;rtol=1.e-8,verbose=verbose) + test_solver(fgmres,op,Uh,dΩ) + + fgmres = LinearSolvers.FGMRESSolver(10,P;restart=true,rtol=1.e-8,verbose=verbose) test_solver(fgmres,op,Uh,dΩ) pcg = LinearSolvers.CGSolver(P;rtol=1.e-8,verbose=verbose)