Skip to content

Commit

Permalink
add a trait to dedup code
Browse files Browse the repository at this point in the history
  • Loading branch information
kevaundray committed Sep 26, 2024
1 parent 0fc446a commit 654ccc3
Showing 1 changed file with 22 additions and 50 deletions.
72 changes: 22 additions & 50 deletions cryptography/polynomial/src/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,60 +223,27 @@ impl Domain {
}
}

/// Computes a FFT of the field elements(scalars).
///
/// Note: This is essentially multiple inner products.
///
/// TODO: This method is still duplicated below
fn fft_scalar_inplace(twiddle_factors: &[Scalar], a: &mut [Scalar]) {
let n = a.len();
let log_n = log2_pow2(n);
assert_eq!(n, 1 << log_n);

// Bit-reversal permutation
for k in 0..n {
let rk = bitreverse(k as u32, log_n) as usize;
if k < rk {
a.swap(rk, k);
}
}

let mut m = 1;
for s in 0..log_n {
let w_m = twiddle_factors[s as usize];
for k in (0..n).step_by(2 * m) {
let mut w = Scalar::ONE;

for j in 0..m {
let t = if w == Scalar::ONE {
a[k + j + m]
} else if w == -Scalar::ONE {
-a[k + j + m]
} else {
a[k + j + m] * w
};
use std::ops::{Add, Mul, Neg, Sub};

trait FFTElement:
Sized
+ Copy
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Scalar, Output = Self>
+ Neg<Output = Self>
{
}

let u = a[k + j];
impl FFTElement for Scalar {}

a[k + j] = u + t;
a[k + j + m] = u - t;

w *= w_m;
}
}
m *= 2;
}
}
impl FFTElement for G1Projective {}

/// Computes a FFT of the group elements(points).
///
/// Note: This is essentially multiple multi-scalar multiplications.
fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
fn fft_inplace<T: FFTElement>(twiddle_factors: &[Scalar], a: &mut [T]) {
let n = a.len();
let log_n = log2_pow2(n);
assert_eq!(n, 1 << log_n);

// Bit-reversal permutation
for k in 0..n {
let rk = bitreverse(k as u32, log_n) as usize;
if k < rk {
Expand All @@ -294,12 +261,9 @@ fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
a[k + j + m]
} else if w == -Scalar::ONE {
-a[k + j + m]
} else if a[k + j + m].is_identity().into() {
G1Projective::identity()
} else {
a[k + j + m] * w
};

let u = a[k + j];
a[k + j] = u + t;
a[k + j + m] = u - t;
Expand All @@ -310,6 +274,14 @@ fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
}
}

fn fft_scalar_inplace(twiddle_factors: &[Scalar], a: &mut [Scalar]) {
fft_inplace(twiddle_factors, a);
}

fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
fft_inplace(twiddle_factors, a);
}

fn bitreverse(mut n: u32, l: u32) -> u32 {
let mut r = 0;
for _ in 0..l {
Expand Down

0 comments on commit 654ccc3

Please sign in to comment.