Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In-place real-to-complex FFTs #65

Merged
merged 22 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ jobs:
experimental: [false]
version:
- '1.7'
- '1.8'
- '~1.9.0-0'
- '1.9'
os:
- ubuntu-latest
arch:
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PencilFFTs"
uuid = "4a48f351-57a6-4416-9ec4-c37015456aae"
authors = ["Juan Ignacio Polanco <[email protected]>"]
version = "0.14.4"
version = "0.15.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -16,7 +16,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
AbstractFFTs = "1"
FFTW = "1.6"
MPI = "0.19, 0.20"
PencilArrays = "0.17"
PencilArrays = "0.18"
Reexport = "1"
TimerOutputs = "0.5"
julia = "1.7"
6 changes: 6 additions & 0 deletions docs/src/PencilFFTs.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ scale_factor(::PencilFFTPlan)
timer(::PencilFFTPlan)
is_inplace(::PencilFFTPlan)
```

## Internals

```@docs
ManyPencilArrayRFFT!
```
2 changes: 2 additions & 0 deletions docs/src/Transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ BFFT
BFFT!

RFFT
RFFT!
BRFFT
BRFFT!

R2R
R2R!
Expand Down
1 change: 1 addition & 0 deletions src/PencilFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const AbstractTransformList{N} = NTuple{N, AbstractTransform} where N

include("global_params.jl")
include("plans.jl")
include("multiarrays_r2c.jl")
include("allocate.jl")
include("operations.jl")

Expand Down
85 changes: 76 additions & 9 deletions src/Transforms/r2c.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## Real-to-complex and complex-to-real transforms.
using FFTW: FFTW

"""
RFFT()
Expand All @@ -10,6 +11,13 @@ See also
"""
struct RFFT <: AbstractTransform end

"""
RFFT!()

In-place version of [`RFFT`](@ref).
"""
struct RFFT! <: AbstractTransform end

"""
BRFFT(d::Integer)
BRFFT((d1, d2, ..., dN))
Expand Down Expand Up @@ -40,23 +48,40 @@ struct BRFFT <: AbstractTransform
even_output :: Bool
end

_show_extra_info(io::IO, tr::BRFFT) = print(io, tr.even_output ? "{even}" : "{odd}")
"""
BRFFT!(d::Integer)
BRFFT!((d1, d2, ..., dN))

In-place version of [`BRFFT`](@ref).
"""
struct BRFFT! <: AbstractTransform
even_output :: Bool
end

const TransformR2C = Union{RFFT, RFFT!}
const TransformC2R = Union{BRFFT, BRFFT!}

_show_extra_info(io::IO, tr::TransformC2R) = print(io, tr.even_output ? "{even}" : "{odd}")

BRFFT(d::Integer) = BRFFT(iseven(d))
BRFFT(ts::Tuple) = BRFFT(last(ts)) # c2r transform is applied along the **last** dimension (opposite of FFTW)
BRFFT!(d::Integer) = BRFFT!(iseven(d))
BRFFT!(ts::Tuple) = BRFFT!(last(ts)) # c2r transform is applied along the **last** dimension (opposite of FFTW)

is_inplace(::Union{RFFT, BRFFT}) = false
is_inplace(::Union{RFFT!, BRFFT!}) = true

length_output(::RFFT, length_in::Integer) = div(length_in, 2) + 1
length_output(tr::BRFFT, length_in::Integer) = 2 * length_in - 1 - tr.even_output
length_output(::TransformR2C, length_in::Integer) = div(length_in, 2) + 1
length_output(tr::TransformC2R, length_in::Integer) = 2 * length_in - 1 - tr.even_output

eltype_output(::RFFT, ::Type{T}) where {T <: FFTReal} = Complex{T}
eltype_output(::BRFFT, ::Type{Complex{T}}) where {T <: FFTReal} = T
eltype_output(::TransformR2C, ::Type{T}) where {T <: FFTReal} = Complex{T}
eltype_output(::TransformC2R, ::Type{Complex{T}}) where {T <: FFTReal} = T

eltype_input(::RFFT, ::Type{T}) where {T <: FFTReal} = T
eltype_input(::BRFFT, ::Type{T}) where {T <: FFTReal} = Complex{T}
eltype_input(::TransformR2C, ::Type{T}) where {T <: FFTReal} = T
eltype_input(::TransformC2R, ::Type{T}) where {T <: FFTReal} = Complex{T}

plan(::RFFT, A::AbstractArray, args...; kwargs...) = FFTW.plan_rfft(A, args...; kwargs...)
plan(::RFFT!, A::AbstractArray, args...; kwargs...) = plan_rfft!(A, args...; kwargs...)

# NOTE: unlike most FFTW plans, this function also requires the length `d` of
# the transform output along the first transformed dimension.
Expand All @@ -65,23 +90,65 @@ function plan(tr::BRFFT, A::AbstractArray, dims; kwargs...)
d = length_output(tr, Nin)
FFTW.plan_brfft(A, d, dims; kwargs...)
end
function plan(tr::BRFFT!, A::AbstractArray, dims; kwargs...)
Nin = size(A, first(dims)) # input length along first dimension
d = length_output(tr, Nin)
plan_brfft!(A, d, dims; kwargs...)
end

binv(::RFFT, d) = BRFFT(d)
binv(::BRFFT, d) = RFFT()
binv(::RFFT!, d) = BRFFT!(d)
binv(::BRFFT!, d) = RFFT!()

function scale_factor(tr::BRFFT, A::ComplexArray, dims)
function scale_factor(tr::TransformC2R, A::ComplexArray, dims)
prod(dims; init = one(Int)) do i
n = size(A, i)
i == last(dims) ? length_output(tr, n) : n
end
end

scale_factor(::RFFT, A::RealArray, dims) = _prod_dims(A, dims)
scale_factor(::TransformR2C, A::RealArray, dims) = _prod_dims(A, dims)

# r2c along the first dimension, then c2c for the other dimensions.
expand_dims(tr::RFFT, ::Val{N}) where {N} =
N === 0 ? () : (tr, expand_dims(FFT(), Val(N - 1))...)
expand_dims(tr::RFFT!, ::Val{N}) where {N} =
N === 0 ? () : (tr, expand_dims(FFT!(), Val(N - 1))...)

expand_dims(tr::BRFFT, ::Val{N}) where {N} = (BFFT(), expand_dims(tr, Val(N - 1))...)
expand_dims(tr::BRFFT!, ::Val{N}) where {N} = (BFFT!(), expand_dims(tr, Val(N - 1))...)
expand_dims(tr::BRFFT, ::Val{1}) = (tr, )
expand_dims(tr::BRFFT, ::Val{0}) = ()
expand_dims(tr::BRFFT!, ::Val{1}) = (tr, )
expand_dims(tr::BRFFT!, ::Val{0}) = ()

## FFTW wrappers for inplace RFFT plans

function plan_rfft!(X::StridedArray{T,N}, region;
flags::Integer=FFTW.ESTIMATE,
timelimit::Real=FFTW.NO_TIMELIMIT) where {T<:FFTW.fftwReal,N}
sz = size(X) # physical input size (real)
osize = FFTW.rfft_output_size(sz, region) # output size (complex)
isize = ntuple(i -> i == first(region) ? 2osize[i] : osize[i], Val(N)) # padded input size (real)
if flags&FFTW.ESTIMATE != 0 # time measurement not required
X_padded = FFTW.FakeArray{T,N}(sz, FFTW.colmajorstrides(isize)) # fake allocation, only pointer, size and strides matter
Y = FFTW.FakeArray{Complex{T}}(osize)
else # need to allocate new array since size of X is too small...
data = Array{T}(undef, prod(isize))
X_padded = view(reshape(data, isize), Base.OneTo.(sz)...) # allocation
Y = reshape(reinterpret(Complex{T}, data), osize)
end
return FFTW.rFFTWPlan{T,FFTW.FORWARD,true,N}(X_padded, Y, region, flags, timelimit)
end

function plan_brfft!(X::StridedArray{Complex{T},N}, d, region;
flags::Integer=FFTW.ESTIMATE,
timelimit::Real=FFTW.NO_TIMELIMIT) where {T<:FFTW.fftwReal,N}
isize = size(X) # input size (complex)
osize = ntuple(i -> i == first(region) ? 2isize[i] : isize[i], Val(N)) # padded output size (real)
sz = FFTW.brfft_output_size(X, d, region) # physical output size (real)
Yflat = reinterpret(T, reshape(X, prod(isize)))
Y = view(reshape(Yflat, osize), Base.OneTo.(sz)...) # Y is padded
return FFTW.rFFTWPlan{Complex{T},FFTW.BACKWARD,true,N}(X, Y, region, flags, timelimit)
end
42 changes: 33 additions & 9 deletions src/allocate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ size `dims`, and a tuple of `N` `PencilArray`s.

!!! note "In-place plans"

If `p` is an in-place plan, a
If `p` is an in-place real-to-real or complex-to-complex plan, a
[`ManyPencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.ManyPencilArray)
is allocated. This
type holds `PencilArray` wrappers for the input and output transforms (as
is allocated. If `p` is an in-place real-to-complex plan, a
[`ManyPencilArrayRFFT!`](@ref) is allocated.

These types hold `PencilArray` wrappers for the input and output transforms (as
well as for intermediate transforms) which share the same space in memory.
The input and output `PencilArray`s should be respectively accessed by
calling [`first(::ManyPencilArray)`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#Base.first-Tuple{ManyPencilArray}) and
Expand All @@ -39,17 +41,26 @@ size `dims`, and a tuple of `N` `PencilArray`s.
# p * v_in # not allowed!!
```
"""
function allocate_input end
function allocate_input(p::PencilFFTPlan)
inplace = is_inplace(p)
_allocate_input(Val(inplace), p)
end

# Out-of-place version
function allocate_input(p::PencilFFTPlan{T,N,false} where {T,N})
function _allocate_input(inplace::Val{false}, p::PencilFFTPlan)
T = eltype_input(p)
pen = pencil_input(p)
PencilArray{T}(undef, pen, p.extra_dims...)
end

# In-place version
function allocate_input(p::PencilFFTPlan{T,N,true} where {T,N})
function _allocate_input(inplace::Val{true}, p::PencilFFTPlan)
(; transforms,) = p.global_params
_allocate_input(inplace, p, transforms...)
end

# In-place: generic case
function _allocate_input(inplace::Val{true}, p::PencilFFTPlan, transforms...)
pencils = map(pp -> pp.pencil_in, p.plans)

# Note that for each 1D plan, the input and output pencils are the same.
Expand All @@ -61,6 +72,16 @@ function allocate_input(p::PencilFFTPlan{T,N,true} where {T,N})
ManyPencilArray{T}(undef, pencils...; extra_dims=p.extra_dims)
end

# In-place: specific case of RFFT!
function _allocate_input(
inplace::Val{true}, p::PencilFFTPlan{T},
::Transforms.RFFT!, ::Vararg{Transforms.FFT!},
) where {T}
plans = p.plans
pencils = (first(plans).pencil_in, first(plans).pencil_out, map(pp -> pp.pencil_in, plans[2:end])...)
ManyPencilArrayRFFT!{T}(undef, pencils...; extra_dims=p.extra_dims)
end

allocate_input(p::PencilFFTPlan, dims...) =
_allocate_many(allocate_input, p, dims...)

Expand All @@ -76,17 +97,20 @@ If `p` is an in-place plan, a [`ManyPencilArray`](https://jipolanco.github.io/Pe

See [`allocate_input`](@ref) for details.
"""
function allocate_output end
function allocate_output(p::PencilFFTPlan)
inplace = is_inplace(p)
_allocate_output(Val(inplace), p)
end

# Out-of-place version.
function allocate_output(p::PencilFFTPlan{T,N,false} where {T,N})
function _allocate_output(inplace::Val{false}, p::PencilFFTPlan)
T = eltype_output(p)
pen = pencil_output(p)
PencilArray{T}(undef, pen, p.extra_dims...)
end

# For in-place plans, the output and input are the same ManyPencilArray.
allocate_output(p::PencilFFTPlan{T,N,true} where {T,N}) = allocate_input(p)
_allocate_output(inplace::Val{true}, p::PencilFFTPlan) = _allocate_input(inplace, p)

allocate_output(p::PencilFFTPlan, dims...) =
_allocate_many(allocate_output, p, dims...)
Expand Down
79 changes: 79 additions & 0 deletions src/multiarrays_r2c.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# copied and modified from https://github.com/jipolanco/PencilArrays.jl/blob/master/src/multiarrays.jl
import PencilArrays: AbstractManyPencilArray, _make_arrays

"""
ManyPencilArrayRFFT!{T,N,M} <: AbstractManyPencilArray{N,M}

Container holding `M` different [`PencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.PencilArray) views to the same
underlying data buffer. All views share the same and dimensionality `N`.
The element type `T` of the first view is real, that of subsequent views is
`Complex{T}`.

This can be used to perform in-place real-to-complex plan, see also[`Transforms.RFFT!`](@ref).
It is used internally for such transforms by [`allocate_input`](@ref) and should not be constructed directly.

---

ManyPencilArrayRFFT!{T}(undef, pencils...; extra_dims=())

Create a `ManyPencilArrayRFFT!` container that can hold data of type `T` and `Complex{T}` associated
to all the given [`Pencil`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.Pencil)s.

The optional `extra_dims` argument is the same as for [`PencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.PencilArray).

See also [`ManyPencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.ManyPencilArray)
"""
struct ManyPencilArrayRFFT!{
T, # element type of real array
N, # number of dimensions of each array (including extra_dims)
M, # number of arrays
Arrays <: Tuple{Vararg{PencilArray,M}},
DataVector <: AbstractVector{T},
DataVectorComplex <: AbstractVector{Complex{T}},
} <: AbstractManyPencilArray{N, M}
data :: DataVector
data_complex :: DataVectorComplex
arrays :: Arrays

function ManyPencilArrayRFFT!{T}(
init, real_pencil::Pencil{Np}, complex_pencils::Vararg{Pencil{Np}};
extra_dims::Dims=()
) where {Np,T<:FFTReal}
# real_pencil is a Pencil with dimensions `dims` of a real array with no padding and no permutation
# the padded dimensions are (2*(dims[1] ÷ 2 + 1), dims[2:end]...)
# first(complex_pencils) is a Pencil with dimensions of a complex array (dims[1] ÷ 2 + 1, dims[2:end]...) and no permutation
pencils = (real_pencil, complex_pencils...)
BufType = PencilArrays.typeof_array(real_pencil)
@assert all(p -> PencilArrays.typeof_array(p) === BufType, complex_pencils)
@assert size_global(real_pencil)[2:end] == size_global(first(complex_pencils))[2:end]
@assert first(size_global(real_pencil)) ÷ 2 + 1 == first(size_global(first(complex_pencils)))

data_length = max(2 .* length.(complex_pencils)...) * prod(extra_dims)
data_real = BufType{T}(init, data_length)

# we don't use data_complex = reinterpret(Complex{T}, data_real)
# since there is an issue with StridedView of ReinterpretArray, called by _permutedims in PencilArrays.Transpositions
ptr_complex = convert(Ptr{Complex{T}}, pointer(data_real))
data_complex = unsafe_wrap(BufType, ptr_complex, data_length ÷ 2)

array_real = _make_real_array(data_real, extra_dims, real_pencil)
arrays_complex = PencilArrays._make_arrays(data_complex, extra_dims, complex_pencils...)
arrays = (array_real, arrays_complex...)

N = Np + length(extra_dims)
M = length(pencils)
new{T, N, M, typeof(arrays), typeof(data_real), typeof(data_complex)}(data_real, data_complex, arrays)
end
end

function _make_real_array(data, extra_dims, p)
dims_space_local = size_local(p, MemoryOrder())
dims_padded_local = (2*(dims_space_local[1] ÷ 2 + 1), dims_space_local[2:end]...)
dims = (dims_padded_local..., extra_dims...)
axes_local = (Base.OneTo.(dims_space_local)..., Base.OneTo.(extra_dims)...)
n = prod(dims)
vec = view(data, Base.OneTo(n))
parent_arr = reshape(vec, dims)
arr = view(parent_arr, axes_local...)
PencilArray(p, arr)
end
Loading