From b7bc762cf0c192f091671d18021d2914c599b547 Mon Sep 17 00:00:00 2001 From: Francesc Verdugo Date: Sun, 21 Jan 2024 16:17:58 +0100 Subject: [PATCH] Adding some operations to PSparseMatrix --- src/p_sparse_matrix.jl | 82 +++++++++++++++++++++++++++++++++++ test/p_sparse_matrix_tests.jl | 18 ++++++-- 2 files changed, 96 insertions(+), 4 deletions(-) diff --git a/src/p_sparse_matrix.jl b/src/p_sparse_matrix.jl index 2bb3413b..4964e80e 100644 --- a/src/p_sparse_matrix.jl +++ b/src/p_sparse_matrix.jl @@ -697,6 +697,42 @@ function LinearAlgebra.fillstored!(a::AbstractSplitMatrix,v) a end +function Base.:*(a::Number,b::AbstractSplitMatrix) + own_own = a*b.blocks.own_own + own_ghost = a*b.blocks.own_ghost + ghost_own = a*b.blocks.ghost_own + ghost_ghost = a*b.blocks.ghost_ghost + blocks = split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost) + split_matrix(blocks,b.row_permutation,b.col_permutation) +end + +function Base.:*(b::AbstractSplitMatrix,a::Number) + a*b +end + +for op in (:+,:-) + @eval begin + function Base.$op(a::AbstractSplitMatrix) + own_own = $op(a.blocks.own_own) + own_ghost = $op(a.blocks.own_ghost) + ghost_own = $op(a.blocks.ghost_own) + ghost_ghost = $op(a.blocks.ghost_ghost) + blocks = split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost) + split_matrix(blocks,a.row_permutation,a.col_permutation) + end + function Base.$op(a::AbstractSplitMatrix,b::AbstractSplitMatrix) + @boundscheck @assert a.row_permutation == b.row_permutation + @boundscheck @assert a.col_permutation == b.col_permutation + own_own = $op(a.blocks.own_own,b.blocks.own_own) + own_ghost = $op(a.blocks.own_ghost,b.blocks.own_ghost) + ghost_own = $op(a.blocks.ghost_own,b.blocks.ghost_own) + ghost_ghost = $op(a.blocks.ghost_ghost,b.blocks.ghost_ghost) + blocks = split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost) + split_matrix(blocks,b.row_permutation,b.col_permutation) + end + end +end + function split_format_locally(A,rows,cols) n_own_rows = own_length(rows) n_own_cols = own_length(cols) @@ -1582,6 +1618,52 @@ function psparse_consitent_impl!(B,A,::Type{<:AbstractSplitMatrix},cache) end end +function Base.:*(a::PSparseMatrix,b::PVector) + Ta = eltype(a) + Tb = eltype(b) + T = typeof(zero(Ta)*zero(Tb)+zero(Ta)*zero(Tb)) + c = PVector{Vector{T}}(undef,partition(axes(a,1))) + mul!(c,a,b) + c +end + +function Base.:*(a::Number,b::PSparseMatrix) + matrix_partition = map(partition(b)) do values + a*values + end + rows = partition(axes(b,1)) + cols = partition(axes(b,2)) + PSparseMatrix(matrix_partition,rows,cols,b.assembled) +end + +function Base.:*(b::PSparseMatrix,a::Number) + a*b +end + +for op in (:+,:-) + @eval begin + function Base.$op(a::PSparseMatrix) + matrix_partition = map(partition(a)) do a + $op(a) + end + rows = partition(axes(a,1)) + cols = partition(axes(a,2)) + PSparseMatrix(matrix_partition,rows,cols,a.assembled) + end + function Base.$op(a::PSparseMatrix,b::PSparseMatrix) + @boundscheck @assert matching_local_indices(axes(a,1),axes(b,1)) + @boundscheck @assert matching_local_indices(axes(a,2),axes(b,2)) + matrix_partition = map(partition(a),partition(b)) do a,b + $op(a,b) + end + rows = partition(axes(b,1)) + cols = partition(axes(b,2)) + assembled = a.assembled && b.assembled + PSparseMatrix(matrix_partition,rows,cols,assembled) + end + end +end + function LinearAlgebra.mul!(c::PVector,a::PSparseMatrix,b::PVector,α::Number,β::Number) @boundscheck @assert matching_own_indices(axes(c,1),axes(a,1)) @boundscheck @assert matching_own_indices(axes(a,2),axes(b,1)) diff --git a/test/p_sparse_matrix_tests.jl b/test/p_sparse_matrix_tests.jl index 8950032f..22f9f539 100644 --- a/test/p_sparse_matrix_tests.jl +++ b/test/p_sparse_matrix_tests.jl @@ -217,6 +217,7 @@ function p_sparse_matrix_tests(distribute) A = psparse(I,J,V,row_partition,col_partition) |> fetch x = pones(partition(axes(A,2))) y = A*x + @test isa(y,PVector) dy = y - y x = IterativeSolvers.cg(A,y) @@ -299,10 +300,19 @@ function p_sparse_matrix_tests(distribute) display((A,A)) display((b,b)) - #Ar = renumber(A) - #br,cache = renumber(b,partition(axes(Ar,1)),reuse=true) - #cr = Ar\br - #renumber!(c,cr) + LinearAlgebra.fillstored!(A,3) + B = 2*A + @test eltype(partition(B)) == eltype(partition(A)) + B = A*2 + @test eltype(partition(B)) == eltype(partition(A)) + B = +A + @test eltype(partition(B)) == eltype(partition(A)) + B = -A + @test eltype(partition(B)) == eltype(partition(A)) + C = B+A + @test eltype(partition(C)) == eltype(partition(A)) + C = B-A + @test eltype(partition(C)) == eltype(partition(A)) end