Skip to content

Commit

Permalink
Simplifying linear solver interface
Browse files Browse the repository at this point in the history
  • Loading branch information
fverdugo committed Feb 12, 2024
1 parent ffd9e38 commit 4bd2afc
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 66 deletions.
1 change: 0 additions & 1 deletion extensions/PartitionedSolvers/src/PartitionedSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module PartitionedSolvers
using PartitionedArrays
using SparseArrays
using LinearAlgebra
using IncompleteLU

export setup
export solve!
Expand Down
34 changes: 12 additions & 22 deletions extensions/PartitionedSolvers/src/interfaces.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,25 @@

setup(a) = a.setup
solve!(a) = a.solve!
apply!(a) = a.apply!
setup!(a) = a.setup!
finalize!(a) = a.finalize!
residual!(a) = a.residual!
jacobian!(a) = a.jacobian!
residual_and_jacobian!(a) = a.residual_and_jacobian!

abstract type AbstractLinearSolver end

struct GenericLinearSolver{A,B,C,D,E} <: AbstractLinearSolver
struct GenericLinearSolver{A,B,C,D} <: AbstractLinearSolver
setup::A # (x,A,b) -> ls_setup
solve!::B # (x,ls_setup,b) -> results
setup!::C # (ls_setup,A) -> ls_setup
finalize!::D # ls_setup -> nothing
apply!::E # (x,ls_setup,b) -> x
end

setup(solver::GenericLinearSolver,x,A,b) = solver.setup(x,A,b)
solve!(solver::GenericLinearSolver,x,S,b) = solver.solve!(x,S,b)
setup!(solver::GenericLinearSolver,S,A) = solver.setup!(S,A)
finalize!(solver::GenericLinearSolver,S) = solver.finalize!(S)

function linear_solver(;
setup,
solve!,
setup!,
finalize! = ls_setup->nothing,
apply! = (x,ls_setup,b) -> begin
fill!(x,zero(eltype(x)))
solve!(x,ls_setup,b)
x
end
)
GenericLinearSolver(setup,solve!,setup!,finalize!,apply!)
GenericLinearSolver(setup,solve!,setup!,finalize!)
end

struct Preconditioner{A,B}
Expand All @@ -38,22 +28,22 @@ struct Preconditioner{A,B}
end

function LinearAlgebra.ldiv!(x,P::Preconditioner,b)
apply!(P.solver)(x,P.solver_setup,b)
fill!(x,zero(eltype(x)))
solve!(P.solver,x,P.solver_setup,b)
x
end

function preconditioner(x,A,b,solver)
solver_setup = setup(solver)(x,A,b)
solver_setup = setup(solver,x,A,b)
Preconditioner(solver,solver_setup)
end

function preconditioner!(P::Preconditioner,A)
setup!(P.solver)(P.solver_setup,A)
setup!(P.solver,P.solver_setup,A)
P
end

function finalize!(P::Preconditioner)
PartitionedSolvers.finalize!(P.solver)(P.solver_setup)
PartitionedSolvers.finalize!(P.solver,P.solver_setup)
end


34 changes: 12 additions & 22 deletions extensions/PartitionedSolvers/src/smoothers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,21 @@ function do_nothing_linear_solver()
setup(x,A,b) = nothing
solve!(x,::Nothing,b) = copy!(x,b)
setup!(::Nothing,A) = nothing
apply! = solve!
linear_solver(;setup,setup!,solve!,apply!)
linear_solver(;setup,setup!,solve!)
end

function lu_solver()
setup(x,A,b) = lu(A)
solve! = ldiv!
setup! = lu!
apply! = ldiv!
linear_solver(;setup,solve!,setup!,apply!)
end

function ilu_solver(;kwargs...)
setup(x,A,b) = ilu(A;kwargs...)
solve! = ldiv!
setup! = lu!
apply! = ldiv!
linear_solver(;setup,solve!,setup!,apply!)
linear_solver(;setup,solve!,setup!)
end

function diagonal_solver()
setup(x,A,b) = diag(A)
solve!(x,D,b) = x .= D .\ b
apply! = solve!
setup! = diag!
linear_solver(;setup,setup!,solve!,apply!)
linear_solver(;setup,setup!,solve!)
end

function richardson_solver(solver;niters)
Expand Down Expand Up @@ -70,23 +59,24 @@ end

function additive_schwartz_solver(local_solver)
function setup(x,A,b)
local_setups = map(PartitionedSolvers.setup(local_solver),own_values(x),own_own_values(A),own_values(b))
f = (x,A,b) -> PartitionedSolvers.setup(local_solver,x,A,b)
local_setups = map(f,own_values(x),own_own_values(A),own_values(b))
local_setups
end
function setup!(local_setups,A)
map(PartitionedSolvers.setup!(local_solver),own_own_values(A))
f = (S,A) -> PartitionedSolvers.setup!(local_solver,S,A)
map(f,local_setups,own_own_values(A))
local_setups
end
function solve!(x,local_setups,b)
map(PartitionedSolvers.solve!(local_solver),own_values(x),local_setups,own_values(b))
end
function apply!(x,local_setups,b)
map(PartitionedSolvers.apply!(local_solver),own_values(x),local_setups,own_values(b))
f = (x,S,b) -> PartitionedSolvers.solve!(local_solver,x,S,b)
map(f,own_values(x),local_setups,own_values(b))
end
function finalize!(local_setups)
map(PartitionedSolvers.finalize!(local_solver),local_setups)
f = (S) -> PartitionedSolvers.finalize!(local_solver,S)
map(f,local_setups)
nothing
end
linear_solver(;setup,setup!,solve!,apply!,finalize!)
linear_solver(;setup,setup!,solve!,finalize!)
end

42 changes: 21 additions & 21 deletions extensions/PartitionedSolvers/test/smoothers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,47 @@ x = pones(partition(axes(A,2)))
b = A*x

solver = lu_solver()
S = setup(solver)(x,A,b)
S = setup(solver,x,A,b)
y = similar(x)
solve!(solver)(y,S,b)
solve!(solver,y,S,b)
tol = 1.e-8
@test norm(y-x)/norm(x) < tol
setup!(solver)(S,2*A)
solve!(solver)(y,S,b)
setup!(solver,S,2*A)
solve!(solver,y,S,b)
@test norm(y-x/2)/norm(x/2) < tol
finalize!(solver)(S)
finalize!(solver,S)

solver = richardson_solver(lu_solver(),niters=1)
S = setup(solver)(x,A,b)
S = setup(solver,x,A,b)
y = similar(x)
y .= 0
solve!(solver)(y,S,b)
solve!(solver,y,S,b)
tol = 1.e-8
@test norm(y-x)/norm(x) < tol
setup!(solver)(S,2*A)
solve!(solver)(y,S,b)
setup!(solver,S,2*A)
solve!(solver,y,S,b)
@test norm(y-x/2)/norm(x/2) < tol
finalize!(solver)(S)
finalize!(solver,S)

solver = jacobi_solver(;niters=1000)
S = setup(solver)(x,A,b)
S = setup(solver,x,A,b)
y = similar(x)
y .= 0
solve!(solver)(y,S,b)
solve!(solver,y,S,b)
tol = 1.e-8
@test norm(y-x)/norm(x) < tol
setup!(solver)(S,2*A)
solve!(solver)(y,S,b)
setup!(solver,S,2*A)
solve!(solver,y,S,b)
@test norm(y-x/2)/norm(x/2) < tol
finalize!(solver)(S)
finalize!(solver,S)

solver = additive_schwartz_solver(ilu_solver())
S = setup(solver)(x,A,b)
solver = additive_schwartz_solver(lu_solver())
S = setup(solver,x,A,b)
y = similar(x)
y .= 0
solve!(solver)(y,S,b)
setup!(solver)(S,2*A)
solve!(solver)(y,S,b)
finalize!(solver)(S)
solve!(solver,y,S,b)
setup!(solver,S,2*A)
solve!(solver,y,S,b)
finalize!(solver,S)

end #module

0 comments on commit 4bd2afc

Please sign in to comment.