Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallelism using threads #1

Merged
merged 9 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"

[compat]
Bessels = "0.2"
FFTW = "1.7"
LinearAlgebra = "1.9"
StaticArrays = "1.7"
StructArrays = "0.6"
ThreadsX = "0.1"
julia = "1.9"
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ More details on optional parameters and on tuning accuracy is coming soon.
This package roughly follows the same notation and conventions of the [FINUFFT library](https://finufft.readthedocs.io/en/latest/)
and its [Julia interface](https://github.com/ludvigak/FINUFFT.jl), with a few differences detailed below.

For now, parallelism is not supported by this package, but this will come in the near future.
On a single thread, performance is comparable (and often better) than other libraries, including those mentioned below.

### Conventions used by this package

We try to preserve as much as possible the conventions used in FFTW3.
Expand Down
6 changes: 6 additions & 0 deletions src/Kernels/Kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ function kernel_indices(i, ::AbstractKernelData{K, M}, N::Integer) where {K, M}
end
end

# This variant can be used when periodic wrapping is not needed.
# (Used when doing block partitioning for parallelisation using threads.)
function kernel_indices(i, ::AbstractKernelData{K, M}) where {K, M}
(i - M + 1):(i + M)
end

# Returns evaluation points around the normalised location X ∈ [0, 1/M).
# Note that points are returned in decreasing order.
function evaluation_points(::Val{M}, X) where {M}
Expand Down
45 changes: 15 additions & 30 deletions src/NonuniformFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,11 @@ export
exec_type1!,
exec_type2!

include("blocking.jl")
include("plan.jl")
include("set_points.jl")
include("spreading.jl")
include("interpolation.jl")
include("plan.jl")

# Here the element type of `xp` can either be an NTuple{N, <:Real}, an SVector{N, <:Real},
# or anything else which has length `N`.
function set_points!(p::PlanNUFFT{T, N}, xp::AbstractVector) where {T, N}
(; points,) = p
type_length(eltype(xp)) == N || throw(DimensionMismatch(lazy"expected $N-dimensional points"))
resize!(points, length(xp))
Base.require_one_based_indexing(points)
@inbounds for (i, x) ∈ enumerate(xp)
points[i] = NTuple{N}(x) # converts `x` to Tuple if it's an SVector
end
p
end

type_length(::Type{T}) where {T} = length(T) # usually for SVector
type_length(::Type{<:NTuple{N}}) where {N} = N

function set_points!(p::PlanNUFFT{T, N}, xp::NTuple{N, AbstractVector}) where {T, N}
set_points!(p, StructVector(xp))
end

# 1D case
function set_points!(p::PlanNUFFT{T, 1}, xp::AbstractVector{<:Real}) where {T}
set_points!(p, StructVector((xp,)))
end

function check_nufft_uniform_data(p::PlanNUFFT, ûs_k::AbstractArray{<:Complex})
(; ks,) = p.data
Expand All @@ -68,11 +45,15 @@ function check_nufft_uniform_data(p::PlanNUFFT, ûs_k::AbstractArray{<:Complex}
end

function exec_type1!(ûs_k::AbstractArray{<:Complex}, p::PlanNUFFT, charges)
(; points, kernels, data,) = p
(; points, kernels, data, blocks,) = p
(; us, ks,) = data
check_nufft_uniform_data(p, ûs_k)
fill!(us, zero(eltype(us)))
spread_from_points!(kernels, us, points, charges)
if with_blocking(blocks)
spread_from_points_blocked!(kernels, blocks, us, points, charges)
else
spread_from_points!(kernels, us, points, charges) # single-threaded case?
end
ûs = _type1_fft!(data)
T = real(eltype(us))
normfactor::T = prod(N -> 2π / N, size(us)) # FFT normalisation factor
Expand All @@ -94,12 +75,16 @@ function _type1_fft!(data::ComplexNUFFTData)
end

function exec_type2!(vp::AbstractVector, p::PlanNUFFT, ûs_k::AbstractArray{<:Complex})
(; points, kernels, data,) = p
(; points, kernels, data, blocks,) = p
(; us, ks,) = data
check_nufft_uniform_data(p, ûs_k)
ϕ̂s = map(init_fourier_coefficients!, kernels, ks) # this takes time only the first time it's called
_type2_copy_and_fft!(ûs_k, ϕ̂s, data)
interpolate!(kernels, vp, us, points)
if with_blocking(blocks)
interpolate_blocked!(kernels, blocks, vp, us, points)
else
interpolate!(kernels, vp, us, points)
end
vp
end

Expand Down
97 changes: 97 additions & 0 deletions src/blocking.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using ThreadsX: ThreadsX

abstract type AbstractBlockData end

# Dummy type used when blocking has been disabled in the NUFFT plan.
struct NullBlockData <: AbstractBlockData end
with_blocking(::NullBlockData) = false
sort_points!(::NullBlockData, xp) = nothing

struct BlockData{
T, N,
Tr, # = real(T)
Buffers <: AbstractVector{<:AbstractArray{T, N}},
Indices <: CartesianIndices{N},
} <: AbstractBlockData
block_dims :: Dims{N} # size of each block (in number of elements)
block_sizes :: NTuple{N, Tr} # size of each block (in units of length)
buffers :: Buffers
indices :: Indices
cumulative_npoints_per_block :: Vector{Int} # cumulative sum of number of points in each block (length = 1 + num_blocks, initial value is 0)
blockidx :: Vector{Int} # linear index of block associated to each point (length = Np)
pointperm :: Vector{Int} # index permutation for sorting points according to their block (length = Np)
end

function BlockData(::Type{T}, block_dims::Dims{D}, Ñs::Dims{D}, ::HalfSupport{M}) where {T, D, M}
Nt = Threads.nthreads()
# Nt = ifelse(Nt == 1, zero(Nt), Nt) # this disables blocking if running on single thread
dims = block_dims .+ 2M # include padding for values outside of block
Tr = real(T)
block_sizes = map(Ñs, block_dims) do N, B
@inline
Δx = Tr(2π) / N # grid step
B * Δx
end
buffers = map(_ -> Array{T}(undef, dims), 1:Nt)
indices_tup = map(Ñs, block_dims) do N, B
range(0, N - 1; step = B)
end
indices = CartesianIndices(indices_tup)
nblocks = length(indices) # total number of blocks
cumulative_npoints_per_block = Vector{Int}(undef, nblocks + 1)
blockidx = Int[]
pointperm = Int[]
BlockData(block_dims, block_sizes, buffers, indices, cumulative_npoints_per_block, blockidx, pointperm)
end

# Blocking is considered to be disabled if there are no allocated buffers.
with_blocking(bd::BlockData) = !isempty(bd.buffers)

function sort_points!(bd::BlockData, xp::AbstractVector)
with_blocking(bd) || return nothing
(; indices, cumulative_npoints_per_block, blockidx, pointperm, block_sizes,) = bd
fill!(cumulative_npoints_per_block, 0)
to_linear_index = LinearIndices(axes(indices)) # maps Cartesian to linear index of a block
Np = length(xp)
resize!(blockidx, Np)
resize!(pointperm, Np)

@inbounds for (i, x⃗) ∈ pairs(xp)
# Get index of block where point x⃗ is located.
is = map(x⃗, block_sizes) do x, Δx # we assume x⃗ is already in [0, 2π)
# @assert 0 ≤ x < 2π
1 + floor(Int, x / Δx)
end
# checkbounds(indices, CartesianIndex(is))
n = to_linear_index[is...] # linear index of block
cumulative_npoints_per_block[n + 1] += 1
pointperm[i] = i
blockidx[i] = n
end

# Compute cumulative sum (we don't use cumsum! due to aliasing warning in its docs).
for i ∈ eachindex(IndexLinear(), cumulative_npoints_per_block)[2:end]
cumulative_npoints_per_block[i] += cumulative_npoints_per_block[i - 1]
end
@assert cumulative_npoints_per_block[begin] == 0
@assert cumulative_npoints_per_block[end] == Np

if Threads.nthreads() == 1
# This is the same as sortperm! but seems to be faster.
sort!(pointperm; by = i -> @inbounds(blockidx[i]), alg = QuickSort)
# sortperm!(pointperm, blockidx; alg = QuickSort)
else
ThreadsX.sort!(pointperm; by = i -> @inbounds(blockidx[i]), alg = ThreadsX.QuickSort())
end

# Verification
# for i ∈ eachindex(cumulative_npoints_per_block)[begin:end - 1]
# a = cumulative_npoints_per_block[i] + 1
# b = cumulative_npoints_per_block[i + 1]
# for j ∈ a:b
# @assert blockidx[pointperm[j]] == i
# end
# end

nothing
end
104 changes: 102 additions & 2 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
function interpolate!(gs, vp::AbstractArray, us, xp::AbstractArray)
@assert axes(vp) === axes(xp)
for i ∈ eachindex(vp)
x⃗ = to_unit_cell(xp[i])
vp[i] = interpolate(gs, us, x⃗)
vp[i] = interpolate(gs, us, xp[i])
end
vp
end
Expand Down Expand Up @@ -35,6 +34,41 @@ end

interpolate(gs::NTuple, u::AbstractArray, x⃗) = only(interpolate(gs, (u,), x⃗))

function interpolate_blocked(
gs::NTuple{D},
us::NTuple{M, AbstractArray{T, D}} where {T},
x⃗::NTuple{D},
I₀::NTuple{D},
) where {D, M}
@assert M > 0
map(Base.require_one_based_indexing, us)
Ns = size(first(us))
@assert all(u -> size(u) === Ns, us)

# Evaluate 1D kernels.
gs_eval = map(Kernels.evaluate_kernel, gs, x⃗)

Ms = map(Kernels.half_support, gs)
δs = Ms .- I₀ # index offset

# Determine indices to load from `u` arrays.
inds = map(gs_eval, gs, δs) do gdata, g, δ
is = Kernels.kernel_indices(gdata.i, g) # note: this variant doesn't perform periodic wrapping
is .+ δ # shift to beginning of current block
end
Is = CartesianIndices(inds)
# Base.checkbounds(us[1], Is) # check that indices fall inside the output array

vals = map(gs_eval, gs) do geval, g
Δx = gridstep(g)
geval.values .* Δx
end

interpolate_from_arrays_blocked(us, Is, vals)
end

interpolate_blocked(gs::NTuple, u::AbstractArray, args...) = only(interpolate_blocked(gs, (u,), args...))

function interpolate_from_arrays(
us::NTuple{C, AbstractArray{T, D}} where {T},
inds::NTuple{D, Tuple},
Expand All @@ -54,3 +88,69 @@ function interpolate_from_arrays(
end
vs
end

function interpolate_from_arrays_blocked(
us::NTuple{C, AbstractArray{T, D}} where {T},
Is::CartesianIndices{D},
vals::NTuple{D, Tuple},
) where {C, D}
vs = ntuple(_ -> zero(eltype(first(us))), Val(C))
inds_iter = CartesianIndices(map(eachindex, vals))
@inbounds for ns ∈ inds_iter # ns = (ni, nj, ...)
I = Is[ns]
gs = map(getindex, vals, Tuple(ns))
gprod = prod(gs)
vs_new = ntuple(Val(C)) do j
@inline
gprod * us[j][I]
end
vs = vs .+ vs_new
end
vs
end

function interpolate_blocked!(gs, blocks::BlockData, vp::AbstractArray, us, xp::AbstractArray)
@assert axes(vp) === axes(xp)
(; block_dims, cumulative_npoints_per_block, pointperm, buffers, indices,) = blocks
Ms = map(Kernels.half_support, gs)
Nt = length(buffers) # usually equal to the number of threads
nblocks = length(indices)
Base.require_one_based_indexing(buffers)
Base.require_one_based_indexing(indices)
Threads.@threads :static for i ∈ 1:Nt
j_start = (i - 1) * nblocks ÷ Nt + 1
j_end = i * nblocks ÷ Nt
@inbounds for j ∈ j_start:j_end
block = buffers[i]
I₀ = indices[j]

# Indices of current block including padding
inds_block = (I₀ + oneunit(I₀) - CartesianIndex(Ms)):(I₀ + CartesianIndex(block_dims) + CartesianIndex(Ms))
copy_to_block!(block, us, inds_block) # copy local data to `block` array

# Iterate over all points in the current block
a = cumulative_npoints_per_block[j] + 1
b = cumulative_npoints_per_block[j + 1]
for k ∈ a:b
l = pointperm[k]
# @assert blocks.blockidx[l] == j # check that point is really in the current block
x⃗ = xp[l] # if points have not been permuted
# x⃗ = xp[k] # if points have been permuted (may be slightly faster here, but requires permutation in sort_points!)
vp[l] = interpolate_blocked(gs, block, x⃗, Tuple(I₀))
end
end
end
vp
end

function copy_to_block!(block::AbstractArray, us::AbstractArray, inds::CartesianIndices)
@assert size(block) == size(inds)
Base.require_one_based_indexing(us)
Ñs = size(us)
@inbounds for i ∈ eachindex(block, inds)
I = inds[i]
is = map(wrap_periodic, Tuple(I), Ñs) # wrap_periodic is defined in spreading.jl
block[i] = us[is...]
end
us
end
32 changes: 29 additions & 3 deletions src/plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,45 @@ struct PlanNUFFT{
Kernels <: NTuple{N, AbstractKernelData{<:AbstractKernel, M, Treal}},
Points <: StructVector{NTuple{N, Treal}},
Data <: AbstractNUFFTData{T, N},
Blocks <: AbstractBlockData,
}
kernels :: Kernels
σ :: Treal # oversampling factor (≥ 1)
points :: Points # non-uniform points (real values)
data :: Data
blocks :: Blocks
end

"""
size(p::PlanNUFFT) -> (N₁, N₂, ...)

Return the dimensions of arrays containing uniform values.

This corresponds to the number of Fourier modes in each direction.
This corresponds to the number of Fourier modes in each direction (in the non-oversampled grid).
"""
Base.size(p::PlanNUFFT) = map(length, p.data.ks)

# Case of real-to-complex transform.
default_block_size() = 4096 # in number of linear elements

function get_block_dims(Ñs::Dims, bsize::Int)
d = length(Ñs)
bdims = @. false * Ñs + 1 # size along each direction (initially 1 in each direction)
bprod = 1 # product of sizes
i = 1 # current dimension
while bprod < bsize
# Multiply block size by 2 in the current dimension.
bdims = Base.setindex(bdims, bdims[i] << 1, i)
bprod <<= 1
i = ifelse(i == d, 1, i + 1)
end
bdims
end

# This constructor is generally not called directly.
function _PlanNUFFT(
::Type{T}, kernel::AbstractKernel, h::HalfSupport, σ_wanted, Ns::Dims{D};
fftw_flags = FFTW.MEASURE,
block_size::Union{Integer, Nothing} = default_block_size(),
) where {T <: Number, D}
ks = init_wavenumbers(T, Ns)
# Determine dimensions of oversampled grid.
Expand All @@ -87,8 +105,16 @@ function _PlanNUFFT(
Kernels.optimal_kernel(kernel, h, Δx̃, Ñ / N)
end
points = StructVector(ntuple(_ -> Tr[], Val(D)))
if block_size === nothing
blocks = NullBlockData() # disable blocking (→ can't use multithreading when spreading)
FFTW.set_num_threads(1) # also disable FFTW threading (avoids allocations)
else
block_dims = get_block_dims(Ñs, block_size)
blocks = BlockData(T, block_dims, Ñs, h)
FFTW.set_num_threads(Threads.nthreads())
end
nufft_data = init_plan_data(T, Ñs, ks; fftw_flags)
PlanNUFFT(kernel_data, σ, points, nufft_data)
PlanNUFFT(kernel_data, σ, points, nufft_data, blocks)
end

init_wavenumbers(::Type{T}, Ns::Dims) where {T <: AbstractFloat} = ntuple(Val(length(Ns))) do i
Expand Down
Loading