diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index daac645e..d2677881 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,8 +21,7 @@ jobs: experimental: [false] version: - '1.7' - - '1.8' - - '~1.9.0-0' + - '1.9' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index 92f7c81f..95eae0c8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PencilFFTs" uuid = "4a48f351-57a6-4416-9ec4-c37015456aae" authors = ["Juan Ignacio Polanco "] -version = "0.14.4" +version = "0.15.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -16,7 +16,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" AbstractFFTs = "1" FFTW = "1.6" MPI = "0.19, 0.20" -PencilArrays = "0.17" +PencilArrays = "0.18" Reexport = "1" TimerOutputs = "0.5" julia = "1.7" diff --git a/docs/src/PencilFFTs.md b/docs/src/PencilFFTs.md index cb03891b..803824ea 100644 --- a/docs/src/PencilFFTs.md +++ b/docs/src/PencilFFTs.md @@ -28,3 +28,9 @@ scale_factor(::PencilFFTPlan) timer(::PencilFFTPlan) is_inplace(::PencilFFTPlan) ``` + +## Internals + +```@docs +ManyPencilArrayRFFT! +``` diff --git a/docs/src/Transforms.md b/docs/src/Transforms.md index 3c2e2c10..d4fe4a42 100644 --- a/docs/src/Transforms.md +++ b/docs/src/Transforms.md @@ -17,7 +17,9 @@ BFFT BFFT! RFFT +RFFT! BRFFT +BRFFT! R2R R2R! diff --git a/src/PencilFFTs.jl b/src/PencilFFTs.jl index f2784cea..8c70555a 100644 --- a/src/PencilFFTs.jl +++ b/src/PencilFFTs.jl @@ -27,6 +27,7 @@ const AbstractTransformList{N} = NTuple{N, AbstractTransform} where N include("global_params.jl") include("plans.jl") +include("multiarrays_r2c.jl") include("allocate.jl") include("operations.jl") diff --git a/src/Transforms/r2c.jl b/src/Transforms/r2c.jl index 27af8a78..4dd0cb30 100644 --- a/src/Transforms/r2c.jl +++ b/src/Transforms/r2c.jl @@ -1,4 +1,5 @@ ## Real-to-complex and complex-to-real transforms. +using FFTW: FFTW """ RFFT() @@ -10,6 +11,13 @@ See also """ struct RFFT <: AbstractTransform end +""" + RFFT!() + +In-place version of [`RFFT`](@ref). +""" +struct RFFT! <: AbstractTransform end + """ BRFFT(d::Integer) BRFFT((d1, d2, ..., dN)) @@ -40,23 +48,40 @@ struct BRFFT <: AbstractTransform even_output :: Bool end -_show_extra_info(io::IO, tr::BRFFT) = print(io, tr.even_output ? "{even}" : "{odd}") +""" + BRFFT!(d::Integer) + BRFFT!((d1, d2, ..., dN)) + +In-place version of [`BRFFT`](@ref). +""" +struct BRFFT! <: AbstractTransform + even_output :: Bool +end + +const TransformR2C = Union{RFFT, RFFT!} +const TransformC2R = Union{BRFFT, BRFFT!} + +_show_extra_info(io::IO, tr::TransformC2R) = print(io, tr.even_output ? "{even}" : "{odd}") BRFFT(d::Integer) = BRFFT(iseven(d)) BRFFT(ts::Tuple) = BRFFT(last(ts)) # c2r transform is applied along the **last** dimension (opposite of FFTW) +BRFFT!(d::Integer) = BRFFT!(iseven(d)) +BRFFT!(ts::Tuple) = BRFFT!(last(ts)) # c2r transform is applied along the **last** dimension (opposite of FFTW) is_inplace(::Union{RFFT, BRFFT}) = false +is_inplace(::Union{RFFT!, BRFFT!}) = true -length_output(::RFFT, length_in::Integer) = div(length_in, 2) + 1 -length_output(tr::BRFFT, length_in::Integer) = 2 * length_in - 1 - tr.even_output +length_output(::TransformR2C, length_in::Integer) = div(length_in, 2) + 1 +length_output(tr::TransformC2R, length_in::Integer) = 2 * length_in - 1 - tr.even_output -eltype_output(::RFFT, ::Type{T}) where {T <: FFTReal} = Complex{T} -eltype_output(::BRFFT, ::Type{Complex{T}}) where {T <: FFTReal} = T +eltype_output(::TransformR2C, ::Type{T}) where {T <: FFTReal} = Complex{T} +eltype_output(::TransformC2R, ::Type{Complex{T}}) where {T <: FFTReal} = T -eltype_input(::RFFT, ::Type{T}) where {T <: FFTReal} = T -eltype_input(::BRFFT, ::Type{T}) where {T <: FFTReal} = Complex{T} +eltype_input(::TransformR2C, ::Type{T}) where {T <: FFTReal} = T +eltype_input(::TransformC2R, ::Type{T}) where {T <: FFTReal} = Complex{T} plan(::RFFT, A::AbstractArray, args...; kwargs...) = FFTW.plan_rfft(A, args...; kwargs...) +plan(::RFFT!, A::AbstractArray, args...; kwargs...) = plan_rfft!(A, args...; kwargs...) # NOTE: unlike most FFTW plans, this function also requires the length `d` of # the transform output along the first transformed dimension. @@ -65,23 +90,65 @@ function plan(tr::BRFFT, A::AbstractArray, dims; kwargs...) d = length_output(tr, Nin) FFTW.plan_brfft(A, d, dims; kwargs...) end +function plan(tr::BRFFT!, A::AbstractArray, dims; kwargs...) + Nin = size(A, first(dims)) # input length along first dimension + d = length_output(tr, Nin) + plan_brfft!(A, d, dims; kwargs...) +end binv(::RFFT, d) = BRFFT(d) binv(::BRFFT, d) = RFFT() +binv(::RFFT!, d) = BRFFT!(d) +binv(::BRFFT!, d) = RFFT!() -function scale_factor(tr::BRFFT, A::ComplexArray, dims) +function scale_factor(tr::TransformC2R, A::ComplexArray, dims) prod(dims; init = one(Int)) do i n = size(A, i) i == last(dims) ? length_output(tr, n) : n end end -scale_factor(::RFFT, A::RealArray, dims) = _prod_dims(A, dims) +scale_factor(::TransformR2C, A::RealArray, dims) = _prod_dims(A, dims) # r2c along the first dimension, then c2c for the other dimensions. expand_dims(tr::RFFT, ::Val{N}) where {N} = N === 0 ? () : (tr, expand_dims(FFT(), Val(N - 1))...) +expand_dims(tr::RFFT!, ::Val{N}) where {N} = + N === 0 ? () : (tr, expand_dims(FFT!(), Val(N - 1))...) expand_dims(tr::BRFFT, ::Val{N}) where {N} = (BFFT(), expand_dims(tr, Val(N - 1))...) +expand_dims(tr::BRFFT!, ::Val{N}) where {N} = (BFFT!(), expand_dims(tr, Val(N - 1))...) expand_dims(tr::BRFFT, ::Val{1}) = (tr, ) expand_dims(tr::BRFFT, ::Val{0}) = () +expand_dims(tr::BRFFT!, ::Val{1}) = (tr, ) +expand_dims(tr::BRFFT!, ::Val{0}) = () + +## FFTW wrappers for inplace RFFT plans + +function plan_rfft!(X::StridedArray{T,N}, region; + flags::Integer=FFTW.ESTIMATE, + timelimit::Real=FFTW.NO_TIMELIMIT) where {T<:FFTW.fftwReal,N} + sz = size(X) # physical input size (real) + osize = FFTW.rfft_output_size(sz, region) # output size (complex) + isize = ntuple(i -> i == first(region) ? 2osize[i] : osize[i], Val(N)) # padded input size (real) + if flags&FFTW.ESTIMATE != 0 # time measurement not required + X_padded = FFTW.FakeArray{T,N}(sz, FFTW.colmajorstrides(isize)) # fake allocation, only pointer, size and strides matter + Y = FFTW.FakeArray{Complex{T}}(osize) + else # need to allocate new array since size of X is too small... + data = Array{T}(undef, prod(isize)) + X_padded = view(reshape(data, isize), Base.OneTo.(sz)...) # allocation + Y = reshape(reinterpret(Complex{T}, data), osize) + end + return FFTW.rFFTWPlan{T,FFTW.FORWARD,true,N}(X_padded, Y, region, flags, timelimit) +end + +function plan_brfft!(X::StridedArray{Complex{T},N}, d, region; + flags::Integer=FFTW.ESTIMATE, + timelimit::Real=FFTW.NO_TIMELIMIT) where {T<:FFTW.fftwReal,N} + isize = size(X) # input size (complex) + osize = ntuple(i -> i == first(region) ? 2isize[i] : isize[i], Val(N)) # padded output size (real) + sz = FFTW.brfft_output_size(X, d, region) # physical output size (real) + Yflat = reinterpret(T, reshape(X, prod(isize))) + Y = view(reshape(Yflat, osize), Base.OneTo.(sz)...) # Y is padded + return FFTW.rFFTWPlan{Complex{T},FFTW.BACKWARD,true,N}(X, Y, region, flags, timelimit) +end diff --git a/src/allocate.jl b/src/allocate.jl index 5a23ba19..ec0f18d2 100644 --- a/src/allocate.jl +++ b/src/allocate.jl @@ -12,10 +12,12 @@ size `dims`, and a tuple of `N` `PencilArray`s. !!! note "In-place plans" - If `p` is an in-place plan, a + If `p` is an in-place real-to-real or complex-to-complex plan, a [`ManyPencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.ManyPencilArray) - is allocated. This - type holds `PencilArray` wrappers for the input and output transforms (as + is allocated. If `p` is an in-place real-to-complex plan, a + [`ManyPencilArrayRFFT!`](@ref) is allocated. + + These types hold `PencilArray` wrappers for the input and output transforms (as well as for intermediate transforms) which share the same space in memory. The input and output `PencilArray`s should be respectively accessed by calling [`first(::ManyPencilArray)`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#Base.first-Tuple{ManyPencilArray}) and @@ -39,17 +41,26 @@ size `dims`, and a tuple of `N` `PencilArray`s. # p * v_in # not allowed!! ``` """ -function allocate_input end +function allocate_input(p::PencilFFTPlan) + inplace = is_inplace(p) + _allocate_input(Val(inplace), p) +end # Out-of-place version -function allocate_input(p::PencilFFTPlan{T,N,false} where {T,N}) +function _allocate_input(inplace::Val{false}, p::PencilFFTPlan) T = eltype_input(p) pen = pencil_input(p) PencilArray{T}(undef, pen, p.extra_dims...) end # In-place version -function allocate_input(p::PencilFFTPlan{T,N,true} where {T,N}) +function _allocate_input(inplace::Val{true}, p::PencilFFTPlan) + (; transforms,) = p.global_params + _allocate_input(inplace, p, transforms...) +end + +# In-place: generic case +function _allocate_input(inplace::Val{true}, p::PencilFFTPlan, transforms...) pencils = map(pp -> pp.pencil_in, p.plans) # Note that for each 1D plan, the input and output pencils are the same. @@ -61,6 +72,16 @@ function allocate_input(p::PencilFFTPlan{T,N,true} where {T,N}) ManyPencilArray{T}(undef, pencils...; extra_dims=p.extra_dims) end +# In-place: specific case of RFFT! +function _allocate_input( + inplace::Val{true}, p::PencilFFTPlan{T}, + ::Transforms.RFFT!, ::Vararg{Transforms.FFT!}, + ) where {T} + plans = p.plans + pencils = (first(plans).pencil_in, first(plans).pencil_out, map(pp -> pp.pencil_in, plans[2:end])...) + ManyPencilArrayRFFT!{T}(undef, pencils...; extra_dims=p.extra_dims) +end + allocate_input(p::PencilFFTPlan, dims...) = _allocate_many(allocate_input, p, dims...) @@ -76,17 +97,20 @@ If `p` is an in-place plan, a [`ManyPencilArray`](https://jipolanco.github.io/Pe See [`allocate_input`](@ref) for details. """ -function allocate_output end +function allocate_output(p::PencilFFTPlan) + inplace = is_inplace(p) + _allocate_output(Val(inplace), p) +end # Out-of-place version. -function allocate_output(p::PencilFFTPlan{T,N,false} where {T,N}) +function _allocate_output(inplace::Val{false}, p::PencilFFTPlan) T = eltype_output(p) pen = pencil_output(p) PencilArray{T}(undef, pen, p.extra_dims...) end # For in-place plans, the output and input are the same ManyPencilArray. -allocate_output(p::PencilFFTPlan{T,N,true} where {T,N}) = allocate_input(p) +_allocate_output(inplace::Val{true}, p::PencilFFTPlan) = _allocate_input(inplace, p) allocate_output(p::PencilFFTPlan, dims...) = _allocate_many(allocate_output, p, dims...) diff --git a/src/multiarrays_r2c.jl b/src/multiarrays_r2c.jl new file mode 100644 index 00000000..0faf8c5b --- /dev/null +++ b/src/multiarrays_r2c.jl @@ -0,0 +1,79 @@ +# copied and modified from https://github.com/jipolanco/PencilArrays.jl/blob/master/src/multiarrays.jl +import PencilArrays: AbstractManyPencilArray, _make_arrays + +""" + ManyPencilArrayRFFT!{T,N,M} <: AbstractManyPencilArray{N,M} + +Container holding `M` different [`PencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.PencilArray) views to the same +underlying data buffer. All views share the same and dimensionality `N`. +The element type `T` of the first view is real, that of subsequent views is +`Complex{T}`. + +This can be used to perform in-place real-to-complex plan, see also[`Transforms.RFFT!`](@ref). +It is used internally for such transforms by [`allocate_input`](@ref) and should not be constructed directly. + +--- + + ManyPencilArrayRFFT!{T}(undef, pencils...; extra_dims=()) + +Create a `ManyPencilArrayRFFT!` container that can hold data of type `T` and `Complex{T}` associated +to all the given [`Pencil`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.Pencil)s. + +The optional `extra_dims` argument is the same as for [`PencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.PencilArray). + +See also [`ManyPencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.ManyPencilArray) +""" +struct ManyPencilArrayRFFT!{ + T, # element type of real array + N, # number of dimensions of each array (including extra_dims) + M, # number of arrays + Arrays <: Tuple{Vararg{PencilArray,M}}, + DataVector <: AbstractVector{T}, + DataVectorComplex <: AbstractVector{Complex{T}}, + } <: AbstractManyPencilArray{N, M} + data :: DataVector + data_complex :: DataVectorComplex + arrays :: Arrays + + function ManyPencilArrayRFFT!{T}( + init, real_pencil::Pencil{Np}, complex_pencils::Vararg{Pencil{Np}}; + extra_dims::Dims=() + ) where {Np,T<:FFTReal} + # real_pencil is a Pencil with dimensions `dims` of a real array with no padding and no permutation + # the padded dimensions are (2*(dims[1] ÷ 2 + 1), dims[2:end]...) + # first(complex_pencils) is a Pencil with dimensions of a complex array (dims[1] ÷ 2 + 1, dims[2:end]...) and no permutation + pencils = (real_pencil, complex_pencils...) + BufType = PencilArrays.typeof_array(real_pencil) + @assert all(p -> PencilArrays.typeof_array(p) === BufType, complex_pencils) + @assert size_global(real_pencil)[2:end] == size_global(first(complex_pencils))[2:end] + @assert first(size_global(real_pencil)) ÷ 2 + 1 == first(size_global(first(complex_pencils))) + + data_length = max(2 .* length.(complex_pencils)...) * prod(extra_dims) + data_real = BufType{T}(init, data_length) + + # we don't use data_complex = reinterpret(Complex{T}, data_real) + # since there is an issue with StridedView of ReinterpretArray, called by _permutedims in PencilArrays.Transpositions + ptr_complex = convert(Ptr{Complex{T}}, pointer(data_real)) + data_complex = unsafe_wrap(BufType, ptr_complex, data_length ÷ 2) + + array_real = _make_real_array(data_real, extra_dims, real_pencil) + arrays_complex = PencilArrays._make_arrays(data_complex, extra_dims, complex_pencils...) + arrays = (array_real, arrays_complex...) + + N = Np + length(extra_dims) + M = length(pencils) + new{T, N, M, typeof(arrays), typeof(data_real), typeof(data_complex)}(data_real, data_complex, arrays) + end +end + +function _make_real_array(data, extra_dims, p) + dims_space_local = size_local(p, MemoryOrder()) + dims_padded_local = (2*(dims_space_local[1] ÷ 2 + 1), dims_space_local[2:end]...) + dims = (dims_padded_local..., extra_dims...) + axes_local = (Base.OneTo.(dims_space_local)..., Base.OneTo.(extra_dims)...) + n = prod(dims) + vec = view(data, Base.OneTo(n)) + parent_arr = reshape(vec, dims) + arr = view(parent_arr, axes_local...) + PencilArray(p, arr) +end diff --git a/src/operations.jl b/src/operations.jl index 4bd4e485..8caab361 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -2,9 +2,9 @@ const RealOrComplex{T} = Union{T, Complex{T}} where T <: FFTReal const PlanArrayPair{P,A} = Pair{P,A} where {P <: PencilPlan1D, A <: PencilArray} # Types of array over which a PencilFFTPlan can operate. -# PencilArray and ManyPencilArray are respectively for out-of-place and in-place +# PencilArray, ManyPencilArray and ManyPencilArrayRFFT! are respectively for out-of-place, in-place and in-place rfft # transforms. -const FFTArray{T,N} = Union{PencilArray{T,N}, ManyPencilArray{T,N}} where {T,N} +const FFTArray{T,N} = Union{PencilArray{T,N}, ManyPencilArray{T,N}, ManyPencilArrayRFFT!{T,N}} where {T,N} # Collections of FFTArray (e.g. for vector components), for broadcasting plans # to each array. These types are basically those returned by `allocate_input` @@ -12,6 +12,8 @@ const FFTArray{T,N} = Union{PencilArray{T,N}, ManyPencilArray{T,N}} where {T,N} const FFTArrayCollection = Union{Tuple{Vararg{A}}, AbstractArray{A}} where {A <: FFTArray} +const PencilMultiarray{T,N} = Union{ManyPencilArray{T,N}, ManyPencilArrayRFFT!{T,N}} where {T,N} + # This allows to treat plans as scalars when broadcasting. # This means that, if u = (u1, u2, u3) is a tuple of PencilArrays # compatible with p, then p .* u does what one would expect, that is, it @@ -28,13 +30,24 @@ function LinearAlgebra.mul!( end end -# Backward transforms +# Backward transforms (unscaled) +function bmul!( + dst::FFTArray{T,N}, p::PencilFFTPlan{T,N}, src::FFTArray{Ti,N}, + ) where {T, N, Ti <: RealOrComplex} + @timeit_debug p.timer "PencilFFTs bmul!" begin + _check_arrays(p, dst, src) + _apply_plans!(Val(FFTW.BACKWARD), p, dst, src) + end +end + +# Inverse transforms (scaled) function LinearAlgebra.ldiv!( dst::FFTArray{T,N}, p::PencilFFTPlan{T,N}, src::FFTArray{Ti,N}, ) where {T, N, Ti <: RealOrComplex} @timeit_debug p.timer "PencilFFTs ldiv!" begin _check_arrays(p, dst, src) _apply_plans!(Val(FFTW.BACKWARD), p, dst, src) + _scale!(dst, inv(scale_factor(p))) end end @@ -48,13 +61,21 @@ function Base.:\(p::PencilFFTPlan, src::FFTArray) ldiv!(dst, p, src) end +function _scale!(dst::PencilArray{<:RealOrComplex{T},N}, inv_scale::Number) where {T,N} + dst .*= inv_scale +end + +function _scale!(dst::PencilMultiarray{<:RealOrComplex{T},N}, inv_scale::Number) where {T,N} + first(dst) .*= inv_scale +end + # Out-of-place version _maybe_allocate(allocator::Function, p::PencilFFTPlan{T,N,false} where {T,N}, ::PencilArray) = allocator(p) # In-place version _maybe_allocate(::Function, ::PencilFFTPlan{T,N,true} where {T,N}, - src::ManyPencilArray) = src + src::PencilMultiarray) = src # Fallback case. function _maybe_allocate(::Function, p::PencilFFTPlan, src::A) where {A} @@ -76,7 +97,7 @@ end function _check_arrays( p::PencilFFTPlan{T,N,true} where {T,N}, - Ain::ManyPencilArray, Aout::ManyPencilArray, + Ain::PencilMultiarray, Aout::PencilMultiarray, ) if Ain !== Aout throw(ArgumentError( @@ -121,6 +142,10 @@ for f in (:mul!, :ldiv!) (check_compatible(dst, src); $f.(dst, p, src)) end +bmul!(dst::FFTArrayCollection, p::PencilFFTPlan, + src::FFTArrayCollection) = + (check_compatible(dst, src); bmul!.(dst, p, src)) + for f in (:*, :\) @eval Base.$f(p::PencilFFTPlan, src::FFTArrayCollection) = $f.(p, src) @@ -143,11 +168,6 @@ function _apply_plans!( _apply_plans_out_of_place!(dir, full_plan, y, x, plans...) - if dir === Val(FFTW.BACKWARD) - # Scale transform. - y ./= scale_factor(full_plan) - end - y end @@ -163,9 +183,42 @@ function _apply_plans!( _apply_plans_in_place!(dir, full_plan, nothing, pp...) - if dir === Val(FFTW.BACKWARD) - # Scale transform. - first(A) ./= scale_factor(full_plan) + A +end + +# In-place RFFT version +function _apply_plans!( + dir::Val, full_plan::PencilFFTPlan{T,N,true}, + A::ManyPencilArrayRFFT!{T,N}, A_again::ManyPencilArrayRFFT!{T,N}) where {T<:FFTReal,N} + @assert A === A_again + + # pairs for 1D FFT! plans, RFFT! plan is treated separately + pairs = _make_pairs(full_plan.plans[2:end], A.arrays[3:end]) + + # Backward transforms are applied in reverse order. + pp = dir === Val(FFTW.BACKWARD) ? reverse(pairs) : pairs + + if dir === Val(FFTW.FORWARD) + # apply separately first transform (RFFT!) + _apply_rfft_plan_in_place!(dir, full_plan, A.arrays[2], first(full_plan.plans), A.arrays[1]) + # apply recursively all successive transforms (FFT!) + _apply_plans_in_place!(dir, full_plan, A.arrays[2], pp...) + elseif dir === Val(FFTW.BACKWARD) + # apply recursively all transforms but last (BFFT!) + _apply_plans_in_place!(dir, full_plan, nothing, pp...) + # transpose before last transform + t = if pp == () + nothing + else + @assert Base.mightalias(A.arrays[3], A.arrays[2]) # they're aliased! + t = Transpositions.Transposition(A.arrays[2], A.arrays[3], + method=full_plan.transpose_method) + transpose!(t, waitall=false) + end + # apply separately last transform (BRFFT!) + _apply_rfft_plan_in_place!(dir, full_plan, A.arrays[1], first(full_plan.plans), A.arrays[2]) + + _wait_mpi_operations!(t, full_plan.timer) end A @@ -240,6 +293,12 @@ end _apply_plans_in_place!(::Val, ::PencilFFTPlan, u_prev::PencilArray) = u_prev +function _apply_rfft_plan_in_place!(dir::Val, full_plan::PencilFFTPlan, A_out ::PencilArray{To,N}, p::PencilPlan1D{ti,to,Pi,Po,Tr}, A_in ::PencilArray{Ti,N}) where + {Ti<:RealOrComplex{T},To<:RealOrComplex{T},ti<:RealOrComplex{T},to<:RealOrComplex{T},Pi,Po,N,Tr<:Union{Transforms.RFFT!,Transforms.BRFFT!}} where T<:FFTReal + fft_plan = dir === Val(FFTW.FORWARD) ? p.fft_plan : p.bfft_plan + @timeit_debug full_plan.timer "FFT!" mul!(parent(A_out), fft_plan, parent(A_in)) +end + _split_first(a, b...) = (a, b) # (x, y, z, w) -> (x, (y, z, w)) function _make_pairs(plans::Tuple{Vararg{PencilPlan1D,N}}, diff --git a/test/rfft.jl b/test/rfft.jl index d955ca59..e1212c14 100644 --- a/test/rfft.jl +++ b/test/rfft.jl @@ -5,6 +5,7 @@ import MPI using BenchmarkTools using LinearAlgebra +using FFTW using Printf using Random using Test @@ -190,6 +191,104 @@ function test_rfft(size_in; benchmark=true) MPI.Barrier(comm) end +function test_rfft!(size_in; flags = FFTW.ESTIMATE, benchmark=true) + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + + rank == 0 && @info "Input data size: $size_in" + + # Test creating Pencil and creating plan. + pen = Pencil(size_in, comm) + + inplace_plan = @inferred PencilFFTPlan(pen, Transforms.RFFT!(), fftw_flags=flags) + outofplace_place = @inferred PencilFFTPlan(pen, Transforms.RFFT(), fftw_flags=flags) + + # Allocate and initialise scalar fields + u = @inferred allocate_input(inplace_plan) + x = first(u); x̂ = last(u) # Real and Complex views + + v = @inferred allocate_input(outofplace_place) + v̂ = @inferred allocate_output(outofplace_place) + + fill!(x, 0.0) + fill!(v, 0.0) + if rank == 0 + x[1] = 1.0; x[2] = 2.0 + v[1] = 1.0; v[2] = 2.0 + end + + @testset "RFFT! vs RFFT" begin + mul!(u, inplace_plan, u) + mul!(v̂, outofplace_place, v) + @test all(isapprox.(x̂, v̂, atol=1e-8)) + + ldiv!(u, inplace_plan, u) + ldiv!(v, outofplace_place, v̂) + @test all(isapprox.(x, v, atol=1e-8)) + rank == 0 && @test all(isapprox.(x[1:3], [1.0, 2.0, 0.0], atol = 1e-8)) + + rng = MersenneTwister(42) + init_random_field!(x̂, rng) + copyto!(parent(v̂), parent(x̂)) + + PencilFFTs.bmul!(u, inplace_plan, u) + PencilFFTs.bmul!(v, outofplace_place, v̂) + @test all(isapprox.(x, v, atol=1e-8)) + + mul!(u, inplace_plan, u) + mul!(v̂, outofplace_place, v) + @test all(isapprox.(x̂, v̂, atol=1e-8)) + end + if benchmark + println("micro-benchmarks: ") + println("- rfft!...\t") + @time mul!(u, inplace_plan, u) + @time mul!(u, inplace_plan, u) + println("- rfft...\t") + @time mul!(v̂, outofplace_place, v) + @time mul!(v̂, outofplace_place, v) + println("done ") + end + MPI.Barrier(comm) +end + +function test_1D_rfft!(size_in; flags=FFTW.ESTIMATE) + dims = (size_in,) + dims_padded = (2(dims[1] ÷ 2 + 1), dims[2:end]...) + dims_fourier = ((dims[1] ÷ 2 + 1), dims[2:end]...) + + A = zeros(Float64, dims_padded) + a = view(A, Base.OneTo.(dims)...) + â = reinterpret(Complex{Float64}, A) + + â2 = zeros(Complex{Float64}, dims_fourier) + a2 = zeros(Float64, dims) + + p = Transforms.plan_rfft!(a, 1, flags=flags) + p2 = FFTW.plan_rfft(a2, 1, flags=flags) + bp = Transforms.plan_brfft!(â, dims[1], 1, flags=flags) + bp2 = FFTW.plan_brfft(â, dims[1], 1, flags=flags) + + fill!(a2, 0.0); a2[1] = 1; a2[2] = 2; + fill!(a, 0.0); a[1] = 1; a[2] = 2; + + @testset "1D RFFT! vs RFFT" begin + mul!(â, p, a) + mul!(â2, p2, a2) + @test all(isapprox.(â2, â, atol = 1e-8)) + + mul!(a, bp, â) + mul!(a2, bp2, â2) + @test all(isapprox.(a2, a, atol = 1e-8)) + + a /= size_in + a2 /= size_in + @test all(isapprox.(a[1:3], [1.0, 2.0, 0.0], atol = 1e-8)) + end + + MPI.Barrier(comm) +end + MPI.Init() comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) @@ -198,3 +297,11 @@ rank == 0 || redirect_stdout(devnull) test_rfft(DATA_DIMS_EVEN) println() test_rfft(DATA_DIMS_ODD, benchmark=false) + +test_1D_rfft!(first(DATA_DIMS_ODD)) +test_1D_rfft!(first(DATA_DIMS_EVEN), flags = FFTW.MEASURE) +test_1D_rfft!(first(DATA_DIMS_EVEN)) + +test_rfft!(DATA_DIMS_ODD, benchmark=false) +test_rfft!(DATA_DIMS_EVEN, benchmark=false) +# test_rfft!((256,256,256)) # similar execution times for large rfft and rfft!