Skip to content

Commit

Permalink
GPU: disable explicit synchronisation barriers by default (#32)
Browse files Browse the repository at this point in the history
* Disable explicit GPU synchronisation by default

* Warning in p.timer on GPU without synchronisation

* Fix get_timer

* Avoid warnings in internal calls to p.timer

* Add extra inference test
  • Loading branch information
jipolanco authored Sep 25, 2024
1 parent 1b320ff commit 323962d
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 27 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- Removed explicit GPU synchronisation barriers (using `KA.synchronize`) by default.
This can now be re-enabled by passing `synchronise = true` as a plan argument.
Enabling synchronisation is useful for getting accurate timings (in `p.timer`) but
may result in decreased performance.

## [v0.5.3] - 2024-09-24

### Changed
Expand Down
22 changes: 12 additions & 10 deletions src/NonuniformFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ See also [`exec_type2!`](@ref).
function exec_type1! end

function exec_type1!(ûs_k::NTuple{C, AbstractArray{<:Complex}}, p::PlanNUFFT, vp::NTuple{C}) where {C}
(; backend, points, kernels, data, blocks, index_map, timer,) = p
(; backend, points, kernels, data, blocks, index_map,) = p
(; us,) = data
timer = get_timer_nowarn(p)

@timeit timer "Execute type 1" begin
check_nufft_uniform_data(p, ûs_k)
Expand All @@ -142,25 +143,25 @@ function exec_type1!(ûs_k::NTuple{C, AbstractArray{<:Complex}}, p::PlanNUFFT,
local workgroupsize = default_workgroupsize(backend, ndrange)
local kernel! = fill_with_zeros_kernel!(backend, workgroupsize, ndrange)
kernel!(us)
KA.synchronize(backend)
maybe_synchronise(p)
end

@timeit timer "(1) Spreading" begin
spread_from_points!(backend, blocks, kernels, us, points, vp)
KA.synchronize(backend)
maybe_synchronise(p)
end

@timeit timer "(2) Forward FFT" begin
ûs = _type1_fft!(data)
KA.synchronize(backend)
maybe_synchronise(p)
end

@timeit timer "(3) Deconvolution" begin
T = real(eltype(first(us)))
normfactor::T = prod(N -> 2π / N, size(first(us))) # FFT normalisation factor
ϕ̂s = map(fourier_coefficients, kernels)
copy_deconvolve_to_non_oversampled!(backend, ûs_k, ûs, index_map, ϕ̂s, normfactor) # truncate to original grid + normalise
KA.synchronize(backend)
maybe_synchronise(p)
end
end

Expand Down Expand Up @@ -208,8 +209,9 @@ See also [`exec_type1!`](@ref).
function exec_type2! end

function exec_type2!(vp::NTuple{C, AbstractVector}, p::PlanNUFFT, ûs_k::NTuple{C, AbstractArray{<:Complex}}) where {C}
(; backend, points, kernels, data, blocks, index_map, timer,) = p
(; backend, points, kernels, data, blocks, index_map,) = p
(; us,) = data
timer = get_timer_nowarn(p)

@timeit timer "Execute type 2" begin
check_nufft_uniform_data(p, ûs_k)
Expand All @@ -230,23 +232,23 @@ function exec_type2!(vp::NTuple{C, AbstractVector}, p::PlanNUFFT, ûs_k::NTuple
local workgroupsize = default_workgroupsize(backend, ndrange)
local kernel! = fill_with_zeros_kernel!(backend, workgroupsize, ndrange)
kernel!(ûs)
KA.synchronize(backend)
maybe_synchronise(p)
end

@timeit timer "(1) Deconvolution" begin
ϕ̂s = map(fourier_coefficients, kernels)
copy_deconvolve_to_oversampled!(backend, ûs, ûs_k, index_map, ϕ̂s)
KA.synchronize(backend)
maybe_synchronise(p)
end

@timeit timer "(2) Backward FFT" begin
_type2_fft!(data)
KA.synchronize(backend)
maybe_synchronise(p)
end

@timeit timer "(3) Interpolation" begin
interpolate!(backend, blocks, kernels, vp, us, points)
KA.synchronize(backend)
maybe_synchronise(p)
end
end

Expand Down
7 changes: 6 additions & 1 deletion src/abstractNFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Base.@constprop :aggressive function PlanNUFFT(
fftflags = FFTW.ESTIMATE, blocking = true, sortNodes = false,
window = default_kernel(),
fftshift = true, # for compatibility with NFFT.jl
synchronise = false,
kwargs...,
) where {Tr <: AbstractFloat}
# Note: the NFFT.jl package uses an odd window size, w = 2m + 1.
Expand All @@ -93,7 +94,11 @@ Base.@constprop :aggressive function PlanNUFFT(
sort_points = sortNodes ? True() : False() # this is type-unstable (unless constant propagation happens)
block_size = blocking ? default_block_size(Ns, backend) : nothing # also type-unstable
kernel = window isa AbstractKernel ? window : convert_window_function(window)
p = PlanNUFFT(Complex{Tr}, Ns, HalfSupport(m); backend, σ = Tr(σ), sort_points, fftshift, block_size, kernel, fftw_flags = fftflags)
p = PlanNUFFT(
Complex{Tr}, Ns, HalfSupport(m);
backend, σ = Tr(σ), sort_points, fftshift, block_size,
kernel, fftw_flags = fftflags, synchronise,
)
AbstractNFFTs.nodes!(p, xp)
p
end
32 changes: 19 additions & 13 deletions src/blocking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ end

# Here the element type of `xp` can be either an NTuple{N, <:Real}, an SVector{N, <:Real},
# or anything else which has length `N`.
function set_points!(backend, ::NullBlockData, points::StructVector, xp, timer; transform::F = identity) where {F <: Function}
function set_points_impl!(
backend, ::NullBlockData, points::StructVector, xp, timer;
synchronise, transform::F = identity
) where {F <: Function}
length(points) == length(xp) || resize_no_copy!(points, length(xp))
KA.synchronize(backend)
maybe_synchronise(backend, synchronise)
Base.require_one_based_indexing(points)
@timeit timer "(1) Copy + fold" begin
# NOTE: we explicitly iterate through StructVector components because CUDA.jl
Expand All @@ -30,7 +33,7 @@ function set_points!(backend, ::NullBlockData, points::StructVector, xp, timer;
points_comp = StructArrays.components(points)
kernel! = copy_points_unblocked_kernel!(backend)
kernel!(transform, points_comp, xp; ndrange = size(xp))
KA.synchronize(backend) # mostly to get accurate timings
maybe_synchronise(backend, synchronise)
end
nothing
end
Expand Down Expand Up @@ -100,9 +103,9 @@ end
get_pointperm(bd::BlockDataGPU) = bd.pointperm
get_sort_points(bd::BlockDataGPU) = bd.sort_points

function set_points!(
function set_points_impl!(
backend::GPU, bd::BlockDataGPU, points::StructVector, xp, timer;
transform::F = identity,
transform::F = identity, synchronise,
) where {F <: Function}
(;
cumulative_npoints_per_block, nblocks_per_dir, block_sizes,
Expand All @@ -120,7 +123,7 @@ function set_points!(
resize_no_copy!(pointperm, Np)
resize_no_copy!(points, Np)
fill!(cumulative_npoints_per_block, 0)
KA.synchronize(backend)
maybe_synchronise(backend, synchronise)
end

# We avoid passing a StructVector to the kernel, so we pass `points` as a tuple of
Expand All @@ -137,15 +140,15 @@ function set_points!(
block_sizes, nblocks_per_dir, sort_points, transform;
ndrange,
)
KA.synchronize(backend)
maybe_synchronise(backend, synchronise)
end

@timeit timer "(2) Cumulative sum" begin
# Note: the Julia docs state that this can fail if the accumulation is done in
# place. With CUDA, this doesn't seem to be a problem, but we could allocate a
# separate array if it becomes an issue.
cumsum!(cumulative_npoints_per_block, cumulative_npoints_per_block)
KA.synchronize(backend)
maybe_synchronise(backend, synchronise)
end

# Compute permutation needed to sort points according to their block.
Expand All @@ -156,7 +159,7 @@ function set_points!(
block_sizes, nblocks_per_dir, transform;
ndrange,
)
KA.synchronize(backend)
maybe_synchronise(backend, synchronise)
end

# `pointperm` now contains the permutation needed to sort points
Expand All @@ -165,7 +168,7 @@ function set_points!(
@timeit timer "(4) Permute points" let
local kernel! = permute_kernel!(backend, workgroupsize)
kernel!(points_comp, xp, pointperm, transform; ndrange)
KA.synchronize(backend)
maybe_synchronise(backend, synchronise)
end
end

Expand Down Expand Up @@ -313,9 +316,13 @@ function BlockData(
)
end

function set_points!(backend::CPU, bd::BlockData, points::StructVector, xp, timer; transform::F = identity) where {F <: Function}
function set_points_impl!(
backend::CPU, bd::BlockData, points::StructVector, xp, timer;
transform::F = identity,
synchronise,
) where {F <: Function}
# This technically never happens, but we might use it as a way to disable blocking.
isempty(bd.buffers) && return set_points!(backend, NullBlockData(), points, xp, timer; transform)
isempty(bd.buffers) && return set_points_impl!(backend, NullBlockData(), points, xp, timer; transform, synchronise)

(; indices, cumulative_npoints_per_block, blockidx, pointperm, block_sizes,) = bd
N = type_length(eltype(xp)) # = number of dimensions
Expand All @@ -328,7 +335,6 @@ function set_points!(backend::CPU, bd::BlockData, points::StructVector, xp, time
resize_no_copy!(pointperm, Np)
resize_no_copy!(points, Np)
fill!(cumulative_npoints_per_block, 0)
KA.synchronize(backend)
end

@timeit timer "(1) Assign blocks" @inbounds for (i, x⃗) pairs(xp)
Expand Down
30 changes: 29 additions & 1 deletion src/plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ The created plan contains all data needed to perform NUFFTs for non-uniform data
By default the plan creates its own timer.
One can visualise the time spent on different parts of the NUFFT computation using `p.timer`.
- `synchronise = false`: if `true`, add synchronisation barrier between calls to GPU kernels.
Enabling this is needed for accurate timings in `p.timer` when computing on a GPU, but may
result in reduced performance.
# FFT size and performance
For performance reasons, when doing FFTs one usually wants the size of the input along each
Expand Down Expand Up @@ -225,6 +229,7 @@ struct PlanNUFFT{
fftshift :: Bool
index_map :: IndexMap
timer :: Timer
synchronise :: Bool
end

function Base.show(io::IO, p::PlanNUFFT{T, N, Nc}) where {T, N, Nc}
Expand All @@ -240,6 +245,25 @@ function Base.show(io::IO, p::PlanNUFFT{T, N, Nc}) where {T, N, Nc}
nothing
end

get_timer_nowarn(p::PlanNUFFT) = getfield(p, :timer)

# Show warning if timer is retrieved in cases where timings may be incorrect.
function get_timer(p::PlanNUFFT)
(; backend, synchronise,) = p
if backend isa GPU && !synchronise
@warn "synchronisation is disabled on GPU: timings will be incorrect"
end
get_timer_nowarn(p)
end

@inline function Base.getproperty(p::PlanNUFFT, name::Symbol)
if name === :timer
get_timer(p)
else
getfield(p, name)
end
end

"""
size(p::PlanNUFFT) -> (N₁, N₂, ...)
Expand Down Expand Up @@ -274,6 +298,9 @@ end

get_block_dims(::Dims{N}, bdims::NTuple{N}) where {N} = bdims

maybe_synchronise(backend::KA.Backend, synchronise::Bool) = synchronise && KA.synchronize(backend) # this doesn't do anything on the CPU
maybe_synchronise(p::PlanNUFFT) = maybe_synchronise(p.backend, p.synchronise)

# This constructor is generally not called directly.
function _PlanNUFFT(
::Type{T}, kernel::AbstractKernel, h::HalfSupport, σ_wanted, Ns::Dims{D},
Expand All @@ -284,6 +311,7 @@ function _PlanNUFFT(
sort_points::StaticBool = False(),
backend::KA.Backend = CPU(),
block_size::Union{Integer, Dims{D}, Nothing} = default_block_size(Ns, backend),
synchronise::Bool = false,
) where {T <: Number, D}
ks = init_wavenumbers(T, Ns)
# Determine dimensions of oversampled grid.
Expand Down Expand Up @@ -331,7 +359,7 @@ function _PlanNUFFT(
indmap = KA.allocate(backend, eltype(inds), length(k))
non_oversampled_indices!(indmap, k, inds; fftshift)
end
PlanNUFFT(kernel_data, backend, σ, points, nufft_data, blocks, fftshift, index_map, timer)
PlanNUFFT(kernel_data, backend, σ, points, nufft_data, blocks, fftshift, index_map, timer, synchronise)
end

function check_nufft_size(Ñ, ::HalfSupport{M}) where M
Expand Down
5 changes: 3 additions & 2 deletions src/set_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ end
# Here the element type of `xp` can be either an NTuple{N, <:Real}, an SVector{N, <:Real},
# or anything else which has length `N`.
function set_points!(p::PlanNUFFT, xp::AbstractVector; kwargs...)
(; points, timer,) = p
(; points, synchronise,) = p
timer = get_timer_nowarn(p)
N = ndims(p)
type_length(eltype(xp)) == N || throw(DimensionMismatch(lazy"expected $N-dimensional points"))
@timeit timer "Set points" set_points!(p.backend, p.blocks, points, xp, timer; kwargs...)
@timeit timer "Set points" set_points_impl!(p.backend, p.blocks, points, xp, timer; synchronise, kwargs...)
p
end

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2 changes: 2 additions & 0 deletions test/near_2pi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Test
using NonuniformFFTs
using NonuniformFFTs: Kernels
using FFTW: fftfreq, rfftfreq
using TimerOutputs: TimerOutput

function type1_exact!(us, ks, xp, vp)
fill!(us, 0)
Expand All @@ -26,6 +27,7 @@ end
# NOTE: these parameters allow to reproduce issue when a point is very close to 2π.
N = 32
plan = PlanNUFFT(T, N; m = HalfSupport(8), σ = 1.5, block_size = 16)
@test @inferred((p -> p.timer)(plan)) isa TimerOutput
set_points!(plan, xp)

us = Array{T}(undef, N)
Expand Down

0 comments on commit 323962d

Please sign in to comment.