Skip to content

Commit

Permalink
Fixed BlockPArray bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
JordiManyer committed Apr 12, 2024
1 parent e2ecabf commit a343032
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
54 changes: 45 additions & 9 deletions src/BlockPartitionedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,19 @@ end

function Base.copyto!(y::BlockPVector,x::BlockPVector)
@check blocklength(x) == blocklength(y)
for i in blockaxes(x,1)
copyto!(y[i],x[i])
yb, xb = blocks(y), blocks(x)
for i in 1:blocksize(x,1)
copyto!(yb[i],xb[i])
end
return y
end

function Base.copyto!(y::BlockPMatrix,x::BlockPMatrix)
@check blocksize(x) == blocksize(y)
for i in blockaxes(x,1)
for j in blockaxes(x,2)
copyto!(y[i,j],x[i,j])
yb, xb = blocks(y), blocks(x)
for i in 1:blocksize(x,1)
for j in 1:blocksize(x,2)
copyto!(yb[i,j],xb[i,j])
end
end
return y
Expand All @@ -169,6 +171,8 @@ function Base.fill!(a::BlockPVector,v)
end

function Base.sum(a::BlockPArray)
# TODO: This could use a single communication, instead of one for each block
# TODO: We could implement a generic reduce, that we apply to sum, all, any, etc..
return sum(map(sum,blocks(a)))
end

Expand Down Expand Up @@ -284,15 +288,47 @@ end

# LinearAlgebra API

function Base.:*(a::Number,b::BlockArray)
mortar(map(bi -> a*bi,blocks(b)))
end
Base.:*(b::BlockPMatrix,a::Number) = a*b
Base.:/(b::BlockPVector,a::Number) = (1/a)*b

function Base.:*(a::BlockPMatrix,b::BlockPVector)
c = similar(b)
mul!(c,a,b)
return c
end

for op in (:+,:-)
@eval begin
function Base.$op(a::BlockPArray)
mortar(map($op,blocks(a)))
end
function Base.$op(a::BlockPArray,b::BlockPArray)
@assert blocksize(a) == blocksize(b)
mortar(map($op,blocks(a),blocks(b)))
end
end
end

function LinearAlgebra.mul!(y::BlockPVector,A::BlockPMatrix,x::BlockPVector)
o = one(eltype(A))
mul!(y,A,x,o,o)
end

function LinearAlgebra.mul!(y::BlockPVector,A::BlockPMatrix,x::BlockPVector::Number::Number)
yb, Ab, xb = blocks(y), blocks(A), blocks(x)
z = zero(eltype(y))
o = one(eltype(A))
for i in blockaxes(A,2)
fill!(y[i],z)
for j in blockaxes(A,2)
mul!(y[i],A[i,j],x[j],o,o)
for i in 1:blocksize(A,1)
fill!(yb[i],z)
for j in 1:blocksize(A,2)
mul!(yb[i],Ab[i,j],xb[j],α,o)
end
rmul!(yb[i],β)
end
return y
end

function LinearAlgebra.dot(x::BlockPVector,y::BlockPVector)
Expand Down
8 changes: 4 additions & 4 deletions test/MultiFieldTests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module MultiFieldTests

using Gridap
using Gridap.FESpaces
using Gridap.MultiField
using Gridap.FESpaces, Gridap.MultiField, Gridap.Algebra
using GridapDistributed
using PartitionedArrays
using Test
Expand Down Expand Up @@ -74,8 +73,9 @@ function main(distribute, parts, mfs)
A1 = assemble_matrix(a1,UxP,UxP)
A2 = assemble_matrix(a2,UxP,UxP)

x = prandn(partition(axes(A1,2)))
@test norm(A1*x-A2*x) < 1.0e-9
x1 = allocate_in_domain(A1); fill!(x1,1.0)
x2 = allocate_in_domain(A2); fill!(x2,1.0)
@test norm(A1*x1-A2*x2) < 1.0e-9
end

function main(distribute, parts)
Expand Down

0 comments on commit a343032

Please sign in to comment.