Skip to content

Commit

Permalink
algebra: scalar: Fix FFT over larger domain than input sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Nov 3, 2023
1 parent 6ef64d5 commit 0c22cb0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
36 changes: 14 additions & 22 deletions src/algebra/scalar/authenticated_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1189,21 +1189,11 @@ where
domain: D,
) -> Vec<AuthenticatedScalarResult<C>> {
assert!(!x.is_empty(), "Cannot compute FFT of empty vector");
let fabric = x[0].fabric();

// Extend to the next power of two
let n = x.len();
let padding_length = n.next_power_of_two() - n;
let pad = fabric.zeros_authenticated(padding_length);
let padded_input = [x, &pad].concat();

// Take the FFT of the shares and the macs separately
let shares = padded_input.iter().map(|v| v.share()).collect_vec();
let macs = padded_input.iter().map(|v| v.mac_share()).collect_vec();
let modifiers = padded_input
.into_iter()
.map(|v| v.public_modifier)
.collect_vec();
let shares = x.iter().map(|v| v.share()).collect_vec();
let macs = x.iter().map(|v| v.mac_share()).collect_vec();
let modifiers = x.iter().map(|v| v.public_modifier.clone()).collect_vec();

let (share_fft, mac_fft, modifier_fft) = if is_forward {
(
Expand All @@ -1219,7 +1209,7 @@ where
)
};

let mut res = Vec::with_capacity(n);
let mut res = Vec::with_capacity(domain.size());
for (share, mac, modifier) in izip!(share_fft, mac_fft, modifier_fft) {
res.push(AuthenticatedScalarResult {
share: MpcScalarResult::new_shared(share),
Expand Down Expand Up @@ -1488,12 +1478,13 @@ mod tests {
async fn test_fft() {
let mut rng = thread_rng();
let n: usize = rng.gen_range(0..100);
let domain_size = rng.gen_range(n..10 * n);

let values = (0..n)
.map(|_| Scalar::<TestCurve>::random(&mut rng))
.collect_vec();

let domain = Radix2EvaluationDomain::<TestPolyField>::new(n).unwrap();
let domain = Radix2EvaluationDomain::<TestPolyField>::new(domain_size).unwrap();
let fft_res = domain.fft(
&values
.iter()
Expand All @@ -1511,9 +1502,9 @@ mod tests {
.into_iter()
.map(|v| v + Scalar::one())
.collect_vec();
let fft = AuthenticatedScalarResult::fft::<Radix2EvaluationDomain<TestPolyField>>(
&shared_values,
);
let fft = AuthenticatedScalarResult::fft_with_domain::<
Radix2EvaluationDomain<TestPolyField>,
>(&shared_values, domain);

let opening = AuthenticatedScalarResult::open_authenticated_batch(&fft);
future::join_all(opening.into_iter())
Expand All @@ -1531,12 +1522,13 @@ mod tests {
async fn test_ifft() {
let mut rng = thread_rng();
let n: usize = rng.gen_range(0..100);
let domain_size = rng.gen_range(n..10 * n);

let values = (0..n)
.map(|_| Scalar::<TestCurve>::random(&mut rng))
.collect_vec();

let domain = Radix2EvaluationDomain::<TestPolyField>::new(n).unwrap();
let domain = Radix2EvaluationDomain::<TestPolyField>::new(domain_size).unwrap();
let ifft_res = domain.ifft(
&values
.iter()
Expand All @@ -1555,9 +1547,9 @@ mod tests {
.map(|v| v + Scalar::one())
.collect_vec();

let ifft = AuthenticatedScalarResult::ifft::<Radix2EvaluationDomain<TestPolyField>>(
&shared_values,
);
let ifft = AuthenticatedScalarResult::ifft_with_domain::<
Radix2EvaluationDomain<TestPolyField>,
>(&shared_values, domain);

let opening = AuthenticatedScalarResult::open_authenticated_batch(&ifft);
future::join_all(opening.into_iter())
Expand Down
20 changes: 14 additions & 6 deletions src/algebra/scalar/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ where
domain: D,
) -> Vec<ScalarResult<C>> {
assert!(!x.is_empty(), "Cannot compute fft of empty sequence");
let n = x.len().next_power_of_two();
let n = domain.size();

let fabric = x[0].fabric();
let ids = x.iter().map(|v| v.id).collect_vec();
Expand Down Expand Up @@ -495,7 +495,7 @@ where
domain: D,
) -> Vec<ScalarResult<C>> {
assert!(!x.is_empty(), "Cannot compute fft of empty sequence");
let n = x.len().next_power_of_two();
let n = domain.size();

let fabric = x[0].fabric();
let ids = x.iter().map(|v| v.id).collect_vec();
Expand Down Expand Up @@ -696,12 +696,13 @@ mod test {
async fn test_circuit_fft() {
let mut rng = thread_rng();
let n: usize = rng.gen_range(1..=100);
let domain_size = rng.gen_range(n..10 * n);

let seq = (0..n)
.map(|_| Scalar::<TestCurve>::random(&mut rng))
.collect_vec();

let domain = Radix2EvaluationDomain::<TestPolyField>::new(n).unwrap();
let domain = Radix2EvaluationDomain::<TestPolyField>::new(domain_size).unwrap();
let fft_res = domain.fft(&seq.iter().map(|s| s.inner()).collect_vec());
let expected_res = fft_res.into_iter().map(Scalar::new).collect_vec();

Expand All @@ -710,12 +711,15 @@ mod test {
async move {
let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec();

let res = ScalarResult::fft::<Radix2EvaluationDomain<TestPolyField>>(&seq_alloc);
let res = ScalarResult::fft_with_domain::<Radix2EvaluationDomain<TestPolyField>>(
&seq_alloc, domain,
);
future::join_all(res.into_iter()).await
}
})
.await;

assert_eq!(res.len(), expected_res.len());
assert_eq!(res, expected_res);
}

Expand All @@ -724,12 +728,13 @@ mod test {
async fn test_circuit_ifft() {
let mut rng = thread_rng();
let n: usize = rng.gen_range(1..=100);
let domain_size = rng.gen_range(n..10 * n);

let seq = (0..n)
.map(|_| Scalar::<TestCurve>::random(&mut rng))
.collect_vec();

let domain = Radix2EvaluationDomain::<TestPolyField>::new(n).unwrap();
let domain = Radix2EvaluationDomain::<TestPolyField>::new(domain_size).unwrap();
let ifft_res = domain.ifft(&seq.iter().map(|s| s.inner()).collect_vec());
let expected_res = ifft_res.into_iter().map(Scalar::new).collect_vec();

Expand All @@ -738,12 +743,15 @@ mod test {
async move {
let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec();

let res = ScalarResult::ifft::<Radix2EvaluationDomain<TestPolyField>>(&seq_alloc);
let res = ScalarResult::ifft_with_domain::<Radix2EvaluationDomain<TestPolyField>>(
&seq_alloc, domain,
);
future::join_all(res.into_iter()).await
}
})
.await;

assert_eq!(res.len(), expected_res.len());
assert_eq!(res, expected_res);
}
}

0 comments on commit 0c22cb0

Please sign in to comment.