Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add matrix_assembly! and precompute_nzindex! #172

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions src/p_sparse_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,236 @@ function psparse_from_split_blocks(oo,oh,rowp,colp;assembled=true)
psparse_from_split_blocks(oo,oh,ho,hh,rowp,colp;assembled)
end

function matrix_assembly!(f, I, J, V, rows, cols)
yssamtu marked this conversation as resolved.
Show resolved Hide resolved
yssamtu marked this conversation as resolved.
Show resolved Hide resolved
function dutch_national_flag_partition!(part, key, values::Vararg{Any,N}) where {N}
global_to_own_part = global_to_own(part)
left_ptr = firstindex(key)
mid_ptr = firstindex(key)
right_ptr = lastindex(key)
n_change = 0
while true
if mid_ptr > right_ptr
Tkey = eltype(key)
return left_ptr - 1, mid_ptr, Vector{Tuple{Tkey,Tkey}}(undef, n_change)
elseif key[mid_ptr] <= 0 || values[1][mid_ptr] <= 0
mid_ptr += 1
elseif global_to_own_part[key[mid_ptr]] != 0
if left_ptr != mid_ptr
break
end
left_ptr += 1
mid_ptr += 1
else
if mid_ptr != right_ptr
break
end
right_ptr -= 1
end
end
left_start = left_ptr
mid_start = mid_ptr
right_start = right_ptr
mid_actual_ptr = mid_ptr
while mid_ptr <= right_ptr
if key[mid_actual_ptr] <= 0 || values[1][mid_actual_ptr] <= 0
mid_ptr += 1
mid_actual_ptr = mid_ptr
elseif global_to_own_part[key[mid_actual_ptr]] != 0
if left_ptr != mid_ptr
n_change += 1
end
left_ptr += 1
mid_ptr += 1
mid_actual_ptr = mid_ptr
else
if mid_ptr != right_ptr
n_change += 1
end
mid_actual_ptr = right_ptr
right_ptr -= 1
end
end
Tkey = eltype(key)
change = Vector{Tuple{Tkey,Tkey}}(undef, n_change)
ptr = firstindex(change)
left_ptr = left_start
mid_ptr = mid_start
right_ptr = right_start
while mid_ptr <= right_ptr
if key[mid_ptr] <= 0 || values[1][mid_ptr] <= 0
mid_ptr += 1
elseif global_to_own_part[key[mid_ptr]] != 0
if left_ptr != mid_ptr
key[left_ptr], key[mid_ptr] = key[mid_ptr], key[left_ptr]
for i in 1:N
values[i][left_ptr], values[i][mid_ptr] = values[i][mid_ptr], values[i][left_ptr]
end
change[ptr] = (left_ptr, mid_ptr)
ptr += 1
end
left_ptr += 1
mid_ptr += 1
else
if mid_ptr != right_ptr
key[mid_ptr], key[right_ptr] = key[right_ptr], key[mid_ptr]
for i in 1:N
values[i][mid_ptr], values[i][right_ptr] = values[i][right_ptr], values[i][mid_ptr]
end
change[ptr] = (mid_ptr, right_ptr)
ptr += 1
end
right_ptr -= 1
end
end
left_ptr - 1, mid_ptr, change
end
function partition_and_prepare_snd_buf!(I, J, V, I_owner, parts_snd, rows_sa)
n_hold_data, snd_start_index, change_index = dutch_national_flag_partition!(rows_sa, I, J, V, I_owner)
snd_index = snd_start_index:lastindex(I)
I_raw_snd_data = view(I, snd_index)
J_raw_snd_data = view(J, snd_index)
V_raw_snd_data = view(V, snd_index)
I_raw_snd_owner = view(I_owner, snd_index)
n_snd_data = length(I_raw_snd_data)
I_snd_data = similar(I, n_snd_data)
J_snd_data = similar(I, n_snd_data)
V_snd_data = similar(V, n_snd_data)
owner_to_p = Dict(owner => i for (i, owner) in enumerate(parts_snd))
ptrs = zeros(Int32, length(parts_snd) + 1)
for (i, owner) in enumerate(I_raw_snd_owner)
p = owner_to_p[owner]
I_raw_snd_owner[i] = p
ptrs[p+1] += 1
end
length_to_ptrs!(ptrs)
for (n, (i, j, v, p)) in enumerate(zip(I_raw_snd_data, J_raw_snd_data, V_raw_snd_data, I_raw_snd_owner))
index = ptrs[p]
I_snd_data[index] = i
J_snd_data[index] = j
V_snd_data[index] = v
I_owner[n] = index
ptrs[p] += 1
end
rewind_ptrs!(ptrs)
resize!(I_owner, n_snd_data)
I_snd = JaggedArray(I_snd_data, ptrs)
J_snd = JaggedArray(J_snd_data, ptrs)
V_snd = JaggedArray(V_snd_data, ptrs)
I_snd, J_snd, V_snd, n_hold_data, snd_start_index, change_index, I_owner
end
function store_recv_data!(I, J, V, n_hold_data, I_rcv, J_rcv, V_rcv)
n_data = n_hold_data + length(I_rcv.data)
resize!(I, n_data)
resize!(J, n_data)
resize!(V, n_data)
rcv_index = (n_hold_data+1):n_data
I[rcv_index] = I_rcv.data
J[rcv_index] = J_rcv.data
V[rcv_index] = V_rcv.data
return
end
function split_and_compress!(I, J, V, perm, rows_fa, cols_fa)
n_own_data, ghost_start_index, change_index = dutch_national_flag_partition!(cols_fa, J, I, V)
is_own = firstindex(I):n_own_data
is_ghost = ghost_start_index:lastindex(I)
I_own_own = view(I, is_own)
J_own_own = view(J, is_own)
V_own_own = view(V, is_own)
I_own_ghost = view(I, is_ghost)
J_own_ghost = view(J, is_ghost)
V_own_ghost = view(V, is_ghost)
map_global_to_own!(I_own_own, rows_fa)
map_global_to_own!(J_own_own, cols_fa)
map_global_to_own!(I_own_ghost, rows_fa)
map_global_to_ghost!(J_own_ghost, cols_fa)
n_own_rows = own_length(rows_fa)
n_own_cols = own_length(cols_fa)
n_ghost_rows = ghost_length(rows_fa)
n_ghost_cols = ghost_length(cols_fa)
Ti = eltype(I)
Tv = eltype(V)
combine = +
own_own = f(I_own_own, J_own_own, V_own_own, n_own_rows, n_own_cols, combine)
own_ghost = f(I_own_ghost, J_own_ghost, V_own_ghost, n_own_rows, n_ghost_cols, combine)
ghost_own = f(Ti[], Ti[], Tv[], n_ghost_rows, n_own_cols, combine)
ghost_ghost = f(Ti[], Ti[], Tv[], n_ghost_rows, n_ghost_cols, combine)
blocks = split_matrix_blocks(own_own, own_ghost, ghost_own, ghost_ghost)
perm_own = view(perm, is_own)
perm_ghost = view(perm, is_ghost)
precompute_nzindex!(perm_own, own_own, I_own_own, J_own_own)
precompute_nzindex!(perm_ghost, own_ghost, I_own_ghost, J_own_ghost)
rows_perm = local_permutation(rows_fa)
cols_perm = local_permutation(cols_fa)
split_matrix(blocks, rows_perm, cols_perm), n_own_data, change_index, perm
end
I_owner = find_owner(rows, I)
rows_sa = map(union_ghost, rows, I, I_owner)
parts_snd, parts_rcv = assembly_neighbors(rows_sa)
I_snd_buf, J_snd_buf, V_snd_buf, hold_data_size, snd_start_idx, change_snd, perm_snd = map(partition_and_prepare_snd_buf!, I, J, V, I_owner, parts_snd, rows_sa) |> tuple_of_arrays
graph = ExchangeGraph(parts_snd, parts_rcv)
t_I = exchange(I_snd_buf, graph)
t_J = exchange(J_snd_buf, graph)
t_V = exchange(V_snd_buf, graph)
@fake_async begin
I_rcv_buf = fetch(t_I)
J_rcv_buf = fetch(t_J)
V_rcv_buf = fetch(t_V)
map(store_recv_data!, I, J, V, hold_data_size, I_rcv_buf, J_rcv_buf, V_rcv_buf)
rows_fa = rows
J_owner = find_owner(cols, J)
cols_fa = map(union_ghost, cols, J, J_owner)
vals_fa, own_data_size, change_sparse, perm_sparse = map(split_and_compress!, I, J, V, J_owner, rows_fa, cols_fa) |> tuple_of_arrays
cache = (; graph, V_snd_buf, V_rcv_buf, hold_data_size, snd_start_idx, change_snd, perm_snd, own_data_size, change_sparse, perm_sparse)
assembled = true
PSparseMatrix(vals_fa, rows_fa, cols_fa, assembled), cache
end
end

function matrix_assembly!(A, V, cache)
yssamtu marked this conversation as resolved.
Show resolved Hide resolved
yssamtu marked this conversation as resolved.
Show resolved Hide resolved
function perm_partition!(V, perm::Vector{Tuple{T,T}}) where {T}
for (i, j) in perm
V[i], V[j] = V[j], V[i]
end
end
function partition_and_prepare_snd_buf!(V_snd, V, snd_start_index, change_index, perm)
perm_partition!(V, change_index)
snd_index = snd_start_index:lastindex(V)
V_raw_snd_data = view(V, snd_index)
V_snd_data = V_snd.data
for (p, v) in zip(perm, V_raw_snd_data)
V_snd_data[p] = v
end
end
function store_recv_data!(V, n_hold_data, V_rcv)
n_data = n_hold_data + length(V_rcv.data)
resize!(V, n_data)
rcv_index = (n_hold_data+1):n_data
V[rcv_index] = V_rcv.data
return
end
function split_and_compress!(A, V, n_own_data, change_index, perm)
perm_partition!(V, change_index)
is_own = firstindex(V):n_own_data
is_ghost = (n_own_data+1):lastindex(V)
V_own_own = view(V, is_own)
V_own_ghost = view(V, is_ghost)
perm_own = view(perm, is_own)
perm_ghost = view(perm, is_ghost)
sparse_matrix!(A.blocks.own_own, V_own_own, perm_own)
sparse_matrix!(A.blocks.own_ghost, V_own_ghost, perm_ghost)
return
end
graph, V_snd_buf, V_rcv_buf, hold_data_size, snd_start_idx, change_snd, perm_snd, own_data_size, change_sparse, perm_sparse = cache
map(partition_and_prepare_snd_buf!, V_snd_buf, V, snd_start_idx, change_snd, perm_snd)
t_V = exchange!(V_rcv_buf, V_snd_buf, graph)
@fake_async begin
fetch(t_V)
map(store_recv_data!, V, hold_data_size, V_rcv_buf)
map(split_and_compress!, partition(A), V, own_data_size, change_sparse, perm_sparse)
A
end
end

function assemble(A::PSparseMatrix;kwargs...)
rows = map(remove_ghost,partition(axes(A,1)))
assemble(A,rows;kwargs...)
Expand Down
9 changes: 9 additions & 0 deletions src/sparse_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,15 @@ function precompute_nzindex(A,I,J)
K
end

function precompute_nzindex!(K::AbstractVector{Int32}, A, I, J)
yssamtu marked this conversation as resolved.
Show resolved Hide resolved
for (p, (i, j)) in enumerate(zip(I, J))
if i < 1 || j < 1
continue
end
K[p] = nzindex(A, i, j)
end
end

function sparse_matrix!(A,V,K;reset=true)
if reset
LinearAlgebra.fillstored!(A,0)
Expand Down
Loading