diff --git a/src/Assembly.jl b/src/Assembly.jl index e0b3374..a5f10c9 100644 --- a/src/Assembly.jl +++ b/src/Assembly.jl @@ -262,34 +262,37 @@ function create_from_nz_locally_assembled( a, callback::Function = rows -> nothing ) - # Recover some data I,J,V = get_allocations(a) - test_gids, trial_gids = axes(a) + test_ids = partition(axes(a,1)) + trial_ids = partition(axes(a,2)) - rows = remove_ghost(unpermute(test_gids)) + rows = map(remove_ghost,map(unpermute,test_ids)) b = callback(rows) - # convert I and J to global dof ids - to_global_indices!(I,test_gids;ax=:rows) - to_global_indices!(J,trial_gids;ax=:cols) + map(map_local_to_global!,I,test_ids) + map(map_local_to_global!,J,trial_ids) - # Create the range for cols - cols = _setup_prange(trial_gids,J;ax=:cols) + # TODO: replace_ghost or union_ghost? + # Actually, do we want to change the ghosts at all? We could return the unpermuted trial_ids + J_owners = find_owner(trial_ids,J) + cols = map(replace_ghost,map(unpermute,trial_ids),J,J_owners) - # Convert again I,J to local numeration - to_local_indices!(I,rows;ax=:rows) - to_local_indices!(J,cols;ax=:cols) + map(map_global_to_local!,I,rows) + map(map_global_to_local!,J,cols) + + assembled = true + a_sys = change_axes(a,(PRange(rows),PRange(cols))) + values = map(create_from_nz,local_views(a_sys)) + A = PSparseMatrix(values,rows,cols,assembled) - A = _setup_matrix(I,J,V,rows,cols) return A, b end function create_from_nz_assembled( - a, + a, callback::Function = rows -> nothing, - async_callback::Function = b -> nothing + async_callback::Function = b -> empty_async_task ) - # Recover some data I,J,V = get_allocations(a) test_ids = partition(axes(a,1)) trial_ids = partition(axes(a,2)) @@ -298,36 +301,44 @@ function create_from_nz_assembled( map(map_local_to_global!,I,test_ids) map(map_local_to_global!,J,trial_ids) - # Move (I,J,V) triplets to the owner process of each row I. + # Overlapped COO communication and vector assembly rows = map(unpermute,test_ids) - # rows = replace_ghost(unpermute(test_gids),I,find_owner(test_gids,I)) t = PartitionedArrays.assemble_coo!(I,J,V,rows) - - # Overlap CSC communication with rhs assembly b = callback(rows) - - # Wait the transfer to finish wait(t) - # Create the Prange for the cols - J_owners = find_owner(trial_ids,J) - cols = map(union_ghost,map(unpermute,trial_ids),J,J_owners) - # Overlap rhs communications with CSC compression t2 = async_callback(b) + J_owners = find_owner(trial_ids,J) + cols = map(replace_ghost,map(unpermute,trial_ids),J,J_owners) # TODO: replace_ghost or union_ghost? - # Convert again I,J to local numeration map(map_global_to_local!,I,rows) map(map_global_to_local!,J,cols) + assembled = true a_sys = change_axes(a,(PRange(rows),PRange(cols))) values = map(create_from_nz,local_views(a_sys)) - A = PSparseMatrix(values,rows,cols,true) + A = PSparseMatrix(values,rows,cols,assembled) - # Wait the transfer to finish - if !isa(t2,Nothing) - wait(t2) - end + wait(t2) + return A, b +end + +function create_from_nz_subassembled( + a, + callback::Function = rows -> nothing, + async_callback::Function = b -> empty_async_task +) + rows = partition(axes(a,1)) + cols = partition(axes(a,2)) + + b = callback(rows) + t2 = async_callback(b) + + assembled = false + values = map(create_from_nz,local_views(a)) + A = PSparseMatrix(values,rows,cols,assembled) + wait(t2) return A, b end diff --git a/src/PArraysExtras.jl b/src/PArraysExtras.jl index be90762..4d87e2b 100644 --- a/src/PArraysExtras.jl +++ b/src/PArraysExtras.jl @@ -111,6 +111,10 @@ function SparseArrays.findnz(A::PartitionedArrays.SubSparseMatrix) return I[mask], J[mask], V[mask] end +# Async tasks + +const empty_async_task = PartitionedArrays.FakeTask(x -> nothing) + # Linear algebra function LinearAlgebra.axpy!(α,x::PVector,y::PVector)