diff --git a/plonk/src/multiprover/proof_system/constraint_system.rs b/plonk/src/multiprover/proof_system/constraint_system.rs index c03ae504d..ddc72c049 100644 --- a/plonk/src/multiprover/proof_system/constraint_system.rs +++ b/plonk/src/multiprover/proof_system/constraint_system.rs @@ -2,17 +2,20 @@ //! MPC-enabled arithmetic use ark_ec::CurveGroup; -use ark_ff::FftField; +use ark_ff::{FftField, Zero}; use ark_mpc::{ algebra::{AuthenticatedDensePoly, AuthenticatedScalarResult, Scalar, ScalarResult}, + gadgets::prefix_product, MpcFabric, }; -use ark_poly::{univariate::DensePolynomial, EvaluationDomain, Radix2EvaluationDomain}; +use ark_poly::{ + univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Radix2EvaluationDomain, +}; use async_trait::async_trait; use futures::future; use itertools::Itertools; use jf_relation::{ - constants::{GATE_WIDTH, N_MUL_SELECTORS}, + constants::{compute_coset_representatives, GATE_WIDTH, N_MUL_SELECTORS}, errors::CircuitError, gates::{ AdditionGate, BoolGate, ConstantGate, EqualityGate, Gate, IoGate, MultiplicationGate, @@ -20,6 +23,8 @@ use jf_relation::{ }, GateId, Variable, WireId, }; +use jf_utils::par_utils::parallelizable_slice_iter; +use rayon::prelude::*; use super::MpcCircuitError; @@ -227,7 +232,7 @@ where /// Return an error if the circuit has not been finalized yet. fn compute_extended_permutation_polynomials( &self, - ) -> Result>, MpcCircuitError>; + ) -> Result>, MpcCircuitError>; /// Compute and return the product polynomial for permutation arguments. /// Return an error if the circuit has not been finalized yet. @@ -386,6 +391,7 @@ where /// You should absolutely sure about what you are doing /// You should normally only use this API if you already enforce `v` to be a /// boolean value using other constraints. + #[allow(unused)] pub(crate) fn create_boolean_variable_unchecked( &mut self, a: AuthenticatedScalarResult, @@ -485,7 +491,7 @@ impl MpcPlonkCircuit { // Compute the gate output let expected_gate_output = - pub_input + self.compute_gate_output(q_lc, q_mul, q_hash, q_ecc, q_c, q_o, &w_vals); + pub_input + self.compute_gate_output(q_lc, q_mul, q_hash, q_ecc, q_c, &w_vals); let gate_output = q_o * w_vals[4]; (gate_output - expected_gate_output).open() @@ -496,6 +502,7 @@ impl MpcPlonkCircuit { /// This method differs from the single prover case because multiplication /// induces a significantly higher overhead. So we only compute the /// output of a gate if its selector is non-zero + #[allow(clippy::too_many_arguments)] fn compute_gate_output( &self, q_lc: Vec>, @@ -503,7 +510,6 @@ impl MpcPlonkCircuit { q_hash: Vec>, q_ecc: Scalar, q_c: Scalar, - q_o: Scalar, wire_values: &[&AuthenticatedScalarResult], ) -> AuthenticatedScalarResult { let mut res = self.fabric.zero_authenticated(); @@ -931,3 +937,233 @@ impl MpcCircuit for MpcPlonkCircuit { } } } + +/// Private permutation related methods +impl MpcPlonkCircuit { + /// Copy constraints: precompute the extended permutation over circuit + /// wires. Refer to Sec 5.2 and Sec 8.1 of https://eprint.iacr.org/2019/953.pdf for more details. + #[inline] + fn compute_extended_id_permutation(&mut self) { + assert!(self.is_finalized()); + let n = self.eval_domain.size(); + + // Compute the extended identity permutation + // id[i*n+j] = k[i] * g^j + let k: Vec = compute_coset_representatives(self.num_wire_types, Some(n)); + + // Precompute domain elements + let group_elems: Vec = self.eval_domain.elements().collect(); + + // Compute extended identity permutation + self.extended_id_permutation = vec![C::ScalarField::zero(); self.num_wire_types * n]; + for (i, &coset_repr) in k.iter().enumerate() { + for (j, &group_elem) in group_elems.iter().enumerate() { + self.extended_id_permutation[i * n + j] = coset_repr * group_elem; + } + } + } + + #[inline] + fn compute_extended_permutation(&self) -> Result, MpcCircuitError> { + assert!(self.is_finalized()); + let n = self.eval_domain.size(); + + // The extended wire permutation can be computed as + // extended_perm[i] = id[wire_perm[i].into() * n + wire_perm[i].1] + let extended_perm: Vec = self + .wire_permutation + .iter() + .map(|&(wire_id, gate_id)| { + // if permutation value undefined, return 0 + if wire_id >= self.num_wire_types { + C::ScalarField::zero() + } else { + self.extended_id_permutation[wire_id * n + gate_id] + } + }) + .collect(); + + if extended_perm.len() != self.num_wire_types * n { + return Err(MpcCircuitError::ConstraintSystem( + CircuitError::ParameterError( + "Length of the extended permutation vector should be number of gate \ + (including padded dummy gates) * number of wire types" + .to_string(), + ), + )); + } + Ok(extended_perm) + } +} + +/// Finalization +impl MpcPlonkCircuit { + /// Finalize the setup of the circuit before arithmetization. + pub fn finalize_for_arithmetization(&mut self) -> Result<(), MpcCircuitError> { + if self.is_finalized() { + return Ok(()); + } + + self.eval_domain = Radix2EvaluationDomain::new(self.num_gates()) + .ok_or(CircuitError::DomainCreationError) + .map_err(MpcCircuitError::ConstraintSystem)?; + self.pad()?; + self.rearrange_gates()?; + self.compute_wire_permutation(); + self.compute_extended_id_permutation(); + Ok(()) + } +} + +impl MpcArithmetization for MpcPlonkCircuit { + fn srs_size(&self) -> Result { + Ok(self.eval_domain_size()? + 2) + } + + fn eval_domain_size(&self) -> Result { + self.check_finalize_flag(true)?; + Ok(self.eval_domain.size()) + } + + fn compute_selector_polynomials( + &self, + ) -> Result::ScalarField>>, MpcCircuitError> { + self.check_finalize_flag(true)?; + let domain = &self.eval_domain; + if domain.size() < self.num_gates() { + return Err(MpcCircuitError::ConstraintSystem( + CircuitError::ParameterError( + "Domain size should be bigger than number of constraint".to_string(), + ), + )); + } + + // Order: (lc, mul, hash, o, c, ecc) as specified in spec + let selector_polys = parallelizable_slice_iter(&self.all_selectors()) + .map(|selector| DensePolynomial::from_coefficients_vec(domain.ifft(selector))) + .collect(); + Ok(selector_polys) + } + + fn compute_extended_permutation_polynomials( + &self, + ) -> Result>, MpcCircuitError> { + self.check_finalize_flag(true)?; + let domain = &self.eval_domain; + let n = domain.size(); + let extended_perm = self.compute_extended_permutation()?; + + let extended_perm_polys: Vec> = + parallelizable_slice_iter(&(0..self.num_wire_types).collect::>()) // current par_utils only support slice iterator, not range iterator. + .map(|i| { + DensePolynomial::from_coefficients_vec( + domain.ifft(&extended_perm[i * n..(i + 1) * n]), + ) + }) + .collect(); + + Ok(extended_perm_polys) + } + + fn compute_prod_permutation_polynomial( + &self, + beta: &::ScalarField, + gamma: &::ScalarField, + ) -> Result, MpcCircuitError> { + self.check_finalize_flag(true)?; + let n = self.eval_domain.size(); + + let gamma = Scalar::new(*gamma); + let beta = Scalar::new(*beta); + let one = self.fabric.one_authenticated(); + + let mut numerators = Vec::with_capacity(self.num_wire_types() * (n - 1)); + let mut denominators = Vec::with_capacity(self.num_wire_types() * (n - 1)); + + for j in 0..(n - 1) { + // Numerator + let mut a = one.clone(); + // Denominator + let mut b = one.clone(); + + for i in 0..self.num_wire_types() { + let wire_value = &self.witness[self.wire_variable(i, j)]; + let tmp = wire_value + gamma; + a = a * (&tmp + beta * Scalar::new(self.extended_id_permutation[i * n + j])); + + let (perm_i, perm_j) = self.wire_permutation[i * n + j]; + b = b + * (tmp + beta + Scalar::new(self.extended_id_permutation[perm_i * n + perm_j])); + } + + numerators.push(a); + denominators.push(b); + } + + // Divide the numerators and denominators, create a prefix product of this + // division, and then convert into a polynomial from evaluation form + let div_res = AuthenticatedScalarResult::batch_div(&numerators, &denominators); + let products = prefix_product(&div_res, &self.fabric); + + // The last element of this product is one for a valid proof, we match the + // single-prover implementation and put this first in the resulting + // prefix product vec + let product_vec = [vec![one], products].concat(); + let coeffs = + AuthenticatedScalarResult::ifft::>(&product_vec); + + Ok(AuthenticatedDensePoly::from_coeffs(coeffs)) + } + + fn compute_wire_polynomials(&self) -> Result>, MpcCircuitError> { + self.check_finalize_flag(true)?; + let domain = &self.eval_domain; + if domain.size() < self.num_gates() { + return Err(MpcCircuitError::ConstraintSystem( + CircuitError::ParameterError(format!( + "Domain size {} should be bigger than number of constraint {}", + domain.size(), + self.num_gates() + )), + )); + } + + let witness = &self.witness; + let wire_polys: Vec> = + parallelizable_slice_iter(&self.wire_variables) + .take(self.num_wire_types()) + .map(|wire_vars| { + let wire_vec: Vec> = wire_vars + .iter() + .map(|&var| witness[var].clone()) + .collect_vec(); + + let coeffs = AuthenticatedScalarResult::ifft::< + Radix2EvaluationDomain, + >(&wire_vec); + + AuthenticatedDensePoly::from_coeffs(coeffs) + }) + .collect(); + + assert_eq!(wire_polys.len(), self.num_wire_types()); + Ok(wire_polys) + } + + fn compute_pub_input_polynomial(&self) -> Result, MpcCircuitError> { + self.check_finalize_flag(true)?; + + let domain = &self.eval_domain; + let mut pub_input_vec = self.fabric.zeros_authenticated(domain.size()); + + self.pub_input_gate_ids.iter().for_each(|&io_gate_id| { + let var = self.wire_variables[GATE_WIDTH][io_gate_id]; + pub_input_vec[io_gate_id] = self.witness[var].clone(); + }); + + let coeffs = AuthenticatedScalarResult::ifft::>( + &pub_input_vec, + ); + Ok(AuthenticatedDensePoly::from_coeffs(coeffs)) + } +} diff --git a/relation/src/constraint_system.rs b/relation/src/constraint_system.rs index 4e8796209..82d817a1d 100644 --- a/relation/src/constraint_system.rs +++ b/relation/src/constraint_system.rs @@ -1119,8 +1119,10 @@ impl PlonkCircuit { // Compute the extended identity permutation // id[i*n+j] = k[i] * g^j let k: Vec = compute_coset_representatives(self.num_wire_types, Some(n)); + // Precompute domain elements let group_elems: Vec = self.eval_domain.elements().collect(); + // Compute extended identity permutation self.extended_id_permutation = vec![F::zero(); self.num_wire_types * n]; for (i, &coset_repr) in k.iter().enumerate() {