Skip to content

Commit

Permalink
Adding some operations to PSparseMatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
fverdugo committed Jan 21, 2024
1 parent 9b8ffc1 commit b7bc762
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 4 deletions.
82 changes: 82 additions & 0 deletions src/p_sparse_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,42 @@ function LinearAlgebra.fillstored!(a::AbstractSplitMatrix,v)
a
end

function Base.:*(a::Number,b::AbstractSplitMatrix)
own_own = a*b.blocks.own_own
own_ghost = a*b.blocks.own_ghost
ghost_own = a*b.blocks.ghost_own
ghost_ghost = a*b.blocks.ghost_ghost
blocks = split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost)
split_matrix(blocks,b.row_permutation,b.col_permutation)
end

function Base.:*(b::AbstractSplitMatrix,a::Number)
a*b
end

for op in (:+,:-)
@eval begin
function Base.$op(a::AbstractSplitMatrix)
own_own = $op(a.blocks.own_own)
own_ghost = $op(a.blocks.own_ghost)
ghost_own = $op(a.blocks.ghost_own)
ghost_ghost = $op(a.blocks.ghost_ghost)
blocks = split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost)
split_matrix(blocks,a.row_permutation,a.col_permutation)
end
function Base.$op(a::AbstractSplitMatrix,b::AbstractSplitMatrix)
@boundscheck @assert a.row_permutation == b.row_permutation
@boundscheck @assert a.col_permutation == b.col_permutation
own_own = $op(a.blocks.own_own,b.blocks.own_own)
own_ghost = $op(a.blocks.own_ghost,b.blocks.own_ghost)
ghost_own = $op(a.blocks.ghost_own,b.blocks.ghost_own)
ghost_ghost = $op(a.blocks.ghost_ghost,b.blocks.ghost_ghost)
blocks = split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost)
split_matrix(blocks,b.row_permutation,b.col_permutation)
end
end
end

function split_format_locally(A,rows,cols)
n_own_rows = own_length(rows)
n_own_cols = own_length(cols)
Expand Down Expand Up @@ -1582,6 +1618,52 @@ function psparse_consitent_impl!(B,A,::Type{<:AbstractSplitMatrix},cache)
end
end

function Base.:*(a::PSparseMatrix,b::PVector)
Ta = eltype(a)
Tb = eltype(b)
T = typeof(zero(Ta)*zero(Tb)+zero(Ta)*zero(Tb))
c = PVector{Vector{T}}(undef,partition(axes(a,1)))
mul!(c,a,b)
c
end

function Base.:*(a::Number,b::PSparseMatrix)
matrix_partition = map(partition(b)) do values
a*values
end
rows = partition(axes(b,1))
cols = partition(axes(b,2))
PSparseMatrix(matrix_partition,rows,cols,b.assembled)
end

function Base.:*(b::PSparseMatrix,a::Number)
a*b
end

for op in (:+,:-)
@eval begin
function Base.$op(a::PSparseMatrix)
matrix_partition = map(partition(a)) do a
$op(a)
end
rows = partition(axes(a,1))
cols = partition(axes(a,2))
PSparseMatrix(matrix_partition,rows,cols,a.assembled)
end
function Base.$op(a::PSparseMatrix,b::PSparseMatrix)
@boundscheck @assert matching_local_indices(axes(a,1),axes(b,1))
@boundscheck @assert matching_local_indices(axes(a,2),axes(b,2))
matrix_partition = map(partition(a),partition(b)) do a,b
$op(a,b)
end
rows = partition(axes(b,1))
cols = partition(axes(b,2))
assembled = a.assembled && b.assembled
PSparseMatrix(matrix_partition,rows,cols,assembled)
end
end
end

function LinearAlgebra.mul!(c::PVector,a::PSparseMatrix,b::PVector::Number::Number)
@boundscheck @assert matching_own_indices(axes(c,1),axes(a,1))
@boundscheck @assert matching_own_indices(axes(a,2),axes(b,1))
Expand Down
18 changes: 14 additions & 4 deletions test/p_sparse_matrix_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ function p_sparse_matrix_tests(distribute)
A = psparse(I,J,V,row_partition,col_partition) |> fetch
x = pones(partition(axes(A,2)))
y = A*x
@test isa(y,PVector)
dy = y - y

x = IterativeSolvers.cg(A,y)
Expand Down Expand Up @@ -299,10 +300,19 @@ function p_sparse_matrix_tests(distribute)
display((A,A))
display((b,b))

#Ar = renumber(A)
#br,cache = renumber(b,partition(axes(Ar,1)),reuse=true)
#cr = Ar\br
#renumber!(c,cr)
LinearAlgebra.fillstored!(A,3)
B = 2*A
@test eltype(partition(B)) == eltype(partition(A))
B = A*2
@test eltype(partition(B)) == eltype(partition(A))
B = +A
@test eltype(partition(B)) == eltype(partition(A))
B = -A
@test eltype(partition(B)) == eltype(partition(A))
C = B+A
@test eltype(partition(C)) == eltype(partition(A))
C = B-A
@test eltype(partition(C)) == eltype(partition(A))

end

0 comments on commit b7bc762

Please sign in to comment.