diff --git a/Cargo.toml b/Cargo.toml index 1e12ea5..17d4d52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,10 @@ serde_json = { version = "1.0", default-features = false } itertools = "0.11" prettytable = "0.10.0" +tracing = "0.1.40" +tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } +tracing-forest = { version = "0.1.6", features = ["ansi", "smallvec"] } + [features] bench = [] @@ -49,4 +53,4 @@ incremental = false [workspace] members = [ "rs-script", -] \ No newline at end of file +] diff --git a/rs-script/Cargo.toml b/rs-script/Cargo.toml index 27bc01b..919b33b 100644 --- a/rs-script/Cargo.toml +++ b/rs-script/Cargo.toml @@ -17,3 +17,12 @@ num-traits = "0.2" ndarray = "0.15.6" itertools = "0.13.0" rayon = "1.10.0" +tracing = "0.1.40" +tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } +tracing-forest = { version = "0.1.6", features = ["ansi", "smallvec"] } +concrete-ntt = { git = "https://github.com/zama-ai/concrete-ntt", features = [ + "nightly", +] } + +[features] +sanity-check = [] diff --git a/rs-script/src/main.rs b/rs-script/src/main.rs index 9299a39..1c04ee5 100644 --- a/rs-script/src/main.rs +++ b/rs-script/src/main.rs @@ -1,5 +1,6 @@ mod poly; +use concrete_ntt::native64::Plan32; use fhe::bfv::{ BfvParameters, BfvParametersBuilder, Ciphertext, Encoding, Plaintext, PublicKey, SecretKey, }; @@ -8,9 +9,9 @@ use fhe_math::{ zq::Modulus, }; use fhe_traits::*; -use itertools::izip; +use itertools::{izip, Itertools}; use num_bigint::BigInt; -use num_traits::{Num, Signed, ToPrimitive, Zero}; +use num_traits::{FromPrimitive, Num, Signed, ToPrimitive, Zero}; use rand::rngs::StdRng; use rand::SeedableRng; use rayon::iter::{ParallelBridge, ParallelIterator}; @@ -21,6 +22,9 @@ use std::ops::Deref; use std::path::Path; use std::sync::Arc; use std::vec; +use tracing::info_span; +use tracing_forest::ForestLayer; +use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry}; use poly::*; @@ -266,20 +270,16 @@ impl InputValidationVectors { cyclo[0] = BigInt::from(1u64); // x^N term cyclo[N as usize] = BigInt::from(1u64); // x^0 term - // Print - /* - println!("m = {:?}\n", &m); - println!("k1 = {:?}\n", &k1); - println!("u = {:?}\n", &u); - println!("e0 = {:?}\n", &e0); - println!("e1 = {:?}\n", &e1); - */ - // Initialize matrices to store results let num_moduli = ctx.moduli().len(); let mut res = InputValidationVectors::new(num_moduli, N as usize); + let plan = PlanNtt::try_new(N as usize * 2).unwrap(); + #[cfg(feature = "sanity-check")] + let plan_cyclo = PlanNtt::try_new(N as usize * 4).unwrap(); + // Perform the main computation logic + #[allow(clippy::type_complexity)] let results: Vec<( usize, Vec, @@ -299,158 +299,210 @@ impl InputValidationVectors { pk1.coefficients().rows() ) .enumerate() + // .take(1) .par_bridge() .map( |(i, (qi, ct0_coeffs, ct1_coeffs, pk0_coeffs, pk1_coeffs))| { // --------------------------------------------------- ct0i --------------------------------------------------- - - // Convert to vectors of bigint, center, and reverse order. - let mut ct0i: Vec = - ct0_coeffs.iter().rev().map(|&x| BigInt::from(x)).collect(); - let mut ct1i: Vec = - ct1_coeffs.iter().rev().map(|&x| BigInt::from(x)).collect(); - let mut pk0i: Vec = - pk0_coeffs.iter().rev().map(|&x| BigInt::from(x)).collect(); - let mut pk1i: Vec = - pk1_coeffs.iter().rev().map(|&x| BigInt::from(x)).collect(); - - let qi_bigint = BigInt::from(qi.modulus()); - - reduce_and_center_coefficients_mut(&mut ct0i, &qi_bigint); - reduce_and_center_coefficients_mut(&mut ct1i, &qi_bigint); - reduce_and_center_coefficients_mut(&mut pk0i, &qi_bigint); - reduce_and_center_coefficients_mut(&mut pk1i, &qi_bigint); - - // k0qi = -t^{-1} mod qi - let koqi_u64 = qi.inv(qi.neg(t.modulus())).unwrap(); - let k0qi = BigInt::from(koqi_u64); // Do not need to center this - - // ki = k1 * k0qi - let ki = poly_scalar_mul(&k1, &k0qi); - - // Calculate ct0i_hat = pk0 * ui + e0i + ki - let ct0i_hat = { - let pk0i_times_u = poly_mul(&pk0i, &u); - assert_eq!((pk0i_times_u.len() as u64) - 1, 2 * (N - 1)); - - let e0_plus_ki = poly_add(&e0, &ki); - assert_eq!((e0_plus_ki.len() as u64) - 1, N - 1); - - poly_add(&pk0i_times_u, &e0_plus_ki) - }; - assert_eq!((ct0i_hat.len() as u64) - 1, 2 * (N - 1)); - - // Check whether ct0i_hat mod R_qi (the ring) is equal to ct0i - let mut ct0i_hat_mod_rqi = ct0i_hat.clone(); - reduce_in_ring(&mut ct0i_hat_mod_rqi, &cyclo, &qi_bigint); - assert_eq!(&ct0i, &ct0i_hat_mod_rqi); - - // Compute r2i numerator = ct0i - ct0i_hat and reduce/center the polynomial - let ct0i_minus_ct0i_hat = poly_sub(&ct0i, &ct0i_hat); - assert_eq!((ct0i_minus_ct0i_hat.len() as u64) - 1, 2 * (N - 1)); - let mut ct0i_minus_ct0i_hat_mod_zqi = ct0i_minus_ct0i_hat.clone(); - reduce_and_center_coefficients_mut(&mut ct0i_minus_ct0i_hat_mod_zqi, &qi_bigint); - - // Compute r2i as the quotient of numerator divided by the cyclotomic polynomial - // to produce: (ct0i - ct0i_hat) / (x^N + 1) mod Z_qi. Remainder should be empty. - let (r2i, r2i_rem) = poly_div(&ct0i_minus_ct0i_hat_mod_zqi, &cyclo); - assert!(r2i_rem.is_empty()); - assert_eq!((r2i.len() as u64) - 1, N - 2); // Order(r2i) = N - 2 - - // Assert that (ct0i - ct0i_hat) = (r2i * cyclo) mod Z_qi - let r2i_times_cyclo = poly_mul(&r2i, &cyclo); - let mut r2i_times_cyclo_mod_zqi = r2i_times_cyclo.clone(); - reduce_and_center_coefficients_mut(&mut r2i_times_cyclo_mod_zqi, &qi_bigint); - assert_eq!(&ct0i_minus_ct0i_hat_mod_zqi, &r2i_times_cyclo_mod_zqi); - assert_eq!((r2i_times_cyclo.len() as u64) - 1, 2 * (N - 1)); - - // Calculate r1i = (ct0i - ct0i_hat - r2i * cyclo) / qi mod Z_p. Remainder should be empty. - let r1i_num = poly_sub(&ct0i_minus_ct0i_hat, &r2i_times_cyclo); - assert_eq!((r1i_num.len() as u64) - 1, 2 * (N - 1)); - - let (r1i, r1i_rem) = poly_div(&r1i_num, &[qi_bigint.clone()]); - assert!(r1i_rem.is_empty()); - assert_eq!((r1i.len() as u64) - 1, 2 * (N - 1)); // Order(r1i) = 2*(N-1) - assert_eq!(&r1i_num, &poly_mul(&r1i, &[qi_bigint.clone()])); - - // Assert that ct0i = ct0i_hat + r1i * qi + r2i * cyclo mod Z_p - let r1i_times_qi = poly_scalar_mul(&r1i, &qi_bigint); - let mut ct0i_calculated = - poly_add(&poly_add(&ct0i_hat, &r1i_times_qi), &r2i_times_cyclo); - - while ct0i_calculated.len() > 0 && ct0i_calculated[0].is_zero() { - ct0i_calculated.remove(0); - } - - assert_eq!(&ct0i, &ct0i_calculated); - - // --------------------------------------------------- ct1i --------------------------------------------------- - - // Calculate ct1i_hat = pk1i * ui + e1i - let ct1i_hat = { - let pk1i_times_u = poly_mul(&pk1i, &u); - assert_eq!((pk1i_times_u.len() as u64) - 1, 2 * (N - 1)); - - poly_add(&pk1i_times_u, &e1) - }; - assert_eq!((ct1i_hat.len() as u64) - 1, 2 * (N - 1)); - - // Check whether ct1i_hat mod R_qi (the ring) is equal to ct1i - let mut ct1i_hat_mod_rqi = ct1i_hat.clone(); - reduce_in_ring(&mut ct1i_hat_mod_rqi, &cyclo, &qi_bigint); - assert_eq!(&ct1i, &ct1i_hat_mod_rqi); - - // Compute p2i numerator = ct1i - ct1i_hat - let ct1i_minus_ct1i_hat = poly_sub(&ct1i, &ct1i_hat); - assert_eq!((ct1i_minus_ct1i_hat.len() as u64) - 1, 2 * (N - 1)); - let mut ct1i_minus_ct1i_hat_mod_zqi = ct1i_minus_ct1i_hat.clone(); - reduce_and_center_coefficients_mut(&mut ct1i_minus_ct1i_hat_mod_zqi, &qi_bigint); - - // Compute p2i as the quotient of numerator divided by the cyclotomic polynomial, - // and reduce/center the resulting coefficients to produce: - // (ct1i - ct1i_hat) / (x^N + 1) mod Z_qi. Remainder should be empty. - let (p2i, p2i_rem) = poly_div(&ct1i_minus_ct1i_hat_mod_zqi, &cyclo.clone()); - assert!(p2i_rem.is_empty()); - assert_eq!((p2i.len() as u64) - 1, N - 2); // Order(p2i) = N - 2 - - // Assert that (ct1i - ct1i_hat) = (p2i * cyclo) mod Z_qi - let p2i_times_cyclo: Vec = poly_mul(&p2i, &cyclo); - let mut p2i_times_cyclo_mod_zqi = p2i_times_cyclo.clone(); - reduce_and_center_coefficients_mut(&mut p2i_times_cyclo_mod_zqi, &qi_bigint); - assert_eq!(&ct1i_minus_ct1i_hat_mod_zqi, &p2i_times_cyclo_mod_zqi); - assert_eq!((p2i_times_cyclo.len() as u64) - 1, 2 * (N - 1)); - - // Calculate p1i = (ct1i - ct1i_hat - p2i * cyclo) / qi mod Z_p. Remainder should be empty. - let p1i_num = poly_sub(&ct1i_minus_ct1i_hat, &p2i_times_cyclo); - assert_eq!((p1i_num.len() as u64) - 1, 2 * (N - 1)); - - let (p1i, p1i_rem) = poly_div(&p1i_num, &[BigInt::from(qi.modulus())]); - assert!(p1i_rem.is_empty()); - assert_eq!((p1i.len() as u64) - 1, 2 * (N - 1)); // Order(p1i) = 2*(N-1) - assert_eq!(&p1i_num, &poly_mul(&p1i, &[qi_bigint.clone()])); - - // Assert that ct1i = ct1i_hat + p1i * qi + p2i * cyclo mod Z_p - let p1i_times_qi = poly_scalar_mul(&p1i, &qi_bigint); - let mut ct1i_calculated = - poly_add(&poly_add(&ct1i_hat, &p1i_times_qi), &p2i_times_cyclo); - - while ct1i_calculated.len() > 0 && ct1i_calculated[0].is_zero() { - ct1i_calculated.remove(0); - } - - assert_eq!(&ct1i, &ct1i_calculated); - - /* - println!("qi = {:?}\n", &qi_bigint); - println!("ct0i = {:?}\n", &ct0i); - println!("k0qi = {:?}\n", &k0qi); - println!("pk0 = Polynomial({:?})\n", &pk0i); - println!("pk1 = Polynomial({:?})\n", &pk1i); - println!("ki = {:?}\n", &ki); - println!("ct0i_hat_mod_rqi = {:?}\n", &ct0i_hat_mod_rqi); - */ - - (i, r2i, r1i, k0qi, ct0i, ct1i, pk0i, pk1i, p1i, p2i) + info_span!("results", i).in_scope(|| { + // Convert to vectors of bigint, center, and reverse order. + let mut ct0i: Vec = + ct0_coeffs.iter().rev().map(|&x| BigInt::from(x)).collect(); + let mut ct1i: Vec = + ct1_coeffs.iter().rev().map(|&x| BigInt::from(x)).collect(); + let mut pk0i: Vec = + pk0_coeffs.iter().rev().map(|&x| BigInt::from(x)).collect(); + let mut pk1i: Vec = + pk1_coeffs.iter().rev().map(|&x| BigInt::from(x)).collect(); + + let qi_bigint = BigInt::from(qi.modulus()); + + reduce_and_center_coefficients_mut(&mut ct0i, &qi_bigint); + reduce_and_center_coefficients_mut(&mut ct1i, &qi_bigint); + reduce_and_center_coefficients_mut(&mut pk0i, &qi_bigint); + reduce_and_center_coefficients_mut(&mut pk1i, &qi_bigint); + + // k0qi = -t^{-1} mod qi + let koqi_u64 = qi.inv(qi.neg(t.modulus())).unwrap(); + let k0qi = BigInt::from(koqi_u64); // Do not need to center this + + // ki = k1 * k0qi + let ki = poly_scalar_mul(&k1, &k0qi); + + // Calculate ct0i_hat = pk0 * ui + e0i + ki + let ct0i_hat = info_span!("compute ct0i_hat").in_scope(|| { + let pk0i_times_u = poly_mul(&plan, &pk0i, &u); + + #[cfg(feature = "sanity-check")] + assert_eq!((pk0i_times_u.len() as u64) - 1, 2 * (N - 1)); + + let e0_plus_ki = poly_add(&e0, &ki); + + #[cfg(feature = "sanity-check")] + assert_eq!((e0_plus_ki.len() as u64) - 1, N - 1); + + poly_add(&pk0i_times_u, &e0_plus_ki) + }); + + #[cfg(feature = "sanity-check")] + assert_eq!((ct0i_hat.len() as u64) - 1, 2 * (N - 1)); + + // Check whether ct0i_hat mod R_qi (the ring) is equal to ct0i + let mut ct0i_hat_mod_rqi = ct0i_hat.clone(); + info_span!("reduce_in_ring: ct0i_hat_mod_rqi % qi").in_scope(|| { + reduce_in_ring(&mut ct0i_hat_mod_rqi, &cyclo, &qi_bigint); + }); + + #[cfg(feature = "sanity-check")] + assert_eq!(&ct0i, &ct0i_hat_mod_rqi); + + // Compute r2i numerator = ct0i - ct0i_hat and reduce/center the polynomial + let ct0i_minus_ct0i_hat = poly_sub(&ct0i, &ct0i_hat); + + #[cfg(feature = "sanity-check")] + assert_eq!((ct0i_minus_ct0i_hat.len() as u64) - 1, 2 * (N - 1)); + + let mut ct0i_minus_ct0i_hat_mod_zqi = ct0i_minus_ct0i_hat.clone(); + reduce_and_center_coefficients_mut( + &mut ct0i_minus_ct0i_hat_mod_zqi, + &qi_bigint, + ); + + // Compute r2i as the quotient of numerator divided by the cyclotomic polynomial + // to produce: (ct0i - ct0i_hat) / (x^N + 1) mod Z_qi. Remainder should be empty. + let (r2i, _r2i_rem) = info_span!("poly_div_cyclo: r2i") + .in_scope(|| poly_div_cyclo(&ct0i_minus_ct0i_hat_mod_zqi, cyclo.len() - 1)); + + #[cfg(feature = "sanity-check")] + { + assert!(_r2i_rem.is_empty()); + assert_eq!((r2i.len() as u64) - 1, N - 2); // Order(r2i) = N - 2 + } + + // Assert that (ct0i - ct0i_hat) = (r2i * cyclo) mod Z_qi + let r2i_times_cyclo = info_span!("poly_mul: r2i * cyclo") + .in_scope(|| poly_mul(&plan, &r2i, &cyclo)); + let mut r2i_times_cyclo_mod_zqi = r2i_times_cyclo.clone(); + reduce_and_center_coefficients_mut(&mut r2i_times_cyclo_mod_zqi, &qi_bigint); + #[cfg(feature = "sanity-check")] + { + assert_eq!(&ct0i_minus_ct0i_hat_mod_zqi, &r2i_times_cyclo_mod_zqi); + assert_eq!((r2i_times_cyclo.len() as u64) - 1, 2 * (N - 1)); + } + + // Calculate r1i = (ct0i - ct0i_hat - r2i * cyclo) / qi mod Z_p. Remainder should be empty. + let r1i_num = poly_sub(&ct0i_minus_ct0i_hat, &r2i_times_cyclo); + #[cfg(feature = "sanity-check")] + assert_eq!((r1i_num.len() as u64) - 1, 2 * (N - 1)); + + let (r1i, _r1i_rem) = info_span!("poly_div: r1i_num / qi") + .in_scope(|| poly_div(&r1i_num, &[qi_bigint.clone()])); + #[cfg(feature = "sanity-check")] + { + assert!(_r1i_rem.is_empty()); + assert_eq!((r1i.len() as u64) - 1, 2 * (N - 1)); // Order(r1i) = 2*(N-1) + assert_eq!(&r1i_num, &poly_mul(&plan_cyclo, &r1i, &[qi_bigint.clone()])); + } + + // Assert that ct0i = ct0i_hat + r1i * qi + r2i * cyclo mod Z_p + #[cfg(feature = "sanity-check")] + { + let r1i_times_qi = info_span!("poly_scalar_mul: r1i * qi_bigint") + .in_scope(|| poly_scalar_mul(&r1i, &qi_bigint)); + let mut ct0i_calculated = + poly_add(&poly_add(&ct0i_hat, &r1i_times_qi), &r2i_times_cyclo); + + while ct0i_calculated.len() > 0 && ct0i_calculated[0].is_zero() { + ct0i_calculated.remove(0); + } + + assert_eq!(&ct0i, &ct0i_calculated); + } + + // --------------------------------------------------- ct1i --------------------------------------------------- + + // Calculate ct1i_hat = pk1i * ui + e1i + let ct1i_hat = info_span!("poly_mul: pk1i * u)").in_scope(|| { + let pk1i_times_u = poly_mul(&plan, &pk1i, &u); + #[cfg(feature = "sanity-check")] + assert_eq!((pk1i_times_u.len() as u64) - 1, 2 * (N - 1)); + + poly_add(&pk1i_times_u, &e1) + }); + + #[cfg(feature = "sanity-check")] + assert_eq!((ct1i_hat.len() as u64) - 1, 2 * (N - 1)); + + // Check whether ct1i_hat mod R_qi (the ring) is equal to ct1i + let mut ct1i_hat_mod_rqi = ct1i_hat.clone(); + info_span!("reduce_in_ring: ct1i_hat_mod_rqi % qi_bigint") + .in_scope(|| reduce_in_ring(&mut ct1i_hat_mod_rqi, &cyclo, &qi_bigint)); + #[cfg(feature = "sanity-check")] + assert_eq!(&ct1i, &ct1i_hat_mod_rqi); + + // Compute p2i numerator = ct1i - ct1i_hat + let ct1i_minus_ct1i_hat = poly_sub(&ct1i, &ct1i_hat); + #[cfg(feature = "sanity-check")] + assert_eq!((ct1i_minus_ct1i_hat.len() as u64) - 1, 2 * (N - 1)); + let mut ct1i_minus_ct1i_hat_mod_zqi = ct1i_minus_ct1i_hat.clone(); + reduce_and_center_coefficients_mut( + &mut ct1i_minus_ct1i_hat_mod_zqi, + &qi_bigint, + ); + + // Compute p2i as the quotient of numerator divided by the cyclotomic polynomial, + // and reduce/center the resulting coefficients to produce: + // (ct1i - ct1i_hat) / (x^N + 1) mod Z_qi. Remainder should be empty. + let (p2i, _p2i_rem) = info_span!("poly_div_cyclo: p2i").in_scope(|| { + poly_div_cyclo(&ct1i_minus_ct1i_hat_mod_zqi, &cyclo.len() - 1) + }); + #[cfg(feature = "sanity-check")] + { + assert!(_p2i_rem.is_empty()); + assert_eq!((p2i.len() as u64) - 1, N - 2); // Order(p2i) = N - 2 + } + + // Assert that (ct1i - ct1i_hat) = (p2i * cyclo) mod Z_qi + let p2i_times_cyclo = info_span!("poly_mul p2i_times_cyclo") + .in_scope(|| poly_mul(&plan, &p2i, &cyclo)); + let mut p2i_times_cyclo_mod_zqi = p2i_times_cyclo.clone(); + reduce_and_center_coefficients_mut(&mut p2i_times_cyclo_mod_zqi, &qi_bigint); + #[cfg(feature = "sanity-check")] + { + assert_eq!(&ct1i_minus_ct1i_hat_mod_zqi, &p2i_times_cyclo_mod_zqi); + assert_eq!((p2i_times_cyclo.len() as u64) - 1, 2 * (N - 1)); + } + + // Calculate p1i = (ct1i - ct1i_hat - p2i * cyclo) / qi mod Z_p. Remainder should be empty. + let p1i_num = poly_sub(&ct1i_minus_ct1i_hat, &p2i_times_cyclo); + #[cfg(feature = "sanity-check")] + assert_eq!((p1i_num.len() as u64) - 1, 2 * (N - 1)); + + let (p1i, _p1i_rem) = info_span!("poly_div: p1i / p") + .in_scope(|| poly_div(&p1i_num, &[BigInt::from(qi.modulus())])); + #[cfg(feature = "sanity-check")] + { + assert!(_p1i_rem.is_empty()); + assert_eq!((p1i.len() as u64) - 1, 2 * (N - 1)); // Order(p1i) = 2*(N-1) + assert_eq!(&p1i_num, &poly_mul(&plan_cyclo, &p1i, &[qi_bigint.clone()])); + } + + // Assert that ct1i = ct1i_hat + p1i * qi + p2i * cyclo mod Z_p + #[cfg(feature = "sanity-check")] + { + let p1i_times_qi = poly_scalar_mul(&p1i, &qi_bigint); + let mut ct1i_calculated = + poly_add(&poly_add(&ct1i_hat, &p1i_times_qi), &p2i_times_cyclo); + + while ct1i_calculated.len() > 0 && ct1i_calculated[0].is_zero() { + ct1i_calculated.remove(0); + } + + assert_eq!(&ct1i, &ct1i_calculated); + } + + (i, r2i, r1i, k0qi, ct0i, ct1i, pk0i, pk1i, p1i, p2i) + }) }, ) .collect(); @@ -818,16 +870,22 @@ impl InputValidationBounds { } fn main() -> Result<(), Box> { - // Set up the BFV parameters - let N: u64 = 128; - let plaintext_modulus: u64 = 65537; - let moduli: Vec = vec![4503599625535489, 4503599626321921]; + let env_filter = EnvFilter::builder() + .with_default_directive(tracing::Level::INFO.into()) + .from_env_lossy(); - let params = BfvParametersBuilder::new() - .set_degree(N as usize) - .set_plaintext_modulus(plaintext_modulus) - .set_moduli(&moduli) - .build_arc()?; + let subscriber = Registry::default() + .with(env_filter) + .with(ForestLayer::default()); + + let _ = tracing::subscriber::set_global_default(subscriber); + + // TODO: add method `default_parameter_128(plaintext_nbits: usize, log2_n: usize) -> BfvParameters` in fhe-rs fork + // TODO: and cache? + let params = info_span!("BfvParameters::default_parameters_128") + .in_scope(|| BfvParameters::default_parameters_128(20)[4].clone()); + let N: u64 = params.degree() as u64; + let moduli = params.moduli().to_vec(); // Extract plaintext modulus let t = Modulus::new(params.plaintext())?; @@ -843,11 +901,15 @@ fn main() -> Result<(), Box> { // let m = t.random_vec(N as usize, &mut rng); let m: Vec = (-(N as i64 / 2)..(N as i64 / 2)).collect(); // m here is from lowest degree to largest as input into fhe.rs (REQUIRED) let pt = Plaintext::try_encode(&m, Encoding::poly(), ¶ms)?; - let (ct, u_rns, e0_rns, e1_rns) = pk.try_encrypt_extended(&pt, &mut rng)?; + let (ct, u_rns, e0_rns, e1_rns) = + info_span!("encrypt").in_scope(|| pk.try_encrypt_extended(&pt, &mut rng))?; // Sanity check. m = Decrypt(ct) - let m_decrypted = unsafe { t.center_vec_vt(&sk.try_decrypt(&ct)?.value.into_vec()) }; - assert_eq!(m_decrypted, m); + #[cfg(feature = "sanity-check")] + { + let m_decrypted = unsafe { t.center_vec_vt(&sk.try_decrypt(&ct)?.value.into_vec()) }; + assert_eq!(m_decrypted, m); + } // Extract context let ctx = params.ctx_at_level(pt.level())?.clone(); @@ -859,16 +921,18 @@ fn main() -> Result<(), Box> { )?; // Compute input validation vectors - let res = InputValidationVectors::compute(&pt, &u_rns, &e0_rns, &e1_rns, &ct, &pk)?; + let res = info_span!("InputValidationVectors::compute") + .in_scope(|| InputValidationVectors::compute(&pt, &u_rns, &e0_rns, &e1_rns, &ct, &pk))?; // Create output json with standard form polynomials let json_data = res.standard_form(&p).to_json(); // Calculate bounds --------------------------------------------------------------------- - let bounds = InputValidationBounds::compute(¶ms, pt.level())?; + let bounds = info_span!("InputValidationBounds::compute") + .in_scope(|| InputValidationBounds::compute(¶ms, pt.level()))?; // Check the constraints - bounds.check_constraints(&res, &p); + info_span!("bounds.check_constraints").in_scope(|| bounds.check_constraints(&res, &p)); let moduli_bitsize = { if let Some(&max_value) = ctx.moduli().iter().max() { diff --git a/rs-script/src/poly.rs b/rs-script/src/poly.rs index 402ac77..8dff88d 100644 --- a/rs-script/src/poly.rs +++ b/rs-script/src/poly.rs @@ -1,7 +1,14 @@ /// Provides helper methods that perform modular poynomial arithmetic over polynomials encoded in vectors /// of coefficients from largest degree to lowest. +use itertools::Itertools; use num_bigint::BigInt; use num_traits::*; +use rayon::iter::IntoParallelRefIterator; +use rayon::iter::ParallelIterator; + +// NTT related +pub(crate) use concrete_ntt::native64::Plan32 as PlanNtt; +pub(crate) type NttUint = u64; /// Adds two polynomials represented as vectors of `BigInt` coefficients in descending order of powers. /// @@ -30,9 +37,9 @@ pub fn poly_add(poly1: &[BigInt], poly2: &[BigInt]) -> Vec { // Add the coefficients let mut result = vec![BigInt::zero(); max_length]; - for i in 0..max_length { - result[i] = &extended_poly1[i] + &extended_poly2[i]; - } + result.iter_mut().enumerate().for_each(|(i, x)| { + *x = &extended_poly1[i] + &extended_poly2[i]; + }); result } @@ -74,7 +81,7 @@ pub fn poly_sub(poly1: &[BigInt], poly2: &[BigInt]) -> Vec { poly_add(poly1, &poly_neg(poly2)) } -/// Multiplies two polynomials represented as slices of `BigInt` coefficients naively. +/// Multiplies two polynomials represented as slices of `BigInt` coefficients using NNT. /// /// Given two polynomials `poly1` and `poly2`, where each polynomial is represented by a slice of /// coefficients, this function computes their product. The order of coefficients (ascending or @@ -90,20 +97,26 @@ pub fn poly_sub(poly1: &[BigInt], poly2: &[BigInt]) -> Vec { /// /// A vector of `BigInt` representing the coefficients of the resulting polynomial after multiplication, /// in the same order as the input polynomials. -pub fn poly_mul(poly1: &[BigInt], poly2: &[BigInt]) -> Vec { - let product_len = poly1.len() + poly2.len() - 1; - let mut product = vec![BigInt::zero(); product_len]; +pub fn poly_mul(plan: &PlanNtt, poly1: &[BigInt], poly2: &[BigInt]) -> Vec { + let product_len_orig = poly1.len() + poly2.len() - 1; - for i in 0..poly1.len() { - for j in 0..poly2.len() { - product[i + j] += &poly1[i] * &poly2[j]; - } - } + let mut poly1_padded = poly_bigint_into_uint_vec(poly1); + poly1_padded.resize(poly1.len().next_power_of_two() * 2, 0); + + let mut poly2_padded = poly_bigint_into_uint_vec(poly2); + poly2_padded.resize(poly1_padded.len(), 0); + + let product_len = poly1_padded.len(); + let mut product = vec![u64::zero(); product_len]; - product + plan.negacyclic_polymul(&mut product, &poly1_padded, &poly2_padded); + + vec_uint_into_bigint_poly(&product, product_len_orig) } -/// Divides one polynomial by another, returning the quotient and remainder, with both polynomials +/// Divides one polynomial by another recursively based "divide-and-conquer" strategy. +/// +/// Returns the quotient and remainder, with both polynomials /// represented by vectors of `BigInt` coefficients in descending order of powers. /// /// Given two polynomials `dividend` and `divisor`, where each polynomial is represented by a vector @@ -133,6 +146,43 @@ pub fn poly_div(dividend: &[BigInt], divisor: &[BigInt]) -> (Vec, Vec (Vec, Vec) { let mut quotient = vec![BigInt::zero(); dividend.len() - divisor.len() + 1]; let mut remainder = dividend.to_vec(); @@ -145,13 +195,73 @@ pub fn poly_div(dividend: &[BigInt], divisor: &[BigInt]) -> (Vec, Vec 0 && remainder[0].is_zero() { + // Remove leading zeros from the remainder + while !remainder.is_empty() && remainder[0].is_zero() { remainder.remove(0); } (quotient, remainder) } +/// Divides a polynomial by a cyclotomic polynomial, returning the quotient and remainder. +/// +/// * `dividend` - A slice of `BigInt` representing the coefficients of the dividend polynomial. +/// * `n` is the degree of the cyclotomic polynomial. +/// +/// Assumes `dividend` polynomial is represented by a vector of coefficients in descending order of powers. +pub fn poly_div_cyclo(dividend: &[BigInt], n: usize) -> (Vec, Vec) { + let m = dividend.len(); + let q_len = if m >= n { m - n } else { 0 }; + let mut remainder = dividend.to_vec(); + let mut quotient = Vec::with_capacity(q_len); + + for i in 0..q_len { + let q = std::mem::take(&mut remainder[i]); + quotient.push(q.clone()); + + let idx = i + n; + if idx < remainder.len() { + remainder[idx] -= "ient[i]; + } else { + remainder.resize(idx + 1, BigInt::zero()); + remainder[idx] -= "ient[i]; + } + } + + // Find the first non-zero index without modifying the vector + let first_non_zero = remainder + .iter() + .position(|x| !x.is_zero()) + .unwrap_or(remainder.len()); + let trimmed_remainder = &remainder[first_non_zero..]; + + (quotient, trimmed_remainder.to_vec()) +} + +// Computes the polynomial modulo a cyclotomic polynomial leveraging the stucture of the cyclotomic polynomial. +// +// * `dividend` is the polynomial to be reduced. +// * `n` is the degree of the cyclotomic polynomial. +// +// Assumes `dividend` polynomial is represented by a vector of coefficients in descending order of powers. +pub fn poly_modulo_cyclo(dividend: &[BigInt], n: usize) -> Vec { + let mut remainder = vec![BigInt::zero(); n]; + let degree = dividend.len() - 1; // Highest exponent + for (i, coeff) in dividend.iter().enumerate() { + let e = degree - i; // Actual exponent of the term + let q = e / n; + let sign = if q % 2 == 0 { + BigInt::one() + } else { + -BigInt::one() + }; + let r = e % n; + let remainder_index = n - 1 - r; + remainder[remainder_index] += sign * coeff; + } + remainder +} + /// Multiplies each coefficient of a polynomial by a scalar. /// /// This function takes a polynomial represented as a vector of `BigInt` coefficients and multiplies each @@ -168,7 +278,7 @@ pub fn poly_div(dividend: &[BigInt], divisor: &[BigInt]) -> (Vec, Vec Vec { - poly.iter().map(|coeff| coeff * scalar).collect() + poly.par_iter().map(|coeff| coeff * scalar).collect() } /// Reduces the coefficients of a polynomial by dividing it with a cyclotomic polynomial @@ -191,14 +301,13 @@ pub fn poly_scalar_mul(poly: &[BigInt], scalar: &BigInt) -> Vec { /// This function will panic if the remainder length exceeds the degree of the cyclotomic polynomial, /// which would indicate an issue with the division operation. pub fn reduce_coefficients_by_cyclo(coefficients: &mut Vec, cyclo: &[BigInt]) { - // Perform polynomial long division, assuming poly_div returns (quotient, remainder) - let (_, remainder) = poly_div(&coefficients, cyclo); + let remainder = poly_modulo_cyclo(coefficients, cyclo.len() - 1); - let N = cyclo.len() - 1; - let mut out: Vec = vec![BigInt::zero(); N]; + let n = cyclo.len() - 1; + let mut out: Vec = vec![BigInt::zero(); n]; // Calculate the starting index in `out` where the remainder should be copied - let start_idx = N - remainder.len(); + let start_idx = n - remainder.len(); // Copy the remainder into the `out` vector starting from `start_idx` out[start_idx..].clone_from_slice(&remainder); @@ -264,7 +373,7 @@ pub fn reduce_and_center_coefficients( ) -> Vec { let half_modulus = modulus / BigInt::from(2); coefficients - .iter() + .par_iter() .map(|x| reduce_and_center(x, modulus, &half_modulus)) .collect() } @@ -341,3 +450,16 @@ pub fn range_check_standard(vec: &[BigInt], bound: &BigInt, modulus: &BigInt) -> || (coeff >= &(modulus - bound) && coeff < modulus) }) } + +fn poly_bigint_into_uint_vec(vec: &[BigInt]) -> Vec { + vec.iter() + .map(|x| (x.to_i64().unwrap() as NttUint)) + .collect_vec() +} + +fn vec_uint_into_bigint_poly(vec: &[NttUint], size: usize) -> Vec { + vec.iter() + .take(size) + .map(|x| BigInt::from_i64(*x as i64).unwrap()) + .collect_vec() +}