Skip to content

Commit

Permalink
WIP: Working on KSP
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremy E Kozdon committed Aug 4, 2021
1 parent 969b287 commit 1bd216c
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 75 deletions.
130 changes: 62 additions & 68 deletions src/ksp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,88 @@ const CKSP = Ptr{Cvoid}
const CKSPType = Cstring

abstract type AbstractKSP{PetscLib, PetscScalar} <: Factorization{PetscScalar} end

Base.@kwdef mutable struct KSP{PetscLib, PetscScalar} <: AbstractKSP{PetscLib, PetscScalar}

Base.@kwdef mutable struct KSP{PetscLib, PetscScalar} <:
AbstractKSP{PetscLib, PetscScalar}
ptr::CKSP = C_NULL
opts::Options{PetscLib} = Options(PetscLib)
A::Union{AbstractMat, Nothing} = nothing
P::Union{AbstractMat, Nothing} = nothing
end

function KSP(A::AbstractMat{PetscLib}; kwargs...) where PetscLib
function KSP(
A::AbstractMat{PetscLib},
P::AbstractMat{PetscLib} = A;
kwargs...,
) where {PetscLib}
@assert initialized(PetscLib)
opts = Options(PetscLib; kwargs...)
PetscScalar = PetscLib.PetscScalar
ksp = KSP{PetscLib, PetscScalar}(opts=opts, A = A)
#=
ksp = KSP{PetscLib, PetscScalar}(opts = opts)
comm = getcomm(A)

with(ksp.opts) do
@chk ccall((:KSPCreate, $libpetsc), PetscErrorCode, (MPI.MPI_Comm, Ptr{CKSP}), comm, ksp)
LibPETSc.KSPCreate(PetscLib, comm, ksp)
end
if comm == MPI.COMM_SELF

setoperators!(ksp, A, P)
setfromoptions!(ksp)

# If there is only one rank we can finalize the KSP with GC
if MPI.Comm_size(comm) == 1
finalizer(destroy, ksp)
end
=#

return ksp
end

function setoperators!(
ksp::AbstractKSP{PetscLib},
A::AbstractMat{PetscLib},
P::AbstractMat{PetscLib} = A,
) where {PetscLib}
LibPETSc.KSPSetOperators(PetscLib, ksp, A, P)
ksp.A = A
ksp.P = P
return ksp
end

function setfromoptions!(ksp::AbstractKSP{PetscLib}) where {PetscLib}
with(ksp.opts) do
LibPETSc.KSPSetFromOptions(PetscLib, ksp)
end
end

function destroy(ksp::AbstractKSP{PetscLib}) where {PetscLib}
finalized(PetscLib) || LibPETSc.MatDestroy(PetscLib, ksp)
return nothing
end

function solve!(
x::AbstractVec{PetscLib},
ksp::AbstractKSP{PetscLib},
b::AbstractVec{PetscLib},
) where {PetscLib}
with(ksp.opts) do
LibPETSc.KSPSolve(PetscLib, ksp, b, x)
end
return x
end

function LinearAlgebra.ldiv!(x::AbstractVec, ksp::AbstractKSP, b::AbstractVec)
solve!(x, ksp, b)
end
#=
function Base.:\(ksp::AbstractKSP, b::AbstractVec)
ldiv!(similar(b), ksp, b)
end
=#

#=
struct WrappedKSP{T, PetscLib} <: AbstractKSP{T, PetscLib}
ptr::CKSP
end
scalartype(::KSP{T}) where {T} = T
Base.eltype(::KSP{T}) where {T} = T
LinearAlgebra.transpose(ksp) = LinearAlgebra.Transpose(ksp)
LinearAlgebra.adjoint(ksp) = LinearAlgebra.Adjoint(ksp)
Expand Down Expand Up @@ -80,32 +131,6 @@ struct Fn_KSPComputeOperators{T} end
@for_libpetsc begin
function KSP{$PetscScalar}(comm::MPI.Comm; kwargs...)
@assert initialized($petsclib)
opts = Options($petsclib, kwargs...)
ksp = KSP{$PetscScalar, $PetscLib}(opts=opts)
with(ksp.opts) do
@chk ccall((:KSPCreate, $libpetsc), PetscErrorCode, (MPI.MPI_Comm, Ptr{CKSP}), comm, ksp)
end
if comm == MPI.COMM_SELF
finalizer(destroy, ksp)
end
return ksp
end
function destroy(ksp::KSP{$PetscScalar})
finalized($petsclib) ||
@chk ccall((:KSPDestroy, $libpetsc), PetscErrorCode, (Ptr{CKSP},), ksp)
return nothing
end
function setoperators!(ksp::KSP{$PetscScalar}, A::AbstractMat{$PetscScalar}, P::AbstractMat{$PetscScalar})
@chk ccall((:KSPSetOperators, $libpetsc), PetscErrorCode, (CKSP, CMat, CMat), ksp, A, P)
ksp._A = A
ksp._P = P
return nothing
end
function (::Fn_KSPComputeRHS{$PetscScalar})(
new_ksp_ptr::CKSP,
cb::CVec,
Expand Down Expand Up @@ -194,12 +219,6 @@ struct Fn_KSPComputeOperators{T} end
return nothing
end
function setfromoptions!(ksp::KSP{$PetscScalar})
with(ksp.opts) do
@chk ccall((:KSPSetFromOptions, $libpetsc), PetscErrorCode, (CKSP,), ksp)
end
end
function gettype(ksp::KSP{$PetscScalar})
t_r = Ref{CKSPType}()
@chk ccall((:KSPGetType, $libpetsc), PetscErrorCode, (CKSP, Ptr{CKSPType}), ksp, t_r)
Expand Down Expand Up @@ -227,14 +246,6 @@ struct Fn_KSPComputeOperators{T} end
return r_rnorm[]
end
function solve!(x::AbstractVec{$PetscScalar}, ksp::KSP{$PetscScalar}, b::AbstractVec{$PetscScalar})
with(ksp.opts) do
@chk ccall((:KSPSolve, $libpetsc), PetscErrorCode,
(CKSP, CVec, CVec), ksp, b, x)
end
return x
end
function solve!(ksp::KSP{$PetscScalar})
with(ksp.opts) do
@chk ccall((:KSPSolve, $libpetsc), PetscErrorCode,
Expand Down Expand Up @@ -266,21 +277,6 @@ function LinearAlgebra.ldiv!(x::AbstractVector{T}, ksp::KSPAT{T, LT}, b::Abstrac
end
Base.:\(ksp::KSPAT{T, LT}, b::AbstractVector{T}) where {T, LT} = ldiv!(similar(b), ksp, b)
"""
KSP(A, P; options...)
Construct a PETSc Krylov subspace solver.
Any PETSc options prefixed with `ksp_` and `pc_` can be passed as keywords.
"""
function KSP(A::AbstractMat{T}, P::AbstractMat{T}=A; kwargs...) where {T}
ksp = KSP{T}(getcomm(A); kwargs...)
setoperators!(ksp, A, P)
setfromoptions!(ksp)
return ksp
end
"""
KSP(da::AbstractDM; options...)
Expand All @@ -300,7 +296,6 @@ end
Base.show(io::IO, ksp::KSP) = _show(io, ksp)
"""
iters(ksp::KSP)
Expand All @@ -311,7 +306,6 @@ $(_doc_external("KSP/KSPGetIterationNumber"))
"""
iters
"""
resnorm(ksp::KSP)
Expand Down
9 changes: 5 additions & 4 deletions src/options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ mutable struct Options{T} <: AbstractOptions{T}
ptr::CPetscOptions
end

function Options(petsclib::PetscLibType)
function Options_(petsclib::PetscLibType)
@assert initialized(petsclib)
PetscLib = typeof(petsclib)
opts = Options{PetscLib}(C_NULL)
Expand All @@ -86,9 +86,10 @@ function Options(petsclib::PetscLibType)
return opts
end

Options(petsclib; kwargs...) = Options(petsclib, kwargs...)
function Options(petsclib, ps::Pair...)
opts = Options(petsclib)
Options(petsclib::PetscLibType; kwargs...) = Options_(petsclib, kwargs...)
Options(PetscLib::Type{<:PetscLibType}; kwargs...) = Options_(getlib(PetscLib), kwargs...)
function Options_(petsclib::PetscLibType, ps::Pair...)
opts = Options_(petsclib)
for (k, v) in ps
opts[k] = v
end
Expand Down
5 changes: 3 additions & 2 deletions src/sys.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const CPetscObject = Ptr{Cvoid}

const UnionPetscTypes = Union{Options, AbstractVec, AbstractMat}
const UnionPetscTypes = Union{Options, AbstractVec, AbstractMat, AbstractKSP}

# allows us to pass PETSc_XXX objects directly into CXXX ccall signatures
Base.cconvert(::Type{CPetscObject}, obj::UnionPetscTypes) = obj
Expand All @@ -12,7 +12,8 @@ function Base.unsafe_convert(::Type{Ptr{CPetscObject}}, obj::UnionPetscTypes)
end

function getcomm(
obj::Union{AbstractVec{PetscLib}, AbstractMat{PetscLib}},
obj::Union{AbstractVec{PetscLib}, AbstractMat{PetscLib},
AbstractKSP{PetscLib}},
) where {PetscLib}
comm = MPI.Comm()
LibPETSc.PetscObjectGetComm(PetscLib, obj, comm)
Expand Down
7 changes: 7 additions & 0 deletions src/vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,10 @@ function getpetsctype(vec::AbstractVec{PetscLib}) where {PetscLib}
LibPETSc.VecGetType(PetscLib, vec, name_r)
return unsafe_string(name_r[])
end

function Base.similar(v::AbstractVec{PetscLib}) where {PetscLib}
r_x = Ref{CVec}()
LibPETSc.VecDuplicate(PetscLib, v, r_x)
x = VecPtr(PetscLib, r_x[], true)
return x
end
76 changes: 76 additions & 0 deletions test/ksp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using Test
using MPI
MPI.Initialized() || MPI.Init()
using PETSc
using LinearAlgebra: mul!

@testset "KSP" begin
comm = MPI.COMM_WORLD
mpisize = MPI.Comm_size(comm)
mpirank = MPI.Comm_rank(comm)

for petsclib in PETSc.petsclibs
PETSc.initialize(petsclib)
PetscScalar = petsclib.PetscScalar
PetscInt = petsclib.PetscInt

loc_num_rows = 10
loc_num_cols = 10
diag_nonzeros = 3
off_diag_non_zeros = 3

A = PETSc.MatAIJ(
petsclib,
comm,
loc_num_rows,
loc_num_cols,
diag_nonzeros,
off_diag_non_zeros,
)

# Get compatible vectors
(x, b) = PETSc.createvecs(A)

row_rng = PETSc.ownershiprange(A, false)
for i in row_rng
if i == 0
vals = [-2, 1]
row0idxs = [i]
col0idxs = [i, i + 1]
elseif i == mpisize * loc_num_rows - 1
vals = [-2, 1]
row0idxs = [i]
col0idxs = [i, i - 1]
else
vals = [1, -2, 1]
row0idxs = [i]
col0idxs = [i - 1, i, i + 1]
end
PETSc.setvalues!(
A,
PetscInt.(row0idxs),
PetscInt.(col0idxs),
PetscScalar.(vals),
)
x[i + 1] = (i + 1)^3
end
PETSc.assemble!(A)
PETSc.assemble!(x)

mul!(b, A, x)
y = similar(x)

ksp = PETSc.KSP(A; ksp_rtol = 1e-16, pc_type = "jacobi")
PETSc.solve!(y, ksp, b)
PETSc.withlocalarray!(x, y) do x, y
@test x y
end

# PETSc.destroy(ksp)
PETSc.destroy(A)
PETSc.destroy(x)
PETSc.destroy(y)
PETSc.destroy(b)
PETSc.finalize(petsclib)
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using MPI: mpiexec

# Do the MPI tests first so we do not have mpi running inside MPI
@testset "mpi tests" begin
for file in ("mpivec.jl", "mpimat.jl")
for file in ("mpivec.jl", "mpimat.jl", "ksp.jl")
@test mpiexec() do mpi_cmd
cmd =
`$mpi_cmd -n 4 $(Base.julia_cmd()) --startup-file=no --project $file`
Expand All @@ -18,3 +18,4 @@ include("options.jl")
include("vec.jl")
include("mat.jl")
include("matshell.jl")
include("ksp.jl")

0 comments on commit 1bd216c

Please sign in to comment.