-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Jeremy E Kozdon
committed
Jul 20, 2021
1 parent
0499a9c
commit 7abb847
Showing
5 changed files
with
123 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,99 @@ | ||
""" | ||
MatShell{T}(obj, m, n) | ||
Create a `m×n` PETSc shell matrix object wrapping `obj`. | ||
If `obj` is a `Function`, then the multiply action `obj(y,x)`; otherwise it calls `mul!(y, obj, x)`. | ||
This can be changed by defining `PETSc._mul!`. | ||
MatShell( | ||
petsclib::PetscLib, | ||
obj::OType, | ||
comm::MPI.Comm, | ||
local_rows, | ||
local_cols, | ||
global_rows = LibPETSc.PETSC_DECIDE, | ||
global_cols = LibPETSc.PETSC_DECIDE, | ||
) | ||
Create a `global_rows X global_cols` PETSc shell matrix object wrapping `obj` | ||
with local size `local_rows X local_cols`. | ||
The `obj` will be registered as an `MATOP_MULT` function and if if `obj` is a | ||
`Function`, then the multiply action `obj(y,x)`; otherwise it calls `mul!(y, | ||
obj, x)`. | ||
# External Links | ||
$(_doc_external("Mat/MatCreateShell")) | ||
$(_doc_external("Mat/MatShellSetOperation")) | ||
$(_doc_external("Mat/MATOP_MULT")) | ||
""" | ||
mutable struct MatShell{T,A} <: AbstractMat{T} | ||
mutable struct MatShell{PetscLib, PetscScalar, OType} <: | ||
AbstractMat{PetscLib, PetscScalar} | ||
ptr::CMat | ||
obj::A | ||
obj::OType | ||
end | ||
|
||
struct MatOp{PetscLib, PetscInt, Op} end | ||
|
||
function (::MatOp{PetscLib, PetscInt, LibPETSc.MATOP_MULT})( | ||
M::CMat, | ||
cx::CVec, | ||
cy::CVec, | ||
)::PetscInt where {PetscLib, PetscInt} | ||
r_ctx = Ref{Ptr{Cvoid}}() | ||
LibPETSc.MatShellGetContext(PetscLib, M, r_ctx) | ||
ptr = r_ctx[] | ||
mat = unsafe_pointer_to_objref(ptr) | ||
|
||
struct MatOp{T,Op} end | ||
PetscScalar = getlib(PetscLib).PetscScalar | ||
x = unsafe_localarray(VecPtr(PetscLib, cx); write = false) | ||
y = unsafe_localarray(VecPtr(PetscLib, cy); read = false) | ||
|
||
_mul!(y, mat, x) | ||
|
||
function _mul!(y,mat::MatShell{T,F},x) where {T, F<:Function} | ||
Base.finalize(y) | ||
Base.finalize(x) | ||
return PetscInt(0) | ||
end | ||
|
||
function _mul!( | ||
y, | ||
mat::MatShell{PetscLib, PetscScalar, F}, | ||
x, | ||
) where {PetscLib, PetscScalar, F <: Function} | ||
mat.obj(y, x) | ||
end | ||
|
||
function _mul!(y,mat::MatShell{T},x) where {T} | ||
function _mul!(y, mat::MatShell, x) where {T} | ||
LinearAlgebra.mul!(y, mat.obj, x) | ||
end | ||
|
||
MatShell{T}(obj, m, n) where {T} = MatShell{T}(obj, MPI.COMM_SELF, m, n, m, n) | ||
|
||
|
||
@for_libpetsc begin | ||
function MatShell{$PetscScalar}(obj::A, comm::MPI.Comm, m, n, M, N) where {A} | ||
mat = MatShell{$PetscScalar,A}(C_NULL, obj) | ||
# we use the MatShell object itsel | ||
ctx = pointer_from_objref(mat) | ||
@chk ccall((:MatCreateShell, $libpetsc), PetscErrorCode, | ||
(MPI.MPI_Comm,$PetscInt,$PetscInt,$PetscInt,$PetscInt,Ptr{Cvoid},Ptr{CMat}), | ||
comm, m, n, M, N, ctx, mat) | ||
|
||
mulptr = @cfunction(MatOp{$PetscScalar, MATOP_MULT}(), $PetscInt, (CMat, CVec, CVec)) | ||
@chk ccall((:MatShellSetOperation, $libpetsc), PetscErrorCode, (CMat, MatOperation, Ptr{Cvoid}), mat, MATOP_MULT, mulptr) | ||
return mat | ||
end | ||
|
||
function (::MatOp{$PetscScalar, MATOP_MULT})(M::CMat,cx::CVec,cy::CVec)::$PetscInt | ||
r_ctx = Ref{Ptr{Cvoid}}() | ||
@chk ccall((:MatShellGetContext, $libpetsc), PetscErrorCode, (CMat, Ptr{Ptr{Cvoid}}), M, r_ctx) | ||
ptr = r_ctx[] | ||
mat = unsafe_pointer_to_objref(ptr) | ||
|
||
x = unsafe_localarray($PetscScalar, cx; write=false) | ||
y = unsafe_localarray($PetscScalar, cy; read=false) | ||
|
||
_mul!(y,mat,x) | ||
|
||
Base.finalize(y) | ||
Base.finalize(x) | ||
return $PetscInt(0) | ||
end | ||
|
||
# We have to use the macro here because of the @cfunction | ||
LibPETSc.@for_petsc function MatShell( | ||
petsclib::$PetscLib, | ||
obj::OType, | ||
comm::MPI.Comm, | ||
local_rows, | ||
local_cols, | ||
global_rows = LibPETSc.PETSC_DECIDE, | ||
global_cols = LibPETSc.PETSC_DECIDE, | ||
) where {OType} | ||
mat = MatShell{$PetscLib, $PetscScalar, OType}(C_NULL, obj) | ||
|
||
# we use the MatShell object itself | ||
ctx = pointer_from_objref(mat) | ||
|
||
LibPETSc.MatCreateShell( | ||
petsclib, | ||
comm, | ||
local_rows, | ||
local_cols, | ||
global_rows, | ||
global_cols, | ||
pointer_from_objref(mat), | ||
mat, | ||
) | ||
|
||
mulptr = @cfunction( | ||
MatOp{$PetscLib, $PetscInt, LibPETSc.MATOP_MULT}(), | ||
$PetscInt, | ||
(CMat, CVec, CVec) | ||
) | ||
LibPETSc.MatShellSetOperation(petsclib, mat, LibPETSc.MATOP_MULT, mulptr) | ||
|
||
return mat | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
using Test | ||
using PETSc | ||
using MPI | ||
|
||
@testset "MatShell" begin | ||
for petsclib in PETSc.petsclibs | ||
PETSc.initialize(petsclib) | ||
PetscScalar = petsclib.PetscScalar | ||
|
||
local_rows = 10 | ||
local_cols = 5 | ||
f!(x, y) = x .= [2y; 3y] | ||
x_jl = collect | ||
|
||
matshell = | ||
PETSc.MatShell(petsclib, f!, MPI.COMM_SELF, local_rows, local_cols) | ||
x = PetscScalar.(collect(1:5)) | ||
@test matshell * x == [2x; 3x] | ||
|
||
PETSc.finalize(petsclib) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ include("init.jl") | |
include("options.jl") | ||
include("vec.jl") | ||
include("mat.jl") | ||
include("matshell.jl") |