Skip to content

Commit

Permalink
Drafting the construction of a PSparseMatrix in split format
Browse files Browse the repository at this point in the history
  • Loading branch information
fverdugo committed Dec 18, 2023
1 parent a16f0dc commit 0ca9c3d
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/PartitionedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ include("p_vector.jl")
export PSparseMatrix
export psparse
export psparse!
export psparse_split_format!
export own_ghost_values
export ghost_own_values
include("p_sparse_matrix.jl")
Expand Down
288 changes: 288 additions & 0 deletions src/p_sparse_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,291 @@ function IterativeSolvers.zerox(A::PSparseMatrix,b::PVector)
fill!(x, zero(T))
return x
end

struct MatrixSplit{A,B,C,D}
own_own::A
own_ghost::B
ghost_own::C
ghost_ghost::D
end

struct PSparseMatrixCacheSplit{A,B}
V_snd::A
V_rcv::A
k_snd::B
k_rcv::B
end

struct PSparseMatrixNew{A,B,C,D,E}
matrix_partition::A
row_partition::B
col_partition_for_own_rows::C
col_partition_for_ghost_rows::D
cache::E
end

function scoo_from_coo(coo,rows,cols)
global_to_own_row = global_to_own(rows)
global_to_own_col = global_to_own(cols)
n_own_own = 0
n_own_ghost = 0
n_ghost_own = 0
n_ghost_ghost = 0
for p in 1:nnz(coo)
gi = coo.I[p]
gj = coo.J[p]
i = global_to_own_row[gi]
j = global_to_own_row[gj]
if i != 0 && j != 0
n_own_own += 1
elseif i != 0 && j==0
n_own_ghost += 1
elseif j == 0 && j!= 0
n_ghost_own += 1
else
n_ghost_ghost += 1
end
end
Tv = eltype(coo)
m,n = size(coo)
own_own = spzeros_coo(Tv,m,n,n_own_own)
own_ghost = spzeros_coo(Tv,m,n,n_own_ghost)
ghost_own = spzeros_coo(Tv,m,n,n_ghost_own)
ghost_ghost = spzeros_coo(Tv,m,n,n_ghost_ghost)
scoo = MatrixSplit(own_own,own_ghost,ghost_own,ghost_ghost)
n_own_own = 0
n_own_ghost = 0
n_ghost_own = 0
n_ghost_ghost = 0
for p in 1:nnz(coo)
gi = coo.I[p]
gj = coo.J[p]
gv = coo.V[p]
i = global_to_own_row[gi]
j = global_to_own_row[gj]
if i != 0 && j != 0
n_own_own += 1
scoo.own_own.I[n_own_own] = gi
scoo.own_own.J[n_own_own] = gj
scoo.own_own.V[n_own_own] = gv
elseif i != 0 && j==0
n_own_ghost += 1
scoo.own_ghost.I[n_own_ghost] = gi
scoo.own_ghost.J[n_own_ghost] = gj
scoo.own_ghost.V[n_own_ghost] = gv
elseif j == 0 && j!= 0
n_ghost_own += 1
scoo.ghost_own.I[n_ghost_own] = gi
scoo.ghost_own.J[n_ghost_own] = gj
scoo.ghost_own.V[n_ghost_own] = gv
else
n_ghost_ghost += 1
scoo.ghost_ghost.I[n_ghost_ghost] = gi
scoo.ghost_ghost.J[n_ghost_ghost] = gj
scoo.ghost_ghost.V[n_ghost_ghost] = gv
end
end
scoo
end

function psparse_split_format!(I,J,V,row_partition,col_partition;kwargs...)
scoo = map(I,J,V,row_partition,col_partition) do I,J,V,rows,cols
m = global_length(rows)
n = global_length(cols)
coo = SparseMatrixCOO(I,J,V,m,n)
scoo_from_coo(coo,rows,cols)
end
psparse_split_format_from_scoo!(scoo,row_partition,col_partition;kwargs...)
end

function psparse_split_format_from_scoo!(scoo,rows,cols;exchange_graph_options=(;))
function find_ghost_rows(scoo,rows,cols)
I_ghost = vcat(scoo.ghost_own.I,scoo.ghost_ghost.I)
(;scoo,rows,cols,I_ghost)
end
function expand_rows(state)
rows = map(i->i.rows,state)
cols = map(i->i.cols,state)
I_ghost = map(i->i.I_ghost,state)
J_ghost = map(i->i.scoo.ghost_ghost.J,state)
I_ghost_owner = find_owner(rows,I_ghost)
J_ghost_owner = find_owner(cols,J_ghost)
rows_with_ghost = map(union_ghost,rows,I_ghost,I_ghost_owner)
cols_for_ghost_rows = map(union_ghost,cols,J_ghost,J_ghost_owner)
parts_snd, parts_rcv = assembly_neighbors(
rows_with_ghost;exchange_graph_options...)
function setup_state(
rows_with_ghost,cols_for_ghost_rows,parts_snd,parts_rcv,state)
(;rows_with_ghost,cols_for_ghost_rows,parts_snd,parts_rcv,state...)
end
map(setup_state,rows_with_ghost,cols_for_ghost_rows,parts_snd,parts_rcv,state)
end
function setup_scsc_ghost(state)
map_global_to_ghost!(state.scoo.ghost_own.I,state.rows_with_ghost)
map_global_to_ghost!(state.scoo.ghost_ghost.I,state.rows_with_ghost)
map_global_to_own!(state.scoo.ghost_own.J,state.cols_for_ghost_rows)
map_global_to_ghost!(state.scoo.ghost_ghost.J,state.cols_for_ghost_rows)
a = ghost_length(state.rows_with_ghost)
b = own_length(state.cols_for_ghost_rows)
c = ghost_length(state.cols_for_ghost_rows)
A_ghost_own = sparse(findnz(state.scoo.ghost_own)...,a,b)
A_ghost_ghost = sparse(findnz(state.scoo.ghost_ghost)...,a,c)
(;A_ghost_own,A_ghost_ghost,state...)
end
function setup_cache_snd(state)
owner_to_p = Dict(( owner=>i for (i,owner) in enumerate(state.parts_snd) ))
ptrs = zeros(Int32,length(state.parts_snd)+1)
ghost_to_owner_row = ghost_to_owner(state.rows_with_ghost)
ghost_to_global_row = ghost_to_global(state.rows_with_ghost)
own_to_global_col = own_to_global(state.cols_for_ghost_rows)
ghost_to_global_col = ghost_to_global(state.cols_for_ghost_rows)
for (i,_,_) in nziterator(state.A_ghost_own)
owner = ghost_to_owner_row[i]
ptrs[owner_to_p[owner]+1] += 1
end
for (i,_,_) in nziterator(state.A_ghost_ghost)
owner = ghost_to_owner_row[i]
ptrs[owner_to_p[owner]+1] += 1
end
length_to_ptrs!(ptrs)
Tv = eltype(state.A_ghost_own)
ndata = ptrs[end]-1
I_snd_data = zeros(Int,ndata)
J_snd_data = zeros(Int,ndata)
V_snd_data = zeros(Tv,ndata)
k_snd_data = zeros(Int32,ndata)
nnz_ghost_own = 0
for (k,(i,j,v)) in enumerate(nziterator(state.A_ghost_own))
owner = ghost_to_owner_row[i]
p = ptrs[owner_to_p[owner]]
I_snd_data[p] = ghost_to_global_row[i]
J_snd_data[p] = own_to_global_col[j]
V_snd_data[p] = v
k_snd_data[p] = k
ptrs[owner_to_p[owner]] += 1
nnz_ghost_own += 1
end
for (k,(i,j,v)) in enumerate(nziterator(state.A_ghost_ghost))
owner = ghost_to_owner_row[i]
p = ptrs[owner_to_p[owner]]
I_snd_data[p] = ghost_to_global_row[i]
J_snd_data[p] = ghost_to_global_col[j]
V_snd_data[p] = v
k_snd_data[p] = k+nnz_ghost_own
ptrs[owner_to_p[owner]] += 1
end
rewind_ptrs!(ptrs)
I_snd = JaggedArray(I_snd_data,ptrs)
J_snd = JaggedArray(J_snd_data,ptrs)
V_snd = JaggedArray(V_snd_data,ptrs)
k_snd = JaggedArray(k_snd_data,ptrs)
(;I_snd,J_snd,V_snd,k_snd,state...)
end
function exchange_coo(state)
I_snd = map(i->i.I_snd,state)
J_snd = map(i->i.J_snd,state)
V_snd = map(i->i.V_snd,state)
parts_snd = map(i->i.parts_snd,state)
parts_rcv = map(i->i.parts_rcv,state)
graph = ExchangeGraph(parts_snd,parts_rcv)
t_I = exchange(I_snd,graph)
t_J = exchange(J_snd,graph)
t_V = exchange(V_snd,graph)
@async begin
I_rcv = fetch(t_I)
J_rcv = fetch(t_J)
V_rcv = fetch(t_V)
function setup_state(I_rcv,J_rcv,V_rcv,state)
(;I_rcv,J_rcv,V_rcv,state...)
end
map(setup_state,I_rcv,J_rcv,V_rcv,state)
end
end
function setup_scoo_own(state)
I_rcv_data = state.I_rcv.data
J_rcv_data = state.J_rcv.data
V_rcv_data = state.V_rcv.data
global_to_own_col = global_to_own(state.cols)
is_ghost = map(j->global_to_own_col[j]==0,J_rcv_data)
is_own = .! is_ghost
append!(state.scoo.own_own.I,I_rcv_data[is_own])
append!(state.scoo.own_own.J,J_rcv_data[is_own])
append!(state.scoo.own_own.V,V_rcv_data[is_own])
append!(state.scoo.own_ghost.I,I_rcv_data[is_ghost])
append!(state.scoo.own_ghost.J,J_rcv_data[is_ghost])
append!(state.scoo.own_ghost.V,V_rcv_data[is_ghost])
(;is_own,state...)
end
function expand_cols(state)
cols = map(i->i.cols,state)
J_ghost = map(i->i.scoo.own_ghost.J,state)
J_ghost_owner = find_owner(cols,J_ghost)
cols_for_own_rows = map(union_ghost,cols,J_ghost,J_ghost_owner)
map(state,cols_for_own_rows) do state,cols_for_own_rows
(;cols_for_own_rows,state...)
end
end
function setup_scsc_own(state)
map_global_to_own!(state.scoo.own_own.I,state.rows)
map_global_to_own!(state.scoo.own_ghost.I,state.rows)
map_global_to_own!(state.scoo.own_own.J,state.cols_for_own_rows)
map_global_to_ghost!(state.scoo.own_ghost.J,state.cols_for_own_rows)
a = own_length(state.rows)
b = own_length(state.cols_for_own_rows)
c = ghost_length(state.cols_for_own_rows)
A_own_own = sparse(findnz(state.scoo.own_own)...,a,b)
A_own_ghost = sparse(findnz(state.scoo.own_ghost)...,a,c)
A = MatrixSplit(A_own_own,A_own_ghost,state.A_ghost_own,state.A_ghost_ghost)
I_rcv_data = state.I_rcv.data
J_rcv_data = state.J_rcv.data
V_rcv_data = state.V_rcv.data
ndata = length(I_rcv_data)
k_rcv_data = zeros(Int32,ndata)
n_own_nz = nnz(A_own_own)
global_to_own_row = global_to_own(state.rows)
global_to_own_col = global_to_own(state.cols_for_own_rows)
global_to_ghost_col = global_to_ghost(state.cols_for_own_rows)
for p in 1:ndata
gi = I_rcv_data[p]
gj = J_rcv_data[p]
i = global_to_own_row[gi]
k = if state.is_own[p]
j = global_to_own_col[gj]
nzindex(A_own_own,i,j)
else
j = global_to_ghost_col[gj]
nzindex(A_own_ghost,i,j)+n_own_nz
end
k_rcv_data[p] = k
end
k_rcv = JaggedArray(k_rcv_data,state.I_rcv.ptrs)
(;A,k_rcv,state...)
end
function finalize_psparse_scsc(state)
matrix_partition = map(i->i.A,state)
row_partition = map(i->i.rows_with_ghost,state)
col_partition_for_own_rows = map(i->i.cols_for_own_rows,state)
col_partition_for_ghost_rows = map(i->i.cols_for_ghost_rows,state)
cache = map(i->PSparseMatrixCacheSplit(i.V_snd,i.V_rcv,i.k_snd,i.k_rcv),state)
PSparseMatrixNew(
matrix_partition,
row_partition,
col_partition_for_own_rows,
col_partition_for_ghost_rows,
cache)
end
state1 = map(find_ghost_rows,scoo,rows,cols)
state2 = expand_rows(state1)
state3 = map(setup_scsc_ghost,state2)
state4 = map(setup_cache_snd,state3)
t = exchange_coo(state4)
@async begin
state5 = fetch(t)
state6 = map(setup_scoo_own,state5)
state7 = expand_cols(state6)
state8 = map(setup_scsc_own,state7)
finalize_psparse_scsc(state8)
end
end


19 changes: 19 additions & 0 deletions test/p_sparse_matrix_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,24 @@ function p_sparse_matrix_tests(distribute)
@test norm(r) < 1.0e-9
display(A)

n = 10
parts = rank
row_partition = uniform_partition(parts,n)
col_partition = row_partition

I,J,V = map(parts) do part
if part == 1
[1,2,1,2,2], [2,6,1,2,1], [1.0,2.0,30.0,10.0,1.0]
elseif part == 2
[3,3,4,6], [3,9,4,2], [10.0,2.0,30.0,2.0]
elseif part == 3
[5,5,6,7], [5,6,6,7], [10.0,2.0,30.0,1.0]
else
[9,9,8,10,6], [9,3,8,10,5], [10.0,2.0,30.0,50.0,2.0]
end
end |> tuple_of_arrays

A = psparse_split_format!(I,J,V,row_partition,col_partition) |> fetch

end

0 comments on commit 0ca9c3d

Please sign in to comment.