Skip to content

Commit

Permalink
Add tests in 2D
Browse files Browse the repository at this point in the history
Also, support passing points as SVector.
  • Loading branch information
jipolanco committed Dec 12, 2023
1 parent 48e0410 commit 68aed59
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 3 deletions.
10 changes: 8 additions & 2 deletions src/NonuniformFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,22 @@ function PlanNUFFT(N::Union{Integer, Dims}, args...; kws...)
PlanNUFFT(ComplexF64, N, args...; kws...)
end

function set_points!(p::PlanNUFFT{T, N}, xp::AbstractVector{<:NTuple{N}}) where {T, N}
# Here the element type of `xp` can either be an NTuple{N, <:Real}, an SVector{N, <:Real},
# or anything else which has length `N`.
function set_points!(p::PlanNUFFT{T, N}, xp::AbstractVector) where {T, N}
(; points,) = p
type_length(eltype(xp)) == N || throw(DimensionMismatch(lazy"expected $N-dimensional points"))
resize!(points, length(xp))
Base.require_one_based_indexing(points)
@inbounds for (i, x) enumerate(xp)
points[i] = x
points[i] = NTuple{N}(x) # converts `x` to Tuple if it's an SVector
end
p
end

type_length(::Type{T}) where {T} = length(T) # usually for SVector
type_length(::Type{<:NTuple{N}}) where {N} = N

function set_points!(p::PlanNUFFT{T, N}, xp::NTuple{N, AbstractVector}) where {T, N}
set_points!(p, StructVector(xp))
end
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
13 changes: 12 additions & 1 deletion test/accuracy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,23 @@ function test_nufft_type2_1d(

err = l2_error(vp, vp_exact)

# Inference tests
if VERSION < v"1.10-"
# On Julia 1.9, there seems to be a runtime dispatch related to throwing a
# DimensionMismatch error using LazyStrings (in NonuniformFFTs.check_nufft_uniform_data).
JET.@test_opt ignored_modules=(Base,) NonuniformFFTs.set_points!(plan_nufft, xp)
JET.@test_opt ignored_modules=(Base,) NonuniformFFTs.exec_type2!(vp, plan_nufft, ûs)
else
JET.@test_opt NonuniformFFTs.set_points!(plan_nufft, xp)
JET.@test_opt NonuniformFFTs.exec_type2!(vp, plan_nufft, ûs)
end

check_nufft_error(T, kernel, m, σ, err)

err
end

@testset "NUFFTs: $T" for T (Float64, ComplexF64)
@testset "1D NUFFTs: $T" for T (Float64, ComplexF64)
@testset "Type 1 NUFFTs" begin
for M 4:10
m = HalfSupport(M)
Expand Down
172 changes: 172 additions & 0 deletions test/multidimensional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
using Test
using Random: Random
using AbstractFFTs: fftfreq, rfftfreq
using JET: JET
using StaticArrays: SVector
using LinearAlgebra:
using NonuniformFFTs

function check_nufft_error(::Type{Float64}, ::BackwardsKaiserBesselKernel, ::HalfSupport{M}, σ, err) where {M}
if σ 1.25
err_min_kb = 4e-12 # error reaches a minimum at ~2e-12 for M = 10
@test err < max(10.0^(-1.20 * M) * 2, err_min_kb)
elseif σ 2.0
err_max_kb = max(6 * 10.0^(-1.9 * M), 4e-14) # error "plateaus" at ~2e-14 for M ≥ 8
@test err < err_max_kb
end
nothing
end

check_nufft_error(::Type{ComplexF64}, args...) = check_nufft_error(Float64, args...)

function l2_error(us, vs)
err = sum(zip(us, vs)) do (u, v)
abs2(u - v)
end
norm = sum(abs2, vs)
sqrt(err / norm)
end

function test_nufft_type1(
::Type{T}, Ns::Dims;
kernel = BackwardsKaiserBesselKernel(),
Np = 2 * first(Ns),
m = HalfSupport(8),
σ = 1.25,
) where {T <: Number}
Tr = real(T)
ks = map(N -> fftfreq(N, Tr(N)), Ns)
if T <: Real
ks = Base.setindex(ks, rfftfreq(Ns[1], Tr(Ns[1])), 1) # we perform r2c transform along first dimension
end

# Generate some non-uniform random data
rng = Random.Xoshiro(42)
d = length(Ns)
xp = rand(rng, SVector{d, Tr}, Np) # non-uniform points in [0, 1]ᵈ
for i eachindex(xp)
xp[i] = xp[i] .* 2π # rescale points to [0, 2π]ᵈ
end
vp = randn(rng, T, Np) # random values at points

# Compute "exact" non-uniform transform
ûs_exact = zeros(Complex{Tr}, map(length, ks))
for I CartesianIndices(ûs_exact)
k⃗ = SVector(map(getindex, ks, Tuple(I)))
for (x⃗, v) zip(xp, vp)
ûs_exact[I] += v * cis(-k⃗ x⃗)
end
end

# Compute NUFFT
ûs = Array{Complex{Tr}}(undef, map(length, ks))
plan_nufft = @inferred PlanNUFFT(T, Ns, m; σ, kernel)
NonuniformFFTs.set_points!(plan_nufft, xp)
NonuniformFFTs.exec_type1!(ûs, plan_nufft, vp)

# Check results
err = l2_error(ûs, ûs_exact)

# Inference tests
if VERSION < v"1.10-"
# On Julia 1.9, there seems to be a runtime dispatch related to throwing a
# DimensionMismatch error using LazyStrings (in NonuniformFFTs.check_nufft_uniform_data).
JET.@test_opt ignored_modules=(Base,) NonuniformFFTs.set_points!(plan_nufft, xp)
JET.@test_opt ignored_modules=(Base,) NonuniformFFTs.exec_type1!(ûs, plan_nufft, vp)
else
JET.@test_opt NonuniformFFTs.set_points!(plan_nufft, xp)
JET.@test_opt NonuniformFFTs.exec_type1!(ûs, plan_nufft, vp)
end

check_nufft_error(T, kernel, m, σ, err)

err
end

function test_nufft_type2(
::Type{T}, Ns::Dims;
kernel = BackwardsKaiserBesselKernel(),
Np = 2 * first(Ns),
m = HalfSupport(8),
σ = 1.25,
) where {T <: Number}
Tr = real(T)
ks = map(N -> fftfreq(N, Tr(N)), Ns)
if T <: Real
ks = Base.setindex(ks, rfftfreq(Ns[1], Tr(Ns[1])), 1) # we perform r2c transform along first dimension
end

# Generate some uniform random data + non-uniform points
rng = Random.Xoshiro(42)
ûs = randn(rng, Complex{Tr}, map(length, ks))
d = length(Ns)
xp = rand(rng, SVector{d, Tr}, Np) # non-uniform points in [0, 1]ᵈ
for i eachindex(xp)
xp[i] = xp[i] .* 2π # rescale points to [0, 2π]ᵈ
end

# Compute "exact" type-2 transform (interpolation)
vp_exact = zeros(T, Np)
for I CartesianIndices(ûs)
k⃗ = SVector(map(getindex, ks, Tuple(I)))
= ûs[I]
for i eachindex(xp, vp_exact)
x⃗ = xp[i]
if T <: Real
# Complex-to-real transform with Hermitian symmetry.
factor = ifelse(iszero(k⃗[1]), 1, 2)
s, c = sincos(k⃗ x⃗)
ur, ui = real(û), imag(û)
vp_exact[i] += factor * (c * ur - s * ui)
else
# Usual complex-to-complex transform.
vp_exact[i] +=* cis(k⃗ x⃗)
end
end
end

# Compute NUFFT
vp = Array{T}(undef, Np)
plan_nufft = @inferred PlanNUFFT(T, Ns, m; σ, kernel)
NonuniformFFTs.set_points!(plan_nufft, xp)
NonuniformFFTs.exec_type2!(vp, plan_nufft, ûs)

err = l2_error(vp, vp_exact)

# Inference tests
if VERSION < v"1.10-"
# On Julia 1.9, there seems to be a runtime dispatch related to throwing a
# DimensionMismatch error using LazyStrings (in NonuniformFFTs.check_nufft_uniform_data).
JET.@test_opt ignored_modules=(Base,) NonuniformFFTs.set_points!(plan_nufft, xp)
JET.@test_opt ignored_modules=(Base,) NonuniformFFTs.exec_type2!(vp, plan_nufft, ûs)
else
JET.@test_opt NonuniformFFTs.set_points!(plan_nufft, xp)
JET.@test_opt NonuniformFFTs.exec_type2!(vp, plan_nufft, ûs)
end

check_nufft_error(T, kernel, m, σ, err)

err
end

@testset "2D NUFFTs: $T" for T (Float64, ComplexF64)
Ns = (64, 64)
@testset "Type 1 NUFFTs" begin
for M 4:8 # for σ = 1.25, going beyond M = 8 gives no improvements
m = HalfSupport(M)
σ = 1.25
@testset "$kernel (m = $M, σ = )" for kernel (BackwardsKaiserBesselKernel(),)
test_nufft_type1(T, Ns; m, σ, kernel)
end
end
end
@testset "Type 2 NUFFTs" begin
for M 4:8 # for σ = 1.25, going beyond M = 8 gives no improvements
m = HalfSupport(M)
σ = 1.25
@testset "$kernel (m = $M, σ = )" for kernel (BackwardsKaiserBesselKernel(),)
test_nufft_type2(T, Ns; m, σ, kernel)
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ end

@testset "NonuniformFFTs.jl" begin
@includetest "accuracy.jl"
@includetest "multidimensional.jl"
@includetest "uniform_points.jl"
end

0 comments on commit 68aed59

Please sign in to comment.