Skip to content

Commit

Permalink
Update GPU interpolation + remove unused function
Browse files Browse the repository at this point in the history
  • Loading branch information
jipolanco committed Oct 15, 2024
1 parent 494dac9 commit 185bce4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 13 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.5.5"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AbstractNFFTs = "7f219486-4aa7-41d6-80a7-e08ef20ceed7"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
Expand All @@ -22,6 +23,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
[compat]
AbstractFFTs = "1.5.0"
AbstractNFFTs = "0.8.2"
Adapt = "4.0.4"
Atomix = "0.1.0"
Bessels = "0.2"
Bumper = "0.6, 0.7"
Expand Down
5 changes: 0 additions & 5 deletions src/Kernels/Kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ end
kernel_indices(i, HalfSupport(M), args...)
end

# Returns a function which is callable from GPU kernels.
function kernel_indices_func(::AbstractKernelData{K, M}) where {K, M}
@inline (i, args...) -> kernel_indices(i, HalfSupport(M), args...)
end

# Takes into account periodic wrapping.
# This is equivalent to calling mod1(j, N) for each j, but much much faster.
# We assume the central index `i` is in 1:N and that M < N / 2.
Expand Down
13 changes: 5 additions & 8 deletions src/interpolation/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ using StaticArrays: MVector
# Interpolate onto a single point
@kernel function interpolate_to_point_naive_kernel!(
vp::NTuple{C},
@Const(gs::NTuple{D}),
@Const(points::NTuple{D}),
@Const(us::NTuple{C}),
@Const(pointperm),
@Const(Δxs::NTuple{D}), # grid step in each direction (oversampled grid)
evaluate::NTuple{D, <:Function}, # can't be marked Const for some reason
to_indices::NTuple{D, <:Function},
) where {C, D}
i = @index(Global, Linear)

Expand All @@ -26,15 +25,15 @@ using StaticArrays: MVector
Ns = size(first(us)) # grid dimensions

# Evaluate 1D kernels.
gs_eval = map((f, x) -> f(x), evaluate, x⃗)
gs_eval = map(Kernels.evaluate_kernel, gs, x⃗)

# Determine indices to load from `u` arrays.
indvals = ntuple(Val(D)) do n
@inbounds begin
gdata = gs_eval[n]
vals = gdata.values .* Δxs[n]
f = to_indices[n]
f(gdata.i, Ns[n]) => vals
inds = Kernels.kernel_indices(gdata.i, gs[n], Ns[n])
inds => vals
end
end

Expand All @@ -59,8 +58,6 @@ function interpolate!(
Base.require_one_based_indexing(x⃗s) # this is to make sure that all indices match
foreach(Base.require_one_based_indexing, vp_all)

evaluate = map(Kernels.evaluate_kernel_func, gs) # kernel evaluation functions
to_indices = map(Kernels.kernel_indices_func, gs) # functions returning spreading indices
xs_comp = StructArrays.components(x⃗s)
Δxs = map(Kernels.gridstep, gs)

Expand All @@ -84,7 +81,7 @@ function interpolate!(
ndrange = size(x⃗s) # iterate through points
workgroupsize = default_workgroupsize(backend, ndrange)
kernel! = interpolate_to_point_naive_kernel!(backend, workgroupsize)
kernel!(vp_sorted, xs_comp, us, pointperm_, Δxs, evaluate, to_indices; ndrange)
kernel!(vp_sorted, gs, xs_comp, us, pointperm_, Δxs; ndrange)

if sort_points === True()
kernel_perm! = interp_permute_kernel!(backend, workgroupsize)
Expand Down

0 comments on commit 185bce4

Please sign in to comment.