From 03eb4b02a21ee7a327450b57f3091cf52d789130 Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Sat, 6 Jan 2024 15:39:15 -0800 Subject: [PATCH] plonk: multiprover: proof-linking: Add proof-linking unit tests --- .../multiprover/proof_system/proof_linking.rs | 338 +++++++++++++++++- plonk/src/proof_system/proof_linking.rs | 223 ++++++------ relation/src/proof_linking/mod.rs | 8 + 3 files changed, 468 insertions(+), 101 deletions(-) diff --git a/plonk/src/multiprover/proof_system/proof_linking.rs b/plonk/src/multiprover/proof_system/proof_linking.rs index adac84ac9..be9421e5c 100644 --- a/plonk/src/multiprover/proof_system/proof_linking.rs +++ b/plonk/src/multiprover/proof_system/proof_linking.rs @@ -97,7 +97,7 @@ impl, E: Pairing MultiproverPlonkKzgSnark { /// Link a singleprover proof into a multiprover proof - pub fn link_singleprover_proof( + pub fn link_proofs( lhs_link_hint: &MpcLinkingHint, rhs_link_hint: &MpcLinkingHint, group_layout: &GroupLayout, @@ -228,3 +228,339 @@ impl, E: Pairing Ok(opening_proof) } } + +#[cfg(test)] +mod test { + use ark_bn254::{Bn254, G1Projective as Curve}; + use ark_ec::pairing::Pairing; + use ark_ff::Zero; + use ark_mpc::{ + algebra::{AuthenticatedScalarResult, Scalar}, + test_helpers::execute_mock_mpc, + MpcFabric, PARTY0, + }; + use itertools::Itertools; + use mpc_relation::{ + proof_linking::{GroupLayout, LinkableCircuit}, + PlonkCircuit, + }; + use rand::{thread_rng, Rng}; + + use crate::{ + errors::PlonkError, + multiprover::proof_system::{ + CollaborativeProof, MpcLinkingHint, MpcPlonkCircuit, MultiproverPlonkKzgSnark, + }, + proof_system::{ + proof_linking::test_helpers::{ + gen_commit_keys, gen_proving_keys, gen_test_circuit1, gen_test_circuit2, + CircuitSelector, GROUP_NAME, + }, + structs::ProvingKey, + PlonkKzgSnark, + }, + transcript::SolidityTranscript, + }; + + /// The number of linked witness elements to use in the tests + const WITNESS_ELEMS: usize = 10; + /// The test field used + type TestField = ::ScalarField; + + // ----------- + // | Helpers | + // ----------- + + /// Generate a test case proof, group layout, and link hint from the given + /// circuit + fn gen_circuit_proof_and_hint( + witness: &[AuthenticatedScalarResult], + circuit: CircuitSelector, + layout: Option, + fabric: &MpcFabric, + ) -> (CollaborativeProof, MpcLinkingHint, GroupLayout) { + let mut cs = MpcPlonkCircuit::new(fabric.clone()); + match circuit { + CircuitSelector::Circuit1 => gen_test_circuit1(&mut cs, witness, layout), + CircuitSelector::Circuit2 => gen_test_circuit2(&mut cs, witness, layout), + }; + cs.finalize_for_arithmetization().unwrap(); + + // Generate a proving key + let pk = gen_pk_from_singleprover(circuit, layout); + + // Get the layout + let group_layout = cs.get_link_group_layout(GROUP_NAME).unwrap(); + + // Generate a proof with a linking hint + let (proof, hint) = gen_test_proof(&cs, &pk, fabric); + (proof, hint, group_layout) + } + + /// Get a proving key by constructing a singleprover circuit of the same + /// topology + fn gen_pk_from_singleprover( + circuit_selector: CircuitSelector, + layout: Option, + ) -> ProvingKey { + let mut cs = PlonkCircuit::new_turbo_plonk(); + let dummy_witness = (0..WITNESS_ELEMS).map(|_| TestField::zero()).collect_vec(); + match circuit_selector { + CircuitSelector::Circuit1 => gen_test_circuit1(&mut cs, &dummy_witness, layout), + CircuitSelector::Circuit2 => gen_test_circuit2(&mut cs, &dummy_witness, layout), + }; + cs.finalize_for_arithmetization().unwrap(); + + let (pk, _) = gen_proving_keys(&cs); + pk + } + + /// Generate a proof and link hint for the circuit by proving its r1cs + /// relation + fn gen_test_proof( + circuit: &MpcPlonkCircuit, + pk: &ProvingKey, + fabric: &MpcFabric, + ) -> (CollaborativeProof, MpcLinkingHint) { + MultiproverPlonkKzgSnark::::prove_with_link_hint(circuit, pk, fabric.clone()) + .unwrap() + } + + /// Prove a link between two circuits and verify the link, return the result + /// as a result + async fn prove_and_verify_link( + lhs_hint: &MpcLinkingHint, + rhs_hint: &MpcLinkingHint, + lhs_proof: &CollaborativeProof, + rhs_proof: &CollaborativeProof, + layout: &GroupLayout, + fabric: &MpcFabric, + ) -> Result<(), PlonkError> { + let (commit_key, open_key) = gen_commit_keys(); + let proof = MultiproverPlonkKzgSnark::::link_proofs( + lhs_hint, + rhs_hint, + layout, + &commit_key, + fabric, + )?; + + let opened_link = proof.open_authenticated().await?; + let lhs_proof = lhs_proof.clone().open_authenticated().await?; + let rhs_proof = rhs_proof.clone().open_authenticated().await?; + + PlonkKzgSnark::::verify_link_proof::( + &lhs_proof, + &rhs_proof, + &opened_link, + layout, + &open_key, + ) + } + + // -------------- + // | Test Cases | + // -------------- + + /// Test the basic case of a valid link on two circuits + #[tokio::test] + async fn test_valid_link() { + let mut rng = thread_rng(); + let witness = (0..WITNESS_ELEMS).map(|_| Scalar::random(&mut rng)).collect_vec(); + + // Generate a proof and link in an MPC + let (res, _) = execute_mock_mpc(move |fabric| { + let witness = witness.clone(); + async move { + let witness = fabric.batch_share_scalar(witness, PARTY0); + + // Generate r1cs proofs for the two circuits + let (lhs_proof, lhs_hint, layout) = + gen_circuit_proof_and_hint(&witness, CircuitSelector::Circuit1, None, &fabric); + let (rhs_proof, rhs_hint, _) = gen_circuit_proof_and_hint( + &witness, + CircuitSelector::Circuit2, + Some(layout), + &fabric, + ); + + // Prove and verify the link + prove_and_verify_link( + &lhs_hint, &rhs_hint, &lhs_proof, &rhs_proof, &layout, &fabric, + ) + .await + } + }) + .await; + + assert!(res.is_ok()); + } + + /// Tests a valid link with a layout specified up front + #[tokio::test] + #[allow(non_snake_case)] + async fn test_valid_link__specific_layout() { + let mut rng = thread_rng(); + let witness = (0..WITNESS_ELEMS).map(|_| Scalar::random(&mut rng)).collect_vec(); + + // Generate a proof and link in an MPC + let (res, _) = execute_mock_mpc(move |fabric| { + let witness = witness.clone(); + async move { + let witness = fabric.batch_share_scalar(witness, PARTY0); + + // Generate r1cs proofs for the two circuits + let layout = GroupLayout { offset: 20, size: WITNESS_ELEMS, alignment: 8 }; + let (lhs_proof, lhs_hint, layout) = gen_circuit_proof_and_hint( + &witness, + CircuitSelector::Circuit1, + Some(layout), + &fabric, + ); + let (rhs_proof, rhs_hint, _) = gen_circuit_proof_and_hint( + &witness, + CircuitSelector::Circuit2, + Some(layout), + &fabric, + ); + + // Prove and verify the link + prove_and_verify_link( + &lhs_hint, &rhs_hint, &lhs_proof, &rhs_proof, &layout, &fabric, + ) + .await + } + }) + .await; + + assert!(res.is_ok()); + } + + /// Tests an invalid proof link wherein the witnesses used are different + #[tokio::test] + #[allow(non_snake_case)] + async fn test_invalid_proof_link__different_witnesses() { + let mut rng = thread_rng(); + + // Modify the second witness at a random location + let witness1 = (0..WITNESS_ELEMS).map(|_| Scalar::random(&mut rng)).collect_vec(); + let mut witness2 = witness1.clone(); + let modification_idx = rng.gen_range(0..WITNESS_ELEMS); + witness2[modification_idx] = Scalar::random(&mut rng); + + let (res, _) = execute_mock_mpc(move |fabric| { + let witness1 = witness1.clone(); + let witness2 = witness2.clone(); + + async move { + let witness1 = fabric.batch_share_scalar(witness1, PARTY0); + let witness2 = fabric.batch_share_scalar(witness2, PARTY0); + + // Generate r1cs proofs for the two circuits + let (lhs_proof, lhs_hint, layout) = + gen_circuit_proof_and_hint(&witness1, CircuitSelector::Circuit1, None, &fabric); + let (rhs_proof, rhs_hint, _) = gen_circuit_proof_and_hint( + &witness2, + CircuitSelector::Circuit2, + Some(layout), + &fabric, + ); + + // Prove and verify the link + prove_and_verify_link( + &lhs_hint, &rhs_hint, &lhs_proof, &rhs_proof, &layout, &fabric, + ) + .await + } + }) + .await; + + assert!(res.is_err()); + } + + /// Tests the case in which the correct witness is used to link but over + /// incorrectly aligned domains + #[tokio::test] + #[allow(non_snake_case)] + async fn test_invalid_proof_link__wrong_alignment() { + // Use the same witness between two circuits + let mut rng = thread_rng(); + let witness = (0..WITNESS_ELEMS).map(|_| Scalar::random(&mut rng)).collect_vec(); + + let (res, _) = execute_mock_mpc(move |fabric| { + let witness = witness.clone(); + async move { + let witness = fabric.batch_share_scalar(witness, PARTY0); + + // Generate r1cs proofs for the two circuits + let (lhs_proof, lhs_hint, mut layout) = gen_circuit_proof_and_hint( + &witness, + CircuitSelector::Circuit1, + None, // layout + &fabric, + ); + + // Modify the layout to be misaligned + layout.alignment += 1; + let (rhs_proof, rhs_hint, _) = gen_circuit_proof_and_hint( + &witness, + CircuitSelector::Circuit2, + Some(layout), + &fabric, + ); + + // Prove and verify the link + prove_and_verify_link( + &lhs_hint, &rhs_hint, &lhs_proof, &rhs_proof, &layout, &fabric, + ) + .await + } + }) + .await; + + assert!(res.is_err()); + } + + /// Tests the case in which the correct witness is used to link but over + /// domains at different offsets + #[tokio::test] + #[allow(non_snake_case)] + async fn test_invalid_proof_link__wrong_offset() { + // Use the same witness between two circuits + let mut rng = thread_rng(); + let witness = (0..WITNESS_ELEMS).map(|_| Scalar::random(&mut rng)).collect_vec(); + + let (res, _) = execute_mock_mpc(move |fabric| { + let witness = witness.clone(); + async move { + let witness = fabric.batch_share_scalar(witness, PARTY0); + + // Generate r1cs proofs for the two circuits + let (lhs_proof, lhs_hint, mut layout) = gen_circuit_proof_and_hint( + &witness, + CircuitSelector::Circuit1, + None, // layout + &fabric, + ); + + // Modify the layout to be misaligned + layout.offset -= 1; + let (rhs_proof, rhs_hint, _) = gen_circuit_proof_and_hint( + &witness, + CircuitSelector::Circuit2, + Some(layout), + &fabric, + ); + + // Prove and verify the link + prove_and_verify_link( + &lhs_hint, &rhs_hint, &lhs_proof, &rhs_proof, &layout, &fabric, + ) + .await + } + }) + .await; + + assert!(res.is_err()); + } +} diff --git a/plonk/src/proof_system/proof_linking.rs b/plonk/src/proof_system/proof_linking.rs index 37c385b80..3fb7aaa65 100644 --- a/plonk/src/proof_system/proof_linking.rs +++ b/plonk/src/proof_system/proof_linking.rs @@ -257,38 +257,37 @@ where } #[cfg(test)] -mod test { +pub mod test_helpers { + //! Helpers exported for proof-linking tests in the `plonk` crate + use ark_bn254::{Bn254, Fr as FrBn254}; - use ark_std::UniformRand; use itertools::Itertools; use jf_primitives::pcs::StructuredReferenceString; use lazy_static::lazy_static; use mpc_relation::{ proof_linking::{GroupLayout, LinkableCircuit}, - traits::Circuit, PlonkCircuit, }; - use rand::{thread_rng, Rng, SeedableRng}; + use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; - use crate::{ - errors::PlonkError, - proof_system::{ - structs::{CommitKey, LinkingHint, OpenKey, Proof, ProvingKey, VerifyingKey}, - PlonkKzgSnark, UniversalSNARK, - }, - transcript::SolidityTranscript, + use crate::proof_system::{ + structs::{CommitKey, OpenKey, ProvingKey, VerifyingKey}, + PlonkKzgSnark, UniversalSNARK, }; - /// The name of the group used in testing - const GROUP_NAME: &str = "test_group"; + /// The field used in testing + pub type TestField = FrBn254; /// The maximum circuit degree used in testing - const MAX_DEGREE_TESTING: usize = 1000; + pub const MAX_DEGREE_TESTING: usize = 1000; + /// The name of the group used in testing + pub const GROUP_NAME: &str = "test_group"; + lazy_static! { /// The rng seed used to generate the test circuit's SRS /// /// We generate it once at setup to ensure that the SRS is the same between all constructions - static ref SRS_SEED: u64 = rand::thread_rng().gen(); + pub static ref SRS_SEED: u64 = rand::thread_rng().gen(); } // ----------------- @@ -299,7 +298,7 @@ mod test { /// /// Aids in the ergonomics of templating tests below #[derive(Copy, Clone)] - enum CircuitSelector { + pub enum CircuitSelector { /// The first circuit Circuit1, /// The second circuit @@ -307,79 +306,132 @@ mod test { } /// Generate a summation circuit with the given witness - fn gen_test_circuit1( - witness: &[FrBn254], + pub fn gen_test_circuit1>( + c: &mut C, + witness: &[C::Wire], layout: Option, - ) -> PlonkCircuit { - let mut rng = thread_rng(); - let mut circuit = PlonkCircuit::new_turbo_plonk(); - - // Add a few non-linked witnesses to the sum - let non_linked_witnesses = (0..10).map(|_| FrBn254::rand(&mut rng)).collect_vec(); - let sum = witness.iter().chain(non_linked_witnesses.iter()).cloned().sum::(); - let expected = circuit.create_public_variable(sum).unwrap(); + ) { + let sum = witness.iter().cloned().sum::(); + let expected = c.create_public_variable(sum).unwrap(); // Add a few public inputs to the circuit - for _ in 0..10 { - circuit.create_public_variable(FrBn254::rand(&mut rng)).unwrap(); + // Simplest just to re-use the witness values directly as these are purely for + // spacing + for val in witness.iter() { + c.create_public_variable(val.clone()).unwrap(); } - // Create a proof linking group and add the witnesses to it - let group = circuit.create_link_group(GROUP_NAME.to_string(), layout); - let mut witness_vars = witness + // Link the witnesses to the group + let group = c.create_link_group(GROUP_NAME.to_string(), layout); + let witness_vars = witness .iter() - .map(|&w| circuit.create_variable_with_link_groups(w, &[group.clone()]).unwrap()) + .map(|w| c.create_variable_with_link_groups(w.clone(), &[group.clone()]).unwrap()) .collect_vec(); - witness_vars.extend( - &mut non_linked_witnesses.into_iter().map(|w| circuit.create_variable(w).unwrap()), - ); - let sum = circuit.sum(&witness_vars).unwrap(); - circuit.enforce_equal(sum, expected).unwrap(); - circuit.finalize_for_arithmetization().unwrap(); + // Create a few more witnesses that are not linked + // Again, for interface simplicity just reuse the witness data to generate new + // values + witness.iter().map(|w| c.create_variable(w.clone() * w.clone()).unwrap()).collect_vec(); - circuit + let sum = c.sum(&witness_vars).unwrap(); + c.enforce_equal(sum, expected).unwrap(); } /// Generate a product circuit with the given witness - fn gen_test_circuit2( - witness: &[FrBn254], + pub fn gen_test_circuit2>( + c: &mut C, + witness: &[C::Wire], layout: Option, - ) -> PlonkCircuit { - let mut rng = thread_rng(); - let mut circuit = PlonkCircuit::new_turbo_plonk(); - - // Add a few non-linked witnesses to the product - let non_linked_witnesses = (0..10).map(|_| FrBn254::rand(&mut rng)).collect_vec(); - let product = - witness.iter().chain(non_linked_witnesses.iter()).copied().product::(); - let expected = circuit.create_public_variable(product).unwrap(); + ) { + // Compute the expected result + let mut product = witness[0].clone(); + for w in witness[1..].iter().cloned() { + product = product * w; + } + let expected = c.create_public_variable(product).unwrap(); // Add a few public inputs to the circuit - for _ in 0..10 { - circuit.create_public_variable(FrBn254::rand(&mut rng)).unwrap(); + // Simplest just to re-use the witness values directly as these are purely for + // spacing + for val in witness.iter() { + c.create_public_variable(val.clone()).unwrap(); } - // Create a link group with the placement - let group = circuit.create_link_group(GROUP_NAME.to_string(), layout); - let mut witness_vars = witness + // Link half the witnesses to the group + let group = c.create_link_group(GROUP_NAME.to_string(), layout); + let witness_vars = witness .iter() - .map(|w| circuit.create_variable_with_link_groups(*w, &[group.clone()]).unwrap()) + .map(|w| c.create_variable_with_link_groups(w.clone(), &[group.clone()]).unwrap()) .collect_vec(); - witness_vars.extend( - &mut non_linked_witnesses.into_iter().map(|w| circuit.create_variable(w).unwrap()), - ); + + // Create a few more witnesses that are not linked + // Again, for interface simplicity just reuse the witness data to generate new + // values + witness.iter().map(|w| c.create_variable(w.clone() * w.clone()).unwrap()).collect_vec(); // Constrain the product - let mut product = circuit.one(); + let mut product = c.one(); for var in &witness_vars { - product = circuit.mul(product, *var).unwrap(); + product = c.mul(product, *var).unwrap(); } - circuit.enforce_equal(product, expected).unwrap(); - circuit.finalize_for_arithmetization().unwrap(); + c.enforce_equal(product, expected).unwrap(); + } - circuit + /// Setup proving and verifying keys for a test circuit + pub fn gen_proving_keys( + circuit: &PlonkCircuit, + ) -> (ProvingKey, VerifyingKey) { + let mut rng = ChaCha20Rng::seed_from_u64(*SRS_SEED); + let srs = PlonkKzgSnark::::universal_setup_for_testing(MAX_DEGREE_TESTING, &mut rng) + .unwrap(); + + PlonkKzgSnark::::preprocess(&srs, circuit).unwrap() + } + + /// Generate commitment keys for a KZG commitment + /// + /// This is done separately from the proving key to allow helpers to + /// generate circuit-agnostic keys + pub fn gen_commit_keys() -> (CommitKey, OpenKey) { + let mut rng = ChaCha20Rng::seed_from_u64(*SRS_SEED); + let srs = PlonkKzgSnark::::universal_setup_for_testing(MAX_DEGREE_TESTING, &mut rng) + .unwrap(); + + ( + srs.extract_prover_param(MAX_DEGREE_TESTING), + srs.extract_verifier_param(MAX_DEGREE_TESTING), + ) } +} + +#[cfg(test)] +mod test { + use ark_bn254::{Bn254, Fr as FrBn254}; + use ark_std::UniformRand; + use itertools::Itertools; + use mpc_relation::{ + proof_linking::{GroupLayout, LinkableCircuit}, + PlonkCircuit, + }; + use rand::{thread_rng, Rng}; + + use crate::{ + errors::PlonkError, + proof_system::{ + structs::{LinkingHint, Proof}, + PlonkKzgSnark, + }, + transcript::SolidityTranscript, + }; + + use super::test_helpers::{ + gen_commit_keys, gen_proving_keys, gen_test_circuit1, gen_test_circuit2, CircuitSelector, + GROUP_NAME, + }; + + // ----------- + // | Helpers | + // ----------- /// Generate a test case proof, group layout, and link hint from the given /// circuit @@ -388,24 +440,21 @@ mod test { circuit: CircuitSelector, layout: Option, ) -> (Proof, LinkingHint, GroupLayout) { - let circuit = match circuit { - CircuitSelector::Circuit1 => gen_test_circuit1(witness, layout), - CircuitSelector::Circuit2 => gen_test_circuit2(witness, layout), + let mut cs = PlonkCircuit::new_turbo_plonk(); + match circuit { + CircuitSelector::Circuit1 => gen_test_circuit1(&mut cs, witness, layout), + CircuitSelector::Circuit2 => gen_test_circuit2(&mut cs, witness, layout), }; + cs.finalize_for_arithmetization().unwrap(); // Get the layout - let group_layout = circuit.get_link_group_layout(GROUP_NAME).unwrap(); + let group_layout = cs.get_link_group_layout(GROUP_NAME).unwrap(); // Generate a proof with a linking hint - let (proof, hint) = gen_test_proof(&circuit); - + let (proof, hint) = gen_test_proof(&cs); (proof, hint, group_layout) } - // ----------- - // | Helpers | - // ----------- - /// Generate a proof and link hint for the circuit by proving its r1cs /// relation fn gen_test_proof(circuit: &PlonkCircuit) -> (Proof, LinkingHint) { @@ -418,32 +467,6 @@ mod test { .unwrap() } - /// Setup proving and verifying keys for a test circuit - fn gen_proving_keys( - circuit: &PlonkCircuit, - ) -> (ProvingKey, VerifyingKey) { - let mut rng = ChaCha20Rng::seed_from_u64(*SRS_SEED); - let srs = PlonkKzgSnark::::universal_setup_for_testing(MAX_DEGREE_TESTING, &mut rng) - .unwrap(); - - PlonkKzgSnark::::preprocess(&srs, circuit).unwrap() - } - - /// Generate commitment keys for a KZG commitment - /// - /// This is done separately from the proving key to allow helpers to - /// generate circuit-agnostic keys - fn gen_commit_keys() -> (CommitKey, OpenKey) { - let mut rng = ChaCha20Rng::seed_from_u64(*SRS_SEED); - let srs = PlonkKzgSnark::::universal_setup_for_testing(MAX_DEGREE_TESTING, &mut rng) - .unwrap(); - - ( - srs.extract_prover_param(MAX_DEGREE_TESTING), - srs.extract_verifier_param(MAX_DEGREE_TESTING), - ) - } - /// Prove a link between two circuits and verify the link, return the result /// as a result fn prove_and_verify_link( diff --git a/relation/src/proof_linking/mod.rs b/relation/src/proof_linking/mod.rs index d1269eaf8..bf2c04ac2 100644 --- a/relation/src/proof_linking/mod.rs +++ b/relation/src/proof_linking/mod.rs @@ -69,6 +69,14 @@ pub struct CircuitLayout { } impl CircuitLayout { + /// Get the layout for a given group + /// + /// # Panics + /// Panics if the group does not exist + pub fn get_group_layout(&self, id: &str) -> GroupLayout { + self.group_layouts[id] + } + /// Get the domain size used to represent the circuit after proof linking /// gates are accounted for pub fn circuit_size(&self) -> usize {