Skip to content

Commit

Permalink
Simplify main GPU kernels using Adapt (#34)
Browse files Browse the repository at this point in the history
* Define Adapt.adapt_structure for KernelData types

* Update GPU spreading

* Update GPU interpolation + remove unused function
  • Loading branch information
jipolanco authored Oct 15, 2024
1 parent ab21f69 commit 0df18d9
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 26 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.5.5"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AbstractNFFTs = "7f219486-4aa7-41d6-80a7-e08ef20ceed7"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
Expand All @@ -22,6 +23,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
[compat]
AbstractFFTs = "1.5.0"
AbstractNFFTs = "0.8.2"
Adapt = "4.0.4"
Atomix = "0.1.0"
Bessels = "0.2"
Bumper = "0.6, 0.7"
Expand Down
7 changes: 1 addition & 6 deletions src/Kernels/Kernels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Kernels

using KernelAbstractions: KernelAbstractions as KA
using Adapt: Adapt, adapt

export HalfSupport

Expand Down Expand Up @@ -81,18 +82,12 @@ end
end

# Note: evaluate_kernel_func generates a function which is callable from GPU kernels.
# (Directly passing an AbstractKernelData to a GPU kernel fails, at least with CUDA).
@inline evaluate_kernel(g::AbstractKernelData, x₀) = evaluate_kernel_func(g)(x₀)

@inline function kernel_indices(i, ::AbstractKernelData{K, M}, args...) where {K, M}
kernel_indices(i, HalfSupport(M), args...)
end

# Returns a function which is callable from GPU kernels.
function kernel_indices_func(::AbstractKernelData{K, M}) where {K, M}
@inline (i, args...) -> kernel_indices(i, HalfSupport(M), args...)
end

# Takes into account periodic wrapping.
# This is equivalent to calling mod1(j, N) for each j, but much much faster.
# We assume the central index `i` is in 1:N and that M < N / 2.
Expand Down
14 changes: 13 additions & 1 deletion src/Kernels/bspline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,27 @@ struct BSplineKernelData{
σ :: T
Δt :: T # knot separation
gk :: FourierCoefs # values in uniform Fourier grid

function BSplineKernelData{M}::T, Δt::T, gk) where {M, T <: AbstractFloat}
new{M, T, typeof(gk)}(σ, Δt, gk)
end

function BSplineKernelData{M}(backend::KA.Backend, Δx::Real) where {M}
Δt = Δx
σ = sqrt(M / 6) * Δt
T = eltype(Δt)
gk = KA.allocate(backend, T, 0)
new{M, T, typeof(gk)}(T(σ), Δt, gk)
BSplineKernelData{M}, Δt, gk)
end
end

function Adapt.adapt_structure(to, g::BSplineKernelData{M}) where {M}
BSplineKernelData{M}(
g.σ, g.Δt,
adapt(to, g.gk),
)
end

gridstep(g::BSplineKernelData) = g.Δt # assume Δx = Δt

BSplineKernelData(::HalfSupport{M}, args...) where {M} = BSplineKernelData{M}(args...)
Expand Down
14 changes: 13 additions & 1 deletion src/Kernels/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ struct GaussianKernelData{
cs :: NTuple{M, T} # precomputed exponentials
gk :: FourierCoefs # values in uniform Fourier grid

function GaussianKernelData{M}(Δx::T, σ::T, τ::T, cs::NTuple{M, T}, gk) where {M, T <: AbstractFloat}
new{M, T, typeof(gk)}(Δx, σ, τ, cs, gk)
end

function GaussianKernelData{M}(backend::KA.Backend, Δx::T, α::T) where {M, T <: AbstractFloat}
σ = α * Δx
τ = 2 * σ^2
Expand All @@ -74,10 +78,18 @@ struct GaussianKernelData{
exp(-x^2 / τ)
end
gk = KA.allocate(backend, T, 0)
new{M, T, typeof(gk)}(Δx, σ, τ, cs, gk)
GaussianKernelData{M}(Δx, σ, τ, cs, gk)
end
end

function Adapt.adapt_structure(to, g::GaussianKernelData{M}) where {M}
GaussianKernelData{M}(
g.Δx, g.σ, g.τ,
adapt(to, g.cs),
adapt(to, g.gk),
)
end

GaussianKernelData(::HalfSupport{M}, args...) where {M} = GaussianKernelData{M}(args...)

function Base.show(io::IO, g::GaussianKernelData{M}) where {M}
Expand Down
14 changes: 13 additions & 1 deletion src/Kernels/kaiser_bessel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ struct KaiserBesselKernelData{
cs :: ApproxCoefs # coefficients of polynomial approximation
gk :: FourierCoefs

function KaiserBesselKernelData{M}(Δx::T, σ::T, w::T, β::T, β²::T, cs, gk) where {M, T <: AbstractFloat}
new{M, T, typeof(cs), typeof(gk)}(Δx, σ, w, β, β², cs, gk)
end

function KaiserBesselKernelData{M}(backend::KA.Backend, Δx::T, β::T) where {M, T <: AbstractFloat}
w = M * Δx
σ = sqrt(kb_equivalent_variance(β)) * w
Expand All @@ -116,10 +120,18 @@ struct KaiserBesselKernelData{
cs = solve_piecewise_polynomial_coefficients(T, Val(M), Val(Npoly)) do x
besseli0* sqrt(1 - x^2))
end
new{M, T, typeof(cs), typeof(gk)}(Δx, σ, w, β, β², cs, gk)
KaiserBesselKernelData{M}(Δx, σ, w, β, β², cs, gk)
end
end

function Adapt.adapt_structure(to, g::KaiserBesselKernelData{M}) where {M}
KaiserBesselKernelData{M}(
g.Δx, g.σ, g.w, g.β, g.β²,
adapt(to, g.cs),
adapt(to, g.gk),
)
end

KaiserBesselKernelData(::HalfSupport{M}, args...) where {M} =
KaiserBesselKernelData{M}(args...)

Expand Down
14 changes: 13 additions & 1 deletion src/Kernels/kaiser_bessel_backwards.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ struct BackwardsKaiserBesselKernelData{
cs :: ApproxCoefs # coefficients of polynomial approximation
gk :: FourierCoefs

function BackwardsKaiserBesselKernelData{M}(Δx::T, σ::T, w::T, β::T, cs, gk) where {M, T <: AbstractFloat}
new{M, T, typeof(cs), typeof(gk)}(Δx, σ, w, β, cs, gk)
end

function BackwardsKaiserBesselKernelData{M}(backend::KA.Backend, Δx::T, β::T) where {M, T <: AbstractFloat}
w = M * Δx
σ = sqrt(backwards_kb_equivalent_variance(β)) * w
Expand All @@ -87,10 +91,18 @@ struct BackwardsKaiserBesselKernelData{
s = sqrt(1 - x^2)
sinh* s) / (s * oftype(x, π))
end
new{M, T, typeof(cs), typeof(gk)}(Δx, σ, w, β, cs, gk)
BackwardsKaiserBesselKernelData{M}(Δx, σ, w, β, cs, gk)
end
end

function Adapt.adapt_structure(to, g::BackwardsKaiserBesselKernelData{M}) where {M}
BackwardsKaiserBesselKernelData{M}(
g.Δx, g.σ, g.w, g.β,
adapt(to, g.cs),
adapt(to, g.gk),
)
end

BackwardsKaiserBesselKernelData(::HalfSupport{M}, args...) where {M} =
BackwardsKaiserBesselKernelData{M}(args...)

Expand Down
13 changes: 5 additions & 8 deletions src/interpolation/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ using StaticArrays: MVector
# Interpolate onto a single point
@kernel function interpolate_to_point_naive_kernel!(
vp::NTuple{C},
@Const(gs::NTuple{D}),
@Const(points::NTuple{D}),
@Const(us::NTuple{C}),
@Const(pointperm),
@Const(Δxs::NTuple{D}), # grid step in each direction (oversampled grid)
evaluate::NTuple{D, <:Function}, # can't be marked Const for some reason
to_indices::NTuple{D, <:Function},
) where {C, D}
i = @index(Global, Linear)

Expand All @@ -26,15 +25,15 @@ using StaticArrays: MVector
Ns = size(first(us)) # grid dimensions

# Evaluate 1D kernels.
gs_eval = map((f, x) -> f(x), evaluate, x⃗)
gs_eval = map(Kernels.evaluate_kernel, gs, x⃗)

# Determine indices to load from `u` arrays.
indvals = ntuple(Val(D)) do n
@inbounds begin
gdata = gs_eval[n]
vals = gdata.values .* Δxs[n]
f = to_indices[n]
f(gdata.i, Ns[n]) => vals
inds = Kernels.kernel_indices(gdata.i, gs[n], Ns[n])
inds => vals
end
end

Expand All @@ -59,8 +58,6 @@ function interpolate!(
Base.require_one_based_indexing(x⃗s) # this is to make sure that all indices match
foreach(Base.require_one_based_indexing, vp_all)

evaluate = map(Kernels.evaluate_kernel_func, gs) # kernel evaluation functions
to_indices = map(Kernels.kernel_indices_func, gs) # functions returning spreading indices
xs_comp = StructArrays.components(x⃗s)
Δxs = map(Kernels.gridstep, gs)

Expand All @@ -84,7 +81,7 @@ function interpolate!(
ndrange = size(x⃗s) # iterate through points
workgroupsize = default_workgroupsize(backend, ndrange)
kernel! = interpolate_to_point_naive_kernel!(backend, workgroupsize)
kernel!(vp_sorted, xs_comp, us, pointperm_, Δxs, evaluate, to_indices; ndrange)
kernel!(vp_sorted, gs, xs_comp, us, pointperm_, Δxs; ndrange)

if sort_points === True()
kernel_perm! = interp_permute_kernel!(backend, workgroupsize)
Expand Down
13 changes: 5 additions & 8 deletions src/spreading/gpu.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Spread from a single point
@kernel function spread_from_point_naive_kernel!(
us::NTuple{C},
@Const(gs::NTuple{D}),
@Const(points::NTuple{D}),
@Const(vp::NTuple{C}),
@Const(pointperm),
evaluate::NTuple{D, <:Function}, # can't be marked Const for some reason
to_indices::NTuple{D, <:Function},
) where {C, D}
i = @index(Global, Linear)

Expand All @@ -28,15 +27,15 @@
Ns = spread_actual_dims(Z, Ns_real) # divides the Ns_real[1] by 2 if Z <: Complex

# Evaluate 1D kernels.
gs_eval = map((f, x) -> f(x), evaluate, x⃗)
gs_eval = map(Kernels.evaluate_kernel, gs, x⃗)

# Determine indices to write in `u` arrays.
indvals = ntuple(Val(D)) do n
@inbounds begin
gdata = gs_eval[n]
vals = gdata.values
f = to_indices[n]
f(gdata.i, Ns[n]) => vals
inds = Kernels.kernel_indices(gdata.i, gs[n], Ns[n])
inds => vals
end
end

Expand Down Expand Up @@ -138,8 +137,6 @@ function spread_from_points!(
Base.require_one_based_indexing(x⃗s) # this is to make sure that all indices match
foreach(Base.require_one_based_indexing, vp_all)

evaluate = map(Kernels.evaluate_kernel_func, gs) # kernel evaluation functions
to_indices = map(Kernels.kernel_indices_func, gs) # functions returning spreading indices
xs_comp = StructArrays.components(x⃗s)

# Reinterpret `us_all` as real arrays, in case they are complex.
Expand Down Expand Up @@ -177,7 +174,7 @@ function spread_from_points!(
end

kernel! = spread_from_point_naive_kernel!(backend, workgroupsize)
kernel!(us_real, xs_comp, vp_sorted, pointperm_, evaluate, to_indices; ndrange)
kernel!(us_real, gs, xs_comp, vp_sorted, pointperm_; ndrange)

if sort_points === True()
foreach(KA.unsafe_free!, vp_sorted) # manually deallocate temporary arrays
Expand Down

0 comments on commit 0df18d9

Please sign in to comment.