Skip to content

Commit

Permalink
plonk: multiprover: constraint-system: Implement MpcArithmetization
Browse files Browse the repository at this point in the history
… for `MpcPlonkCircuit`
  • Loading branch information
joeykraut committed Oct 22, 2023
1 parent 0f47cc0 commit b65fae7
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 6 deletions.
248 changes: 242 additions & 6 deletions plonk/src/multiprover/proof_system/constraint_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@
//! MPC-enabled arithmetic
use ark_ec::CurveGroup;
use ark_ff::FftField;
use ark_ff::{FftField, Zero};
use ark_mpc::{
algebra::{AuthenticatedDensePoly, AuthenticatedScalarResult, Scalar, ScalarResult},
gadgets::prefix_product,
MpcFabric,
};
use ark_poly::{univariate::DensePolynomial, EvaluationDomain, Radix2EvaluationDomain};
use ark_poly::{
univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Radix2EvaluationDomain,
};
use async_trait::async_trait;
use futures::future;
use itertools::Itertools;
use jf_relation::{
constants::{GATE_WIDTH, N_MUL_SELECTORS},
constants::{compute_coset_representatives, GATE_WIDTH, N_MUL_SELECTORS},
errors::CircuitError,
gates::{
AdditionGate, BoolGate, ConstantGate, EqualityGate, Gate, IoGate, MultiplicationGate,
PaddingGate, SubtractionGate,
},
GateId, Variable, WireId,
};
use jf_utils::par_utils::parallelizable_slice_iter;
use rayon::prelude::*;

use super::MpcCircuitError;

Expand Down Expand Up @@ -227,7 +232,7 @@ where
/// Return an error if the circuit has not been finalized yet.
fn compute_extended_permutation_polynomials(
&self,
) -> Result<Vec<AuthenticatedDensePoly<C>>, MpcCircuitError>;
) -> Result<Vec<DensePolynomial<C::ScalarField>>, MpcCircuitError>;

/// Compute and return the product polynomial for permutation arguments.
/// Return an error if the circuit has not been finalized yet.
Expand Down Expand Up @@ -386,6 +391,7 @@ where
/// You should absolutely sure about what you are doing
/// You should normally only use this API if you already enforce `v` to be a
/// boolean value using other constraints.
#[allow(unused)]
pub(crate) fn create_boolean_variable_unchecked(
&mut self,
a: AuthenticatedScalarResult<C>,
Expand Down Expand Up @@ -485,7 +491,7 @@ impl<C: CurveGroup> MpcPlonkCircuit<C> {

// Compute the gate output
let expected_gate_output =
pub_input + self.compute_gate_output(q_lc, q_mul, q_hash, q_ecc, q_c, q_o, &w_vals);
pub_input + self.compute_gate_output(q_lc, q_mul, q_hash, q_ecc, q_c, &w_vals);
let gate_output = q_o * w_vals[4];

(gate_output - expected_gate_output).open()
Expand All @@ -496,14 +502,14 @@ impl<C: CurveGroup> MpcPlonkCircuit<C> {
/// This method differs from the single prover case because multiplication
/// induces a significantly higher overhead. So we only compute the
/// output of a gate if its selector is non-zero
#[allow(clippy::too_many_arguments)]
fn compute_gate_output(
&self,
q_lc: Vec<Scalar<C>>,
q_mul: Vec<Scalar<C>>,
q_hash: Vec<Scalar<C>>,
q_ecc: Scalar<C>,
q_c: Scalar<C>,
q_o: Scalar<C>,
wire_values: &[&AuthenticatedScalarResult<C>],
) -> AuthenticatedScalarResult<C> {
let mut res = self.fabric.zero_authenticated();
Expand Down Expand Up @@ -931,3 +937,233 @@ impl<C: CurveGroup> MpcCircuit<C> for MpcPlonkCircuit<C> {
}
}
}

/// Private permutation related methods
impl<C: CurveGroup> MpcPlonkCircuit<C> {
/// Copy constraints: precompute the extended permutation over circuit
/// wires. Refer to Sec 5.2 and Sec 8.1 of https://eprint.iacr.org/2019/953.pdf for more details.
#[inline]
fn compute_extended_id_permutation(&mut self) {
assert!(self.is_finalized());
let n = self.eval_domain.size();

// Compute the extended identity permutation
// id[i*n+j] = k[i] * g^j
let k: Vec<C::ScalarField> = compute_coset_representatives(self.num_wire_types, Some(n));

// Precompute domain elements
let group_elems: Vec<C::ScalarField> = self.eval_domain.elements().collect();

// Compute extended identity permutation
self.extended_id_permutation = vec![C::ScalarField::zero(); self.num_wire_types * n];
for (i, &coset_repr) in k.iter().enumerate() {
for (j, &group_elem) in group_elems.iter().enumerate() {
self.extended_id_permutation[i * n + j] = coset_repr * group_elem;
}
}
}

#[inline]
fn compute_extended_permutation(&self) -> Result<Vec<C::ScalarField>, MpcCircuitError> {
assert!(self.is_finalized());
let n = self.eval_domain.size();

// The extended wire permutation can be computed as
// extended_perm[i] = id[wire_perm[i].into() * n + wire_perm[i].1]
let extended_perm: Vec<C::ScalarField> = self
.wire_permutation
.iter()
.map(|&(wire_id, gate_id)| {
// if permutation value undefined, return 0
if wire_id >= self.num_wire_types {
C::ScalarField::zero()
} else {
self.extended_id_permutation[wire_id * n + gate_id]
}
})
.collect();

if extended_perm.len() != self.num_wire_types * n {
return Err(MpcCircuitError::ConstraintSystem(
CircuitError::ParameterError(
"Length of the extended permutation vector should be number of gate \
(including padded dummy gates) * number of wire types"
.to_string(),
),
));
}
Ok(extended_perm)
}
}

/// Finalization
impl<C: CurveGroup> MpcPlonkCircuit<C> {
/// Finalize the setup of the circuit before arithmetization.
pub fn finalize_for_arithmetization(&mut self) -> Result<(), MpcCircuitError> {
if self.is_finalized() {
return Ok(());
}

self.eval_domain = Radix2EvaluationDomain::new(self.num_gates())
.ok_or(CircuitError::DomainCreationError)
.map_err(MpcCircuitError::ConstraintSystem)?;
self.pad()?;
self.rearrange_gates()?;
self.compute_wire_permutation();
self.compute_extended_id_permutation();
Ok(())
}
}

impl<C: CurveGroup> MpcArithmetization<C> for MpcPlonkCircuit<C> {
fn srs_size(&self) -> Result<usize, MpcCircuitError> {
Ok(self.eval_domain_size()? + 2)
}

fn eval_domain_size(&self) -> Result<usize, MpcCircuitError> {
self.check_finalize_flag(true)?;
Ok(self.eval_domain.size())
}

fn compute_selector_polynomials(
&self,
) -> Result<Vec<DensePolynomial<<C>::ScalarField>>, MpcCircuitError> {
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
if domain.size() < self.num_gates() {
return Err(MpcCircuitError::ConstraintSystem(
CircuitError::ParameterError(
"Domain size should be bigger than number of constraint".to_string(),
),
));
}

// Order: (lc, mul, hash, o, c, ecc) as specified in spec
let selector_polys = parallelizable_slice_iter(&self.all_selectors())
.map(|selector| DensePolynomial::from_coefficients_vec(domain.ifft(selector)))
.collect();
Ok(selector_polys)
}

fn compute_extended_permutation_polynomials(
&self,
) -> Result<Vec<DensePolynomial<C::ScalarField>>, MpcCircuitError> {
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
let n = domain.size();
let extended_perm = self.compute_extended_permutation()?;

let extended_perm_polys: Vec<DensePolynomial<C::ScalarField>> =
parallelizable_slice_iter(&(0..self.num_wire_types).collect::<Vec<_>>()) // current par_utils only support slice iterator, not range iterator.
.map(|i| {
DensePolynomial::from_coefficients_vec(
domain.ifft(&extended_perm[i * n..(i + 1) * n]),
)
})
.collect();

Ok(extended_perm_polys)
}

fn compute_prod_permutation_polynomial(
&self,
beta: &<C>::ScalarField,
gamma: &<C>::ScalarField,
) -> Result<AuthenticatedDensePoly<C>, MpcCircuitError> {
self.check_finalize_flag(true)?;
let n = self.eval_domain.size();

let gamma = Scalar::new(*gamma);
let beta = Scalar::new(*beta);
let one = self.fabric.one_authenticated();

let mut numerators = Vec::with_capacity(self.num_wire_types() * (n - 1));
let mut denominators = Vec::with_capacity(self.num_wire_types() * (n - 1));

for j in 0..(n - 1) {
// Numerator
let mut a = one.clone();
// Denominator
let mut b = one.clone();

for i in 0..self.num_wire_types() {
let wire_value = &self.witness[self.wire_variable(i, j)];
let tmp = wire_value + gamma;
a = a * (&tmp + beta * Scalar::new(self.extended_id_permutation[i * n + j]));

let (perm_i, perm_j) = self.wire_permutation[i * n + j];
b = b
* (tmp + beta + Scalar::new(self.extended_id_permutation[perm_i * n + perm_j]));
}

numerators.push(a);
denominators.push(b);
}

// Divide the numerators and denominators, create a prefix product of this
// division, and then convert into a polynomial from evaluation form
let div_res = AuthenticatedScalarResult::batch_div(&numerators, &denominators);
let products = prefix_product(&div_res, &self.fabric);

// The last element of this product is one for a valid proof, we match the
// single-prover implementation and put this first in the resulting
// prefix product vec
let product_vec = [vec![one], products].concat();
let coeffs =
AuthenticatedScalarResult::ifft::<Radix2EvaluationDomain<C::ScalarField>>(&product_vec);

Ok(AuthenticatedDensePoly::from_coeffs(coeffs))
}

fn compute_wire_polynomials(&self) -> Result<Vec<AuthenticatedDensePoly<C>>, MpcCircuitError> {
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
if domain.size() < self.num_gates() {
return Err(MpcCircuitError::ConstraintSystem(
CircuitError::ParameterError(format!(
"Domain size {} should be bigger than number of constraint {}",
domain.size(),
self.num_gates()
)),
));
}

let witness = &self.witness;
let wire_polys: Vec<AuthenticatedDensePoly<C>> =
parallelizable_slice_iter(&self.wire_variables)
.take(self.num_wire_types())
.map(|wire_vars| {
let wire_vec: Vec<AuthenticatedScalarResult<C>> = wire_vars
.iter()
.map(|&var| witness[var].clone())
.collect_vec();

let coeffs = AuthenticatedScalarResult::ifft::<
Radix2EvaluationDomain<C::ScalarField>,
>(&wire_vec);

AuthenticatedDensePoly::from_coeffs(coeffs)
})
.collect();

assert_eq!(wire_polys.len(), self.num_wire_types());
Ok(wire_polys)
}

fn compute_pub_input_polynomial(&self) -> Result<AuthenticatedDensePoly<C>, MpcCircuitError> {
self.check_finalize_flag(true)?;

let domain = &self.eval_domain;
let mut pub_input_vec = self.fabric.zeros_authenticated(domain.size());

self.pub_input_gate_ids.iter().for_each(|&io_gate_id| {
let var = self.wire_variables[GATE_WIDTH][io_gate_id];
pub_input_vec[io_gate_id] = self.witness[var].clone();
});

let coeffs = AuthenticatedScalarResult::ifft::<Radix2EvaluationDomain<C::ScalarField>>(
&pub_input_vec,
);
Ok(AuthenticatedDensePoly::from_coeffs(coeffs))
}
}
2 changes: 2 additions & 0 deletions relation/src/constraint_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1119,8 +1119,10 @@ impl<F: PrimeField> PlonkCircuit<F> {
// Compute the extended identity permutation
// id[i*n+j] = k[i] * g^j
let k: Vec<F> = compute_coset_representatives(self.num_wire_types, Some(n));

// Precompute domain elements
let group_elems: Vec<F> = self.eval_domain.elements().collect();

// Compute extended identity permutation
self.extended_id_permutation = vec![F::zero(); self.num_wire_types * n];
for (i, &coset_repr) in k.iter().enumerate() {
Expand Down

0 comments on commit b65fae7

Please sign in to comment.