Skip to content

Commit

Permalink
Renamed split_values -> split_format
Browse files Browse the repository at this point in the history
  • Loading branch information
fverdugo committed Jan 13, 2024
1 parent 7ca291d commit 997ac6e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/PartitionedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ export PSparseMatrix
export old_psparse
export psparse
export psparse!
export split_values
export split_values!
export split_format
export split_format!
export old_psparse!
export own_ghost_values
export ghost_own_values
Expand Down
28 changes: 14 additions & 14 deletions src/p_sparse_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ function LinearAlgebra.fillstored!(a::AbstractSplitMatrix,v)
a
end

function split_locally(A,rows,cols)
function split_format_locally(A,rows,cols)
n_own_rows = own_length(rows)
n_own_cols = own_length(cols)
n_ghost_rows = ghost_length(rows)
Expand Down Expand Up @@ -775,7 +775,7 @@ function split_locally(A,rows,cols)
B, cache
end

function split_locally!(B::AbstractSplitMatrix,A,rows,cols,cache)
function split_format_locally!(B::AbstractSplitMatrix,A,rows,cols,cache)
(c1,c2,c3,c4,own_own_V,own_ghost_V,ghost_own_V,ghost_ghost_V) = cache
n_own_rows = own_length(rows)
n_own_cols = own_length(cols)
Expand Down Expand Up @@ -971,10 +971,10 @@ end
val_parameter(a) = a
val_parameter(::Val{a}) where a = a

function split_values(A::PSparseMatrix;reuse=Val(false))
function split_format(A::PSparseMatrix;reuse=Val(false))
rows = partition(axes(A,1))
cols = partition(axes(A,2))
values, cache = map(split_locally,partition(A),rows,cols) |> tuple_of_arrays
values, cache = map(split_format_locally,partition(A),rows,cols) |> tuple_of_arrays
B = PSparseMatrix(values,rows,cols,A.assembled)
if val_parameter(reuse) == false
B
Expand All @@ -983,10 +983,10 @@ function split_values(A::PSparseMatrix;reuse=Val(false))
end
end

function split_values!(B,A::PSparseMatrix,cache)
function split_format!(B,A::PSparseMatrix,cache)
rows = partition(axes(A,1))
cols = partition(axes(A,2))
map(split_locally!,partition(B),partition(A),rows,cols,cache)
map(split_format_locally!,partition(B),partition(A),rows,cols,cache)
B
end

Expand Down Expand Up @@ -1019,7 +1019,7 @@ instance of [`PSparseMatrix`](@ref) allowing latency hiding while performing
the communications needed in its setup.
"""
function psparse(f,I,J,V,rows,cols;
split=true,
split_format=true,
assembled=false,
assemble=true,
discover_rows=true,
Expand All @@ -1033,7 +1033,7 @@ function psparse(f,I,J,V,rows,cols;
# TODO for some particular cases
# this function allocates more
# intermediate results than needed
# One can e.g. merge the split and assemble steps
# One can e.g. merge the split_format and assemble steps
# Even the matrix compression step could be
# merged with the assembly step

Expand Down Expand Up @@ -1072,8 +1072,8 @@ function psparse(f,I,J,V,rows,cols;
map(map_local_to_global!,J,cols_sa)
end
A = PSparseMatrix(values_sa,rows_sa,cols_sa,assembled)
if split
B,cacheB = split_values(A;reuse=true)
if split_format
B,cacheB = PartitionedArrays.split_format(A;reuse=true)
else
B,cacheB = A,nothing
end
Expand All @@ -1090,7 +1090,7 @@ function psparse(f,I,J,V,rows,cols;
else
return @async begin
C, cacheC = fetch(t)
cache = (A,B,K,cacheB,cacheC,split,assembled)
cache = (A,B,K,cacheB,cacheC,split_format,assembled)
(C, cache)
end
end
Expand All @@ -1100,13 +1100,13 @@ end
psparse!(C::PSparseMatrix,V,cache)
"""
function psparse!(C,V,cache)
(A,B,K,cacheB,cacheC,split,assembled) = cache
(A,B,K,cacheB,cacheC,split_format,assembled) = cache
rows_sa = partition(axes(A,1))
cols_sa = partition(axes(A,2))
values_sa = partition(A)
map(setcoofast!,values_sa,V,K)
if split
split_values!(B,A,cacheB)
if split_format
split_format!(B,A,cacheB)
end
if !assembled && C.assembled
t = PartitionedArrays.assemble!(C,B,cacheC)
Expand Down
24 changes: 8 additions & 16 deletions test/p_sparse_matrix_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,21 @@ function p_sparse_matrix_tests(distribute)
end
end |> tuple_of_arrays

A = psparse(I,J,V,row_partition,col_partition,split=false,assemble=false) |> fetch
B = split_values(A)
B, cache = split_values(A,reuse=true)
split_values!(B,A,cache)
A = psparse(I,J,V,row_partition,col_partition,split_format=false,assemble=false) |> fetch
B = split_format(A)
B, cache = split_format(A,reuse=true)
split_format!(B,A,cache)
C = assemble(B) |> fetch
C,cache = assemble(B,reuse=true) |> fetch
assemble!(C,B,cache) |> wait
display(C)

A = psparse(I,J,V,row_partition,col_partition,split=true,assemble=false) |> fetch
A = psparse(I,J,V,row_partition,col_partition,split=true,assemble=true) |> fetch
A = psparse(I,J,V,row_partition,col_partition,split_format=true,assemble=false) |> fetch
A = psparse(I,J,V,row_partition,col_partition,split_format=true,assemble=true) |> fetch
A = psparse(I,J,V,row_partition,col_partition) |> fetch
display(A)
# TODO Assembly in non-split format not yet implemented
#A = psparse(I,J,V,row_partition,col_partition,split=false,assemble=true) |> fetch
# TODO Assembly in non-split_format format not yet implemented
#A = psparse(I,J,V,row_partition,col_partition,split_format=false,assemble=true) |> fetch

A,cache = psparse(I,J,V,row_partition,col_partition,reuse=true) |> fetch
psparse!(A,V,cache) |> wait
Expand Down Expand Up @@ -304,13 +304,5 @@ function p_sparse_matrix_tests(distribute)
#cr = Ar\br
#renumber!(c,cr)

# TODO
# 3. Cleanup and documentation
# better names for precompute_nzindex, setcoofast
# better name for split_values, split_values!
# how to avoind conflicts with existing functions
# in particular: delete val_parameter
# renumber

end

0 comments on commit 997ac6e

Please sign in to comment.