Skip to content

Commit

Permalink
PathcBasedPrologationOps working in serial
Browse files Browse the repository at this point in the history
  • Loading branch information
JordiManyer committed Sep 2, 2024
1 parent ae0ea85 commit 2e2034d
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 141 deletions.
2 changes: 2 additions & 0 deletions src/LinearSolvers/LinearSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export SymGaussSeidelSmoother
export GMGLinearSolver
export BlockDiagonalSmoother
export SchurComplementSolver
export SchwarzLinearSolver

# Wrappers for IterativeSolvers.jl
export IS_ConjugateGradientSolver
Expand Down Expand Up @@ -56,5 +57,6 @@ include("GMGLinearSolvers.jl")
include("IterativeLinearSolvers.jl")
include("SchurComplementSolvers.jl")
include("MatrixSolvers.jl")
include("SchwarzLinearSolvers.jl")

end
49 changes: 49 additions & 0 deletions src/LinearSolvers/SchwarzLinearSolvers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

# TODO:
# - Implement the multiplicative case
# - Add support for weights/averaging when aggregating the additive case?

struct SchwarzLinearSolver{T,S,A} <: Algebra.LinearSolver
local_solvers::A
function SchwarzLinearSolver(
solver::Union{S,AbstractVector{<:S}};
type = :additive
) where S <: Algebra.LinearSolver
@check type in (:additive,:multiplicative)
@notimplementedif type == :multiplicative # TODO
A = typeof(solver)
new{type,S,A}(solver)
end
end

struct SchwarzSymbolicSetup{T,S,A,B} <: Algebra.SymbolicSetup
solver::SchwarzLinearSolver{T,S,A}
local_ss::B
end

function Algebra.symbolic_setup(s::SchwarzLinearSolver,mat::AbstractMatrix)
# TODO: This is where we should compute the comm coloring for the multiplicative case
expand(s) = map(m -> s,partition(mat))
expand(s::AbstractVector) = s

local_solvers = expand(s.local_solvers)
local_ss = map(symbolic_setup,local_solvers,partition(mat))
return SchwarzSymbolicSetup(s,local_ss)
end

struct SchwarzNumericalSetup{T,S,A,B} <: Algebra.NumericalSetup
solver::SchwarzLinearSolver{T,S,A}
local_ns::B
end

function Algebra.numerical_setup(ss::SchwarzSymbolicSetup,mat::PSparseMatrix)
local_ns = map(numerical_setup,ss.local_ss,partition(mat))
return SchwarzNumericalSetup(ss.solver,local_ns)
end

function Algebra.solve!(x::PVector,ns::SchwarzNumericalSetup{:additive},b::PVector)
map(solve!,partition(x),ns.local_ns,partition(b))
assemble!(x) |> wait
consistent!(x) |> wait
return x
end
134 changes: 57 additions & 77 deletions src/PatchBasedSmoothers/seq/PatchBasedLinearSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,112 +73,92 @@ struct PatchBasedSmootherNumericalSetup{A,B,C,D} <: Gridap.Algebra.NumericalSetu
caches :: D
end

function Gridap.Algebra.numerical_setup(ss::PatchBasedSymbolicSetup,A::AbstractMatrix)
@check !ss.solver.is_nonlinear
solver = ss.solver
Ph, Vh = solver.patch_space, solver.space
weights = solver.weighted ? compute_weight_operators(Ph,Vh) : nothing

ap(u,v) = solver.biform(u,v)
function assemble_patch_matrices(Ph::PatchFESpace,ap;local_solver=LUSolver())
assem = SparseMatrixAssembler(Ph,Ph)
Ap = assemble_matrix(ap,assem,Ph,Ph)
Ap_ns = numerical_setup(symbolic_setup(solver.local_solver,Ap),Ap)

# Caches
rp = allocate_in_range(Ap); fill!(rp,0.0)
dxp = allocate_in_domain(Ap); fill!(dxp,0.0)
caches = (rp,dxp)
Ap_ns = numerical_setup(symbolic_setup(local_solver,Ap),Ap)
return Ap, Ap_ns
end

return PatchBasedSmootherNumericalSetup(solver,nothing,Ap_ns,weights,caches)
function assemble_patch_matrices(Ph::DistributedPatchFESpace,ap;local_solver=LUSolver())
u, v = get_trial_fe_basis(Vh), get_fe_basis(Vh)
matdata = collect_cell_matrix(Ph,Ph,ap(u,v))
Ap, Ap_ns = map(local_views(Ph),matdata) do Ph, matdata
assem = SparseMatrixAssembler(Ph,Ph)
Ap = assemble_matrix(assem,matdata)
Ap_ns = numerical_setup(symbolic_setup(local_solver,Ap),Ap)
return Ap, Ap_ns
end |> PartitionedArrays.tuple_of_arrays
return Ap, Ap_ns
end

function Gridap.Algebra.numerical_setup(ss::PatchBasedSymbolicSetup,A::AbstractMatrix,x::AbstractVector)
@check ss.solver.is_nonlinear
solver = ss.solver
Ph, Vh = solver.patch_space, solver.space
weights = solver.weighted ? compute_weight_operators(Ph,Vh) : nothing

u0 = FEFunction(Vh,x)
ap(u,v) = solver.biform(u0,u,v)
function update_patch_matrices!(Ap,Ap_ns,Ph::PatchFESpace,ap)
assem = SparseMatrixAssembler(Ph,Ph)
Ap = assemble_matrix(ap,assem,Ph,Ph)
Ap_ns = numerical_setup(symbolic_setup(solver.local_solver,Ap),Ap)
assemble_matrix!(Ap,assem,Ph,Ph,ap)
numerical_setup!(Ap_ns,Ap)
end

function update_patch_matrices!(Ap,Ap_ns,Ph::DistributedPatchFESpace,ap)
u, v = get_trial_fe_basis(Vh), get_fe_basis(Vh)
matdata = collect_cell_matrix(Ph,Ph,ap(u,v))
map(Ap, Ap_ns, local_views(Ph), matdata) do Ap, Ap_ns, Ph, matdata
assem = SparseMatrixAssembler(Ph,Ph)
assemble_matrix!(Ap,assem,matdata)
numerical_setup!(Ap_ns,Ap)
end
end

# Caches
rp = allocate_in_range(Ap); fill!(rp,0.0)
dxp = allocate_in_domain(Ap); fill!(dxp,0.0)
caches = (rp,dxp)
function allocate_patch_workvectors(Ph::PatchFESpace,Vh::FESpace)
rp = zero_free_values(Ph)
dxp = zero_free_values(Ph)
return rp,dxp
end

return PatchBasedSmootherNumericalSetup(solver,Ap,Ap_ns,weights,caches)
function allocate_patch_workvectors(Ph::DistributedPatchFESpace,Vh::GridapDistributed.DistributedFESpace)
rp = zero_free_values(Ph)
dxp = zero_free_values(Ph)
r = zero_free_values(Vh)
x = zero_free_values(Vh)
return rp,dxp,r,x
end

function Gridap.Algebra.numerical_setup(ss::PatchBasedSymbolicSetup,A::PSparseMatrix)
function Gridap.Algebra.numerical_setup(ss::PatchBasedSymbolicSetup,A::AbstractMatrix)
@check !ss.solver.is_nonlinear
solver = ss.solver
local_solver = solver.local_solver
Ph, Vh = solver.patch_space, solver.space
weights = solver.weighted ? compute_weight_operators(Ph,Vh) : nothing

# Patch system solver (only local systems need to be solved)
u, v = get_trial_fe_basis(Vh), get_fe_basis(Vh)
contr = solver.biform(u,v)
matdata = collect_cell_matrix(Ph,Ph,contr)
Ap, Ap_ns = map(local_views(Ph),matdata) do Ph, matdata
assem = SparseMatrixAssembler(Ph,Ph)
Ap = assemble_matrix(assem,matdata)
Ap_ns = numerical_setup(symbolic_setup(solver.local_solver,Ap),Ap)
return Ap, Ap_ns
end |> PartitionedArrays.tuple_of_arrays

# Caches
rp = pfill(0.0,partition(Ph.gids))
dxp = pfill(0.0,partition(Ph.gids))
r = pfill(0.0,partition(Vh.gids))
x = pfill(0.0,partition(Vh.gids))
caches = (rp,dxp,r,x)

ap(u,v) = solver.biform(u,v)
Ap, Ap_ns = assemble_patch_matrices(Ph,ap;local_solver)
weights = solver.weighted ? compute_weight_operators(Ph,Vh) : nothing
caches = allocate_patch_workvectors(Ph,Vh)
return PatchBasedSmootherNumericalSetup(solver,nothing,Ap_ns,weights,caches)
end

function Gridap.Algebra.numerical_setup(ss::PatchBasedSymbolicSetup,A::PSparseMatrix,x::PVector)
function Gridap.Algebra.numerical_setup(ss::PatchBasedSymbolicSetup,A::AbstractMatrix,x::AbstractVector)
@check ss.solver.is_nonlinear
solver = ss.solver
local_solver = solver.local_solver
Ph, Vh = solver.patch_space, solver.space

u0 = FEFunction(Vh,x)
ap(u,v) = solver.biform(u0,u,v)
Ap, Ap_ns = assemble_patch_matrices(Ph,ap;local_solver)
weights = solver.weighted ? compute_weight_operators(Ph,Vh) : nothing

# Patch system solver (only local systems need to be solved)
u0, u, v = FEFunction(Vh,x), get_trial_fe_basis(Vh), get_fe_basis(Vh)
contr = solver.biform(u0,u,v)
matdata = collect_cell_matrix(Ph,Ph,contr)
Ap, Ap_ns = map(local_views(Ph),matdata) do Ph, matdata
assem = SparseMatrixAssembler(Ph,Ph)
Ap = assemble_matrix(assem,matdata)
Ap_ns = numerical_setup(symbolic_setup(solver.local_solver,Ap),Ap)
return Ap, Ap_ns
end |> PartitionedArrays.tuple_of_arrays

# Caches
rp = pfill(0.0,partition(Ph.gids))
dxp = pfill(0.0,partition(Ph.gids))
r = pfill(0.0,partition(Vh.gids))
x = pfill(0.0,partition(Vh.gids))
caches = (rp,dxp,r,x)
caches = allocate_patch_workvectors(Ph,Vh)
return PatchBasedSmootherNumericalSetup(solver,Ap,Ap_ns,weights,caches)
end

function Gridap.Algebra.numerical_setup!(ns::PatchBasedSmootherNumericalSetup,A::PSparseMatrix,x::PVector)
function Gridap.Algebra.numerical_setup!(ns::PatchBasedSmootherNumericalSetup,A::AbstractMatrix,x::AbstractVector)
@check ns.solver.is_nonlinear
solver = ns.solver
Ph, Vh = solver.patch_space, solver.space
Ap, Ap_ns = ns.local_A, ns.local_ns

u0, u, v = FEFunction(Vh,x), get_trial_fe_basis(Vh), get_fe_basis(Vh)
contr = solver.biform(u0,u,v)
matdata = collect_cell_matrix(Ph,Ph,contr)
map(Ap, Ap_ns, local_views(Ph), matdata) do Ap, Ap_ns, Ph, matdata
assem = SparseMatrixAssembler(Ph,Ph)
assemble_matrix!(Ap,assem,matdata)
numerical_setup!(Ap_ns,Ap)
end
u0 = FEFunction(Vh,x)
ap(u,v) = solver.biform(u0,u,v)
update_patch_matrices!(Ap,Ap_ns,Ph,ap)
return ns
end

function Gridap.Algebra.solve!(x::AbstractVector,ns::PatchBasedSmootherNumericalSetup,r::AbstractVector)
Expand Down
40 changes: 17 additions & 23 deletions src/PatchBasedSmoothers/seq/PatchTransferOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,8 @@ function _get_patch_cache(lev,sh,PD,lhs,rhs,is_nonlinear,cache_refine)
Ph = PatchFESpace(Uh,PD,cell_conformity;patches_mask)

# Solver caches
u, v = get_trial_fe_basis(Uh), get_fe_basis(Uh)
contr = is_nonlinear ? lhs(zero(Uh),u,v) : lhs(u,v)
matdata = collect_cell_matrix(Ph,Ph,contr)
Ap_ns, Ap = map(local_views(Ph),matdata) do Ph, matdata
assem = SparseMatrixAssembler(Ph,Ph)
Ap = assemble_matrix(assem,matdata)
Ap_ns = numerical_setup(symbolic_setup(LUSolver(),Ap),Ap)
return Ap_ns, Ap
end |> tuple_of_arrays
Ap = is_nonlinear ? Ap : nothing
ap(u,v) = is_nonlinear ? lhs(zero(Uh),u,v) : lhs(u,v)
Ap, Ap_ns = assemble_patch_matrices(Ph,ap) # TODO: This is using an LUSolver

duh = zero(Uh)
dxp, rp = zero_free_values(Ph), zero_free_values(Ph)
Expand All @@ -107,18 +99,12 @@ function MultilevelTools.update_transfer_operator!(op::PatchProlongationOperator
end

if !isa(fv_h,Nothing)
u, v = get_trial_fe_basis(Uh), get_fe_basis(Uh)
contr = op.is_nonlinear ? op.lhs(FEFunction(Uh,fv_h),u,v) : op.lhs(u,v)
matdata = collect_cell_matrix(Ph,Ph,contr)
map(Ap_ns,Ap,local_views(Ph),matdata) do Ap_ns, Ap, Ph, matdata
assem = SparseMatrixAssembler(Ph,Ph)
assemble_matrix!(Ap,assem,matdata)
numerical_setup!(Ap_ns,Ap)
end
ap(u,v) = op.is_nonlinear ? op.lhs(FEFunction(Uh,fv_h),u,v) : op.lhs(u,v)
update_patch_matrices!(Ap,Ap_ns,Ph,ap)
end
end

function LinearAlgebra.mul!(y::PVector,A::PatchProlongationOperator{Val{false}},x::PVector)
function LinearAlgebra.mul!(y::AbstractVector,A::PatchProlongationOperator{Val{false}},x::AbstractVector)
cache_refine, cache_patch, cache_redist = A.caches
model_h, Uh, fv_h, dv_h, UH, fv_H, dv_H = cache_refine
Ph, Ap_ns, Ap, duh, dxp, rp = cache_patch
Expand All @@ -130,7 +116,11 @@ function LinearAlgebra.mul!(y::PVector,A::PatchProlongationOperator{Val{false}},
uh = FEFunction(Uh,fv_h,dv_h)

assemble_vector!(v->A.rhs(uh,v),rp,Ph)
map(solve!,partition(dxp),Ap_ns,partition(rp))
if isa(y,PVector)
map(solve!,partition(dxp),Ap_ns,partition(rp))
else
solve!(dxp,Ap_ns,rp)
end
inject!(dxh,Ph,dxp)
fv_h .= fv_h .- dxh
copy!(y,fv_h)
Expand Down Expand Up @@ -230,7 +220,6 @@ struct PatchRestrictionOperator{R,A,B}
end

function PatchRestrictionOperator(lev,sh,Ip,rhs,qdegree;solver=LUSolver())

cache_refine = MultilevelTools._get_dual_projection_cache(lev,sh,qdegree,solver)
cache_redist = MultilevelTools._get_redistribution_cache(lev,sh,:residual,:restriction,:dual_projection,cache_refine)
cache_patch = Ip.caches[2]
Expand All @@ -242,6 +231,7 @@ function PatchRestrictionOperator(lev,sh,Ip,rhs,qdegree;solver=LUSolver())
end

function MultilevelTools.update_transfer_operator!(op::PatchRestrictionOperator,x::Union{PVector,Nothing})
# Note: Update is done in the prolongation operator, with which we share the cache
nothing
end

Expand All @@ -262,7 +252,7 @@ function setup_patch_restriction_operators(sh,patch_prolongations,rhs,qdegrees;k
end
end

function LinearAlgebra.mul!(y::PVector,A::PatchRestrictionOperator{Val{false}},x::PVector)
function LinearAlgebra.mul!(y::AbstractVector,A::PatchRestrictionOperator{Val{false}},x::AbstractVector)
cache_refine, cache_patch, _ = A.caches
model_h, Uh, VH, Mh_ns, rh, uh, assem, dΩhH = cache_refine
Ph, Ap_ns, Ap, duh, dxp, rp = cache_patch
Expand All @@ -272,7 +262,11 @@ function LinearAlgebra.mul!(y::PVector,A::PatchRestrictionOperator{Val{false}},x
copy!(fv_h,x)
fill!(rp,0.0)
prolongate!(rp,Ph,fv_h)
map(solve!,partition(dxp),Ap_ns,partition(rp))
if isa(y,PVector)
map(solve!,partition(dxp),Ap_ns,partition(rp))
else
solve!(dxp,Ap_ns,rp)
end
inject!(dxh,Ph,dxp)
consistent!(dxh) |> fetch

Expand Down
38 changes: 8 additions & 30 deletions src/PatchBasedSmoothers/seq/VankaSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,16 @@ function VankaSolver(space::MultiFieldFESpace,patch_cells::Table{<:Integer})
end

function VankaSolver(space::GridapDistributed.DistributedMultiFieldFESpace)
patch_ids = map(local_views(space)) do space
solver = VankaSolver(space)
return solver.patch_ids
end
return VankaSolver(patch_ids)
local_solvers = map(VankaSolver,local_views(space))
return SchwarzLinearSolver(local_solvers)
end

function VankaSolver(space::GridapDistributed.DistributedMultiFieldFESpace,patch_decomposition::DistributedPatchDecomposition)
patch_ids = map(local_views(space),local_views(patch_decomposition)) do space, patch_decomposition
solver = VankaSolver(space,patch_decomposition)
return solver.patch_ids
end
return VankaSolver(patch_ids)
function VankaSolver(
space::GridapDistributed.DistributedMultiFieldFESpace,
patch_decomposition::DistributedPatchDecomposition
)
local_solvers = map(VankaSolver,local_views(space),local_views(patch_decomposition))
return SchwarzLinearSolver(local_solvers)
end

struct VankaSS{A} <: Algebra.SymbolicSetup
Expand Down Expand Up @@ -93,22 +90,3 @@ function Algebra.solve!(x::AbstractVector,ns::VankaNS,b::AbstractVector)

return x
end

struct DistributedVankaNS{A,B} <: Algebra.NumericalSetup
solver::VankaSolver{A}
ns::B
end

function Algebra.numerical_setup(ss::VankaSS,mat::PSparseMatrix)
ns = map(partition(mat)) do mat
numerical_setup(ss,mat)
end
return DistributedVankaNS(ss.solver,ns)
end

function Algebra.solve!(x::PVector,ns::DistributedVankaNS,b::PVector)
map(solve!,partition(x),ns.ns,partition(b))
assemble!(x) |> wait
consistent!(x) |> wait
return x
end
Loading

0 comments on commit 2e2034d

Please sign in to comment.