Skip to content

Commit

Permalink
Added restart for GMRES
Browse files Browse the repository at this point in the history
  • Loading branch information
JordiManyer committed Dec 30, 2023
1 parent 119329f commit f8fb795
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 39 deletions.
66 changes: 51 additions & 15 deletions src/LinearSolvers/Krylov/FGMRESSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@
# FGMRES Solver
struct FGMRESSolver <: Gridap.Algebra.LinearSolver
m :: Int
restart :: Bool
m_add :: Int
Pr :: Gridap.Algebra.LinearSolver
Pl :: Union{Gridap.Algebra.LinearSolver,Nothing}
outer_log :: ConvergenceLog{Float64}
inner_log :: ConvergenceLog{Float64}
log :: ConvergenceLog{Float64}
end

function FGMRESSolver(m,Pr;Pl=nothing,maxiter=100,atol=1e-12,rtol=1.e-6,verbose=false,name="FGMRES")
outer_tols = SolverTolerances{Float64}(maxiter=maxiter,atol=atol,rtol=rtol)
outer_log = ConvergenceLog(name,outer_tols,verbose=verbose)
inner_tols = SolverTolerances{Float64}(maxiter=m,atol=atol,rtol=rtol)
inner_log = ConvergenceLog("$(name)_inner",inner_tols,verbose=verbose,nested=true)
return FGMRESSolver(m,Pr,Pl,outer_log,inner_log)
function FGMRESSolver(m,Pr;Pl=nothing,restart=false,m_add=1,maxiter=100,atol=1e-12,rtol=1.e-6,verbose=false,name="FGMRES")
tols = SolverTolerances{Float64}(maxiter=maxiter,atol=atol,rtol=rtol)
log = ConvergenceLog(name,tols,verbose=verbose)
return FGMRESSolver(m,restart,m_add,Pr,Pl,log)
end

function restart(s::FGMRESSolver,k::Int)
if s.restart && (k > s.m)
print_message(s.log,"Restarting Krylov basis.")
return true
end
return false
end

AbstractTrees.children(s::FGMRESSolver) = [s.Pr,s.Pl]
Expand Down Expand Up @@ -48,18 +55,42 @@ function get_solver_caches(solver::FGMRESSolver,A)
return (V,Z,zl,H,g,c,s)
end

function krylov_cache_length(ns::FGMRESNumericalSetup)
V, _, _, _, _, _, _ = ns.caches
return length(V) - 1
end

function expand_krylov_caches!(ns::FGMRESNumericalSetup)
V, Z, zl, H, g, c, s = ns.caches

m = krylov_cache_length(ns)
m_add = ns.solver.m_add
m_new = m + m_add

for _ in 1:m_add
push!(V,allocate_in_domain(ns.A))
push!(Z,allocate_in_domain(ns.A))
end
H_new = zeros(eltype(H),m_new+1,m_new); H_new[1:m+1,1:m] .= H
g_new = zeros(eltype(g),m_new+1); g_new[1:m+1] .= g
c_new = zeros(eltype(c),m_new); c_new[1:m] .= c
s_new = zeros(eltype(s),m_new); s_new[1:m] .= s
ns.caches = (V,Z,zl,H_new,g_new,c_new,s_new)
return H_new,g_new,c_new,s_new
end

function Gridap.Algebra.numerical_setup(ss::FGMRESSymbolicSetup, A::AbstractMatrix)
solver = ss.solver
Pr_ns = numerical_setup(symbolic_setup(solver.Pr,A),A)
Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A),A)
Pl_ns = !isnothing(solver.Pl) ? numerical_setup(symbolic_setup(solver.Pl,A),A) : nothing
caches = get_solver_caches(solver,A)
return FGMRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches)
end

function Gridap.Algebra.numerical_setup(ss::FGMRESSymbolicSetup, A::AbstractMatrix, x::AbstractVector)
solver = ss.solver
Pr_ns = numerical_setup(symbolic_setup(solver.Pr,A,x),A,x)
Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A,x),A,x)
Pl_ns = !isnothing(solver.Pl) ? numerical_setup(symbolic_setup(solver.Pl,A,x),A,x) : nothing
caches = get_solver_caches(solver,A)
return FGMRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches)
end
Expand All @@ -84,8 +115,9 @@ end

function Gridap.Algebra.solve!(x::AbstractVector,ns::FGMRESNumericalSetup,b::AbstractVector)
solver, A, Pl, Pr, caches = ns.solver, ns.A, ns.Pl_ns, ns.Pr_ns, ns.caches
log, ilog = solver.outer_log, solver.inner_log
V, Z, zl, H, g, c, s = caches
m = krylov_cache_length(ns)
log = solver.log

fill!(V[1],zero(eltype(V[1])))
fill!(zl,zero(eltype(zl)))
Expand All @@ -100,8 +132,13 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::FGMRESNumericalSetup,b::Abs
V[1] ./= β
fill!(H,0.0)
fill!(g,0.0); g[1] = β
idone = init!(ilog,β)
while !idone
while !done && !restart(solver,j)
# Expand Krylov basis if needed
if j > m
H, g, c, s = expand_krylov_caches!(ns)
m = krylov_cache_length(ns)
end

# Arnoldi orthogonalization by Modified Gram-Schmidt
fill!(V[j+1],zero(eltype(V[j+1])))
fill!(Z[j],zero(eltype(Z[j])))
Expand All @@ -127,7 +164,7 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::FGMRESNumericalSetup,b::Abs

β = abs(g[j+1])
j += 1
idone = update!(ilog,β)
done = update!(log,β)
end
j = j-1

Expand All @@ -141,7 +178,6 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::FGMRESNumericalSetup,b::Abs
x .+= g[i] .* Z[i]
end
krylov_residual!(V[1],x,A,b,Pl,zl)
done = update!(log,β)
end

finalize!(log,β)
Expand Down
76 changes: 56 additions & 20 deletions src/LinearSolvers/Krylov/GMRESSolvers.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
# GMRES Solver
struct GMRESSolver <: Gridap.Algebra.LinearSolver
m :: Int
Pr :: Union{Gridap.Algebra.LinearSolver,Nothing}
Pl :: Union{Gridap.Algebra.LinearSolver,Nothing}
outer_log :: ConvergenceLog{Float64}
inner_log :: ConvergenceLog{Float64}
m :: Int
restart :: Bool
m_add :: Int
Pr :: Union{Gridap.Algebra.LinearSolver,Nothing}
Pl :: Union{Gridap.Algebra.LinearSolver,Nothing}
log :: ConvergenceLog{Float64}
end

function GMRESSolver(m;Pr=nothing,Pl=nothing,maxiter=100,atol=1e-12,rtol=1.e-6,verbose=false,name="GMRES")
outer_tols = SolverTolerances{Float64}(maxiter=maxiter,atol=atol,rtol=rtol)
outer_log = ConvergenceLog(name,outer_tols,verbose=verbose)
inner_tols = SolverTolerances{Float64}(maxiter=m,atol=atol,rtol=rtol)
inner_log = ConvergenceLog("$(name)_inner",inner_tols,verbose=verbose,nested=true)
return GMRESSolver(m,Pr,Pl,outer_log,inner_log)
function GMRESSolver(m;Pr=nothing,Pl=nothing,restart=false,m_add=1,maxiter=100,atol=1e-12,rtol=1.e-6,verbose=false,name="GMRES")
tols = SolverTolerances{Float64}(maxiter=maxiter,atol=atol,rtol=rtol)
log = ConvergenceLog(name,tols,verbose=verbose)
return GMRESSolver(m,restart,m_add,Pr,Pl,log)
end

function restart(s::GMRESSolver,k::Int)
if s.restart && (k > s.m)
print_message(s.log,"Restarting Krylov basis.")
return true
end
return false
end

AbstractTrees.children(s::GMRESSolver) = [s.Pr,s.Pl]
Expand All @@ -37,7 +44,7 @@ function get_solver_caches(solver::GMRESSolver,A)
m, Pl, Pr = solver.m, solver.Pl, solver.Pr

V = [allocate_in_domain(A) for i in 1:m+1]
zr = !isa(Pr,Nothing) ? allocate_in_domain(A) : nothing
zr = !isnothing(Pr) ? allocate_in_domain(A) : nothing
zl = allocate_in_domain(A)

H = zeros(m+1,m) # Hessenberg matrix
Expand All @@ -47,10 +54,33 @@ function get_solver_caches(solver::GMRESSolver,A)
return (V,zr,zl,H,g,c,s)
end

function krylov_cache_length(ns::GMRESNumericalSetup)
V, _, _, _, _, _, _ = ns.caches
return length(V) - 1
end

function expand_krylov_caches!(ns::GMRESNumericalSetup)
V, zr, zl, H, g, c, s = ns.caches

m = krylov_cache_length(ns)
m_add = ns.solver.m_add
m_new = m + m_add

for _ in 1:m_add
push!(V,allocate_in_domain(ns.A))
end
H_new = zeros(eltype(H),m_new+1,m_new); H_new[1:m+1,1:m] .= H
g_new = zeros(eltype(g),m_new+1); g_new[1:m+1] .= g
c_new = zeros(eltype(c),m_new); c_new[1:m] .= c
s_new = zeros(eltype(s),m_new); s_new[1:m] .= s
ns.caches = (V,zr,zl,H_new,g_new,c_new,s_new)
return H_new,g_new,c_new,s_new
end

function Gridap.Algebra.numerical_setup(ss::GMRESSymbolicSetup, A::AbstractMatrix)
solver = ss.solver
Pr_ns = isa(solver.Pr,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pr,A),A)
Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A),A)
Pr_ns = !isnothing(solver.Pr) ? numerical_setup(symbolic_setup(solver.Pr,A),A) : nothing
Pl_ns = !isnothing(solver.Pl) ? numerical_setup(symbolic_setup(solver.Pl,A),A) : nothing
caches = get_solver_caches(solver,A)
return GMRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches)
end
Expand Down Expand Up @@ -79,11 +109,12 @@ end

function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::AbstractVector)
solver, A, Pl, Pr, caches = ns.solver, ns.A, ns.Pl_ns, ns.Pr_ns, ns.caches
log, ilog = solver.outer_log, solver.inner_log
V, zr, zl, H, g, c, s = caches
m = krylov_cache_length(ns)
log = solver.log

fill!(V[1],zero(eltype(V[1])))
fill!(zr,zero(eltype(zr)))
!isnothing(zr) && fill!(zr,zero(eltype(zr)))
fill!(zl,zero(eltype(zl)))

# Initial residual
Expand All @@ -96,8 +127,13 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::Abst
V[1] ./= β
fill!(H,0.0)
fill!(g,0.0); g[1] = β
idone = init!(ilog,β)
while !idone
while !done && !restart(solver,j)
# Expand Krylov basis if needed
if j > m
H, g, c, s = expand_krylov_caches!(ns)
m = krylov_cache_length(ns)
end

# Arnoldi orthogonalization by Modified Gram-Schmidt
fill!(V[j+1],zero(eltype(V[j+1])))
krylov_mul!(V[j+1],A,V[j],Pr,Pl,zr,zl)
Expand All @@ -122,7 +158,7 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::Abst

β = abs(g[j+1])
j += 1
idone = update!(ilog,β)
done = update!(log,β)
end
j = j-1

Expand All @@ -145,8 +181,8 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::Abst
x .+= zr
end
krylov_residual!(V[1],x,A,b,Pl,zl)
done = update!(log,β)
end

finalize!(log,β)
return x
end
6 changes: 6 additions & 0 deletions src/SolverInterfaces/ConvergenceLogs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ function finalize!(log::ConvergenceLog{T},r::T) where T
return flag
end

function print_message(log::ConvergenceLog{T},msg::String) where T
if log.verbose > SOLVER_VERBOSE_LOW
println(get_tabulation(log),msg)
end
end

function Base.show(io::IO,k::MIME"text/plain",log::ConvergenceLog)
println(io,"ConvergenceLog[$(log.name)]")
println(io," > tols: $(summary(log.tols))")
Expand Down
2 changes: 1 addition & 1 deletion src/SolverInterfaces/SolverInterfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include("SolverInfos.jl")

export SolverVerboseLevel, SolverConvergenceFlag
export SolverTolerances, get_solver_tolerances, set_solver_tolerances!
export ConvergenceLog, init!, update!, finalize!, reset!
export ConvergenceLog, init!, update!, finalize!, reset!, print_message

export SolverInfo

Expand Down
12 changes: 9 additions & 3 deletions test/LinearSolvers/KrylovSolversTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function test_solver(solver,op,Uh,dΩ)
A, b = get_matrix(op), get_vector(op);
ns = numerical_setup(symbolic_setup(solver,A),A)

x = allocate_in_domain(A)
x = allocate_in_domain(A); fill!(x,0.0)
solve!(x,ns,b)

u = interpolate(sol,Uh)
Expand Down Expand Up @@ -69,10 +69,16 @@ function main(distribute,np)
test_solver(gmres,op,Uh,dΩ)

# GMRES without preconditioner
gmres = LinearSolvers.GMRESSolver(40;rtol=1.e-8,verbose=verbose)
gmres = LinearSolvers.GMRESSolver(10;rtol=1.e-8,verbose=verbose)
test_solver(gmres,op,Uh,dΩ)

fgmres = LinearSolvers.FGMRESSolver(40,P;rtol=1.e-8,verbose=verbose)
gmres = LinearSolvers.GMRESSolver(10;restart=true,rtol=1.e-8,verbose=verbose)
test_solver(gmres,op,Uh,dΩ)

fgmres = LinearSolvers.FGMRESSolver(10,P;rtol=1.e-8,verbose=verbose)
test_solver(fgmres,op,Uh,dΩ)

fgmres = LinearSolvers.FGMRESSolver(10,P;restart=true,rtol=1.e-8,verbose=verbose)
test_solver(fgmres,op,Uh,dΩ)

pcg = LinearSolvers.CGSolver(P;rtol=1.e-8,verbose=verbose)
Expand Down

0 comments on commit f8fb795

Please sign in to comment.