Skip to content

Commit

Permalink
Improve load balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
jipolanco committed Dec 19, 2023
1 parent fae7384 commit 982d4b0
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 20 deletions.
46 changes: 44 additions & 2 deletions src/blocking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ struct BlockData{
} <: AbstractBlockData
block_dims :: Dims{N} # size of each block (in number of elements)
block_sizes :: NTuple{N, Tr} # size of each block (in units of length)
buffers :: Buffers # length = nthreads()
buffers :: Buffers # length = nthreads
blocks_per_thread :: Vector{Int} # maps a set of blocks i_start:i_end to a thread (length = nthreads + 1)
indices :: Indices # index associated to each block (length = num_blocks)
buffers_for_indices :: Vector{NTuple{N, Vector{Int}}} # maps values of current buffer to indices in global array (length = nthreads())
cumulative_npoints_per_block :: Vector{Int} # cumulative sum of number of points in each block (length = 1 + num_blocks, initial value is 0)
Expand Down Expand Up @@ -46,8 +47,9 @@ function BlockData(::Type{T}, block_dims::Dims{D}, Ñs::Dims{D}, ::HalfSupport{M
cumulative_npoints_per_block = Vector{Int}(undef, nblocks + 1)
blockidx = Int[]
pointperm = Int[]
blocks_per_thread = zeros(Int, Nt + 1)
BlockData(
block_dims, block_sizes, buffers, indices, buffers_for_indices,
block_dims, block_sizes, buffers, blocks_per_thread, indices, buffers_for_indices,
cumulative_npoints_per_block, blockidx, pointperm,
)
end
Expand Down Expand Up @@ -84,6 +86,11 @@ function sort_points!(bd::BlockData, xp::AbstractVector)
@assert cumulative_npoints_per_block[begin] == 0
@assert cumulative_npoints_per_block[end] == Np

# Determine how many blocks each thread will manage. The idea is that, if the point
# distribution is inhomogeneous, then more threads are dedicated to areas where points
# are concentrated, improving load balance.
map_blocks_to_threads!(bd.blocks_per_thread, cumulative_npoints_per_block)

if Threads.nthreads() == 1
# This is the same as sortperm! but seems to be faster.
sort!(pointperm; by = i -> @inbounds(blockidx[i]), alg = QuickSort)
Expand All @@ -103,3 +110,38 @@ function sort_points!(bd::BlockData, xp::AbstractVector)

nothing
end

function map_blocks_to_threads!(blocks_per_thread, cumulative_npoints_per_block)
Np = last(cumulative_npoints_per_block) # total number of points
Nt = length(blocks_per_thread) - 1 # number of threads
Np_per_thread = Np / Nt # target number of points per thread
blocks_per_thread[begin] = 0
@assert cumulative_npoints_per_block[begin] == 0
n = firstindex(cumulative_npoints_per_block) - 1
nblocks = length(cumulative_npoints_per_block) - 1
Base.require_one_based_indexing(cumulative_npoints_per_block)
for i 1:Nt
npoints_in_current_thread = 0
stop = false
while npoints_in_current_thread < Np_per_thread
n += 1
if n > nblocks
stop = true
break
end
npoints_in_block = cumulative_npoints_per_block[n + 1] - cumulative_npoints_per_block[n]
npoints_in_current_thread += npoints_in_block
end
if stop
blocks_per_thread[begin + i] = nblocks # this thread ends at the last block (inclusive)
for j (i + 1):Nt
blocks_per_thread[begin + j] = nblocks # this thread does no work (starts and ends at the last block)
end
break
else
blocks_per_thread[begin + i] = n # this thread ends at block `n` (inclusive)
end
end
blocks_per_thread[end] = nblocks # make sure the last block is included
blocks_per_thread
end
20 changes: 11 additions & 9 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,24 @@ function interpolate_from_arrays_blocked(
vs
end

function interpolate_blocked!(gs, blocks::BlockData, vp::AbstractArray, us, xp::AbstractArray)
function interpolate_blocked!(gs, bd::BlockData, vp::AbstractArray, us, xp::AbstractArray)
@assert axes(vp) === axes(xp)
(; block_dims, cumulative_npoints_per_block, pointperm, buffers, indices,) = blocks
(; block_dims, pointperm, buffers, indices,) = bd
Ms = map(Kernels.half_support, gs)
Nt = length(buffers) # usually equal to the number of threads
nblocks = length(indices)
# nblocks = length(indices)
Base.require_one_based_indexing(buffers)
Base.require_one_based_indexing(indices)
Threads.@threads :static for i 1:Nt
j_start = (i - 1) * nblocks ÷ Nt + 1
j_end = i * nblocks ÷ Nt
# j_start = (i - 1) * nblocks ÷ Nt + 1
# j_end = i * nblocks ÷ Nt
j_start = bd.blocks_per_thread[i] + 1
j_end = bd.blocks_per_thread[i + 1]
block = buffers[i]
inds_wrapped = blocks.buffers_for_indices[i]
inds_wrapped = bd.buffers_for_indices[i]
@inbounds for j j_start:j_end
a = cumulative_npoints_per_block[j]
b = cumulative_npoints_per_block[j + 1]
a = bd.cumulative_npoints_per_block[j]
b = bd.cumulative_npoints_per_block[j + 1]
a == b && continue # no points in this block (otherwise b > a)

# Indices of current block including padding
Expand All @@ -137,7 +139,7 @@ function interpolate_blocked!(gs, blocks::BlockData, vp::AbstractArray, us, xp::
# Iterate over all points in the current block
for k (a + 1):b
l = pointperm[k]
# @assert blocks.blockidx[l] == j # check that point is really in the current block
# @assert bd.blockidx[l] == j # check that point is really in the current block
x⃗ = xp[l] # if points have not been permuted
# x⃗ = xp[k] # if points have been permuted (may be slightly faster here, but requires permutation in sort_points!)
vp[l] = interpolate_blocked(gs, block, x⃗, Tuple(I₀))
Expand Down
20 changes: 11 additions & 9 deletions src/spreading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,34 @@ function spread_from_point_blocked!(gs::NTuple, u::AbstractArray, x⃗₀, v::Nu
end

function spread_from_points_blocked!(
gs, blocks::BlockData, us::AbstractArray, xp::AbstractVector, vp::AbstractVector,
gs, bd::BlockData, us::AbstractArray, xp::AbstractVector, vp::AbstractVector,
)
(; block_dims, cumulative_npoints_per_block, pointperm, buffers, indices,) = blocks
(; block_dims, pointperm, buffers, indices,) = bd
Ms = map(Kernels.half_support, gs)
fill!(us, zero(eltype(us)))
Nt = length(buffers) # usually equal to the number of threads
nblocks = length(indices)
# nblocks = length(indices)
Base.require_one_based_indexing(buffers)
Base.require_one_based_indexing(indices)
lck = ReentrantLock()
Threads.@threads :static for i 1:Nt
j_start = (i - 1) * nblocks ÷ Nt + 1
j_end = i * nblocks ÷ Nt
# j_start = (i - 1) * nblocks ÷ Nt + 1
# j_end = i * nblocks ÷ Nt
j_start = bd.blocks_per_thread[i] + 1
j_end = bd.blocks_per_thread[i + 1]
block = buffers[i]
inds_wrapped = blocks.buffers_for_indices[i]
inds_wrapped = bd.buffers_for_indices[i]
@inbounds for j j_start:j_end
a = cumulative_npoints_per_block[j]
b = cumulative_npoints_per_block[j + 1]
a = bd.cumulative_npoints_per_block[j]
b = bd.cumulative_npoints_per_block[j + 1]
a == b && continue # no points in this block (otherwise b > a)

# Iterate over all points in the current block
I₀ = indices[j]
fill!(block, zero(eltype(block)))
for k (a + 1):b
l = pointperm[k]
# @assert blocks.blockidx[l] == j # check that point is really in the current block
# @assert bd.blockidx[l] == j # check that point is really in the current block
x⃗ = xp[l] # if points have not been permuted
# x⃗ = xp[k] # if points have been permuted (may be slightly faster here, but requires permutation in sort_points!)
v = vp[l]
Expand Down

0 comments on commit 982d4b0

Please sign in to comment.