Skip to content

Commit

Permalink
Fix rfft! test using latest PencilArrays.jl
Browse files Browse the repository at this point in the history
Also made some formatting changes.
  • Loading branch information
jipolanco committed Jul 15, 2023
1 parent dcb0bda commit ab1a2ce
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions test/rfft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ function test_rfft(size_in; benchmark=true)
MPI.Barrier(comm)
end

function test_rfft!(size_in; flags = FFTW.ESTIMATE, benchmark=true)
function test_rfft!(size_in; flags = FFTW.ESTIMATE, benchmark = true)
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)

Expand All @@ -200,8 +200,8 @@ function test_rfft!(size_in; flags = FFTW.ESTIMATE, benchmark=true)
# Test creating Pencil and creating plan.
pen = Pencil(size_in, comm)

inplace_plan = @inferred PencilFFTPlan(pen, Transforms.RFFT!(), fftw_flags=flags)
outofplace_place = @inferred PencilFFTPlan(pen, Transforms.RFFT(), fftw_flags=flags)
inplace_plan = @inferred PencilFFTPlan(pen, Transforms.RFFT!(); fftw_flags = flags)
outofplace_place = @inferred PencilFFTPlan(pen, Transforms.RFFT(); fftw_flags = flags)

# Allocate and initialise scalar fields
u = @inferred allocate_input(inplace_plan)
Expand All @@ -220,24 +220,27 @@ function test_rfft!(size_in; flags = FFTW.ESTIMATE, benchmark=true)
@testset "RFFT! vs RFFT" begin
mul!(u, inplace_plan, u)
mul!(v̂, outofplace_place, v)
@test all(isapprox.(x̂, v̂, atol=1e-8))
@test all(isapprox.(x̂, v̂; atol = 1e-8))

ldiv!(u, inplace_plan, u)
ldiv!(v, outofplace_place, v̂)
@test all(isapprox.(x, v, atol=1e-8))
rank == 0 && @test all(isapprox.(x[1:3], [1.0, 2.0, 0.0], atol = 1e-8))
@test all(isapprox.(x, v; atol = 1e-8))

if rank == 0
@test all(isapprox.(@view(x[1:3]), [1.0, 2.0, 0.0]; atol = 1e-8))
end

rng = MersenneTwister(42)
init_random_field!(x̂, rng)
copyto!(parent(v̂), parent(x̂))

PencilFFTs.bmul!(u, inplace_plan, u)
PencilFFTs.bmul!(v, outofplace_place, v̂)
@test all(isapprox.(x, v, atol=1e-8))
@test all(isapprox.(x, v; atol = 1e-8))

mul!(u, inplace_plan, u)
mul!(v̂, outofplace_place, v)
@test all(isapprox.(x̂, v̂, atol=1e-8))
@test all(isapprox.(x̂, v̂; atol = 1e-8))
end
if benchmark
println("micro-benchmarks: ")
Expand All @@ -252,7 +255,7 @@ function test_rfft!(size_in; flags = FFTW.ESTIMATE, benchmark=true)
MPI.Barrier(comm)
end

function test_1D_rfft!(size_in; flags=FFTW.ESTIMATE)
function test_1D_rfft!(size_in; flags = FFTW.ESTIMATE)
dims = (size_in,)
dims_padded = (2(dims[1] ÷ 2 + 1), dims[2:end]...)
dims_fourier = ((dims[1] ÷ 2 + 1), dims[2:end]...)
Expand All @@ -264,26 +267,31 @@ function test_1D_rfft!(size_in; flags=FFTW.ESTIMATE)
â2 = zeros(Complex{Float64}, dims_fourier)
a2 = zeros(Float64, dims)

p = Transforms.plan_rfft!(a, 1, flags=flags)
p2 = FFTW.plan_rfft(a2, 1, flags=flags)
bp = Transforms.plan_brfft!(â, dims[1], 1, flags=flags)
bp2 = FFTW.plan_brfft(â, dims[1], 1, flags=flags)
p = Transforms.plan_rfft!(a, 1; flags)
p2 = FFTW.plan_rfft(a2, 1; flags)
bp = Transforms.plan_brfft!(â, dims[1], 1; flags)
bp2 = FFTW.plan_brfft(â, dims[1], 1; flags)

fill!(a2, 0.0)
a2[1] = 1
a2[2] = 2

fill!(a2, 0.0); a2[1] = 1; a2[2] = 2;
fill!(a, 0.0); a[1] = 1; a[2] = 2;
fill!(a, 0.0)
a[1] = 1
a[2] = 2

@testset "1D RFFT! vs RFFT" begin
mul!(â, p, a)
mul!(â2, p2, a2)
@test all(isapprox.(â2, â, atol = 1e-8))
@test all(isapprox.(â2, â; atol = 1e-8))

mul!(a, bp, â)
mul!(a2, bp2, â2)
@test all(isapprox.(a2, a, atol = 1e-8))
@test all(isapprox.(a2, a; atol = 1e-8))

a /= size_in
a2 /= size_in
@test all(isapprox.(a[1:3], [1.0, 2.0, 0.0], atol = 1e-8))
@test all(isapprox.(@view(a[1:3]), [1.0, 2.0, 0.0]; atol = 1e-8))
end

MPI.Barrier(comm)
Expand Down

0 comments on commit ab1a2ce

Please sign in to comment.