Skip to content

Commit

Permalink
Started exploring generalized block solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
JordiManyer committed Dec 13, 2023
1 parent aec6326 commit 161443b
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 0 deletions.
116 changes: 116 additions & 0 deletions src/BlockSolvers/BlockDiagonalSolvers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@

struct BlockDiagonalSolver{N,A,B} <: Gridap.Algebra.LinearSolver
blocks :: B
solvers :: C
function BlockDiagonalSolver(
blocks :: AbstractVector{<:SolverBlock},
solvers :: AbstractVector{<:Gridap.Algebra.LinearSolver}
)
N = length(solvers)
@check length(blocks) == N

A = typeof(blocks)
B = typeof(solvers)
return new{N,A,B}(blocks,solvers)
end
end

# Constructors

function BlockDiagonalSolver(solvers::AbstractVector{<:Gridap.Algebra.LinearSolver};
is_nonlinear::Vector{Bool}=fill(false,length(solvers)))
blocks = map(nl -> nl ? NonlinearSystemBlock() : LinearSystemBlock(),is_nonlinear)
return BlockDiagonalSolver(blocks,solvers)
end

function BlockDiagonalSolver(funcs :: AbstractArray{<:Function},
trials :: AbstractArray{<:FESpace},
tests :: AbstractArray{<:FESpace},
solvers :: AbstractArray{<:Gridap.Algebra.LinearSolver};
is_nonlinear::Vector{Bool}=fill(false,length(solvers)))
blocks = map(funcs,trials,tests,is_nonlinear) do f,trial,test,nl
nl ? TriformBlock(f,trial,test) : BiformBlock(f,trial,test)
end
return BlockDiagonalSolver(blocks,solvers)
end

function BlockDiagonalSolver(mats::AbstractVector{<:AbstractMatrix},
solvers::AbstractVector{<:Gridap.Algebra.LinearSolver})
blocks = map(MatrixBlock,mats)
return BlockDiagonalSolver(blocks,solvers)
end

# Symbolic setup

struct BlockDiagonalSolverSS{A,B,C} <: Gridap.Algebra.SymbolicSetup
solver :: A
block_ss :: B
block_caches :: C
end

function Gridap.Algebra.symbolic_setup(solver::BlockDiagonalSolver,mat::AbstractBlockMatrix)
mat_blocks = diag(blocks(mat))
block_caches = map(instantiate_block_cache,solver.blocks,mat_blocks)
block_ss = map(symbolic_setup,solver.solvers,block_caches)
return BlockDiagonalSolverSS(solver,block_ss,block_caches)
end

function Gridap.Algebra.symbolic_setup(solver::BlockDiagonalSolver,mat::AbstractBlockMatrix,x::AbstractBlockVector)
mat_blocks = diag(blocks(mat))
vec_blocks = blocks(x)
block_caches = map(instantiate_block_cache,solver.blocks,mat_blocks,vec_blocks)
block_ss = map(symbolic_setup,solver.solvers,block_caches,vec_blocks)
return BlockDiagonalSolverSS(solver,block_ss,block_caches)
end

# Numerical setup

struct BlockDiagonalSolverNS{A,B,C} <: Gridap.Algebra.NumericalSetup
solver :: A
block_ns :: B
block_caches :: C
end

function Gridap.Algebra.numerical_setup(ss::BlockDiagonalSolverSS,mat::AbstractBlockMatrix)
solver = ss.solver
block_ns = map(numerical_setup,ss.block_ss,ss.block_caches)
return BlockDiagonalSolverNS(solver,block_ns,ss.block_caches)
end

function Gridap.Algebra.numerical_setup(ss::BlockDiagonalSolverSS,mat::AbstractBlockMatrix,x::AbstractBlockVector)
solver = ss.solver
vec_blocks = blocks(x)
block_ns = map(numerical_setup,ss.block_ss,ss.block_caches,vec_blocks)
return BlockDiagonalSolverNS(solver,block_ns,ss.block_caches)
end

function Gridap.Algebra.numerical_setup!(ns::BlockDiagonalSolverNS,mat::AbstractBlockMatrix)
solver = ns.solver
mat_blocks = diag(blocks(mat))
block_caches = map(update_block_cache!,ns.block_caches,solver.blocks,mat_blocks)
map(numerical_setup!,ns.block_ns,block_caches)
return ns
end

function Gridap.Algebra.numerical_setup!(ns::BlockDiagonalSolverNS,mat::AbstractBlockMatrix,x::AbstractBlockVector)
solver = ns.solver
mat_blocks = diag(blocks(mat))
vec_blocks = blocks(x)
block_caches = map(update_block_cache!,ns.block_caches,solver.blocks,mat_blocks,vec_blocks)
map(numerical_setup!,ns.block_ns,block_caches,vec_blocks)
return ns
end

function Gridap.Algebra.solve!(x::AbstractBlockVector,ns::BlockDiagonalSolverNS,b::AbstractBlockVector)
@check blocklength(x) == blocklength(b) == length(ns.block_ns)
for (iB,bns) in enumerate(ns.block_ns)
xi = x[Block(iB)]
bi = b[Block(iB)]
solve!(xi,bns,bi)
end
return x
end

function LinearAlgebra.ldiv!(x,ns::BlockDiagonalSolverNS,b)
solve!(x,ns,b)
end
77 changes: 77 additions & 0 deletions src/BlockSolvers/BlockSolverInterfaces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@

abstract type SolverBlock end
abstract type LinearSolverBlock <: SolverBlock end
abstract type NonlinearSolverBlock <: SolverBlock end

struct MatrixBlock{A} <: LinearSolverBlock
mat :: A
function MatrixBlock(mat::AbstractMatrix)
A = typeof(mat)
return new{A}(mat)
end
end

struct LinearSystemBlock <: LinearSolverBlock end
struct NonlinearSystemBlock <: NonlinearSolverBlock end

struct BiformBlock <: LinearSolverBlock
f :: Function
trial :: FESpace
test :: FESpace
assem :: Assembler
end

struct TriformBlock <: NonlinearSolverBlock
f :: Function
trial :: FESpace
test :: FESpace
assem :: Assembler
end

# Instantiate blocks

function instantiate_block_cache(block::LinearSolverBlock,mat::AbstractMatrix)
@abstractmethod
end
function instantiate_block_cache(block::NonlinearSolverBlock,mat::AbstractMatrix,x::AbstractVector)
@abstractmethod
end
function instantiate_block_cache(block::LinearSolverBlock,mat::AbstractMatrix,x::AbstractVector)
instantiate_block_cache(block,mat)
end

function instantiate_block_cache(block::MatrixBlock,mat::AbstractMatrix)
return block.mat
end
function instantiate_block_cache(block::BiformBlock,mat::AbstractMatrix)
return assemble_matrix(block.f,block.assem,block.trial,block.test)
end
instantiate_block_cache(block::LinearSystemBlock,mat::AbstractMatrix) = mat

function instantiate_block_cache(block::TriformSolverBlock,mat::AbstractMatrix,x::AbstractVector)
uh = FEFunction(block.trial,x)
f(u,v) = block.f(uh,u,v)
return assemble_matrix(f,block.assem,block.trial,block.test)
end
instantiate_block_cache(block::NonlinearSystemBlock,mat::AbstractMatrix,x::AbstractVector) = mat

# Update blocks

function update_block_cache!(cache,block::LinearSolverBlock,mat::AbstractMatrix)
return cache
end
function update_block_cache!(cache,block::NonlinearSolverBlock,mat::AbstractMatrix,x::AbstractVector)
@abstractmethod
end
function update_block_cache!(cache,block::LinearSolverBlock,mat::AbstractMatrix,x::AbstractVector)
update_block!(cache,block,mat)
end

function update_block_cache!(cache,block::TriformBlock,mat::AbstractMatrix,x::AbstractVector)
uh = FEFunction(block.trial,x)
f(u,v) = block.f(uh,u,v)
assemble_matrix!(mat,f,block.assem,block.trial,block.test)
end
function update_block_cache!(cache,block::NonlinearSystemBlock,mat::AbstractMatrix,x::AbstractVector)
return cache
end
22 changes: 22 additions & 0 deletions src/BlockSolvers/BlockSolvers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module BlockSolvers
using LinearAlgebra
using SparseArrays
using SparseMatricesCSR
using BlockArrays
using IterativeSolvers

using Gridap
using Gridap.Helpers, Gridap.Algebra, Gridap.CellData, Gridap.Arrays, Gridap.FESpaces, Gridap.MultiField
using PartitionedArrays
using GridapDistributed

using GridapSolvers.MultilevelTools
using GridapSolvers.SolverInterfaces

include("BlockSolverInterfaces.jl")
include("BlockDiagonalSolvers.jl")
include("BlockTriangularSolvers.jl")

export BlockDiagonalSolver
export BlockTriangularSolver
end
Empty file.

0 comments on commit 161443b

Please sign in to comment.