diff --git a/src/algebra/scalar/authenticated_scalar.rs b/src/algebra/scalar/authenticated_scalar.rs index f956827..781301d 100644 --- a/src/algebra/scalar/authenticated_scalar.rs +++ b/src/algebra/scalar/authenticated_scalar.rs @@ -1189,21 +1189,11 @@ where domain: D, ) -> Vec> { assert!(!x.is_empty(), "Cannot compute FFT of empty vector"); - let fabric = x[0].fabric(); - - // Extend to the next power of two - let n = x.len(); - let padding_length = n.next_power_of_two() - n; - let pad = fabric.zeros_authenticated(padding_length); - let padded_input = [x, &pad].concat(); // Take the FFT of the shares and the macs separately - let shares = padded_input.iter().map(|v| v.share()).collect_vec(); - let macs = padded_input.iter().map(|v| v.mac_share()).collect_vec(); - let modifiers = padded_input - .into_iter() - .map(|v| v.public_modifier) - .collect_vec(); + let shares = x.iter().map(|v| v.share()).collect_vec(); + let macs = x.iter().map(|v| v.mac_share()).collect_vec(); + let modifiers = x.iter().map(|v| v.public_modifier.clone()).collect_vec(); let (share_fft, mac_fft, modifier_fft) = if is_forward { ( @@ -1219,7 +1209,7 @@ where ) }; - let mut res = Vec::with_capacity(n); + let mut res = Vec::with_capacity(domain.size()); for (share, mac, modifier) in izip!(share_fft, mac_fft, modifier_fft) { res.push(AuthenticatedScalarResult { share: MpcScalarResult::new_shared(share), @@ -1488,12 +1478,13 @@ mod tests { async fn test_fft() { let mut rng = thread_rng(); let n: usize = rng.gen_range(0..100); + let domain_size = rng.gen_range(n..10 * n); let values = (0..n) .map(|_| Scalar::::random(&mut rng)) .collect_vec(); - let domain = Radix2EvaluationDomain::::new(n).unwrap(); + let domain = Radix2EvaluationDomain::::new(domain_size).unwrap(); let fft_res = domain.fft( &values .iter() @@ -1511,9 +1502,9 @@ mod tests { .into_iter() .map(|v| v + Scalar::one()) .collect_vec(); - let fft = AuthenticatedScalarResult::fft::>( - &shared_values, - ); + let fft = AuthenticatedScalarResult::fft_with_domain::< + Radix2EvaluationDomain, + >(&shared_values, domain); let opening = AuthenticatedScalarResult::open_authenticated_batch(&fft); future::join_all(opening.into_iter()) @@ -1531,12 +1522,13 @@ mod tests { async fn test_ifft() { let mut rng = thread_rng(); let n: usize = rng.gen_range(0..100); + let domain_size = rng.gen_range(n..10 * n); let values = (0..n) .map(|_| Scalar::::random(&mut rng)) .collect_vec(); - let domain = Radix2EvaluationDomain::::new(n).unwrap(); + let domain = Radix2EvaluationDomain::::new(domain_size).unwrap(); let ifft_res = domain.ifft( &values .iter() @@ -1555,9 +1547,9 @@ mod tests { .map(|v| v + Scalar::one()) .collect_vec(); - let ifft = AuthenticatedScalarResult::ifft::>( - &shared_values, - ); + let ifft = AuthenticatedScalarResult::ifft_with_domain::< + Radix2EvaluationDomain, + >(&shared_values, domain); let opening = AuthenticatedScalarResult::open_authenticated_batch(&ifft); future::join_all(opening.into_iter()) diff --git a/src/algebra/scalar/scalar.rs b/src/algebra/scalar/scalar.rs index c5edf6e..f40c57f 100644 --- a/src/algebra/scalar/scalar.rs +++ b/src/algebra/scalar/scalar.rs @@ -462,7 +462,7 @@ where domain: D, ) -> Vec> { assert!(!x.is_empty(), "Cannot compute fft of empty sequence"); - let n = x.len().next_power_of_two(); + let n = domain.size(); let fabric = x[0].fabric(); let ids = x.iter().map(|v| v.id).collect_vec(); @@ -495,7 +495,7 @@ where domain: D, ) -> Vec> { assert!(!x.is_empty(), "Cannot compute fft of empty sequence"); - let n = x.len().next_power_of_two(); + let n = domain.size(); let fabric = x[0].fabric(); let ids = x.iter().map(|v| v.id).collect_vec(); @@ -696,12 +696,13 @@ mod test { async fn test_circuit_fft() { let mut rng = thread_rng(); let n: usize = rng.gen_range(1..=100); + let domain_size = rng.gen_range(n..10 * n); let seq = (0..n) .map(|_| Scalar::::random(&mut rng)) .collect_vec(); - let domain = Radix2EvaluationDomain::::new(n).unwrap(); + let domain = Radix2EvaluationDomain::::new(domain_size).unwrap(); let fft_res = domain.fft(&seq.iter().map(|s| s.inner()).collect_vec()); let expected_res = fft_res.into_iter().map(Scalar::new).collect_vec(); @@ -710,12 +711,15 @@ mod test { async move { let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec(); - let res = ScalarResult::fft::>(&seq_alloc); + let res = ScalarResult::fft_with_domain::>( + &seq_alloc, domain, + ); future::join_all(res.into_iter()).await } }) .await; + assert_eq!(res.len(), expected_res.len()); assert_eq!(res, expected_res); } @@ -724,12 +728,13 @@ mod test { async fn test_circuit_ifft() { let mut rng = thread_rng(); let n: usize = rng.gen_range(1..=100); + let domain_size = rng.gen_range(n..10 * n); let seq = (0..n) .map(|_| Scalar::::random(&mut rng)) .collect_vec(); - let domain = Radix2EvaluationDomain::::new(n).unwrap(); + let domain = Radix2EvaluationDomain::::new(domain_size).unwrap(); let ifft_res = domain.ifft(&seq.iter().map(|s| s.inner()).collect_vec()); let expected_res = ifft_res.into_iter().map(Scalar::new).collect_vec(); @@ -738,12 +743,15 @@ mod test { async move { let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec(); - let res = ScalarResult::ifft::>(&seq_alloc); + let res = ScalarResult::ifft_with_domain::>( + &seq_alloc, domain, + ); future::join_all(res.into_iter()).await } }) .await; + assert_eq!(res.len(), expected_res.len()); assert_eq!(res, expected_res); } }