diff --git a/Cargo.lock b/Cargo.lock index d2acbec..38e61dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -277,7 +277,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08e854964160a323e65baa19a0b1a027f76d590faba01f05c0cbc3187221a8c9" dependencies = [ "matrixmultiply", - "num-complex", + "num-complex 0.4.0", "num-integer", "num-traits", "rawpointer", @@ -286,12 +286,23 @@ dependencies = [ [[package]] name = "ndrustfft" -version = "0.1.6" +version = "0.2.0" dependencies = [ "criterion", "ndarray", "num-traits", - "rustfft", + "realfft", + "rustdct", + "rustfft 6.0.1", +] + +[[package]] +name = "num-complex" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "747d632c0c558b87dbabbe6a82f3b4ae03720d0646ac5b7b4dae89394be5f2c5" +dependencies = [ + "num-traits", ] [[package]] @@ -433,6 +444,15 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "realfft" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7695c87f31dc3644760f23fb59a3fed47659703abf76cf2d111f03b9e712342" +dependencies = [ + "rustfft 6.0.1", +] + [[package]] name = "regex" version = "1.5.4" @@ -463,13 +483,36 @@ dependencies = [ "semver", ] +[[package]] +name = "rustdct" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fadcb505b98aa64da1dadb1498b912e3642aae4606623cb3ae952cd8da33f80d" +dependencies = [ + "rustfft 5.1.1", +] + +[[package]] +name = "rustfft" +version = "5.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1869bb2a6ff77380d52ff4bc631f165637035a55855c76aa462c85474dadc42f" +dependencies = [ + "num-complex 0.3.1", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + [[package]] name = "rustfft" version = "6.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1d089e5c57521629a59f5f39bca7434849ff89bd6873b521afe389c1c602543" dependencies = [ - "num-complex", + "num-complex 0.4.0", "num-integer", "num-traits", "primal-check", diff --git a/Cargo.toml b/Cargo.toml index fdf7dd6..beb78e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,11 @@ [package] name = "ndrustfft" -version = "0.1.6" +version = "0.2.0" authors = ["preiter "] edition = "2018" -description = "N-dimensional real-to-complex FFT and real-to-real DCT for Rust" +description = "N-dimensional c2c FFT, r2c FFT and r2r DCT for Rust" repository = "https://github.com/preiter93/ndrustfft" -keywords = ["fft", "dft", "dct", "rustfft", "ndarray"] +keywords = ["fft", "dft", "dct", "rustfft", "realfft", "rustdct", "ndarray"] readme = "README.md" license = "MIT" @@ -15,8 +15,10 @@ path = "src/lib.rs" [dependencies] ndarray = { version = "0.15.0", features = ["rayon"] } -rustfft = "6.0.1" +rustfft = "6.0" num-traits = "0.2.12" +rustdct = "0.6" +realfft = "2.0.1" [dev-dependencies] criterion = { version = "0.3.4", features = ["html_reports"] } diff --git a/README.md b/README.md index 3af3c8c..8602b31 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,12 @@ ## ndrustfft: *n*-dimensional complex-to-complex FFT, real-to-complex FFT and real-to-real DCT -This library is a wrapper for `RustFFT` that enables performing FFTs of complex-, real-valued -data and DCT's on *n*-dimensional arrays (ndarray). +This library is a wrapper for `RustFFT`, `RustDCT` and `RealFft` +that enables performing FFTs and DCTs of complex- and real-valued +data on *n*-dimensional arrays (ndarray). -ndrustfft provides Handler structs for FFT's and DCTs, which must be provided -to the respective function (see implemented transforms below) alongside with the arrays. -The Handlers contain the transform plans and buffers which reduce allocation cost. +ndrustfft provides Handler structs for FFT's and DCTs, which must be provided alongside +with the arrays to the respective function (see below) . The Handlers implement a process function, which is a wrapper around Rustfft's process function with additional functionality. Transforms along the outermost axis are in general the fastest, while transforms along @@ -15,12 +15,17 @@ other axis' will create temporary copies of the input array. ### Implemented transforms #### Complex-to-complex -- `fft` / `ifft`: [`ndfft`],[`ndfft_par`], [`ndifft`],[`ndifft_par`] +- `fft` : [`ndfft`], [`ndfft_par`] +- `ifft`: [`ndifft`],[`ndifft_par`] #### Real-to-complex -- `fft_r2c` / `ifft_r2c`: [`ndfft_r2c`],[`ndfft_r2c_par`], [`ndifft_r2c`],[`ndifft_r2c_par`] +- `fft_r2c` : [`ndfft_r2c`], [`ndfft_r2c_par`], +#### Complex-to-real +- `ifft_r2c`: [`ndifft_r2c`],[`ndifft_r2c_par`] #### Real-to-real -- `fft_r2hc` / `ifft_r2hc`: [`ndfft_r2hc`],[`ndfft_r2hc_par`], [`ndifft_r2hc`],[`ndifft_r2hc_par`] - `dct1`: [`nddct1`],[`nddct1_par`] +- `dct2`: [`nddct2`],[`nddct2_par`] +- `dct3`: [`nddct3`],[`nddct3_par`] +- `dct4`: [`nddct4`],[`nddct4_par`] ### Parallel The library ships all functions with a parallel version @@ -30,7 +35,7 @@ which leverages the parallel abilities of ndarray. 2-Dimensional real-to-complex fft along first axis ```rust use ndarray::{Array2, Dim, Ix}; -use ndrustfft::{ndfft_r2c, Complex, FftHandler}; +use ndrustfft::{ndfft_r2c, Complex, R2cFftHandler}; let (nx, ny) = (6, 4); let mut data = Array2::::zeros((nx, ny)); @@ -38,7 +43,7 @@ let mut vhat = Array2::>::zeros((nx / 2 + 1, ny)); for (i, v) in data.iter_mut().enumerate() { *v = i as f64; } -let mut fft_handler: FftHandler = FftHandler::new(nx); +let mut fft_handler = R2cFftHandler::::new(nx); ndfft_r2c( &mut data.view_mut(), &mut vhat.view_mut(), diff --git a/benches/ndrustfft.rs b/benches/ndrustfft.rs index dc0eaea..6ca45be 100644 --- a/benches/ndrustfft.rs +++ b/benches/ndrustfft.rs @@ -1,7 +1,8 @@ use criterion::{criterion_group, criterion_main, Criterion}; use ndarray::{Array, Dim, Ix}; use ndrustfft::{nddct1, DctHandler}; -use ndrustfft::{ndfft, ndrfft, Complex, FftHandler}; +use ndrustfft::{ndfft, Complex, FftHandler}; +use ndrustfft::{ndfft_r2c, R2cFftHandler}; const SIZES: [usize; 4] = [128, 264, 512, 1024]; pub fn bench_fft2d(c: &mut Criterion) { @@ -32,9 +33,9 @@ pub fn bench_rfft2d(c: &mut Criterion) { for (i, v) in data.iter_mut().enumerate() { *v = i as f64; } - let mut handler: FftHandler = FftHandler::new(*n); + let mut handler = R2cFftHandler::::new(*n); group.bench_function(&name, |b| { - b.iter(|| ndrfft(&mut data.view_mut(), &mut vhat.view_mut(), &mut handler, 0)) + b.iter(|| ndfft_r2c(&mut data.view_mut(), &mut vhat.view_mut(), &mut handler, 0)) }); } group.finish(); diff --git a/benches/ndrustfft_par.rs b/benches/ndrustfft_par.rs index 2455c28..283ed19 100644 --- a/benches/ndrustfft_par.rs +++ b/benches/ndrustfft_par.rs @@ -1,7 +1,9 @@ use criterion::{criterion_group, criterion_main, Criterion}; use ndarray::{Array, Dim, Ix}; use ndrustfft::{nddct1_par, DctHandler}; -use ndrustfft::{ndfft_par, ndrfft_par, Complex, FftHandler}; +use ndrustfft::{ndfft_par, Complex, FftHandler}; +use ndrustfft::{ndfft_r2c_par, R2cFftHandler}; + const SIZES: [usize; 4] = [128, 264, 512, 1024]; pub fn bench_fft2d(c: &mut Criterion) { @@ -32,9 +34,9 @@ pub fn bench_rfft2d(c: &mut Criterion) { for (i, v) in data.iter_mut().enumerate() { *v = i as f64; } - let mut handler: FftHandler = FftHandler::new(*n); + let mut handler = R2cFftHandler::::new(*n); group.bench_function(&name, |b| { - b.iter(|| ndrfft_par(&mut data.view_mut(), &mut vhat.view_mut(), &mut handler, 0)) + b.iter(|| ndfft_r2c_par(&mut data.view_mut(), &mut vhat.view_mut(), &mut handler, 0)) }); } group.finish(); diff --git a/src/lib.rs b/src/lib.rs index 3b3c6a2..7bd893b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,11 @@ //! # ndrustfft: *n*-dimensional complex-to-complex FFT, real-to-complex FFT and real-to-real DCT //! -//! This library is a wrapper for `RustFFT` that enables performing FFTs of complex-, real-valued -//! data and DCT's on *n*-dimensional arrays (ndarray). +//! This library is a wrapper for `RustFFT`, `RustDCT` and `RealFft` +//! that enables performing FFTs and DCTs of complex- and real-valued +//! data on *n*-dimensional arrays (ndarray). //! -//! ndrustfft provides Handler structs for FFT's and DCTs, which must be provided -//! to the respective function (see implemented transforms below) alongside with the arrays. -//! The Handlers contain the transform plans and buffers which reduce allocation cost. +//! ndrustfft provides Handler structs for FFT's and DCTs, which must be provided alongside +//! with the arrays to the respective function (see below) . //! The Handlers implement a process function, which is a wrapper around Rustfft's //! process function with additional functionality. //! Transforms along the outermost axis are in general the fastest, while transforms along @@ -13,12 +13,17 @@ //! //! ## Implemented transforms //! ### Complex-to-complex -//! - `fft` / `ifft`: [`ndfft`],[`ndfft_par`], [`ndifft`],[`ndifft_par`] +//! - `fft` : [`ndfft`], [`ndfft_par`] +//! - `ifft`: [`ndifft`],[`ndifft_par`] //! ### Real-to-complex -//! - `fft_r2c` / `ifft_r2c`: [`ndfft_r2c`],[`ndfft_r2c_par`], [`ndifft_r2c`],[`ndifft_r2c_par`] +//! - `fft_r2c` : [`ndfft_r2c`], [`ndfft_r2c_par`], +//! ### Complex-to-real +//! - `ifft_r2c`: [`ndifft_r2c`],[`ndifft_r2c_par`] //! ### Real-to-real -//! - `fft_r2hc` / `ifft_r2hc`: [`ndfft_r2hc`],[`ndfft_r2hc_par`], [`ndifft_r2hc`],[`ndifft_r2hc_par`] //! - `dct1`: [`nddct1`],[`nddct1_par`] +//! - `dct2`: [`nddct2`],[`nddct2_par`] +//! - `dct3`: [`nddct3`],[`nddct3_par`] +//! - `dct4`: [`nddct4`],[`nddct4_par`] //! //! ## Parallel //! The library ships all functions with a parallel version @@ -28,7 +33,7 @@ //! 2-Dimensional real-to-complex fft along first axis //! ``` //! use ndarray::{Array2, Dim, Ix}; -//! use ndrustfft::{ndfft_r2c, Complex, FftHandler}; +//! use ndrustfft::{ndfft_r2c, Complex, R2cFftHandler}; //! //! let (nx, ny) = (6, 4); //! let mut data = Array2::::zeros((nx, ny)); @@ -36,7 +41,7 @@ //! for (i, v) in data.iter_mut().enumerate() { //! *v = i as f64; //! } -//! let mut fft_handler: FftHandler = FftHandler::new(nx); +//! let mut fft_handler = R2cFftHandler::::new(nx); //! ndfft_r2c( //! &mut data.view_mut(), //! &mut vhat.view_mut(), @@ -50,10 +55,13 @@ extern crate ndarray; extern crate rustfft; use ndarray::{Array1, ArrayBase, Dimension, Zip}; use ndarray::{Data, DataMut}; +use num_traits::FloatConst; +use realfft::{ComplexToReal, RealFftPlanner, RealToComplex}; +use rustdct::{Dct1, DctPlanner, TransformType2And3, TransformType4}; pub use rustfft::num_complex::Complex; pub use rustfft::num_traits::Zero; pub use rustfft::FftNum; -use rustfft::FftPlanner; +use rustfft::{Fft, FftPlanner}; use std::sync::Arc; /// Declare procedural macro which creates functions for the individual @@ -71,7 +79,7 @@ macro_rules! create_transform { handler: &mut $h, axis: usize, ) where - T: FftNum, + T: FftNum + FloatConst, R: Data, S: Data + DataMut, D: Dimension, @@ -110,7 +118,7 @@ macro_rules! create_transform_par { handler: &mut $h, axis: usize, ) where - T: FftNum, + T: FftNum + FloatConst, R: Data, S: Data + DataMut, D: Dimension, @@ -140,40 +148,41 @@ macro_rules! create_transform_par { }; } -/// # *n*-dimensional real-to-complex Fourier Transform. +/// # *n*-dimensional complex-to-complex Fourier Transform. /// -/// Transforms a real ndarray of size *n* to a complex array of size -/// *n/2+1* and vice versa. The transformation is performed along a single +/// Transforms a complex ndarray of size *n* to a complex array of size +/// *n* and vice versa. The transformation is performed along a single /// axis, all other array dimensions are unaffected. /// Performs best on sizes which are mutiple of 2 or 3. /// -/// The accompanying functions for the forward transform are [`ndfft_r2c`] (serial) and -/// [`ndfft_r2c_par`] (parallel). +/// The accompanying functions for the forward transform are [`ndfft`] (serial) and +/// [`ndfft_par`] (parallel). /// -/// The accompanying functions for the inverse transform are [`ndifft_r2c`] (serial) and -/// [`ndifft_r2c_par`] (parallel). +/// The accompanying functions for the inverse transform are [`ndifft`] (serial) and +/// [`ndifft_par`] (parallel). /// /// # Example -/// 2-Dimensional real-to-complex fft along first axis +/// 2-Dimensional complex-to-complex fft along first axis /// ``` /// use ndarray::{Array2, Dim, Ix}; -/// use ndrustfft::{ndfft_r2c, Complex, FftHandler}; +/// use ndrustfft::{ndfft, Complex, FftHandler}; /// /// let (nx, ny) = (6, 4); -/// let mut data = Array2::::zeros((nx, ny)); -/// let mut vhat = Array2::>::zeros((nx / 2 + 1, ny)); +/// let mut data = Array2::>::zeros((nx, ny)); +/// let mut vhat = Array2::>::zeros((nx, ny)); /// for (i, v) in data.iter_mut().enumerate() { -/// *v = i as f64; +/// v.re = i as f64; +/// v.im = i as f64; /// } /// let mut fft_handler: FftHandler = FftHandler::new(nx); -/// ndfft_r2c(&mut data, &mut vhat, &mut fft_handler, 0); +/// ndfft(&mut data, &mut vhat, &mut fft_handler, 0); /// ``` #[derive(Clone)] pub struct FftHandler { n: usize, m: usize, - plan_fwd: Arc>, - plan_bwd: Arc>, + plan_fwd: Arc>, + plan_bwd: Arc>, buffer: Vec>, } @@ -231,172 +240,6 @@ impl FftHandler { } } - fn fft_r2c_lane(&mut self, data: &[T], out: &mut [Complex]) { - Self::assert_size(self.n, data.len()); - Self::assert_size(self.m, out.len()); - for (b, d) in self.buffer.iter_mut().zip(data.iter()) { - *b = Complex::new(*d, T::zero()); - } - self.plan_fwd.process(&mut self.buffer); - for (b, d) in self.buffer[0..=self.n / 2].iter().zip(out.iter_mut()) { - *d = *b; - } - } - - fn fft_r2c_lane_par(&self, data: &[T], out: &mut [Complex]) { - Self::assert_size(self.n, data.len()); - Self::assert_size(self.m, out.len()); - let mut buffer = vec![Complex::zero(); self.n]; - for (b, d) in buffer.iter_mut().zip(data.iter()) { - *b = Complex::new(*d, T::zero()); - } - self.plan_fwd.process(&mut buffer); - for (b, d) in buffer[0..=self.n / 2].iter().zip(out.iter_mut()) { - *d = *b; - } - } - - #[allow(clippy::cast_precision_loss)] - fn ifft_r2c_lane(&mut self, data: &[Complex], out: &mut [T]) { - Self::assert_size(self.m, data.len()); - Self::assert_size(self.n, out.len()); - let m = data.len(); - for (b, d) in self.buffer[..m].iter_mut().zip(data.iter()) { - *b = *d; - } - for (b, d) in self.buffer[m..].iter_mut().rev().zip(data[1..].iter()) { - b.re = d.re; - b.im = -d.im; - } - self.plan_bwd.process(&mut self.buffer); - let n64 = T::from_f64(1. / self.n as f64).unwrap(); - for (b, d) in self.buffer.iter().zip(out.iter_mut()) { - *d = b.re * n64; - } - } - - #[allow(clippy::cast_precision_loss)] - fn ifft_r2c_lane_par(&self, data: &[Complex], out: &mut [T]) { - Self::assert_size(self.m, data.len()); - let m = data.len(); - let mut buffer = vec![Complex::zero(); self.n]; - for (b, d) in buffer[..m].iter_mut().zip(data.iter()) { - *b = *d; - } - for (b, d) in buffer[m..].iter_mut().rev().zip(data[1..].iter()) { - b.re = d.re; - b.im = -d.im; - } - self.plan_bwd.process(&mut buffer); - let n64 = T::from_f64(1. / self.n as f64).unwrap(); - for (b, d) in buffer.iter().zip(out.iter_mut()) { - *d = b.re * n64; - } - } - - /// Real to half-complex [r0, r1, r2, r3, i2, i1] - #[allow(clippy::cast_precision_loss)] - fn fft_r2hc_lane(&mut self, data: &[T], out: &mut [T]) { - Self::assert_size(self.n, data.len()); - Self::assert_size(self.n, out.len()); - for (b, d) in self.buffer.iter_mut().zip(data.iter()) { - *b = Complex::new(*d, T::zero()); - } - self.plan_fwd.process(&mut self.buffer); - // Transfer to half-complex format - out[0] = self.buffer[0].re; - out[self.n / 2] = self.buffer[self.n / 2].re; - let (left, right) = out.split_at_mut(self.n / 2); - for (b, (d1, d2)) in self.buffer[1..self.n / 2] - .iter() - .zip(left[1..].iter_mut().zip(right[1..].iter_mut().rev())) - { - *d1 = b.re; - *d2 = b.im; - } - } - - #[allow(clippy::cast_precision_loss)] - fn fft_r2hc_lane_par(&self, data: &[T], out: &mut [T]) { - Self::assert_size(self.n, data.len()); - Self::assert_size(self.n, out.len()); - let mut buffer = vec![Complex::zero(); self.n]; - for (b, d) in buffer.iter_mut().zip(data.iter()) { - *b = Complex::new(*d, T::zero()); - } - self.plan_fwd.process(&mut buffer); - // Transfer to half-complex format - out[0] = buffer[0].re; - out[self.n / 2] = buffer[self.n / 2].re; - let (left, right) = out.split_at_mut(self.n / 2); - for (b, (d1, d2)) in buffer[1..self.n / 2] - .iter() - .zip(left[1..].iter_mut().zip(right[1..].iter_mut().rev())) - { - *d1 = b.re; - *d2 = b.im; - } - } - - #[allow(clippy::cast_precision_loss, clippy::shadow_unrelated)] - fn ifft_r2hc_lane(&mut self, data: &[T], out: &mut [T]) { - Self::assert_size(self.n, data.len()); - Self::assert_size(self.n, out.len()); - self.buffer[0].re = data[0]; - self.buffer[0].im = T::zero(); - self.buffer[self.n / 2].re = data[self.n / 2]; - self.buffer[self.n / 2].im = T::zero(); - let (left, right) = data.split_at(self.n / 2); - for (b, (d1, d2)) in self.buffer[1..self.n / 2] - .iter_mut() - .zip(left[1..].iter().zip(right[1..].iter().rev())) - { - b.re = *d1; - b.im = *d2; - } - // Conjugate part - let (left, right) = self.buffer.split_at_mut(self.n / 2); - for (r, l) in right[1..].iter_mut().rev().zip(left[1..].iter()) { - r.re = l.re; - r.im = -l.im; - } - - self.plan_bwd.process(&mut self.buffer); - let n64 = T::from_f64(1. / self.n as f64).unwrap(); - for (b, d) in self.buffer.iter().zip(out.iter_mut()) { - *d = b.re * n64; - } - } - - #[allow(clippy::cast_precision_loss, clippy::shadow_unrelated)] - fn ifft_r2hc_lane_par(&self, data: &[T], out: &mut [T]) { - Self::assert_size(self.n, data.len()); - Self::assert_size(self.n, out.len()); - let mut buffer = vec![Complex::zero(); self.n]; - buffer[0].re = data[0]; - buffer[self.n / 2].re = data[self.n / 2]; - let (left, right) = data.split_at(self.n / 2); - for (b, (d1, d2)) in buffer[1..self.n / 2] - .iter_mut() - .zip(left[1..].iter().zip(right[1..].iter().rev())) - { - b.re = *d1; - b.im = *d2; - } - // Conjugate part - let (left, right) = buffer.split_at_mut(self.n / 2); - for (r, l) in right[1..].iter_mut().rev().zip(left[1..].iter()) { - r.re = l.re; - r.im = -l.im; - } - - self.plan_bwd.process(&mut buffer); - let n64 = T::from_f64(1. / self.n as f64).unwrap(); - for (b, d) in buffer.iter().zip(out.iter_mut()) { - *d = b.re * n64; - } - } - fn assert_size(n: usize, size: usize) { assert!( n == size, @@ -456,12 +299,154 @@ create_transform!( ifft_lane ); +create_transform_par!( + /// Complex-to-complex Fourier Transform (parallel). + /// + /// Further infos: see [`ndfft`] + ndfft_par, + Complex, + Complex, + FftHandler, + fft_lane +); + +create_transform_par!( + /// Complex-to-complex inverse Fourier Transform (parallel). + /// + /// Further infos: see [`ndifft`] + ndifft_par, + Complex, + Complex, + FftHandler, + ifft_lane +); + +// create_transform_par!( +// /// Real-to-complex Fourier Transform (parallel). +// /// +// /// Further infos: see [`ndfft_r2c`] +// ndfft_r2c_par, +// T, +// Complex, +// FftHandler, +// fft_r2c_lane_par +// ); + +// create_transform_par!( +// /// Complex-to-real inverse Fourier Transform (parallel). +// /// +// /// Further infos: see [`ndifft_r2c`] +// ndifft_r2c_par, +// Complex, +// T, +// FftHandler, +// ifft_r2c_lane_par +// ); + +/// # *n*-dimensional real-to-complex Fourier Transform. +/// +/// Transforms a real ndarray of size *n* to a complex array of size +/// *n/2+1* and vice versa. The transformation is performed along a single +/// axis, all other array dimensions are unaffected. +/// Performs best on sizes which are mutiple of 2 or 3. +/// +/// The accompanying functions for the forward transform are [`ndfft_r2c`] (serial) and +/// [`ndfft_r2c_par`] (parallel). +/// +/// The accompanying functions for the inverse transform are [`ndifft_r2c`] (serial) and +/// [`ndifft_r2c_par`] (parallel). +/// +/// # Example +/// 2-Dimensional real-to-complex fft along first axis +/// ``` +/// use ndarray::{Array2, Dim, Ix}; +/// use ndrustfft::{ndfft_r2c, Complex, R2cFftHandler}; +/// +/// let (nx, ny) = (6, 4); +/// let mut data = Array2::::zeros((nx, ny)); +/// let mut vhat = Array2::>::zeros((nx / 2 + 1, ny)); +/// for (i, v) in data.iter_mut().enumerate() { +/// *v = i as f64; +/// } +/// let mut fft_handler = R2cFftHandler::::new(nx); +/// ndfft_r2c(&mut data, &mut vhat, &mut fft_handler, 0); +/// ``` +#[derive(Clone)] +pub struct R2cFftHandler { + n: usize, + m: usize, + plan_fwd: Arc>, + plan_bwd: Arc>, +} + +impl R2cFftHandler { + /// Creates a new `RealFftPlanner`. + /// + /// # Arguments + /// + /// * `n` - Length of array along axis of which fft will be performed. + /// The size of the complex array after the fft is performed will be of + /// size *n / 2 + 1*. + /// + /// # Examples + /// + /// ``` + /// use ndrustfft::R2cFftHandler; + /// let handler = R2cFftHandler::::new(10); + /// ``` + #[allow(clippy::similar_names)] + #[must_use] + pub fn new(n: usize) -> Self { + let mut planner = RealFftPlanner::::new(); + let fwd = planner.plan_fft_forward(n); + let bwd = planner.plan_fft_inverse(n); + Self { + n, + m: n / 2 + 1, + plan_fwd: Arc::clone(&fwd), + plan_bwd: Arc::clone(&bwd), + } + } + + fn fft_r2c_lane(&self, data: &[T], out: &mut [Complex]) { + Self::assert_size(self.n, data.len()); + Self::assert_size(self.m, out.len()); + let mut indata = vec![T::zero(); self.n]; + for (a, b) in indata.iter_mut().zip(data.iter()) { + *a = *b; + } + self.plan_fwd.process(&mut indata, out).unwrap(); + } + + #[allow(clippy::cast_precision_loss)] + fn ifft_r2c_lane(&self, data: &[Complex], out: &mut [T]) { + Self::assert_size(self.m, data.len()); + Self::assert_size(self.n, out.len()); + let n64 = T::from_f64(1. / self.n as f64).unwrap(); + let mut indata = vec![Complex::zero(); self.m]; + for (a, b) in indata.iter_mut().zip(data.iter()) { + a.re = b.re * n64; + a.im = b.im * n64; + } + self.plan_bwd.process(&mut indata, out).unwrap(); + } + + fn assert_size(n: usize, size: usize) { + assert!( + n == size, + "Size mismatch in fft, got {} expected {}", + size, + n + ); + } +} + create_transform!( /// Real-to-complex Fourier Transform (serial). /// # Example /// ``` /// use ndarray::Array2; - /// use ndrustfft::{ndfft_r2c, Complex, FftHandler}; + /// use ndrustfft::{ndfft_r2c, Complex, R2cFftHandler}; /// /// let (nx, ny) = (6, 4); /// let mut data = Array2::::zeros((nx, ny)); @@ -469,13 +454,13 @@ create_transform!( /// for (i, v) in data.iter_mut().enumerate() { /// *v = i as f64; /// } - /// let mut handler: FftHandler = FftHandler::new(nx); + /// let mut handler = R2cFftHandler::::new(nx); /// ndfft_r2c(&mut data, &mut vhat, &mut handler, 0); /// ``` ndfft_r2c, T, Complex, - FftHandler, + R2cFftHandler, fft_r2c_lane ); @@ -484,7 +469,7 @@ create_transform!( /// # Example /// ``` /// use ndarray::Array2; - /// use ndrustfft::{ndifft_r2c, Complex, FftHandler}; + /// use ndrustfft::{ndifft_r2c, Complex, R2cFftHandler}; /// /// let (nx, ny) = (6, 4); /// let mut data = Array2::::zeros((nx, ny)); @@ -492,84 +477,16 @@ create_transform!( /// for (i, v) in vhat.iter_mut().enumerate() { /// v.re = i as f64; /// } - /// let mut handler: FftHandler = FftHandler::new(nx); + /// let mut handler = R2cFftHandler::::new(nx); /// ndifft_r2c(&mut vhat, &mut data, &mut handler, 0); /// ``` ndifft_r2c, Complex, T, - FftHandler, + R2cFftHandler, ifft_r2c_lane ); -create_transform!( - /// Real-to-real Fourier Transform (serial). - /// # Example - /// ``` - /// use ndarray::Array2; - /// use ndrustfft::{ndfft_r2hc, Complex, FftHandler}; - /// - /// let (nx, ny) = (6, 4); - /// let mut data = Array2::::zeros((nx, ny)); - /// let mut vhat = Array2::::zeros((nx, ny)); - /// for (i, v) in data.iter_mut().enumerate() { - /// *v = i as f64; - /// } - /// let mut handler: FftHandler = FftHandler::new(nx); - /// ndfft_r2hc(&mut data, &mut vhat, &mut handler, 0); - /// ``` - ndfft_r2hc, - T, - T, - FftHandler, - fft_r2hc_lane -); - -create_transform!( - /// Real-to-real Fourier Transform (serial). - /// # Example - /// ``` - /// use ndarray::Array2; - /// use ndrustfft::{ndifft_r2hc, Complex, FftHandler}; - /// - /// let (nx, ny) = (6, 4); - /// let mut data = Array2::::zeros((nx, ny)); - /// let mut vhat = Array2::::zeros((nx, ny)); - /// for (i, v) in vhat.iter_mut().enumerate() { - /// *v = i as f64; - /// } - /// let mut handler: FftHandler = FftHandler::new(nx); - /// ndifft_r2hc(&mut vhat, &mut data, &mut handler, 0); - /// ``` - ndifft_r2hc, - T, - T, - FftHandler, - ifft_r2hc_lane -); - -create_transform_par!( - /// Complex-to-complex Fourier Transform (parallel). - /// - /// Further infos: see [`ndfft`] - ndfft_par, - Complex, - Complex, - FftHandler, - fft_lane -); - -create_transform_par!( - /// Complex-to-complex inverse Fourier Transform (parallel). - /// - /// Further infos: see [`ndifft`] - ndifft_par, - Complex, - Complex, - FftHandler, - ifft_lane -); - create_transform_par!( /// Real-to-complex Fourier Transform (parallel). /// @@ -577,8 +494,8 @@ create_transform_par!( ndfft_r2c_par, T, Complex, - FftHandler, - fft_r2c_lane_par + R2cFftHandler, + fft_r2c_lane ); create_transform_par!( @@ -588,33 +505,11 @@ create_transform_par!( ndifft_r2c_par, Complex, T, - FftHandler, - ifft_r2c_lane_par -); - -create_transform_par!( - /// Real-to-real Fourier Transform (parallel). - /// - /// Further infos: see [`ndfft_r2hc`] - ndfft_r2hc_par, - T, - T, - FftHandler, - fft_r2hc_lane_par -); - -create_transform_par!( - /// Real-to-real inverse Fourier Transform (parallel). - /// - /// Further infos: see [`ndifft_r2hc`] - ndifft_r2hc_par, - T, - T, - FftHandler, - ifft_r2hc_lane_par + R2cFftHandler, + ifft_r2c_lane ); -/// # *n*-dimensional real-to-real Cosine Transform (DCT-I). +/// # *n*-dimensional real-to-real Cosine Transform. /// /// The dct transforms a real ndarray of size *n* to a real array of size *n*. /// The transformation is performed along a single axis, all other array @@ -644,11 +539,13 @@ create_transform_par!( #[derive(Clone)] pub struct DctHandler { n: usize, - plan: Arc>, - buffer: Vec>, + plan_dct1: Arc>, + plan_dct2: Arc>, + plan_dct3: Arc>, + plan_dct4: Arc>, } -impl DctHandler { +impl DctHandler { /// Creates a new `DctHandler`. /// /// # Arguments @@ -664,59 +561,58 @@ impl DctHandler { /// ``` #[must_use] pub fn new(n: usize) -> Self { - let m = 2 * (n - 1); - let mut planner = FftPlanner::::new(); - let fft = planner.plan_fft_forward(m); - let buffer = vec![Complex::zero(); m]; - DctHandler:: { + let mut planner = DctPlanner::::new(); + let dct1 = planner.plan_dct1(n); + let dct2 = planner.plan_dct2(n); + let dct3 = planner.plan_dct3(n); + let dct4 = planner.plan_dct4(n); + Self { n, - plan: Arc::clone(&fft), - buffer, + plan_dct1: Arc::clone(&dct1), + plan_dct2: Arc::clone(&dct2), + plan_dct3: Arc::clone(&dct3), + plan_dct4: Arc::clone(&dct4), } } - /// # Algorithm: - /// 1. Reorder: - /// (a,b,c,d) -> (a,b,c,d,c,b) - /// - /// 2. Compute FFT - /// -> (a*,b*,c*,d*,c*,b*) - /// - /// 3. Extract - /// (a*,b*,c*,d*) - fn dct1_lane(&mut self, data: &[T], out: &mut [T]) { - self.assert_size(data.len()); - let m = self.buffer.len(); - for b in &mut self.buffer.iter_mut() { - b.re = T::zero(); - b.im = T::zero(); - } - self.buffer[0] = Complex::new(data[0], T::zero()); - for (i, d) in data[1..].iter().enumerate() { - self.buffer[i + 1] = Complex::new(*d, T::zero()); - self.buffer[m - i - 1] = Complex::new(*d, T::zero()); + fn dct1_lane(&self, data: &[T], out: &mut [T]) { + Self::assert_size(self, data.len()); + Self::assert_size(self, out.len()); + let two = T::one() + T::one(); + for (b, d) in out.iter_mut().zip(data.iter()) { + *b = *d * two; } - self.plan.process(&mut self.buffer); - out[0] = self.buffer[0].re; - for (i, d) in out[1..].iter_mut().enumerate() { - *d = self.buffer[i + 1].re; + self.plan_dct1.process_dct1(out); + } + + fn dct2_lane(&self, data: &[T], out: &mut [T]) { + Self::assert_size(self, data.len()); + Self::assert_size(self, out.len()); + let two = T::one() + T::one(); + for (b, d) in out.iter_mut().zip(data.iter()) { + *b = *d * two; } + self.plan_dct2.process_dct2(out); } - fn dct1_lane_par(&self, data: &[T], out: &mut [T]) { - self.assert_size(data.len()); - let m = 2 * (self.n - 1); - let mut buffer = vec![Complex::zero(); m]; - buffer[0] = Complex::new(data[0], T::zero()); - for (i, d) in data[1..].iter().enumerate() { - buffer[i + 1] = Complex::new(*d, T::zero()); - buffer[m - i - 1] = Complex::new(*d, T::zero()); + fn dct3_lane(&self, data: &[T], out: &mut [T]) { + Self::assert_size(self, data.len()); + Self::assert_size(self, out.len()); + let two = T::one() + T::one(); + for (b, d) in out.iter_mut().zip(data.iter()) { + *b = *d * two; } - self.plan.process(&mut buffer); - out[0] = buffer[0].re; - for (i, d) in out[1..].iter_mut().enumerate() { - *d = buffer[i + 1].re; + self.plan_dct3.process_dct3(out); + } + + fn dct4_lane(&self, data: &[T], out: &mut [T]) { + Self::assert_size(self, data.len()); + Self::assert_size(self, out.len()); + let two = T::one() + T::one(); + for (b, d) in out.iter_mut().zip(data.iter()) { + *b = *d * two; } + self.plan_dct4.process_dct4(out); } fn assert_size(&self, size: usize) { @@ -761,7 +657,61 @@ create_transform_par!( T, T, DctHandler, - dct1_lane_par + dct1_lane +); + +create_transform!( + /// Real-to-real Discrete Cosine Transform of type 2 DCT-2 (serial). + nddct2, + T, + T, + DctHandler, + dct2_lane +); + +create_transform_par!( + /// Real-to-real Discrete Cosine Transform of type 2 DCT-2 (parallel). + nddct2_par, + T, + T, + DctHandler, + dct2_lane +); + +create_transform!( + /// Real-to-real Discrete Cosine Transform of type 3 DCT-3 (serial). + nddct3, + T, + T, + DctHandler, + dct3_lane +); + +create_transform_par!( + /// Real-to-real Discrete Cosine Transform of type 3 DCT-3 (parallel). + nddct3_par, + T, + T, + DctHandler, + dct3_lane +); + +create_transform!( + /// Real-to-real Discrete Cosine Transform of type 4 DCT-4 (serial). + nddct4, + T, + T, + DctHandler, + dct4_lane +); + +create_transform_par!( + /// Real-to-real Discrete Cosine Transform of type 4 DCT-4 (parallel). + nddct4_par, + T, + T, + DctHandler, + dct4_lane ); /// Tests @@ -857,6 +807,15 @@ mod test { // Assert approx_eq_complex(&vhat, &solution); approx_eq_complex(&v, &v_copy); + + // Transform Par + let mut v = test_matrix_complex(); + ndfft_par(&mut v, &mut vhat, &mut handler, 1); + ndifft_par(&mut vhat, &mut v, &mut handler, 1); + + // Assert + approx_eq_complex(&vhat, &solution); + approx_eq_complex(&v, &v_copy); } #[test] @@ -894,7 +853,7 @@ mod test { let v_copy = v.clone(); let (nx, ny) = (v.shape()[0], v.shape()[1]); let mut vhat = Array2::>::zeros((nx, ny / 2 + 1)); - let mut handler: FftHandler = FftHandler::new(ny); + let mut handler = R2cFftHandler::::new(ny); // Transform ndfft_r2c(&mut v, &mut vhat, &mut handler, 1); @@ -903,38 +862,19 @@ mod test { // Assert approx_eq_complex(&vhat, &solution); approx_eq(&v, &v_copy); - } - #[test] - fn test_fft_r2hc() { - // Solution from np.fft.rfft - let solution = array![ - [0.61, 0.543, -0.572, 0.048, -3.08, -2.562], - [2.795, 1.944, -1.291, 1.179, -1.51, 1.332], - [2.259, 0.36, 2.275, 0.979, 2.23, -0.242], - [-0.296, 2.696, 2.044, -4.282, 1.5, 3.592], - [0.573, 1.753, -2.155, -2.613, 1.695, 2.713], - [3.978, -1.596, -3.205, 1.154, -3.339, 0.633], - ]; - - // Setup + // Transform Par let mut v = test_matrix(); - let v_copy = v.clone(); - let (nx, ny) = (v.shape()[0], v.shape()[1]); - let mut vhat = Array2::::zeros((nx, ny)); - let mut handler: FftHandler = FftHandler::new(ny); - - // Transform - ndfft_r2hc(&mut v, &mut vhat, &mut handler, 1); - ndifft_r2hc(&mut vhat, &mut v, &mut handler, 1); + ndfft_r2c_par(&mut v, &mut vhat, &mut handler, 1); + ndifft_r2c_par(&mut vhat, &mut v, &mut handler, 1); - // Assert - approx_eq(&vhat, &solution); + // // Assert + approx_eq_complex(&vhat, &solution); approx_eq(&v, &v_copy); } #[test] - fn test_fft_dct1() { + fn test_dct1() { // Solution from scipy.fft.dct(x, type=1) let solution = array![ [2.469, 4.259, 0.6, 0.04, -4.957, -1.353], @@ -956,138 +896,89 @@ mod test { // Assert approx_eq(&vhat, &solution); - } - - #[test] - fn test_fft_par() { - // Solution from np.fft.fft - let solution_re = array![ - [0.61, 3.105, 2.508, 0.048, -3.652, -2.019], - [2.795, 0.612, 0.219, 1.179, -2.801, 3.276], - [2.259, 0.601, 0.045, 0.979, 4.506, 0.118], - [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288], - [0.573, -0.96, -3.85, -2.613, -0.461, 4.467], - [3.978, -2.229, 0.133, 1.154, -6.544, -0.962], - ]; - - let solution_im = array![ - [0.61, -2.019, -3.652, 0.048, 2.508, 3.105], - [2.795, 3.276, -2.801, 1.179, 0.219, 0.612], - [2.259, 0.118, 4.506, 0.979, 0.045, 0.601], - [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896], - [0.573, 4.467, -0.461, -2.613, -3.85, -0.96], - [3.978, -0.962, -6.544, 1.154, 0.133, -2.229], - ]; - - let mut solution: Array2> = Array2::zeros(solution_re.raw_dim()); - for (s, (s_re, s_im)) in solution - .iter_mut() - .zip(solution_re.iter().zip(solution_im.iter())) - { - s.re = *s_re; - s.im = *s_im; - } - - // Setup - let mut v = test_matrix_complex(); - let v_copy = v.clone(); - let (nx, ny) = (v.shape()[0], v.shape()[1]); - let mut vhat = Array2::>::zeros((nx, ny)); - let mut handler: FftHandler = FftHandler::new(ny); - // Transform - ndfft_par(&mut v, &mut vhat, &mut handler, 1); - ndifft_par(&mut vhat, &mut v, &mut handler, 1); + // Transform Par + let mut v = test_matrix(); + nddct1_par(&mut v, &mut vhat, &mut handler, 1); // Assert - approx_eq_complex(&vhat, &solution); - approx_eq_complex(&v, &v_copy); + approx_eq(&vhat, &solution); } #[test] - fn test_fft_r2c_par() { - // Solution from np.fft.rfft - let solution_re = array![ - [0.61, 0.543, -0.572, 0.048], - [2.795, 1.944, -1.291, 1.179], - [2.259, 0.36, 2.275, 0.979], - [-0.296, 2.696, 2.044, -4.282], - [0.573, 1.753, -2.155, -2.613], - [3.978, -1.596, -3.205, 1.154], - ]; - - let solution_im = array![ - [0., -2.562, -3.08, 0.], - [0., 1.332, -1.51, 0.], - [0., -0.242, 2.23, 0.], - [0., 3.592, 1.5, 0.], - [0., 2.713, 1.695, 0.], - [0., 0.633, -3.339, 0.], + fn test_dct2() { + // Solution from scipy.fft.dct(x, type=2) + let solution = array![ + [1.22, 5.25, -1.621, -0.619, -5.906, -1.105], + [5.59, -0.209, 4.699, 0.134, -3.907, 1.838], + [4.518, 1.721, 0.381, 1.492, 6.138, 0.513], + [-0.592, -3.746, 8.262, 1.31, 4.642, -6.125], + [1.146, -5.709, 5.75, -4.275, 0.78, -0.963], + [7.956, -2.873, -2.13, 0.006, -8.988, 2.56], ]; - let mut solution: Array2> = Array2::zeros(solution_re.raw_dim()); - for (s, (s_re, s_im)) in solution - .iter_mut() - .zip(solution_re.iter().zip(solution_im.iter())) - { - s.re = *s_re; - s.im = *s_im; - } - // Setup let mut v = test_matrix(); - let v_copy = v.clone(); let (nx, ny) = (v.shape()[0], v.shape()[1]); - let mut vhat = Array2::>::zeros((nx, ny / 2 + 1)); - let mut handler: FftHandler = FftHandler::new(ny); + let mut vhat = Array2::::zeros((nx, ny)); + let mut handler: DctHandler = DctHandler::new(ny); // Transform - ndfft_r2c_par(&mut v, &mut vhat, &mut handler, 1); - ndifft_r2c_par(&mut vhat, &mut v, &mut handler, 1); + nddct2(&mut v, &mut vhat, &mut handler, 1); // Assert - approx_eq_complex(&vhat, &solution); - approx_eq(&v, &v_copy); + approx_eq(&vhat, &solution); + + // Transform Par + let mut v = test_matrix(); + nddct2_par(&mut v, &mut vhat, &mut handler, 1); + + // Assert + approx_eq(&vhat, &solution); } #[test] - fn test_fft_r2hc_par() { - // Solution from np.fft.rfft + fn test_dct3() { + // Solution from scipy.fft.dct(x, type=3) let solution = array![ - [0.61, 0.543, -0.572, 0.048, -3.08, -2.562], - [2.795, 1.944, -1.291, 1.179, -1.51, 1.332], - [2.259, 0.36, 2.275, 0.979, 2.23, -0.242], - [-0.296, 2.696, 2.044, -4.282, 1.5, 3.592], - [0.573, 1.753, -2.155, -2.613, 1.695, 2.713], - [3.978, -1.596, -3.205, 1.154, -3.339, 0.633], + [2.898, 4.571, -0.801, 1.65, -5.427, -2.291], + [2.701, -0.578, 5.768, -0.335, -3.158, 0.882], + [2.348, -0.184, -1.258, 0.048, 5.472, 2.081], + [-3.421, -2.075, 6.944, 0.264, 7.505, -4.315], + [-1.43, -3.023, 6.317, -5.259, 1.991, -1.44], + [5.76, -4.047, 1.974, 0.376, -8.651, 0.117], ]; // Setup let mut v = test_matrix(); - let v_copy = v.clone(); let (nx, ny) = (v.shape()[0], v.shape()[1]); let mut vhat = Array2::::zeros((nx, ny)); - let mut handler: FftHandler = FftHandler::new(ny); + let mut handler: DctHandler = DctHandler::new(ny); // Transform - ndfft_r2hc_par(&mut v, &mut vhat, &mut handler, 1); - ndifft_r2hc_par(&mut vhat, &mut v, &mut handler, 1); + nddct3(&mut v, &mut vhat, &mut handler, 1); + + // Assert + approx_eq(&vhat, &solution); + + // Transform Par + let mut v = test_matrix(); + nddct3_par(&mut v, &mut vhat, &mut handler, 1); // Assert approx_eq(&vhat, &solution); - approx_eq(&v, &v_copy); } #[test] - fn test_fft_dct1_par() { - // Solution from scipy.fft.dct(x, type=1) + fn test_dct4() { + // Solution from scipy.fft.dct(x, type=4) let solution = array![ - [2.469, 4.259, 0.6, 0.04, -4.957, -1.353], - [3.953, -0.374, 4.759, -0.436, -2.643, 2.235], - [2.632, 0.818, -1.609, 1.053, 5.008, 1.008], - [-3.652, -2.628, 4.81, 2.632, 4.666, -7.138], - [-0.835, -2.982, 4.105, -3.192, 1.265, -2.297], - [8.743, -2.422, 1.167, -0.841, -7.506, 3.011], + [3.18, 2.73, -2.314, -2.007, -5.996, 2.127], + [3.175, 0.865, 4.939, -4.305, -0.443, 1.568], + [3.537, 0.677, 0.371, 4.186, 4.528, -1.531], + [-2.687, 1.838, 6.968, 0.899, 2.456, -8.79], + [-2.289, -1.002, 3.67, -5.705, 3.867, -4.349], + [4.192, -5.626, 1.789, -6.057, -4.61, 4.627], ]; // Setup @@ -1097,7 +988,14 @@ mod test { let mut handler: DctHandler = DctHandler::new(ny); // Transform - nddct1_par(&mut v, &mut vhat, &mut handler, 1); + nddct4(&mut v, &mut vhat, &mut handler, 1); + + // Assert + approx_eq(&vhat, &solution); + + // Transform Par + let mut v = test_matrix(); + nddct4_par(&mut v, &mut vhat, &mut handler, 1); // Assert approx_eq(&vhat, &solution);