diff --git a/src/abstractNFFTs.jl b/src/abstractNFFTs.jl index 4a07d1a..3c20cdd 100644 --- a/src/abstractNFFTs.jl +++ b/src/abstractNFFTs.jl @@ -91,7 +91,7 @@ Base.@constprop :aggressive function PlanNUFFT( m, σ, reltol = AbstractNFFTs.accuracyParams(; kwargs...) backend = KA.get_backend(xp) # e.g. use GPU backend if xp is a GPU array sort_points = sortNodes ? True() : False() # this is type-unstable (unless constant propagation happens) - block_size = blocking ? default_block_size() : nothing # also type-unstable + block_size = blocking ? default_block_size(backend) : nothing # also type-unstable kernel = window isa AbstractKernel ? window : convert_window_function(window) p = PlanNUFFT(Complex{Tr}, Ns, HalfSupport(m); backend, σ = Tr(σ), sort_points, fftshift, block_size, kernel, fftw_flags = fftflags) AbstractNFFTs.nodes!(p, xp) diff --git a/test/accuracy.jl b/test/accuracy.jl index 2aec8cc..cc83fcb 100644 --- a/test/accuracy.jl +++ b/test/accuracy.jl @@ -95,7 +95,7 @@ function test_nufft_type1_1d( Np = 2 * N, m = HalfSupport(8), σ = 1.25, - block_size = NonuniformFFTs.default_block_size(), + block_size = NonuniformFFTs.default_block_size(CPU()), ) where {T <: Number} if T <: Real Tr = T @@ -154,7 +154,7 @@ function test_nufft_type2_1d( Np = 2 * N, m = HalfSupport(8), σ = 1.25, - block_size = NonuniformFFTs.default_block_size(), + block_size = NonuniformFFTs.default_block_size(CPU()), ) where {T <: Number} if T <: Real Tr = T diff --git a/test/multidimensional.jl b/test/multidimensional.jl index e4c1b83..4ead8f1 100644 --- a/test/multidimensional.jl +++ b/test/multidimensional.jl @@ -32,7 +32,7 @@ function test_nufft_type1( Np = 2 * first(Ns), m = HalfSupport(8), σ = 1.25, - block_size = NonuniformFFTs.default_block_size(), + block_size = NonuniformFFTs.default_block_size(CPU()), sort_points = False(), ) where {T <: Number} Tr = real(T) @@ -79,7 +79,7 @@ function test_nufft_type2( Np = 2 * first(Ns), m = HalfSupport(8), σ = 1.25, - block_size = NonuniformFFTs.default_block_size(), + block_size = NonuniformFFTs.default_block_size(CPU()), sort_points = False(), ) where {T <: Number} Tr = real(T)