From a21abd7a6438dc60fc8eed3bd4ed853636a851e5 Mon Sep 17 00:00:00 2001 From: JordiManyer Date: Tue, 5 Nov 2024 19:02:35 +1100 Subject: [PATCH] We can assemble again! --- src/Assembly.jl | 52 +++++++++++++++++++++++------------------- src/GridapExtras.jl | 13 +++++++++++ src/PArraysExtras.jl | 13 +++++++++++ test/parrays_update.jl | 14 ++++++++++++ 4 files changed, 69 insertions(+), 23 deletions(-) diff --git a/src/Assembly.jl b/src/Assembly.jl index 1797a99..e0b3374 100644 --- a/src/Assembly.jl +++ b/src/Assembly.jl @@ -50,6 +50,7 @@ struct DistributedCounter{S,T,N,A,B} <: GridapType end Base.axes(a::DistributedCounter) = a.axes +Base.axes(a::DistributedCounter,d::Integer) = a.axes[d] local_views(a::DistributedCounter) = a.counters const PVectorCounter{S,T,A,B} = DistributedCounter{S,T,1,A,B} @@ -79,8 +80,18 @@ struct DistributedAllocation{S,T,N,A,B} <: GridapType end Base.axes(a::DistributedAllocation) = a.axes +Base.axes(a::DistributedAllocation,d::Integer) = a.axes[d] local_views(a::DistributedAllocation) = a.allocs +function change_axes(a::DistributedAllocation{S,T,N},axes::NTuple{N,<:PRange}) where {S,T,N} + indices = map(partition,axes) + local_axes = map(indices...) do indices... + map(ids -> Base.OneTo(local_length(ids)), indices) + end + allocs = map(change_axes,a.allocs,local_axes) + DistributedAllocation(allocs,axes,a.strategy) +end + const PVectorAllocation{S,T} = DistributedAllocation{S,T,1} const PSparseMatrixAllocation{S,T} = DistributedAllocation{S,T,2} @@ -97,9 +108,9 @@ function collect_touched_ids(a::PVectorAllocation{<:TrackedArrayAllocation}) ghost_lids = ghost_to_local(ids) ghost_owners = ghost_to_owner(ids) - touched_ghost_lids = filter(lid -> a.touched[lid],ghost_lids) + touched_ghost_lids = filter(lid -> a.touched[lid], ghost_lids) touched_ghost_owners = collect(ghost_owners[touched_ghost_lids]) - touched_ghost_gids = to_global!(touched_ghost_lids,ids) + touched_ghost_gids = to_global!(touched_ghost_lids, ids) ghost = GhostIndices(n_global,touched_ghost_gids,touched_ghost_owners) return replace_ghost(rows,ghost) end @@ -280,43 +291,38 @@ function create_from_nz_assembled( ) # Recover some data I,J,V = get_allocations(a) - test_gids, trial_gids = axes(a) + test_ids = partition(axes(a,1)) + trial_ids = partition(axes(a,2)) # convert I and J to global dof ids - to_global_indices!(I,test_gids;ax=:rows) - to_global_indices!(J,trial_gids;ax=:cols) + map(map_local_to_global!,I,test_ids) + map(map_local_to_global!,J,trial_ids) - # Create the Prange for the rows - rows = unpermute(test_gids) + # Move (I,J,V) triplets to the owner process of each row I. + rows = map(unpermute,test_ids) # rows = replace_ghost(unpermute(test_gids),I,find_owner(test_gids,I)) + t = PartitionedArrays.assemble_coo!(I,J,V,rows) - # Move (I,J,V) triplets to the owner process of each row I. - J_owners = find_owner(trial_gids,J) - cols = union_ghost(unpermute(trial_gids),J,J_owners) - t = _assemble_coo!(I,J,V,rows;owners=Jo) - - # Here we can overlap computations - # This is a good place to overlap since - # sending the matrix rows is a lot of data - if !isa(b,Nothing) - bprange=_setup_prange_from_pvector_allocation(b) - b = callback(bprange) - end + # Overlap CSC communication with rhs assembly + b = callback(rows) # Wait the transfer to finish wait(t) # Create the Prange for the cols - cols = _setup_prange(trial_gids,J;ax=:cols,owners=Jo) + J_owners = find_owner(trial_ids,J) + cols = map(union_ghost,map(unpermute,trial_ids),J,J_owners) # Overlap rhs communications with CSC compression t2 = async_callback(b) # Convert again I,J to local numeration - to_local_indices!(I,rows;ax=:rows) - to_local_indices!(J,cols;ax=:cols) + map(map_global_to_local!,I,rows) + map(map_global_to_local!,J,cols) - A = _setup_matrix(a,I,J,V,rows,cols) + a_sys = change_axes(a,(PRange(rows),PRange(cols))) + values = map(create_from_nz,local_views(a_sys)) + A = PSparseMatrix(values,rows,cols,true) # Wait the transfer to finish if !isa(t2,Nothing) diff --git a/src/GridapExtras.jl b/src/GridapExtras.jl index df5b118..e072318 100644 --- a/src/GridapExtras.jl +++ b/src/GridapExtras.jl @@ -53,3 +53,16 @@ end end nothing end + +# change_axes + +function change_axes(a::Algebra.CounterCOO{T,A}, axes::A) where {T,A} + b = Algebra.CounterCOO{T}(axes) + b.nnz = a.nnz + b +end + +function change_axes(a::Algebra.AllocationCOO{T,A}, axes::A) where {T,A} + counter = change_axes(a.counter,axes) + Algebra.AllocationCOO(counter,a.I,a.J,a.V) +end diff --git a/src/PArraysExtras.jl b/src/PArraysExtras.jl index 19c9714..be90762 100644 --- a/src/PArraysExtras.jl +++ b/src/PArraysExtras.jl @@ -98,6 +98,19 @@ function locally_repartition!(w::PVector,v::PVector) return w end +# SubSparseMatrix extensions + +function SparseArrays.findnz(A::PartitionedArrays.SubSparseMatrix) + I,J,V = findnz(A.parent) + rowmap, colmap = A.inv_indices + for k in eachindex(I) + I[k] = rowmap[I[k]] + J[k] = colmap[J[k]] + end + mask = map((i,j) -> (i > 0 && j > 0), I, J) + return I[mask], J[mask], V[mask] +end + # Linear algebra function LinearAlgebra.axpy!(α,x::PVector,y::PVector) diff --git a/test/parrays_update.jl b/test/parrays_update.jl index ae00925..9e39681 100644 --- a/test/parrays_update.jl +++ b/test/parrays_update.jl @@ -37,6 +37,8 @@ assem = SparseMatrixAssembler(dist_V,dist_V) dist_A1 = assemble_matrix(dist_a1,dist_V,dist_V) all(centralize(dist_A1) - serial_A1 .< 1e-10) +centralize(dist_A1) + function PartitionedArrays.precompute_nzindex(A,I,J;skip=false) K = zeros(Int32,length(I)) for (p,(i,j)) in enumerate(zip(I,J)) @@ -48,7 +50,19 @@ function PartitionedArrays.precompute_nzindex(A,I,J;skip=false) K end +Aoo = own_own_values(dist_A1).items[1] +using SparseArrays +function SparseArrays.findnz(A::PartitionedArrays.SubSparseMatrix) + I,J,V = findnz(A.parent) + rowmap, colmap = A.inv_indices + for k in eachindex(I) + I[k] = rowmap[I[k]] + J[k] = colmap[J[k]] + end + mask = map((i,j) -> (i > 0 && j > 0), I, J) + return I[mask], J[mask], V[mask] +end ############################################################################################